1use crate::agent::loop_::get_model_switch_state;
2use crate::security::SecurityPolicy;
3use crate::security::policy::ToolOperation;
4use async_trait::async_trait;
5use serde_json::json;
6use std::sync::Arc;
7use zeroclaw_api::tool::{Tool, ToolResult};
8use zeroclaw_config::schema::Config;
9
10fn configured_model_provider_profiles(config: &Config) -> Vec<String> {
11 let mut profiles = config
12 .providers
13 .models
14 .iter_entries()
15 .map(|(family, alias, _profile)| format!("{family}.{alias}"))
16 .collect::<Vec<_>>();
17 profiles.sort();
18 profiles
19}
20
21fn resolve_model_provider_profile_ref(config: &Config, raw: &str) -> Result<String, String> {
22 let raw = raw.trim();
23 let Some((family, alias)) = raw.split_once('.') else {
24 return Err(format!(
25 "model_provider must be a dotted `<type>.<alias>` provider profile reference, got `{raw}`"
26 ));
27 };
28 let family = family.trim();
29 let alias = alias.trim();
30 if family.is_empty() || alias.is_empty() {
31 return Err(format!(
32 "model_provider must be a dotted `<type>.<alias>` provider profile reference, got `{raw}`"
33 ));
34 }
35
36 if config.providers.models.find(family, alias).is_none() {
37 let available = configured_model_provider_profiles(config);
38 let available = if available.is_empty() {
39 "no configured provider profiles".to_string()
40 } else {
41 available.join(", ")
42 };
43 return Err(format!(
44 "model_provider `{raw}` is not a configured provider profile. Add a [providers.models.{family}.{alias}] entry or use one of: {available}"
45 ));
46 }
47
48 Ok(format!("{family}.{alias}"))
49}
50
51pub struct ModelSwitchTool {
52 security: Arc<SecurityPolicy>,
53 config: Arc<Config>,
54}
55
56impl ModelSwitchTool {
57 pub const NAME: &'static str = "model_switch";
60
61 pub fn new(security: Arc<SecurityPolicy>, config: Arc<Config>) -> Self {
62 Self { security, config }
63 }
64}
65
66#[async_trait]
67impl Tool for ModelSwitchTool {
68 fn name(&self) -> &str {
69 Self::NAME
70 }
71
72 fn description(&self) -> &str {
73 "Request a runtime model switch using a configured provider profile plus provider-local model. Use 'get' to see the pending switch, 'list_model_providers' to see provider families, 'list_models' to see common models for a provider profile, or 'set' with a dotted provider profile ref such as 'openai.default'. The switch is runtime/session state and does not write config."
74 }
75
76 fn parameters_schema(&self) -> serde_json::Value {
77 json!({
78 "type": "object",
79 "properties": {
80 "action": {
81 "type": "string",
82 "enum": ["get", "set", "list_model_providers", "list_models"],
83 "description": "Action to perform: get pending switch state, set a runtime provider-profile/model switch, list available provider families, or list common models for a provider profile"
84 },
85 "model_provider": {
86 "type": "string",
87 "description": "Dotted provider profile reference (e.g., 'openai.default', 'anthropic.sonnet', 'ollama.local'). Required for 'set' and 'list_models' actions."
88 },
89 "model": {
90 "type": "string",
91 "description": "Model ID (e.g., 'gpt-4o', 'claude-sonnet-4-6'). Required for 'set' action."
92 }
93 },
94 "required": ["action"]
95 })
96 }
97
98 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
99 let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get");
100
101 if let Err(error) = self
102 .security
103 .enforce_tool_operation(ToolOperation::Act, "model_switch")
104 {
105 return Ok(ToolResult {
106 success: false,
107 output: String::new(),
108 error: Some(error),
109 });
110 }
111
112 match action {
113 "get" => self.handle_get(),
114 "set" => self.handle_set(&args),
115 "list_model_providers" => self.handle_list_providers(),
116 "list_models" => self.handle_list_models(&args),
117 _ => Ok(ToolResult {
118 success: false,
119 output: String::new(),
120 error: Some(format!(
121 "Unknown action: {}. Valid actions: get, set, list_model_providers, list_models",
122 action
123 )),
124 }),
125 }
126 }
127}
128
129impl ModelSwitchTool {
130 fn handle_get(&self) -> anyhow::Result<ToolResult> {
131 let switch_state = get_model_switch_state();
132 let pending = switch_state.lock().unwrap().clone();
133
134 Ok(ToolResult {
135 success: true,
136 output: serde_json::to_string_pretty(&json!({
137 "pending_switch": pending,
138 "note": "To switch models, use action 'set' with dotted <type>.<alias> model_provider and model parameters"
139 }))?,
140 error: None,
141 })
142 }
143
144 fn handle_set(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
145 let model_provider = args.get("model_provider").and_then(|v| v.as_str());
146
147 let model_provider = match model_provider {
148 Some(p) => p,
149 None => {
150 return Ok(ToolResult {
151 success: false,
152 output: String::new(),
153 error: Some("Missing 'model_provider' parameter for 'set' action".to_string()),
154 });
155 }
156 };
157
158 let model = args.get("model").and_then(|v| v.as_str());
159
160 let model = match model {
161 Some(m) => m,
162 None => {
163 return Ok(ToolResult {
164 success: false,
165 output: String::new(),
166 error: Some("Missing 'model' parameter for 'set' action".to_string()),
167 });
168 }
169 };
170
171 let model_provider = match resolve_model_provider_profile_ref(&self.config, model_provider)
172 {
173 Ok(model_provider) => model_provider,
174 Err(error) => {
175 let known_model_providers = zeroclaw_providers::list_model_providers();
176 let configured_profiles = configured_model_provider_profiles(&self.config);
177 return Ok(ToolResult {
178 success: false,
179 output: serde_json::to_string_pretty(&json!({
180 "provider_ref_shape": "<type>.<alias>",
181 "available_provider_families": known_model_providers.iter().map(|p| p.name).collect::<Vec<_>>(),
182 "configured_provider_profiles": configured_profiles
183 }))?,
184 error: Some(error),
185 });
186 }
187 };
188
189 let model = model.trim();
190 if model.is_empty() {
191 return Ok(ToolResult {
192 success: false,
193 output: String::new(),
194 error: Some("Model ID cannot be empty".to_string()),
195 });
196 }
197
198 let switch_state = get_model_switch_state();
200 *switch_state.lock().unwrap() = Some((model_provider.clone(), model.to_string()));
201
202 Ok(ToolResult {
203 success: true,
204 output: serde_json::to_string_pretty(&json!({
205 "message": "Model switch requested",
206 "model_provider": model_provider,
207 "model": model,
208 "note": "The active runtime path will consume this provider-profile/model switch where model_switch is supported. This does not write persisted config."
209 }))?,
210 error: None,
211 })
212 }
213
214 fn handle_list_providers(&self) -> anyhow::Result<ToolResult> {
215 let providers_list = zeroclaw_providers::list_model_providers();
216 let configured_profiles = configured_model_provider_profiles(&self.config);
217 let configured_count = configured_profiles.len();
218
219 let model_providers: Vec<serde_json::Value> = providers_list
220 .iter()
221 .map(|p| {
222 json!({
223 "name": p.name,
224 "display_name": p.display_name,
225 "local": p.local
226 })
227 })
228 .collect();
229
230 Ok(ToolResult {
231 success: true,
232 output: serde_json::to_string_pretty(&json!({
233 "model_providers": model_providers,
234 "count": model_providers.len(),
235 "configured_provider_profiles": configured_profiles,
236 "configured_count": configured_count,
237 "provider_ref_shape": "<type>.<alias>",
238 "example": "Use action 'set' with a dotted provider profile ref such as 'openai.default'"
239 }))?,
240 error: None,
241 })
242 }
243
244 fn handle_list_models(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
245 let model_provider = args.get("model_provider").and_then(|v| v.as_str());
246
247 let model_provider = match model_provider {
248 Some(p) => p,
249 None => {
250 return Ok(ToolResult {
251 success: false,
252 output: String::new(),
253 error: Some(
254 "Missing 'model_provider' parameter for 'list_models' action".to_string(),
255 ),
256 });
257 }
258 };
259
260 let model_provider = match resolve_model_provider_profile_ref(&self.config, model_provider)
261 {
262 Ok(model_provider) => model_provider,
263 Err(error) => {
264 return Ok(ToolResult {
265 success: false,
266 output: serde_json::to_string_pretty(&json!({
267 "provider_ref_shape": "<type>.<alias>",
268 "configured_provider_profiles": configured_model_provider_profiles(&self.config)
269 }))?,
270 error: Some(error),
271 });
272 }
273 };
274 let provider_family = model_provider
275 .split_once('.')
276 .map(|(family, _alias)| family)
277 .unwrap_or(model_provider.as_str());
278
279 let models = match provider_family.to_lowercase().as_str() {
281 "openai" => vec![
282 "gpt-4o",
283 "gpt-4o-mini",
284 "gpt-4-turbo",
285 "gpt-4",
286 "gpt-3.5-turbo",
287 ],
288 "anthropic" => vec![
289 "claude-sonnet-4-6",
290 "claude-sonnet-4-5",
291 "claude-3-5-sonnet",
292 "claude-3-opus",
293 "claude-3-haiku",
294 ],
295 "openrouter" => vec![
296 "anthropic/claude-sonnet-4-6",
297 "openai/gpt-4o",
298 "google/gemini-pro",
299 "meta-llama/llama-3-70b-instruct",
300 ],
301 "groq" => vec![
302 "llama-3.3-70b-versatile",
303 "mixtral-8x7b-32768",
304 "llama-3.1-70b-speculative",
305 ],
306 "ollama" => vec!["llama3", "llama3.1", "mistral", "codellama", "phi3"],
307 "deepseek" => vec!["deepseek-chat", "deepseek-coder"],
308 "mistral" => vec![
309 "mistral-large-latest",
310 "mistral-small-latest",
311 "mistral-nemo",
312 ],
313 "gemini" => vec!["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"],
314 "xai" => vec!["grok-2", "grok-2-vision", "grok-beta"],
315 _ => vec![],
316 };
317
318 if models.is_empty() {
319 return Ok(ToolResult {
320 success: true,
321 output: serde_json::to_string_pretty(&json!({
322 "model_provider": model_provider,
323 "models": [],
324 "note": "No common models listed for this model_provider family. Check model_provider documentation for available models."
325 }))?,
326 error: None,
327 });
328 }
329
330 Ok(ToolResult {
331 success: true,
332 output: serde_json::to_string_pretty(&json!({
333 "model_provider": model_provider,
334 "models": models,
335 "example": "Use action 'set' with this model_provider and a model ID to switch"
336 }))?,
337 error: None,
338 })
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::agent::loop_::{clear_model_switch_request, get_model_switch_state};
346
347 static MODEL_SWITCH_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
348
349 fn test_config() -> Config {
350 let mut config = Config::default();
351 config.providers.models.ensure("openai", "default").unwrap();
352 config.providers.models.ensure("custom", "local").unwrap();
353 config
354 }
355
356 fn tool() -> ModelSwitchTool {
357 ModelSwitchTool::new(Arc::new(SecurityPolicy::default()), Arc::new(test_config()))
358 }
359
360 fn pending_switch() -> Option<(String, String)> {
361 get_model_switch_state().lock().unwrap().clone()
362 }
363
364 #[test]
365 fn set_rejects_bare_provider_family() {
366 let _guard = MODEL_SWITCH_TEST_LOCK.lock().unwrap();
367 clear_model_switch_request();
368
369 let result = tool()
370 .handle_set(&json!({
371 "model_provider": "openai",
372 "model": "gpt-4o"
373 }))
374 .expect("set should return a tool result");
375
376 assert!(!result.success);
377 assert!(
378 result
379 .error
380 .as_deref()
381 .unwrap_or_default()
382 .contains("dotted `<type>.<alias>`"),
383 "unexpected error: {:?}",
384 result.error
385 );
386 assert_eq!(pending_switch(), None);
387 }
388
389 #[test]
390 fn set_accepts_dotted_provider_profile_ref() {
391 let _guard = MODEL_SWITCH_TEST_LOCK.lock().unwrap();
392 clear_model_switch_request();
393
394 let result = tool()
395 .handle_set(&json!({
396 "model_provider": "openai.default",
397 "model": "gpt-4o"
398 }))
399 .expect("set should return a tool result");
400
401 assert!(result.success, "unexpected error: {:?}", result.error);
402 assert_eq!(
403 pending_switch(),
404 Some(("openai.default".to_string(), "gpt-4o".to_string()))
405 );
406
407 clear_model_switch_request();
408 }
409
410 #[test]
411 fn set_rejects_unconfigured_provider_profile_ref() {
412 let _guard = MODEL_SWITCH_TEST_LOCK.lock().unwrap();
413 clear_model_switch_request();
414
415 let result = tool()
416 .handle_set(&json!({
417 "model_provider": "openai.missing",
418 "model": "gpt-4o"
419 }))
420 .expect("set should return a tool result");
421
422 assert!(!result.success);
423 assert!(
424 result
425 .error
426 .as_deref()
427 .unwrap_or_default()
428 .contains("configured provider profile"),
429 "unexpected error: {:?}",
430 result.error
431 );
432 assert_eq!(pending_switch(), None);
433 }
434
435 #[test]
436 fn set_accepts_configured_custom_provider_profile_ref() {
437 let _guard = MODEL_SWITCH_TEST_LOCK.lock().unwrap();
438 clear_model_switch_request();
439
440 let result = tool()
441 .handle_set(&json!({
442 "model_provider": "custom.local",
443 "model": "local-model"
444 }))
445 .expect("set should return a tool result");
446
447 assert!(result.success, "unexpected error: {:?}", result.error);
448 assert_eq!(
449 pending_switch(),
450 Some(("custom.local".to_string(), "local-model".to_string()))
451 );
452
453 clear_model_switch_request();
454 }
455
456 #[test]
457 fn list_models_accepts_dotted_provider_profile_ref() {
458 let result = tool()
459 .handle_list_models(&json!({
460 "model_provider": "openai.default"
461 }))
462 .expect("list_models should return a tool result");
463
464 assert!(result.success, "unexpected error: {:?}", result.error);
465 let output: serde_json::Value =
466 serde_json::from_str(&result.output).expect("output should be json");
467 assert_eq!(output["model_provider"], "openai.default");
468 assert!(
469 output["models"]
470 .as_array()
471 .expect("models should be an array")
472 .iter()
473 .any(|model| model == "gpt-4o")
474 );
475 }
476}