1use super::ModelProvider;
2use super::traits::{
3 ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamEvent, StreamOptions, StreamResult,
4};
5use async_trait::async_trait;
6use futures_util::stream::BoxStream;
7use std::collections::HashMap;
8
9fn score_model(pricing: &HashMap<String, f64>, model: &str) -> Option<f64> {
13 let mut total = 0.0;
14 let mut matched = false;
15 if let Some(v) = pricing.get(model) {
16 total += *v;
17 matched = true;
18 }
19 if let Some(v) = pricing.get(&format!("{model}.input")) {
20 total += *v;
21 matched = true;
22 }
23 if let Some(v) = pricing.get(&format!("{model}.output")) {
24 total += *v;
25 matched = true;
26 }
27 matched.then_some(total)
28}
29
30#[derive(Debug, Clone)]
32pub struct Route {
33 pub provider_name: String,
34 pub model: String,
35}
36
37pub struct RouterModelProvider {
46 alias: String,
48 routes: HashMap<String, (usize, String)>, model_providers: Vec<(String, Box<dyn ModelProvider>)>,
50 default_index: usize,
51 default_model: String,
52}
53
54impl RouterModelProvider {
55 pub fn new(
60 alias: &str,
61 model_providers: Vec<(String, Box<dyn ModelProvider>)>,
62 routes: Vec<(String, Route)>,
63 default_model: String,
64 ) -> Self {
65 let name_to_index: HashMap<&str, usize> = model_providers
67 .iter()
68 .enumerate()
69 .map(|(i, (name, _))| (name.as_str(), i))
70 .collect();
71
72 let resolved_routes: HashMap<String, (usize, String)> = routes
74 .into_iter()
75 .filter_map(|(hint, route)| {
76 let index = name_to_index.get(route.provider_name.as_str()).copied();
77 match index {
78 Some(i) => Some((hint, (i, route.model))),
79 None => {
80 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"hint": hint, "model_provider": route.provider_name})), "Route references unknown model_provider, skipping");
81 None
82 }
83 }
84 })
85 .collect();
86
87 Self {
88 alias: alias.to_string(),
89 routes: resolved_routes,
90 model_providers,
91 default_index: 0,
92 default_model,
93 }
94 }
95 pub fn resolve_cost_optimized(
104 &self,
105 model: &str,
106 model_provider_pricing: &HashMap<String, HashMap<String, f64>>,
107 required_vision: bool,
108 required_tools: bool,
109 ) -> (usize, String) {
110 let hint = model.strip_prefix("hint:");
111 let is_cost_hint = matches!(hint, Some("cost-optimized" | "cheapest"));
112
113 if !is_cost_hint {
114 return self.resolve(model);
115 }
116
117 let mut candidates: Vec<(usize, String, f64)> = Vec::new();
118
119 for (idx, route_model) in self.routes.values() {
120 if let Some((_, model_provider)) = self.model_providers.get(*idx) {
122 if required_vision && !model_provider.supports_vision() {
123 continue;
124 }
125 if required_tools && !model_provider.supports_native_tools() {
126 continue;
127 }
128 }
129
130 let Some((model_provider_name, _)) = self.model_providers.get(*idx) else {
131 continue;
132 };
133 if let Some(pricing) = model_provider_pricing.get(model_provider_name)
134 && let Some(total_cost) = score_model(pricing, route_model)
135 {
136 candidates.push((*idx, route_model.clone(), total_cost));
137 }
138 }
139
140 candidates.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
142
143 if let Some((idx, route_model, _)) = candidates.into_iter().next() {
144 return (idx, route_model);
145 }
146
147 ::zeroclaw_log::record!(
149 WARN,
150 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
151 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
152 "No cost-optimized route found with matching pricing data, \
153 falling back to default"
154 );
155 (self.default_index, self.default_model.clone())
156 }
157
158 fn resolve(&self, model: &str) -> (usize, String) {
164 if let Some(hint) = model.strip_prefix("hint:") {
165 if let Some((idx, resolved_model)) = self.routes.get(hint) {
166 return (*idx, resolved_model.clone());
167 }
168 ::zeroclaw_log::record!(
169 WARN,
170 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
171 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
172 .with_attrs(::serde_json::json!({"hint": hint})),
173 "Unknown route hint, falling back to default model_provider"
174 );
175 }
176
177 (self.default_index, model.to_string())
179 }
180}
181
182#[derive(Debug, Clone)]
190pub struct CostOptimizedStrategy {
191 pub model_provider_pricing: HashMap<String, HashMap<String, f64>>,
193 pub required_vision: bool,
195 pub required_tools: bool,
197}
198
199impl CostOptimizedStrategy {
200 pub fn new(model_provider_pricing: HashMap<String, HashMap<String, f64>>) -> Self {
203 Self {
204 model_provider_pricing,
205 required_vision: false,
206 required_tools: false,
207 }
208 }
209
210 pub fn with_vision(mut self, required: bool) -> Self {
212 self.required_vision = required;
213 self
214 }
215
216 pub fn with_tools(mut self, required: bool) -> Self {
218 self.required_tools = required;
219 self
220 }
221
222 pub fn score(&self, model_provider_name: &str, model: &str) -> Option<f64> {
225 let pricing = self.model_provider_pricing.get(model_provider_name)?;
226 score_model(pricing, model)
227 }
228}
229
230#[async_trait]
231impl ModelProvider for RouterModelProvider {
232 async fn chat_with_system(
233 &self,
234 system_prompt: Option<&str>,
235 message: &str,
236 model: &str,
237 temperature: Option<f64>,
238 ) -> anyhow::Result<String> {
239 let (provider_idx, resolved_model) = self.resolve(model);
240
241 let (provider_name, model_provider) = &self.model_providers[provider_idx];
242 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name.as_str(), "model": resolved_model.as_str()})), "router dispatching request");
246
247 model_provider
248 .chat_with_system(system_prompt, message, &resolved_model, temperature)
249 .await
250 }
251
252 async fn chat_with_history(
253 &self,
254 messages: &[ChatMessage],
255 model: &str,
256 temperature: Option<f64>,
257 ) -> anyhow::Result<String> {
258 let (provider_idx, resolved_model) = self.resolve(model);
259 let (_, model_provider) = &self.model_providers[provider_idx];
260 model_provider
261 .chat_with_history(messages, &resolved_model, temperature)
262 .await
263 }
264
265 async fn chat(
266 &self,
267 request: ChatRequest<'_>,
268 model: &str,
269 temperature: Option<f64>,
270 ) -> anyhow::Result<ChatResponse> {
271 let (provider_idx, resolved_model) = self.resolve(model);
272 let (_, model_provider) = &self.model_providers[provider_idx];
273 model_provider
274 .chat(request, &resolved_model, temperature)
275 .await
276 }
277
278 async fn chat_with_tools(
279 &self,
280 messages: &[ChatMessage],
281 tools: &[serde_json::Value],
282 model: &str,
283 temperature: Option<f64>,
284 ) -> anyhow::Result<ChatResponse> {
285 let (provider_idx, resolved_model) = self.resolve(model);
286 let (_, model_provider) = &self.model_providers[provider_idx];
287 model_provider
288 .chat_with_tools(messages, tools, &resolved_model, temperature)
289 .await
290 }
291
292 fn supports_native_tools(&self) -> bool {
293 self.model_providers
294 .get(self.default_index)
295 .map(|(_, p)| p.supports_native_tools())
296 .unwrap_or(false)
297 }
298
299 fn supports_streaming(&self) -> bool {
300 self.model_providers
301 .iter()
302 .any(|(_, model_provider)| model_provider.supports_streaming())
303 }
304
305 fn supports_streaming_tool_events(&self) -> bool {
306 self.model_providers
307 .iter()
308 .any(|(_, model_provider)| model_provider.supports_streaming_tool_events())
309 }
310
311 fn stream_chat_with_system(
312 &self,
313 system_prompt: Option<&str>,
314 message: &str,
315 model: &str,
316 temperature: Option<f64>,
317 options: StreamOptions,
318 ) -> BoxStream<'static, StreamResult<StreamChunk>> {
319 let (provider_idx, resolved_model) = self.resolve(model);
320 let (_, model_provider) = &self.model_providers[provider_idx];
321 model_provider.stream_chat_with_system(
322 system_prompt,
323 message,
324 &resolved_model,
325 temperature,
326 options,
327 )
328 }
329
330 fn stream_chat_with_history(
331 &self,
332 messages: &[ChatMessage],
333 model: &str,
334 temperature: Option<f64>,
335 options: StreamOptions,
336 ) -> BoxStream<'static, StreamResult<StreamChunk>> {
337 let (provider_idx, resolved_model) = self.resolve(model);
338 let (_, model_provider) = &self.model_providers[provider_idx];
339 model_provider.stream_chat_with_history(messages, &resolved_model, temperature, options)
340 }
341
342 fn stream_chat(
343 &self,
344 request: ChatRequest<'_>,
345 model: &str,
346 temperature: Option<f64>,
347 options: StreamOptions,
348 ) -> BoxStream<'static, StreamResult<StreamEvent>> {
349 let (provider_idx, resolved_model) = self.resolve(model);
350 let (_, model_provider) = &self.model_providers[provider_idx];
351 model_provider.stream_chat(request, &resolved_model, temperature, options)
352 }
353
354 fn supports_vision(&self) -> bool {
355 self.model_providers
356 .get(self.default_index)
357 .map(|(_, p)| p.supports_vision())
358 .unwrap_or(false)
359 }
360
361 async fn warmup(&self) -> anyhow::Result<()> {
362 for (name, model_provider) in &self.model_providers {
363 ::zeroclaw_log::record!(
364 INFO,
365 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
366 .with_attrs(::serde_json::json!({"model_provider": name})),
367 "Warming up routed model_provider"
368 );
369 if let Err(e) = model_provider.warmup().await {
370 ::zeroclaw_log::record!(
371 WARN,
372 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
373 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
374 .with_attrs(
375 ::serde_json::json!({"error": format!("{}", e), "model_provider": name})
376 ),
377 "Warmup failed (non-fatal)"
378 );
379 }
380 }
381 Ok(())
382 }
383}
384
385impl ::zeroclaw_api::attribution::Attributable for RouterModelProvider {
386 fn role(&self) -> ::zeroclaw_api::attribution::Role {
387 ::zeroclaw_api::attribution::Role::Provider(
388 ::zeroclaw_api::attribution::ProviderKind::Model(
389 ::zeroclaw_api::attribution::ModelProviderKind::Router,
390 ),
391 )
392 }
393 fn alias(&self) -> &str {
394 &self.alias
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use futures_util::StreamExt;
402 use std::sync::Arc;
403 use std::sync::atomic::{AtomicUsize, Ordering};
404 use zeroclaw_api::tool::ToolSpec;
405
406 struct MockModelProvider {
407 calls: Arc<AtomicUsize>,
408 response: &'static str,
409 last_model: parking_lot::Mutex<String>,
410 vision: bool,
411 }
412
413 impl MockModelProvider {
414 fn new(response: &'static str) -> Self {
415 Self {
416 calls: Arc::new(AtomicUsize::new(0)),
417 response,
418 last_model: parking_lot::Mutex::new(String::new()),
419 vision: false,
420 }
421 }
422
423 fn with_vision(mut self, vision: bool) -> Self {
424 self.vision = vision;
425 self
426 }
427
428 fn call_count(&self) -> usize {
429 self.calls.load(Ordering::SeqCst)
430 }
431
432 fn last_model(&self) -> String {
433 self.last_model.lock().clone()
434 }
435 }
436
437 #[async_trait]
438 impl ModelProvider for MockModelProvider {
439 async fn chat_with_system(
440 &self,
441 _system_prompt: Option<&str>,
442 _message: &str,
443 model: &str,
444 _temperature: Option<f64>,
445 ) -> anyhow::Result<String> {
446 self.calls.fetch_add(1, Ordering::SeqCst);
447 *self.last_model.lock() = model.to_string();
448 Ok(self.response.to_string())
449 }
450
451 fn supports_vision(&self) -> bool {
452 self.vision
453 }
454 }
455 impl ::zeroclaw_api::attribution::Attributable for MockModelProvider {
456 fn role(&self) -> ::zeroclaw_api::attribution::Role {
457 ::zeroclaw_api::attribution::Role::Provider(
458 ::zeroclaw_api::attribution::ProviderKind::Model(
459 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
460 ),
461 )
462 }
463 fn alias(&self) -> &str {
464 "MockModelProvider"
465 }
466 }
467
468 fn make_router(
469 model_providers: Vec<(&'static str, &'static str)>,
470 routes: Vec<(&str, &str, &str)>,
471 ) -> (RouterModelProvider, Vec<Arc<MockModelProvider>>) {
472 let mocks: Vec<Arc<MockModelProvider>> = model_providers
473 .iter()
474 .map(|(_, response)| Arc::new(MockModelProvider::new(response)))
475 .collect();
476
477 let provider_list: Vec<(String, Box<dyn ModelProvider>)> = model_providers
478 .iter()
479 .zip(mocks.iter())
480 .map(|((name, _), mock)| {
481 (
482 (*name).to_string(),
483 Box::new(Arc::clone(mock)) as Box<dyn ModelProvider>,
484 )
485 })
486 .collect();
487
488 let route_list: Vec<(String, Route)> = routes
489 .iter()
490 .map(|(hint, provider_name, model)| {
491 (
492 (*hint).to_string(),
493 Route {
494 provider_name: (*provider_name).to_string(),
495 model: (*model).to_string(),
496 },
497 )
498 })
499 .collect();
500
501 let router = RouterModelProvider::new(
502 "test",
503 provider_list,
504 route_list,
505 "default-model".to_string(),
506 );
507
508 (router, mocks)
509 }
510
511 struct StreamingMockModelProvider {
514 stream_calls: Arc<AtomicUsize>,
515 last_stream_model: parking_lot::Mutex<String>,
516 response: &'static str,
517 }
518
519 impl StreamingMockModelProvider {
520 fn new(response: &'static str) -> Self {
521 Self {
522 stream_calls: Arc::new(AtomicUsize::new(0)),
523 last_stream_model: parking_lot::Mutex::new(String::new()),
524 response,
525 }
526 }
527
528 fn stream_response(&self, model: &str) -> BoxStream<'static, StreamResult<StreamChunk>> {
529 self.stream_calls.fetch_add(1, Ordering::SeqCst);
530 *self.last_stream_model.lock() = model.to_string();
531 let chunks = vec![
532 Ok(StreamChunk::delta(self.response)),
533 Ok(StreamChunk::final_chunk()),
534 ];
535 futures_util::stream::iter(chunks).boxed()
536 }
537 }
538
539 #[async_trait]
540 impl ModelProvider for StreamingMockModelProvider {
541 async fn chat_with_system(
542 &self,
543 _system_prompt: Option<&str>,
544 _message: &str,
545 _model: &str,
546 _temperature: Option<f64>,
547 ) -> anyhow::Result<String> {
548 Ok("ok".to_string())
549 }
550
551 fn supports_streaming(&self) -> bool {
552 true
553 }
554
555 fn stream_chat_with_system(
556 &self,
557 _system_prompt: Option<&str>,
558 _message: &str,
559 model: &str,
560 _temperature: Option<f64>,
561 _options: StreamOptions,
562 ) -> BoxStream<'static, StreamResult<StreamChunk>> {
563 self.stream_response(model)
564 }
565
566 fn stream_chat_with_history(
567 &self,
568 _messages: &[ChatMessage],
569 model: &str,
570 _temperature: Option<f64>,
571 _options: StreamOptions,
572 ) -> BoxStream<'static, StreamResult<StreamChunk>> {
573 self.stream_response(model)
574 }
575 }
576 impl ::zeroclaw_api::attribution::Attributable for StreamingMockModelProvider {
577 fn role(&self) -> ::zeroclaw_api::attribution::Role {
578 ::zeroclaw_api::attribution::Role::Provider(
579 ::zeroclaw_api::attribution::ProviderKind::Model(
580 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
581 ),
582 )
583 }
584 fn alias(&self) -> &str {
585 "StreamingMockModelProvider"
586 }
587 }
588
589 struct ToolEventStreamingMockModelProvider {
592 stream_calls: Arc<AtomicUsize>,
593 tool_event_calls: Arc<AtomicUsize>,
594 last_stream_model: parking_lot::Mutex<String>,
595 }
596
597 impl ToolEventStreamingMockModelProvider {
598 fn new() -> Self {
599 Self {
600 stream_calls: Arc::new(AtomicUsize::new(0)),
601 tool_event_calls: Arc::new(AtomicUsize::new(0)),
602 last_stream_model: parking_lot::Mutex::new(String::new()),
603 }
604 }
605 }
606
607 #[async_trait]
608 impl ModelProvider for ToolEventStreamingMockModelProvider {
609 async fn chat_with_system(
610 &self,
611 _system_prompt: Option<&str>,
612 _message: &str,
613 _model: &str,
614 _temperature: Option<f64>,
615 ) -> anyhow::Result<String> {
616 Ok("ok".to_string())
617 }
618
619 fn supports_streaming(&self) -> bool {
620 true
621 }
622
623 fn supports_streaming_tool_events(&self) -> bool {
624 true
625 }
626
627 fn stream_chat(
628 &self,
629 request: ChatRequest<'_>,
630 model: &str,
631 _temperature: Option<f64>,
632 _options: StreamOptions,
633 ) -> BoxStream<'static, StreamResult<StreamEvent>> {
634 self.stream_calls.fetch_add(1, Ordering::SeqCst);
635 if request.tools.is_some_and(|tools| !tools.is_empty()) {
636 self.tool_event_calls.fetch_add(1, Ordering::SeqCst);
637 }
638 *self.last_stream_model.lock() = model.to_string();
639 futures_util::stream::iter(vec![
640 Ok(StreamEvent::ToolCall(crate::traits::ToolCall {
641 id: "call_router_1".to_string(),
642 name: "shell".to_string(),
643 arguments: r#"{"command":"date"}"#.to_string(),
644 extra_content: None,
645 })),
646 Ok(StreamEvent::Final),
647 ])
648 .boxed()
649 }
650 }
651 impl ::zeroclaw_api::attribution::Attributable for ToolEventStreamingMockModelProvider {
652 fn role(&self) -> ::zeroclaw_api::attribution::Role {
653 ::zeroclaw_api::attribution::Role::Provider(
654 ::zeroclaw_api::attribution::ProviderKind::Model(
655 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
656 ),
657 )
658 }
659 fn alias(&self) -> &str {
660 "ToolEventStreamingMockModelProvider"
661 }
662 }
663
664 #[tokio::test]
667 async fn routes_hint_to_correct_provider() {
668 let (router, mocks) = make_router(
669 vec![("fast", "fast-response"), ("smart", "smart-response")],
670 vec![
671 ("fast", "fast", "llama-3-70b"),
672 ("reasoning", "smart", "claude-opus"),
673 ],
674 );
675
676 let result = router
677 .simple_chat("hello", "hint:reasoning", Some(0.5))
678 .await
679 .unwrap();
680 assert_eq!(result, "smart-response");
681 assert_eq!(mocks[1].call_count(), 1);
682 assert_eq!(mocks[1].last_model(), "claude-opus");
683 assert_eq!(mocks[0].call_count(), 0);
684 }
685
686 #[tokio::test]
687 async fn routes_fast_hint() {
688 let (router, mocks) = make_router(
689 vec![("fast", "fast-response"), ("smart", "smart-response")],
690 vec![("fast", "fast", "llama-3-70b")],
691 );
692
693 let result = router
694 .simple_chat("hello", "hint:fast", Some(0.5))
695 .await
696 .unwrap();
697 assert_eq!(result, "fast-response");
698 assert_eq!(mocks[0].call_count(), 1);
699 assert_eq!(mocks[0].last_model(), "llama-3-70b");
700 }
701
702 #[tokio::test]
703 async fn unknown_hint_falls_back_to_default() {
704 let (router, mocks) = make_router(
705 vec![("default", "default-response"), ("other", "other-response")],
706 vec![],
707 );
708
709 let result = router
710 .simple_chat("hello", "hint:nonexistent", Some(0.5))
711 .await
712 .unwrap();
713 assert_eq!(result, "default-response");
714 assert_eq!(mocks[0].call_count(), 1);
715 assert_eq!(mocks[0].last_model(), "hint:nonexistent");
717 }
718
719 #[tokio::test]
720 async fn non_hint_model_uses_default_provider() {
721 let (router, mocks) = make_router(
722 vec![
723 ("primary", "primary-response"),
724 ("secondary", "secondary-response"),
725 ],
726 vec![("code", "secondary", "codellama")],
727 );
728
729 let result = router
730 .simple_chat("hello", "anthropic/claude-sonnet-4-20250514", Some(0.5))
731 .await
732 .unwrap();
733 assert_eq!(result, "primary-response");
734 assert_eq!(mocks[0].call_count(), 1);
735 assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
736 }
737
738 #[test]
739 fn resolve_preserves_model_for_non_hints() {
740 let (router, _) = make_router(vec![("default", "ok")], vec![]);
741
742 let (idx, model) = router.resolve("gpt-4o");
743 assert_eq!(idx, 0);
744 assert_eq!(model, "gpt-4o");
745 }
746
747 #[test]
748 fn resolve_strips_hint_prefix() {
749 let (router, _) = make_router(
750 vec![("fast", "ok"), ("smart", "ok")],
751 vec![("reasoning", "smart", "claude-opus")],
752 );
753
754 let (idx, model) = router.resolve("hint:reasoning");
755 assert_eq!(idx, 1);
756 assert_eq!(model, "claude-opus");
757 }
758
759 #[test]
760 fn skips_routes_with_unknown_provider() {
761 let (router, _) = make_router(
762 vec![("default", "ok")],
763 vec![("broken", "nonexistent", "model")],
764 );
765
766 assert!(!router.routes.contains_key("broken"));
768 }
769
770 #[tokio::test]
771 async fn warmup_calls_all_providers() {
772 let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]);
773
774 assert!(router.warmup().await.is_ok());
776 }
777
778 #[tokio::test]
779 async fn chat_with_system_passes_system_prompt() {
780 let mock = Arc::new(MockModelProvider::new("response"));
781 let router = RouterModelProvider::new(
782 "test",
783 vec![(
784 "default".into(),
785 Box::new(Arc::clone(&mock)) as Box<dyn ModelProvider>,
786 )],
787 vec![],
788 "model".into(),
789 );
790
791 let result = router
792 .chat_with_system(Some("system"), "hello", "model", Some(0.5))
793 .await
794 .unwrap();
795 assert_eq!(result, "response");
796 assert_eq!(mock.call_count(), 1);
797 }
798
799 #[tokio::test]
800 async fn chat_with_tools_delegates_to_resolved_provider() {
801 let mock = Arc::new(MockModelProvider::new("tool-response"));
802 let router = RouterModelProvider::new(
803 "test",
804 vec![(
805 "default".into(),
806 Box::new(Arc::clone(&mock)) as Box<dyn ModelProvider>,
807 )],
808 vec![],
809 "model".into(),
810 );
811
812 let messages = vec![ChatMessage {
813 role: "user".to_string(),
814 content: "use tools".to_string(),
815 }];
816 let tools = vec![serde_json::json!({
817 "type": "function",
818 "function": {
819 "name": "shell",
820 "description": "Run shell command",
821 "parameters": {}
822 }
823 })];
824
825 let result = router
828 .chat_with_tools(&messages, &tools, "model", Some(0.7))
829 .await
830 .unwrap();
831 assert_eq!(result.text.as_deref(), Some("tool-response"));
832 assert_eq!(mock.call_count(), 1);
833 assert_eq!(mock.last_model(), "model");
834 }
835
836 #[tokio::test]
837 async fn chat_with_tools_routes_hint_correctly() {
838 let (router, mocks) = make_router(
839 vec![("fast", "fast-tool"), ("smart", "smart-tool")],
840 vec![("reasoning", "smart", "claude-opus")],
841 );
842
843 let messages = vec![ChatMessage {
844 role: "user".to_string(),
845 content: "reason about this".to_string(),
846 }];
847 let tools = vec![serde_json::json!({"type": "function", "function": {"name": "test"}})];
848
849 let result = router
850 .chat_with_tools(&messages, &tools, "hint:reasoning", Some(0.5))
851 .await
852 .unwrap();
853 assert_eq!(result.text.as_deref(), Some("smart-tool"));
854 assert_eq!(mocks[1].call_count(), 1);
855 assert_eq!(mocks[1].last_model(), "claude-opus");
856 assert_eq!(mocks[0].call_count(), 0);
857 }
858
859 use crate::traits::ProviderCapabilities;
862
863 struct CapableMockModelProvider {
865 response: &'static str,
866 vision: bool,
867 tools: bool,
868 }
869
870 impl CapableMockModelProvider {
871 fn new(response: &'static str, vision: bool, tools: bool) -> Self {
872 Self {
873 response,
874 vision,
875 tools,
876 }
877 }
878 }
879
880 #[async_trait]
881 impl ModelProvider for CapableMockModelProvider {
882 fn capabilities(&self) -> ProviderCapabilities {
883 ProviderCapabilities {
884 native_tool_calling: self.tools,
885 vision: self.vision,
886 prompt_caching: false,
887 extended_thinking: false,
888 }
889 }
890
891 async fn chat_with_system(
892 &self,
893 _system_prompt: Option<&str>,
894 _message: &str,
895 _model: &str,
896 _temperature: Option<f64>,
897 ) -> anyhow::Result<String> {
898 Ok(self.response.to_string())
899 }
900 }
901 impl ::zeroclaw_api::attribution::Attributable for CapableMockModelProvider {
902 fn role(&self) -> ::zeroclaw_api::attribution::Role {
903 ::zeroclaw_api::attribution::Role::Provider(
904 ::zeroclaw_api::attribution::ProviderKind::Model(
905 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
906 ),
907 )
908 }
909 fn alias(&self) -> &str {
910 "CapableMockModelProvider"
911 }
912 }
913
914 fn make_pricing(entries: Vec<(&str, &str, f64, f64)>) -> HashMap<String, HashMap<String, f64>> {
917 let mut map: HashMap<String, HashMap<String, f64>> = HashMap::new();
918 for (model_provider, model, input, output) in entries {
919 let inner = map.entry(model_provider.to_string()).or_default();
920 inner.insert(format!("{model}.input"), input);
921 inner.insert(format!("{model}.output"), output);
922 }
923 map
924 }
925
926 #[test]
927 fn cost_optimized_selects_cheapest_provider() {
928 let model_providers: Vec<(String, Box<dyn ModelProvider>)> = vec![
929 (
930 "expensive".into(),
931 Box::new(CapableMockModelProvider::new("exp", false, false)),
932 ),
933 (
934 "cheap".into(),
935 Box::new(CapableMockModelProvider::new("chp", false, false)),
936 ),
937 ];
938 let routes = vec![
939 (
940 "expensive".to_string(),
941 Route {
942 provider_name: "expensive".into(),
943 model: "big-model".into(),
944 },
945 ),
946 (
947 "cheap".to_string(),
948 Route {
949 provider_name: "cheap".into(),
950 model: "small-model".into(),
951 },
952 ),
953 ];
954 let router =
955 RouterModelProvider::new("test", model_providers, routes, "default-model".into());
956
957 let prices = make_pricing(vec![
958 ("expensive", "big-model", 15.0, 75.0),
959 ("cheap", "small-model", 0.25, 1.25),
960 ]);
961
962 let (idx, model) =
963 router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
964 assert_eq!(model, "small-model");
965 assert_eq!(idx, 1);
966 }
967
968 #[test]
969 fn cost_optimized_respects_vision_requirement() {
970 let model_providers: Vec<(String, Box<dyn ModelProvider>)> = vec![
971 (
972 "no-vision".into(),
973 Box::new(CapableMockModelProvider::new("nv", false, false)),
974 ),
975 (
976 "has-vision".into(),
977 Box::new(CapableMockModelProvider::new("hv", true, false)),
978 ),
979 ];
980 let routes = vec![
981 (
982 "cheap".to_string(),
983 Route {
984 provider_name: "no-vision".into(),
985 model: "cheap-model".into(),
986 },
987 ),
988 (
989 "vision".to_string(),
990 Route {
991 provider_name: "has-vision".into(),
992 model: "vision-model".into(),
993 },
994 ),
995 ];
996 let router =
997 RouterModelProvider::new("test", model_providers, routes, "default-model".into());
998
999 let prices = make_pricing(vec![
1000 ("no-vision", "cheap-model", 0.10, 0.40),
1001 ("has-vision", "vision-model", 3.0, 15.0),
1002 ]);
1003
1004 let (_, model) = router.resolve_cost_optimized("hint:cheapest", &prices, true, false);
1006 assert_eq!(model, "vision-model");
1007 }
1008
1009 #[test]
1010 fn cost_optimized_respects_tools_requirement() {
1011 let model_providers: Vec<(String, Box<dyn ModelProvider>)> = vec![
1012 (
1013 "no-tools".into(),
1014 Box::new(CapableMockModelProvider::new("nt", false, false)),
1015 ),
1016 (
1017 "has-tools".into(),
1018 Box::new(CapableMockModelProvider::new("ht", false, true)),
1019 ),
1020 ];
1021 let routes = vec![
1022 (
1023 "basic".to_string(),
1024 Route {
1025 provider_name: "no-tools".into(),
1026 model: "basic-model".into(),
1027 },
1028 ),
1029 (
1030 "tools".to_string(),
1031 Route {
1032 provider_name: "has-tools".into(),
1033 model: "tools-model".into(),
1034 },
1035 ),
1036 ];
1037 let router =
1038 RouterModelProvider::new("test", model_providers, routes, "default-model".into());
1039
1040 let prices = make_pricing(vec![
1041 ("no-tools", "basic-model", 0.10, 0.40),
1042 ("has-tools", "tools-model", 5.0, 15.0),
1043 ]);
1044
1045 let (_, model) = router.resolve_cost_optimized("hint:cost-optimized", &prices, false, true);
1047 assert_eq!(model, "tools-model");
1048 }
1049
1050 #[test]
1051 fn cost_optimized_falls_back_when_no_pricing() {
1052 let (router, _) = make_router(
1053 vec![("default", "ok"), ("other", "ok")],
1054 vec![("route-a", "other", "some-model")],
1055 );
1056
1057 let prices: HashMap<String, HashMap<String, f64>> = HashMap::new();
1059 let (idx, model) =
1060 router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
1061 assert_eq!(idx, 0);
1062 assert_eq!(model, "default-model");
1063 }
1064
1065 #[test]
1066 fn cost_optimized_with_single_route() {
1067 let model_providers: Vec<(String, Box<dyn ModelProvider>)> = vec![(
1068 "only".into(),
1069 Box::new(CapableMockModelProvider::new("ok", false, false)),
1070 )];
1071 let routes = vec![(
1072 "single".to_string(),
1073 Route {
1074 provider_name: "only".into(),
1075 model: "the-model".into(),
1076 },
1077 )];
1078 let router =
1079 RouterModelProvider::new("test", model_providers, routes, "default-model".into());
1080
1081 let prices = make_pricing(vec![("only", "the-model", 1.0, 2.0)]);
1082
1083 let (idx, model) = router.resolve_cost_optimized("hint:cheapest", &prices, false, false);
1084 assert_eq!(idx, 0);
1085 assert_eq!(model, "the-model");
1086 }
1087
1088 #[test]
1089 fn cost_optimized_prefers_lower_total_cost() {
1090 let model_providers: Vec<(String, Box<dyn ModelProvider>)> = vec![
1091 (
1092 "p1".into(),
1093 Box::new(CapableMockModelProvider::new("r1", false, false)),
1094 ),
1095 (
1096 "p2".into(),
1097 Box::new(CapableMockModelProvider::new("r2", false, false)),
1098 ),
1099 (
1100 "p3".into(),
1101 Box::new(CapableMockModelProvider::new("r3", false, false)),
1102 ),
1103 ];
1104 let routes = vec![
1105 (
1106 "a".to_string(),
1107 Route {
1108 provider_name: "p1".into(),
1109 model: "model-a".into(),
1110 },
1111 ),
1112 (
1113 "b".to_string(),
1114 Route {
1115 provider_name: "p2".into(),
1116 model: "model-b".into(),
1117 },
1118 ),
1119 (
1120 "c".to_string(),
1121 Route {
1122 provider_name: "p3".into(),
1123 model: "model-c".into(),
1124 },
1125 ),
1126 ];
1127 let router =
1128 RouterModelProvider::new("test", model_providers, routes, "default-model".into());
1129
1130 let prices = make_pricing(vec![
1131 ("p1", "model-a", 10.0, 50.0), ("p2", "model-b", 0.15, 0.60), ("p3", "model-c", 3.0, 15.0), ]);
1135
1136 let (idx, model) =
1137 router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
1138 assert_eq!(model, "model-b");
1139 assert_eq!(idx, 1);
1140 }
1141
1142 #[test]
1143 fn cost_optimized_strategy_score() {
1144 let prices = make_pricing(vec![
1145 ("cheap-provider", "cheap-model", 0.10, 0.40),
1146 ("expensive-provider", "expensive-model", 15.0, 75.0),
1147 ]);
1148 let strategy = CostOptimizedStrategy::new(prices);
1149
1150 assert!(
1151 (strategy.score("cheap-provider", "cheap-model").unwrap() - 0.50).abs() < f64::EPSILON
1152 );
1153 assert!(
1154 (strategy
1155 .score("expensive-provider", "expensive-model")
1156 .unwrap()
1157 - 90.0)
1158 .abs()
1159 < f64::EPSILON
1160 );
1161 assert!(strategy.score("cheap-provider", "unknown").is_none());
1162 assert!(strategy.score("unknown-provider", "cheap-model").is_none());
1163 }
1164
1165 #[tokio::test]
1166 async fn supports_streaming_returns_true_when_any_provider_supports_it() {
1167 let streaming = Arc::new(StreamingMockModelProvider::new("stream"));
1168 let router = RouterModelProvider::new(
1169 "test",
1170 vec![
1171 (
1172 "default".into(),
1173 Box::new(MockModelProvider::new("default")) as Box<dyn ModelProvider>,
1174 ),
1175 (
1176 "streaming".into(),
1177 Box::new(Arc::clone(&streaming)) as Box<dyn ModelProvider>,
1178 ),
1179 ],
1180 vec![(
1181 "reasoning".into(),
1182 Route {
1183 provider_name: "streaming".into(),
1184 model: "claude-opus".into(),
1185 },
1186 )],
1187 "model".into(),
1188 );
1189
1190 assert!(router.supports_streaming());
1191 }
1192
1193 #[tokio::test]
1194 async fn stream_chat_with_system_routes_hint_to_correct_provider_and_model() {
1195 let streaming = Arc::new(StreamingMockModelProvider::new("streamed system response"));
1196 let router = RouterModelProvider::new(
1197 "test",
1198 vec![
1199 (
1200 "default".into(),
1201 Box::new(MockModelProvider::new("default")) as Box<dyn ModelProvider>,
1202 ),
1203 (
1204 "streaming".into(),
1205 Box::new(Arc::clone(&streaming)) as Box<dyn ModelProvider>,
1206 ),
1207 ],
1208 vec![(
1209 "reasoning".into(),
1210 Route {
1211 provider_name: "streaming".into(),
1212 model: "claude-opus".into(),
1213 },
1214 )],
1215 "model".into(),
1216 );
1217
1218 let mut stream = router.stream_chat_with_system(
1219 Some("system"),
1220 "hello",
1221 "hint:reasoning",
1222 Some(0.0),
1223 StreamOptions::new(true),
1224 );
1225
1226 let mut collected = String::new();
1227 while let Some(chunk) = stream.next().await {
1228 let chunk = chunk.expect("stream chunk should be ok");
1229 collected.push_str(&chunk.delta);
1230 }
1231
1232 assert_eq!(collected, "streamed system response");
1233 assert_eq!(streaming.stream_calls.load(Ordering::SeqCst), 1);
1234 assert_eq!(*streaming.last_stream_model.lock(), "claude-opus");
1235 }
1236
1237 #[tokio::test]
1238 async fn stream_chat_with_history_routes_hint_to_correct_provider_and_model() {
1239 let streaming = Arc::new(StreamingMockModelProvider::new("streamed response"));
1240 let router = RouterModelProvider::new(
1241 "test",
1242 vec![
1243 (
1244 "default".into(),
1245 Box::new(MockModelProvider::new("default")) as Box<dyn ModelProvider>,
1246 ),
1247 (
1248 "streaming".into(),
1249 Box::new(Arc::clone(&streaming)) as Box<dyn ModelProvider>,
1250 ),
1251 ],
1252 vec![(
1253 "reasoning".into(),
1254 Route {
1255 provider_name: "streaming".into(),
1256 model: "claude-opus".into(),
1257 },
1258 )],
1259 "model".into(),
1260 );
1261
1262 let messages = vec![ChatMessage::user("hello")];
1263 let mut stream = router.stream_chat_with_history(
1264 &messages,
1265 "hint:reasoning",
1266 Some(0.0),
1267 StreamOptions::new(true),
1268 );
1269
1270 let mut collected = String::new();
1271 while let Some(chunk) = stream.next().await {
1272 let chunk = chunk.expect("stream chunk should be ok");
1273 collected.push_str(&chunk.delta);
1274 }
1275
1276 assert_eq!(collected, "streamed response");
1277 assert_eq!(streaming.stream_calls.load(Ordering::SeqCst), 1);
1278 assert_eq!(*streaming.last_stream_model.lock(), "claude-opus");
1279 }
1280
1281 #[tokio::test]
1282 async fn stream_chat_routes_hint_with_structured_tool_events() {
1283 let streaming = Arc::new(ToolEventStreamingMockModelProvider::new());
1284 let router = RouterModelProvider::new(
1285 "test",
1286 vec![
1287 (
1288 "default".into(),
1289 Box::new(MockModelProvider::new("default")) as Box<dyn ModelProvider>,
1290 ),
1291 (
1292 "streaming".into(),
1293 Box::new(Arc::clone(&streaming)) as Box<dyn ModelProvider>,
1294 ),
1295 ],
1296 vec![(
1297 "reasoning".into(),
1298 Route {
1299 provider_name: "streaming".into(),
1300 model: "claude-opus".into(),
1301 },
1302 )],
1303 "model".into(),
1304 );
1305
1306 let messages = vec![ChatMessage::user("hello")];
1307 let tools = vec![ToolSpec {
1308 name: "shell".to_string(),
1309 description: "run shell commands".to_string(),
1310 parameters: serde_json::json!({
1311 "type": "object",
1312 "properties": {
1313 "command": { "type": "string" }
1314 }
1315 }),
1316 }];
1317
1318 let mut stream = router.stream_chat(
1319 ChatRequest {
1320 messages: &messages,
1321 tools: Some(&tools),
1322 thinking: None,
1323 },
1324 "hint:reasoning",
1325 Some(0.0),
1326 StreamOptions::new(true),
1327 );
1328
1329 let first = stream.next().await.unwrap().unwrap();
1330 let second = stream.next().await.unwrap().unwrap();
1331 assert!(stream.next().await.is_none());
1332
1333 match first {
1334 StreamEvent::ToolCall(call) => {
1335 assert_eq!(call.name, "shell");
1336 assert_eq!(call.arguments, r#"{"command":"date"}"#);
1337 }
1338 other => panic!("expected tool-call event, got {other:?}"),
1339 }
1340 assert!(matches!(second, StreamEvent::Final));
1341 assert_eq!(streaming.stream_calls.load(Ordering::SeqCst), 1);
1342 assert_eq!(streaming.tool_event_calls.load(Ordering::SeqCst), 1);
1343 assert_eq!(*streaming.last_stream_model.lock(), "claude-opus");
1344 }
1345
1346 #[test]
1351 fn supports_vision_reflects_default_provider_not_any_route() {
1352 let default_provider = Box::new(MockModelProvider::new("nope").with_vision(false));
1353 let vision_route_provider = Box::new(MockModelProvider::new("ok").with_vision(true));
1354
1355 let router = RouterModelProvider::new(
1356 "test",
1357 vec![
1358 ("default".into(), default_provider as Box<dyn ModelProvider>),
1359 (
1360 "vision".into(),
1361 vision_route_provider as Box<dyn ModelProvider>,
1362 ),
1363 ],
1364 vec![(
1365 "hint:vision".into(),
1366 Route {
1367 provider_name: "vision".into(),
1368 model: "vision-model".into(),
1369 },
1370 )],
1371 "default-model".into(),
1372 );
1373
1374 assert!(
1375 !router.supports_vision(),
1376 "router with non-vision default must report supports_vision()=false even when a vision-capable route exists"
1377 );
1378 }
1379
1380 #[test]
1381 fn supports_vision_true_when_default_provider_supports_vision() {
1382 let default_provider = Box::new(MockModelProvider::new("ok").with_vision(true));
1383 let aux_provider = Box::new(MockModelProvider::new("nope").with_vision(false));
1384
1385 let router = RouterModelProvider::new(
1386 "test",
1387 vec![
1388 ("default".into(), default_provider as Box<dyn ModelProvider>),
1389 ("aux".into(), aux_provider as Box<dyn ModelProvider>),
1390 ],
1391 vec![],
1392 "default-model".into(),
1393 );
1394
1395 assert!(router.supports_vision());
1396 }
1397}