Skip to main content

zeroclaw_providers/
reliable.rs

1use super::ModelProvider;
2use super::traits::{
3    ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamEvent, StreamOptions, StreamResult,
4};
5use async_trait::async_trait;
6use futures_util::{StreamExt, stream};
7use std::cell::RefCell;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::time::Duration;
11
12// ── ModelProvider Fallback Notification ──────────────────────────────────────
13// When ReliableModelProvider uses a fallback (different model_provider or model than
14// requested), it records the details here so channel code can notify the user.
15// Uses tokio::task_local to avoid cross-request leakage between concurrent
16// users (the old global static had a race window).
17
18/// Info about a model_provider fallback that occurred during a request.
19#[derive(Debug, Clone)]
20pub struct ProviderFallbackInfo {
21    /// ModelProvider that was originally requested.
22    pub requested_provider: String,
23    /// Model that was originally requested.
24    pub requested_model: String,
25    /// ModelProvider that actually served the request.
26    pub actual_provider: String,
27    /// Model that actually served the request.
28    pub actual_model: String,
29}
30
31tokio::task_local! {
32    static PROVIDER_FALLBACK: RefCell<Option<ProviderFallbackInfo>>;
33}
34
35/// Take (consume) the last model_provider fallback info, if any.
36/// Must be called within a `scope_provider_fallback` scope.
37pub fn take_last_provider_fallback() -> Option<ProviderFallbackInfo> {
38    PROVIDER_FALLBACK
39        .try_with(|cell| cell.borrow_mut().take())
40        .ok()
41        .flatten()
42}
43
44/// Run the given future within a provider-fallback scope.
45/// Both `record_provider_fallback` (inside ReliableModelProvider) and
46/// `take_last_provider_fallback` (post-loop channel code) must execute
47/// within this scope for the data to be visible.
48pub async fn scope_provider_fallback<F: std::future::Future>(future: F) -> F::Output {
49    PROVIDER_FALLBACK.scope(RefCell::new(None), future).await
50}
51
52/// Record a model_provider fallback event.
53fn record_provider_fallback(
54    requested_provider: &str,
55    requested_model: &str,
56    actual_provider: &str,
57    actual_model: &str,
58) {
59    let _ = PROVIDER_FALLBACK.try_with(|cell| {
60        *cell.borrow_mut() = Some(ProviderFallbackInfo {
61            requested_provider: requested_provider.to_string(),
62            requested_model: requested_model.to_string(),
63            actual_provider: actual_provider.to_string(),
64            actual_model: actual_model.to_string(),
65        });
66    });
67}
68
69// ── Error Classification ─────────────────────────────────────────────────
70// Errors are split into retryable (transient server/network failures) and
71// non-retryable (permanent client errors). This distinction drives whether
72// the retry loop continues, falls back to the next model_provider, or aborts
73// immediately — avoiding wasted latency on errors that cannot self-heal.
74
75/// Check if an error is non-retryable (client errors that won't resolve with retries).
76pub fn is_non_retryable(err: &anyhow::Error) -> bool {
77    // Context window errors are NOT non-retryable — they can be recovered
78    // by truncating conversation history, so let the retry loop handle them.
79    if is_context_window_exceeded(err) {
80        return false;
81    }
82
83    // Tool schema validation errors are NOT non-retryable — the model_provider's
84    // built-in fallback in compatible.rs can recover by switching to
85    // prompt-guided tool instructions.
86    if is_tool_schema_error(err) {
87        return false;
88    }
89
90    // 4xx errors are generally non-retryable (bad request, auth failure, etc.),
91    // except 429 (rate-limit — transient) and 408 (timeout — worth retrying).
92    if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
93        && let Some(status) = reqwest_err.status()
94    {
95        let code = status.as_u16();
96        return status.is_client_error() && code != 429 && code != 408;
97    }
98    // Fallback: parse status codes from stringified errors (some model_providers
99    // embed codes in error messages rather than returning typed HTTP errors).
100    let msg = err.to_string();
101    for word in msg.split(|c: char| !c.is_ascii_digit()) {
102        if let Ok(code) = word.parse::<u16>()
103            && (400..500).contains(&code)
104        {
105            return code != 429 && code != 408;
106        }
107    }
108
109    // Heuristic: detect auth/model failures by keyword when no HTTP status
110    // is available (e.g. gRPC or custom transport errors).
111    let msg_lower = msg.to_lowercase();
112    let auth_failure_hints = [
113        "invalid api key",
114        "incorrect api key",
115        "missing api key",
116        "api key not set",
117        "authentication failed",
118        "auth failed",
119        "unauthorized",
120        "forbidden",
121        "permission denied",
122        "access denied",
123        "invalid token",
124    ];
125
126    if auth_failure_hints
127        .iter()
128        .any(|hint| msg_lower.contains(hint))
129    {
130        return true;
131    }
132
133    msg_lower.contains("model")
134        && (msg_lower.contains("not found")
135            || msg_lower.contains("unknown")
136            || msg_lower.contains("unsupported")
137            || msg_lower.contains("does not exist")
138            || msg_lower.contains("invalid"))
139}
140
141/// Check if an error indicates an authentication/authorization failure.
142/// Used by channels to evict cached model_providers whose OAuth tokens may have
143/// expired so the next request triggers a fresh credential resolution.
144pub fn is_auth_error(err: &anyhow::Error) -> bool {
145    if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
146        && let Some(status) = reqwest_err.status()
147    {
148        let code = status.as_u16();
149        return code == 401 || code == 403;
150    }
151
152    let msg_lower = err.to_string().to_lowercase();
153    let hints = [
154        "401 unauthorized",
155        "403 forbidden",
156        "invalid api key",
157        "incorrect api key",
158        "authentication failed",
159        "auth failed",
160        "unauthorized",
161        "invalid token",
162        "token expired",
163        "access_token",
164    ];
165
166    hints.iter().any(|hint| msg_lower.contains(hint))
167}
168
169/// Check if an error is a tool schema validation failure (e.g. Groq returning
170/// "tool call validation failed: attempted to call tool '...' which was not in request").
171/// These errors should NOT be classified as non-retryable because the model_provider's
172/// built-in fallback logic (`compatible.rs::is_native_tool_schema_unsupported`)
173/// can recover by switching to prompt-guided tool instructions.
174pub fn is_tool_schema_error(err: &anyhow::Error) -> bool {
175    let lower = err.to_string().to_lowercase();
176    let hints = [
177        "tool call validation failed",
178        "was not in request",
179        "not found in tool list",
180        "invalid_tool_call",
181    ];
182    hints.iter().any(|hint| lower.contains(hint))
183}
184
185pub fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
186    let lower = err.to_string().to_lowercase();
187    let hints = [
188        "exceeds the context window",
189        "exceeds the available context size",
190        "context window of this model",
191        "maximum context length",
192        "context length exceeded",
193        "too many tokens",
194        "token limit exceeded",
195        "prompt is too long",
196        "input is too long",
197        "prompt exceeds max length",
198    ];
199
200    hints.iter().any(|hint| lower.contains(hint))
201}
202
203/// Check if an error is a rate-limit (429) error.
204fn is_rate_limited(err: &anyhow::Error) -> bool {
205    if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
206        && let Some(status) = reqwest_err.status()
207    {
208        return status.as_u16() == 429;
209    }
210    let msg = err.to_string();
211    msg.contains("429")
212        && (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
213}
214
215/// Check if a 429 is a business/quota-plan error that retries cannot fix.
216///
217/// Examples:
218/// - plan does not include requested model
219/// - insufficient balance / package not active
220/// - known model_provider business codes (e.g. Z.AI: 1311, 1113)
221fn is_non_retryable_rate_limit(err: &anyhow::Error) -> bool {
222    if !is_rate_limited(err) {
223        return false;
224    }
225
226    let msg = err.to_string();
227    let lower = msg.to_lowercase();
228
229    let business_hints = [
230        "plan does not include",
231        "doesn't include",
232        "not include",
233        "insufficient balance",
234        "insufficient_balance",
235        "insufficient quota",
236        "insufficient_quota",
237        "quota exhausted",
238        "out of credits",
239        "no available package",
240        "package not active",
241        "purchase package",
242        "model not available for your plan",
243    ];
244
245    if business_hints.iter().any(|hint| lower.contains(hint)) {
246        return true;
247    }
248
249    // Known model_provider business codes observed for 429 where retry is futile.
250    for token in lower.split(|c: char| !c.is_ascii_digit()) {
251        if let Ok(code) = token.parse::<u16>()
252            && matches!(code, 1113 | 1311)
253        {
254            return true;
255        }
256    }
257
258    false
259}
260
261/// Try to extract a Retry-After value (in milliseconds) from an error message.
262/// Looks for patterns like `Retry-After: 5` or `retry_after: 2.5` in the error string.
263fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
264    let msg = err.to_string();
265    let lower = msg.to_lowercase();
266
267    // Look for "retry-after: <number>" or "retry_after: <number>"
268    for prefix in &[
269        "retry-after:",
270        "retry_after:",
271        "retry-after ",
272        "retry_after ",
273    ] {
274        if let Some(pos) = lower.find(prefix) {
275            let after = &msg[pos + prefix.len()..];
276            let num_str: String = after
277                .trim()
278                .chars()
279                .take_while(|c| c.is_ascii_digit() || *c == '.')
280                .collect();
281            if let Ok(secs) = num_str.parse::<f64>()
282                && secs.is_finite()
283                && secs >= 0.0
284            {
285                let millis = Duration::from_secs_f64(secs).as_millis();
286                if let Ok(value) = u64::try_from(millis) {
287                    return Some(value);
288                }
289            }
290        }
291    }
292    None
293}
294
295fn failure_reason(rate_limited: bool, non_retryable: bool) -> &'static str {
296    if rate_limited && non_retryable {
297        "rate_limited_non_retryable"
298    } else if rate_limited {
299        "rate_limited"
300    } else if non_retryable {
301        "non_retryable"
302    } else {
303        "retryable"
304    }
305}
306
307fn compact_error_detail(err: &anyhow::Error) -> String {
308    super::sanitize_api_error(&format!("{err:#}"))
309        .split_whitespace()
310        .collect::<Vec<_>>()
311        .join(" ")
312}
313
314/// Truncate conversation history by dropping the oldest non-system messages.
315/// Returns the number of messages dropped. Keeps at least the system message
316/// (if any) and the most recent user message.
317fn truncate_for_context(messages: &mut Vec<ChatMessage>) -> usize {
318    // Find all non-system message indices
319    let non_system: Vec<usize> = messages
320        .iter()
321        .enumerate()
322        .filter(|(_, m)| m.role != "system")
323        .map(|(i, _)| i)
324        .collect();
325
326    // Keep at least the last non-system message (most recent user turn)
327    if non_system.len() <= 1 {
328        return 0;
329    }
330
331    // Drop the oldest half of non-system messages
332    let drop_count = non_system.len() / 2;
333    let indices_to_remove: Vec<usize> = non_system[..drop_count].to_vec();
334
335    // Remove in reverse order to preserve indices
336    for &idx in indices_to_remove.iter().rev() {
337        messages.remove(idx);
338    }
339
340    drop_count
341}
342
343fn push_failure(
344    failures: &mut Vec<String>,
345    provider_name: &str,
346    model: &str,
347    attempt: u32,
348    max_attempts: u32,
349    reason: &str,
350    error_detail: &str,
351) {
352    failures.push(format!(
353        "model_provider={provider_name} model={model} attempt {attempt}/{max_attempts}: {reason}; error={error_detail}"
354    ));
355}
356
357// ── Resilient ModelProvider Wrapper ────────────────────────────────────────────
358// Two-level strategy: model_provider chain → retry loop.
359//   Outer loop: iterate registered model_providers in priority order. The production
360//               caller always wires a single primary; tests construct multi-
361//               element chains directly to exercise failover semantics.
362//   Inner loop: retry the same (model_provider, model) pair with exponential backoff,
363//               rotating API keys on rate-limit errors.
364// Loop invariant: `failures` accumulates every failed attempt so the final
365// error message gives operators a complete diagnostic trail.
366
367/// ModelProvider wrapper with retry + auth-key rotation. The model_provider Vec exists
368/// for tests to exercise multi-provider failover; production wiring always
369/// passes a single primary. Per-model failover chains are also test-only —
370/// the schema no longer surfaces them.
371pub struct ReliableModelProvider {
372    /// `[model_providers.<family>.<alias>]` config-key alias.
373    alias: String,
374    model_providers: Vec<(String, Box<dyn ModelProvider>)>,
375    max_retries: u32,
376    base_backoff_ms: u64,
377    /// Extra API keys for rotation (index tracks round-robin position).
378    api_keys: Vec<String>,
379    key_index: AtomicUsize,
380    /// Per-model failover chains. Test-only: model_name → [alt1, alt2, ...].
381    model_fallbacks: HashMap<String, Vec<String>>,
382}
383
384impl ReliableModelProvider {
385    pub fn new(
386        alias: &str,
387        model_providers: Vec<(String, Box<dyn ModelProvider>)>,
388        max_retries: u32,
389        base_backoff_ms: u64,
390    ) -> Self {
391        Self {
392            alias: alias.to_string(),
393            model_providers,
394            max_retries,
395            base_backoff_ms: base_backoff_ms.max(50),
396            api_keys: Vec::new(),
397            key_index: AtomicUsize::new(0),
398            model_fallbacks: HashMap::new(),
399        }
400    }
401    /// Set additional API keys for round-robin rotation on rate-limit errors.
402    pub fn with_api_keys(mut self, keys: Vec<String>) -> Self {
403        self.api_keys = keys;
404        self
405    }
406
407    /// Test-only hook: install per-model failover chains. Production builds
408    /// never call this — the schema has no surface for it.
409    #[cfg(test)]
410    pub fn with_model_fallbacks(mut self, fallbacks: HashMap<String, Vec<String>>) -> Self {
411        self.model_fallbacks = fallbacks;
412        self
413    }
414
415    /// Build the list of models to try: [original, alt1, alt2, ...]
416    fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> {
417        let mut chain = vec![model];
418        if let Some(fallbacks) = self.model_fallbacks.get(model) {
419            chain.extend(fallbacks.iter().map(|s| s.as_str()));
420        }
421        chain
422    }
423
424    /// Advance to the next API key and return it, or None if no extra keys configured.
425    fn rotate_key(&self) -> Option<&str> {
426        if self.api_keys.is_empty() {
427            return None;
428        }
429        let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len();
430        Some(&self.api_keys[idx])
431    }
432
433    /// Compute backoff duration, respecting Retry-After if present.
434    fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 {
435        if let Some(retry_after) = parse_retry_after_ms(err) {
436            // Use Retry-After but cap at 30s to avoid indefinite waits
437            retry_after.min(30_000).max(base)
438        } else {
439            base
440        }
441    }
442}
443
444#[async_trait]
445impl ModelProvider for ReliableModelProvider {
446    async fn warmup(&self) -> anyhow::Result<()> {
447        for (name, model_provider) in &self.model_providers {
448            ::zeroclaw_log::record!(
449                INFO,
450                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
451                    .with_attrs(::serde_json::json!({"model_provider": name})),
452                "Warming up model_provider connection pool"
453            );
454            if model_provider.warmup().await.is_err() {
455                ::zeroclaw_log::record!(
456                    WARN,
457                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
458                        .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
459                        .with_attrs(::serde_json::json!({"model_provider": name})),
460                    "Warmup failed (non-fatal)"
461                );
462            }
463        }
464        Ok(())
465    }
466
467    async fn chat_with_system(
468        &self,
469        system_prompt: Option<&str>,
470        message: &str,
471        model: &str,
472        temperature: Option<f64>,
473    ) -> anyhow::Result<String> {
474        let models = self.model_chain(model);
475        let mut failures = Vec::new();
476
477        // Outer: model fallback chain. Middle: model_provider priority. Inner: retries.
478        // Each iteration: attempt one (model_provider, model) call. On success, return
479        // immediately. On non-retryable error, break to next model_provider. On
480        // retryable error, sleep with exponential backoff and retry.
481        for current_model in &models {
482            for (provider_name, model_provider) in &self.model_providers {
483                let mut backoff_ms = self.base_backoff_ms;
484
485                for attempt in 0..=self.max_retries {
486                    match model_provider
487                        .chat_with_system(system_prompt, message, current_model, temperature)
488                        .await
489                    {
490                        Ok(resp) => {
491                            if attempt > 0
492                                || *current_model != model
493                                || self.model_providers.first().map(|(n, _)| n.as_str())
494                                    != Some(provider_name)
495                            {
496                                ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model})), "ModelProvider recovered (failover/retry)");
497                                let primary = self
498                                    .model_providers
499                                    .first()
500                                    .map(|(n, _)| n.as_str())
501                                    .unwrap_or("");
502                                record_provider_fallback(
503                                    primary,
504                                    model,
505                                    provider_name,
506                                    current_model,
507                                );
508                            }
509                            return Ok(resp);
510                        }
511                        Err(e) => {
512                            // Context window exceeded: no history to truncate
513                            // in chat_with_system, bail immediately.
514                            if is_context_window_exceeded(&e) {
515                                let error_detail = compact_error_detail(&e);
516                                push_failure(
517                                    &mut failures,
518                                    provider_name,
519                                    current_model,
520                                    attempt + 1,
521                                    self.max_retries + 1,
522                                    "non_retryable",
523                                    &error_detail,
524                                );
525                                anyhow::bail!(
526                                    "Request exceeds model context window. Attempts:\n{}",
527                                    failures.join("\n")
528                                );
529                            }
530
531                            let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
532                            let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
533                            let rate_limited = is_rate_limited(&e);
534                            let failure_reason = failure_reason(rate_limited, non_retryable);
535                            let error_detail = compact_error_detail(&e);
536
537                            push_failure(
538                                &mut failures,
539                                provider_name,
540                                current_model,
541                                attempt + 1,
542                                self.max_retries + 1,
543                                failure_reason,
544                                &error_detail,
545                            );
546
547                            // Rate-limit with rotatable keys: cycle to the next API key
548                            // so the retry hits a different quota bucket.
549                            if rate_limited
550                                && !non_retryable_rate_limit
551                                && let Some(new_key) = self.rotate_key()
552                            {
553                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
554                                     but cannot apply (ModelProvider trait has no set_api_key). \
555                                     Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
556                            }
557
558                            if non_retryable {
559                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
560                                break;
561                            }
562
563                            if attempt < self.max_retries {
564                                let wait = self.compute_backoff(backoff_ms, &e);
565                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
566                                tokio::time::sleep(Duration::from_millis(wait)).await;
567                                backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
568                            }
569                        }
570                    }
571                }
572
573                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
574            }
575
576            if *current_model != model {
577                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"original_model": model, "fallback_model": *current_model})), "Model fallback exhausted all model_providers, trying next fallback model");
578            }
579        }
580
581        anyhow::bail!(
582            "All model_providers/models failed. Attempts:\n{}",
583            failures.join("\n")
584        )
585    }
586
587    async fn chat_with_history(
588        &self,
589        messages: &[ChatMessage],
590        model: &str,
591        temperature: Option<f64>,
592    ) -> anyhow::Result<String> {
593        let models = self.model_chain(model);
594        let mut failures = Vec::new();
595        let mut effective_messages = messages.to_vec();
596        let mut context_truncated = false;
597
598        for current_model in &models {
599            for (provider_name, model_provider) in &self.model_providers {
600                let mut backoff_ms = self.base_backoff_ms;
601
602                for attempt in 0..=self.max_retries {
603                    match model_provider
604                        .chat_with_history(&effective_messages, current_model, temperature)
605                        .await
606                    {
607                        Ok(resp) => {
608                            if attempt > 0
609                                || *current_model != model
610                                || context_truncated
611                                || self.model_providers.first().map(|(n, _)| n.as_str())
612                                    != Some(provider_name)
613                            {
614                                ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model, "context_truncated": context_truncated})), "ModelProvider recovered (failover/retry)");
615                                let primary = self
616                                    .model_providers
617                                    .first()
618                                    .map(|(n, _)| n.as_str())
619                                    .unwrap_or("");
620                                record_provider_fallback(
621                                    primary,
622                                    model,
623                                    provider_name,
624                                    current_model,
625                                );
626                            }
627                            return Ok(resp);
628                        }
629                        Err(e) => {
630                            // Context window exceeded: truncate history and retry
631                            if is_context_window_exceeded(&e) && !context_truncated {
632                                let dropped = truncate_for_context(&mut effective_messages);
633                                if dropped > 0 {
634                                    context_truncated = true;
635                                    ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "dropped": dropped, "remaining": effective_messages.len()})), "Context window exceeded; truncated history and retrying");
636                                    continue; // Retry with truncated messages (counts as an attempt)
637                                }
638                                // Nothing to truncate (system prompt alone exceeds
639                                // the model's context window) — bail immediately
640                                // instead of wasting retry attempts.
641                                let error_detail = compact_error_detail(&e);
642                                push_failure(
643                                    &mut failures,
644                                    provider_name,
645                                    current_model,
646                                    attempt + 1,
647                                    self.max_retries + 1,
648                                    "non_retryable",
649                                    &error_detail,
650                                );
651                                anyhow::bail!(
652                                    "Request exceeds model context window and cannot be reduced further. \
653                                     Try using a model with a larger context window, reducing the number \
654                                     of tools/skills, or enabling compact_context in config. Attempts:\n{}",
655                                    failures.join("\n")
656                                );
657                            }
658
659                            let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
660                            let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
661                            let rate_limited = is_rate_limited(&e);
662                            let failure_reason = failure_reason(rate_limited, non_retryable);
663                            let error_detail = compact_error_detail(&e);
664
665                            push_failure(
666                                &mut failures,
667                                provider_name,
668                                current_model,
669                                attempt + 1,
670                                self.max_retries + 1,
671                                failure_reason,
672                                &error_detail,
673                            );
674
675                            if rate_limited
676                                && !non_retryable_rate_limit
677                                && let Some(new_key) = self.rotate_key()
678                            {
679                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
680                                     but cannot apply (ModelProvider trait has no set_api_key). \
681                                     Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
682                            }
683
684                            if non_retryable {
685                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
686                                break;
687                            }
688
689                            if attempt < self.max_retries {
690                                let wait = self.compute_backoff(backoff_ms, &e);
691                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
692                                tokio::time::sleep(Duration::from_millis(wait)).await;
693                                backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
694                            }
695                        }
696                    }
697                }
698
699                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
700            }
701        }
702
703        anyhow::bail!(
704            "All model_providers/models failed. Attempts:\n{}",
705            failures.join("\n")
706        )
707    }
708
709    fn supports_native_tools(&self) -> bool {
710        self.model_providers
711            .first()
712            .map(|(_, p)| p.supports_native_tools())
713            .unwrap_or(false)
714    }
715
716    fn supports_vision(&self) -> bool {
717        self.model_providers
718            .first()
719            .map(|(_, p)| p.supports_vision())
720            .unwrap_or(false)
721    }
722
723    async fn chat_with_tools(
724        &self,
725        messages: &[ChatMessage],
726        tools: &[serde_json::Value],
727        model: &str,
728        temperature: Option<f64>,
729    ) -> anyhow::Result<ChatResponse> {
730        let models = self.model_chain(model);
731        let mut failures = Vec::new();
732        let mut effective_messages = messages.to_vec();
733        let mut context_truncated = false;
734
735        for current_model in &models {
736            for (provider_name, model_provider) in &self.model_providers {
737                let mut backoff_ms = self.base_backoff_ms;
738
739                for attempt in 0..=self.max_retries {
740                    match model_provider
741                        .chat_with_tools(&effective_messages, tools, current_model, temperature)
742                        .await
743                    {
744                        Ok(resp) => {
745                            if attempt > 0
746                                || *current_model != model
747                                || context_truncated
748                                || self.model_providers.first().map(|(n, _)| n.as_str())
749                                    != Some(provider_name)
750                            {
751                                ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model, "context_truncated": context_truncated})), "ModelProvider recovered (failover/retry)");
752                                let primary = self
753                                    .model_providers
754                                    .first()
755                                    .map(|(n, _)| n.as_str())
756                                    .unwrap_or("");
757                                record_provider_fallback(
758                                    primary,
759                                    model,
760                                    provider_name,
761                                    current_model,
762                                );
763                            }
764                            return Ok(resp);
765                        }
766                        Err(e) => {
767                            // Context window exceeded: truncate history and retry
768                            if is_context_window_exceeded(&e) && !context_truncated {
769                                let dropped = truncate_for_context(&mut effective_messages);
770                                if dropped > 0 {
771                                    context_truncated = true;
772                                    ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "dropped": dropped, "remaining": effective_messages.len()})), "Context window exceeded; truncated history and retrying");
773                                    continue; // Retry with truncated messages (counts as an attempt)
774                                }
775                                // Nothing to truncate (system prompt alone exceeds
776                                // the model's context window) — bail immediately
777                                // instead of wasting retry attempts.
778                                let error_detail = compact_error_detail(&e);
779                                push_failure(
780                                    &mut failures,
781                                    provider_name,
782                                    current_model,
783                                    attempt + 1,
784                                    self.max_retries + 1,
785                                    "non_retryable",
786                                    &error_detail,
787                                );
788                                anyhow::bail!(
789                                    "Request exceeds model context window and cannot be reduced further. \
790                                     Try using a model with a larger context window, reducing the number \
791                                     of tools/skills, or enabling compact_context in config. Attempts:\n{}",
792                                    failures.join("\n")
793                                );
794                            }
795
796                            let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
797                            let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
798                            let rate_limited = is_rate_limited(&e);
799                            let failure_reason = failure_reason(rate_limited, non_retryable);
800                            let error_detail = compact_error_detail(&e);
801
802                            push_failure(
803                                &mut failures,
804                                provider_name,
805                                current_model,
806                                attempt + 1,
807                                self.max_retries + 1,
808                                failure_reason,
809                                &error_detail,
810                            );
811
812                            if rate_limited
813                                && !non_retryable_rate_limit
814                                && let Some(new_key) = self.rotate_key()
815                            {
816                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
817                                     but cannot apply (ModelProvider trait has no set_api_key). \
818                                     Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
819                            }
820
821                            if non_retryable {
822                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
823                                break;
824                            }
825
826                            if attempt < self.max_retries {
827                                let wait = self.compute_backoff(backoff_ms, &e);
828                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
829                                tokio::time::sleep(Duration::from_millis(wait)).await;
830                                backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
831                            }
832                        }
833                    }
834                }
835
836                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
837            }
838        }
839
840        anyhow::bail!(
841            "All model_providers/models failed. Attempts:\n{}",
842            failures.join("\n")
843        )
844    }
845
846    async fn chat(
847        &self,
848        request: ChatRequest<'_>,
849        model: &str,
850        temperature: Option<f64>,
851    ) -> anyhow::Result<ChatResponse> {
852        let models = self.model_chain(model);
853        let mut failures = Vec::new();
854        let mut effective_messages = request.messages.to_vec();
855        let mut context_truncated = false;
856
857        for current_model in &models {
858            for (provider_name, model_provider) in &self.model_providers {
859                let mut backoff_ms = self.base_backoff_ms;
860
861                for attempt in 0..=self.max_retries {
862                    let req = ChatRequest {
863                        messages: &effective_messages,
864                        tools: request.tools,
865                        thinking: request.thinking,
866                    };
867                    match model_provider.chat(req, current_model, temperature).await {
868                        Ok(resp) => {
869                            if attempt > 0
870                                || *current_model != model
871                                || context_truncated
872                                || self.model_providers.first().map(|(n, _)| n.as_str())
873                                    != Some(provider_name)
874                            {
875                                ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model, "context_truncated": context_truncated})), "ModelProvider recovered (failover/retry)");
876                                let primary = self
877                                    .model_providers
878                                    .first()
879                                    .map(|(n, _)| n.as_str())
880                                    .unwrap_or("");
881                                record_provider_fallback(
882                                    primary,
883                                    model,
884                                    provider_name,
885                                    current_model,
886                                );
887                            }
888                            return Ok(resp);
889                        }
890                        Err(e) => {
891                            // Context window exceeded: truncate history and retry
892                            if is_context_window_exceeded(&e) && !context_truncated {
893                                let dropped = truncate_for_context(&mut effective_messages);
894                                if dropped > 0 {
895                                    context_truncated = true;
896                                    ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "dropped": dropped, "remaining": effective_messages.len()})), "Context window exceeded; truncated history and retrying");
897                                    continue; // Retry with truncated messages (counts as an attempt)
898                                }
899                                // Nothing to truncate (system prompt alone exceeds
900                                // the model's context window) — bail immediately
901                                // instead of wasting retry attempts.
902                                let error_detail = compact_error_detail(&e);
903                                push_failure(
904                                    &mut failures,
905                                    provider_name,
906                                    current_model,
907                                    attempt + 1,
908                                    self.max_retries + 1,
909                                    "non_retryable",
910                                    &error_detail,
911                                );
912                                anyhow::bail!(
913                                    "Request exceeds model context window and cannot be reduced further. \
914                                     Try using a model with a larger context window, reducing the number \
915                                     of tools/skills, or enabling compact_context in config. Attempts:\n{}",
916                                    failures.join("\n")
917                                );
918                            }
919
920                            let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
921                            let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
922                            let rate_limited = is_rate_limited(&e);
923                            let failure_reason = failure_reason(rate_limited, non_retryable);
924                            let error_detail = compact_error_detail(&e);
925
926                            push_failure(
927                                &mut failures,
928                                provider_name,
929                                current_model,
930                                attempt + 1,
931                                self.max_retries + 1,
932                                failure_reason,
933                                &error_detail,
934                            );
935
936                            if rate_limited
937                                && !non_retryable_rate_limit
938                                && let Some(new_key) = self.rotate_key()
939                            {
940                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
941                                     but cannot apply (ModelProvider trait has no set_api_key). \
942                                     Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
943                            }
944
945                            if non_retryable {
946                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
947                                break;
948                            }
949
950                            if attempt < self.max_retries {
951                                let wait = self.compute_backoff(backoff_ms, &e);
952                                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
953                                tokio::time::sleep(Duration::from_millis(wait)).await;
954                                backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
955                            }
956                        }
957                    }
958                }
959
960                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
961            }
962
963            if *current_model != model {
964                ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"original_model": model, "fallback_model": *current_model})), "Model fallback exhausted all model_providers, trying next fallback model");
965            }
966        }
967
968        anyhow::bail!(
969            "All model_providers/models failed. Attempts:\n{}",
970            failures.join("\n")
971        )
972    }
973
974    fn supports_streaming(&self) -> bool {
975        self.model_providers
976            .iter()
977            .any(|(_, p)| p.supports_streaming())
978    }
979
980    fn supports_streaming_tool_events(&self) -> bool {
981        self.model_providers
982            .iter()
983            .any(|(_, p)| p.supports_streaming_tool_events())
984    }
985
986    fn stream_chat(
987        &self,
988        request: ChatRequest<'_>,
989        model: &str,
990        temperature: Option<f64>,
991        options: StreamOptions,
992    ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
993        let needs_tool_events = request.tools.is_some_and(|tools| !tools.is_empty());
994
995        for (provider_name, model_provider) in &self.model_providers {
996            if !model_provider.supports_streaming() || !options.enabled {
997                continue;
998            }
999
1000            if needs_tool_events && !model_provider.supports_streaming_tool_events() {
1001                continue;
1002            }
1003
1004            let provider_clone = provider_name.clone();
1005
1006            let current_model = self
1007                .model_chain(model)
1008                .first()
1009                .copied()
1010                .unwrap_or(model)
1011                .to_string();
1012
1013            let req = ChatRequest {
1014                messages: request.messages,
1015                tools: request.tools,
1016                thinking: request.thinking,
1017            };
1018            let stream = model_provider.stream_chat(req, &current_model, temperature, options);
1019            let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamEvent>>(100);
1020
1021            tokio::spawn(async move {
1022                let mut stream = stream;
1023                while let Some(event) = stream.next().await {
1024                    if let Err(ref e) = event {
1025                        ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_clone, "model": current_model, "e": e.to_string()})), "Streaming error: ");
1026                    }
1027                    if tx.send(event).await.is_err() {
1028                        break;
1029                    }
1030                }
1031            });
1032
1033            return stream::unfold(rx, |mut rx| async move {
1034                rx.recv().await.map(|event| (event, rx))
1035            })
1036            .boxed();
1037        }
1038
1039        let message = if needs_tool_events {
1040            "No model_provider supports streaming tool events".to_string()
1041        } else {
1042            "No model_provider supports streaming".to_string()
1043        };
1044        stream::once(async move { Err(super::traits::StreamError::ModelProvider(message)) }).boxed()
1045    }
1046
1047    fn stream_chat_with_system(
1048        &self,
1049        system_prompt: Option<&str>,
1050        message: &str,
1051        model: &str,
1052        temperature: Option<f64>,
1053        options: StreamOptions,
1054    ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1055        // Try each model_provider/model combination for streaming
1056        // For streaming, we use the first model_provider that supports it and has streaming enabled
1057        for (provider_name, model_provider) in &self.model_providers {
1058            if !model_provider.supports_streaming() || !options.enabled {
1059                continue;
1060            }
1061
1062            // Clone model_provider data for the stream
1063            let provider_clone = provider_name.clone();
1064
1065            // Try the first model in the chain for streaming
1066            let current_model = match self.model_chain(model).first() {
1067                Some(m) => (*m).to_string(),
1068                None => model.to_string(),
1069            };
1070
1071            // For streaming, we attempt once and propagate errors
1072            // The caller can retry the entire request if needed
1073            let stream = model_provider.stream_chat_with_system(
1074                system_prompt,
1075                message,
1076                &current_model,
1077                temperature,
1078                options,
1079            );
1080
1081            // Use a channel to bridge the stream with logging
1082            let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1083
1084            tokio::spawn(async move {
1085                let mut stream = stream;
1086                while let Some(chunk) = stream.next().await {
1087                    if let Err(ref e) = chunk {
1088                        ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_clone, "model": current_model, "e": e.to_string()})), "Streaming error: ");
1089                    }
1090                    if tx.send(chunk).await.is_err() {
1091                        break; // Receiver dropped
1092                    }
1093                }
1094            });
1095
1096            // Convert channel receiver to stream
1097            return stream::unfold(rx, |mut rx| async move {
1098                rx.recv().await.map(|chunk| (chunk, rx))
1099            })
1100            .boxed();
1101        }
1102
1103        // No streaming support available
1104        stream::once(async move {
1105            Err(super::traits::StreamError::ModelProvider(
1106                "No model_provider supports streaming".to_string(),
1107            ))
1108        })
1109        .boxed()
1110    }
1111
1112    fn stream_chat_with_history(
1113        &self,
1114        messages: &[ChatMessage],
1115        model: &str,
1116        temperature: Option<f64>,
1117        options: StreamOptions,
1118    ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1119        // Try each model_provider/model combination for streaming with history.
1120        // Mirrors stream_chat_with_system but delegates to the underlying
1121        // model_provider's stream_chat_with_history, preserving the full conversation.
1122        for (provider_name, model_provider) in &self.model_providers {
1123            if !model_provider.supports_streaming() || !options.enabled {
1124                continue;
1125            }
1126
1127            let provider_clone = provider_name.clone();
1128
1129            let current_model = match self.model_chain(model).first() {
1130                Some(m) => (*m).to_string(),
1131                None => model.to_string(),
1132            };
1133
1134            let stream = model_provider.stream_chat_with_history(
1135                messages,
1136                &current_model,
1137                temperature,
1138                options,
1139            );
1140
1141            let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1142
1143            tokio::spawn(async move {
1144                let mut stream = stream;
1145                while let Some(chunk) = stream.next().await {
1146                    if let Err(ref e) = chunk {
1147                        ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_clone, "model": current_model, "e": e.to_string()})), "Streaming error: ");
1148                    }
1149                    if tx.send(chunk).await.is_err() {
1150                        break; // Receiver dropped
1151                    }
1152                }
1153            });
1154
1155            return stream::unfold(rx, |mut rx| async move {
1156                rx.recv().await.map(|chunk| (chunk, rx))
1157            })
1158            .boxed();
1159        }
1160
1161        // No streaming support available
1162        stream::once(async move {
1163            Err(super::traits::StreamError::ModelProvider(
1164                "No model_provider supports streaming".to_string(),
1165            ))
1166        })
1167        .boxed()
1168    }
1169}
1170
1171impl ::zeroclaw_api::attribution::Attributable for ReliableModelProvider {
1172    fn role(&self) -> ::zeroclaw_api::attribution::Role {
1173        ::zeroclaw_api::attribution::Role::Provider(
1174            ::zeroclaw_api::attribution::ProviderKind::Model(
1175                ::zeroclaw_api::attribution::ModelProviderKind::Reliable,
1176            ),
1177        )
1178    }
1179    fn alias(&self) -> &str {
1180        &self.alias
1181    }
1182}
1183
1184#[cfg(test)]
1185mod tests {
1186    use super::*;
1187    use futures_util::StreamExt;
1188    use std::sync::Arc;
1189    use zeroclaw_api::tool::ToolSpec;
1190
1191    struct MockModelProvider {
1192        calls: Arc<AtomicUsize>,
1193        fail_until_attempt: usize,
1194        response: &'static str,
1195        error: &'static str,
1196    }
1197
1198    #[async_trait]
1199    impl ModelProvider for MockModelProvider {
1200        async fn chat_with_system(
1201            &self,
1202            _system_prompt: Option<&str>,
1203            _message: &str,
1204            _model: &str,
1205            _temperature: Option<f64>,
1206        ) -> anyhow::Result<String> {
1207            let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1208            if attempt <= self.fail_until_attempt {
1209                anyhow::bail!(self.error);
1210            }
1211            Ok(self.response.to_string())
1212        }
1213
1214        async fn chat_with_history(
1215            &self,
1216            _messages: &[ChatMessage],
1217            _model: &str,
1218            _temperature: Option<f64>,
1219        ) -> anyhow::Result<String> {
1220            let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1221            if attempt <= self.fail_until_attempt {
1222                anyhow::bail!(self.error);
1223            }
1224            Ok(self.response.to_string())
1225        }
1226    }
1227    impl ::zeroclaw_api::attribution::Attributable for MockModelProvider {
1228        fn role(&self) -> ::zeroclaw_api::attribution::Role {
1229            ::zeroclaw_api::attribution::Role::Provider(
1230                ::zeroclaw_api::attribution::ProviderKind::Model(
1231                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
1232                ),
1233            )
1234        }
1235        fn alias(&self) -> &str {
1236            "MockModelProvider"
1237        }
1238    }
1239
1240    /// Mock that records which model was used for each call.
1241    struct ModelAwareMock {
1242        calls: Arc<AtomicUsize>,
1243        models_seen: parking_lot::Mutex<Vec<String>>,
1244        fail_models: Vec<&'static str>,
1245        response: &'static str,
1246    }
1247
1248    #[async_trait]
1249    impl ModelProvider for ModelAwareMock {
1250        async fn chat_with_system(
1251            &self,
1252            _system_prompt: Option<&str>,
1253            _message: &str,
1254            model: &str,
1255            _temperature: Option<f64>,
1256        ) -> anyhow::Result<String> {
1257            self.calls.fetch_add(1, Ordering::SeqCst);
1258            self.models_seen.lock().push(model.to_string());
1259            if self.fail_models.contains(&model) {
1260                anyhow::bail!("500 model {} unavailable", model);
1261            }
1262            Ok(self.response.to_string())
1263        }
1264    }
1265    impl ::zeroclaw_api::attribution::Attributable for ModelAwareMock {
1266        fn role(&self) -> ::zeroclaw_api::attribution::Role {
1267            ::zeroclaw_api::attribution::Role::Provider(
1268                ::zeroclaw_api::attribution::ProviderKind::Model(
1269                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
1270                ),
1271            )
1272        }
1273        fn alias(&self) -> &str {
1274            "ModelAwareMock"
1275        }
1276    }
1277
1278    // ── Existing tests (preserved) ──
1279
1280    #[tokio::test]
1281    async fn succeeds_without_retry() {
1282        let calls = Arc::new(AtomicUsize::new(0));
1283        let model_provider = ReliableModelProvider::new(
1284            "test",
1285            vec![(
1286                "primary".into(),
1287                Box::new(MockModelProvider {
1288                    calls: Arc::clone(&calls),
1289                    fail_until_attempt: 0,
1290                    response: "ok",
1291                    error: "boom",
1292                }),
1293            )],
1294            2,
1295            1,
1296        );
1297
1298        let result = model_provider
1299            .simple_chat("hello", "test", Some(0.0))
1300            .await
1301            .unwrap();
1302        assert_eq!(result, "ok");
1303        assert_eq!(calls.load(Ordering::SeqCst), 1);
1304    }
1305
1306    #[tokio::test]
1307    async fn retries_then_recovers() {
1308        let calls = Arc::new(AtomicUsize::new(0));
1309        let model_provider = ReliableModelProvider::new(
1310            "test",
1311            vec![(
1312                "primary".into(),
1313                Box::new(MockModelProvider {
1314                    calls: Arc::clone(&calls),
1315                    fail_until_attempt: 1,
1316                    response: "recovered",
1317                    error: "temporary",
1318                }),
1319            )],
1320            2,
1321            1,
1322        );
1323
1324        let result = model_provider
1325            .simple_chat("hello", "test", Some(0.0))
1326            .await
1327            .unwrap();
1328        assert_eq!(result, "recovered");
1329        assert_eq!(calls.load(Ordering::SeqCst), 2);
1330    }
1331
1332    #[tokio::test]
1333    async fn falls_back_after_retries_exhausted() {
1334        let primary_calls = Arc::new(AtomicUsize::new(0));
1335        let fallback_calls = Arc::new(AtomicUsize::new(0));
1336
1337        let model_provider = ReliableModelProvider::new(
1338            "test",
1339            vec![
1340                (
1341                    "primary".into(),
1342                    Box::new(MockModelProvider {
1343                        calls: Arc::clone(&primary_calls),
1344                        fail_until_attempt: usize::MAX,
1345                        response: "never",
1346                        error: "primary down",
1347                    }),
1348                ),
1349                (
1350                    "fallback".into(),
1351                    Box::new(MockModelProvider {
1352                        calls: Arc::clone(&fallback_calls),
1353                        fail_until_attempt: 0,
1354                        response: "from fallback",
1355                        error: "fallback down",
1356                    }),
1357                ),
1358            ],
1359            1,
1360            1,
1361        );
1362
1363        let result = model_provider
1364            .simple_chat("hello", "test", Some(0.0))
1365            .await
1366            .unwrap();
1367        assert_eq!(result, "from fallback");
1368        assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
1369        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1370    }
1371
1372    #[tokio::test]
1373    async fn returns_aggregated_error_when_all_providers_fail() {
1374        let model_provider = ReliableModelProvider::new(
1375            "test",
1376            vec![
1377                (
1378                    "p1".into(),
1379                    Box::new(MockModelProvider {
1380                        calls: Arc::new(AtomicUsize::new(0)),
1381                        fail_until_attempt: usize::MAX,
1382                        response: "never",
1383                        error: "p1 error",
1384                    }),
1385                ),
1386                (
1387                    "p2".into(),
1388                    Box::new(MockModelProvider {
1389                        calls: Arc::new(AtomicUsize::new(0)),
1390                        fail_until_attempt: usize::MAX,
1391                        response: "never",
1392                        error: "p2 error",
1393                    }),
1394                ),
1395            ],
1396            0,
1397            1,
1398        );
1399
1400        let err = model_provider
1401            .simple_chat("hello", "test", Some(0.0))
1402            .await
1403            .expect_err("all model_providers should fail");
1404        let msg = err.to_string();
1405        assert!(msg.contains("All model_providers/models failed"));
1406        assert!(msg.contains("model_provider=p1 model=test"));
1407        assert!(msg.contains("model_provider=p2 model=test"));
1408        assert!(msg.contains("error=p1 error"));
1409        assert!(msg.contains("error=p2 error"));
1410        assert!(msg.contains("retryable"));
1411    }
1412
1413    #[test]
1414    fn non_retryable_detects_common_patterns() {
1415        assert!(is_non_retryable(&anyhow::Error::msg("400 Bad Request")));
1416        assert!(is_non_retryable(&anyhow::Error::msg("401 Unauthorized")));
1417        assert!(is_non_retryable(&anyhow::Error::msg("403 Forbidden")));
1418        assert!(is_non_retryable(&anyhow::Error::msg("404 Not Found")));
1419        assert!(is_non_retryable(&anyhow::Error::msg(
1420            "invalid api key provided"
1421        )));
1422        assert!(is_non_retryable(&anyhow::Error::msg(
1423            "authentication failed"
1424        )));
1425        assert!(is_non_retryable(&anyhow::Error::msg(
1426            "model glm-4.7 not found"
1427        )));
1428        assert!(is_non_retryable(&anyhow::Error::msg(
1429            "unsupported model: glm-4.7"
1430        )));
1431        assert!(!is_non_retryable(&anyhow::Error::msg(
1432            "429 Too Many Requests"
1433        )));
1434        assert!(!is_non_retryable(&anyhow::Error::msg(
1435            "408 Request Timeout"
1436        )));
1437        assert!(!is_non_retryable(&anyhow::Error::msg(
1438            "500 Internal Server Error"
1439        )));
1440        assert!(!is_non_retryable(&anyhow::Error::msg("502 Bad Gateway")));
1441        assert!(!is_non_retryable(&anyhow::Error::msg("timeout")));
1442        assert!(!is_non_retryable(&anyhow::Error::msg("connection reset")));
1443        assert!(!is_non_retryable(&anyhow::Error::msg(
1444            "model overloaded, try again later"
1445        )));
1446        // Context window errors are now recoverable (not non-retryable)
1447        assert!(!is_non_retryable(&anyhow::Error::msg(
1448            "OpenAI Codex stream error: Your input exceeds the context window of this model."
1449        )));
1450    }
1451
1452    #[test]
1453    fn auth_error_detects_common_patterns() {
1454        assert!(is_auth_error(&anyhow::Error::msg("401 Unauthorized")));
1455        assert!(is_auth_error(&anyhow::Error::msg("403 Forbidden")));
1456        assert!(is_auth_error(&anyhow::Error::msg("invalid api key")));
1457        assert!(is_auth_error(&anyhow::Error::msg("authentication failed")));
1458        assert!(is_auth_error(&anyhow::Error::msg("token expired")));
1459        assert!(!is_auth_error(&anyhow::Error::msg("400 Bad Request")));
1460        assert!(!is_auth_error(&anyhow::Error::msg("429 Too Many Requests")));
1461        assert!(!is_auth_error(&anyhow::Error::msg("timeout")));
1462        assert!(!is_auth_error(&anyhow::Error::msg("connection reset")));
1463    }
1464
1465    #[tokio::test]
1466    async fn context_window_error_aborts_retries_and_model_fallbacks() {
1467        let calls = Arc::new(AtomicUsize::new(0));
1468        let mut model_fallbacks = std::collections::HashMap::new();
1469        model_fallbacks.insert(
1470            "gpt-5.3-codex".to_string(),
1471            vec!["gpt-5.2-codex".to_string()],
1472        );
1473
1474        let model_provider = ReliableModelProvider::new("test", vec![(
1475                "openai-codex".into(),
1476                Box::new(MockModelProvider {
1477                    calls: Arc::clone(&calls),
1478                    fail_until_attempt: usize::MAX,
1479                    response: "never",
1480                    error: "OpenAI Codex stream error: Your input exceeds the context window of this model. Please adjust your input and try again.",
1481                }),
1482            )],
1483            4,
1484            1,
1485        )
1486        .with_model_fallbacks(model_fallbacks);
1487
1488        let err = model_provider
1489            .simple_chat("hello", "gpt-5.3-codex", Some(0.0))
1490            .await
1491            .expect_err("context window overflow should fail fast");
1492        let msg = err.to_string();
1493
1494        assert!(msg.contains("context window"));
1495        // chat_with_system has no history to truncate, so it bails immediately
1496        assert_eq!(calls.load(Ordering::SeqCst), 1);
1497    }
1498
1499    #[tokio::test]
1500    async fn aggregated_error_marks_non_retryable_model_mismatch_with_details() {
1501        let calls = Arc::new(AtomicUsize::new(0));
1502        let model_provider = ReliableModelProvider::new(
1503            "test",
1504            vec![(
1505                "custom".into(),
1506                Box::new(MockModelProvider {
1507                    calls: Arc::clone(&calls),
1508                    fail_until_attempt: usize::MAX,
1509                    response: "never",
1510                    error: "unsupported model: glm-4.7",
1511                }),
1512            )],
1513            3,
1514            1,
1515        );
1516
1517        let err = model_provider
1518            .simple_chat("hello", "glm-4.7", Some(0.0))
1519            .await
1520            .expect_err("model_provider should fail");
1521        let msg = err.to_string();
1522
1523        assert!(msg.contains("non_retryable"));
1524        assert!(msg.contains("error=unsupported model: glm-4.7"));
1525        // Non-retryable errors should not consume retry budget.
1526        assert_eq!(calls.load(Ordering::SeqCst), 1);
1527    }
1528
1529    #[tokio::test]
1530    async fn skips_retries_on_non_retryable_error() {
1531        let primary_calls = Arc::new(AtomicUsize::new(0));
1532        let fallback_calls = Arc::new(AtomicUsize::new(0));
1533
1534        let model_provider = ReliableModelProvider::new(
1535            "test",
1536            vec![
1537                (
1538                    "primary".into(),
1539                    Box::new(MockModelProvider {
1540                        calls: Arc::clone(&primary_calls),
1541                        fail_until_attempt: usize::MAX,
1542                        response: "never",
1543                        error: "401 Unauthorized",
1544                    }),
1545                ),
1546                (
1547                    "fallback".into(),
1548                    Box::new(MockModelProvider {
1549                        calls: Arc::clone(&fallback_calls),
1550                        fail_until_attempt: 0,
1551                        response: "from fallback",
1552                        error: "fallback err",
1553                    }),
1554                ),
1555            ],
1556            3,
1557            1,
1558        );
1559
1560        let result = model_provider
1561            .simple_chat("hello", "test", Some(0.0))
1562            .await
1563            .unwrap();
1564        assert_eq!(result, "from fallback");
1565        // Primary should have been called only once (no retries)
1566        assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
1567        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1568    }
1569
1570    #[tokio::test]
1571    async fn chat_with_history_retries_then_recovers() {
1572        let calls = Arc::new(AtomicUsize::new(0));
1573        let model_provider = ReliableModelProvider::new(
1574            "test",
1575            vec![(
1576                "primary".into(),
1577                Box::new(MockModelProvider {
1578                    calls: Arc::clone(&calls),
1579                    fail_until_attempt: 1,
1580                    response: "history ok",
1581                    error: "temporary",
1582                }),
1583            )],
1584            2,
1585            1,
1586        );
1587
1588        let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
1589        let result = model_provider
1590            .chat_with_history(&messages, "test", Some(0.0))
1591            .await
1592            .unwrap();
1593        assert_eq!(result, "history ok");
1594        assert_eq!(calls.load(Ordering::SeqCst), 2);
1595    }
1596
1597    #[tokio::test]
1598    async fn chat_with_history_falls_back() {
1599        let primary_calls = Arc::new(AtomicUsize::new(0));
1600        let fallback_calls = Arc::new(AtomicUsize::new(0));
1601
1602        let model_provider = ReliableModelProvider::new(
1603            "test",
1604            vec![
1605                (
1606                    "primary".into(),
1607                    Box::new(MockModelProvider {
1608                        calls: Arc::clone(&primary_calls),
1609                        fail_until_attempt: usize::MAX,
1610                        response: "never",
1611                        error: "primary down",
1612                    }),
1613                ),
1614                (
1615                    "fallback".into(),
1616                    Box::new(MockModelProvider {
1617                        calls: Arc::clone(&fallback_calls),
1618                        fail_until_attempt: 0,
1619                        response: "fallback ok",
1620                        error: "fallback err",
1621                    }),
1622                ),
1623            ],
1624            1,
1625            1,
1626        );
1627
1628        let messages = vec![ChatMessage::user("hello")];
1629        let result = model_provider
1630            .chat_with_history(&messages, "test", Some(0.0))
1631            .await
1632            .unwrap();
1633        assert_eq!(result, "fallback ok");
1634        assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
1635        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1636    }
1637
1638    // ── New tests: model failover ──
1639
1640    #[tokio::test]
1641    async fn model_failover_tries_fallback_model() {
1642        let calls = Arc::new(AtomicUsize::new(0));
1643        let mock = Arc::new(ModelAwareMock {
1644            calls: Arc::clone(&calls),
1645            models_seen: parking_lot::Mutex::new(Vec::new()),
1646            fail_models: vec!["claude-opus"],
1647            response: "ok from sonnet",
1648        });
1649
1650        let mut fallbacks = HashMap::new();
1651        fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
1652
1653        let model_provider = ReliableModelProvider::new(
1654            "test",
1655            vec![(
1656                "anthropic".into(),
1657                Box::new(mock.clone()) as Box<dyn ModelProvider>,
1658            )],
1659            0, // no retries — force immediate model failover
1660            1,
1661        )
1662        .with_model_fallbacks(fallbacks);
1663
1664        let result = model_provider
1665            .simple_chat("hello", "claude-opus", Some(0.0))
1666            .await
1667            .unwrap();
1668        assert_eq!(result, "ok from sonnet");
1669
1670        let seen = mock.models_seen.lock();
1671        assert_eq!(seen.len(), 2);
1672        assert_eq!(seen[0], "claude-opus");
1673        assert_eq!(seen[1], "claude-sonnet");
1674    }
1675
1676    #[tokio::test]
1677    async fn model_failover_all_models_fail() {
1678        let calls = Arc::new(AtomicUsize::new(0));
1679        let mock = Arc::new(ModelAwareMock {
1680            calls: Arc::clone(&calls),
1681            models_seen: parking_lot::Mutex::new(Vec::new()),
1682            fail_models: vec!["model-a", "model-b", "model-c"],
1683            response: "never",
1684        });
1685
1686        let mut fallbacks = HashMap::new();
1687        fallbacks.insert(
1688            "model-a".to_string(),
1689            vec!["model-b".to_string(), "model-c".to_string()],
1690        );
1691
1692        let model_provider = ReliableModelProvider::new(
1693            "test",
1694            vec![(
1695                "p1".into(),
1696                Box::new(mock.clone()) as Box<dyn ModelProvider>,
1697            )],
1698            0,
1699            1,
1700        )
1701        .with_model_fallbacks(fallbacks);
1702
1703        let err = model_provider
1704            .simple_chat("hello", "model-a", Some(0.0))
1705            .await
1706            .expect_err("all models should fail");
1707        assert!(
1708            err.to_string()
1709                .contains("All model_providers/models failed")
1710        );
1711
1712        let seen = mock.models_seen.lock();
1713        assert_eq!(seen.len(), 3);
1714    }
1715
1716    #[tokio::test]
1717    async fn no_model_fallbacks_behaves_like_before() {
1718        let calls = Arc::new(AtomicUsize::new(0));
1719        let model_provider = ReliableModelProvider::new(
1720            "test",
1721            vec![(
1722                "primary".into(),
1723                Box::new(MockModelProvider {
1724                    calls: Arc::clone(&calls),
1725                    fail_until_attempt: 0,
1726                    response: "ok",
1727                    error: "boom",
1728                }),
1729            )],
1730            2,
1731            1,
1732        );
1733        // No model_fallbacks set — should work exactly as before
1734        let result = model_provider
1735            .simple_chat("hello", "test", Some(0.0))
1736            .await
1737            .unwrap();
1738        assert_eq!(result, "ok");
1739        assert_eq!(calls.load(Ordering::SeqCst), 1);
1740    }
1741
1742    // ── New tests: auth rotation ──
1743
1744    #[tokio::test]
1745    async fn auth_rotation_cycles_keys() {
1746        let model_provider = ReliableModelProvider::new(
1747            "test",
1748            vec![(
1749                "p".into(),
1750                Box::new(MockModelProvider {
1751                    calls: Arc::new(AtomicUsize::new(0)),
1752                    fail_until_attempt: 0,
1753                    response: "ok",
1754                    error: "",
1755                }),
1756            )],
1757            0,
1758            1,
1759        )
1760        .with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]);
1761
1762        // Rotate 5 times, verify round-robin
1763        let keys: Vec<&str> = (0..5)
1764            .map(|_| model_provider.rotate_key().unwrap())
1765            .collect();
1766        assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]);
1767    }
1768
1769    #[tokio::test]
1770    async fn auth_rotation_returns_none_when_empty() {
1771        let model_provider = ReliableModelProvider::new("test", vec![], 0, 1);
1772        assert!(model_provider.rotate_key().is_none());
1773    }
1774
1775    // ── New tests: Retry-After parsing ──
1776
1777    #[test]
1778    fn parse_retry_after_integer() {
1779        let err = anyhow::Error::msg("429 Too Many Requests, Retry-After: 5");
1780        assert_eq!(parse_retry_after_ms(&err), Some(5000));
1781    }
1782
1783    #[test]
1784    fn parse_retry_after_float() {
1785        let err = anyhow::Error::msg("Rate limited. retry_after: 2.5 seconds");
1786        assert_eq!(parse_retry_after_ms(&err), Some(2500));
1787    }
1788
1789    #[test]
1790    fn parse_retry_after_missing() {
1791        let err = anyhow::Error::msg("500 Internal Server Error");
1792        assert_eq!(parse_retry_after_ms(&err), None);
1793    }
1794
1795    #[test]
1796    fn rate_limited_detection() {
1797        assert!(is_rate_limited(&anyhow::Error::msg(
1798            "429 Too Many Requests"
1799        )));
1800        assert!(is_rate_limited(&anyhow::Error::msg(
1801            "HTTP 429 rate limit exceeded"
1802        )));
1803        assert!(!is_rate_limited(&anyhow::Error::msg("401 Unauthorized")));
1804        assert!(!is_rate_limited(&anyhow::Error::msg(
1805            "500 Internal Server Error"
1806        )));
1807    }
1808
1809    #[test]
1810    fn non_retryable_rate_limit_detects_plan_restricted_model() {
1811        let err = anyhow::Error::msg(
1812            "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"the current account plan does not include glm-5\"}",
1813        );
1814        assert!(
1815            is_non_retryable_rate_limit(&err),
1816            "plan-restricted 429 should skip retries"
1817        );
1818    }
1819
1820    #[test]
1821    fn non_retryable_rate_limit_detects_insufficient_balance() {
1822        let err = anyhow::Error::msg(
1823            "API error (429 Too Many Requests): {\"code\":1113,\"message\":\"insufficient balance\"}",
1824        );
1825        assert!(
1826            is_non_retryable_rate_limit(&err),
1827            "insufficient-balance 429 should skip retries"
1828        );
1829    }
1830
1831    #[test]
1832    fn non_retryable_rate_limit_does_not_flag_generic_429() {
1833        let err = anyhow::Error::msg("429 Too Many Requests: rate limit exceeded");
1834        assert!(
1835            !is_non_retryable_rate_limit(&err),
1836            "generic rate-limit 429 should remain retryable"
1837        );
1838    }
1839
1840    #[test]
1841    fn compute_backoff_uses_retry_after() {
1842        let model_provider = ReliableModelProvider::new("test", vec![], 0, 500);
1843        let err = anyhow::Error::msg("429 Retry-After: 3");
1844        assert_eq!(model_provider.compute_backoff(500, &err), 3_000);
1845    }
1846
1847    #[test]
1848    fn compute_backoff_caps_at_30s() {
1849        let model_provider = ReliableModelProvider::new("test", vec![], 0, 500);
1850        let err = anyhow::Error::msg("429 Retry-After: 120");
1851        assert_eq!(model_provider.compute_backoff(500, &err), 30_000);
1852    }
1853
1854    #[test]
1855    fn compute_backoff_falls_back_to_base() {
1856        let model_provider = ReliableModelProvider::new("test", vec![], 0, 500);
1857        let err = anyhow::Error::msg("500 Server Error");
1858        assert_eq!(model_provider.compute_backoff(500, &err), 500);
1859    }
1860
1861    // ── §2.1 API auth error (401/403) tests ──────────────────
1862
1863    #[test]
1864    fn non_retryable_detects_401() {
1865        let err = anyhow::Error::msg("API error (401 Unauthorized): invalid api key");
1866        assert!(
1867            is_non_retryable(&err),
1868            "401 errors must be detected as non-retryable"
1869        );
1870    }
1871
1872    #[test]
1873    fn non_retryable_detects_403() {
1874        let err = anyhow::Error::msg("API error (403 Forbidden): access denied");
1875        assert!(
1876            is_non_retryable(&err),
1877            "403 errors must be detected as non-retryable"
1878        );
1879    }
1880
1881    #[test]
1882    fn non_retryable_detects_404() {
1883        let err = anyhow::Error::msg("API error (404 Not Found): model not found");
1884        assert!(
1885            is_non_retryable(&err),
1886            "404 errors must be detected as non-retryable"
1887        );
1888    }
1889
1890    #[test]
1891    fn non_retryable_does_not_flag_429() {
1892        let err = anyhow::Error::msg("429 Too Many Requests");
1893        assert!(
1894            !is_non_retryable(&err),
1895            "429 must NOT be treated as non-retryable (it is retryable with backoff)"
1896        );
1897    }
1898
1899    #[test]
1900    fn non_retryable_does_not_flag_408() {
1901        let err = anyhow::Error::msg("408 Request Timeout");
1902        assert!(
1903            !is_non_retryable(&err),
1904            "408 must NOT be treated as non-retryable (it is retryable)"
1905        );
1906    }
1907
1908    #[test]
1909    fn non_retryable_does_not_flag_500() {
1910        let err = anyhow::Error::msg("500 Internal Server Error");
1911        assert!(
1912            !is_non_retryable(&err),
1913            "500 must NOT be treated as non-retryable (server errors are retryable)"
1914        );
1915    }
1916
1917    #[test]
1918    fn non_retryable_does_not_flag_502() {
1919        let err = anyhow::Error::msg("502 Bad Gateway");
1920        assert!(
1921            !is_non_retryable(&err),
1922            "502 must NOT be treated as non-retryable"
1923        );
1924    }
1925
1926    // ── §2.2 Rate limit Retry-After edge cases ───────────────
1927
1928    #[test]
1929    fn parse_retry_after_zero() {
1930        let err = anyhow::Error::msg("429 Too Many Requests, Retry-After: 0");
1931        assert_eq!(
1932            parse_retry_after_ms(&err),
1933            Some(0),
1934            "Retry-After: 0 should parse as 0ms"
1935        );
1936    }
1937
1938    #[test]
1939    fn parse_retry_after_with_underscore_separator() {
1940        let err = anyhow::Error::msg("rate limited, retry_after: 10");
1941        assert_eq!(
1942            parse_retry_after_ms(&err),
1943            Some(10_000),
1944            "retry_after with underscore must be parsed"
1945        );
1946    }
1947
1948    #[test]
1949    fn parse_retry_after_space_separator() {
1950        let err = anyhow::Error::msg("Retry-After 7");
1951        assert_eq!(
1952            parse_retry_after_ms(&err),
1953            Some(7000),
1954            "Retry-After with space separator must be parsed"
1955        );
1956    }
1957
1958    #[test]
1959    fn rate_limited_false_for_generic_error() {
1960        let err = anyhow::Error::msg("Connection refused");
1961        assert!(
1962            !is_rate_limited(&err),
1963            "generic errors must not be flagged as rate-limited"
1964        );
1965    }
1966
1967    // ── §2.3 Malformed API response error classification ─────
1968
1969    #[tokio::test]
1970    async fn non_retryable_skips_retries_for_401() {
1971        let calls = Arc::new(AtomicUsize::new(0));
1972        let model_provider = ReliableModelProvider::new(
1973            "test",
1974            vec![(
1975                "primary".into(),
1976                Box::new(MockModelProvider {
1977                    calls: Arc::clone(&calls),
1978                    fail_until_attempt: usize::MAX,
1979                    response: "never",
1980                    error: "API error (401 Unauthorized): invalid key",
1981                }),
1982            )],
1983            5,
1984            1,
1985        );
1986
1987        let result = model_provider.simple_chat("hello", "test", Some(0.0)).await;
1988        assert!(result.is_err(), "401 should fail without retries");
1989        assert_eq!(
1990            calls.load(Ordering::SeqCst),
1991            1,
1992            "must not retry on 401 — should be exactly 1 call"
1993        );
1994    }
1995
1996    #[tokio::test]
1997    async fn non_retryable_rate_limit_skips_retries_for_plan_errors() {
1998        let calls = Arc::new(AtomicUsize::new(0));
1999        let model_provider = ReliableModelProvider::new(
2000            "test",
2001            vec![(
2002                "primary".into(),
2003                Box::new(MockModelProvider {
2004                    calls: Arc::clone(&calls),
2005                    fail_until_attempt: usize::MAX,
2006                    response: "never",
2007                    error: "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"plan does not include glm-5\"}",
2008                }),
2009            )],
2010            5,
2011            1,
2012        );
2013
2014        let result = model_provider.simple_chat("hello", "test", Some(0.0)).await;
2015        assert!(
2016            result.is_err(),
2017            "plan-restricted 429 should fail quickly without retrying"
2018        );
2019        assert_eq!(
2020            calls.load(Ordering::SeqCst),
2021            1,
2022            "must not retry non-retryable 429 business errors"
2023        );
2024    }
2025
2026    // Arc<ModelAwareMock> ModelProvider impl provided by blanket impl in zeroclaw-types.
2027
2028    /// Mock model_provider that implements `chat()` with native tool support.
2029    struct NativeToolMock {
2030        calls: Arc<AtomicUsize>,
2031        fail_until_attempt: usize,
2032        response_text: &'static str,
2033        tool_calls: Vec<super::super::traits::ToolCall>,
2034        error: &'static str,
2035    }
2036
2037    #[async_trait]
2038    impl ModelProvider for NativeToolMock {
2039        async fn chat_with_system(
2040            &self,
2041            _system_prompt: Option<&str>,
2042            _message: &str,
2043            _model: &str,
2044            _temperature: Option<f64>,
2045        ) -> anyhow::Result<String> {
2046            Ok(self.response_text.to_string())
2047        }
2048
2049        fn supports_native_tools(&self) -> bool {
2050            true
2051        }
2052
2053        async fn chat(
2054            &self,
2055            _request: ChatRequest<'_>,
2056            _model: &str,
2057            _temperature: Option<f64>,
2058        ) -> anyhow::Result<ChatResponse> {
2059            let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
2060            if attempt <= self.fail_until_attempt {
2061                anyhow::bail!(self.error);
2062            }
2063            Ok(ChatResponse {
2064                text: Some(self.response_text.to_string()),
2065                tool_calls: self.tool_calls.clone(),
2066                usage: None,
2067                reasoning_content: None,
2068            })
2069        }
2070    }
2071    impl ::zeroclaw_api::attribution::Attributable for NativeToolMock {
2072        fn role(&self) -> ::zeroclaw_api::attribution::Role {
2073            ::zeroclaw_api::attribution::Role::Provider(
2074                ::zeroclaw_api::attribution::ProviderKind::Model(
2075                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2076                ),
2077            )
2078        }
2079        fn alias(&self) -> &str {
2080            "NativeToolMock"
2081        }
2082    }
2083
2084    #[tokio::test]
2085    async fn chat_delegates_to_inner_provider() {
2086        let calls = Arc::new(AtomicUsize::new(0));
2087        let tool_call = super::super::traits::ToolCall {
2088            id: "call_1".to_string(),
2089            name: "shell".to_string(),
2090            arguments: r#"{"command":"date"}"#.to_string(),
2091            extra_content: None,
2092        };
2093        let model_provider = ReliableModelProvider::new(
2094            "test",
2095            vec![(
2096                "primary".into(),
2097                Box::new(NativeToolMock {
2098                    calls: Arc::clone(&calls),
2099                    fail_until_attempt: 0,
2100                    response_text: "ok",
2101                    tool_calls: vec![tool_call.clone()],
2102                    error: "boom",
2103                }) as Box<dyn ModelProvider>,
2104            )],
2105            2,
2106            1,
2107        );
2108
2109        let messages = vec![ChatMessage::user("what time is it?")];
2110        let request = ChatRequest {
2111            messages: &messages,
2112            tools: None,
2113            thinking: None,
2114        };
2115        let result = model_provider
2116            .chat(request, "test-model", Some(0.0))
2117            .await
2118            .unwrap();
2119
2120        assert_eq!(result.text.as_deref(), Some("ok"));
2121        assert_eq!(result.tool_calls.len(), 1);
2122        assert_eq!(result.tool_calls[0].name, "shell");
2123        assert_eq!(calls.load(Ordering::SeqCst), 1);
2124    }
2125
2126    #[tokio::test]
2127    async fn chat_retries_and_recovers() {
2128        let calls = Arc::new(AtomicUsize::new(0));
2129        let tool_call = super::super::traits::ToolCall {
2130            id: "call_1".to_string(),
2131            name: "shell".to_string(),
2132            arguments: r#"{"command":"date"}"#.to_string(),
2133            extra_content: None,
2134        };
2135        let model_provider = ReliableModelProvider::new(
2136            "test",
2137            vec![(
2138                "primary".into(),
2139                Box::new(NativeToolMock {
2140                    calls: Arc::clone(&calls),
2141                    fail_until_attempt: 2,
2142                    response_text: "recovered",
2143                    tool_calls: vec![tool_call],
2144                    error: "temporary failure",
2145                }) as Box<dyn ModelProvider>,
2146            )],
2147            3,
2148            1,
2149        );
2150
2151        let messages = vec![ChatMessage::user("test")];
2152        let request = ChatRequest {
2153            messages: &messages,
2154            tools: None,
2155            thinking: None,
2156        };
2157        let result = model_provider
2158            .chat(request, "test-model", Some(0.0))
2159            .await
2160            .unwrap();
2161
2162        assert_eq!(result.text.as_deref(), Some("recovered"));
2163        assert!(
2164            calls.load(Ordering::SeqCst) > 1,
2165            "should have retried at least once"
2166        );
2167    }
2168
2169    #[tokio::test]
2170    async fn chat_preserves_native_tools_support() {
2171        let calls = Arc::new(AtomicUsize::new(0));
2172        let model_provider = ReliableModelProvider::new(
2173            "test",
2174            vec![(
2175                "primary".into(),
2176                Box::new(NativeToolMock {
2177                    calls: Arc::clone(&calls),
2178                    fail_until_attempt: 0,
2179                    response_text: "ok",
2180                    tool_calls: vec![],
2181                    error: "boom",
2182                }) as Box<dyn ModelProvider>,
2183            )],
2184            2,
2185            1,
2186        );
2187
2188        assert!(
2189            model_provider.supports_native_tools(),
2190            "ReliableModelProvider must propagate supports_native_tools from inner model_provider"
2191        );
2192    }
2193
2194    // ── Gap 2-4: Parity tests for chat() ────────────────────────
2195
2196    /// Gap 2: `chat()` returns an aggregated error when all model_providers fail,
2197    /// matching behavior of `returns_aggregated_error_when_all_providers_fail`.
2198    #[tokio::test]
2199    async fn chat_returns_aggregated_error_when_all_providers_fail() {
2200        let model_provider = ReliableModelProvider::new(
2201            "test",
2202            vec![
2203                (
2204                    "p1".into(),
2205                    Box::new(NativeToolMock {
2206                        calls: Arc::new(AtomicUsize::new(0)),
2207                        fail_until_attempt: usize::MAX,
2208                        response_text: "never",
2209                        tool_calls: vec![],
2210                        error: "p1 chat error",
2211                    }) as Box<dyn ModelProvider>,
2212                ),
2213                (
2214                    "p2".into(),
2215                    Box::new(NativeToolMock {
2216                        calls: Arc::new(AtomicUsize::new(0)),
2217                        fail_until_attempt: usize::MAX,
2218                        response_text: "never",
2219                        tool_calls: vec![],
2220                        error: "p2 chat error",
2221                    }) as Box<dyn ModelProvider>,
2222                ),
2223            ],
2224            0,
2225            1,
2226        );
2227
2228        let messages = vec![ChatMessage::user("hello")];
2229        let request = ChatRequest {
2230            messages: &messages,
2231            tools: None,
2232            thinking: None,
2233        };
2234        let err = model_provider
2235            .chat(request, "test", Some(0.0))
2236            .await
2237            .expect_err("all model_providers should fail");
2238        let msg = err.to_string();
2239        assert!(msg.contains("All model_providers/models failed"));
2240        assert!(msg.contains("model_provider=p1 model=test"));
2241        assert!(msg.contains("model_provider=p2 model=test"));
2242        assert!(msg.contains("error=p1 chat error"));
2243        assert!(msg.contains("error=p2 chat error"));
2244        assert!(msg.contains("retryable"));
2245    }
2246
2247    /// Mock that records model names and can fail specific models,
2248    /// implementing `chat()` for native tool calling parity tests.
2249    struct NativeModelAwareMock {
2250        calls: Arc<AtomicUsize>,
2251        models_seen: parking_lot::Mutex<Vec<String>>,
2252        fail_models: Vec<&'static str>,
2253        response_text: &'static str,
2254    }
2255
2256    #[async_trait]
2257    impl ModelProvider for NativeModelAwareMock {
2258        async fn chat_with_system(
2259            &self,
2260            _system_prompt: Option<&str>,
2261            _message: &str,
2262            _model: &str,
2263            _temperature: Option<f64>,
2264        ) -> anyhow::Result<String> {
2265            Ok(self.response_text.to_string())
2266        }
2267
2268        fn supports_native_tools(&self) -> bool {
2269            true
2270        }
2271
2272        async fn chat(
2273            &self,
2274            _request: ChatRequest<'_>,
2275            model: &str,
2276            _temperature: Option<f64>,
2277        ) -> anyhow::Result<ChatResponse> {
2278            self.calls.fetch_add(1, Ordering::SeqCst);
2279            self.models_seen.lock().push(model.to_string());
2280            if self.fail_models.contains(&model) {
2281                anyhow::bail!("500 model {} unavailable", model);
2282            }
2283            Ok(ChatResponse {
2284                text: Some(self.response_text.to_string()),
2285                tool_calls: vec![],
2286                usage: None,
2287                reasoning_content: None,
2288            })
2289        }
2290    }
2291    impl ::zeroclaw_api::attribution::Attributable for NativeModelAwareMock {
2292        fn role(&self) -> ::zeroclaw_api::attribution::Role {
2293            ::zeroclaw_api::attribution::Role::Provider(
2294                ::zeroclaw_api::attribution::ProviderKind::Model(
2295                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2296                ),
2297            )
2298        }
2299        fn alias(&self) -> &str {
2300            "NativeModelAwareMock"
2301        }
2302    }
2303
2304    // Arc<NativeModelAwareMock> ModelProvider impl provided by blanket impl in zeroclaw-types.
2305
2306    /// Gap 3: `chat()` tries fallback models on failure,
2307    /// matching behavior of `model_failover_tries_fallback_model`.
2308    #[tokio::test]
2309    async fn chat_tries_model_failover_on_failure() {
2310        let calls = Arc::new(AtomicUsize::new(0));
2311        let mock = Arc::new(NativeModelAwareMock {
2312            calls: Arc::clone(&calls),
2313            models_seen: parking_lot::Mutex::new(Vec::new()),
2314            fail_models: vec!["claude-opus"],
2315            response_text: "ok from sonnet",
2316        });
2317
2318        let mut fallbacks = HashMap::new();
2319        fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
2320
2321        let model_provider = ReliableModelProvider::new(
2322            "test",
2323            vec![(
2324                "anthropic".into(),
2325                Box::new(mock.clone()) as Box<dyn ModelProvider>,
2326            )],
2327            0, // no retries — force immediate model failover
2328            1,
2329        )
2330        .with_model_fallbacks(fallbacks);
2331
2332        let messages = vec![ChatMessage::user("hello")];
2333        let request = ChatRequest {
2334            messages: &messages,
2335            tools: None,
2336            thinking: None,
2337        };
2338        let result = model_provider
2339            .chat(request, "claude-opus", Some(0.0))
2340            .await
2341            .unwrap();
2342        assert_eq!(result.text.as_deref(), Some("ok from sonnet"));
2343
2344        let seen = mock.models_seen.lock();
2345        assert_eq!(seen.len(), 2);
2346        assert_eq!(seen[0], "claude-opus");
2347        assert_eq!(seen[1], "claude-sonnet");
2348    }
2349
2350    /// Gap 4: `chat()` skips retries on non-retryable errors (401, 403, etc.),
2351    /// matching behavior of `skips_retries_on_non_retryable_error`.
2352    #[tokio::test]
2353    async fn chat_skips_non_retryable_errors() {
2354        let primary_calls = Arc::new(AtomicUsize::new(0));
2355        let fallback_calls = Arc::new(AtomicUsize::new(0));
2356
2357        let model_provider = ReliableModelProvider::new(
2358            "test",
2359            vec![
2360                (
2361                    "primary".into(),
2362                    Box::new(NativeToolMock {
2363                        calls: Arc::clone(&primary_calls),
2364                        fail_until_attempt: usize::MAX,
2365                        response_text: "never",
2366                        tool_calls: vec![],
2367                        error: "401 Unauthorized",
2368                    }) as Box<dyn ModelProvider>,
2369                ),
2370                (
2371                    "fallback".into(),
2372                    Box::new(NativeToolMock {
2373                        calls: Arc::clone(&fallback_calls),
2374                        fail_until_attempt: 0,
2375                        response_text: "from fallback",
2376                        tool_calls: vec![],
2377                        error: "fallback err",
2378                    }) as Box<dyn ModelProvider>,
2379                ),
2380            ],
2381            3,
2382            1,
2383        );
2384
2385        let messages = vec![ChatMessage::user("hello")];
2386        let request = ChatRequest {
2387            messages: &messages,
2388            tools: None,
2389            thinking: None,
2390        };
2391        let result = model_provider
2392            .chat(request, "test", Some(0.0))
2393            .await
2394            .unwrap();
2395        assert_eq!(result.text.as_deref(), Some("from fallback"));
2396        // Primary should have been called only once (no retries)
2397        assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
2398        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
2399    }
2400
2401    // ── Context window truncation tests ─────────────────────────
2402
2403    #[test]
2404    fn context_window_error_is_not_non_retryable() {
2405        // Context window errors should be recoverable via truncation
2406        assert!(!is_non_retryable(&anyhow::Error::msg(
2407            "exceeds the context window"
2408        )));
2409        assert!(!is_non_retryable(&anyhow::Error::msg(
2410            "maximum context length exceeded"
2411        )));
2412        assert!(!is_non_retryable(&anyhow::Error::msg(
2413            "too many tokens in the request"
2414        )));
2415        assert!(!is_non_retryable(&anyhow::Error::msg(
2416            "token limit exceeded"
2417        )));
2418    }
2419
2420    #[test]
2421    fn is_context_window_exceeded_detects_llamacpp() {
2422        assert!(is_context_window_exceeded(&anyhow::Error::msg(
2423            "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
2424        )));
2425    }
2426
2427    #[test]
2428    fn truncate_for_context_drops_oldest_non_system() {
2429        let mut messages = vec![
2430            ChatMessage::system("sys"),
2431            ChatMessage::user("msg1"),
2432            ChatMessage::assistant("resp1"),
2433            ChatMessage::user("msg2"),
2434            ChatMessage::assistant("resp2"),
2435            ChatMessage::user("msg3"),
2436        ];
2437
2438        let dropped = truncate_for_context(&mut messages);
2439
2440        // 5 non-system messages, drop oldest half = 2
2441        assert_eq!(dropped, 2);
2442        // System message preserved
2443        assert_eq!(messages[0].role, "system");
2444        // Remaining messages should be the newer ones
2445        assert_eq!(messages.len(), 4); // system + 3 remaining non-system
2446        // The last message should still be the most recent user message
2447        assert_eq!(messages.last().unwrap().content, "msg3");
2448    }
2449
2450    #[test]
2451    fn truncate_for_context_preserves_system_and_last_message() {
2452        // Only one non-system message: nothing to drop
2453        let mut messages = vec![ChatMessage::system("sys"), ChatMessage::user("only")];
2454        let dropped = truncate_for_context(&mut messages);
2455        assert_eq!(dropped, 0);
2456        assert_eq!(messages.len(), 2);
2457
2458        // No system message, only one user message
2459        let mut messages = vec![ChatMessage::user("only")];
2460        let dropped = truncate_for_context(&mut messages);
2461        assert_eq!(dropped, 0);
2462        assert_eq!(messages.len(), 1);
2463    }
2464
2465    /// Mock that fails with context error on first N calls, then succeeds.
2466    /// Tracks the number of messages received on each call.
2467    struct ContextOverflowMock {
2468        calls: Arc<AtomicUsize>,
2469        fail_until_attempt: usize,
2470        message_counts: parking_lot::Mutex<Vec<usize>>,
2471    }
2472
2473    #[async_trait]
2474    impl ModelProvider for ContextOverflowMock {
2475        async fn chat_with_system(
2476            &self,
2477            _system_prompt: Option<&str>,
2478            _message: &str,
2479            _model: &str,
2480            _temperature: Option<f64>,
2481        ) -> anyhow::Result<String> {
2482            Ok("ok".to_string())
2483        }
2484
2485        async fn chat_with_history(
2486            &self,
2487            messages: &[ChatMessage],
2488            _model: &str,
2489            _temperature: Option<f64>,
2490        ) -> anyhow::Result<String> {
2491            let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
2492            self.message_counts.lock().push(messages.len());
2493            if attempt <= self.fail_until_attempt {
2494                anyhow::bail!(
2495                    "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
2496                );
2497            }
2498            Ok("recovered after truncation".to_string())
2499        }
2500    }
2501    impl ::zeroclaw_api::attribution::Attributable for ContextOverflowMock {
2502        fn role(&self) -> ::zeroclaw_api::attribution::Role {
2503            ::zeroclaw_api::attribution::Role::Provider(
2504                ::zeroclaw_api::attribution::ProviderKind::Model(
2505                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2506                ),
2507            )
2508        }
2509        fn alias(&self) -> &str {
2510            "ContextOverflowMock"
2511        }
2512    }
2513
2514    #[tokio::test]
2515    async fn chat_with_history_truncates_on_context_overflow() {
2516        let calls = Arc::new(AtomicUsize::new(0));
2517        let mock = ContextOverflowMock {
2518            calls: Arc::clone(&calls),
2519            fail_until_attempt: 1, // fail first call, succeed after truncation
2520            message_counts: parking_lot::Mutex::new(Vec::new()),
2521        };
2522
2523        let model_provider = ReliableModelProvider::new(
2524            "test",
2525            vec![("local".into(), Box::new(mock) as Box<dyn ModelProvider>)],
2526            3,
2527            1,
2528        );
2529
2530        let messages = vec![
2531            ChatMessage::system("system prompt"),
2532            ChatMessage::user("old message 1"),
2533            ChatMessage::assistant("old response 1"),
2534            ChatMessage::user("old message 2"),
2535            ChatMessage::assistant("old response 2"),
2536            ChatMessage::user("current question"),
2537        ];
2538
2539        let result = model_provider
2540            .chat_with_history(&messages, "local-model", Some(0.0))
2541            .await
2542            .unwrap();
2543        assert_eq!(result, "recovered after truncation");
2544        // Should have been called twice: once with full messages, once with truncated
2545        assert_eq!(calls.load(Ordering::SeqCst), 2);
2546    }
2547
2548    #[tokio::test]
2549    async fn context_overflow_with_no_history_to_truncate_bails_immediately() {
2550        let calls = Arc::new(AtomicUsize::new(0));
2551        let mock = ContextOverflowMock {
2552            calls: Arc::clone(&calls),
2553            fail_until_attempt: 999, // always fail
2554            message_counts: parking_lot::Mutex::new(Vec::new()),
2555        };
2556
2557        let model_provider = ReliableModelProvider::new(
2558            "test",
2559            vec![("local".into(), Box::new(mock) as Box<dyn ModelProvider>)],
2560            3,
2561            1,
2562        );
2563
2564        // Only system + one user message — nothing to truncate
2565        let messages = vec![
2566            ChatMessage::system("huge system prompt that exceeds context window"),
2567            ChatMessage::user("hello"),
2568        ];
2569
2570        let result = model_provider
2571            .chat_with_history(&messages, "local-model", Some(0.0))
2572            .await;
2573        assert!(result.is_err());
2574        let err_msg = result.unwrap_err().to_string();
2575        assert!(
2576            err_msg.contains("cannot be reduced further"),
2577            "Should bail with actionable message, got: {err_msg}"
2578        );
2579        // Should only be called once — no useless retries
2580        assert_eq!(
2581            calls.load(Ordering::SeqCst),
2582            1,
2583            "Should not retry when truncation is impossible"
2584        );
2585    }
2586
2587    // ── Tool schema error detection tests ───────────────────────────────
2588
2589    #[test]
2590    fn tool_schema_error_detects_groq_validation_failure() {
2591        let msg = r#"Groq API error (400 Bad Request): {"error":{"message":"tool call validation failed: attempted to call tool 'memory_recall' which was not in request"}}"#;
2592        let err = anyhow::Error::msg(msg.to_string());
2593        assert!(is_tool_schema_error(&err));
2594    }
2595
2596    #[test]
2597    fn tool_schema_error_detects_not_in_request() {
2598        let err = anyhow::Error::msg("tool 'search' was not in request");
2599        assert!(is_tool_schema_error(&err));
2600    }
2601
2602    #[test]
2603    fn tool_schema_error_detects_not_found_in_tool_list() {
2604        let err = anyhow::Error::msg("function 'foo' not found in tool list");
2605        assert!(is_tool_schema_error(&err));
2606    }
2607
2608    #[test]
2609    fn tool_schema_error_detects_invalid_tool_call() {
2610        let err = anyhow::Error::msg("invalid_tool_call: no matching function");
2611        assert!(is_tool_schema_error(&err));
2612    }
2613
2614    #[test]
2615    fn tool_schema_error_ignores_unrelated_errors() {
2616        let err = anyhow::Error::msg("invalid api key");
2617        assert!(!is_tool_schema_error(&err));
2618
2619        let err = anyhow::Error::msg("model not found");
2620        assert!(!is_tool_schema_error(&err));
2621    }
2622
2623    #[test]
2624    fn non_retryable_returns_false_for_tool_schema_400() {
2625        // A 400 error with tool schema validation text should NOT be non-retryable.
2626        let msg = "400 Bad Request: tool call validation failed: attempted to call tool 'x' which was not in request";
2627        let err = anyhow::Error::msg(msg.to_string());
2628        assert!(!is_non_retryable(&err));
2629    }
2630
2631    #[test]
2632    fn non_retryable_returns_true_for_other_400_errors() {
2633        // A regular 400 error (e.g. invalid API key) should still be non-retryable.
2634        let err = anyhow::Error::msg("400 Bad Request: invalid api key provided");
2635        assert!(is_non_retryable(&err));
2636    }
2637
2638    struct StreamingToolEventMock {
2639        stream_calls: Arc<AtomicUsize>,
2640        supports_tool_events: bool,
2641    }
2642
2643    impl StreamingToolEventMock {
2644        fn new(supports_tool_events: bool) -> Self {
2645            Self {
2646                stream_calls: Arc::new(AtomicUsize::new(0)),
2647                supports_tool_events,
2648            }
2649        }
2650    }
2651
2652    #[async_trait]
2653    impl ModelProvider for StreamingToolEventMock {
2654        async fn chat_with_system(
2655            &self,
2656            _system_prompt: Option<&str>,
2657            _message: &str,
2658            _model: &str,
2659            _temperature: Option<f64>,
2660        ) -> anyhow::Result<String> {
2661            Ok("ok".to_string())
2662        }
2663
2664        fn supports_streaming(&self) -> bool {
2665            true
2666        }
2667
2668        fn supports_streaming_tool_events(&self) -> bool {
2669            self.supports_tool_events
2670        }
2671
2672        fn stream_chat(
2673            &self,
2674            _request: ChatRequest<'_>,
2675            _model: &str,
2676            _temperature: Option<f64>,
2677            _options: StreamOptions,
2678        ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
2679            self.stream_calls.fetch_add(1, Ordering::SeqCst);
2680            stream::iter(vec![
2681                Ok(StreamEvent::ToolCall(super::super::traits::ToolCall {
2682                    id: "call_1".to_string(),
2683                    name: "shell".to_string(),
2684                    arguments: r#"{"command":"date"}"#.to_string(),
2685                    extra_content: None,
2686                })),
2687                Ok(StreamEvent::Final),
2688            ])
2689            .boxed()
2690        }
2691    }
2692    impl ::zeroclaw_api::attribution::Attributable for StreamingToolEventMock {
2693        fn role(&self) -> ::zeroclaw_api::attribution::Role {
2694            ::zeroclaw_api::attribution::Role::Provider(
2695                ::zeroclaw_api::attribution::ProviderKind::Model(
2696                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2697                ),
2698            )
2699        }
2700        fn alias(&self) -> &str {
2701            "StreamingToolEventMock"
2702        }
2703    }
2704
2705    // Arc<StreamingToolEventMock> ModelProvider impl provided by blanket impl in zeroclaw-types.
2706
2707    #[tokio::test]
2708    async fn stream_chat_prefers_provider_with_tool_event_support() {
2709        let primary = Arc::new(StreamingToolEventMock::new(false));
2710        let fallback = Arc::new(StreamingToolEventMock::new(true));
2711        let model_provider = ReliableModelProvider::new(
2712            "test",
2713            vec![
2714                (
2715                    "primary".into(),
2716                    Box::new(Arc::clone(&primary)) as Box<dyn ModelProvider>,
2717                ),
2718                (
2719                    "fallback".into(),
2720                    Box::new(Arc::clone(&fallback)) as Box<dyn ModelProvider>,
2721                ),
2722            ],
2723            0,
2724            1,
2725        );
2726
2727        let messages = vec![ChatMessage::user("hello")];
2728        let tools = vec![ToolSpec {
2729            name: "shell".to_string(),
2730            description: "run shell".to_string(),
2731            parameters: serde_json::json!({
2732                "type": "object",
2733                "properties": {
2734                    "command": { "type": "string" }
2735                }
2736            }),
2737        }];
2738        let mut stream = model_provider.stream_chat(
2739            ChatRequest {
2740                messages: &messages,
2741                tools: Some(&tools),
2742                thinking: None,
2743            },
2744            "model",
2745            Some(0.0),
2746            StreamOptions::new(true),
2747        );
2748
2749        let first = stream.next().await.unwrap().unwrap();
2750        let second = stream.next().await.unwrap().unwrap();
2751        assert!(stream.next().await.is_none());
2752
2753        match first {
2754            StreamEvent::ToolCall(call) => assert_eq!(call.name, "shell"),
2755            other => panic!("expected tool-call event, got {other:?}"),
2756        }
2757        assert!(matches!(second, StreamEvent::Final));
2758        assert_eq!(primary.stream_calls.load(Ordering::SeqCst), 0);
2759        assert_eq!(fallback.stream_calls.load(Ordering::SeqCst), 1);
2760    }
2761
2762    #[tokio::test]
2763    async fn stream_chat_errors_when_no_provider_supports_tool_events() {
2764        let primary = Arc::new(StreamingToolEventMock::new(false));
2765        let model_provider = ReliableModelProvider::new(
2766            "test",
2767            vec![(
2768                "primary".into(),
2769                Box::new(Arc::clone(&primary)) as Box<dyn ModelProvider>,
2770            )],
2771            0,
2772            1,
2773        );
2774
2775        let messages = vec![ChatMessage::user("hello")];
2776        let tools = vec![ToolSpec {
2777            name: "shell".to_string(),
2778            description: "run shell".to_string(),
2779            parameters: serde_json::json!({"type": "object"}),
2780        }];
2781        let mut stream = model_provider.stream_chat(
2782            ChatRequest {
2783                messages: &messages,
2784                tools: Some(&tools),
2785                thinking: None,
2786            },
2787            "model",
2788            Some(0.0),
2789            StreamOptions::new(true),
2790        );
2791
2792        let first = stream.next().await.unwrap();
2793        let err = first.expect_err("stream should fail without tool-event support");
2794        assert!(
2795            err.to_string()
2796                .contains("No model_provider supports streaming tool events"),
2797            "unexpected stream error: {err}"
2798        );
2799        assert!(stream.next().await.is_none());
2800        assert_eq!(primary.stream_calls.load(Ordering::SeqCst), 0);
2801    }
2802
2803    // ── stream_chat_with_history failover tests ──────────────────────
2804
2805    /// Mock model_provider that supports streaming via stream_chat_with_history.
2806    struct StreamingHistoryMock {
2807        stream_calls: Arc<AtomicUsize>,
2808        supports: bool,
2809    }
2810
2811    #[async_trait]
2812    impl ModelProvider for StreamingHistoryMock {
2813        async fn chat_with_system(
2814            &self,
2815            _system_prompt: Option<&str>,
2816            _message: &str,
2817            _model: &str,
2818            _temperature: Option<f64>,
2819        ) -> anyhow::Result<String> {
2820            Ok("ok".to_string())
2821        }
2822
2823        fn supports_streaming(&self) -> bool {
2824            self.supports
2825        }
2826
2827        fn stream_chat_with_history(
2828            &self,
2829            messages: &[ChatMessage],
2830            _model: &str,
2831            _temperature: Option<f64>,
2832            _options: StreamOptions,
2833        ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
2834            self.stream_calls.fetch_add(1, Ordering::SeqCst);
2835            // Echo the number of messages as the delta to verify history was passed through
2836            let msg_count = messages.len().to_string();
2837            stream::iter(vec![
2838                Ok(StreamChunk::delta(msg_count)),
2839                Ok(StreamChunk::final_chunk()),
2840            ])
2841            .boxed()
2842        }
2843    }
2844    impl ::zeroclaw_api::attribution::Attributable for StreamingHistoryMock {
2845        fn role(&self) -> ::zeroclaw_api::attribution::Role {
2846            ::zeroclaw_api::attribution::Role::Provider(
2847                ::zeroclaw_api::attribution::ProviderKind::Model(
2848                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2849                ),
2850            )
2851        }
2852        fn alias(&self) -> &str {
2853            "StreamingHistoryMock"
2854        }
2855    }
2856
2857    #[tokio::test]
2858    async fn stream_chat_with_history_delegates_to_streaming_provider() {
2859        let calls = Arc::new(AtomicUsize::new(0));
2860        let model_provider = ReliableModelProvider::new(
2861            "test",
2862            vec![(
2863                "primary".into(),
2864                Box::new(StreamingHistoryMock {
2865                    stream_calls: Arc::clone(&calls),
2866                    supports: true,
2867                }) as Box<dyn ModelProvider>,
2868            )],
2869            0,
2870            1,
2871        );
2872
2873        let messages = vec![
2874            ChatMessage::system("system"),
2875            ChatMessage::user("msg1"),
2876            ChatMessage::assistant("resp1"),
2877            ChatMessage::user("msg2"),
2878        ];
2879        let mut stream = model_provider.stream_chat_with_history(
2880            &messages,
2881            "model",
2882            Some(0.0),
2883            StreamOptions::new(true),
2884        );
2885
2886        let first = stream.next().await.unwrap().unwrap();
2887        assert_eq!(
2888            first.delta, "4",
2889            "should pass all 4 messages to model_provider"
2890        );
2891        let second = stream.next().await.unwrap().unwrap();
2892        assert!(second.is_final);
2893        assert!(stream.next().await.is_none());
2894        assert_eq!(calls.load(Ordering::SeqCst), 1);
2895    }
2896
2897    #[tokio::test]
2898    async fn stream_chat_with_history_skips_non_streaming_providers() {
2899        let non_streaming_calls = Arc::new(AtomicUsize::new(0));
2900        let streaming_calls = Arc::new(AtomicUsize::new(0));
2901
2902        let model_provider = ReliableModelProvider::new(
2903            "test",
2904            vec![
2905                (
2906                    "non-streaming".into(),
2907                    Box::new(StreamingHistoryMock {
2908                        stream_calls: Arc::clone(&non_streaming_calls),
2909                        supports: false,
2910                    }) as Box<dyn ModelProvider>,
2911                ),
2912                (
2913                    "streaming".into(),
2914                    Box::new(StreamingHistoryMock {
2915                        stream_calls: Arc::clone(&streaming_calls),
2916                        supports: true,
2917                    }) as Box<dyn ModelProvider>,
2918                ),
2919            ],
2920            0,
2921            1,
2922        );
2923
2924        let messages = vec![ChatMessage::user("hello")];
2925        let mut stream = model_provider.stream_chat_with_history(
2926            &messages,
2927            "model",
2928            Some(0.0),
2929            StreamOptions::new(true),
2930        );
2931
2932        let first = stream.next().await.unwrap().unwrap();
2933        assert_eq!(first.delta, "1");
2934        assert_eq!(
2935            non_streaming_calls.load(Ordering::SeqCst),
2936            0,
2937            "non-streaming model_provider should be skipped"
2938        );
2939        assert_eq!(
2940            streaming_calls.load(Ordering::SeqCst),
2941            1,
2942            "streaming model_provider should be used"
2943        );
2944    }
2945
2946    #[tokio::test]
2947    async fn stream_chat_with_history_errors_when_no_provider_supports_streaming() {
2948        let model_provider = ReliableModelProvider::new(
2949            "test",
2950            vec![(
2951                "non-streaming".into(),
2952                Box::new(StreamingHistoryMock {
2953                    stream_calls: Arc::new(AtomicUsize::new(0)),
2954                    supports: false,
2955                }) as Box<dyn ModelProvider>,
2956            )],
2957            0,
2958            1,
2959        );
2960
2961        let messages = vec![ChatMessage::user("hello")];
2962        let mut stream = model_provider.stream_chat_with_history(
2963            &messages,
2964            "model",
2965            Some(0.0),
2966            StreamOptions::new(true),
2967        );
2968
2969        let first = stream.next().await.unwrap();
2970        let err = first.expect_err("should fail when no model_provider supports streaming");
2971        assert!(
2972            err.to_string()
2973                .contains("No model_provider supports streaming"),
2974            "unexpected error: {err}"
2975        );
2976        assert!(stream.next().await.is_none());
2977    }
2978
2979    #[tokio::test]
2980    async fn fallback_records_provider_fallback_info() {
2981        scope_provider_fallback(async {
2982            let model_provider = ReliableModelProvider::new(
2983                "test",
2984                vec![
2985                    (
2986                        "broken".into(),
2987                        Box::new(MockModelProvider {
2988                            calls: Arc::new(AtomicUsize::new(0)),
2989                            fail_until_attempt: 99, // always fail
2990                            response: "unused",
2991                            error: "401 Unauthorized",
2992                        }),
2993                    ),
2994                    (
2995                        "working".into(),
2996                        Box::new(MockModelProvider {
2997                            calls: Arc::new(AtomicUsize::new(0)),
2998                            fail_until_attempt: 0,
2999                            response: "hello from working",
3000                            error: "unused",
3001                        }),
3002                    ),
3003                ],
3004                2,
3005                1,
3006            );
3007
3008            let resp = model_provider
3009                .simple_chat("hi", "test-model", Some(0.0))
3010                .await
3011                .unwrap();
3012            assert_eq!(resp, "hello from working");
3013
3014            let fb = take_last_provider_fallback();
3015            assert!(fb.is_some(), "fallback info should be recorded");
3016            let fb = fb.unwrap();
3017            assert_eq!(fb.requested_provider, "broken");
3018            assert_eq!(fb.actual_provider, "working");
3019            assert_eq!(fb.actual_model, "test-model");
3020
3021            // Second take should be None.
3022            assert!(take_last_provider_fallback().is_none());
3023        })
3024        .await;
3025    }
3026
3027    // Regression for #6589: ReliableModelProvider::supports_vision() must reflect the
3028    // primary (first) provider, not .any() across the fallback chain. This mirrors
3029    // supports_native_tools() which already uses .first().
3030    #[test]
3031    fn supports_vision_reflects_first_provider_not_any_fallback() {
3032        struct VisionMock(bool);
3033
3034        #[async_trait]
3035        impl ModelProvider for VisionMock {
3036            async fn chat_with_system(
3037                &self,
3038                _system_prompt: Option<&str>,
3039                _message: &str,
3040                _model: &str,
3041                _temperature: Option<f64>,
3042            ) -> anyhow::Result<String> {
3043                Ok(String::new())
3044            }
3045
3046            fn supports_vision(&self) -> bool {
3047                self.0
3048            }
3049        }
3050        impl ::zeroclaw_api::attribution::Attributable for VisionMock {
3051            fn role(&self) -> ::zeroclaw_api::attribution::Role {
3052                ::zeroclaw_api::attribution::Role::Provider(
3053                    ::zeroclaw_api::attribution::ProviderKind::Model(
3054                        ::zeroclaw_api::attribution::ModelProviderKind::Custom,
3055                    ),
3056                )
3057            }
3058            fn alias(&self) -> &str {
3059                "VisionMock"
3060            }
3061        }
3062
3063        let provider = ReliableModelProvider::new(
3064            "test",
3065            vec![
3066                (
3067                    "primary".into(),
3068                    Box::new(VisionMock(false)) as Box<dyn ModelProvider>,
3069                ),
3070                (
3071                    "fallback".into(),
3072                    Box::new(VisionMock(true)) as Box<dyn ModelProvider>,
3073                ),
3074            ],
3075            0,
3076            0,
3077        );
3078
3079        assert!(
3080            !provider.supports_vision(),
3081            "ReliableModelProvider with non-vision primary must report supports_vision()=false even when a fallback supports vision"
3082        );
3083
3084        let provider = ReliableModelProvider::new(
3085            "test",
3086            vec![
3087                (
3088                    "primary".into(),
3089                    Box::new(VisionMock(true)) as Box<dyn ModelProvider>,
3090                ),
3091                (
3092                    "fallback".into(),
3093                    Box::new(VisionMock(false)) as Box<dyn ModelProvider>,
3094                ),
3095            ],
3096            0,
3097            0,
3098        );
3099
3100        assert!(provider.supports_vision());
3101    }
3102}