1use crate::agent::loop_::get_model_switch_state;
2use crate::security::SecurityPolicy;
3use crate::security::policy::ToolOperation;
4use async_trait::async_trait;
5use serde_json::json;
6use std::sync::Arc;
7use zeroclaw_api::tool::{Tool, ToolResult};
8
9pub struct ModelSwitchTool {
10 security: Arc<SecurityPolicy>,
11}
12
13impl ModelSwitchTool {
14 pub fn new(security: Arc<SecurityPolicy>) -> Self {
15 Self { security }
16 }
17}
18
19#[async_trait]
20impl Tool for ModelSwitchTool {
21 fn name(&self) -> &str {
22 "model_switch"
23 }
24
25 fn description(&self) -> &str {
26 "Switch the AI model at runtime. Use 'get' to see current model, 'list_model_providers' to see available model_providers, 'list_models' to see models for a model_provider, or 'set' to switch to a different model. The switch takes effect immediately for the current conversation."
27 }
28
29 fn parameters_schema(&self) -> serde_json::Value {
30 json!({
31 "type": "object",
32 "properties": {
33 "action": {
34 "type": "string",
35 "enum": ["get", "set", "list_model_providers", "list_models"],
36 "description": "Action to perform: get current model, set a new model, list available model_providers, or list models for a model_provider"
37 },
38 "model_provider": {
39 "type": "string",
40 "description": "ModelProvider name (e.g., 'openai', 'anthropic', 'groq', 'ollama'). Required for 'set' and 'list_models' actions."
41 },
42 "model": {
43 "type": "string",
44 "description": "Model ID (e.g., 'gpt-4o', 'claude-sonnet-4-6'). Required for 'set' action."
45 }
46 },
47 "required": ["action"]
48 })
49 }
50
51 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
52 let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get");
53
54 if let Err(error) = self
55 .security
56 .enforce_tool_operation(ToolOperation::Act, "model_switch")
57 {
58 return Ok(ToolResult {
59 success: false,
60 output: String::new(),
61 error: Some(error),
62 });
63 }
64
65 match action {
66 "get" => self.handle_get(),
67 "set" => self.handle_set(&args),
68 "list_model_providers" => self.handle_list_providers(),
69 "list_models" => self.handle_list_models(&args),
70 _ => Ok(ToolResult {
71 success: false,
72 output: String::new(),
73 error: Some(format!(
74 "Unknown action: {}. Valid actions: get, set, list_model_providers, list_models",
75 action
76 )),
77 }),
78 }
79 }
80}
81
82impl ModelSwitchTool {
83 fn handle_get(&self) -> anyhow::Result<ToolResult> {
84 let switch_state = get_model_switch_state();
85 let pending = switch_state.lock().unwrap().clone();
86
87 Ok(ToolResult {
88 success: true,
89 output: serde_json::to_string_pretty(&json!({
90 "pending_switch": pending,
91 "note": "To switch models, use action 'set' with model_provider and model parameters"
92 }))?,
93 error: None,
94 })
95 }
96
97 fn handle_set(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
98 let model_provider = args.get("model_provider").and_then(|v| v.as_str());
99
100 let model_provider = match model_provider {
101 Some(p) => p,
102 None => {
103 return Ok(ToolResult {
104 success: false,
105 output: String::new(),
106 error: Some("Missing 'model_provider' parameter for 'set' action".to_string()),
107 });
108 }
109 };
110
111 let model = args.get("model").and_then(|v| v.as_str());
112
113 let model = match model {
114 Some(m) => m,
115 None => {
116 return Ok(ToolResult {
117 success: false,
118 output: String::new(),
119 error: Some("Missing 'model' parameter for 'set' action".to_string()),
120 });
121 }
122 };
123
124 let known_model_providers = zeroclaw_providers::list_model_providers();
130 let model_provider_valid = known_model_providers
131 .iter()
132 .any(|p| p.name.eq_ignore_ascii_case(model_provider));
133
134 if !model_provider_valid {
135 return Ok(ToolResult {
136 success: false,
137 output: serde_json::to_string_pretty(&json!({
138 "available_model_providers": known_model_providers.iter().map(|p| p.name).collect::<Vec<_>>()
139 }))?,
140 error: Some(format!(
141 "Unknown model model_provider: {}. Use 'list_model_providers' to see available options.",
142 model_provider
143 )),
144 });
145 }
146
147 let switch_state = get_model_switch_state();
149 *switch_state.lock().unwrap() = Some((model_provider.to_string(), model.to_string()));
150
151 Ok(ToolResult {
152 success: true,
153 output: serde_json::to_string_pretty(&json!({
154 "message": "Model switch requested",
155 "model_provider": model_provider,
156 "model": model,
157 "note": "The agent will switch to this model on the next turn. Use 'get' to check pending switch."
158 }))?,
159 error: None,
160 })
161 }
162
163 fn handle_list_providers(&self) -> anyhow::Result<ToolResult> {
164 let providers_list = zeroclaw_providers::list_model_providers();
165
166 let model_providers: Vec<serde_json::Value> = providers_list
167 .iter()
168 .map(|p| {
169 json!({
170 "name": p.name,
171 "display_name": p.display_name,
172 "local": p.local
173 })
174 })
175 .collect();
176
177 Ok(ToolResult {
178 success: true,
179 output: serde_json::to_string_pretty(&json!({
180 "model_providers": model_providers,
181 "count": model_providers.len(),
182 "example": "Use action 'set' with model_provider and model to switch"
183 }))?,
184 error: None,
185 })
186 }
187
188 fn handle_list_models(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
189 let model_provider = args.get("model_provider").and_then(|v| v.as_str());
190
191 let model_provider = match model_provider {
192 Some(p) => p,
193 None => {
194 return Ok(ToolResult {
195 success: false,
196 output: String::new(),
197 error: Some(
198 "Missing 'model_provider' parameter for 'list_models' action".to_string(),
199 ),
200 });
201 }
202 };
203
204 let models = match model_provider.to_lowercase().as_str() {
206 "openai" => vec![
207 "gpt-4o",
208 "gpt-4o-mini",
209 "gpt-4-turbo",
210 "gpt-4",
211 "gpt-3.5-turbo",
212 ],
213 "anthropic" => vec![
214 "claude-sonnet-4-6",
215 "claude-sonnet-4-5",
216 "claude-3-5-sonnet",
217 "claude-3-opus",
218 "claude-3-haiku",
219 ],
220 "openrouter" => vec![
221 "anthropic/claude-sonnet-4-6",
222 "openai/gpt-4o",
223 "google/gemini-pro",
224 "meta-llama/llama-3-70b-instruct",
225 ],
226 "groq" => vec![
227 "llama-3.3-70b-versatile",
228 "mixtral-8x7b-32768",
229 "llama-3.1-70b-speculative",
230 ],
231 "ollama" => vec!["llama3", "llama3.1", "mistral", "codellama", "phi3"],
232 "deepseek" => vec!["deepseek-chat", "deepseek-coder"],
233 "mistral" => vec![
234 "mistral-large-latest",
235 "mistral-small-latest",
236 "mistral-nemo",
237 ],
238 "google" | "gemini" => vec!["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"],
239 "xai" | "grok" => vec!["grok-2", "grok-2-vision", "grok-beta"],
240 _ => vec![],
241 };
242
243 if models.is_empty() {
244 return Ok(ToolResult {
245 success: true,
246 output: serde_json::to_string_pretty(&json!({
247 "model_provider": model_provider,
248 "models": [],
249 "note": "No common models listed for this model_provider. Check model_provider documentation for available models."
250 }))?,
251 error: None,
252 });
253 }
254
255 Ok(ToolResult {
256 success: true,
257 output: serde_json::to_string_pretty(&json!({
258 "model_provider": model_provider,
259 "models": models,
260 "example": "Use action 'set' with this model_provider and a model ID to switch"
261 }))?,
262 error: None,
263 })
264 }
265}