Skip to main content

zeroclaw_runtime/tools/
model_switch.rs

1use crate::agent::loop_::get_model_switch_state;
2use crate::security::SecurityPolicy;
3use crate::security::policy::ToolOperation;
4use async_trait::async_trait;
5use serde_json::json;
6use std::sync::Arc;
7use zeroclaw_api::tool::{Tool, ToolResult};
8use zeroclaw_config::schema::Config;
9
10fn configured_model_provider_profiles(config: &Config) -> Vec<String> {
11    let mut profiles = config
12        .providers
13        .models
14        .iter_entries()
15        .map(|(family, alias, _profile)| format!("{family}.{alias}"))
16        .collect::<Vec<_>>();
17    profiles.sort();
18    profiles
19}
20
21fn resolve_model_provider_profile_ref(config: &Config, raw: &str) -> Result<String, String> {
22    let raw = raw.trim();
23    let Some((family, alias)) = raw.split_once('.') else {
24        return Err(format!(
25            "model_provider must be a dotted `<type>.<alias>` provider profile reference, got `{raw}`"
26        ));
27    };
28    let family = family.trim();
29    let alias = alias.trim();
30    if family.is_empty() || alias.is_empty() {
31        return Err(format!(
32            "model_provider must be a dotted `<type>.<alias>` provider profile reference, got `{raw}`"
33        ));
34    }
35
36    if config.providers.models.find(family, alias).is_none() {
37        let available = configured_model_provider_profiles(config);
38        let available = if available.is_empty() {
39            "no configured provider profiles".to_string()
40        } else {
41            available.join(", ")
42        };
43        return Err(format!(
44            "model_provider `{raw}` is not a configured provider profile. Add a [providers.models.{family}.{alias}] entry or use one of: {available}"
45        ));
46    }
47
48    Ok(format!("{family}.{alias}"))
49}
50
51pub struct ModelSwitchTool {
52    security: Arc<SecurityPolicy>,
53    config: Arc<Config>,
54}
55
56impl ModelSwitchTool {
57    /// Canonical tool name. Referenced by the subagent registry filter so
58    /// a rename cannot desync the two.
59    pub const NAME: &'static str = "model_switch";
60
61    pub fn new(security: Arc<SecurityPolicy>, config: Arc<Config>) -> Self {
62        Self { security, config }
63    }
64}
65
66#[async_trait]
67impl Tool for ModelSwitchTool {
68    fn name(&self) -> &str {
69        Self::NAME
70    }
71
72    fn description(&self) -> &str {
73        "Request a runtime model switch using a configured provider profile plus provider-local model. Use 'get' to see the pending switch, 'list_model_providers' to see provider families, 'list_models' to see common models for a provider profile, or 'set' with a dotted provider profile ref such as 'openai.default'. The switch is runtime/session state and does not write config."
74    }
75
76    fn parameters_schema(&self) -> serde_json::Value {
77        json!({
78            "type": "object",
79            "properties": {
80                "action": {
81                    "type": "string",
82                    "enum": ["get", "set", "list_model_providers", "list_models"],
83                    "description": "Action to perform: get pending switch state, set a runtime provider-profile/model switch, list available provider families, or list common models for a provider profile"
84                },
85                "model_provider": {
86                    "type": "string",
87                    "description": "Dotted provider profile reference (e.g., 'openai.default', 'anthropic.sonnet', 'ollama.local'). Required for 'set' and 'list_models' actions."
88                },
89                "model": {
90                    "type": "string",
91                    "description": "Model ID (e.g., 'gpt-4o', 'claude-sonnet-4-6'). Required for 'set' action."
92                }
93            },
94            "required": ["action"]
95        })
96    }
97
98    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
99        let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get");
100
101        if let Err(error) = self
102            .security
103            .enforce_tool_operation(ToolOperation::Act, "model_switch")
104        {
105            return Ok(ToolResult {
106                success: false,
107                output: String::new(),
108                error: Some(error),
109            });
110        }
111
112        match action {
113            "get" => self.handle_get(),
114            "set" => self.handle_set(&args),
115            "list_model_providers" => self.handle_list_providers(),
116            "list_models" => self.handle_list_models(&args),
117            _ => Ok(ToolResult {
118                success: false,
119                output: String::new(),
120                error: Some(format!(
121                    "Unknown action: {}. Valid actions: get, set, list_model_providers, list_models",
122                    action
123                )),
124            }),
125        }
126    }
127}
128
129impl ModelSwitchTool {
130    fn handle_get(&self) -> anyhow::Result<ToolResult> {
131        let switch_state = get_model_switch_state();
132        let pending = switch_state.lock().unwrap().clone();
133
134        Ok(ToolResult {
135            success: true,
136            output: serde_json::to_string_pretty(&json!({
137                "pending_switch": pending,
138                "note": "To switch models, use action 'set' with dotted <type>.<alias> model_provider and model parameters"
139            }))?,
140            error: None,
141        })
142    }
143
144    fn handle_set(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
145        let model_provider = args.get("model_provider").and_then(|v| v.as_str());
146
147        let model_provider = match model_provider {
148            Some(p) => p,
149            None => {
150                return Ok(ToolResult {
151                    success: false,
152                    output: String::new(),
153                    error: Some("Missing 'model_provider' parameter for 'set' action".to_string()),
154                });
155            }
156        };
157
158        let model = args.get("model").and_then(|v| v.as_str());
159
160        let model = match model {
161            Some(m) => m,
162            None => {
163                return Ok(ToolResult {
164                    success: false,
165                    output: String::new(),
166                    error: Some("Missing 'model' parameter for 'set' action".to_string()),
167                });
168            }
169        };
170
171        let model_provider = match resolve_model_provider_profile_ref(&self.config, model_provider)
172        {
173            Ok(model_provider) => model_provider,
174            Err(error) => {
175                let known_model_providers = zeroclaw_providers::list_model_providers();
176                let configured_profiles = configured_model_provider_profiles(&self.config);
177                return Ok(ToolResult {
178                    success: false,
179                    output: serde_json::to_string_pretty(&json!({
180                        "provider_ref_shape": "<type>.<alias>",
181                        "available_provider_families": known_model_providers.iter().map(|p| p.name).collect::<Vec<_>>(),
182                        "configured_provider_profiles": configured_profiles
183                    }))?,
184                    error: Some(error),
185                });
186            }
187        };
188
189        let model = model.trim();
190        if model.is_empty() {
191            return Ok(ToolResult {
192                success: false,
193                output: String::new(),
194                error: Some("Model ID cannot be empty".to_string()),
195            });
196        }
197
198        // Set the global model switch request
199        let switch_state = get_model_switch_state();
200        *switch_state.lock().unwrap() = Some((model_provider.clone(), model.to_string()));
201
202        Ok(ToolResult {
203            success: true,
204            output: serde_json::to_string_pretty(&json!({
205                "message": "Model switch requested",
206                "model_provider": model_provider,
207                "model": model,
208                "note": "The active runtime path will consume this provider-profile/model switch where model_switch is supported. This does not write persisted config."
209            }))?,
210            error: None,
211        })
212    }
213
214    fn handle_list_providers(&self) -> anyhow::Result<ToolResult> {
215        let providers_list = zeroclaw_providers::list_model_providers();
216        let configured_profiles = configured_model_provider_profiles(&self.config);
217        let configured_count = configured_profiles.len();
218
219        let model_providers: Vec<serde_json::Value> = providers_list
220            .iter()
221            .map(|p| {
222                json!({
223                    "name": p.name,
224                    "display_name": p.display_name,
225                    "local": p.local
226                })
227            })
228            .collect();
229
230        Ok(ToolResult {
231            success: true,
232            output: serde_json::to_string_pretty(&json!({
233                "model_providers": model_providers,
234                "count": model_providers.len(),
235                "configured_provider_profiles": configured_profiles,
236                "configured_count": configured_count,
237                "provider_ref_shape": "<type>.<alias>",
238                "example": "Use action 'set' with a dotted provider profile ref such as 'openai.default'"
239            }))?,
240            error: None,
241        })
242    }
243
244    fn handle_list_models(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
245        let model_provider = args.get("model_provider").and_then(|v| v.as_str());
246
247        let model_provider = match model_provider {
248            Some(p) => p,
249            None => {
250                return Ok(ToolResult {
251                    success: false,
252                    output: String::new(),
253                    error: Some(
254                        "Missing 'model_provider' parameter for 'list_models' action".to_string(),
255                    ),
256                });
257            }
258        };
259
260        let model_provider = match resolve_model_provider_profile_ref(&self.config, model_provider)
261        {
262            Ok(model_provider) => model_provider,
263            Err(error) => {
264                return Ok(ToolResult {
265                    success: false,
266                    output: serde_json::to_string_pretty(&json!({
267                        "provider_ref_shape": "<type>.<alias>",
268                        "configured_provider_profiles": configured_model_provider_profiles(&self.config)
269                    }))?,
270                    error: Some(error),
271                });
272            }
273        };
274        let provider_family = model_provider
275            .split_once('.')
276            .map(|(family, _alias)| family)
277            .unwrap_or(model_provider.as_str());
278
279        // Return common models for known model_provider families.
280        let models = match provider_family.to_lowercase().as_str() {
281            "openai" => vec![
282                "gpt-4o",
283                "gpt-4o-mini",
284                "gpt-4-turbo",
285                "gpt-4",
286                "gpt-3.5-turbo",
287            ],
288            "anthropic" => vec![
289                "claude-sonnet-4-6",
290                "claude-sonnet-4-5",
291                "claude-3-5-sonnet",
292                "claude-3-opus",
293                "claude-3-haiku",
294            ],
295            "openrouter" => vec![
296                "anthropic/claude-sonnet-4-6",
297                "openai/gpt-4o",
298                "google/gemini-pro",
299                "meta-llama/llama-3-70b-instruct",
300            ],
301            "groq" => vec![
302                "llama-3.3-70b-versatile",
303                "mixtral-8x7b-32768",
304                "llama-3.1-70b-speculative",
305            ],
306            "ollama" => vec!["llama3", "llama3.1", "mistral", "codellama", "phi3"],
307            "deepseek" => vec!["deepseek-chat", "deepseek-coder"],
308            "mistral" => vec![
309                "mistral-large-latest",
310                "mistral-small-latest",
311                "mistral-nemo",
312            ],
313            "gemini" => vec!["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"],
314            "xai" => vec!["grok-2", "grok-2-vision", "grok-beta"],
315            _ => vec![],
316        };
317
318        if models.is_empty() {
319            return Ok(ToolResult {
320                success: true,
321                output: serde_json::to_string_pretty(&json!({
322                    "model_provider": model_provider,
323                    "models": [],
324                    "note": "No common models listed for this model_provider family. Check model_provider documentation for available models."
325                }))?,
326                error: None,
327            });
328        }
329
330        Ok(ToolResult {
331            success: true,
332            output: serde_json::to_string_pretty(&json!({
333                "model_provider": model_provider,
334                "models": models,
335                "example": "Use action 'set' with this model_provider and a model ID to switch"
336            }))?,
337            error: None,
338        })
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::agent::loop_::{clear_model_switch_request, get_model_switch_state};
346
347    static MODEL_SWITCH_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
348
349    fn test_config() -> Config {
350        let mut config = Config::default();
351        config.providers.models.ensure("openai", "default").unwrap();
352        config.providers.models.ensure("custom", "local").unwrap();
353        config
354    }
355
356    fn tool() -> ModelSwitchTool {
357        ModelSwitchTool::new(Arc::new(SecurityPolicy::default()), Arc::new(test_config()))
358    }
359
360    fn pending_switch() -> Option<(String, String)> {
361        get_model_switch_state().lock().unwrap().clone()
362    }
363
364    #[test]
365    fn set_rejects_bare_provider_family() {
366        let _guard = MODEL_SWITCH_TEST_LOCK.lock().unwrap();
367        clear_model_switch_request();
368
369        let result = tool()
370            .handle_set(&json!({
371                "model_provider": "openai",
372                "model": "gpt-4o"
373            }))
374            .expect("set should return a tool result");
375
376        assert!(!result.success);
377        assert!(
378            result
379                .error
380                .as_deref()
381                .unwrap_or_default()
382                .contains("dotted `<type>.<alias>`"),
383            "unexpected error: {:?}",
384            result.error
385        );
386        assert_eq!(pending_switch(), None);
387    }
388
389    #[test]
390    fn set_accepts_dotted_provider_profile_ref() {
391        let _guard = MODEL_SWITCH_TEST_LOCK.lock().unwrap();
392        clear_model_switch_request();
393
394        let result = tool()
395            .handle_set(&json!({
396                "model_provider": "openai.default",
397                "model": "gpt-4o"
398            }))
399            .expect("set should return a tool result");
400
401        assert!(result.success, "unexpected error: {:?}", result.error);
402        assert_eq!(
403            pending_switch(),
404            Some(("openai.default".to_string(), "gpt-4o".to_string()))
405        );
406
407        clear_model_switch_request();
408    }
409
410    #[test]
411    fn set_rejects_unconfigured_provider_profile_ref() {
412        let _guard = MODEL_SWITCH_TEST_LOCK.lock().unwrap();
413        clear_model_switch_request();
414
415        let result = tool()
416            .handle_set(&json!({
417                "model_provider": "openai.missing",
418                "model": "gpt-4o"
419            }))
420            .expect("set should return a tool result");
421
422        assert!(!result.success);
423        assert!(
424            result
425                .error
426                .as_deref()
427                .unwrap_or_default()
428                .contains("configured provider profile"),
429            "unexpected error: {:?}",
430            result.error
431        );
432        assert_eq!(pending_switch(), None);
433    }
434
435    #[test]
436    fn set_accepts_configured_custom_provider_profile_ref() {
437        let _guard = MODEL_SWITCH_TEST_LOCK.lock().unwrap();
438        clear_model_switch_request();
439
440        let result = tool()
441            .handle_set(&json!({
442                "model_provider": "custom.local",
443                "model": "local-model"
444            }))
445            .expect("set should return a tool result");
446
447        assert!(result.success, "unexpected error: {:?}", result.error);
448        assert_eq!(
449            pending_switch(),
450            Some(("custom.local".to_string(), "local-model".to_string()))
451        );
452
453        clear_model_switch_request();
454    }
455
456    #[test]
457    fn list_models_accepts_dotted_provider_profile_ref() {
458        let result = tool()
459            .handle_list_models(&json!({
460                "model_provider": "openai.default"
461            }))
462            .expect("list_models should return a tool result");
463
464        assert!(result.success, "unexpected error: {:?}", result.error);
465        let output: serde_json::Value =
466            serde_json::from_str(&result.output).expect("output should be json");
467        assert_eq!(output["model_provider"], "openai.default");
468        assert!(
469            output["models"]
470                .as_array()
471                .expect("models should be an array")
472                .iter()
473                .any(|model| model == "gpt-4o")
474        );
475    }
476}