omni_director/providers/
registry.rs1use super::{Provider, ProviderError, ProviderResult, ProviderContext, ProviderMetadata};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10pub struct ProviderRegistry {
12 providers: RwLock<HashMap<String, Arc<dyn Provider>>>,
14 metadata: RwLock<HashMap<String, ProviderMetadata>>,
16 context: Arc<dyn ProviderContext>,
18}
19
20impl ProviderRegistry {
21 pub fn new(context: Arc<dyn ProviderContext>) -> Self {
23 Self {
24 providers: RwLock::new(HashMap::new()),
25 metadata: RwLock::new(HashMap::new()),
26 context,
27 }
28 }
29
30 pub async fn register_provider(
32 &self,
33 provider: Arc<dyn Provider>,
34 metadata: ProviderMetadata,
35 ) -> ProviderResult<()> {
36 let name = provider.name().to_string();
37
38 {
40 let providers = self.providers.read().await;
41 if providers.contains_key(&name) {
42 return Err(ProviderError::InitializationFailed(
43 format!("Provider '{}' already registered", name)
44 ));
45 }
46 }
47
48 {
50 let mut providers = self.providers.write().await;
51 let mut metadata_store = self.metadata.write().await;
52
53 providers.insert(name.clone(), provider);
54 metadata_store.insert(name.clone(), metadata);
55 }
56
57 println!("✅ Registered provider: {}", name);
58 Ok(())
59 }
60
61 pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
63 let providers = self.providers.read().await;
64 providers.get(name).cloned()
65 }
66
67 pub async fn get_metadata(&self, name: &str) -> Option<ProviderMetadata> {
69 let metadata = self.metadata.read().await;
70 metadata.get(name).cloned()
71 }
72
73 pub async fn list_providers(&self) -> Vec<String> {
75 let providers = self.providers.read().await;
76 providers.keys().cloned().collect()
77 }
78
79 pub async fn has_provider(&self, name: &str) -> bool {
81 let providers = self.providers.read().await;
82 providers.contains_key(name)
83 }
84
85 pub async fn list_metadata(&self) -> Vec<ProviderMetadata> {
87 let metadata = self.metadata.read().await;
88 metadata.values().cloned().collect()
89 }
90
91 pub async fn execute_operation(
93 &self,
94 provider_name: &str,
95 feature: &str,
96 operation: &str,
97 args: HashMap<String, serde_json::Value>,
98 ) -> ProviderResult<serde_json::Value> {
99 let provider = self.get_provider(provider_name).await
100 .ok_or_else(|| ProviderError::NotFound(provider_name.to_string()))?;
101
102 if !provider.supports_feature(feature) {
104 return Err(ProviderError::FeatureNotSupported {
105 provider: provider_name.to_string(),
106 feature: feature.to_string(),
107 });
108 }
109
110 let operations = provider.feature_operations(feature)?;
112 if !operations.contains(&operation.to_string()) {
113 return Err(ProviderError::OperationNotSupported {
114 provider: provider_name.to_string(),
115 feature: feature.to_string(),
116 operation: operation.to_string(),
117 });
118 }
119
120 provider.execute_operation(feature, operation, args, self.context.as_ref())
122 }
123
124 pub async fn get_feature_operations(
126 &self,
127 provider_name: &str,
128 feature: &str,
129 ) -> ProviderResult<Vec<String>> {
130 let provider = self.get_provider(provider_name).await
131 .ok_or_else(|| ProviderError::NotFound(provider_name.to_string()))?;
132
133 if !provider.supports_feature(feature) {
134 return Err(ProviderError::FeatureNotSupported {
135 provider: provider_name.to_string(),
136 feature: feature.to_string(),
137 });
138 }
139
140 provider.feature_operations(feature)
141 }
142
143 pub async fn get_statistics(&self) -> ProviderRegistryStats {
145 let providers = self.providers.read().await;
146 let metadata = self.metadata.read().await;
147
148 let mut total_features = 0;
149 let mut total_operations = 0;
150 let mut provider_stats = HashMap::new();
151
152 for (name, meta) in metadata.iter() {
153 let feature_count = meta.features.len();
154 let operation_count: usize = meta.features.iter()
155 .map(|f| f.operations.len())
156 .sum();
157
158 total_features += feature_count;
159 total_operations += operation_count;
160
161 provider_stats.insert(name.clone(), ProviderStats {
162 feature_count,
163 operation_count,
164 supported_features: meta.feature_names(),
165 });
166 }
167
168 ProviderRegistryStats {
169 total_providers: providers.len(),
170 total_features,
171 total_operations,
172 provider_stats,
173 }
174 }
175
176 pub async fn initialize_all(&self) -> ProviderResult<()> {
178 let providers = self.providers.read().await;
179 let errors: Vec<String> = Vec::new();
180
181 for (name, _provider) in providers.iter() {
182 println!("🔧 Initializing provider: {}", name);
185 }
187
188 if !errors.is_empty() {
189 return Err(ProviderError::InitializationFailed(
190 format!("Failed to initialize {} providers", errors.len())
191 ));
192 }
193
194 Ok(())
195 }
196
197 pub async fn shutdown_all(&self) -> ProviderResult<()> {
199 let providers = self.providers.read().await;
200 let errors: Vec<String> = Vec::new();
201
202 for (name, _provider) in providers.iter() {
203 println!("🛑 Shutting down provider: {}", name);
204 }
206
207 if !errors.is_empty() {
208 return Err(ProviderError::ExecutionFailed(
209 format!("Failed to shutdown {} providers", errors.len())
210 ));
211 }
212
213 Ok(())
214 }
215}
216
217#[derive(Debug, Clone, serde::Serialize)]
219pub struct ProviderRegistryStats {
220 pub total_providers: usize,
221 pub total_features: usize,
222 pub total_operations: usize,
223 pub provider_stats: HashMap<String, ProviderStats>,
224}
225
226#[derive(Debug, Clone, serde::Serialize)]
228pub struct ProviderStats {
229 pub feature_count: usize,
230 pub operation_count: usize,
231 pub supported_features: Vec<String>,
232}