Skip to main content

zeroclaw_runtime/tools/
model_switch.rs

1use crate::agent::loop_::get_model_switch_state;
2use crate::security::SecurityPolicy;
3use crate::security::policy::ToolOperation;
4use async_trait::async_trait;
5use serde_json::json;
6use std::sync::Arc;
7use zeroclaw_api::tool::{Tool, ToolResult};
8
9pub struct ModelSwitchTool {
10    security: Arc<SecurityPolicy>,
11}
12
13impl ModelSwitchTool {
14    pub fn new(security: Arc<SecurityPolicy>) -> Self {
15        Self { security }
16    }
17}
18
19#[async_trait]
20impl Tool for ModelSwitchTool {
21    fn name(&self) -> &str {
22        "model_switch"
23    }
24
25    fn description(&self) -> &str {
26        "Switch the AI model at runtime. Use 'get' to see current model, 'list_model_providers' to see available model_providers, 'list_models' to see models for a model_provider, or 'set' to switch to a different model. The switch takes effect immediately for the current conversation."
27    }
28
29    fn parameters_schema(&self) -> serde_json::Value {
30        json!({
31            "type": "object",
32            "properties": {
33                "action": {
34                    "type": "string",
35                    "enum": ["get", "set", "list_model_providers", "list_models"],
36                    "description": "Action to perform: get current model, set a new model, list available model_providers, or list models for a model_provider"
37                },
38                "model_provider": {
39                    "type": "string",
40                    "description": "ModelProvider name (e.g., 'openai', 'anthropic', 'groq', 'ollama'). Required for 'set' and 'list_models' actions."
41                },
42                "model": {
43                    "type": "string",
44                    "description": "Model ID (e.g., 'gpt-4o', 'claude-sonnet-4-6'). Required for 'set' action."
45                }
46            },
47            "required": ["action"]
48        })
49    }
50
51    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
52        let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get");
53
54        if let Err(error) = self
55            .security
56            .enforce_tool_operation(ToolOperation::Act, "model_switch")
57        {
58            return Ok(ToolResult {
59                success: false,
60                output: String::new(),
61                error: Some(error),
62            });
63        }
64
65        match action {
66            "get" => self.handle_get(),
67            "set" => self.handle_set(&args),
68            "list_model_providers" => self.handle_list_providers(),
69            "list_models" => self.handle_list_models(&args),
70            _ => Ok(ToolResult {
71                success: false,
72                output: String::new(),
73                error: Some(format!(
74                    "Unknown action: {}. Valid actions: get, set, list_model_providers, list_models",
75                    action
76                )),
77            }),
78        }
79    }
80}
81
82impl ModelSwitchTool {
83    fn handle_get(&self) -> anyhow::Result<ToolResult> {
84        let switch_state = get_model_switch_state();
85        let pending = switch_state.lock().unwrap().clone();
86
87        Ok(ToolResult {
88            success: true,
89            output: serde_json::to_string_pretty(&json!({
90                "pending_switch": pending,
91                "note": "To switch models, use action 'set' with model_provider and model parameters"
92            }))?,
93            error: None,
94        })
95    }
96
97    fn handle_set(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
98        let model_provider = args.get("model_provider").and_then(|v| v.as_str());
99
100        let model_provider = match model_provider {
101            Some(p) => p,
102            None => {
103                return Ok(ToolResult {
104                    success: false,
105                    output: String::new(),
106                    error: Some("Missing 'model_provider' parameter for 'set' action".to_string()),
107                });
108            }
109        };
110
111        let model = args.get("model").and_then(|v| v.as_str());
112
113        let model = match model {
114            Some(m) => m,
115            None => {
116                return Ok(ToolResult {
117                    success: false,
118                    output: String::new(),
119                    error: Some("Missing 'model' parameter for 'set' action".to_string()),
120                });
121            }
122        };
123
124        // Validate the model_provider exists. Legacy colon-URL forms
125        // ("custom:https://..." and "anthropic-custom:...") are collapsed at
126        // TOML load by `normalize_model_provider_type` in `schema/v2.rs` into
127        // the typed `custom` family slot, so the runtime only sees canonical
128        // model-provider names. Validate against the static catalog directly.
129        let known_model_providers = zeroclaw_providers::list_model_providers();
130        let model_provider_valid = known_model_providers
131            .iter()
132            .any(|p| p.name.eq_ignore_ascii_case(model_provider));
133
134        if !model_provider_valid {
135            return Ok(ToolResult {
136                success: false,
137                output: serde_json::to_string_pretty(&json!({
138                    "available_model_providers": known_model_providers.iter().map(|p| p.name).collect::<Vec<_>>()
139                }))?,
140                error: Some(format!(
141                    "Unknown model model_provider: {}. Use 'list_model_providers' to see available options.",
142                    model_provider
143                )),
144            });
145        }
146
147        // Set the global model switch request
148        let switch_state = get_model_switch_state();
149        *switch_state.lock().unwrap() = Some((model_provider.to_string(), model.to_string()));
150
151        Ok(ToolResult {
152            success: true,
153            output: serde_json::to_string_pretty(&json!({
154                "message": "Model switch requested",
155                "model_provider": model_provider,
156                "model": model,
157                "note": "The agent will switch to this model on the next turn. Use 'get' to check pending switch."
158            }))?,
159            error: None,
160        })
161    }
162
163    fn handle_list_providers(&self) -> anyhow::Result<ToolResult> {
164        let providers_list = zeroclaw_providers::list_model_providers();
165
166        let model_providers: Vec<serde_json::Value> = providers_list
167            .iter()
168            .map(|p| {
169                json!({
170                    "name": p.name,
171                    "display_name": p.display_name,
172                    "local": p.local
173                })
174            })
175            .collect();
176
177        Ok(ToolResult {
178            success: true,
179            output: serde_json::to_string_pretty(&json!({
180                "model_providers": model_providers,
181                "count": model_providers.len(),
182                "example": "Use action 'set' with model_provider and model to switch"
183            }))?,
184            error: None,
185        })
186    }
187
188    fn handle_list_models(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
189        let model_provider = args.get("model_provider").and_then(|v| v.as_str());
190
191        let model_provider = match model_provider {
192            Some(p) => p,
193            None => {
194                return Ok(ToolResult {
195                    success: false,
196                    output: String::new(),
197                    error: Some(
198                        "Missing 'model_provider' parameter for 'list_models' action".to_string(),
199                    ),
200                });
201            }
202        };
203
204        // Return common models for known model_providers
205        let models = match model_provider.to_lowercase().as_str() {
206            "openai" => vec![
207                "gpt-4o",
208                "gpt-4o-mini",
209                "gpt-4-turbo",
210                "gpt-4",
211                "gpt-3.5-turbo",
212            ],
213            "anthropic" => vec![
214                "claude-sonnet-4-6",
215                "claude-sonnet-4-5",
216                "claude-3-5-sonnet",
217                "claude-3-opus",
218                "claude-3-haiku",
219            ],
220            "openrouter" => vec![
221                "anthropic/claude-sonnet-4-6",
222                "openai/gpt-4o",
223                "google/gemini-pro",
224                "meta-llama/llama-3-70b-instruct",
225            ],
226            "groq" => vec![
227                "llama-3.3-70b-versatile",
228                "mixtral-8x7b-32768",
229                "llama-3.1-70b-speculative",
230            ],
231            "ollama" => vec!["llama3", "llama3.1", "mistral", "codellama", "phi3"],
232            "deepseek" => vec!["deepseek-chat", "deepseek-coder"],
233            "mistral" => vec![
234                "mistral-large-latest",
235                "mistral-small-latest",
236                "mistral-nemo",
237            ],
238            "google" | "gemini" => vec!["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"],
239            "xai" | "grok" => vec!["grok-2", "grok-2-vision", "grok-beta"],
240            _ => vec![],
241        };
242
243        if models.is_empty() {
244            return Ok(ToolResult {
245                success: true,
246                output: serde_json::to_string_pretty(&json!({
247                    "model_provider": model_provider,
248                    "models": [],
249                    "note": "No common models listed for this model_provider. Check model_provider documentation for available models."
250                }))?,
251                error: None,
252            });
253        }
254
255        Ok(ToolResult {
256            success: true,
257            output: serde_json::to_string_pretty(&json!({
258                "model_provider": model_provider,
259                "models": models,
260                "example": "Use action 'set' with this model_provider and a model ID to switch"
261            }))?,
262            error: None,
263        })
264    }
265}