1use super::ModelProvider;
2use super::traits::{
3 ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamEvent, StreamOptions, StreamResult,
4};
5use async_trait::async_trait;
6use futures_util::{StreamExt, stream};
7use std::cell::RefCell;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
20pub struct ProviderFallbackInfo {
21 pub requested_provider: String,
23 pub requested_model: String,
25 pub actual_provider: String,
27 pub actual_model: String,
29}
30
31tokio::task_local! {
32 static PROVIDER_FALLBACK: RefCell<Option<ProviderFallbackInfo>>;
33}
34
35pub fn take_last_provider_fallback() -> Option<ProviderFallbackInfo> {
38 PROVIDER_FALLBACK
39 .try_with(|cell| cell.borrow_mut().take())
40 .ok()
41 .flatten()
42}
43
44pub async fn scope_provider_fallback<F: std::future::Future>(future: F) -> F::Output {
49 PROVIDER_FALLBACK.scope(RefCell::new(None), future).await
50}
51
52fn record_provider_fallback(
54 requested_provider: &str,
55 requested_model: &str,
56 actual_provider: &str,
57 actual_model: &str,
58) {
59 let _ = PROVIDER_FALLBACK.try_with(|cell| {
60 *cell.borrow_mut() = Some(ProviderFallbackInfo {
61 requested_provider: requested_provider.to_string(),
62 requested_model: requested_model.to_string(),
63 actual_provider: actual_provider.to_string(),
64 actual_model: actual_model.to_string(),
65 });
66 });
67}
68
69pub fn is_non_retryable(err: &anyhow::Error) -> bool {
77 if is_context_window_exceeded(err) {
80 return false;
81 }
82
83 if is_tool_schema_error(err) {
87 return false;
88 }
89
90 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
93 && let Some(status) = reqwest_err.status()
94 {
95 let code = status.as_u16();
96 return status.is_client_error() && code != 429 && code != 408;
97 }
98 let msg = err.to_string();
101 for word in msg.split(|c: char| !c.is_ascii_digit()) {
102 if let Ok(code) = word.parse::<u16>()
103 && (400..500).contains(&code)
104 {
105 return code != 429 && code != 408;
106 }
107 }
108
109 let msg_lower = msg.to_lowercase();
112 let auth_failure_hints = [
113 "invalid api key",
114 "incorrect api key",
115 "missing api key",
116 "api key not set",
117 "authentication failed",
118 "auth failed",
119 "unauthorized",
120 "forbidden",
121 "permission denied",
122 "access denied",
123 "invalid token",
124 ];
125
126 if auth_failure_hints
127 .iter()
128 .any(|hint| msg_lower.contains(hint))
129 {
130 return true;
131 }
132
133 msg_lower.contains("model")
134 && (msg_lower.contains("not found")
135 || msg_lower.contains("unknown")
136 || msg_lower.contains("unsupported")
137 || msg_lower.contains("does not exist")
138 || msg_lower.contains("invalid"))
139}
140
141pub fn is_auth_error(err: &anyhow::Error) -> bool {
145 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
146 && let Some(status) = reqwest_err.status()
147 {
148 let code = status.as_u16();
149 return code == 401 || code == 403;
150 }
151
152 let msg_lower = err.to_string().to_lowercase();
153 let hints = [
154 "401 unauthorized",
155 "403 forbidden",
156 "invalid api key",
157 "incorrect api key",
158 "authentication failed",
159 "auth failed",
160 "unauthorized",
161 "invalid token",
162 "token expired",
163 "access_token",
164 ];
165
166 hints.iter().any(|hint| msg_lower.contains(hint))
167}
168
169pub fn is_tool_schema_error(err: &anyhow::Error) -> bool {
175 let lower = err.to_string().to_lowercase();
176 let hints = [
177 "tool call validation failed",
178 "was not in request",
179 "not found in tool list",
180 "invalid_tool_call",
181 ];
182 hints.iter().any(|hint| lower.contains(hint))
183}
184
185pub fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
186 let lower = err.to_string().to_lowercase();
187 let hints = [
188 "exceeds the context window",
189 "exceeds the available context size",
190 "context window of this model",
191 "maximum context length",
192 "context length exceeded",
193 "too many tokens",
194 "token limit exceeded",
195 "prompt is too long",
196 "input is too long",
197 "prompt exceeds max length",
198 ];
199
200 hints.iter().any(|hint| lower.contains(hint))
201}
202
203fn is_rate_limited(err: &anyhow::Error) -> bool {
205 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
206 && let Some(status) = reqwest_err.status()
207 {
208 return status.as_u16() == 429;
209 }
210 let msg = err.to_string();
211 msg.contains("429")
212 && (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
213}
214
215fn is_non_retryable_rate_limit(err: &anyhow::Error) -> bool {
222 if !is_rate_limited(err) {
223 return false;
224 }
225
226 let msg = err.to_string();
227 let lower = msg.to_lowercase();
228
229 let business_hints = [
230 "plan does not include",
231 "doesn't include",
232 "not include",
233 "insufficient balance",
234 "insufficient_balance",
235 "insufficient quota",
236 "insufficient_quota",
237 "quota exhausted",
238 "out of credits",
239 "no available package",
240 "package not active",
241 "purchase package",
242 "model not available for your plan",
243 ];
244
245 if business_hints.iter().any(|hint| lower.contains(hint)) {
246 return true;
247 }
248
249 for token in lower.split(|c: char| !c.is_ascii_digit()) {
251 if let Ok(code) = token.parse::<u16>()
252 && matches!(code, 1113 | 1311)
253 {
254 return true;
255 }
256 }
257
258 false
259}
260
261fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
264 let msg = err.to_string();
265 let lower = msg.to_lowercase();
266
267 for prefix in &[
269 "retry-after:",
270 "retry_after:",
271 "retry-after ",
272 "retry_after ",
273 ] {
274 if let Some(pos) = lower.find(prefix) {
275 let after = &msg[pos + prefix.len()..];
276 let num_str: String = after
277 .trim()
278 .chars()
279 .take_while(|c| c.is_ascii_digit() || *c == '.')
280 .collect();
281 if let Ok(secs) = num_str.parse::<f64>()
282 && secs.is_finite()
283 && secs >= 0.0
284 {
285 let millis = Duration::from_secs_f64(secs).as_millis();
286 if let Ok(value) = u64::try_from(millis) {
287 return Some(value);
288 }
289 }
290 }
291 }
292 None
293}
294
295fn failure_reason(rate_limited: bool, non_retryable: bool) -> &'static str {
296 if rate_limited && non_retryable {
297 "rate_limited_non_retryable"
298 } else if rate_limited {
299 "rate_limited"
300 } else if non_retryable {
301 "non_retryable"
302 } else {
303 "retryable"
304 }
305}
306
307fn compact_error_detail(err: &anyhow::Error) -> String {
308 super::sanitize_api_error(&format!("{err:#}"))
309 .split_whitespace()
310 .collect::<Vec<_>>()
311 .join(" ")
312}
313
314fn truncate_for_context(messages: &mut Vec<ChatMessage>) -> usize {
318 let non_system: Vec<usize> = messages
320 .iter()
321 .enumerate()
322 .filter(|(_, m)| m.role != "system")
323 .map(|(i, _)| i)
324 .collect();
325
326 if non_system.len() <= 1 {
328 return 0;
329 }
330
331 let drop_count = non_system.len() / 2;
333 let indices_to_remove: Vec<usize> = non_system[..drop_count].to_vec();
334
335 for &idx in indices_to_remove.iter().rev() {
337 messages.remove(idx);
338 }
339
340 drop_count
341}
342
343fn push_failure(
344 failures: &mut Vec<String>,
345 provider_name: &str,
346 model: &str,
347 attempt: u32,
348 max_attempts: u32,
349 reason: &str,
350 error_detail: &str,
351) {
352 failures.push(format!(
353 "model_provider={provider_name} model={model} attempt {attempt}/{max_attempts}: {reason}; error={error_detail}"
354 ));
355}
356
357pub struct ReliableModelProvider {
372 alias: String,
374 model_providers: Vec<(String, Box<dyn ModelProvider>)>,
375 max_retries: u32,
376 base_backoff_ms: u64,
377 api_keys: Vec<String>,
379 key_index: AtomicUsize,
380 model_fallbacks: HashMap<String, Vec<String>>,
382}
383
384impl ReliableModelProvider {
385 pub fn new(
386 alias: &str,
387 model_providers: Vec<(String, Box<dyn ModelProvider>)>,
388 max_retries: u32,
389 base_backoff_ms: u64,
390 ) -> Self {
391 Self {
392 alias: alias.to_string(),
393 model_providers,
394 max_retries,
395 base_backoff_ms: base_backoff_ms.max(50),
396 api_keys: Vec::new(),
397 key_index: AtomicUsize::new(0),
398 model_fallbacks: HashMap::new(),
399 }
400 }
401 pub fn with_api_keys(mut self, keys: Vec<String>) -> Self {
403 self.api_keys = keys;
404 self
405 }
406
407 #[cfg(test)]
410 pub fn with_model_fallbacks(mut self, fallbacks: HashMap<String, Vec<String>>) -> Self {
411 self.model_fallbacks = fallbacks;
412 self
413 }
414
415 fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> {
417 let mut chain = vec![model];
418 if let Some(fallbacks) = self.model_fallbacks.get(model) {
419 chain.extend(fallbacks.iter().map(|s| s.as_str()));
420 }
421 chain
422 }
423
424 fn rotate_key(&self) -> Option<&str> {
426 if self.api_keys.is_empty() {
427 return None;
428 }
429 let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len();
430 Some(&self.api_keys[idx])
431 }
432
433 fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 {
435 if let Some(retry_after) = parse_retry_after_ms(err) {
436 retry_after.min(30_000).max(base)
438 } else {
439 base
440 }
441 }
442}
443
444#[async_trait]
445impl ModelProvider for ReliableModelProvider {
446 async fn warmup(&self) -> anyhow::Result<()> {
447 for (name, model_provider) in &self.model_providers {
448 ::zeroclaw_log::record!(
449 INFO,
450 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
451 .with_attrs(::serde_json::json!({"model_provider": name})),
452 "Warming up model_provider connection pool"
453 );
454 if model_provider.warmup().await.is_err() {
455 ::zeroclaw_log::record!(
456 WARN,
457 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
458 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
459 .with_attrs(::serde_json::json!({"model_provider": name})),
460 "Warmup failed (non-fatal)"
461 );
462 }
463 }
464 Ok(())
465 }
466
467 async fn chat_with_system(
468 &self,
469 system_prompt: Option<&str>,
470 message: &str,
471 model: &str,
472 temperature: Option<f64>,
473 ) -> anyhow::Result<String> {
474 let models = self.model_chain(model);
475 let mut failures = Vec::new();
476
477 for current_model in &models {
482 for (provider_name, model_provider) in &self.model_providers {
483 let mut backoff_ms = self.base_backoff_ms;
484
485 for attempt in 0..=self.max_retries {
486 match model_provider
487 .chat_with_system(system_prompt, message, current_model, temperature)
488 .await
489 {
490 Ok(resp) => {
491 if attempt > 0
492 || *current_model != model
493 || self.model_providers.first().map(|(n, _)| n.as_str())
494 != Some(provider_name)
495 {
496 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model})), "ModelProvider recovered (failover/retry)");
497 let primary = self
498 .model_providers
499 .first()
500 .map(|(n, _)| n.as_str())
501 .unwrap_or("");
502 record_provider_fallback(
503 primary,
504 model,
505 provider_name,
506 current_model,
507 );
508 }
509 return Ok(resp);
510 }
511 Err(e) => {
512 if is_context_window_exceeded(&e) {
515 let error_detail = compact_error_detail(&e);
516 push_failure(
517 &mut failures,
518 provider_name,
519 current_model,
520 attempt + 1,
521 self.max_retries + 1,
522 "non_retryable",
523 &error_detail,
524 );
525 anyhow::bail!(
526 "Request exceeds model context window. Attempts:\n{}",
527 failures.join("\n")
528 );
529 }
530
531 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
532 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
533 let rate_limited = is_rate_limited(&e);
534 let failure_reason = failure_reason(rate_limited, non_retryable);
535 let error_detail = compact_error_detail(&e);
536
537 push_failure(
538 &mut failures,
539 provider_name,
540 current_model,
541 attempt + 1,
542 self.max_retries + 1,
543 failure_reason,
544 &error_detail,
545 );
546
547 if rate_limited
550 && !non_retryable_rate_limit
551 && let Some(new_key) = self.rotate_key()
552 {
553 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
554 but cannot apply (ModelProvider trait has no set_api_key). \
555 Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
556 }
557
558 if non_retryable {
559 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
560 break;
561 }
562
563 if attempt < self.max_retries {
564 let wait = self.compute_backoff(backoff_ms, &e);
565 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
566 tokio::time::sleep(Duration::from_millis(wait)).await;
567 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
568 }
569 }
570 }
571 }
572
573 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
574 }
575
576 if *current_model != model {
577 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"original_model": model, "fallback_model": *current_model})), "Model fallback exhausted all model_providers, trying next fallback model");
578 }
579 }
580
581 anyhow::bail!(
582 "All model_providers/models failed. Attempts:\n{}",
583 failures.join("\n")
584 )
585 }
586
587 async fn chat_with_history(
588 &self,
589 messages: &[ChatMessage],
590 model: &str,
591 temperature: Option<f64>,
592 ) -> anyhow::Result<String> {
593 let models = self.model_chain(model);
594 let mut failures = Vec::new();
595 let mut effective_messages = messages.to_vec();
596 let mut context_truncated = false;
597
598 for current_model in &models {
599 for (provider_name, model_provider) in &self.model_providers {
600 let mut backoff_ms = self.base_backoff_ms;
601
602 for attempt in 0..=self.max_retries {
603 match model_provider
604 .chat_with_history(&effective_messages, current_model, temperature)
605 .await
606 {
607 Ok(resp) => {
608 if attempt > 0
609 || *current_model != model
610 || context_truncated
611 || self.model_providers.first().map(|(n, _)| n.as_str())
612 != Some(provider_name)
613 {
614 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model, "context_truncated": context_truncated})), "ModelProvider recovered (failover/retry)");
615 let primary = self
616 .model_providers
617 .first()
618 .map(|(n, _)| n.as_str())
619 .unwrap_or("");
620 record_provider_fallback(
621 primary,
622 model,
623 provider_name,
624 current_model,
625 );
626 }
627 return Ok(resp);
628 }
629 Err(e) => {
630 if is_context_window_exceeded(&e) && !context_truncated {
632 let dropped = truncate_for_context(&mut effective_messages);
633 if dropped > 0 {
634 context_truncated = true;
635 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "dropped": dropped, "remaining": effective_messages.len()})), "Context window exceeded; truncated history and retrying");
636 continue; }
638 let error_detail = compact_error_detail(&e);
642 push_failure(
643 &mut failures,
644 provider_name,
645 current_model,
646 attempt + 1,
647 self.max_retries + 1,
648 "non_retryable",
649 &error_detail,
650 );
651 anyhow::bail!(
652 "Request exceeds model context window and cannot be reduced further. \
653 Try using a model with a larger context window, reducing the number \
654 of tools/skills, or enabling compact_context in config. Attempts:\n{}",
655 failures.join("\n")
656 );
657 }
658
659 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
660 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
661 let rate_limited = is_rate_limited(&e);
662 let failure_reason = failure_reason(rate_limited, non_retryable);
663 let error_detail = compact_error_detail(&e);
664
665 push_failure(
666 &mut failures,
667 provider_name,
668 current_model,
669 attempt + 1,
670 self.max_retries + 1,
671 failure_reason,
672 &error_detail,
673 );
674
675 if rate_limited
676 && !non_retryable_rate_limit
677 && let Some(new_key) = self.rotate_key()
678 {
679 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
680 but cannot apply (ModelProvider trait has no set_api_key). \
681 Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
682 }
683
684 if non_retryable {
685 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
686 break;
687 }
688
689 if attempt < self.max_retries {
690 let wait = self.compute_backoff(backoff_ms, &e);
691 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
692 tokio::time::sleep(Duration::from_millis(wait)).await;
693 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
694 }
695 }
696 }
697 }
698
699 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
700 }
701 }
702
703 anyhow::bail!(
704 "All model_providers/models failed. Attempts:\n{}",
705 failures.join("\n")
706 )
707 }
708
709 fn supports_native_tools(&self) -> bool {
710 self.model_providers
711 .first()
712 .map(|(_, p)| p.supports_native_tools())
713 .unwrap_or(false)
714 }
715
716 fn supports_vision(&self) -> bool {
717 self.model_providers
718 .first()
719 .map(|(_, p)| p.supports_vision())
720 .unwrap_or(false)
721 }
722
723 async fn chat_with_tools(
724 &self,
725 messages: &[ChatMessage],
726 tools: &[serde_json::Value],
727 model: &str,
728 temperature: Option<f64>,
729 ) -> anyhow::Result<ChatResponse> {
730 let models = self.model_chain(model);
731 let mut failures = Vec::new();
732 let mut effective_messages = messages.to_vec();
733 let mut context_truncated = false;
734
735 for current_model in &models {
736 for (provider_name, model_provider) in &self.model_providers {
737 let mut backoff_ms = self.base_backoff_ms;
738
739 for attempt in 0..=self.max_retries {
740 match model_provider
741 .chat_with_tools(&effective_messages, tools, current_model, temperature)
742 .await
743 {
744 Ok(resp) => {
745 if attempt > 0
746 || *current_model != model
747 || context_truncated
748 || self.model_providers.first().map(|(n, _)| n.as_str())
749 != Some(provider_name)
750 {
751 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model, "context_truncated": context_truncated})), "ModelProvider recovered (failover/retry)");
752 let primary = self
753 .model_providers
754 .first()
755 .map(|(n, _)| n.as_str())
756 .unwrap_or("");
757 record_provider_fallback(
758 primary,
759 model,
760 provider_name,
761 current_model,
762 );
763 }
764 return Ok(resp);
765 }
766 Err(e) => {
767 if is_context_window_exceeded(&e) && !context_truncated {
769 let dropped = truncate_for_context(&mut effective_messages);
770 if dropped > 0 {
771 context_truncated = true;
772 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "dropped": dropped, "remaining": effective_messages.len()})), "Context window exceeded; truncated history and retrying");
773 continue; }
775 let error_detail = compact_error_detail(&e);
779 push_failure(
780 &mut failures,
781 provider_name,
782 current_model,
783 attempt + 1,
784 self.max_retries + 1,
785 "non_retryable",
786 &error_detail,
787 );
788 anyhow::bail!(
789 "Request exceeds model context window and cannot be reduced further. \
790 Try using a model with a larger context window, reducing the number \
791 of tools/skills, or enabling compact_context in config. Attempts:\n{}",
792 failures.join("\n")
793 );
794 }
795
796 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
797 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
798 let rate_limited = is_rate_limited(&e);
799 let failure_reason = failure_reason(rate_limited, non_retryable);
800 let error_detail = compact_error_detail(&e);
801
802 push_failure(
803 &mut failures,
804 provider_name,
805 current_model,
806 attempt + 1,
807 self.max_retries + 1,
808 failure_reason,
809 &error_detail,
810 );
811
812 if rate_limited
813 && !non_retryable_rate_limit
814 && let Some(new_key) = self.rotate_key()
815 {
816 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
817 but cannot apply (ModelProvider trait has no set_api_key). \
818 Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
819 }
820
821 if non_retryable {
822 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
823 break;
824 }
825
826 if attempt < self.max_retries {
827 let wait = self.compute_backoff(backoff_ms, &e);
828 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
829 tokio::time::sleep(Duration::from_millis(wait)).await;
830 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
831 }
832 }
833 }
834 }
835
836 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
837 }
838 }
839
840 anyhow::bail!(
841 "All model_providers/models failed. Attempts:\n{}",
842 failures.join("\n")
843 )
844 }
845
846 async fn chat(
847 &self,
848 request: ChatRequest<'_>,
849 model: &str,
850 temperature: Option<f64>,
851 ) -> anyhow::Result<ChatResponse> {
852 let models = self.model_chain(model);
853 let mut failures = Vec::new();
854 let mut effective_messages = request.messages.to_vec();
855 let mut context_truncated = false;
856
857 for current_model in &models {
858 for (provider_name, model_provider) in &self.model_providers {
859 let mut backoff_ms = self.base_backoff_ms;
860
861 for attempt in 0..=self.max_retries {
862 let req = ChatRequest {
863 messages: &effective_messages,
864 tools: request.tools,
865 thinking: request.thinking,
866 };
867 match model_provider.chat(req, current_model, temperature).await {
868 Ok(resp) => {
869 if attempt > 0
870 || *current_model != model
871 || context_truncated
872 || self.model_providers.first().map(|(n, _)| n.as_str())
873 != Some(provider_name)
874 {
875 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model, "context_truncated": context_truncated})), "ModelProvider recovered (failover/retry)");
876 let primary = self
877 .model_providers
878 .first()
879 .map(|(n, _)| n.as_str())
880 .unwrap_or("");
881 record_provider_fallback(
882 primary,
883 model,
884 provider_name,
885 current_model,
886 );
887 }
888 return Ok(resp);
889 }
890 Err(e) => {
891 if is_context_window_exceeded(&e) && !context_truncated {
893 let dropped = truncate_for_context(&mut effective_messages);
894 if dropped > 0 {
895 context_truncated = true;
896 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "dropped": dropped, "remaining": effective_messages.len()})), "Context window exceeded; truncated history and retrying");
897 continue; }
899 let error_detail = compact_error_detail(&e);
903 push_failure(
904 &mut failures,
905 provider_name,
906 current_model,
907 attempt + 1,
908 self.max_retries + 1,
909 "non_retryable",
910 &error_detail,
911 );
912 anyhow::bail!(
913 "Request exceeds model context window and cannot be reduced further. \
914 Try using a model with a larger context window, reducing the number \
915 of tools/skills, or enabling compact_context in config. Attempts:\n{}",
916 failures.join("\n")
917 );
918 }
919
920 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
921 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
922 let rate_limited = is_rate_limited(&e);
923 let failure_reason = failure_reason(rate_limited, non_retryable);
924 let error_detail = compact_error_detail(&e);
925
926 push_failure(
927 &mut failures,
928 provider_name,
929 current_model,
930 attempt + 1,
931 self.max_retries + 1,
932 failure_reason,
933 &error_detail,
934 );
935
936 if rate_limited
937 && !non_retryable_rate_limit
938 && let Some(new_key) = self.rotate_key()
939 {
940 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
941 but cannot apply (ModelProvider trait has no set_api_key). \
942 Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
943 }
944
945 if non_retryable {
946 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
947 break;
948 }
949
950 if attempt < self.max_retries {
951 let wait = self.compute_backoff(backoff_ms, &e);
952 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
953 tokio::time::sleep(Duration::from_millis(wait)).await;
954 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
955 }
956 }
957 }
958 }
959
960 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
961 }
962
963 if *current_model != model {
964 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"original_model": model, "fallback_model": *current_model})), "Model fallback exhausted all model_providers, trying next fallback model");
965 }
966 }
967
968 anyhow::bail!(
969 "All model_providers/models failed. Attempts:\n{}",
970 failures.join("\n")
971 )
972 }
973
974 fn supports_streaming(&self) -> bool {
975 self.model_providers
976 .iter()
977 .any(|(_, p)| p.supports_streaming())
978 }
979
980 fn supports_streaming_tool_events(&self) -> bool {
981 self.model_providers
982 .iter()
983 .any(|(_, p)| p.supports_streaming_tool_events())
984 }
985
986 fn stream_chat(
987 &self,
988 request: ChatRequest<'_>,
989 model: &str,
990 temperature: Option<f64>,
991 options: StreamOptions,
992 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
993 let needs_tool_events = request.tools.is_some_and(|tools| !tools.is_empty());
994
995 for (provider_name, model_provider) in &self.model_providers {
996 if !model_provider.supports_streaming() || !options.enabled {
997 continue;
998 }
999
1000 if needs_tool_events && !model_provider.supports_streaming_tool_events() {
1001 continue;
1002 }
1003
1004 let provider_clone = provider_name.clone();
1005
1006 let current_model = self
1007 .model_chain(model)
1008 .first()
1009 .copied()
1010 .unwrap_or(model)
1011 .to_string();
1012
1013 let req = ChatRequest {
1014 messages: request.messages,
1015 tools: request.tools,
1016 thinking: request.thinking,
1017 };
1018 let stream = model_provider.stream_chat(req, ¤t_model, temperature, options);
1019 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamEvent>>(100);
1020
1021 tokio::spawn(async move {
1022 let mut stream = stream;
1023 while let Some(event) = stream.next().await {
1024 if let Err(ref e) = event {
1025 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_clone, "model": current_model, "e": e.to_string()})), "Streaming error: ");
1026 }
1027 if tx.send(event).await.is_err() {
1028 break;
1029 }
1030 }
1031 });
1032
1033 return stream::unfold(rx, |mut rx| async move {
1034 rx.recv().await.map(|event| (event, rx))
1035 })
1036 .boxed();
1037 }
1038
1039 let message = if needs_tool_events {
1040 "No model_provider supports streaming tool events".to_string()
1041 } else {
1042 "No model_provider supports streaming".to_string()
1043 };
1044 stream::once(async move { Err(super::traits::StreamError::ModelProvider(message)) }).boxed()
1045 }
1046
1047 fn stream_chat_with_system(
1048 &self,
1049 system_prompt: Option<&str>,
1050 message: &str,
1051 model: &str,
1052 temperature: Option<f64>,
1053 options: StreamOptions,
1054 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1055 for (provider_name, model_provider) in &self.model_providers {
1058 if !model_provider.supports_streaming() || !options.enabled {
1059 continue;
1060 }
1061
1062 let provider_clone = provider_name.clone();
1064
1065 let current_model = match self.model_chain(model).first() {
1067 Some(m) => (*m).to_string(),
1068 None => model.to_string(),
1069 };
1070
1071 let stream = model_provider.stream_chat_with_system(
1074 system_prompt,
1075 message,
1076 ¤t_model,
1077 temperature,
1078 options,
1079 );
1080
1081 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1083
1084 tokio::spawn(async move {
1085 let mut stream = stream;
1086 while let Some(chunk) = stream.next().await {
1087 if let Err(ref e) = chunk {
1088 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_clone, "model": current_model, "e": e.to_string()})), "Streaming error: ");
1089 }
1090 if tx.send(chunk).await.is_err() {
1091 break; }
1093 }
1094 });
1095
1096 return stream::unfold(rx, |mut rx| async move {
1098 rx.recv().await.map(|chunk| (chunk, rx))
1099 })
1100 .boxed();
1101 }
1102
1103 stream::once(async move {
1105 Err(super::traits::StreamError::ModelProvider(
1106 "No model_provider supports streaming".to_string(),
1107 ))
1108 })
1109 .boxed()
1110 }
1111
1112 fn stream_chat_with_history(
1113 &self,
1114 messages: &[ChatMessage],
1115 model: &str,
1116 temperature: Option<f64>,
1117 options: StreamOptions,
1118 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1119 for (provider_name, model_provider) in &self.model_providers {
1123 if !model_provider.supports_streaming() || !options.enabled {
1124 continue;
1125 }
1126
1127 let provider_clone = provider_name.clone();
1128
1129 let current_model = match self.model_chain(model).first() {
1130 Some(m) => (*m).to_string(),
1131 None => model.to_string(),
1132 };
1133
1134 let stream = model_provider.stream_chat_with_history(
1135 messages,
1136 ¤t_model,
1137 temperature,
1138 options,
1139 );
1140
1141 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1142
1143 tokio::spawn(async move {
1144 let mut stream = stream;
1145 while let Some(chunk) = stream.next().await {
1146 if let Err(ref e) = chunk {
1147 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_clone, "model": current_model, "e": e.to_string()})), "Streaming error: ");
1148 }
1149 if tx.send(chunk).await.is_err() {
1150 break; }
1152 }
1153 });
1154
1155 return stream::unfold(rx, |mut rx| async move {
1156 rx.recv().await.map(|chunk| (chunk, rx))
1157 })
1158 .boxed();
1159 }
1160
1161 stream::once(async move {
1163 Err(super::traits::StreamError::ModelProvider(
1164 "No model_provider supports streaming".to_string(),
1165 ))
1166 })
1167 .boxed()
1168 }
1169}
1170
1171impl ::zeroclaw_api::attribution::Attributable for ReliableModelProvider {
1172 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1173 ::zeroclaw_api::attribution::Role::Provider(
1174 ::zeroclaw_api::attribution::ProviderKind::Model(
1175 ::zeroclaw_api::attribution::ModelProviderKind::Reliable,
1176 ),
1177 )
1178 }
1179 fn alias(&self) -> &str {
1180 &self.alias
1181 }
1182}
1183
1184#[cfg(test)]
1185mod tests {
1186 use super::*;
1187 use futures_util::StreamExt;
1188 use std::sync::Arc;
1189 use zeroclaw_api::tool::ToolSpec;
1190
1191 struct MockModelProvider {
1192 calls: Arc<AtomicUsize>,
1193 fail_until_attempt: usize,
1194 response: &'static str,
1195 error: &'static str,
1196 }
1197
1198 #[async_trait]
1199 impl ModelProvider for MockModelProvider {
1200 async fn chat_with_system(
1201 &self,
1202 _system_prompt: Option<&str>,
1203 _message: &str,
1204 _model: &str,
1205 _temperature: Option<f64>,
1206 ) -> anyhow::Result<String> {
1207 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1208 if attempt <= self.fail_until_attempt {
1209 anyhow::bail!(self.error);
1210 }
1211 Ok(self.response.to_string())
1212 }
1213
1214 async fn chat_with_history(
1215 &self,
1216 _messages: &[ChatMessage],
1217 _model: &str,
1218 _temperature: Option<f64>,
1219 ) -> anyhow::Result<String> {
1220 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1221 if attempt <= self.fail_until_attempt {
1222 anyhow::bail!(self.error);
1223 }
1224 Ok(self.response.to_string())
1225 }
1226 }
1227 impl ::zeroclaw_api::attribution::Attributable for MockModelProvider {
1228 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1229 ::zeroclaw_api::attribution::Role::Provider(
1230 ::zeroclaw_api::attribution::ProviderKind::Model(
1231 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
1232 ),
1233 )
1234 }
1235 fn alias(&self) -> &str {
1236 "MockModelProvider"
1237 }
1238 }
1239
1240 struct ModelAwareMock {
1242 calls: Arc<AtomicUsize>,
1243 models_seen: parking_lot::Mutex<Vec<String>>,
1244 fail_models: Vec<&'static str>,
1245 response: &'static str,
1246 }
1247
1248 #[async_trait]
1249 impl ModelProvider for ModelAwareMock {
1250 async fn chat_with_system(
1251 &self,
1252 _system_prompt: Option<&str>,
1253 _message: &str,
1254 model: &str,
1255 _temperature: Option<f64>,
1256 ) -> anyhow::Result<String> {
1257 self.calls.fetch_add(1, Ordering::SeqCst);
1258 self.models_seen.lock().push(model.to_string());
1259 if self.fail_models.contains(&model) {
1260 anyhow::bail!("500 model {} unavailable", model);
1261 }
1262 Ok(self.response.to_string())
1263 }
1264 }
1265 impl ::zeroclaw_api::attribution::Attributable for ModelAwareMock {
1266 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1267 ::zeroclaw_api::attribution::Role::Provider(
1268 ::zeroclaw_api::attribution::ProviderKind::Model(
1269 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
1270 ),
1271 )
1272 }
1273 fn alias(&self) -> &str {
1274 "ModelAwareMock"
1275 }
1276 }
1277
1278 #[tokio::test]
1281 async fn succeeds_without_retry() {
1282 let calls = Arc::new(AtomicUsize::new(0));
1283 let model_provider = ReliableModelProvider::new(
1284 "test",
1285 vec![(
1286 "primary".into(),
1287 Box::new(MockModelProvider {
1288 calls: Arc::clone(&calls),
1289 fail_until_attempt: 0,
1290 response: "ok",
1291 error: "boom",
1292 }),
1293 )],
1294 2,
1295 1,
1296 );
1297
1298 let result = model_provider
1299 .simple_chat("hello", "test", Some(0.0))
1300 .await
1301 .unwrap();
1302 assert_eq!(result, "ok");
1303 assert_eq!(calls.load(Ordering::SeqCst), 1);
1304 }
1305
1306 #[tokio::test]
1307 async fn retries_then_recovers() {
1308 let calls = Arc::new(AtomicUsize::new(0));
1309 let model_provider = ReliableModelProvider::new(
1310 "test",
1311 vec![(
1312 "primary".into(),
1313 Box::new(MockModelProvider {
1314 calls: Arc::clone(&calls),
1315 fail_until_attempt: 1,
1316 response: "recovered",
1317 error: "temporary",
1318 }),
1319 )],
1320 2,
1321 1,
1322 );
1323
1324 let result = model_provider
1325 .simple_chat("hello", "test", Some(0.0))
1326 .await
1327 .unwrap();
1328 assert_eq!(result, "recovered");
1329 assert_eq!(calls.load(Ordering::SeqCst), 2);
1330 }
1331
1332 #[tokio::test]
1333 async fn falls_back_after_retries_exhausted() {
1334 let primary_calls = Arc::new(AtomicUsize::new(0));
1335 let fallback_calls = Arc::new(AtomicUsize::new(0));
1336
1337 let model_provider = ReliableModelProvider::new(
1338 "test",
1339 vec![
1340 (
1341 "primary".into(),
1342 Box::new(MockModelProvider {
1343 calls: Arc::clone(&primary_calls),
1344 fail_until_attempt: usize::MAX,
1345 response: "never",
1346 error: "primary down",
1347 }),
1348 ),
1349 (
1350 "fallback".into(),
1351 Box::new(MockModelProvider {
1352 calls: Arc::clone(&fallback_calls),
1353 fail_until_attempt: 0,
1354 response: "from fallback",
1355 error: "fallback down",
1356 }),
1357 ),
1358 ],
1359 1,
1360 1,
1361 );
1362
1363 let result = model_provider
1364 .simple_chat("hello", "test", Some(0.0))
1365 .await
1366 .unwrap();
1367 assert_eq!(result, "from fallback");
1368 assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
1369 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1370 }
1371
1372 #[tokio::test]
1373 async fn returns_aggregated_error_when_all_providers_fail() {
1374 let model_provider = ReliableModelProvider::new(
1375 "test",
1376 vec![
1377 (
1378 "p1".into(),
1379 Box::new(MockModelProvider {
1380 calls: Arc::new(AtomicUsize::new(0)),
1381 fail_until_attempt: usize::MAX,
1382 response: "never",
1383 error: "p1 error",
1384 }),
1385 ),
1386 (
1387 "p2".into(),
1388 Box::new(MockModelProvider {
1389 calls: Arc::new(AtomicUsize::new(0)),
1390 fail_until_attempt: usize::MAX,
1391 response: "never",
1392 error: "p2 error",
1393 }),
1394 ),
1395 ],
1396 0,
1397 1,
1398 );
1399
1400 let err = model_provider
1401 .simple_chat("hello", "test", Some(0.0))
1402 .await
1403 .expect_err("all model_providers should fail");
1404 let msg = err.to_string();
1405 assert!(msg.contains("All model_providers/models failed"));
1406 assert!(msg.contains("model_provider=p1 model=test"));
1407 assert!(msg.contains("model_provider=p2 model=test"));
1408 assert!(msg.contains("error=p1 error"));
1409 assert!(msg.contains("error=p2 error"));
1410 assert!(msg.contains("retryable"));
1411 }
1412
1413 #[test]
1414 fn non_retryable_detects_common_patterns() {
1415 assert!(is_non_retryable(&anyhow::Error::msg("400 Bad Request")));
1416 assert!(is_non_retryable(&anyhow::Error::msg("401 Unauthorized")));
1417 assert!(is_non_retryable(&anyhow::Error::msg("403 Forbidden")));
1418 assert!(is_non_retryable(&anyhow::Error::msg("404 Not Found")));
1419 assert!(is_non_retryable(&anyhow::Error::msg(
1420 "invalid api key provided"
1421 )));
1422 assert!(is_non_retryable(&anyhow::Error::msg(
1423 "authentication failed"
1424 )));
1425 assert!(is_non_retryable(&anyhow::Error::msg(
1426 "model glm-4.7 not found"
1427 )));
1428 assert!(is_non_retryable(&anyhow::Error::msg(
1429 "unsupported model: glm-4.7"
1430 )));
1431 assert!(!is_non_retryable(&anyhow::Error::msg(
1432 "429 Too Many Requests"
1433 )));
1434 assert!(!is_non_retryable(&anyhow::Error::msg(
1435 "408 Request Timeout"
1436 )));
1437 assert!(!is_non_retryable(&anyhow::Error::msg(
1438 "500 Internal Server Error"
1439 )));
1440 assert!(!is_non_retryable(&anyhow::Error::msg("502 Bad Gateway")));
1441 assert!(!is_non_retryable(&anyhow::Error::msg("timeout")));
1442 assert!(!is_non_retryable(&anyhow::Error::msg("connection reset")));
1443 assert!(!is_non_retryable(&anyhow::Error::msg(
1444 "model overloaded, try again later"
1445 )));
1446 assert!(!is_non_retryable(&anyhow::Error::msg(
1448 "OpenAI Codex stream error: Your input exceeds the context window of this model."
1449 )));
1450 }
1451
1452 #[test]
1453 fn auth_error_detects_common_patterns() {
1454 assert!(is_auth_error(&anyhow::Error::msg("401 Unauthorized")));
1455 assert!(is_auth_error(&anyhow::Error::msg("403 Forbidden")));
1456 assert!(is_auth_error(&anyhow::Error::msg("invalid api key")));
1457 assert!(is_auth_error(&anyhow::Error::msg("authentication failed")));
1458 assert!(is_auth_error(&anyhow::Error::msg("token expired")));
1459 assert!(!is_auth_error(&anyhow::Error::msg("400 Bad Request")));
1460 assert!(!is_auth_error(&anyhow::Error::msg("429 Too Many Requests")));
1461 assert!(!is_auth_error(&anyhow::Error::msg("timeout")));
1462 assert!(!is_auth_error(&anyhow::Error::msg("connection reset")));
1463 }
1464
1465 #[tokio::test]
1466 async fn context_window_error_aborts_retries_and_model_fallbacks() {
1467 let calls = Arc::new(AtomicUsize::new(0));
1468 let mut model_fallbacks = std::collections::HashMap::new();
1469 model_fallbacks.insert(
1470 "gpt-5.3-codex".to_string(),
1471 vec!["gpt-5.2-codex".to_string()],
1472 );
1473
1474 let model_provider = ReliableModelProvider::new("test", vec![(
1475 "openai-codex".into(),
1476 Box::new(MockModelProvider {
1477 calls: Arc::clone(&calls),
1478 fail_until_attempt: usize::MAX,
1479 response: "never",
1480 error: "OpenAI Codex stream error: Your input exceeds the context window of this model. Please adjust your input and try again.",
1481 }),
1482 )],
1483 4,
1484 1,
1485 )
1486 .with_model_fallbacks(model_fallbacks);
1487
1488 let err = model_provider
1489 .simple_chat("hello", "gpt-5.3-codex", Some(0.0))
1490 .await
1491 .expect_err("context window overflow should fail fast");
1492 let msg = err.to_string();
1493
1494 assert!(msg.contains("context window"));
1495 assert_eq!(calls.load(Ordering::SeqCst), 1);
1497 }
1498
1499 #[tokio::test]
1500 async fn aggregated_error_marks_non_retryable_model_mismatch_with_details() {
1501 let calls = Arc::new(AtomicUsize::new(0));
1502 let model_provider = ReliableModelProvider::new(
1503 "test",
1504 vec![(
1505 "custom".into(),
1506 Box::new(MockModelProvider {
1507 calls: Arc::clone(&calls),
1508 fail_until_attempt: usize::MAX,
1509 response: "never",
1510 error: "unsupported model: glm-4.7",
1511 }),
1512 )],
1513 3,
1514 1,
1515 );
1516
1517 let err = model_provider
1518 .simple_chat("hello", "glm-4.7", Some(0.0))
1519 .await
1520 .expect_err("model_provider should fail");
1521 let msg = err.to_string();
1522
1523 assert!(msg.contains("non_retryable"));
1524 assert!(msg.contains("error=unsupported model: glm-4.7"));
1525 assert_eq!(calls.load(Ordering::SeqCst), 1);
1527 }
1528
1529 #[tokio::test]
1530 async fn skips_retries_on_non_retryable_error() {
1531 let primary_calls = Arc::new(AtomicUsize::new(0));
1532 let fallback_calls = Arc::new(AtomicUsize::new(0));
1533
1534 let model_provider = ReliableModelProvider::new(
1535 "test",
1536 vec![
1537 (
1538 "primary".into(),
1539 Box::new(MockModelProvider {
1540 calls: Arc::clone(&primary_calls),
1541 fail_until_attempt: usize::MAX,
1542 response: "never",
1543 error: "401 Unauthorized",
1544 }),
1545 ),
1546 (
1547 "fallback".into(),
1548 Box::new(MockModelProvider {
1549 calls: Arc::clone(&fallback_calls),
1550 fail_until_attempt: 0,
1551 response: "from fallback",
1552 error: "fallback err",
1553 }),
1554 ),
1555 ],
1556 3,
1557 1,
1558 );
1559
1560 let result = model_provider
1561 .simple_chat("hello", "test", Some(0.0))
1562 .await
1563 .unwrap();
1564 assert_eq!(result, "from fallback");
1565 assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
1567 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1568 }
1569
1570 #[tokio::test]
1571 async fn chat_with_history_retries_then_recovers() {
1572 let calls = Arc::new(AtomicUsize::new(0));
1573 let model_provider = ReliableModelProvider::new(
1574 "test",
1575 vec![(
1576 "primary".into(),
1577 Box::new(MockModelProvider {
1578 calls: Arc::clone(&calls),
1579 fail_until_attempt: 1,
1580 response: "history ok",
1581 error: "temporary",
1582 }),
1583 )],
1584 2,
1585 1,
1586 );
1587
1588 let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
1589 let result = model_provider
1590 .chat_with_history(&messages, "test", Some(0.0))
1591 .await
1592 .unwrap();
1593 assert_eq!(result, "history ok");
1594 assert_eq!(calls.load(Ordering::SeqCst), 2);
1595 }
1596
1597 #[tokio::test]
1598 async fn chat_with_history_falls_back() {
1599 let primary_calls = Arc::new(AtomicUsize::new(0));
1600 let fallback_calls = Arc::new(AtomicUsize::new(0));
1601
1602 let model_provider = ReliableModelProvider::new(
1603 "test",
1604 vec![
1605 (
1606 "primary".into(),
1607 Box::new(MockModelProvider {
1608 calls: Arc::clone(&primary_calls),
1609 fail_until_attempt: usize::MAX,
1610 response: "never",
1611 error: "primary down",
1612 }),
1613 ),
1614 (
1615 "fallback".into(),
1616 Box::new(MockModelProvider {
1617 calls: Arc::clone(&fallback_calls),
1618 fail_until_attempt: 0,
1619 response: "fallback ok",
1620 error: "fallback err",
1621 }),
1622 ),
1623 ],
1624 1,
1625 1,
1626 );
1627
1628 let messages = vec![ChatMessage::user("hello")];
1629 let result = model_provider
1630 .chat_with_history(&messages, "test", Some(0.0))
1631 .await
1632 .unwrap();
1633 assert_eq!(result, "fallback ok");
1634 assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
1635 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1636 }
1637
1638 #[tokio::test]
1641 async fn model_failover_tries_fallback_model() {
1642 let calls = Arc::new(AtomicUsize::new(0));
1643 let mock = Arc::new(ModelAwareMock {
1644 calls: Arc::clone(&calls),
1645 models_seen: parking_lot::Mutex::new(Vec::new()),
1646 fail_models: vec!["claude-opus"],
1647 response: "ok from sonnet",
1648 });
1649
1650 let mut fallbacks = HashMap::new();
1651 fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
1652
1653 let model_provider = ReliableModelProvider::new(
1654 "test",
1655 vec![(
1656 "anthropic".into(),
1657 Box::new(mock.clone()) as Box<dyn ModelProvider>,
1658 )],
1659 0, 1,
1661 )
1662 .with_model_fallbacks(fallbacks);
1663
1664 let result = model_provider
1665 .simple_chat("hello", "claude-opus", Some(0.0))
1666 .await
1667 .unwrap();
1668 assert_eq!(result, "ok from sonnet");
1669
1670 let seen = mock.models_seen.lock();
1671 assert_eq!(seen.len(), 2);
1672 assert_eq!(seen[0], "claude-opus");
1673 assert_eq!(seen[1], "claude-sonnet");
1674 }
1675
1676 #[tokio::test]
1677 async fn model_failover_all_models_fail() {
1678 let calls = Arc::new(AtomicUsize::new(0));
1679 let mock = Arc::new(ModelAwareMock {
1680 calls: Arc::clone(&calls),
1681 models_seen: parking_lot::Mutex::new(Vec::new()),
1682 fail_models: vec!["model-a", "model-b", "model-c"],
1683 response: "never",
1684 });
1685
1686 let mut fallbacks = HashMap::new();
1687 fallbacks.insert(
1688 "model-a".to_string(),
1689 vec!["model-b".to_string(), "model-c".to_string()],
1690 );
1691
1692 let model_provider = ReliableModelProvider::new(
1693 "test",
1694 vec![(
1695 "p1".into(),
1696 Box::new(mock.clone()) as Box<dyn ModelProvider>,
1697 )],
1698 0,
1699 1,
1700 )
1701 .with_model_fallbacks(fallbacks);
1702
1703 let err = model_provider
1704 .simple_chat("hello", "model-a", Some(0.0))
1705 .await
1706 .expect_err("all models should fail");
1707 assert!(
1708 err.to_string()
1709 .contains("All model_providers/models failed")
1710 );
1711
1712 let seen = mock.models_seen.lock();
1713 assert_eq!(seen.len(), 3);
1714 }
1715
1716 #[tokio::test]
1717 async fn no_model_fallbacks_behaves_like_before() {
1718 let calls = Arc::new(AtomicUsize::new(0));
1719 let model_provider = ReliableModelProvider::new(
1720 "test",
1721 vec![(
1722 "primary".into(),
1723 Box::new(MockModelProvider {
1724 calls: Arc::clone(&calls),
1725 fail_until_attempt: 0,
1726 response: "ok",
1727 error: "boom",
1728 }),
1729 )],
1730 2,
1731 1,
1732 );
1733 let result = model_provider
1735 .simple_chat("hello", "test", Some(0.0))
1736 .await
1737 .unwrap();
1738 assert_eq!(result, "ok");
1739 assert_eq!(calls.load(Ordering::SeqCst), 1);
1740 }
1741
1742 #[tokio::test]
1745 async fn auth_rotation_cycles_keys() {
1746 let model_provider = ReliableModelProvider::new(
1747 "test",
1748 vec![(
1749 "p".into(),
1750 Box::new(MockModelProvider {
1751 calls: Arc::new(AtomicUsize::new(0)),
1752 fail_until_attempt: 0,
1753 response: "ok",
1754 error: "",
1755 }),
1756 )],
1757 0,
1758 1,
1759 )
1760 .with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]);
1761
1762 let keys: Vec<&str> = (0..5)
1764 .map(|_| model_provider.rotate_key().unwrap())
1765 .collect();
1766 assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]);
1767 }
1768
1769 #[tokio::test]
1770 async fn auth_rotation_returns_none_when_empty() {
1771 let model_provider = ReliableModelProvider::new("test", vec![], 0, 1);
1772 assert!(model_provider.rotate_key().is_none());
1773 }
1774
1775 #[test]
1778 fn parse_retry_after_integer() {
1779 let err = anyhow::Error::msg("429 Too Many Requests, Retry-After: 5");
1780 assert_eq!(parse_retry_after_ms(&err), Some(5000));
1781 }
1782
1783 #[test]
1784 fn parse_retry_after_float() {
1785 let err = anyhow::Error::msg("Rate limited. retry_after: 2.5 seconds");
1786 assert_eq!(parse_retry_after_ms(&err), Some(2500));
1787 }
1788
1789 #[test]
1790 fn parse_retry_after_missing() {
1791 let err = anyhow::Error::msg("500 Internal Server Error");
1792 assert_eq!(parse_retry_after_ms(&err), None);
1793 }
1794
1795 #[test]
1796 fn rate_limited_detection() {
1797 assert!(is_rate_limited(&anyhow::Error::msg(
1798 "429 Too Many Requests"
1799 )));
1800 assert!(is_rate_limited(&anyhow::Error::msg(
1801 "HTTP 429 rate limit exceeded"
1802 )));
1803 assert!(!is_rate_limited(&anyhow::Error::msg("401 Unauthorized")));
1804 assert!(!is_rate_limited(&anyhow::Error::msg(
1805 "500 Internal Server Error"
1806 )));
1807 }
1808
1809 #[test]
1810 fn non_retryable_rate_limit_detects_plan_restricted_model() {
1811 let err = anyhow::Error::msg(
1812 "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"the current account plan does not include glm-5\"}",
1813 );
1814 assert!(
1815 is_non_retryable_rate_limit(&err),
1816 "plan-restricted 429 should skip retries"
1817 );
1818 }
1819
1820 #[test]
1821 fn non_retryable_rate_limit_detects_insufficient_balance() {
1822 let err = anyhow::Error::msg(
1823 "API error (429 Too Many Requests): {\"code\":1113,\"message\":\"insufficient balance\"}",
1824 );
1825 assert!(
1826 is_non_retryable_rate_limit(&err),
1827 "insufficient-balance 429 should skip retries"
1828 );
1829 }
1830
1831 #[test]
1832 fn non_retryable_rate_limit_does_not_flag_generic_429() {
1833 let err = anyhow::Error::msg("429 Too Many Requests: rate limit exceeded");
1834 assert!(
1835 !is_non_retryable_rate_limit(&err),
1836 "generic rate-limit 429 should remain retryable"
1837 );
1838 }
1839
1840 #[test]
1841 fn compute_backoff_uses_retry_after() {
1842 let model_provider = ReliableModelProvider::new("test", vec![], 0, 500);
1843 let err = anyhow::Error::msg("429 Retry-After: 3");
1844 assert_eq!(model_provider.compute_backoff(500, &err), 3_000);
1845 }
1846
1847 #[test]
1848 fn compute_backoff_caps_at_30s() {
1849 let model_provider = ReliableModelProvider::new("test", vec![], 0, 500);
1850 let err = anyhow::Error::msg("429 Retry-After: 120");
1851 assert_eq!(model_provider.compute_backoff(500, &err), 30_000);
1852 }
1853
1854 #[test]
1855 fn compute_backoff_falls_back_to_base() {
1856 let model_provider = ReliableModelProvider::new("test", vec![], 0, 500);
1857 let err = anyhow::Error::msg("500 Server Error");
1858 assert_eq!(model_provider.compute_backoff(500, &err), 500);
1859 }
1860
1861 #[test]
1864 fn non_retryable_detects_401() {
1865 let err = anyhow::Error::msg("API error (401 Unauthorized): invalid api key");
1866 assert!(
1867 is_non_retryable(&err),
1868 "401 errors must be detected as non-retryable"
1869 );
1870 }
1871
1872 #[test]
1873 fn non_retryable_detects_403() {
1874 let err = anyhow::Error::msg("API error (403 Forbidden): access denied");
1875 assert!(
1876 is_non_retryable(&err),
1877 "403 errors must be detected as non-retryable"
1878 );
1879 }
1880
1881 #[test]
1882 fn non_retryable_detects_404() {
1883 let err = anyhow::Error::msg("API error (404 Not Found): model not found");
1884 assert!(
1885 is_non_retryable(&err),
1886 "404 errors must be detected as non-retryable"
1887 );
1888 }
1889
1890 #[test]
1891 fn non_retryable_does_not_flag_429() {
1892 let err = anyhow::Error::msg("429 Too Many Requests");
1893 assert!(
1894 !is_non_retryable(&err),
1895 "429 must NOT be treated as non-retryable (it is retryable with backoff)"
1896 );
1897 }
1898
1899 #[test]
1900 fn non_retryable_does_not_flag_408() {
1901 let err = anyhow::Error::msg("408 Request Timeout");
1902 assert!(
1903 !is_non_retryable(&err),
1904 "408 must NOT be treated as non-retryable (it is retryable)"
1905 );
1906 }
1907
1908 #[test]
1909 fn non_retryable_does_not_flag_500() {
1910 let err = anyhow::Error::msg("500 Internal Server Error");
1911 assert!(
1912 !is_non_retryable(&err),
1913 "500 must NOT be treated as non-retryable (server errors are retryable)"
1914 );
1915 }
1916
1917 #[test]
1918 fn non_retryable_does_not_flag_502() {
1919 let err = anyhow::Error::msg("502 Bad Gateway");
1920 assert!(
1921 !is_non_retryable(&err),
1922 "502 must NOT be treated as non-retryable"
1923 );
1924 }
1925
1926 #[test]
1929 fn parse_retry_after_zero() {
1930 let err = anyhow::Error::msg("429 Too Many Requests, Retry-After: 0");
1931 assert_eq!(
1932 parse_retry_after_ms(&err),
1933 Some(0),
1934 "Retry-After: 0 should parse as 0ms"
1935 );
1936 }
1937
1938 #[test]
1939 fn parse_retry_after_with_underscore_separator() {
1940 let err = anyhow::Error::msg("rate limited, retry_after: 10");
1941 assert_eq!(
1942 parse_retry_after_ms(&err),
1943 Some(10_000),
1944 "retry_after with underscore must be parsed"
1945 );
1946 }
1947
1948 #[test]
1949 fn parse_retry_after_space_separator() {
1950 let err = anyhow::Error::msg("Retry-After 7");
1951 assert_eq!(
1952 parse_retry_after_ms(&err),
1953 Some(7000),
1954 "Retry-After with space separator must be parsed"
1955 );
1956 }
1957
1958 #[test]
1959 fn rate_limited_false_for_generic_error() {
1960 let err = anyhow::Error::msg("Connection refused");
1961 assert!(
1962 !is_rate_limited(&err),
1963 "generic errors must not be flagged as rate-limited"
1964 );
1965 }
1966
1967 #[tokio::test]
1970 async fn non_retryable_skips_retries_for_401() {
1971 let calls = Arc::new(AtomicUsize::new(0));
1972 let model_provider = ReliableModelProvider::new(
1973 "test",
1974 vec![(
1975 "primary".into(),
1976 Box::new(MockModelProvider {
1977 calls: Arc::clone(&calls),
1978 fail_until_attempt: usize::MAX,
1979 response: "never",
1980 error: "API error (401 Unauthorized): invalid key",
1981 }),
1982 )],
1983 5,
1984 1,
1985 );
1986
1987 let result = model_provider.simple_chat("hello", "test", Some(0.0)).await;
1988 assert!(result.is_err(), "401 should fail without retries");
1989 assert_eq!(
1990 calls.load(Ordering::SeqCst),
1991 1,
1992 "must not retry on 401 — should be exactly 1 call"
1993 );
1994 }
1995
1996 #[tokio::test]
1997 async fn non_retryable_rate_limit_skips_retries_for_plan_errors() {
1998 let calls = Arc::new(AtomicUsize::new(0));
1999 let model_provider = ReliableModelProvider::new(
2000 "test",
2001 vec![(
2002 "primary".into(),
2003 Box::new(MockModelProvider {
2004 calls: Arc::clone(&calls),
2005 fail_until_attempt: usize::MAX,
2006 response: "never",
2007 error: "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"plan does not include glm-5\"}",
2008 }),
2009 )],
2010 5,
2011 1,
2012 );
2013
2014 let result = model_provider.simple_chat("hello", "test", Some(0.0)).await;
2015 assert!(
2016 result.is_err(),
2017 "plan-restricted 429 should fail quickly without retrying"
2018 );
2019 assert_eq!(
2020 calls.load(Ordering::SeqCst),
2021 1,
2022 "must not retry non-retryable 429 business errors"
2023 );
2024 }
2025
2026 struct NativeToolMock {
2030 calls: Arc<AtomicUsize>,
2031 fail_until_attempt: usize,
2032 response_text: &'static str,
2033 tool_calls: Vec<super::super::traits::ToolCall>,
2034 error: &'static str,
2035 }
2036
2037 #[async_trait]
2038 impl ModelProvider for NativeToolMock {
2039 async fn chat_with_system(
2040 &self,
2041 _system_prompt: Option<&str>,
2042 _message: &str,
2043 _model: &str,
2044 _temperature: Option<f64>,
2045 ) -> anyhow::Result<String> {
2046 Ok(self.response_text.to_string())
2047 }
2048
2049 fn supports_native_tools(&self) -> bool {
2050 true
2051 }
2052
2053 async fn chat(
2054 &self,
2055 _request: ChatRequest<'_>,
2056 _model: &str,
2057 _temperature: Option<f64>,
2058 ) -> anyhow::Result<ChatResponse> {
2059 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
2060 if attempt <= self.fail_until_attempt {
2061 anyhow::bail!(self.error);
2062 }
2063 Ok(ChatResponse {
2064 text: Some(self.response_text.to_string()),
2065 tool_calls: self.tool_calls.clone(),
2066 usage: None,
2067 reasoning_content: None,
2068 })
2069 }
2070 }
2071 impl ::zeroclaw_api::attribution::Attributable for NativeToolMock {
2072 fn role(&self) -> ::zeroclaw_api::attribution::Role {
2073 ::zeroclaw_api::attribution::Role::Provider(
2074 ::zeroclaw_api::attribution::ProviderKind::Model(
2075 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2076 ),
2077 )
2078 }
2079 fn alias(&self) -> &str {
2080 "NativeToolMock"
2081 }
2082 }
2083
2084 #[tokio::test]
2085 async fn chat_delegates_to_inner_provider() {
2086 let calls = Arc::new(AtomicUsize::new(0));
2087 let tool_call = super::super::traits::ToolCall {
2088 id: "call_1".to_string(),
2089 name: "shell".to_string(),
2090 arguments: r#"{"command":"date"}"#.to_string(),
2091 extra_content: None,
2092 };
2093 let model_provider = ReliableModelProvider::new(
2094 "test",
2095 vec![(
2096 "primary".into(),
2097 Box::new(NativeToolMock {
2098 calls: Arc::clone(&calls),
2099 fail_until_attempt: 0,
2100 response_text: "ok",
2101 tool_calls: vec![tool_call.clone()],
2102 error: "boom",
2103 }) as Box<dyn ModelProvider>,
2104 )],
2105 2,
2106 1,
2107 );
2108
2109 let messages = vec![ChatMessage::user("what time is it?")];
2110 let request = ChatRequest {
2111 messages: &messages,
2112 tools: None,
2113 thinking: None,
2114 };
2115 let result = model_provider
2116 .chat(request, "test-model", Some(0.0))
2117 .await
2118 .unwrap();
2119
2120 assert_eq!(result.text.as_deref(), Some("ok"));
2121 assert_eq!(result.tool_calls.len(), 1);
2122 assert_eq!(result.tool_calls[0].name, "shell");
2123 assert_eq!(calls.load(Ordering::SeqCst), 1);
2124 }
2125
2126 #[tokio::test]
2127 async fn chat_retries_and_recovers() {
2128 let calls = Arc::new(AtomicUsize::new(0));
2129 let tool_call = super::super::traits::ToolCall {
2130 id: "call_1".to_string(),
2131 name: "shell".to_string(),
2132 arguments: r#"{"command":"date"}"#.to_string(),
2133 extra_content: None,
2134 };
2135 let model_provider = ReliableModelProvider::new(
2136 "test",
2137 vec![(
2138 "primary".into(),
2139 Box::new(NativeToolMock {
2140 calls: Arc::clone(&calls),
2141 fail_until_attempt: 2,
2142 response_text: "recovered",
2143 tool_calls: vec![tool_call],
2144 error: "temporary failure",
2145 }) as Box<dyn ModelProvider>,
2146 )],
2147 3,
2148 1,
2149 );
2150
2151 let messages = vec![ChatMessage::user("test")];
2152 let request = ChatRequest {
2153 messages: &messages,
2154 tools: None,
2155 thinking: None,
2156 };
2157 let result = model_provider
2158 .chat(request, "test-model", Some(0.0))
2159 .await
2160 .unwrap();
2161
2162 assert_eq!(result.text.as_deref(), Some("recovered"));
2163 assert!(
2164 calls.load(Ordering::SeqCst) > 1,
2165 "should have retried at least once"
2166 );
2167 }
2168
2169 #[tokio::test]
2170 async fn chat_preserves_native_tools_support() {
2171 let calls = Arc::new(AtomicUsize::new(0));
2172 let model_provider = ReliableModelProvider::new(
2173 "test",
2174 vec![(
2175 "primary".into(),
2176 Box::new(NativeToolMock {
2177 calls: Arc::clone(&calls),
2178 fail_until_attempt: 0,
2179 response_text: "ok",
2180 tool_calls: vec![],
2181 error: "boom",
2182 }) as Box<dyn ModelProvider>,
2183 )],
2184 2,
2185 1,
2186 );
2187
2188 assert!(
2189 model_provider.supports_native_tools(),
2190 "ReliableModelProvider must propagate supports_native_tools from inner model_provider"
2191 );
2192 }
2193
2194 #[tokio::test]
2199 async fn chat_returns_aggregated_error_when_all_providers_fail() {
2200 let model_provider = ReliableModelProvider::new(
2201 "test",
2202 vec![
2203 (
2204 "p1".into(),
2205 Box::new(NativeToolMock {
2206 calls: Arc::new(AtomicUsize::new(0)),
2207 fail_until_attempt: usize::MAX,
2208 response_text: "never",
2209 tool_calls: vec![],
2210 error: "p1 chat error",
2211 }) as Box<dyn ModelProvider>,
2212 ),
2213 (
2214 "p2".into(),
2215 Box::new(NativeToolMock {
2216 calls: Arc::new(AtomicUsize::new(0)),
2217 fail_until_attempt: usize::MAX,
2218 response_text: "never",
2219 tool_calls: vec![],
2220 error: "p2 chat error",
2221 }) as Box<dyn ModelProvider>,
2222 ),
2223 ],
2224 0,
2225 1,
2226 );
2227
2228 let messages = vec![ChatMessage::user("hello")];
2229 let request = ChatRequest {
2230 messages: &messages,
2231 tools: None,
2232 thinking: None,
2233 };
2234 let err = model_provider
2235 .chat(request, "test", Some(0.0))
2236 .await
2237 .expect_err("all model_providers should fail");
2238 let msg = err.to_string();
2239 assert!(msg.contains("All model_providers/models failed"));
2240 assert!(msg.contains("model_provider=p1 model=test"));
2241 assert!(msg.contains("model_provider=p2 model=test"));
2242 assert!(msg.contains("error=p1 chat error"));
2243 assert!(msg.contains("error=p2 chat error"));
2244 assert!(msg.contains("retryable"));
2245 }
2246
2247 struct NativeModelAwareMock {
2250 calls: Arc<AtomicUsize>,
2251 models_seen: parking_lot::Mutex<Vec<String>>,
2252 fail_models: Vec<&'static str>,
2253 response_text: &'static str,
2254 }
2255
2256 #[async_trait]
2257 impl ModelProvider for NativeModelAwareMock {
2258 async fn chat_with_system(
2259 &self,
2260 _system_prompt: Option<&str>,
2261 _message: &str,
2262 _model: &str,
2263 _temperature: Option<f64>,
2264 ) -> anyhow::Result<String> {
2265 Ok(self.response_text.to_string())
2266 }
2267
2268 fn supports_native_tools(&self) -> bool {
2269 true
2270 }
2271
2272 async fn chat(
2273 &self,
2274 _request: ChatRequest<'_>,
2275 model: &str,
2276 _temperature: Option<f64>,
2277 ) -> anyhow::Result<ChatResponse> {
2278 self.calls.fetch_add(1, Ordering::SeqCst);
2279 self.models_seen.lock().push(model.to_string());
2280 if self.fail_models.contains(&model) {
2281 anyhow::bail!("500 model {} unavailable", model);
2282 }
2283 Ok(ChatResponse {
2284 text: Some(self.response_text.to_string()),
2285 tool_calls: vec![],
2286 usage: None,
2287 reasoning_content: None,
2288 })
2289 }
2290 }
2291 impl ::zeroclaw_api::attribution::Attributable for NativeModelAwareMock {
2292 fn role(&self) -> ::zeroclaw_api::attribution::Role {
2293 ::zeroclaw_api::attribution::Role::Provider(
2294 ::zeroclaw_api::attribution::ProviderKind::Model(
2295 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2296 ),
2297 )
2298 }
2299 fn alias(&self) -> &str {
2300 "NativeModelAwareMock"
2301 }
2302 }
2303
2304 #[tokio::test]
2309 async fn chat_tries_model_failover_on_failure() {
2310 let calls = Arc::new(AtomicUsize::new(0));
2311 let mock = Arc::new(NativeModelAwareMock {
2312 calls: Arc::clone(&calls),
2313 models_seen: parking_lot::Mutex::new(Vec::new()),
2314 fail_models: vec!["claude-opus"],
2315 response_text: "ok from sonnet",
2316 });
2317
2318 let mut fallbacks = HashMap::new();
2319 fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
2320
2321 let model_provider = ReliableModelProvider::new(
2322 "test",
2323 vec![(
2324 "anthropic".into(),
2325 Box::new(mock.clone()) as Box<dyn ModelProvider>,
2326 )],
2327 0, 1,
2329 )
2330 .with_model_fallbacks(fallbacks);
2331
2332 let messages = vec![ChatMessage::user("hello")];
2333 let request = ChatRequest {
2334 messages: &messages,
2335 tools: None,
2336 thinking: None,
2337 };
2338 let result = model_provider
2339 .chat(request, "claude-opus", Some(0.0))
2340 .await
2341 .unwrap();
2342 assert_eq!(result.text.as_deref(), Some("ok from sonnet"));
2343
2344 let seen = mock.models_seen.lock();
2345 assert_eq!(seen.len(), 2);
2346 assert_eq!(seen[0], "claude-opus");
2347 assert_eq!(seen[1], "claude-sonnet");
2348 }
2349
2350 #[tokio::test]
2353 async fn chat_skips_non_retryable_errors() {
2354 let primary_calls = Arc::new(AtomicUsize::new(0));
2355 let fallback_calls = Arc::new(AtomicUsize::new(0));
2356
2357 let model_provider = ReliableModelProvider::new(
2358 "test",
2359 vec![
2360 (
2361 "primary".into(),
2362 Box::new(NativeToolMock {
2363 calls: Arc::clone(&primary_calls),
2364 fail_until_attempt: usize::MAX,
2365 response_text: "never",
2366 tool_calls: vec![],
2367 error: "401 Unauthorized",
2368 }) as Box<dyn ModelProvider>,
2369 ),
2370 (
2371 "fallback".into(),
2372 Box::new(NativeToolMock {
2373 calls: Arc::clone(&fallback_calls),
2374 fail_until_attempt: 0,
2375 response_text: "from fallback",
2376 tool_calls: vec![],
2377 error: "fallback err",
2378 }) as Box<dyn ModelProvider>,
2379 ),
2380 ],
2381 3,
2382 1,
2383 );
2384
2385 let messages = vec![ChatMessage::user("hello")];
2386 let request = ChatRequest {
2387 messages: &messages,
2388 tools: None,
2389 thinking: None,
2390 };
2391 let result = model_provider
2392 .chat(request, "test", Some(0.0))
2393 .await
2394 .unwrap();
2395 assert_eq!(result.text.as_deref(), Some("from fallback"));
2396 assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
2398 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
2399 }
2400
2401 #[test]
2404 fn context_window_error_is_not_non_retryable() {
2405 assert!(!is_non_retryable(&anyhow::Error::msg(
2407 "exceeds the context window"
2408 )));
2409 assert!(!is_non_retryable(&anyhow::Error::msg(
2410 "maximum context length exceeded"
2411 )));
2412 assert!(!is_non_retryable(&anyhow::Error::msg(
2413 "too many tokens in the request"
2414 )));
2415 assert!(!is_non_retryable(&anyhow::Error::msg(
2416 "token limit exceeded"
2417 )));
2418 }
2419
2420 #[test]
2421 fn is_context_window_exceeded_detects_llamacpp() {
2422 assert!(is_context_window_exceeded(&anyhow::Error::msg(
2423 "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
2424 )));
2425 }
2426
2427 #[test]
2428 fn truncate_for_context_drops_oldest_non_system() {
2429 let mut messages = vec![
2430 ChatMessage::system("sys"),
2431 ChatMessage::user("msg1"),
2432 ChatMessage::assistant("resp1"),
2433 ChatMessage::user("msg2"),
2434 ChatMessage::assistant("resp2"),
2435 ChatMessage::user("msg3"),
2436 ];
2437
2438 let dropped = truncate_for_context(&mut messages);
2439
2440 assert_eq!(dropped, 2);
2442 assert_eq!(messages[0].role, "system");
2444 assert_eq!(messages.len(), 4); assert_eq!(messages.last().unwrap().content, "msg3");
2448 }
2449
2450 #[test]
2451 fn truncate_for_context_preserves_system_and_last_message() {
2452 let mut messages = vec![ChatMessage::system("sys"), ChatMessage::user("only")];
2454 let dropped = truncate_for_context(&mut messages);
2455 assert_eq!(dropped, 0);
2456 assert_eq!(messages.len(), 2);
2457
2458 let mut messages = vec![ChatMessage::user("only")];
2460 let dropped = truncate_for_context(&mut messages);
2461 assert_eq!(dropped, 0);
2462 assert_eq!(messages.len(), 1);
2463 }
2464
2465 struct ContextOverflowMock {
2468 calls: Arc<AtomicUsize>,
2469 fail_until_attempt: usize,
2470 message_counts: parking_lot::Mutex<Vec<usize>>,
2471 }
2472
2473 #[async_trait]
2474 impl ModelProvider for ContextOverflowMock {
2475 async fn chat_with_system(
2476 &self,
2477 _system_prompt: Option<&str>,
2478 _message: &str,
2479 _model: &str,
2480 _temperature: Option<f64>,
2481 ) -> anyhow::Result<String> {
2482 Ok("ok".to_string())
2483 }
2484
2485 async fn chat_with_history(
2486 &self,
2487 messages: &[ChatMessage],
2488 _model: &str,
2489 _temperature: Option<f64>,
2490 ) -> anyhow::Result<String> {
2491 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
2492 self.message_counts.lock().push(messages.len());
2493 if attempt <= self.fail_until_attempt {
2494 anyhow::bail!(
2495 "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
2496 );
2497 }
2498 Ok("recovered after truncation".to_string())
2499 }
2500 }
2501 impl ::zeroclaw_api::attribution::Attributable for ContextOverflowMock {
2502 fn role(&self) -> ::zeroclaw_api::attribution::Role {
2503 ::zeroclaw_api::attribution::Role::Provider(
2504 ::zeroclaw_api::attribution::ProviderKind::Model(
2505 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2506 ),
2507 )
2508 }
2509 fn alias(&self) -> &str {
2510 "ContextOverflowMock"
2511 }
2512 }
2513
2514 #[tokio::test]
2515 async fn chat_with_history_truncates_on_context_overflow() {
2516 let calls = Arc::new(AtomicUsize::new(0));
2517 let mock = ContextOverflowMock {
2518 calls: Arc::clone(&calls),
2519 fail_until_attempt: 1, message_counts: parking_lot::Mutex::new(Vec::new()),
2521 };
2522
2523 let model_provider = ReliableModelProvider::new(
2524 "test",
2525 vec![("local".into(), Box::new(mock) as Box<dyn ModelProvider>)],
2526 3,
2527 1,
2528 );
2529
2530 let messages = vec![
2531 ChatMessage::system("system prompt"),
2532 ChatMessage::user("old message 1"),
2533 ChatMessage::assistant("old response 1"),
2534 ChatMessage::user("old message 2"),
2535 ChatMessage::assistant("old response 2"),
2536 ChatMessage::user("current question"),
2537 ];
2538
2539 let result = model_provider
2540 .chat_with_history(&messages, "local-model", Some(0.0))
2541 .await
2542 .unwrap();
2543 assert_eq!(result, "recovered after truncation");
2544 assert_eq!(calls.load(Ordering::SeqCst), 2);
2546 }
2547
2548 #[tokio::test]
2549 async fn context_overflow_with_no_history_to_truncate_bails_immediately() {
2550 let calls = Arc::new(AtomicUsize::new(0));
2551 let mock = ContextOverflowMock {
2552 calls: Arc::clone(&calls),
2553 fail_until_attempt: 999, message_counts: parking_lot::Mutex::new(Vec::new()),
2555 };
2556
2557 let model_provider = ReliableModelProvider::new(
2558 "test",
2559 vec![("local".into(), Box::new(mock) as Box<dyn ModelProvider>)],
2560 3,
2561 1,
2562 );
2563
2564 let messages = vec![
2566 ChatMessage::system("huge system prompt that exceeds context window"),
2567 ChatMessage::user("hello"),
2568 ];
2569
2570 let result = model_provider
2571 .chat_with_history(&messages, "local-model", Some(0.0))
2572 .await;
2573 assert!(result.is_err());
2574 let err_msg = result.unwrap_err().to_string();
2575 assert!(
2576 err_msg.contains("cannot be reduced further"),
2577 "Should bail with actionable message, got: {err_msg}"
2578 );
2579 assert_eq!(
2581 calls.load(Ordering::SeqCst),
2582 1,
2583 "Should not retry when truncation is impossible"
2584 );
2585 }
2586
2587 #[test]
2590 fn tool_schema_error_detects_groq_validation_failure() {
2591 let msg = r#"Groq API error (400 Bad Request): {"error":{"message":"tool call validation failed: attempted to call tool 'memory_recall' which was not in request"}}"#;
2592 let err = anyhow::Error::msg(msg.to_string());
2593 assert!(is_tool_schema_error(&err));
2594 }
2595
2596 #[test]
2597 fn tool_schema_error_detects_not_in_request() {
2598 let err = anyhow::Error::msg("tool 'search' was not in request");
2599 assert!(is_tool_schema_error(&err));
2600 }
2601
2602 #[test]
2603 fn tool_schema_error_detects_not_found_in_tool_list() {
2604 let err = anyhow::Error::msg("function 'foo' not found in tool list");
2605 assert!(is_tool_schema_error(&err));
2606 }
2607
2608 #[test]
2609 fn tool_schema_error_detects_invalid_tool_call() {
2610 let err = anyhow::Error::msg("invalid_tool_call: no matching function");
2611 assert!(is_tool_schema_error(&err));
2612 }
2613
2614 #[test]
2615 fn tool_schema_error_ignores_unrelated_errors() {
2616 let err = anyhow::Error::msg("invalid api key");
2617 assert!(!is_tool_schema_error(&err));
2618
2619 let err = anyhow::Error::msg("model not found");
2620 assert!(!is_tool_schema_error(&err));
2621 }
2622
2623 #[test]
2624 fn non_retryable_returns_false_for_tool_schema_400() {
2625 let msg = "400 Bad Request: tool call validation failed: attempted to call tool 'x' which was not in request";
2627 let err = anyhow::Error::msg(msg.to_string());
2628 assert!(!is_non_retryable(&err));
2629 }
2630
2631 #[test]
2632 fn non_retryable_returns_true_for_other_400_errors() {
2633 let err = anyhow::Error::msg("400 Bad Request: invalid api key provided");
2635 assert!(is_non_retryable(&err));
2636 }
2637
2638 struct StreamingToolEventMock {
2639 stream_calls: Arc<AtomicUsize>,
2640 supports_tool_events: bool,
2641 }
2642
2643 impl StreamingToolEventMock {
2644 fn new(supports_tool_events: bool) -> Self {
2645 Self {
2646 stream_calls: Arc::new(AtomicUsize::new(0)),
2647 supports_tool_events,
2648 }
2649 }
2650 }
2651
2652 #[async_trait]
2653 impl ModelProvider for StreamingToolEventMock {
2654 async fn chat_with_system(
2655 &self,
2656 _system_prompt: Option<&str>,
2657 _message: &str,
2658 _model: &str,
2659 _temperature: Option<f64>,
2660 ) -> anyhow::Result<String> {
2661 Ok("ok".to_string())
2662 }
2663
2664 fn supports_streaming(&self) -> bool {
2665 true
2666 }
2667
2668 fn supports_streaming_tool_events(&self) -> bool {
2669 self.supports_tool_events
2670 }
2671
2672 fn stream_chat(
2673 &self,
2674 _request: ChatRequest<'_>,
2675 _model: &str,
2676 _temperature: Option<f64>,
2677 _options: StreamOptions,
2678 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
2679 self.stream_calls.fetch_add(1, Ordering::SeqCst);
2680 stream::iter(vec![
2681 Ok(StreamEvent::ToolCall(super::super::traits::ToolCall {
2682 id: "call_1".to_string(),
2683 name: "shell".to_string(),
2684 arguments: r#"{"command":"date"}"#.to_string(),
2685 extra_content: None,
2686 })),
2687 Ok(StreamEvent::Final),
2688 ])
2689 .boxed()
2690 }
2691 }
2692 impl ::zeroclaw_api::attribution::Attributable for StreamingToolEventMock {
2693 fn role(&self) -> ::zeroclaw_api::attribution::Role {
2694 ::zeroclaw_api::attribution::Role::Provider(
2695 ::zeroclaw_api::attribution::ProviderKind::Model(
2696 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2697 ),
2698 )
2699 }
2700 fn alias(&self) -> &str {
2701 "StreamingToolEventMock"
2702 }
2703 }
2704
2705 #[tokio::test]
2708 async fn stream_chat_prefers_provider_with_tool_event_support() {
2709 let primary = Arc::new(StreamingToolEventMock::new(false));
2710 let fallback = Arc::new(StreamingToolEventMock::new(true));
2711 let model_provider = ReliableModelProvider::new(
2712 "test",
2713 vec![
2714 (
2715 "primary".into(),
2716 Box::new(Arc::clone(&primary)) as Box<dyn ModelProvider>,
2717 ),
2718 (
2719 "fallback".into(),
2720 Box::new(Arc::clone(&fallback)) as Box<dyn ModelProvider>,
2721 ),
2722 ],
2723 0,
2724 1,
2725 );
2726
2727 let messages = vec![ChatMessage::user("hello")];
2728 let tools = vec![ToolSpec {
2729 name: "shell".to_string(),
2730 description: "run shell".to_string(),
2731 parameters: serde_json::json!({
2732 "type": "object",
2733 "properties": {
2734 "command": { "type": "string" }
2735 }
2736 }),
2737 }];
2738 let mut stream = model_provider.stream_chat(
2739 ChatRequest {
2740 messages: &messages,
2741 tools: Some(&tools),
2742 thinking: None,
2743 },
2744 "model",
2745 Some(0.0),
2746 StreamOptions::new(true),
2747 );
2748
2749 let first = stream.next().await.unwrap().unwrap();
2750 let second = stream.next().await.unwrap().unwrap();
2751 assert!(stream.next().await.is_none());
2752
2753 match first {
2754 StreamEvent::ToolCall(call) => assert_eq!(call.name, "shell"),
2755 other => panic!("expected tool-call event, got {other:?}"),
2756 }
2757 assert!(matches!(second, StreamEvent::Final));
2758 assert_eq!(primary.stream_calls.load(Ordering::SeqCst), 0);
2759 assert_eq!(fallback.stream_calls.load(Ordering::SeqCst), 1);
2760 }
2761
2762 #[tokio::test]
2763 async fn stream_chat_errors_when_no_provider_supports_tool_events() {
2764 let primary = Arc::new(StreamingToolEventMock::new(false));
2765 let model_provider = ReliableModelProvider::new(
2766 "test",
2767 vec![(
2768 "primary".into(),
2769 Box::new(Arc::clone(&primary)) as Box<dyn ModelProvider>,
2770 )],
2771 0,
2772 1,
2773 );
2774
2775 let messages = vec![ChatMessage::user("hello")];
2776 let tools = vec![ToolSpec {
2777 name: "shell".to_string(),
2778 description: "run shell".to_string(),
2779 parameters: serde_json::json!({"type": "object"}),
2780 }];
2781 let mut stream = model_provider.stream_chat(
2782 ChatRequest {
2783 messages: &messages,
2784 tools: Some(&tools),
2785 thinking: None,
2786 },
2787 "model",
2788 Some(0.0),
2789 StreamOptions::new(true),
2790 );
2791
2792 let first = stream.next().await.unwrap();
2793 let err = first.expect_err("stream should fail without tool-event support");
2794 assert!(
2795 err.to_string()
2796 .contains("No model_provider supports streaming tool events"),
2797 "unexpected stream error: {err}"
2798 );
2799 assert!(stream.next().await.is_none());
2800 assert_eq!(primary.stream_calls.load(Ordering::SeqCst), 0);
2801 }
2802
2803 struct StreamingHistoryMock {
2807 stream_calls: Arc<AtomicUsize>,
2808 supports: bool,
2809 }
2810
2811 #[async_trait]
2812 impl ModelProvider for StreamingHistoryMock {
2813 async fn chat_with_system(
2814 &self,
2815 _system_prompt: Option<&str>,
2816 _message: &str,
2817 _model: &str,
2818 _temperature: Option<f64>,
2819 ) -> anyhow::Result<String> {
2820 Ok("ok".to_string())
2821 }
2822
2823 fn supports_streaming(&self) -> bool {
2824 self.supports
2825 }
2826
2827 fn stream_chat_with_history(
2828 &self,
2829 messages: &[ChatMessage],
2830 _model: &str,
2831 _temperature: Option<f64>,
2832 _options: StreamOptions,
2833 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
2834 self.stream_calls.fetch_add(1, Ordering::SeqCst);
2835 let msg_count = messages.len().to_string();
2837 stream::iter(vec![
2838 Ok(StreamChunk::delta(msg_count)),
2839 Ok(StreamChunk::final_chunk()),
2840 ])
2841 .boxed()
2842 }
2843 }
2844 impl ::zeroclaw_api::attribution::Attributable for StreamingHistoryMock {
2845 fn role(&self) -> ::zeroclaw_api::attribution::Role {
2846 ::zeroclaw_api::attribution::Role::Provider(
2847 ::zeroclaw_api::attribution::ProviderKind::Model(
2848 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2849 ),
2850 )
2851 }
2852 fn alias(&self) -> &str {
2853 "StreamingHistoryMock"
2854 }
2855 }
2856
2857 #[tokio::test]
2858 async fn stream_chat_with_history_delegates_to_streaming_provider() {
2859 let calls = Arc::new(AtomicUsize::new(0));
2860 let model_provider = ReliableModelProvider::new(
2861 "test",
2862 vec![(
2863 "primary".into(),
2864 Box::new(StreamingHistoryMock {
2865 stream_calls: Arc::clone(&calls),
2866 supports: true,
2867 }) as Box<dyn ModelProvider>,
2868 )],
2869 0,
2870 1,
2871 );
2872
2873 let messages = vec![
2874 ChatMessage::system("system"),
2875 ChatMessage::user("msg1"),
2876 ChatMessage::assistant("resp1"),
2877 ChatMessage::user("msg2"),
2878 ];
2879 let mut stream = model_provider.stream_chat_with_history(
2880 &messages,
2881 "model",
2882 Some(0.0),
2883 StreamOptions::new(true),
2884 );
2885
2886 let first = stream.next().await.unwrap().unwrap();
2887 assert_eq!(
2888 first.delta, "4",
2889 "should pass all 4 messages to model_provider"
2890 );
2891 let second = stream.next().await.unwrap().unwrap();
2892 assert!(second.is_final);
2893 assert!(stream.next().await.is_none());
2894 assert_eq!(calls.load(Ordering::SeqCst), 1);
2895 }
2896
2897 #[tokio::test]
2898 async fn stream_chat_with_history_skips_non_streaming_providers() {
2899 let non_streaming_calls = Arc::new(AtomicUsize::new(0));
2900 let streaming_calls = Arc::new(AtomicUsize::new(0));
2901
2902 let model_provider = ReliableModelProvider::new(
2903 "test",
2904 vec![
2905 (
2906 "non-streaming".into(),
2907 Box::new(StreamingHistoryMock {
2908 stream_calls: Arc::clone(&non_streaming_calls),
2909 supports: false,
2910 }) as Box<dyn ModelProvider>,
2911 ),
2912 (
2913 "streaming".into(),
2914 Box::new(StreamingHistoryMock {
2915 stream_calls: Arc::clone(&streaming_calls),
2916 supports: true,
2917 }) as Box<dyn ModelProvider>,
2918 ),
2919 ],
2920 0,
2921 1,
2922 );
2923
2924 let messages = vec![ChatMessage::user("hello")];
2925 let mut stream = model_provider.stream_chat_with_history(
2926 &messages,
2927 "model",
2928 Some(0.0),
2929 StreamOptions::new(true),
2930 );
2931
2932 let first = stream.next().await.unwrap().unwrap();
2933 assert_eq!(first.delta, "1");
2934 assert_eq!(
2935 non_streaming_calls.load(Ordering::SeqCst),
2936 0,
2937 "non-streaming model_provider should be skipped"
2938 );
2939 assert_eq!(
2940 streaming_calls.load(Ordering::SeqCst),
2941 1,
2942 "streaming model_provider should be used"
2943 );
2944 }
2945
2946 #[tokio::test]
2947 async fn stream_chat_with_history_errors_when_no_provider_supports_streaming() {
2948 let model_provider = ReliableModelProvider::new(
2949 "test",
2950 vec![(
2951 "non-streaming".into(),
2952 Box::new(StreamingHistoryMock {
2953 stream_calls: Arc::new(AtomicUsize::new(0)),
2954 supports: false,
2955 }) as Box<dyn ModelProvider>,
2956 )],
2957 0,
2958 1,
2959 );
2960
2961 let messages = vec![ChatMessage::user("hello")];
2962 let mut stream = model_provider.stream_chat_with_history(
2963 &messages,
2964 "model",
2965 Some(0.0),
2966 StreamOptions::new(true),
2967 );
2968
2969 let first = stream.next().await.unwrap();
2970 let err = first.expect_err("should fail when no model_provider supports streaming");
2971 assert!(
2972 err.to_string()
2973 .contains("No model_provider supports streaming"),
2974 "unexpected error: {err}"
2975 );
2976 assert!(stream.next().await.is_none());
2977 }
2978
2979 #[tokio::test]
2980 async fn fallback_records_provider_fallback_info() {
2981 scope_provider_fallback(async {
2982 let model_provider = ReliableModelProvider::new(
2983 "test",
2984 vec![
2985 (
2986 "broken".into(),
2987 Box::new(MockModelProvider {
2988 calls: Arc::new(AtomicUsize::new(0)),
2989 fail_until_attempt: 99, response: "unused",
2991 error: "401 Unauthorized",
2992 }),
2993 ),
2994 (
2995 "working".into(),
2996 Box::new(MockModelProvider {
2997 calls: Arc::new(AtomicUsize::new(0)),
2998 fail_until_attempt: 0,
2999 response: "hello from working",
3000 error: "unused",
3001 }),
3002 ),
3003 ],
3004 2,
3005 1,
3006 );
3007
3008 let resp = model_provider
3009 .simple_chat("hi", "test-model", Some(0.0))
3010 .await
3011 .unwrap();
3012 assert_eq!(resp, "hello from working");
3013
3014 let fb = take_last_provider_fallback();
3015 assert!(fb.is_some(), "fallback info should be recorded");
3016 let fb = fb.unwrap();
3017 assert_eq!(fb.requested_provider, "broken");
3018 assert_eq!(fb.actual_provider, "working");
3019 assert_eq!(fb.actual_model, "test-model");
3020
3021 assert!(take_last_provider_fallback().is_none());
3023 })
3024 .await;
3025 }
3026
3027 #[test]
3031 fn supports_vision_reflects_first_provider_not_any_fallback() {
3032 struct VisionMock(bool);
3033
3034 #[async_trait]
3035 impl ModelProvider for VisionMock {
3036 async fn chat_with_system(
3037 &self,
3038 _system_prompt: Option<&str>,
3039 _message: &str,
3040 _model: &str,
3041 _temperature: Option<f64>,
3042 ) -> anyhow::Result<String> {
3043 Ok(String::new())
3044 }
3045
3046 fn supports_vision(&self) -> bool {
3047 self.0
3048 }
3049 }
3050 impl ::zeroclaw_api::attribution::Attributable for VisionMock {
3051 fn role(&self) -> ::zeroclaw_api::attribution::Role {
3052 ::zeroclaw_api::attribution::Role::Provider(
3053 ::zeroclaw_api::attribution::ProviderKind::Model(
3054 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
3055 ),
3056 )
3057 }
3058 fn alias(&self) -> &str {
3059 "VisionMock"
3060 }
3061 }
3062
3063 let provider = ReliableModelProvider::new(
3064 "test",
3065 vec![
3066 (
3067 "primary".into(),
3068 Box::new(VisionMock(false)) as Box<dyn ModelProvider>,
3069 ),
3070 (
3071 "fallback".into(),
3072 Box::new(VisionMock(true)) as Box<dyn ModelProvider>,
3073 ),
3074 ],
3075 0,
3076 0,
3077 );
3078
3079 assert!(
3080 !provider.supports_vision(),
3081 "ReliableModelProvider with non-vision primary must report supports_vision()=false even when a fallback supports vision"
3082 );
3083
3084 let provider = ReliableModelProvider::new(
3085 "test",
3086 vec![
3087 (
3088 "primary".into(),
3089 Box::new(VisionMock(true)) as Box<dyn ModelProvider>,
3090 ),
3091 (
3092 "fallback".into(),
3093 Box::new(VisionMock(false)) as Box<dyn ModelProvider>,
3094 ),
3095 ],
3096 0,
3097 0,
3098 );
3099
3100 assert!(provider.supports_vision());
3101 }
3102}