Skip to main content

zeroclaw_providers/
router.rs

1use super::ModelProvider;
2use super::traits::{
3    ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamEvent, StreamOptions, StreamResult,
4};
5use async_trait::async_trait;
6use futures_util::stream::BoxStream;
7use std::collections::HashMap;
8
9/// Score a model against a user-keyed pricing map. Sums any entry matching
10/// the model directly, plus optional `.input` and `.output` dimension keys.
11/// Returns `None` when nothing matches.
12fn score_model(pricing: &HashMap<String, f64>, model: &str) -> Option<f64> {
13    let mut total = 0.0;
14    let mut matched = false;
15    if let Some(v) = pricing.get(model) {
16        total += *v;
17        matched = true;
18    }
19    if let Some(v) = pricing.get(&format!("{model}.input")) {
20        total += *v;
21        matched = true;
22    }
23    if let Some(v) = pricing.get(&format!("{model}.output")) {
24        total += *v;
25        matched = true;
26    }
27    matched.then_some(total)
28}
29
30/// A single route: maps a task hint to a model_provider + model combo.
31#[derive(Debug, Clone)]
32pub struct Route {
33    pub provider_name: String,
34    pub model: String,
35}
36
37/// Multi-model router — routes requests to different model_provider+model combos
38/// based on a task hint encoded in the model parameter.
39///
40/// The model parameter can be:
41/// - A regular model name (e.g. "anthropic/claude-sonnet-4") → uses default model_provider
42/// - A hint-prefixed string (e.g. "hint:reasoning") → resolves via route table
43///
44/// This wraps multiple pre-created model_providers and selects the right one per request.
45pub struct RouterModelProvider {
46    /// `[model_providers.<family>.<alias>]` config-key alias.
47    alias: String,
48    routes: HashMap<String, (usize, String)>, // hint → (provider_index, model)
49    model_providers: Vec<(String, Box<dyn ModelProvider>)>,
50    default_index: usize,
51    default_model: String,
52}
53
54impl RouterModelProvider {
55    /// Create a new router with a default model_provider and optional routes.
56    ///
57    /// `model_providers` is a list of (name, model_provider) pairs. The first one is the default.
58    /// `routes` maps hint names to Route structs containing provider_name and model.
59    pub fn new(
60        alias: &str,
61        model_providers: Vec<(String, Box<dyn ModelProvider>)>,
62        routes: Vec<(String, Route)>,
63        default_model: String,
64    ) -> Self {
65        // Build model_provider name → index lookup
66        let name_to_index: HashMap<&str, usize> = model_providers
67            .iter()
68            .enumerate()
69            .map(|(i, (name, _))| (name.as_str(), i))
70            .collect();
71
72        // Resolve routes to model_provider indices
73        let resolved_routes: HashMap<String, (usize, String)> = routes
74            .into_iter()
75            .filter_map(|(hint, route)| {
76                let index = name_to_index.get(route.provider_name.as_str()).copied();
77                match index {
78                    Some(i) => Some((hint, (i, route.model))),
79                    None => {
80                        ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"hint": hint, "model_provider": route.provider_name})), "Route references unknown model_provider, skipping");
81                        None
82                    }
83                }
84            })
85            .collect();
86
87        Self {
88            alias: alias.to_string(),
89            routes: resolved_routes,
90            model_providers,
91            default_index: 0,
92            default_model,
93        }
94    }
95    /// Resolve a model parameter to the cheapest qualifying route based on pricing.
96    ///
97    /// If the model starts with `"hint:cost-optimized"` or `"hint:cheapest"`, this
98    /// method scores each route by `input_price + output_price` (a simple proxy for
99    /// total cost), optionally filtering by capability requirements, and returns the
100    /// cheapest qualifying route.
101    ///
102    /// Falls back to the default route when no pricing data matches.
103    pub fn resolve_cost_optimized(
104        &self,
105        model: &str,
106        model_provider_pricing: &HashMap<String, HashMap<String, f64>>,
107        required_vision: bool,
108        required_tools: bool,
109    ) -> (usize, String) {
110        let hint = model.strip_prefix("hint:");
111        let is_cost_hint = matches!(hint, Some("cost-optimized" | "cheapest"));
112
113        if !is_cost_hint {
114            return self.resolve(model);
115        }
116
117        let mut candidates: Vec<(usize, String, f64)> = Vec::new();
118
119        for (idx, route_model) in self.routes.values() {
120            // Capability filtering
121            if let Some((_, model_provider)) = self.model_providers.get(*idx) {
122                if required_vision && !model_provider.supports_vision() {
123                    continue;
124                }
125                if required_tools && !model_provider.supports_native_tools() {
126                    continue;
127                }
128            }
129
130            let Some((model_provider_name, _)) = self.model_providers.get(*idx) else {
131                continue;
132            };
133            if let Some(pricing) = model_provider_pricing.get(model_provider_name)
134                && let Some(total_cost) = score_model(pricing, route_model)
135            {
136                candidates.push((*idx, route_model.clone(), total_cost));
137            }
138        }
139
140        // Sort by total cost (ascending) and pick the cheapest
141        candidates.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
142
143        if let Some((idx, route_model, _)) = candidates.into_iter().next() {
144            return (idx, route_model);
145        }
146
147        // Fallback to default
148        ::zeroclaw_log::record!(
149            WARN,
150            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
151                .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
152            "No cost-optimized route found with matching pricing data, \
153             falling back to default"
154        );
155        (self.default_index, self.default_model.clone())
156    }
157
158    /// Resolve a model parameter to a (model_provider, actual_model) pair.
159    ///
160    /// If the model starts with "hint:", look up the hint in the route table.
161    /// Otherwise, use the default model_provider with the given model name.
162    /// Resolve a model parameter to a (provider_index, actual_model) pair.
163    fn resolve(&self, model: &str) -> (usize, String) {
164        if let Some(hint) = model.strip_prefix("hint:") {
165            if let Some((idx, resolved_model)) = self.routes.get(hint) {
166                return (*idx, resolved_model.clone());
167            }
168            ::zeroclaw_log::record!(
169                WARN,
170                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
171                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
172                    .with_attrs(::serde_json::json!({"hint": hint})),
173                "Unknown route hint, falling back to default model_provider"
174            );
175        }
176
177        // Not a hint or hint not found — use default model_provider with the model as-is
178        (self.default_index, model.to_string())
179    }
180}
181
182/// A cost-optimized routing strategy that selects the cheapest qualifying
183/// model_provider from the route table based on per-provider pricing maps.
184///
185/// Pricing is keyed by model_provider name (the alias under
186/// `[model_providers.<model_provider>.<alias>]`); each model_provider's pricing map
187/// holds user-defined keys (model identifiers, optionally suffixed with
188/// `.input` / `.output`) mapped to USD-per-1M-token rates.
189#[derive(Debug, Clone)]
190pub struct CostOptimizedStrategy {
191    /// Per-provider pricing data (model_provider name → user-keyed pricing map).
192    pub model_provider_pricing: HashMap<String, HashMap<String, f64>>,
193    /// Whether the request requires vision support.
194    pub required_vision: bool,
195    /// Whether the request requires native tool support.
196    pub required_tools: bool,
197}
198
199impl CostOptimizedStrategy {
200    /// Create a new cost-optimized strategy with the given per-provider
201    /// pricing data.
202    pub fn new(model_provider_pricing: HashMap<String, HashMap<String, f64>>) -> Self {
203        Self {
204            model_provider_pricing,
205            required_vision: false,
206            required_tools: false,
207        }
208    }
209
210    /// Set whether vision support is required.
211    pub fn with_vision(mut self, required: bool) -> Self {
212        self.required_vision = required;
213        self
214    }
215
216    /// Set whether native tool support is required.
217    pub fn with_tools(mut self, required: bool) -> Self {
218        self.required_tools = required;
219        self
220    }
221
222    /// Score a route by summing pricing entries that match the model.
223    /// Returns `None` if no pricing data is available for the route.
224    pub fn score(&self, model_provider_name: &str, model: &str) -> Option<f64> {
225        let pricing = self.model_provider_pricing.get(model_provider_name)?;
226        score_model(pricing, model)
227    }
228}
229
230#[async_trait]
231impl ModelProvider for RouterModelProvider {
232    async fn chat_with_system(
233        &self,
234        system_prompt: Option<&str>,
235        message: &str,
236        model: &str,
237        temperature: Option<f64>,
238    ) -> anyhow::Result<String> {
239        let (provider_idx, resolved_model) = self.resolve(model);
240
241        let (provider_name, model_provider) = &self.model_providers[provider_idx];
242        // `provider_name` is the configured `<type>.<alias>` key the
243        // caller registered with `RouterModelProvider::new` — already a
244        // composite. Layer's `set_composite` splits it on emit.
245        ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name.as_str(), "model": resolved_model.as_str()})), "router dispatching request");
246
247        model_provider
248            .chat_with_system(system_prompt, message, &resolved_model, temperature)
249            .await
250    }
251
252    async fn chat_with_history(
253        &self,
254        messages: &[ChatMessage],
255        model: &str,
256        temperature: Option<f64>,
257    ) -> anyhow::Result<String> {
258        let (provider_idx, resolved_model) = self.resolve(model);
259        let (_, model_provider) = &self.model_providers[provider_idx];
260        model_provider
261            .chat_with_history(messages, &resolved_model, temperature)
262            .await
263    }
264
265    async fn chat(
266        &self,
267        request: ChatRequest<'_>,
268        model: &str,
269        temperature: Option<f64>,
270    ) -> anyhow::Result<ChatResponse> {
271        let (provider_idx, resolved_model) = self.resolve(model);
272        let (_, model_provider) = &self.model_providers[provider_idx];
273        model_provider
274            .chat(request, &resolved_model, temperature)
275            .await
276    }
277
278    async fn chat_with_tools(
279        &self,
280        messages: &[ChatMessage],
281        tools: &[serde_json::Value],
282        model: &str,
283        temperature: Option<f64>,
284    ) -> anyhow::Result<ChatResponse> {
285        let (provider_idx, resolved_model) = self.resolve(model);
286        let (_, model_provider) = &self.model_providers[provider_idx];
287        model_provider
288            .chat_with_tools(messages, tools, &resolved_model, temperature)
289            .await
290    }
291
292    fn supports_native_tools(&self) -> bool {
293        self.model_providers
294            .get(self.default_index)
295            .map(|(_, p)| p.supports_native_tools())
296            .unwrap_or(false)
297    }
298
299    fn supports_streaming(&self) -> bool {
300        self.model_providers
301            .iter()
302            .any(|(_, model_provider)| model_provider.supports_streaming())
303    }
304
305    fn supports_streaming_tool_events(&self) -> bool {
306        self.model_providers
307            .iter()
308            .any(|(_, model_provider)| model_provider.supports_streaming_tool_events())
309    }
310
311    fn stream_chat_with_system(
312        &self,
313        system_prompt: Option<&str>,
314        message: &str,
315        model: &str,
316        temperature: Option<f64>,
317        options: StreamOptions,
318    ) -> BoxStream<'static, StreamResult<StreamChunk>> {
319        let (provider_idx, resolved_model) = self.resolve(model);
320        let (_, model_provider) = &self.model_providers[provider_idx];
321        model_provider.stream_chat_with_system(
322            system_prompt,
323            message,
324            &resolved_model,
325            temperature,
326            options,
327        )
328    }
329
330    fn stream_chat_with_history(
331        &self,
332        messages: &[ChatMessage],
333        model: &str,
334        temperature: Option<f64>,
335        options: StreamOptions,
336    ) -> BoxStream<'static, StreamResult<StreamChunk>> {
337        let (provider_idx, resolved_model) = self.resolve(model);
338        let (_, model_provider) = &self.model_providers[provider_idx];
339        model_provider.stream_chat_with_history(messages, &resolved_model, temperature, options)
340    }
341
342    fn stream_chat(
343        &self,
344        request: ChatRequest<'_>,
345        model: &str,
346        temperature: Option<f64>,
347        options: StreamOptions,
348    ) -> BoxStream<'static, StreamResult<StreamEvent>> {
349        let (provider_idx, resolved_model) = self.resolve(model);
350        let (_, model_provider) = &self.model_providers[provider_idx];
351        model_provider.stream_chat(request, &resolved_model, temperature, options)
352    }
353
354    fn supports_vision(&self) -> bool {
355        self.model_providers
356            .get(self.default_index)
357            .map(|(_, p)| p.supports_vision())
358            .unwrap_or(false)
359    }
360
361    async fn warmup(&self) -> anyhow::Result<()> {
362        for (name, model_provider) in &self.model_providers {
363            ::zeroclaw_log::record!(
364                INFO,
365                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
366                    .with_attrs(::serde_json::json!({"model_provider": name})),
367                "Warming up routed model_provider"
368            );
369            if let Err(e) = model_provider.warmup().await {
370                ::zeroclaw_log::record!(
371                    WARN,
372                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
373                        .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
374                        .with_attrs(
375                            ::serde_json::json!({"error": format!("{}", e), "model_provider": name})
376                        ),
377                    "Warmup failed (non-fatal)"
378                );
379            }
380        }
381        Ok(())
382    }
383}
384
385impl ::zeroclaw_api::attribution::Attributable for RouterModelProvider {
386    fn role(&self) -> ::zeroclaw_api::attribution::Role {
387        ::zeroclaw_api::attribution::Role::Provider(
388            ::zeroclaw_api::attribution::ProviderKind::Model(
389                ::zeroclaw_api::attribution::ModelProviderKind::Router,
390            ),
391        )
392    }
393    fn alias(&self) -> &str {
394        &self.alias
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use futures_util::StreamExt;
402    use std::sync::Arc;
403    use std::sync::atomic::{AtomicUsize, Ordering};
404    use zeroclaw_api::tool::ToolSpec;
405
406    struct MockModelProvider {
407        calls: Arc<AtomicUsize>,
408        response: &'static str,
409        last_model: parking_lot::Mutex<String>,
410        vision: bool,
411    }
412
413    impl MockModelProvider {
414        fn new(response: &'static str) -> Self {
415            Self {
416                calls: Arc::new(AtomicUsize::new(0)),
417                response,
418                last_model: parking_lot::Mutex::new(String::new()),
419                vision: false,
420            }
421        }
422
423        fn with_vision(mut self, vision: bool) -> Self {
424            self.vision = vision;
425            self
426        }
427
428        fn call_count(&self) -> usize {
429            self.calls.load(Ordering::SeqCst)
430        }
431
432        fn last_model(&self) -> String {
433            self.last_model.lock().clone()
434        }
435    }
436
437    #[async_trait]
438    impl ModelProvider for MockModelProvider {
439        async fn chat_with_system(
440            &self,
441            _system_prompt: Option<&str>,
442            _message: &str,
443            model: &str,
444            _temperature: Option<f64>,
445        ) -> anyhow::Result<String> {
446            self.calls.fetch_add(1, Ordering::SeqCst);
447            *self.last_model.lock() = model.to_string();
448            Ok(self.response.to_string())
449        }
450
451        fn supports_vision(&self) -> bool {
452            self.vision
453        }
454    }
455    impl ::zeroclaw_api::attribution::Attributable for MockModelProvider {
456        fn role(&self) -> ::zeroclaw_api::attribution::Role {
457            ::zeroclaw_api::attribution::Role::Provider(
458                ::zeroclaw_api::attribution::ProviderKind::Model(
459                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
460                ),
461            )
462        }
463        fn alias(&self) -> &str {
464            "MockModelProvider"
465        }
466    }
467
468    fn make_router(
469        model_providers: Vec<(&'static str, &'static str)>,
470        routes: Vec<(&str, &str, &str)>,
471    ) -> (RouterModelProvider, Vec<Arc<MockModelProvider>>) {
472        let mocks: Vec<Arc<MockModelProvider>> = model_providers
473            .iter()
474            .map(|(_, response)| Arc::new(MockModelProvider::new(response)))
475            .collect();
476
477        let provider_list: Vec<(String, Box<dyn ModelProvider>)> = model_providers
478            .iter()
479            .zip(mocks.iter())
480            .map(|((name, _), mock)| {
481                (
482                    (*name).to_string(),
483                    Box::new(Arc::clone(mock)) as Box<dyn ModelProvider>,
484                )
485            })
486            .collect();
487
488        let route_list: Vec<(String, Route)> = routes
489            .iter()
490            .map(|(hint, provider_name, model)| {
491                (
492                    (*hint).to_string(),
493                    Route {
494                        provider_name: (*provider_name).to_string(),
495                        model: (*model).to_string(),
496                    },
497                )
498            })
499            .collect();
500
501        let router = RouterModelProvider::new(
502            "test",
503            provider_list,
504            route_list,
505            "default-model".to_string(),
506        );
507
508        (router, mocks)
509    }
510
511    // Arc<MockModelProvider> ModelProvider impl provided by blanket impl in zeroclaw-types.
512
513    struct StreamingMockModelProvider {
514        stream_calls: Arc<AtomicUsize>,
515        last_stream_model: parking_lot::Mutex<String>,
516        response: &'static str,
517    }
518
519    impl StreamingMockModelProvider {
520        fn new(response: &'static str) -> Self {
521            Self {
522                stream_calls: Arc::new(AtomicUsize::new(0)),
523                last_stream_model: parking_lot::Mutex::new(String::new()),
524                response,
525            }
526        }
527
528        fn stream_response(&self, model: &str) -> BoxStream<'static, StreamResult<StreamChunk>> {
529            self.stream_calls.fetch_add(1, Ordering::SeqCst);
530            *self.last_stream_model.lock() = model.to_string();
531            let chunks = vec![
532                Ok(StreamChunk::delta(self.response)),
533                Ok(StreamChunk::final_chunk()),
534            ];
535            futures_util::stream::iter(chunks).boxed()
536        }
537    }
538
539    #[async_trait]
540    impl ModelProvider for StreamingMockModelProvider {
541        async fn chat_with_system(
542            &self,
543            _system_prompt: Option<&str>,
544            _message: &str,
545            _model: &str,
546            _temperature: Option<f64>,
547        ) -> anyhow::Result<String> {
548            Ok("ok".to_string())
549        }
550
551        fn supports_streaming(&self) -> bool {
552            true
553        }
554
555        fn stream_chat_with_system(
556            &self,
557            _system_prompt: Option<&str>,
558            _message: &str,
559            model: &str,
560            _temperature: Option<f64>,
561            _options: StreamOptions,
562        ) -> BoxStream<'static, StreamResult<StreamChunk>> {
563            self.stream_response(model)
564        }
565
566        fn stream_chat_with_history(
567            &self,
568            _messages: &[ChatMessage],
569            model: &str,
570            _temperature: Option<f64>,
571            _options: StreamOptions,
572        ) -> BoxStream<'static, StreamResult<StreamChunk>> {
573            self.stream_response(model)
574        }
575    }
576    impl ::zeroclaw_api::attribution::Attributable for StreamingMockModelProvider {
577        fn role(&self) -> ::zeroclaw_api::attribution::Role {
578            ::zeroclaw_api::attribution::Role::Provider(
579                ::zeroclaw_api::attribution::ProviderKind::Model(
580                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
581                ),
582            )
583        }
584        fn alias(&self) -> &str {
585            "StreamingMockModelProvider"
586        }
587    }
588
589    // Arc<StreamingMockModelProvider> ModelProvider impl provided by blanket impl in zeroclaw-types.
590
591    struct ToolEventStreamingMockModelProvider {
592        stream_calls: Arc<AtomicUsize>,
593        tool_event_calls: Arc<AtomicUsize>,
594        last_stream_model: parking_lot::Mutex<String>,
595    }
596
597    impl ToolEventStreamingMockModelProvider {
598        fn new() -> Self {
599            Self {
600                stream_calls: Arc::new(AtomicUsize::new(0)),
601                tool_event_calls: Arc::new(AtomicUsize::new(0)),
602                last_stream_model: parking_lot::Mutex::new(String::new()),
603            }
604        }
605    }
606
607    #[async_trait]
608    impl ModelProvider for ToolEventStreamingMockModelProvider {
609        async fn chat_with_system(
610            &self,
611            _system_prompt: Option<&str>,
612            _message: &str,
613            _model: &str,
614            _temperature: Option<f64>,
615        ) -> anyhow::Result<String> {
616            Ok("ok".to_string())
617        }
618
619        fn supports_streaming(&self) -> bool {
620            true
621        }
622
623        fn supports_streaming_tool_events(&self) -> bool {
624            true
625        }
626
627        fn stream_chat(
628            &self,
629            request: ChatRequest<'_>,
630            model: &str,
631            _temperature: Option<f64>,
632            _options: StreamOptions,
633        ) -> BoxStream<'static, StreamResult<StreamEvent>> {
634            self.stream_calls.fetch_add(1, Ordering::SeqCst);
635            if request.tools.is_some_and(|tools| !tools.is_empty()) {
636                self.tool_event_calls.fetch_add(1, Ordering::SeqCst);
637            }
638            *self.last_stream_model.lock() = model.to_string();
639            futures_util::stream::iter(vec![
640                Ok(StreamEvent::ToolCall(crate::traits::ToolCall {
641                    id: "call_router_1".to_string(),
642                    name: "shell".to_string(),
643                    arguments: r#"{"command":"date"}"#.to_string(),
644                    extra_content: None,
645                })),
646                Ok(StreamEvent::Final),
647            ])
648            .boxed()
649        }
650    }
651    impl ::zeroclaw_api::attribution::Attributable for ToolEventStreamingMockModelProvider {
652        fn role(&self) -> ::zeroclaw_api::attribution::Role {
653            ::zeroclaw_api::attribution::Role::Provider(
654                ::zeroclaw_api::attribution::ProviderKind::Model(
655                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
656                ),
657            )
658        }
659        fn alias(&self) -> &str {
660            "ToolEventStreamingMockModelProvider"
661        }
662    }
663
664    // Arc<ToolEventStreamingMockModelProvider> ModelProvider impl provided by blanket impl in zeroclaw-types.
665
666    #[tokio::test]
667    async fn routes_hint_to_correct_provider() {
668        let (router, mocks) = make_router(
669            vec![("fast", "fast-response"), ("smart", "smart-response")],
670            vec![
671                ("fast", "fast", "llama-3-70b"),
672                ("reasoning", "smart", "claude-opus"),
673            ],
674        );
675
676        let result = router
677            .simple_chat("hello", "hint:reasoning", Some(0.5))
678            .await
679            .unwrap();
680        assert_eq!(result, "smart-response");
681        assert_eq!(mocks[1].call_count(), 1);
682        assert_eq!(mocks[1].last_model(), "claude-opus");
683        assert_eq!(mocks[0].call_count(), 0);
684    }
685
686    #[tokio::test]
687    async fn routes_fast_hint() {
688        let (router, mocks) = make_router(
689            vec![("fast", "fast-response"), ("smart", "smart-response")],
690            vec![("fast", "fast", "llama-3-70b")],
691        );
692
693        let result = router
694            .simple_chat("hello", "hint:fast", Some(0.5))
695            .await
696            .unwrap();
697        assert_eq!(result, "fast-response");
698        assert_eq!(mocks[0].call_count(), 1);
699        assert_eq!(mocks[0].last_model(), "llama-3-70b");
700    }
701
702    #[tokio::test]
703    async fn unknown_hint_falls_back_to_default() {
704        let (router, mocks) = make_router(
705            vec![("default", "default-response"), ("other", "other-response")],
706            vec![],
707        );
708
709        let result = router
710            .simple_chat("hello", "hint:nonexistent", Some(0.5))
711            .await
712            .unwrap();
713        assert_eq!(result, "default-response");
714        assert_eq!(mocks[0].call_count(), 1);
715        // Falls back to default with the hint as model name
716        assert_eq!(mocks[0].last_model(), "hint:nonexistent");
717    }
718
719    #[tokio::test]
720    async fn non_hint_model_uses_default_provider() {
721        let (router, mocks) = make_router(
722            vec![
723                ("primary", "primary-response"),
724                ("secondary", "secondary-response"),
725            ],
726            vec![("code", "secondary", "codellama")],
727        );
728
729        let result = router
730            .simple_chat("hello", "anthropic/claude-sonnet-4-20250514", Some(0.5))
731            .await
732            .unwrap();
733        assert_eq!(result, "primary-response");
734        assert_eq!(mocks[0].call_count(), 1);
735        assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
736    }
737
738    #[test]
739    fn resolve_preserves_model_for_non_hints() {
740        let (router, _) = make_router(vec![("default", "ok")], vec![]);
741
742        let (idx, model) = router.resolve("gpt-4o");
743        assert_eq!(idx, 0);
744        assert_eq!(model, "gpt-4o");
745    }
746
747    #[test]
748    fn resolve_strips_hint_prefix() {
749        let (router, _) = make_router(
750            vec![("fast", "ok"), ("smart", "ok")],
751            vec![("reasoning", "smart", "claude-opus")],
752        );
753
754        let (idx, model) = router.resolve("hint:reasoning");
755        assert_eq!(idx, 1);
756        assert_eq!(model, "claude-opus");
757    }
758
759    #[test]
760    fn skips_routes_with_unknown_provider() {
761        let (router, _) = make_router(
762            vec![("default", "ok")],
763            vec![("broken", "nonexistent", "model")],
764        );
765
766        // Route should not exist
767        assert!(!router.routes.contains_key("broken"));
768    }
769
770    #[tokio::test]
771    async fn warmup_calls_all_providers() {
772        let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]);
773
774        // Warmup should not error
775        assert!(router.warmup().await.is_ok());
776    }
777
778    #[tokio::test]
779    async fn chat_with_system_passes_system_prompt() {
780        let mock = Arc::new(MockModelProvider::new("response"));
781        let router = RouterModelProvider::new(
782            "test",
783            vec![(
784                "default".into(),
785                Box::new(Arc::clone(&mock)) as Box<dyn ModelProvider>,
786            )],
787            vec![],
788            "model".into(),
789        );
790
791        let result = router
792            .chat_with_system(Some("system"), "hello", "model", Some(0.5))
793            .await
794            .unwrap();
795        assert_eq!(result, "response");
796        assert_eq!(mock.call_count(), 1);
797    }
798
799    #[tokio::test]
800    async fn chat_with_tools_delegates_to_resolved_provider() {
801        let mock = Arc::new(MockModelProvider::new("tool-response"));
802        let router = RouterModelProvider::new(
803            "test",
804            vec![(
805                "default".into(),
806                Box::new(Arc::clone(&mock)) as Box<dyn ModelProvider>,
807            )],
808            vec![],
809            "model".into(),
810        );
811
812        let messages = vec![ChatMessage {
813            role: "user".to_string(),
814            content: "use tools".to_string(),
815        }];
816        let tools = vec![serde_json::json!({
817            "type": "function",
818            "function": {
819                "name": "shell",
820                "description": "Run shell command",
821                "parameters": {}
822            }
823        })];
824
825        // chat_with_tools should delegate through the router to the mock.
826        // MockModelProvider's default chat_with_tools calls chat_with_history -> chat_with_system.
827        let result = router
828            .chat_with_tools(&messages, &tools, "model", Some(0.7))
829            .await
830            .unwrap();
831        assert_eq!(result.text.as_deref(), Some("tool-response"));
832        assert_eq!(mock.call_count(), 1);
833        assert_eq!(mock.last_model(), "model");
834    }
835
836    #[tokio::test]
837    async fn chat_with_tools_routes_hint_correctly() {
838        let (router, mocks) = make_router(
839            vec![("fast", "fast-tool"), ("smart", "smart-tool")],
840            vec![("reasoning", "smart", "claude-opus")],
841        );
842
843        let messages = vec![ChatMessage {
844            role: "user".to_string(),
845            content: "reason about this".to_string(),
846        }];
847        let tools = vec![serde_json::json!({"type": "function", "function": {"name": "test"}})];
848
849        let result = router
850            .chat_with_tools(&messages, &tools, "hint:reasoning", Some(0.5))
851            .await
852            .unwrap();
853        assert_eq!(result.text.as_deref(), Some("smart-tool"));
854        assert_eq!(mocks[1].call_count(), 1);
855        assert_eq!(mocks[1].last_model(), "claude-opus");
856        assert_eq!(mocks[0].call_count(), 0);
857    }
858
859    // ── Cost-optimized routing tests ────────────────────────────────
860
861    use crate::traits::ProviderCapabilities;
862
863    /// Mock model_provider with configurable capability flags.
864    struct CapableMockModelProvider {
865        response: &'static str,
866        vision: bool,
867        tools: bool,
868    }
869
870    impl CapableMockModelProvider {
871        fn new(response: &'static str, vision: bool, tools: bool) -> Self {
872            Self {
873                response,
874                vision,
875                tools,
876            }
877        }
878    }
879
880    #[async_trait]
881    impl ModelProvider for CapableMockModelProvider {
882        fn capabilities(&self) -> ProviderCapabilities {
883            ProviderCapabilities {
884                native_tool_calling: self.tools,
885                vision: self.vision,
886                prompt_caching: false,
887                extended_thinking: false,
888            }
889        }
890
891        async fn chat_with_system(
892            &self,
893            _system_prompt: Option<&str>,
894            _message: &str,
895            _model: &str,
896            _temperature: Option<f64>,
897        ) -> anyhow::Result<String> {
898            Ok(self.response.to_string())
899        }
900    }
901    impl ::zeroclaw_api::attribution::Attributable for CapableMockModelProvider {
902        fn role(&self) -> ::zeroclaw_api::attribution::Role {
903            ::zeroclaw_api::attribution::Role::Provider(
904                ::zeroclaw_api::attribution::ProviderKind::Model(
905                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
906                ),
907            )
908        }
909        fn alias(&self) -> &str {
910            "CapableMockModelProvider"
911        }
912    }
913
914    /// Build a per-provider pricing map for tests. Each tuple is
915    /// `(provider_name, model, input_per_mtok, output_per_mtok)`.
916    fn make_pricing(entries: Vec<(&str, &str, f64, f64)>) -> HashMap<String, HashMap<String, f64>> {
917        let mut map: HashMap<String, HashMap<String, f64>> = HashMap::new();
918        for (model_provider, model, input, output) in entries {
919            let inner = map.entry(model_provider.to_string()).or_default();
920            inner.insert(format!("{model}.input"), input);
921            inner.insert(format!("{model}.output"), output);
922        }
923        map
924    }
925
926    #[test]
927    fn cost_optimized_selects_cheapest_provider() {
928        let model_providers: Vec<(String, Box<dyn ModelProvider>)> = vec![
929            (
930                "expensive".into(),
931                Box::new(CapableMockModelProvider::new("exp", false, false)),
932            ),
933            (
934                "cheap".into(),
935                Box::new(CapableMockModelProvider::new("chp", false, false)),
936            ),
937        ];
938        let routes = vec![
939            (
940                "expensive".to_string(),
941                Route {
942                    provider_name: "expensive".into(),
943                    model: "big-model".into(),
944                },
945            ),
946            (
947                "cheap".to_string(),
948                Route {
949                    provider_name: "cheap".into(),
950                    model: "small-model".into(),
951                },
952            ),
953        ];
954        let router =
955            RouterModelProvider::new("test", model_providers, routes, "default-model".into());
956
957        let prices = make_pricing(vec![
958            ("expensive", "big-model", 15.0, 75.0),
959            ("cheap", "small-model", 0.25, 1.25),
960        ]);
961
962        let (idx, model) =
963            router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
964        assert_eq!(model, "small-model");
965        assert_eq!(idx, 1);
966    }
967
968    #[test]
969    fn cost_optimized_respects_vision_requirement() {
970        let model_providers: Vec<(String, Box<dyn ModelProvider>)> = vec![
971            (
972                "no-vision".into(),
973                Box::new(CapableMockModelProvider::new("nv", false, false)),
974            ),
975            (
976                "has-vision".into(),
977                Box::new(CapableMockModelProvider::new("hv", true, false)),
978            ),
979        ];
980        let routes = vec![
981            (
982                "cheap".to_string(),
983                Route {
984                    provider_name: "no-vision".into(),
985                    model: "cheap-model".into(),
986                },
987            ),
988            (
989                "vision".to_string(),
990                Route {
991                    provider_name: "has-vision".into(),
992                    model: "vision-model".into(),
993                },
994            ),
995        ];
996        let router =
997            RouterModelProvider::new("test", model_providers, routes, "default-model".into());
998
999        let prices = make_pricing(vec![
1000            ("no-vision", "cheap-model", 0.10, 0.40),
1001            ("has-vision", "vision-model", 3.0, 15.0),
1002        ]);
1003
1004        // With vision required, the cheap model (no vision) is filtered out
1005        let (_, model) = router.resolve_cost_optimized("hint:cheapest", &prices, true, false);
1006        assert_eq!(model, "vision-model");
1007    }
1008
1009    #[test]
1010    fn cost_optimized_respects_tools_requirement() {
1011        let model_providers: Vec<(String, Box<dyn ModelProvider>)> = vec![
1012            (
1013                "no-tools".into(),
1014                Box::new(CapableMockModelProvider::new("nt", false, false)),
1015            ),
1016            (
1017                "has-tools".into(),
1018                Box::new(CapableMockModelProvider::new("ht", false, true)),
1019            ),
1020        ];
1021        let routes = vec![
1022            (
1023                "basic".to_string(),
1024                Route {
1025                    provider_name: "no-tools".into(),
1026                    model: "basic-model".into(),
1027                },
1028            ),
1029            (
1030                "tools".to_string(),
1031                Route {
1032                    provider_name: "has-tools".into(),
1033                    model: "tools-model".into(),
1034                },
1035            ),
1036        ];
1037        let router =
1038            RouterModelProvider::new("test", model_providers, routes, "default-model".into());
1039
1040        let prices = make_pricing(vec![
1041            ("no-tools", "basic-model", 0.10, 0.40),
1042            ("has-tools", "tools-model", 5.0, 15.0),
1043        ]);
1044
1045        // With tools required, the basic model (no tools) is filtered out
1046        let (_, model) = router.resolve_cost_optimized("hint:cost-optimized", &prices, false, true);
1047        assert_eq!(model, "tools-model");
1048    }
1049
1050    #[test]
1051    fn cost_optimized_falls_back_when_no_pricing() {
1052        let (router, _) = make_router(
1053            vec![("default", "ok"), ("other", "ok")],
1054            vec![("route-a", "other", "some-model")],
1055        );
1056
1057        // Empty pricing map — no matches possible
1058        let prices: HashMap<String, HashMap<String, f64>> = HashMap::new();
1059        let (idx, model) =
1060            router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
1061        assert_eq!(idx, 0);
1062        assert_eq!(model, "default-model");
1063    }
1064
1065    #[test]
1066    fn cost_optimized_with_single_route() {
1067        let model_providers: Vec<(String, Box<dyn ModelProvider>)> = vec![(
1068            "only".into(),
1069            Box::new(CapableMockModelProvider::new("ok", false, false)),
1070        )];
1071        let routes = vec![(
1072            "single".to_string(),
1073            Route {
1074                provider_name: "only".into(),
1075                model: "the-model".into(),
1076            },
1077        )];
1078        let router =
1079            RouterModelProvider::new("test", model_providers, routes, "default-model".into());
1080
1081        let prices = make_pricing(vec![("only", "the-model", 1.0, 2.0)]);
1082
1083        let (idx, model) = router.resolve_cost_optimized("hint:cheapest", &prices, false, false);
1084        assert_eq!(idx, 0);
1085        assert_eq!(model, "the-model");
1086    }
1087
1088    #[test]
1089    fn cost_optimized_prefers_lower_total_cost() {
1090        let model_providers: Vec<(String, Box<dyn ModelProvider>)> = vec![
1091            (
1092                "p1".into(),
1093                Box::new(CapableMockModelProvider::new("r1", false, false)),
1094            ),
1095            (
1096                "p2".into(),
1097                Box::new(CapableMockModelProvider::new("r2", false, false)),
1098            ),
1099            (
1100                "p3".into(),
1101                Box::new(CapableMockModelProvider::new("r3", false, false)),
1102            ),
1103        ];
1104        let routes = vec![
1105            (
1106                "a".to_string(),
1107                Route {
1108                    provider_name: "p1".into(),
1109                    model: "model-a".into(),
1110                },
1111            ),
1112            (
1113                "b".to_string(),
1114                Route {
1115                    provider_name: "p2".into(),
1116                    model: "model-b".into(),
1117                },
1118            ),
1119            (
1120                "c".to_string(),
1121                Route {
1122                    provider_name: "p3".into(),
1123                    model: "model-c".into(),
1124                },
1125            ),
1126        ];
1127        let router =
1128            RouterModelProvider::new("test", model_providers, routes, "default-model".into());
1129
1130        let prices = make_pricing(vec![
1131            ("p1", "model-a", 10.0, 50.0), // total: 60
1132            ("p2", "model-b", 0.15, 0.60), // total: 0.75 (cheapest)
1133            ("p3", "model-c", 3.0, 15.0),  // total: 18
1134        ]);
1135
1136        let (idx, model) =
1137            router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
1138        assert_eq!(model, "model-b");
1139        assert_eq!(idx, 1);
1140    }
1141
1142    #[test]
1143    fn cost_optimized_strategy_score() {
1144        let prices = make_pricing(vec![
1145            ("cheap-provider", "cheap-model", 0.10, 0.40),
1146            ("expensive-provider", "expensive-model", 15.0, 75.0),
1147        ]);
1148        let strategy = CostOptimizedStrategy::new(prices);
1149
1150        assert!(
1151            (strategy.score("cheap-provider", "cheap-model").unwrap() - 0.50).abs() < f64::EPSILON
1152        );
1153        assert!(
1154            (strategy
1155                .score("expensive-provider", "expensive-model")
1156                .unwrap()
1157                - 90.0)
1158                .abs()
1159                < f64::EPSILON
1160        );
1161        assert!(strategy.score("cheap-provider", "unknown").is_none());
1162        assert!(strategy.score("unknown-provider", "cheap-model").is_none());
1163    }
1164
1165    #[tokio::test]
1166    async fn supports_streaming_returns_true_when_any_provider_supports_it() {
1167        let streaming = Arc::new(StreamingMockModelProvider::new("stream"));
1168        let router = RouterModelProvider::new(
1169            "test",
1170            vec![
1171                (
1172                    "default".into(),
1173                    Box::new(MockModelProvider::new("default")) as Box<dyn ModelProvider>,
1174                ),
1175                (
1176                    "streaming".into(),
1177                    Box::new(Arc::clone(&streaming)) as Box<dyn ModelProvider>,
1178                ),
1179            ],
1180            vec![(
1181                "reasoning".into(),
1182                Route {
1183                    provider_name: "streaming".into(),
1184                    model: "claude-opus".into(),
1185                },
1186            )],
1187            "model".into(),
1188        );
1189
1190        assert!(router.supports_streaming());
1191    }
1192
1193    #[tokio::test]
1194    async fn stream_chat_with_system_routes_hint_to_correct_provider_and_model() {
1195        let streaming = Arc::new(StreamingMockModelProvider::new("streamed system response"));
1196        let router = RouterModelProvider::new(
1197            "test",
1198            vec![
1199                (
1200                    "default".into(),
1201                    Box::new(MockModelProvider::new("default")) as Box<dyn ModelProvider>,
1202                ),
1203                (
1204                    "streaming".into(),
1205                    Box::new(Arc::clone(&streaming)) as Box<dyn ModelProvider>,
1206                ),
1207            ],
1208            vec![(
1209                "reasoning".into(),
1210                Route {
1211                    provider_name: "streaming".into(),
1212                    model: "claude-opus".into(),
1213                },
1214            )],
1215            "model".into(),
1216        );
1217
1218        let mut stream = router.stream_chat_with_system(
1219            Some("system"),
1220            "hello",
1221            "hint:reasoning",
1222            Some(0.0),
1223            StreamOptions::new(true),
1224        );
1225
1226        let mut collected = String::new();
1227        while let Some(chunk) = stream.next().await {
1228            let chunk = chunk.expect("stream chunk should be ok");
1229            collected.push_str(&chunk.delta);
1230        }
1231
1232        assert_eq!(collected, "streamed system response");
1233        assert_eq!(streaming.stream_calls.load(Ordering::SeqCst), 1);
1234        assert_eq!(*streaming.last_stream_model.lock(), "claude-opus");
1235    }
1236
1237    #[tokio::test]
1238    async fn stream_chat_with_history_routes_hint_to_correct_provider_and_model() {
1239        let streaming = Arc::new(StreamingMockModelProvider::new("streamed response"));
1240        let router = RouterModelProvider::new(
1241            "test",
1242            vec![
1243                (
1244                    "default".into(),
1245                    Box::new(MockModelProvider::new("default")) as Box<dyn ModelProvider>,
1246                ),
1247                (
1248                    "streaming".into(),
1249                    Box::new(Arc::clone(&streaming)) as Box<dyn ModelProvider>,
1250                ),
1251            ],
1252            vec![(
1253                "reasoning".into(),
1254                Route {
1255                    provider_name: "streaming".into(),
1256                    model: "claude-opus".into(),
1257                },
1258            )],
1259            "model".into(),
1260        );
1261
1262        let messages = vec![ChatMessage::user("hello")];
1263        let mut stream = router.stream_chat_with_history(
1264            &messages,
1265            "hint:reasoning",
1266            Some(0.0),
1267            StreamOptions::new(true),
1268        );
1269
1270        let mut collected = String::new();
1271        while let Some(chunk) = stream.next().await {
1272            let chunk = chunk.expect("stream chunk should be ok");
1273            collected.push_str(&chunk.delta);
1274        }
1275
1276        assert_eq!(collected, "streamed response");
1277        assert_eq!(streaming.stream_calls.load(Ordering::SeqCst), 1);
1278        assert_eq!(*streaming.last_stream_model.lock(), "claude-opus");
1279    }
1280
1281    #[tokio::test]
1282    async fn stream_chat_routes_hint_with_structured_tool_events() {
1283        let streaming = Arc::new(ToolEventStreamingMockModelProvider::new());
1284        let router = RouterModelProvider::new(
1285            "test",
1286            vec![
1287                (
1288                    "default".into(),
1289                    Box::new(MockModelProvider::new("default")) as Box<dyn ModelProvider>,
1290                ),
1291                (
1292                    "streaming".into(),
1293                    Box::new(Arc::clone(&streaming)) as Box<dyn ModelProvider>,
1294                ),
1295            ],
1296            vec![(
1297                "reasoning".into(),
1298                Route {
1299                    provider_name: "streaming".into(),
1300                    model: "claude-opus".into(),
1301                },
1302            )],
1303            "model".into(),
1304        );
1305
1306        let messages = vec![ChatMessage::user("hello")];
1307        let tools = vec![ToolSpec {
1308            name: "shell".to_string(),
1309            description: "run shell commands".to_string(),
1310            parameters: serde_json::json!({
1311                "type": "object",
1312                "properties": {
1313                    "command": { "type": "string" }
1314                }
1315            }),
1316        }];
1317
1318        let mut stream = router.stream_chat(
1319            ChatRequest {
1320                messages: &messages,
1321                tools: Some(&tools),
1322                thinking: None,
1323            },
1324            "hint:reasoning",
1325            Some(0.0),
1326            StreamOptions::new(true),
1327        );
1328
1329        let first = stream.next().await.unwrap().unwrap();
1330        let second = stream.next().await.unwrap().unwrap();
1331        assert!(stream.next().await.is_none());
1332
1333        match first {
1334            StreamEvent::ToolCall(call) => {
1335                assert_eq!(call.name, "shell");
1336                assert_eq!(call.arguments, r#"{"command":"date"}"#);
1337            }
1338            other => panic!("expected tool-call event, got {other:?}"),
1339        }
1340        assert!(matches!(second, StreamEvent::Final));
1341        assert_eq!(streaming.stream_calls.load(Ordering::SeqCst), 1);
1342        assert_eq!(streaming.tool_event_calls.load(Ordering::SeqCst), 1);
1343        assert_eq!(*streaming.last_stream_model.lock(), "claude-opus");
1344    }
1345
1346    // Regression for #6589: supports_vision() must reflect the default provider,
1347    // not .any() across all sub-providers. Otherwise the multimodal.vision_provider
1348    // fallback in run_tool_call_loop and the image-marker stripping in the context
1349    // compressor are silently bypassed in mixed-provider configurations.
1350    #[test]
1351    fn supports_vision_reflects_default_provider_not_any_route() {
1352        let default_provider = Box::new(MockModelProvider::new("nope").with_vision(false));
1353        let vision_route_provider = Box::new(MockModelProvider::new("ok").with_vision(true));
1354
1355        let router = RouterModelProvider::new(
1356            "test",
1357            vec![
1358                ("default".into(), default_provider as Box<dyn ModelProvider>),
1359                (
1360                    "vision".into(),
1361                    vision_route_provider as Box<dyn ModelProvider>,
1362                ),
1363            ],
1364            vec![(
1365                "hint:vision".into(),
1366                Route {
1367                    provider_name: "vision".into(),
1368                    model: "vision-model".into(),
1369                },
1370            )],
1371            "default-model".into(),
1372        );
1373
1374        assert!(
1375            !router.supports_vision(),
1376            "router with non-vision default must report supports_vision()=false even when a vision-capable route exists"
1377        );
1378    }
1379
1380    #[test]
1381    fn supports_vision_true_when_default_provider_supports_vision() {
1382        let default_provider = Box::new(MockModelProvider::new("ok").with_vision(true));
1383        let aux_provider = Box::new(MockModelProvider::new("nope").with_vision(false));
1384
1385        let router = RouterModelProvider::new(
1386            "test",
1387            vec![
1388                ("default".into(), default_provider as Box<dyn ModelProvider>),
1389                ("aux".into(), aux_provider as Box<dyn ModelProvider>),
1390            ],
1391            vec![],
1392            "default-model".into(),
1393        );
1394
1395        assert!(router.supports_vision());
1396    }
1397}