1use super::ModelProvider;
2use super::stream_guard::AbortOnDrop;
3use super::traits::{
4 ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamEvent, StreamOptions, StreamResult,
5};
6use async_trait::async_trait;
7use futures_util::{StreamExt, stream};
8use std::cell::RefCell;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::time::Duration;
12
13#[derive(Debug, Clone)]
21pub struct ProviderFallbackInfo {
22 pub requested_provider: String,
24 pub requested_model: String,
26 pub actual_provider: String,
28 pub actual_model: String,
30}
31
32tokio::task_local! {
33 static PROVIDER_FALLBACK: RefCell<Option<ProviderFallbackInfo>>;
34}
35
36pub fn take_last_provider_fallback() -> Option<ProviderFallbackInfo> {
39 PROVIDER_FALLBACK
40 .try_with(|cell| cell.borrow_mut().take())
41 .ok()
42 .flatten()
43}
44
45pub async fn scope_provider_fallback<F: std::future::Future>(future: F) -> F::Output {
50 PROVIDER_FALLBACK.scope(RefCell::new(None), future).await
51}
52
53fn record_provider_fallback(
55 requested_provider: &str,
56 requested_model: &str,
57 actual_provider: &str,
58 actual_model: &str,
59) {
60 let _ = PROVIDER_FALLBACK.try_with(|cell| {
61 *cell.borrow_mut() = Some(ProviderFallbackInfo {
62 requested_provider: requested_provider.to_string(),
63 requested_model: requested_model.to_string(),
64 actual_provider: actual_provider.to_string(),
65 actual_model: actual_model.to_string(),
66 });
67 });
68}
69
70pub fn is_non_retryable(err: &anyhow::Error) -> bool {
78 if is_context_window_exceeded(err) {
81 return false;
82 }
83
84 if is_tool_schema_error(err) {
88 return false;
89 }
90
91 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
94 && let Some(status) = reqwest_err.status()
95 {
96 let code = status.as_u16();
97 return status.is_client_error() && code != 429 && code != 408;
98 }
99 let msg = err.to_string();
102 for word in msg.split(|c: char| !c.is_ascii_digit()) {
103 if let Ok(code) = word.parse::<u16>()
104 && (400..500).contains(&code)
105 {
106 return code != 429 && code != 408;
107 }
108 }
109
110 let msg_lower = msg.to_lowercase();
113 let auth_failure_hints = [
114 "invalid api key",
115 "incorrect api key",
116 "missing api key",
117 "api key not set",
118 "authentication failed",
119 "auth failed",
120 "unauthorized",
121 "forbidden",
122 "permission denied",
123 "access denied",
124 "invalid token",
125 ];
126
127 if auth_failure_hints
128 .iter()
129 .any(|hint| msg_lower.contains(hint))
130 {
131 return true;
132 }
133
134 msg_lower.contains("model")
135 && (msg_lower.contains("not found")
136 || msg_lower.contains("unknown")
137 || msg_lower.contains("unsupported")
138 || msg_lower.contains("does not exist")
139 || msg_lower.contains("invalid"))
140}
141
142pub fn is_auth_error(err: &anyhow::Error) -> bool {
146 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
147 && let Some(status) = reqwest_err.status()
148 {
149 let code = status.as_u16();
150 return code == 401 || code == 403;
151 }
152
153 let msg_lower = err.to_string().to_lowercase();
154 let hints = [
155 "401 unauthorized",
156 "403 forbidden",
157 "invalid api key",
158 "incorrect api key",
159 "authentication failed",
160 "auth failed",
161 "unauthorized",
162 "invalid token",
163 "token expired",
164 "access_token",
165 ];
166
167 hints.iter().any(|hint| msg_lower.contains(hint))
168}
169
170pub fn is_tool_schema_error(err: &anyhow::Error) -> bool {
176 let lower = err.to_string().to_lowercase();
177 let hints = [
178 "tool call validation failed",
179 "was not in request",
180 "not found in tool list",
181 "invalid_tool_call",
182 ];
183 hints.iter().any(|hint| lower.contains(hint))
184}
185
186pub fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
187 let lower = err.to_string().to_lowercase();
188 let hints = [
189 "exceeds the context window",
190 "exceeds the available context size",
191 "context window of this model",
192 "maximum context length",
193 "context length exceeded",
194 "too many tokens",
195 "token limit exceeded",
196 "prompt is too long",
197 "input is too long",
198 "prompt exceeds max length",
199 ];
200
201 hints.iter().any(|hint| lower.contains(hint))
202}
203
204fn is_rate_limited(err: &anyhow::Error) -> bool {
206 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
207 && let Some(status) = reqwest_err.status()
208 {
209 return status.as_u16() == 429;
210 }
211 let msg = err.to_string();
212 msg.contains("429")
213 && (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
214}
215
216fn is_non_retryable_rate_limit(err: &anyhow::Error) -> bool {
223 if !is_rate_limited(err) {
224 return false;
225 }
226
227 let msg = err.to_string();
228 let lower = msg.to_lowercase();
229
230 let business_hints = [
231 "plan does not include",
232 "doesn't include",
233 "not include",
234 "insufficient balance",
235 "insufficient_balance",
236 "insufficient quota",
237 "insufficient_quota",
238 "quota exhausted",
239 "out of credits",
240 "no available package",
241 "package not active",
242 "purchase package",
243 "model not available for your plan",
244 ];
245
246 if business_hints.iter().any(|hint| lower.contains(hint)) {
247 return true;
248 }
249
250 for token in lower.split(|c: char| !c.is_ascii_digit()) {
252 if let Ok(code) = token.parse::<u16>()
253 && matches!(code, 1113 | 1311)
254 {
255 return true;
256 }
257 }
258
259 false
260}
261
262fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
265 let msg = err.to_string();
266 let lower = msg.to_lowercase();
267
268 for prefix in &[
270 "retry-after:",
271 "retry_after:",
272 "retry-after ",
273 "retry_after ",
274 ] {
275 if let Some(pos) = lower.find(prefix) {
276 let after = &msg[pos + prefix.len()..];
277 let num_str: String = after
278 .trim()
279 .chars()
280 .take_while(|c| c.is_ascii_digit() || *c == '.')
281 .collect();
282 if let Ok(secs) = num_str.parse::<f64>()
283 && secs.is_finite()
284 && secs >= 0.0
285 {
286 let millis = Duration::from_secs_f64(secs).as_millis();
287 if let Ok(value) = u64::try_from(millis) {
288 return Some(value);
289 }
290 }
291 }
292 }
293 None
294}
295
296fn failure_reason(rate_limited: bool, non_retryable: bool) -> &'static str {
297 if rate_limited && non_retryable {
298 "rate_limited_non_retryable"
299 } else if rate_limited {
300 "rate_limited"
301 } else if non_retryable {
302 "non_retryable"
303 } else {
304 "retryable"
305 }
306}
307
308fn compact_error_detail(err: &anyhow::Error) -> String {
309 super::sanitize_api_error(&format!("{err:#}"))
310 .split_whitespace()
311 .collect::<Vec<_>>()
312 .join(" ")
313}
314
315fn truncate_for_context(messages: &mut Vec<ChatMessage>) -> usize {
319 let non_system: Vec<usize> = messages
321 .iter()
322 .enumerate()
323 .filter(|(_, m)| m.role != "system")
324 .map(|(i, _)| i)
325 .collect();
326
327 if non_system.len() <= 1 {
329 return 0;
330 }
331
332 let drop_count = non_system.len() / 2;
334 let indices_to_remove: Vec<usize> = non_system[..drop_count].to_vec();
335
336 for &idx in indices_to_remove.iter().rev() {
338 messages.remove(idx);
339 }
340
341 drop_count
342}
343
344fn push_failure(
345 failures: &mut Vec<String>,
346 provider_name: &str,
347 model: &str,
348 attempt: u32,
349 max_attempts: u32,
350 reason: &str,
351 error_detail: &str,
352) {
353 failures.push(format!(
354 "model_provider={provider_name} model={model} attempt {attempt}/{max_attempts}: {reason}; error={error_detail}"
355 ));
356}
357
358fn is_empty_completion(resp: &ChatResponse) -> bool {
368 resp.text_or_empty().trim().is_empty()
369 && resp.tool_calls.is_empty()
370 && resp
371 .reasoning_content
372 .as_deref()
373 .is_none_or(|r| r.trim().is_empty())
374}
375
376pub struct ReliableModelProvider {
391 alias: String,
393 model_providers: Vec<(String, Box<dyn ModelProvider>)>,
394 max_retries: u32,
395 base_backoff_ms: u64,
396 api_keys: Vec<String>,
398 key_index: AtomicUsize,
399 model_fallbacks: HashMap<String, Vec<String>>,
401}
402
403impl ReliableModelProvider {
404 pub fn new(
405 alias: &str,
406 model_providers: Vec<(String, Box<dyn ModelProvider>)>,
407 max_retries: u32,
408 base_backoff_ms: u64,
409 ) -> Self {
410 Self {
411 alias: alias.to_string(),
412 model_providers,
413 max_retries,
414 base_backoff_ms: base_backoff_ms.max(50),
415 api_keys: Vec::new(),
416 key_index: AtomicUsize::new(0),
417 model_fallbacks: HashMap::new(),
418 }
419 }
420 pub fn with_api_keys(mut self, keys: Vec<String>) -> Self {
422 self.api_keys = keys;
423 self
424 }
425
426 #[cfg(test)]
429 pub fn with_model_fallbacks(mut self, fallbacks: HashMap<String, Vec<String>>) -> Self {
430 self.model_fallbacks = fallbacks;
431 self
432 }
433
434 fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> {
436 let mut chain = vec![model];
437 if let Some(fallbacks) = self.model_fallbacks.get(model) {
438 chain.extend(fallbacks.iter().map(|s| s.as_str()));
439 }
440 chain
441 }
442
443 fn rotate_key(&self) -> Option<&str> {
445 if self.api_keys.is_empty() {
446 return None;
447 }
448 let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len();
449 Some(&self.api_keys[idx])
450 }
451
452 fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 {
454 if let Some(retry_after) = parse_retry_after_ms(err) {
455 retry_after.min(30_000).max(base)
457 } else {
458 base
459 }
460 }
461
462 async fn backoff_after_empty_completion(
467 &self,
468 failures: &mut Vec<String>,
469 provider_name: &str,
470 model: &str,
471 attempt: u32,
472 backoff_ms: &mut u64,
473 ) {
474 push_failure(
475 failures,
476 provider_name,
477 model,
478 attempt + 1,
479 self.max_retries + 1,
480 "empty_response",
481 "model_provider returned an empty completion",
482 );
483 ::zeroclaw_log::record!(
484 WARN,
485 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
486 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
487 .with_attrs(::serde_json::json!({
488 "model_provider": provider_name,
489 "model": model,
490 "attempt": attempt + 1,
491 "backoff_ms": *backoff_ms
492 })),
493 "Empty completion; retrying"
494 );
495 tokio::time::sleep(Duration::from_millis(*backoff_ms)).await;
496 *backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
497 }
498}
499
500#[async_trait]
501impl ModelProvider for ReliableModelProvider {
502 async fn warmup(&self) -> anyhow::Result<()> {
503 for (name, model_provider) in &self.model_providers {
504 ::zeroclaw_log::record!(
505 INFO,
506 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
507 .with_attrs(::serde_json::json!({"model_provider": name})),
508 "Warming up model_provider connection pool"
509 );
510 if model_provider.warmup().await.is_err() {
511 ::zeroclaw_log::record!(
512 WARN,
513 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
514 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
515 .with_attrs(::serde_json::json!({"model_provider": name})),
516 "Warmup failed (non-fatal)"
517 );
518 }
519 }
520 Ok(())
521 }
522
523 async fn chat_with_system(
524 &self,
525 system_prompt: Option<&str>,
526 message: &str,
527 model: &str,
528 temperature: Option<f64>,
529 ) -> anyhow::Result<String> {
530 let models = self.model_chain(model);
531 let mut failures = Vec::new();
532
533 for current_model in &models {
538 for (provider_name, model_provider) in &self.model_providers {
539 let mut backoff_ms = self.base_backoff_ms;
540
541 for attempt in 0..=self.max_retries {
542 match model_provider
543 .chat_with_system(system_prompt, message, current_model, temperature)
544 .await
545 {
546 Ok(resp) => {
547 if attempt < self.max_retries && resp.trim().is_empty() {
550 self.backoff_after_empty_completion(
551 &mut failures,
552 provider_name,
553 current_model,
554 attempt,
555 &mut backoff_ms,
556 )
557 .await;
558 continue;
559 }
560 if attempt > 0
561 || *current_model != model
562 || self.model_providers.first().map(|(n, _)| n.as_str())
563 != Some(provider_name)
564 {
565 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model})), "ModelProvider recovered (failover/retry)");
566 let primary = self
567 .model_providers
568 .first()
569 .map(|(n, _)| n.as_str())
570 .unwrap_or("");
571 record_provider_fallback(
572 primary,
573 model,
574 provider_name,
575 current_model,
576 );
577 }
578 return Ok(resp);
579 }
580 Err(e) => {
581 if is_context_window_exceeded(&e) {
584 let error_detail = compact_error_detail(&e);
585 push_failure(
586 &mut failures,
587 provider_name,
588 current_model,
589 attempt + 1,
590 self.max_retries + 1,
591 "non_retryable",
592 &error_detail,
593 );
594 anyhow::bail!(
595 "Request exceeds model context window. Attempts:\n{}",
596 failures.join("\n")
597 );
598 }
599
600 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
601 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
602 let rate_limited = is_rate_limited(&e);
603 let failure_reason = failure_reason(rate_limited, non_retryable);
604 let error_detail = compact_error_detail(&e);
605
606 push_failure(
607 &mut failures,
608 provider_name,
609 current_model,
610 attempt + 1,
611 self.max_retries + 1,
612 failure_reason,
613 &error_detail,
614 );
615
616 if rate_limited
619 && !non_retryable_rate_limit
620 && let Some(new_key) = self.rotate_key()
621 {
622 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
623 but cannot apply (ModelProvider trait has no set_api_key). \
624 Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
625 }
626
627 if non_retryable {
628 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
629 break;
630 }
631
632 if attempt < self.max_retries {
633 let wait = self.compute_backoff(backoff_ms, &e);
634 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
635 tokio::time::sleep(Duration::from_millis(wait)).await;
636 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
637 }
638 }
639 }
640 }
641
642 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
643 }
644
645 if *current_model != model {
646 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"original_model": model, "fallback_model": *current_model})), "Model fallback exhausted all model_providers, trying next fallback model");
647 }
648 }
649
650 anyhow::bail!(
651 "All model_providers/models failed. Attempts:\n{}",
652 failures.join("\n")
653 )
654 }
655
656 async fn chat_with_history(
657 &self,
658 messages: &[ChatMessage],
659 model: &str,
660 temperature: Option<f64>,
661 ) -> anyhow::Result<String> {
662 let models = self.model_chain(model);
663 let mut failures = Vec::new();
664 let mut effective_messages = messages.to_vec();
665 let mut context_truncated = false;
666
667 for current_model in &models {
668 for (provider_name, model_provider) in &self.model_providers {
669 let mut backoff_ms = self.base_backoff_ms;
670
671 for attempt in 0..=self.max_retries {
672 match model_provider
673 .chat_with_history(&effective_messages, current_model, temperature)
674 .await
675 {
676 Ok(resp) => {
677 if attempt < self.max_retries && resp.trim().is_empty() {
680 self.backoff_after_empty_completion(
681 &mut failures,
682 provider_name,
683 current_model,
684 attempt,
685 &mut backoff_ms,
686 )
687 .await;
688 continue;
689 }
690 if attempt > 0
691 || *current_model != model
692 || context_truncated
693 || self.model_providers.first().map(|(n, _)| n.as_str())
694 != Some(provider_name)
695 {
696 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model, "context_truncated": context_truncated})), "ModelProvider recovered (failover/retry)");
697 let primary = self
698 .model_providers
699 .first()
700 .map(|(n, _)| n.as_str())
701 .unwrap_or("");
702 record_provider_fallback(
703 primary,
704 model,
705 provider_name,
706 current_model,
707 );
708 }
709 return Ok(resp);
710 }
711 Err(e) => {
712 if is_context_window_exceeded(&e) && !context_truncated {
714 let dropped = truncate_for_context(&mut effective_messages);
715 if dropped > 0 {
716 context_truncated = true;
717 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "dropped": dropped, "remaining": effective_messages.len()})), "Context window exceeded; truncated history and retrying");
718 continue; }
720 let error_detail = compact_error_detail(&e);
724 push_failure(
725 &mut failures,
726 provider_name,
727 current_model,
728 attempt + 1,
729 self.max_retries + 1,
730 "non_retryable",
731 &error_detail,
732 );
733 anyhow::bail!(
734 "Request exceeds model context window and cannot be reduced further. \
735 Try using a model with a larger context window, reducing the number \
736 of tools/skills, or enabling compact_context in config. Attempts:\n{}",
737 failures.join("\n")
738 );
739 }
740
741 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
742 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
743 let rate_limited = is_rate_limited(&e);
744 let failure_reason = failure_reason(rate_limited, non_retryable);
745 let error_detail = compact_error_detail(&e);
746
747 push_failure(
748 &mut failures,
749 provider_name,
750 current_model,
751 attempt + 1,
752 self.max_retries + 1,
753 failure_reason,
754 &error_detail,
755 );
756
757 if rate_limited
758 && !non_retryable_rate_limit
759 && let Some(new_key) = self.rotate_key()
760 {
761 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
762 but cannot apply (ModelProvider trait has no set_api_key). \
763 Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
764 }
765
766 if non_retryable {
767 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
768 break;
769 }
770
771 if attempt < self.max_retries {
772 let wait = self.compute_backoff(backoff_ms, &e);
773 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
774 tokio::time::sleep(Duration::from_millis(wait)).await;
775 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
776 }
777 }
778 }
779 }
780
781 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
782 }
783 }
784
785 anyhow::bail!(
786 "All model_providers/models failed. Attempts:\n{}",
787 failures.join("\n")
788 )
789 }
790
791 fn supports_native_tools(&self) -> bool {
792 self.model_providers
793 .first()
794 .map(|(_, p)| p.supports_native_tools())
795 .unwrap_or(false)
796 }
797
798 fn supports_vision(&self) -> bool {
799 self.model_providers
800 .first()
801 .map(|(_, p)| p.supports_vision())
802 .unwrap_or(false)
803 }
804
805 async fn chat_with_tools(
806 &self,
807 messages: &[ChatMessage],
808 tools: &[serde_json::Value],
809 model: &str,
810 temperature: Option<f64>,
811 ) -> anyhow::Result<ChatResponse> {
812 let models = self.model_chain(model);
813 let mut failures = Vec::new();
814 let mut effective_messages = messages.to_vec();
815 let mut context_truncated = false;
816
817 for current_model in &models {
818 for (provider_name, model_provider) in &self.model_providers {
819 let mut backoff_ms = self.base_backoff_ms;
820
821 for attempt in 0..=self.max_retries {
822 match model_provider
823 .chat_with_tools(&effective_messages, tools, current_model, temperature)
824 .await
825 {
826 Ok(resp) => {
827 if attempt < self.max_retries && is_empty_completion(&resp) {
831 self.backoff_after_empty_completion(
832 &mut failures,
833 provider_name,
834 current_model,
835 attempt,
836 &mut backoff_ms,
837 )
838 .await;
839 continue;
840 }
841 if attempt > 0
842 || *current_model != model
843 || context_truncated
844 || self.model_providers.first().map(|(n, _)| n.as_str())
845 != Some(provider_name)
846 {
847 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model, "context_truncated": context_truncated})), "ModelProvider recovered (failover/retry)");
848 let primary = self
849 .model_providers
850 .first()
851 .map(|(n, _)| n.as_str())
852 .unwrap_or("");
853 record_provider_fallback(
854 primary,
855 model,
856 provider_name,
857 current_model,
858 );
859 }
860 return Ok(resp);
861 }
862 Err(e) => {
863 if is_context_window_exceeded(&e) && !context_truncated {
865 let dropped = truncate_for_context(&mut effective_messages);
866 if dropped > 0 {
867 context_truncated = true;
868 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "dropped": dropped, "remaining": effective_messages.len()})), "Context window exceeded; truncated history and retrying");
869 continue; }
871 let error_detail = compact_error_detail(&e);
875 push_failure(
876 &mut failures,
877 provider_name,
878 current_model,
879 attempt + 1,
880 self.max_retries + 1,
881 "non_retryable",
882 &error_detail,
883 );
884 anyhow::bail!(
885 "Request exceeds model context window and cannot be reduced further. \
886 Try using a model with a larger context window, reducing the number \
887 of tools/skills, or enabling compact_context in config. Attempts:\n{}",
888 failures.join("\n")
889 );
890 }
891
892 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
893 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
894 let rate_limited = is_rate_limited(&e);
895 let failure_reason = failure_reason(rate_limited, non_retryable);
896 let error_detail = compact_error_detail(&e);
897
898 push_failure(
899 &mut failures,
900 provider_name,
901 current_model,
902 attempt + 1,
903 self.max_retries + 1,
904 failure_reason,
905 &error_detail,
906 );
907
908 if rate_limited
909 && !non_retryable_rate_limit
910 && let Some(new_key) = self.rotate_key()
911 {
912 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
913 but cannot apply (ModelProvider trait has no set_api_key). \
914 Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
915 }
916
917 if non_retryable {
918 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
919 break;
920 }
921
922 if attempt < self.max_retries {
923 let wait = self.compute_backoff(backoff_ms, &e);
924 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
925 tokio::time::sleep(Duration::from_millis(wait)).await;
926 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
927 }
928 }
929 }
930 }
931
932 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
933 }
934 }
935
936 anyhow::bail!(
937 "All model_providers/models failed. Attempts:\n{}",
938 failures.join("\n")
939 )
940 }
941
942 async fn chat(
943 &self,
944 request: ChatRequest<'_>,
945 model: &str,
946 temperature: Option<f64>,
947 ) -> anyhow::Result<ChatResponse> {
948 let models = self.model_chain(model);
949 let mut failures = Vec::new();
950 let mut effective_messages = request.messages.to_vec();
951 let mut context_truncated = false;
952
953 for current_model in &models {
954 for (provider_name, model_provider) in &self.model_providers {
955 let mut backoff_ms = self.base_backoff_ms;
956
957 for attempt in 0..=self.max_retries {
958 let req = ChatRequest {
959 messages: &effective_messages,
960 tools: request.tools,
961 thinking: request.thinking,
962 };
963 match model_provider.chat(req, current_model, temperature).await {
964 Ok(resp) => {
965 if attempt < self.max_retries && is_empty_completion(&resp) {
969 self.backoff_after_empty_completion(
970 &mut failures,
971 provider_name,
972 current_model,
973 attempt,
974 &mut backoff_ms,
975 )
976 .await;
977 continue;
978 }
979 if attempt > 0
980 || *current_model != model
981 || context_truncated
982 || self.model_providers.first().map(|(n, _)| n.as_str())
983 != Some(provider_name)
984 {
985 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt, "original_model": model, "context_truncated": context_truncated})), "ModelProvider recovered (failover/retry)");
986 let primary = self
987 .model_providers
988 .first()
989 .map(|(n, _)| n.as_str())
990 .unwrap_or("");
991 record_provider_fallback(
992 primary,
993 model,
994 provider_name,
995 current_model,
996 );
997 }
998 return Ok(resp);
999 }
1000 Err(e) => {
1001 if is_context_window_exceeded(&e) && !context_truncated {
1003 let dropped = truncate_for_context(&mut effective_messages);
1004 if dropped > 0 {
1005 context_truncated = true;
1006 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "dropped": dropped, "remaining": effective_messages.len()})), "Context window exceeded; truncated history and retrying");
1007 continue; }
1009 let error_detail = compact_error_detail(&e);
1013 push_failure(
1014 &mut failures,
1015 provider_name,
1016 current_model,
1017 attempt + 1,
1018 self.max_retries + 1,
1019 "non_retryable",
1020 &error_detail,
1021 );
1022 anyhow::bail!(
1023 "Request exceeds model context window and cannot be reduced further. \
1024 Try using a model with a larger context window, reducing the number \
1025 of tools/skills, or enabling compact_context in config. Attempts:\n{}",
1026 failures.join("\n")
1027 );
1028 }
1029
1030 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
1031 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
1032 let rate_limited = is_rate_limited(&e);
1033 let failure_reason = failure_reason(rate_limited, non_retryable);
1034 let error_detail = compact_error_detail(&e);
1035
1036 push_failure(
1037 &mut failures,
1038 provider_name,
1039 current_model,
1040 attempt + 1,
1041 self.max_retries + 1,
1042 failure_reason,
1043 &error_detail,
1044 );
1045
1046 if rate_limited
1047 && !non_retryable_rate_limit
1048 && let Some(new_key) = self.rotate_key()
1049 {
1050 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "error": error_detail})), &format!("Rate limited; key rotation selected key ending ...{} \
1051 but cannot apply (ModelProvider trait has no set_api_key). \
1052 Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..]));
1053 }
1054
1055 if non_retryable {
1056 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "error": error_detail})), "Non-retryable error, moving on");
1057 break;
1058 }
1059
1060 if attempt < self.max_retries {
1061 let wait = self.compute_backoff(backoff_ms, &e);
1062 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model, "attempt": attempt + 1, "backoff_ms": wait, "reason": failure_reason, "error": error_detail})), "ModelProvider call failed, retrying");
1063 tokio::time::sleep(Duration::from_millis(wait)).await;
1064 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
1065 }
1066 }
1067 }
1068 }
1069
1070 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_name, "model": *current_model})), "Exhausted retries, trying next model_provider/model");
1071 }
1072
1073 if *current_model != model {
1074 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"original_model": model, "fallback_model": *current_model})), "Model fallback exhausted all model_providers, trying next fallback model");
1075 }
1076 }
1077
1078 anyhow::bail!(
1079 "All model_providers/models failed. Attempts:\n{}",
1080 failures.join("\n")
1081 )
1082 }
1083
1084 fn supports_streaming(&self) -> bool {
1085 self.model_providers
1086 .iter()
1087 .any(|(_, p)| p.supports_streaming())
1088 }
1089
1090 fn supports_streaming_tool_events(&self) -> bool {
1091 self.model_providers
1092 .iter()
1093 .any(|(_, p)| p.supports_streaming_tool_events())
1094 }
1095
1096 fn stream_chat(
1097 &self,
1098 request: ChatRequest<'_>,
1099 model: &str,
1100 temperature: Option<f64>,
1101 options: StreamOptions,
1102 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
1103 let needs_tool_events = request.tools.is_some_and(|tools| !tools.is_empty());
1104
1105 for (provider_name, model_provider) in &self.model_providers {
1106 if !model_provider.supports_streaming() || !options.enabled {
1107 continue;
1108 }
1109
1110 if needs_tool_events && !model_provider.supports_streaming_tool_events() {
1111 continue;
1112 }
1113
1114 let provider_clone = provider_name.clone();
1115
1116 let current_model = self
1117 .model_chain(model)
1118 .first()
1119 .copied()
1120 .unwrap_or(model)
1121 .to_string();
1122
1123 let req = ChatRequest {
1124 messages: request.messages,
1125 tools: request.tools,
1126 thinking: request.thinking,
1127 };
1128 let stream = model_provider.stream_chat(req, ¤t_model, temperature, options);
1129 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamEvent>>(100);
1130
1131 let handle = ::zeroclaw_spawn::spawn!(async move {
1132 let mut stream = stream;
1133 while let Some(event) = stream.next().await {
1134 if let Err(ref e) = event {
1135 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_clone, "model": current_model, "e": e.to_string()})), "Streaming error: ");
1136 }
1137 if tx.send(event).await.is_err() {
1138 break;
1139 }
1140 }
1141 });
1142
1143 let guard = AbortOnDrop::new(handle.abort_handle());
1144 return stream::unfold((rx, guard), |(mut rx, guard)| async move {
1145 rx.recv().await.map(|event| (event, (rx, guard)))
1146 })
1147 .boxed();
1148 }
1149
1150 let message = if needs_tool_events {
1151 "No model_provider supports streaming tool events".to_string()
1152 } else {
1153 "No model_provider supports streaming".to_string()
1154 };
1155 stream::once(async move { Err(super::traits::StreamError::ModelProvider(message)) }).boxed()
1156 }
1157
1158 fn stream_chat_with_system(
1159 &self,
1160 system_prompt: Option<&str>,
1161 message: &str,
1162 model: &str,
1163 temperature: Option<f64>,
1164 options: StreamOptions,
1165 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1166 for (provider_name, model_provider) in &self.model_providers {
1169 if !model_provider.supports_streaming() || !options.enabled {
1170 continue;
1171 }
1172
1173 let provider_clone = provider_name.clone();
1175
1176 let current_model = match self.model_chain(model).first() {
1178 Some(m) => (*m).to_string(),
1179 None => model.to_string(),
1180 };
1181
1182 let stream = model_provider.stream_chat_with_system(
1185 system_prompt,
1186 message,
1187 ¤t_model,
1188 temperature,
1189 options,
1190 );
1191
1192 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1194
1195 let handle = ::zeroclaw_spawn::spawn!(async move {
1196 let mut stream = stream;
1197 while let Some(chunk) = stream.next().await {
1198 if let Err(ref e) = chunk {
1199 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_clone, "model": current_model, "e": e.to_string()})), "Streaming error: ");
1200 }
1201 if tx.send(chunk).await.is_err() {
1202 break; }
1204 }
1205 });
1206
1207 let guard = AbortOnDrop::new(handle.abort_handle());
1209 return stream::unfold((rx, guard), |(mut rx, guard)| async move {
1210 rx.recv().await.map(|chunk| (chunk, (rx, guard)))
1211 })
1212 .boxed();
1213 }
1214
1215 stream::once(async move {
1217 Err(super::traits::StreamError::ModelProvider(
1218 "No model_provider supports streaming".to_string(),
1219 ))
1220 })
1221 .boxed()
1222 }
1223
1224 fn stream_chat_with_history(
1225 &self,
1226 messages: &[ChatMessage],
1227 model: &str,
1228 temperature: Option<f64>,
1229 options: StreamOptions,
1230 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1231 for (provider_name, model_provider) in &self.model_providers {
1235 if !model_provider.supports_streaming() || !options.enabled {
1236 continue;
1237 }
1238
1239 let provider_clone = provider_name.clone();
1240
1241 let current_model = match self.model_chain(model).first() {
1242 Some(m) => (*m).to_string(),
1243 None => model.to_string(),
1244 };
1245
1246 let stream = model_provider.stream_chat_with_history(
1247 messages,
1248 ¤t_model,
1249 temperature,
1250 options,
1251 );
1252
1253 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1254
1255 let handle = ::zeroclaw_spawn::spawn!(async move {
1256 let mut stream = stream;
1257 while let Some(chunk) = stream.next().await {
1258 if let Err(ref e) = chunk {
1259 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"model_provider": provider_clone, "model": current_model, "e": e.to_string()})), "Streaming error: ");
1260 }
1261 if tx.send(chunk).await.is_err() {
1262 break; }
1264 }
1265 });
1266
1267 let guard = AbortOnDrop::new(handle.abort_handle());
1268 return stream::unfold((rx, guard), |(mut rx, guard)| async move {
1269 rx.recv().await.map(|chunk| (chunk, (rx, guard)))
1270 })
1271 .boxed();
1272 }
1273
1274 stream::once(async move {
1276 Err(super::traits::StreamError::ModelProvider(
1277 "No model_provider supports streaming".to_string(),
1278 ))
1279 })
1280 .boxed()
1281 }
1282}
1283
1284impl ::zeroclaw_api::attribution::Attributable for ReliableModelProvider {
1285 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1286 ::zeroclaw_api::attribution::Role::Provider(
1287 ::zeroclaw_api::attribution::ProviderKind::Model(
1288 ::zeroclaw_api::attribution::ModelProviderKind::Reliable,
1289 ),
1290 )
1291 }
1292 fn alias(&self) -> &str {
1293 &self.alias
1294 }
1295}
1296
1297#[cfg(test)]
1298mod tests {
1299 use super::*;
1300 use futures_util::StreamExt;
1301 use std::sync::Arc;
1302 use zeroclaw_api::tool::ToolSpec;
1303
1304 struct MockModelProvider {
1305 calls: Arc<AtomicUsize>,
1306 fail_until_attempt: usize,
1307 response: &'static str,
1308 error: &'static str,
1309 }
1310
1311 #[async_trait]
1312 impl ModelProvider for MockModelProvider {
1313 async fn chat_with_system(
1314 &self,
1315 _system_prompt: Option<&str>,
1316 _message: &str,
1317 _model: &str,
1318 _temperature: Option<f64>,
1319 ) -> anyhow::Result<String> {
1320 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1321 if attempt <= self.fail_until_attempt {
1322 anyhow::bail!(self.error);
1323 }
1324 Ok(self.response.to_string())
1325 }
1326
1327 async fn chat_with_history(
1328 &self,
1329 _messages: &[ChatMessage],
1330 _model: &str,
1331 _temperature: Option<f64>,
1332 ) -> anyhow::Result<String> {
1333 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1334 if attempt <= self.fail_until_attempt {
1335 anyhow::bail!(self.error);
1336 }
1337 Ok(self.response.to_string())
1338 }
1339 }
1340 impl ::zeroclaw_api::attribution::Attributable for MockModelProvider {
1341 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1342 ::zeroclaw_api::attribution::Role::Provider(
1343 ::zeroclaw_api::attribution::ProviderKind::Model(
1344 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
1345 ),
1346 )
1347 }
1348 fn alias(&self) -> &str {
1349 "MockModelProvider"
1350 }
1351 }
1352
1353 struct ModelAwareMock {
1355 calls: Arc<AtomicUsize>,
1356 models_seen: parking_lot::Mutex<Vec<String>>,
1357 fail_models: Vec<&'static str>,
1358 response: &'static str,
1359 }
1360
1361 #[async_trait]
1362 impl ModelProvider for ModelAwareMock {
1363 async fn chat_with_system(
1364 &self,
1365 _system_prompt: Option<&str>,
1366 _message: &str,
1367 model: &str,
1368 _temperature: Option<f64>,
1369 ) -> anyhow::Result<String> {
1370 self.calls.fetch_add(1, Ordering::SeqCst);
1371 self.models_seen.lock().push(model.to_string());
1372 if self.fail_models.contains(&model) {
1373 anyhow::bail!("500 model {} unavailable", model);
1374 }
1375 Ok(self.response.to_string())
1376 }
1377 }
1378 impl ::zeroclaw_api::attribution::Attributable for ModelAwareMock {
1379 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1380 ::zeroclaw_api::attribution::Role::Provider(
1381 ::zeroclaw_api::attribution::ProviderKind::Model(
1382 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
1383 ),
1384 )
1385 }
1386 fn alias(&self) -> &str {
1387 "ModelAwareMock"
1388 }
1389 }
1390
1391 #[tokio::test]
1394 async fn succeeds_without_retry() {
1395 let calls = Arc::new(AtomicUsize::new(0));
1396 let model_provider = ReliableModelProvider::new(
1397 "test",
1398 vec![(
1399 "primary".into(),
1400 Box::new(MockModelProvider {
1401 calls: Arc::clone(&calls),
1402 fail_until_attempt: 0,
1403 response: "ok",
1404 error: "boom",
1405 }),
1406 )],
1407 2,
1408 1,
1409 );
1410
1411 let result = model_provider
1412 .simple_chat("hello", "test", Some(0.0))
1413 .await
1414 .unwrap();
1415 assert_eq!(result, "ok");
1416 assert_eq!(calls.load(Ordering::SeqCst), 1);
1417 }
1418
1419 #[tokio::test]
1420 async fn retries_then_recovers() {
1421 let calls = Arc::new(AtomicUsize::new(0));
1422 let model_provider = ReliableModelProvider::new(
1423 "test",
1424 vec![(
1425 "primary".into(),
1426 Box::new(MockModelProvider {
1427 calls: Arc::clone(&calls),
1428 fail_until_attempt: 1,
1429 response: "recovered",
1430 error: "temporary",
1431 }),
1432 )],
1433 2,
1434 1,
1435 );
1436
1437 let result = model_provider
1438 .simple_chat("hello", "test", Some(0.0))
1439 .await
1440 .unwrap();
1441 assert_eq!(result, "recovered");
1442 assert_eq!(calls.load(Ordering::SeqCst), 2);
1443 }
1444
1445 #[tokio::test]
1446 async fn falls_back_after_retries_exhausted() {
1447 let primary_calls = Arc::new(AtomicUsize::new(0));
1448 let fallback_calls = Arc::new(AtomicUsize::new(0));
1449
1450 let model_provider = ReliableModelProvider::new(
1451 "test",
1452 vec![
1453 (
1454 "primary".into(),
1455 Box::new(MockModelProvider {
1456 calls: Arc::clone(&primary_calls),
1457 fail_until_attempt: usize::MAX,
1458 response: "never",
1459 error: "primary down",
1460 }),
1461 ),
1462 (
1463 "fallback".into(),
1464 Box::new(MockModelProvider {
1465 calls: Arc::clone(&fallback_calls),
1466 fail_until_attempt: 0,
1467 response: "from fallback",
1468 error: "fallback down",
1469 }),
1470 ),
1471 ],
1472 1,
1473 1,
1474 );
1475
1476 let result = model_provider
1477 .simple_chat("hello", "test", Some(0.0))
1478 .await
1479 .unwrap();
1480 assert_eq!(result, "from fallback");
1481 assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
1482 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1483 }
1484
1485 struct EmptyThenTextMock {
1490 calls: Arc<AtomicUsize>,
1491 empty_until_attempt: usize,
1492 response: &'static str,
1493 }
1494
1495 #[async_trait]
1496 impl ModelProvider for EmptyThenTextMock {
1497 async fn chat_with_system(
1498 &self,
1499 _system_prompt: Option<&str>,
1500 _message: &str,
1501 _model: &str,
1502 _temperature: Option<f64>,
1503 ) -> anyhow::Result<String> {
1504 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1505 if attempt <= self.empty_until_attempt {
1506 Ok(String::new())
1507 } else {
1508 Ok(self.response.to_string())
1509 }
1510 }
1511 }
1512 impl ::zeroclaw_api::attribution::Attributable for EmptyThenTextMock {
1513 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1514 ::zeroclaw_api::attribution::Role::Provider(
1515 ::zeroclaw_api::attribution::ProviderKind::Model(
1516 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
1517 ),
1518 )
1519 }
1520 fn alias(&self) -> &str {
1521 "EmptyThenTextMock"
1522 }
1523 }
1524
1525 #[tokio::test]
1526 async fn chat_retries_empty_completion_then_succeeds() {
1527 let calls = Arc::new(AtomicUsize::new(0));
1528 let model_provider = ReliableModelProvider::new(
1529 "test",
1530 vec![(
1531 "primary".into(),
1532 Box::new(EmptyThenTextMock {
1533 calls: Arc::clone(&calls),
1534 empty_until_attempt: 1,
1535 response: "recovered",
1536 }),
1537 )],
1538 3,
1539 1,
1540 );
1541
1542 let messages = vec![ChatMessage::user("hello")];
1543 let request = ChatRequest {
1544 messages: &messages,
1545 tools: None,
1546 thinking: None,
1547 };
1548 let result = model_provider
1549 .chat(request, "test", Some(0.0))
1550 .await
1551 .unwrap();
1552 assert_eq!(result.text.as_deref(), Some("recovered"));
1553 assert_eq!(calls.load(Ordering::SeqCst), 2);
1555 }
1556
1557 #[tokio::test]
1558 async fn chat_with_tools_retries_empty_completion_then_succeeds() {
1559 let calls = Arc::new(AtomicUsize::new(0));
1560 let model_provider = ReliableModelProvider::new(
1561 "test",
1562 vec![(
1563 "primary".into(),
1564 Box::new(EmptyThenTextMock {
1565 calls: Arc::clone(&calls),
1566 empty_until_attempt: 1,
1567 response: "recovered",
1568 }),
1569 )],
1570 3,
1571 1,
1572 );
1573
1574 let messages = vec![ChatMessage::user("hello")];
1575 let result = model_provider
1576 .chat_with_tools(&messages, &[], "test", Some(0.0))
1577 .await
1578 .unwrap();
1579 assert_eq!(result.text.as_deref(), Some("recovered"));
1580 assert_eq!(calls.load(Ordering::SeqCst), 2);
1581 }
1582
1583 #[tokio::test]
1584 async fn chat_with_history_retries_empty_string_then_succeeds() {
1585 let calls = Arc::new(AtomicUsize::new(0));
1586 let model_provider = ReliableModelProvider::new(
1587 "test",
1588 vec![(
1589 "primary".into(),
1590 Box::new(EmptyThenTextMock {
1591 calls: Arc::clone(&calls),
1592 empty_until_attempt: 1,
1593 response: "recovered",
1594 }),
1595 )],
1596 3,
1597 1,
1598 );
1599
1600 let messages = vec![ChatMessage::user("hello")];
1601 let result = model_provider
1602 .chat_with_history(&messages, "test", Some(0.0))
1603 .await
1604 .unwrap();
1605 assert_eq!(result, "recovered");
1606 assert_eq!(calls.load(Ordering::SeqCst), 2);
1607 }
1608
1609 #[tokio::test]
1610 async fn chat_with_system_retries_empty_string_then_succeeds() {
1611 let calls = Arc::new(AtomicUsize::new(0));
1612 let model_provider = ReliableModelProvider::new(
1613 "test",
1614 vec![(
1615 "primary".into(),
1616 Box::new(EmptyThenTextMock {
1617 calls: Arc::clone(&calls),
1618 empty_until_attempt: 1,
1619 response: "recovered",
1620 }),
1621 )],
1622 3,
1623 1,
1624 );
1625
1626 let result = model_provider
1629 .simple_chat("hello", "test", Some(0.0))
1630 .await
1631 .unwrap();
1632 assert_eq!(result, "recovered");
1633 assert_eq!(calls.load(Ordering::SeqCst), 2);
1634 }
1635
1636 #[tokio::test]
1637 async fn chat_persistent_empty_returns_blank_without_error() {
1638 let calls = Arc::new(AtomicUsize::new(0));
1639 let model_provider = ReliableModelProvider::new(
1640 "test",
1641 vec![(
1642 "primary".into(),
1643 Box::new(EmptyThenTextMock {
1644 calls: Arc::clone(&calls),
1645 empty_until_attempt: usize::MAX, response: "never",
1647 }),
1648 )],
1649 2,
1650 1,
1651 );
1652
1653 let messages = vec![ChatMessage::user("hello")];
1654 let request = ChatRequest {
1655 messages: &messages,
1656 tools: None,
1657 thinking: None,
1658 };
1659 let result = model_provider
1662 .chat(request, "test", Some(0.0))
1663 .await
1664 .unwrap();
1665 assert_eq!(result.text.as_deref(), Some(""));
1666 assert_eq!(calls.load(Ordering::SeqCst), 3);
1668 }
1669
1670 #[tokio::test]
1671 async fn chat_nonempty_response_is_not_retried() {
1672 let calls = Arc::new(AtomicUsize::new(0));
1673 let model_provider = ReliableModelProvider::new(
1674 "test",
1675 vec![(
1676 "primary".into(),
1677 Box::new(EmptyThenTextMock {
1678 calls: Arc::clone(&calls),
1679 empty_until_attempt: 0, response: "direct",
1681 }),
1682 )],
1683 3,
1684 1,
1685 );
1686
1687 let messages = vec![ChatMessage::user("hello")];
1688 let request = ChatRequest {
1689 messages: &messages,
1690 tools: None,
1691 thinking: None,
1692 };
1693 let result = model_provider
1694 .chat(request, "test", Some(0.0))
1695 .await
1696 .unwrap();
1697 assert_eq!(result.text.as_deref(), Some("direct"));
1698 assert_eq!(calls.load(Ordering::SeqCst), 1);
1700 }
1701
1702 #[tokio::test]
1703 async fn returns_aggregated_error_when_all_providers_fail() {
1704 let model_provider = ReliableModelProvider::new(
1705 "test",
1706 vec![
1707 (
1708 "p1".into(),
1709 Box::new(MockModelProvider {
1710 calls: Arc::new(AtomicUsize::new(0)),
1711 fail_until_attempt: usize::MAX,
1712 response: "never",
1713 error: "p1 error",
1714 }),
1715 ),
1716 (
1717 "p2".into(),
1718 Box::new(MockModelProvider {
1719 calls: Arc::new(AtomicUsize::new(0)),
1720 fail_until_attempt: usize::MAX,
1721 response: "never",
1722 error: "p2 error",
1723 }),
1724 ),
1725 ],
1726 0,
1727 1,
1728 );
1729
1730 let err = model_provider
1731 .simple_chat("hello", "test", Some(0.0))
1732 .await
1733 .expect_err("all model_providers should fail");
1734 let msg = err.to_string();
1735 assert!(msg.contains("All model_providers/models failed"));
1736 assert!(msg.contains("model_provider=p1 model=test"));
1737 assert!(msg.contains("model_provider=p2 model=test"));
1738 assert!(msg.contains("error=p1 error"));
1739 assert!(msg.contains("error=p2 error"));
1740 assert!(msg.contains("retryable"));
1741 }
1742
1743 #[test]
1744 fn non_retryable_detects_common_patterns() {
1745 assert!(is_non_retryable(&anyhow::Error::msg("400 Bad Request")));
1746 assert!(is_non_retryable(&anyhow::Error::msg("401 Unauthorized")));
1747 assert!(is_non_retryable(&anyhow::Error::msg("403 Forbidden")));
1748 assert!(is_non_retryable(&anyhow::Error::msg("404 Not Found")));
1749 assert!(is_non_retryable(&anyhow::Error::msg(
1750 "invalid api key provided"
1751 )));
1752 assert!(is_non_retryable(&anyhow::Error::msg(
1753 "authentication failed"
1754 )));
1755 assert!(is_non_retryable(&anyhow::Error::msg(
1756 "model glm-4.7 not found"
1757 )));
1758 assert!(is_non_retryable(&anyhow::Error::msg(
1759 "unsupported model: glm-4.7"
1760 )));
1761 assert!(!is_non_retryable(&anyhow::Error::msg(
1762 "429 Too Many Requests"
1763 )));
1764 assert!(!is_non_retryable(&anyhow::Error::msg(
1765 "408 Request Timeout"
1766 )));
1767 assert!(!is_non_retryable(&anyhow::Error::msg(
1768 "500 Internal Server Error"
1769 )));
1770 assert!(!is_non_retryable(&anyhow::Error::msg("502 Bad Gateway")));
1771 assert!(!is_non_retryable(&anyhow::Error::msg("timeout")));
1772 assert!(!is_non_retryable(&anyhow::Error::msg("connection reset")));
1773 assert!(!is_non_retryable(&anyhow::Error::msg(
1774 "model overloaded, try again later"
1775 )));
1776 assert!(!is_non_retryable(&anyhow::Error::msg(
1778 "OpenAI Codex stream error: Your input exceeds the context window of this model."
1779 )));
1780 }
1781
1782 #[test]
1783 fn auth_error_detects_common_patterns() {
1784 assert!(is_auth_error(&anyhow::Error::msg("401 Unauthorized")));
1785 assert!(is_auth_error(&anyhow::Error::msg("403 Forbidden")));
1786 assert!(is_auth_error(&anyhow::Error::msg("invalid api key")));
1787 assert!(is_auth_error(&anyhow::Error::msg("authentication failed")));
1788 assert!(is_auth_error(&anyhow::Error::msg("token expired")));
1789 assert!(!is_auth_error(&anyhow::Error::msg("400 Bad Request")));
1790 assert!(!is_auth_error(&anyhow::Error::msg("429 Too Many Requests")));
1791 assert!(!is_auth_error(&anyhow::Error::msg("timeout")));
1792 assert!(!is_auth_error(&anyhow::Error::msg("connection reset")));
1793 }
1794
1795 #[tokio::test]
1796 async fn context_window_error_aborts_retries_and_model_fallbacks() {
1797 let calls = Arc::new(AtomicUsize::new(0));
1798 let mut model_fallbacks = std::collections::HashMap::new();
1799 model_fallbacks.insert(
1800 "gpt-5.3-codex".to_string(),
1801 vec!["gpt-5.2-codex".to_string()],
1802 );
1803
1804 let model_provider = ReliableModelProvider::new("test", vec![(
1805 "openai-codex".into(),
1806 Box::new(MockModelProvider {
1807 calls: Arc::clone(&calls),
1808 fail_until_attempt: usize::MAX,
1809 response: "never",
1810 error: "OpenAI Codex stream error: Your input exceeds the context window of this model. Please adjust your input and try again.",
1811 }),
1812 )],
1813 4,
1814 1,
1815 )
1816 .with_model_fallbacks(model_fallbacks);
1817
1818 let err = model_provider
1819 .simple_chat("hello", "gpt-5.3-codex", Some(0.0))
1820 .await
1821 .expect_err("context window overflow should fail fast");
1822 let msg = err.to_string();
1823
1824 assert!(msg.contains("context window"));
1825 assert_eq!(calls.load(Ordering::SeqCst), 1);
1827 }
1828
1829 #[tokio::test]
1830 async fn aggregated_error_marks_non_retryable_model_mismatch_with_details() {
1831 let calls = Arc::new(AtomicUsize::new(0));
1832 let model_provider = ReliableModelProvider::new(
1833 "test",
1834 vec![(
1835 "custom".into(),
1836 Box::new(MockModelProvider {
1837 calls: Arc::clone(&calls),
1838 fail_until_attempt: usize::MAX,
1839 response: "never",
1840 error: "unsupported model: glm-4.7",
1841 }),
1842 )],
1843 3,
1844 1,
1845 );
1846
1847 let err = model_provider
1848 .simple_chat("hello", "glm-4.7", Some(0.0))
1849 .await
1850 .expect_err("model_provider should fail");
1851 let msg = err.to_string();
1852
1853 assert!(msg.contains("non_retryable"));
1854 assert!(msg.contains("error=unsupported model: glm-4.7"));
1855 assert_eq!(calls.load(Ordering::SeqCst), 1);
1857 }
1858
1859 #[tokio::test]
1860 async fn skips_retries_on_non_retryable_error() {
1861 let primary_calls = Arc::new(AtomicUsize::new(0));
1862 let fallback_calls = Arc::new(AtomicUsize::new(0));
1863
1864 let model_provider = ReliableModelProvider::new(
1865 "test",
1866 vec![
1867 (
1868 "primary".into(),
1869 Box::new(MockModelProvider {
1870 calls: Arc::clone(&primary_calls),
1871 fail_until_attempt: usize::MAX,
1872 response: "never",
1873 error: "401 Unauthorized",
1874 }),
1875 ),
1876 (
1877 "fallback".into(),
1878 Box::new(MockModelProvider {
1879 calls: Arc::clone(&fallback_calls),
1880 fail_until_attempt: 0,
1881 response: "from fallback",
1882 error: "fallback err",
1883 }),
1884 ),
1885 ],
1886 3,
1887 1,
1888 );
1889
1890 let result = model_provider
1891 .simple_chat("hello", "test", Some(0.0))
1892 .await
1893 .unwrap();
1894 assert_eq!(result, "from fallback");
1895 assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
1897 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1898 }
1899
1900 #[tokio::test]
1901 async fn chat_with_history_retries_then_recovers() {
1902 let calls = Arc::new(AtomicUsize::new(0));
1903 let model_provider = ReliableModelProvider::new(
1904 "test",
1905 vec![(
1906 "primary".into(),
1907 Box::new(MockModelProvider {
1908 calls: Arc::clone(&calls),
1909 fail_until_attempt: 1,
1910 response: "history ok",
1911 error: "temporary",
1912 }),
1913 )],
1914 2,
1915 1,
1916 );
1917
1918 let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
1919 let result = model_provider
1920 .chat_with_history(&messages, "test", Some(0.0))
1921 .await
1922 .unwrap();
1923 assert_eq!(result, "history ok");
1924 assert_eq!(calls.load(Ordering::SeqCst), 2);
1925 }
1926
1927 #[tokio::test]
1928 async fn chat_with_history_falls_back() {
1929 let primary_calls = Arc::new(AtomicUsize::new(0));
1930 let fallback_calls = Arc::new(AtomicUsize::new(0));
1931
1932 let model_provider = ReliableModelProvider::new(
1933 "test",
1934 vec![
1935 (
1936 "primary".into(),
1937 Box::new(MockModelProvider {
1938 calls: Arc::clone(&primary_calls),
1939 fail_until_attempt: usize::MAX,
1940 response: "never",
1941 error: "primary down",
1942 }),
1943 ),
1944 (
1945 "fallback".into(),
1946 Box::new(MockModelProvider {
1947 calls: Arc::clone(&fallback_calls),
1948 fail_until_attempt: 0,
1949 response: "fallback ok",
1950 error: "fallback err",
1951 }),
1952 ),
1953 ],
1954 1,
1955 1,
1956 );
1957
1958 let messages = vec![ChatMessage::user("hello")];
1959 let result = model_provider
1960 .chat_with_history(&messages, "test", Some(0.0))
1961 .await
1962 .unwrap();
1963 assert_eq!(result, "fallback ok");
1964 assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
1965 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1966 }
1967
1968 #[tokio::test]
1971 async fn model_failover_tries_fallback_model() {
1972 let calls = Arc::new(AtomicUsize::new(0));
1973 let mock = Arc::new(ModelAwareMock {
1974 calls: Arc::clone(&calls),
1975 models_seen: parking_lot::Mutex::new(Vec::new()),
1976 fail_models: vec!["claude-opus"],
1977 response: "ok from sonnet",
1978 });
1979
1980 let mut fallbacks = HashMap::new();
1981 fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
1982
1983 let model_provider = ReliableModelProvider::new(
1984 "test",
1985 vec![(
1986 "anthropic".into(),
1987 Box::new(mock.clone()) as Box<dyn ModelProvider>,
1988 )],
1989 0, 1,
1991 )
1992 .with_model_fallbacks(fallbacks);
1993
1994 let result = model_provider
1995 .simple_chat("hello", "claude-opus", Some(0.0))
1996 .await
1997 .unwrap();
1998 assert_eq!(result, "ok from sonnet");
1999
2000 let seen = mock.models_seen.lock();
2001 assert_eq!(seen.len(), 2);
2002 assert_eq!(seen[0], "claude-opus");
2003 assert_eq!(seen[1], "claude-sonnet");
2004 }
2005
2006 #[tokio::test]
2007 async fn model_failover_all_models_fail() {
2008 let calls = Arc::new(AtomicUsize::new(0));
2009 let mock = Arc::new(ModelAwareMock {
2010 calls: Arc::clone(&calls),
2011 models_seen: parking_lot::Mutex::new(Vec::new()),
2012 fail_models: vec!["model-a", "model-b", "model-c"],
2013 response: "never",
2014 });
2015
2016 let mut fallbacks = HashMap::new();
2017 fallbacks.insert(
2018 "model-a".to_string(),
2019 vec!["model-b".to_string(), "model-c".to_string()],
2020 );
2021
2022 let model_provider = ReliableModelProvider::new(
2023 "test",
2024 vec![(
2025 "p1".into(),
2026 Box::new(mock.clone()) as Box<dyn ModelProvider>,
2027 )],
2028 0,
2029 1,
2030 )
2031 .with_model_fallbacks(fallbacks);
2032
2033 let err = model_provider
2034 .simple_chat("hello", "model-a", Some(0.0))
2035 .await
2036 .expect_err("all models should fail");
2037 assert!(
2038 err.to_string()
2039 .contains("All model_providers/models failed")
2040 );
2041
2042 let seen = mock.models_seen.lock();
2043 assert_eq!(seen.len(), 3);
2044 }
2045
2046 #[tokio::test]
2047 async fn no_model_fallbacks_behaves_like_before() {
2048 let calls = Arc::new(AtomicUsize::new(0));
2049 let model_provider = ReliableModelProvider::new(
2050 "test",
2051 vec![(
2052 "primary".into(),
2053 Box::new(MockModelProvider {
2054 calls: Arc::clone(&calls),
2055 fail_until_attempt: 0,
2056 response: "ok",
2057 error: "boom",
2058 }),
2059 )],
2060 2,
2061 1,
2062 );
2063 let result = model_provider
2065 .simple_chat("hello", "test", Some(0.0))
2066 .await
2067 .unwrap();
2068 assert_eq!(result, "ok");
2069 assert_eq!(calls.load(Ordering::SeqCst), 1);
2070 }
2071
2072 #[tokio::test]
2075 async fn auth_rotation_cycles_keys() {
2076 let model_provider = ReliableModelProvider::new(
2077 "test",
2078 vec![(
2079 "p".into(),
2080 Box::new(MockModelProvider {
2081 calls: Arc::new(AtomicUsize::new(0)),
2082 fail_until_attempt: 0,
2083 response: "ok",
2084 error: "",
2085 }),
2086 )],
2087 0,
2088 1,
2089 )
2090 .with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]);
2091
2092 let keys: Vec<&str> = (0..5)
2094 .map(|_| model_provider.rotate_key().unwrap())
2095 .collect();
2096 assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]);
2097 }
2098
2099 #[tokio::test]
2100 async fn auth_rotation_returns_none_when_empty() {
2101 let model_provider = ReliableModelProvider::new("test", vec![], 0, 1);
2102 assert!(model_provider.rotate_key().is_none());
2103 }
2104
2105 #[test]
2108 fn parse_retry_after_integer() {
2109 let err = anyhow::Error::msg("429 Too Many Requests, Retry-After: 5");
2110 assert_eq!(parse_retry_after_ms(&err), Some(5000));
2111 }
2112
2113 #[test]
2114 fn parse_retry_after_float() {
2115 let err = anyhow::Error::msg("Rate limited. retry_after: 2.5 seconds");
2116 assert_eq!(parse_retry_after_ms(&err), Some(2500));
2117 }
2118
2119 #[test]
2120 fn parse_retry_after_missing() {
2121 let err = anyhow::Error::msg("500 Internal Server Error");
2122 assert_eq!(parse_retry_after_ms(&err), None);
2123 }
2124
2125 #[test]
2126 fn rate_limited_detection() {
2127 assert!(is_rate_limited(&anyhow::Error::msg(
2128 "429 Too Many Requests"
2129 )));
2130 assert!(is_rate_limited(&anyhow::Error::msg(
2131 "HTTP 429 rate limit exceeded"
2132 )));
2133 assert!(!is_rate_limited(&anyhow::Error::msg("401 Unauthorized")));
2134 assert!(!is_rate_limited(&anyhow::Error::msg(
2135 "500 Internal Server Error"
2136 )));
2137 }
2138
2139 #[test]
2140 fn non_retryable_rate_limit_detects_plan_restricted_model() {
2141 let err = anyhow::Error::msg(
2142 "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"the current account plan does not include glm-5\"}",
2143 );
2144 assert!(
2145 is_non_retryable_rate_limit(&err),
2146 "plan-restricted 429 should skip retries"
2147 );
2148 }
2149
2150 #[test]
2151 fn non_retryable_rate_limit_detects_insufficient_balance() {
2152 let err = anyhow::Error::msg(
2153 "API error (429 Too Many Requests): {\"code\":1113,\"message\":\"insufficient balance\"}",
2154 );
2155 assert!(
2156 is_non_retryable_rate_limit(&err),
2157 "insufficient-balance 429 should skip retries"
2158 );
2159 }
2160
2161 #[test]
2162 fn non_retryable_rate_limit_does_not_flag_generic_429() {
2163 let err = anyhow::Error::msg("429 Too Many Requests: rate limit exceeded");
2164 assert!(
2165 !is_non_retryable_rate_limit(&err),
2166 "generic rate-limit 429 should remain retryable"
2167 );
2168 }
2169
2170 #[test]
2171 fn compute_backoff_uses_retry_after() {
2172 let model_provider = ReliableModelProvider::new("test", vec![], 0, 500);
2173 let err = anyhow::Error::msg("429 Retry-After: 3");
2174 assert_eq!(model_provider.compute_backoff(500, &err), 3_000);
2175 }
2176
2177 #[test]
2178 fn compute_backoff_caps_at_30s() {
2179 let model_provider = ReliableModelProvider::new("test", vec![], 0, 500);
2180 let err = anyhow::Error::msg("429 Retry-After: 120");
2181 assert_eq!(model_provider.compute_backoff(500, &err), 30_000);
2182 }
2183
2184 #[test]
2185 fn compute_backoff_falls_back_to_base() {
2186 let model_provider = ReliableModelProvider::new("test", vec![], 0, 500);
2187 let err = anyhow::Error::msg("500 Server Error");
2188 assert_eq!(model_provider.compute_backoff(500, &err), 500);
2189 }
2190
2191 #[test]
2194 fn non_retryable_detects_401() {
2195 let err = anyhow::Error::msg("API error (401 Unauthorized): invalid api key");
2196 assert!(
2197 is_non_retryable(&err),
2198 "401 errors must be detected as non-retryable"
2199 );
2200 }
2201
2202 #[test]
2203 fn non_retryable_detects_403() {
2204 let err = anyhow::Error::msg("API error (403 Forbidden): access denied");
2205 assert!(
2206 is_non_retryable(&err),
2207 "403 errors must be detected as non-retryable"
2208 );
2209 }
2210
2211 #[test]
2212 fn non_retryable_detects_404() {
2213 let err = anyhow::Error::msg("API error (404 Not Found): model not found");
2214 assert!(
2215 is_non_retryable(&err),
2216 "404 errors must be detected as non-retryable"
2217 );
2218 }
2219
2220 #[test]
2221 fn non_retryable_does_not_flag_429() {
2222 let err = anyhow::Error::msg("429 Too Many Requests");
2223 assert!(
2224 !is_non_retryable(&err),
2225 "429 must NOT be treated as non-retryable (it is retryable with backoff)"
2226 );
2227 }
2228
2229 #[test]
2230 fn non_retryable_does_not_flag_408() {
2231 let err = anyhow::Error::msg("408 Request Timeout");
2232 assert!(
2233 !is_non_retryable(&err),
2234 "408 must NOT be treated as non-retryable (it is retryable)"
2235 );
2236 }
2237
2238 #[test]
2239 fn non_retryable_does_not_flag_500() {
2240 let err = anyhow::Error::msg("500 Internal Server Error");
2241 assert!(
2242 !is_non_retryable(&err),
2243 "500 must NOT be treated as non-retryable (server errors are retryable)"
2244 );
2245 }
2246
2247 #[test]
2248 fn non_retryable_does_not_flag_502() {
2249 let err = anyhow::Error::msg("502 Bad Gateway");
2250 assert!(
2251 !is_non_retryable(&err),
2252 "502 must NOT be treated as non-retryable"
2253 );
2254 }
2255
2256 #[test]
2259 fn parse_retry_after_zero() {
2260 let err = anyhow::Error::msg("429 Too Many Requests, Retry-After: 0");
2261 assert_eq!(
2262 parse_retry_after_ms(&err),
2263 Some(0),
2264 "Retry-After: 0 should parse as 0ms"
2265 );
2266 }
2267
2268 #[test]
2269 fn parse_retry_after_with_underscore_separator() {
2270 let err = anyhow::Error::msg("rate limited, retry_after: 10");
2271 assert_eq!(
2272 parse_retry_after_ms(&err),
2273 Some(10_000),
2274 "retry_after with underscore must be parsed"
2275 );
2276 }
2277
2278 #[test]
2279 fn parse_retry_after_space_separator() {
2280 let err = anyhow::Error::msg("Retry-After 7");
2281 assert_eq!(
2282 parse_retry_after_ms(&err),
2283 Some(7000),
2284 "Retry-After with space separator must be parsed"
2285 );
2286 }
2287
2288 #[test]
2289 fn rate_limited_false_for_generic_error() {
2290 let err = anyhow::Error::msg("Connection refused");
2291 assert!(
2292 !is_rate_limited(&err),
2293 "generic errors must not be flagged as rate-limited"
2294 );
2295 }
2296
2297 #[tokio::test]
2300 async fn non_retryable_skips_retries_for_401() {
2301 let calls = Arc::new(AtomicUsize::new(0));
2302 let model_provider = ReliableModelProvider::new(
2303 "test",
2304 vec![(
2305 "primary".into(),
2306 Box::new(MockModelProvider {
2307 calls: Arc::clone(&calls),
2308 fail_until_attempt: usize::MAX,
2309 response: "never",
2310 error: "API error (401 Unauthorized): invalid key",
2311 }),
2312 )],
2313 5,
2314 1,
2315 );
2316
2317 let result = model_provider.simple_chat("hello", "test", Some(0.0)).await;
2318 assert!(result.is_err(), "401 should fail without retries");
2319 assert_eq!(
2320 calls.load(Ordering::SeqCst),
2321 1,
2322 "must not retry on 401 — should be exactly 1 call"
2323 );
2324 }
2325
2326 #[tokio::test]
2327 async fn non_retryable_rate_limit_skips_retries_for_plan_errors() {
2328 let calls = Arc::new(AtomicUsize::new(0));
2329 let model_provider = ReliableModelProvider::new(
2330 "test",
2331 vec![(
2332 "primary".into(),
2333 Box::new(MockModelProvider {
2334 calls: Arc::clone(&calls),
2335 fail_until_attempt: usize::MAX,
2336 response: "never",
2337 error: "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"plan does not include glm-5\"}",
2338 }),
2339 )],
2340 5,
2341 1,
2342 );
2343
2344 let result = model_provider.simple_chat("hello", "test", Some(0.0)).await;
2345 assert!(
2346 result.is_err(),
2347 "plan-restricted 429 should fail quickly without retrying"
2348 );
2349 assert_eq!(
2350 calls.load(Ordering::SeqCst),
2351 1,
2352 "must not retry non-retryable 429 business errors"
2353 );
2354 }
2355
2356 struct NativeToolMock {
2360 calls: Arc<AtomicUsize>,
2361 fail_until_attempt: usize,
2362 response_text: &'static str,
2363 tool_calls: Vec<super::super::traits::ToolCall>,
2364 error: &'static str,
2365 }
2366
2367 #[async_trait]
2368 impl ModelProvider for NativeToolMock {
2369 async fn chat_with_system(
2370 &self,
2371 _system_prompt: Option<&str>,
2372 _message: &str,
2373 _model: &str,
2374 _temperature: Option<f64>,
2375 ) -> anyhow::Result<String> {
2376 Ok(self.response_text.to_string())
2377 }
2378
2379 fn supports_native_tools(&self) -> bool {
2380 true
2381 }
2382
2383 async fn chat(
2384 &self,
2385 _request: ChatRequest<'_>,
2386 _model: &str,
2387 _temperature: Option<f64>,
2388 ) -> anyhow::Result<ChatResponse> {
2389 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
2390 if attempt <= self.fail_until_attempt {
2391 anyhow::bail!(self.error);
2392 }
2393 Ok(ChatResponse {
2394 text: Some(self.response_text.to_string()),
2395 tool_calls: self.tool_calls.clone(),
2396 usage: None,
2397 reasoning_content: None,
2398 })
2399 }
2400 }
2401 impl ::zeroclaw_api::attribution::Attributable for NativeToolMock {
2402 fn role(&self) -> ::zeroclaw_api::attribution::Role {
2403 ::zeroclaw_api::attribution::Role::Provider(
2404 ::zeroclaw_api::attribution::ProviderKind::Model(
2405 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2406 ),
2407 )
2408 }
2409 fn alias(&self) -> &str {
2410 "NativeToolMock"
2411 }
2412 }
2413
2414 #[tokio::test]
2415 async fn chat_delegates_to_inner_provider() {
2416 let calls = Arc::new(AtomicUsize::new(0));
2417 let tool_call = super::super::traits::ToolCall {
2418 id: "call_1".to_string(),
2419 name: "shell".to_string(),
2420 arguments: r#"{"command":"date"}"#.to_string(),
2421 extra_content: None,
2422 };
2423 let model_provider = ReliableModelProvider::new(
2424 "test",
2425 vec![(
2426 "primary".into(),
2427 Box::new(NativeToolMock {
2428 calls: Arc::clone(&calls),
2429 fail_until_attempt: 0,
2430 response_text: "ok",
2431 tool_calls: vec![tool_call.clone()],
2432 error: "boom",
2433 }) as Box<dyn ModelProvider>,
2434 )],
2435 2,
2436 1,
2437 );
2438
2439 let messages = vec![ChatMessage::user("what time is it?")];
2440 let request = ChatRequest {
2441 messages: &messages,
2442 tools: None,
2443 thinking: None,
2444 };
2445 let result = model_provider
2446 .chat(request, "test-model", Some(0.0))
2447 .await
2448 .unwrap();
2449
2450 assert_eq!(result.text.as_deref(), Some("ok"));
2451 assert_eq!(result.tool_calls.len(), 1);
2452 assert_eq!(result.tool_calls[0].name, "shell");
2453 assert_eq!(calls.load(Ordering::SeqCst), 1);
2454 }
2455
2456 #[tokio::test]
2457 async fn chat_retries_and_recovers() {
2458 let calls = Arc::new(AtomicUsize::new(0));
2459 let tool_call = super::super::traits::ToolCall {
2460 id: "call_1".to_string(),
2461 name: "shell".to_string(),
2462 arguments: r#"{"command":"date"}"#.to_string(),
2463 extra_content: None,
2464 };
2465 let model_provider = ReliableModelProvider::new(
2466 "test",
2467 vec![(
2468 "primary".into(),
2469 Box::new(NativeToolMock {
2470 calls: Arc::clone(&calls),
2471 fail_until_attempt: 2,
2472 response_text: "recovered",
2473 tool_calls: vec![tool_call],
2474 error: "temporary failure",
2475 }) as Box<dyn ModelProvider>,
2476 )],
2477 3,
2478 1,
2479 );
2480
2481 let messages = vec![ChatMessage::user("test")];
2482 let request = ChatRequest {
2483 messages: &messages,
2484 tools: None,
2485 thinking: None,
2486 };
2487 let result = model_provider
2488 .chat(request, "test-model", Some(0.0))
2489 .await
2490 .unwrap();
2491
2492 assert_eq!(result.text.as_deref(), Some("recovered"));
2493 assert!(
2494 calls.load(Ordering::SeqCst) > 1,
2495 "should have retried at least once"
2496 );
2497 }
2498
2499 #[tokio::test]
2500 async fn chat_preserves_native_tools_support() {
2501 let calls = Arc::new(AtomicUsize::new(0));
2502 let model_provider = ReliableModelProvider::new(
2503 "test",
2504 vec![(
2505 "primary".into(),
2506 Box::new(NativeToolMock {
2507 calls: Arc::clone(&calls),
2508 fail_until_attempt: 0,
2509 response_text: "ok",
2510 tool_calls: vec![],
2511 error: "boom",
2512 }) as Box<dyn ModelProvider>,
2513 )],
2514 2,
2515 1,
2516 );
2517
2518 assert!(
2519 model_provider.supports_native_tools(),
2520 "ReliableModelProvider must propagate supports_native_tools from inner model_provider"
2521 );
2522 }
2523
2524 #[tokio::test]
2529 async fn chat_returns_aggregated_error_when_all_providers_fail() {
2530 let model_provider = ReliableModelProvider::new(
2531 "test",
2532 vec![
2533 (
2534 "p1".into(),
2535 Box::new(NativeToolMock {
2536 calls: Arc::new(AtomicUsize::new(0)),
2537 fail_until_attempt: usize::MAX,
2538 response_text: "never",
2539 tool_calls: vec![],
2540 error: "p1 chat error",
2541 }) as Box<dyn ModelProvider>,
2542 ),
2543 (
2544 "p2".into(),
2545 Box::new(NativeToolMock {
2546 calls: Arc::new(AtomicUsize::new(0)),
2547 fail_until_attempt: usize::MAX,
2548 response_text: "never",
2549 tool_calls: vec![],
2550 error: "p2 chat error",
2551 }) as Box<dyn ModelProvider>,
2552 ),
2553 ],
2554 0,
2555 1,
2556 );
2557
2558 let messages = vec![ChatMessage::user("hello")];
2559 let request = ChatRequest {
2560 messages: &messages,
2561 tools: None,
2562 thinking: None,
2563 };
2564 let err = model_provider
2565 .chat(request, "test", Some(0.0))
2566 .await
2567 .expect_err("all model_providers should fail");
2568 let msg = err.to_string();
2569 assert!(msg.contains("All model_providers/models failed"));
2570 assert!(msg.contains("model_provider=p1 model=test"));
2571 assert!(msg.contains("model_provider=p2 model=test"));
2572 assert!(msg.contains("error=p1 chat error"));
2573 assert!(msg.contains("error=p2 chat error"));
2574 assert!(msg.contains("retryable"));
2575 }
2576
2577 struct NativeModelAwareMock {
2580 calls: Arc<AtomicUsize>,
2581 models_seen: parking_lot::Mutex<Vec<String>>,
2582 fail_models: Vec<&'static str>,
2583 response_text: &'static str,
2584 }
2585
2586 #[async_trait]
2587 impl ModelProvider for NativeModelAwareMock {
2588 async fn chat_with_system(
2589 &self,
2590 _system_prompt: Option<&str>,
2591 _message: &str,
2592 _model: &str,
2593 _temperature: Option<f64>,
2594 ) -> anyhow::Result<String> {
2595 Ok(self.response_text.to_string())
2596 }
2597
2598 fn supports_native_tools(&self) -> bool {
2599 true
2600 }
2601
2602 async fn chat(
2603 &self,
2604 _request: ChatRequest<'_>,
2605 model: &str,
2606 _temperature: Option<f64>,
2607 ) -> anyhow::Result<ChatResponse> {
2608 self.calls.fetch_add(1, Ordering::SeqCst);
2609 self.models_seen.lock().push(model.to_string());
2610 if self.fail_models.contains(&model) {
2611 anyhow::bail!("500 model {} unavailable", model);
2612 }
2613 Ok(ChatResponse {
2614 text: Some(self.response_text.to_string()),
2615 tool_calls: vec![],
2616 usage: None,
2617 reasoning_content: None,
2618 })
2619 }
2620 }
2621 impl ::zeroclaw_api::attribution::Attributable for NativeModelAwareMock {
2622 fn role(&self) -> ::zeroclaw_api::attribution::Role {
2623 ::zeroclaw_api::attribution::Role::Provider(
2624 ::zeroclaw_api::attribution::ProviderKind::Model(
2625 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2626 ),
2627 )
2628 }
2629 fn alias(&self) -> &str {
2630 "NativeModelAwareMock"
2631 }
2632 }
2633
2634 #[tokio::test]
2639 async fn chat_tries_model_failover_on_failure() {
2640 let calls = Arc::new(AtomicUsize::new(0));
2641 let mock = Arc::new(NativeModelAwareMock {
2642 calls: Arc::clone(&calls),
2643 models_seen: parking_lot::Mutex::new(Vec::new()),
2644 fail_models: vec!["claude-opus"],
2645 response_text: "ok from sonnet",
2646 });
2647
2648 let mut fallbacks = HashMap::new();
2649 fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
2650
2651 let model_provider = ReliableModelProvider::new(
2652 "test",
2653 vec![(
2654 "anthropic".into(),
2655 Box::new(mock.clone()) as Box<dyn ModelProvider>,
2656 )],
2657 0, 1,
2659 )
2660 .with_model_fallbacks(fallbacks);
2661
2662 let messages = vec![ChatMessage::user("hello")];
2663 let request = ChatRequest {
2664 messages: &messages,
2665 tools: None,
2666 thinking: None,
2667 };
2668 let result = model_provider
2669 .chat(request, "claude-opus", Some(0.0))
2670 .await
2671 .unwrap();
2672 assert_eq!(result.text.as_deref(), Some("ok from sonnet"));
2673
2674 let seen = mock.models_seen.lock();
2675 assert_eq!(seen.len(), 2);
2676 assert_eq!(seen[0], "claude-opus");
2677 assert_eq!(seen[1], "claude-sonnet");
2678 }
2679
2680 #[tokio::test]
2683 async fn chat_skips_non_retryable_errors() {
2684 let primary_calls = Arc::new(AtomicUsize::new(0));
2685 let fallback_calls = Arc::new(AtomicUsize::new(0));
2686
2687 let model_provider = ReliableModelProvider::new(
2688 "test",
2689 vec![
2690 (
2691 "primary".into(),
2692 Box::new(NativeToolMock {
2693 calls: Arc::clone(&primary_calls),
2694 fail_until_attempt: usize::MAX,
2695 response_text: "never",
2696 tool_calls: vec![],
2697 error: "401 Unauthorized",
2698 }) as Box<dyn ModelProvider>,
2699 ),
2700 (
2701 "fallback".into(),
2702 Box::new(NativeToolMock {
2703 calls: Arc::clone(&fallback_calls),
2704 fail_until_attempt: 0,
2705 response_text: "from fallback",
2706 tool_calls: vec![],
2707 error: "fallback err",
2708 }) as Box<dyn ModelProvider>,
2709 ),
2710 ],
2711 3,
2712 1,
2713 );
2714
2715 let messages = vec![ChatMessage::user("hello")];
2716 let request = ChatRequest {
2717 messages: &messages,
2718 tools: None,
2719 thinking: None,
2720 };
2721 let result = model_provider
2722 .chat(request, "test", Some(0.0))
2723 .await
2724 .unwrap();
2725 assert_eq!(result.text.as_deref(), Some("from fallback"));
2726 assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
2728 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
2729 }
2730
2731 #[test]
2734 fn context_window_error_is_not_non_retryable() {
2735 assert!(!is_non_retryable(&anyhow::Error::msg(
2737 "exceeds the context window"
2738 )));
2739 assert!(!is_non_retryable(&anyhow::Error::msg(
2740 "maximum context length exceeded"
2741 )));
2742 assert!(!is_non_retryable(&anyhow::Error::msg(
2743 "too many tokens in the request"
2744 )));
2745 assert!(!is_non_retryable(&anyhow::Error::msg(
2746 "token limit exceeded"
2747 )));
2748 }
2749
2750 #[test]
2751 fn is_context_window_exceeded_detects_llamacpp() {
2752 assert!(is_context_window_exceeded(&anyhow::Error::msg(
2753 "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
2754 )));
2755 }
2756
2757 #[test]
2758 fn truncate_for_context_drops_oldest_non_system() {
2759 let mut messages = vec![
2760 ChatMessage::system("sys"),
2761 ChatMessage::user("msg1"),
2762 ChatMessage::assistant("resp1"),
2763 ChatMessage::user("msg2"),
2764 ChatMessage::assistant("resp2"),
2765 ChatMessage::user("msg3"),
2766 ];
2767
2768 let dropped = truncate_for_context(&mut messages);
2769
2770 assert_eq!(dropped, 2);
2772 assert_eq!(messages[0].role, "system");
2774 assert_eq!(messages.len(), 4); assert_eq!(messages.last().unwrap().content, "msg3");
2778 }
2779
2780 #[test]
2781 fn truncate_for_context_preserves_system_and_last_message() {
2782 let mut messages = vec![ChatMessage::system("sys"), ChatMessage::user("only")];
2784 let dropped = truncate_for_context(&mut messages);
2785 assert_eq!(dropped, 0);
2786 assert_eq!(messages.len(), 2);
2787
2788 let mut messages = vec![ChatMessage::user("only")];
2790 let dropped = truncate_for_context(&mut messages);
2791 assert_eq!(dropped, 0);
2792 assert_eq!(messages.len(), 1);
2793 }
2794
2795 struct ContextOverflowMock {
2798 calls: Arc<AtomicUsize>,
2799 fail_until_attempt: usize,
2800 message_counts: parking_lot::Mutex<Vec<usize>>,
2801 }
2802
2803 #[async_trait]
2804 impl ModelProvider for ContextOverflowMock {
2805 async fn chat_with_system(
2806 &self,
2807 _system_prompt: Option<&str>,
2808 _message: &str,
2809 _model: &str,
2810 _temperature: Option<f64>,
2811 ) -> anyhow::Result<String> {
2812 Ok("ok".to_string())
2813 }
2814
2815 async fn chat_with_history(
2816 &self,
2817 messages: &[ChatMessage],
2818 _model: &str,
2819 _temperature: Option<f64>,
2820 ) -> anyhow::Result<String> {
2821 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
2822 self.message_counts.lock().push(messages.len());
2823 if attempt <= self.fail_until_attempt {
2824 anyhow::bail!(
2825 "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
2826 );
2827 }
2828 Ok("recovered after truncation".to_string())
2829 }
2830 }
2831 impl ::zeroclaw_api::attribution::Attributable for ContextOverflowMock {
2832 fn role(&self) -> ::zeroclaw_api::attribution::Role {
2833 ::zeroclaw_api::attribution::Role::Provider(
2834 ::zeroclaw_api::attribution::ProviderKind::Model(
2835 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
2836 ),
2837 )
2838 }
2839 fn alias(&self) -> &str {
2840 "ContextOverflowMock"
2841 }
2842 }
2843
2844 #[tokio::test]
2845 async fn chat_with_history_truncates_on_context_overflow() {
2846 let calls = Arc::new(AtomicUsize::new(0));
2847 let mock = ContextOverflowMock {
2848 calls: Arc::clone(&calls),
2849 fail_until_attempt: 1, message_counts: parking_lot::Mutex::new(Vec::new()),
2851 };
2852
2853 let model_provider = ReliableModelProvider::new(
2854 "test",
2855 vec![("local".into(), Box::new(mock) as Box<dyn ModelProvider>)],
2856 3,
2857 1,
2858 );
2859
2860 let messages = vec![
2861 ChatMessage::system("system prompt"),
2862 ChatMessage::user("old message 1"),
2863 ChatMessage::assistant("old response 1"),
2864 ChatMessage::user("old message 2"),
2865 ChatMessage::assistant("old response 2"),
2866 ChatMessage::user("current question"),
2867 ];
2868
2869 let result = model_provider
2870 .chat_with_history(&messages, "local-model", Some(0.0))
2871 .await
2872 .unwrap();
2873 assert_eq!(result, "recovered after truncation");
2874 assert_eq!(calls.load(Ordering::SeqCst), 2);
2876 }
2877
2878 #[tokio::test]
2879 async fn context_overflow_with_no_history_to_truncate_bails_immediately() {
2880 let calls = Arc::new(AtomicUsize::new(0));
2881 let mock = ContextOverflowMock {
2882 calls: Arc::clone(&calls),
2883 fail_until_attempt: 999, message_counts: parking_lot::Mutex::new(Vec::new()),
2885 };
2886
2887 let model_provider = ReliableModelProvider::new(
2888 "test",
2889 vec![("local".into(), Box::new(mock) as Box<dyn ModelProvider>)],
2890 3,
2891 1,
2892 );
2893
2894 let messages = vec![
2896 ChatMessage::system("huge system prompt that exceeds context window"),
2897 ChatMessage::user("hello"),
2898 ];
2899
2900 let result = model_provider
2901 .chat_with_history(&messages, "local-model", Some(0.0))
2902 .await;
2903 assert!(result.is_err());
2904 let err_msg = result.unwrap_err().to_string();
2905 assert!(
2906 err_msg.contains("cannot be reduced further"),
2907 "Should bail with actionable message, got: {err_msg}"
2908 );
2909 assert_eq!(
2911 calls.load(Ordering::SeqCst),
2912 1,
2913 "Should not retry when truncation is impossible"
2914 );
2915 }
2916
2917 #[test]
2920 fn tool_schema_error_detects_groq_validation_failure() {
2921 let msg = r#"Groq API error (400 Bad Request): {"error":{"message":"tool call validation failed: attempted to call tool 'memory_recall' which was not in request"}}"#;
2922 let err = anyhow::Error::msg(msg.to_string());
2923 assert!(is_tool_schema_error(&err));
2924 }
2925
2926 #[test]
2927 fn tool_schema_error_detects_not_in_request() {
2928 let err = anyhow::Error::msg("tool 'search' was not in request");
2929 assert!(is_tool_schema_error(&err));
2930 }
2931
2932 #[test]
2933 fn tool_schema_error_detects_not_found_in_tool_list() {
2934 let err = anyhow::Error::msg("function 'foo' not found in tool list");
2935 assert!(is_tool_schema_error(&err));
2936 }
2937
2938 #[test]
2939 fn tool_schema_error_detects_invalid_tool_call() {
2940 let err = anyhow::Error::msg("invalid_tool_call: no matching function");
2941 assert!(is_tool_schema_error(&err));
2942 }
2943
2944 #[test]
2945 fn tool_schema_error_ignores_unrelated_errors() {
2946 let err = anyhow::Error::msg("invalid api key");
2947 assert!(!is_tool_schema_error(&err));
2948
2949 let err = anyhow::Error::msg("model not found");
2950 assert!(!is_tool_schema_error(&err));
2951 }
2952
2953 #[test]
2954 fn non_retryable_returns_false_for_tool_schema_400() {
2955 let msg = "400 Bad Request: tool call validation failed: attempted to call tool 'x' which was not in request";
2957 let err = anyhow::Error::msg(msg.to_string());
2958 assert!(!is_non_retryable(&err));
2959 }
2960
2961 #[test]
2962 fn non_retryable_returns_true_for_other_400_errors() {
2963 let err = anyhow::Error::msg("400 Bad Request: invalid api key provided");
2965 assert!(is_non_retryable(&err));
2966 }
2967
2968 struct StreamingToolEventMock {
2969 stream_calls: Arc<AtomicUsize>,
2970 supports_tool_events: bool,
2971 }
2972
2973 impl StreamingToolEventMock {
2974 fn new(supports_tool_events: bool) -> Self {
2975 Self {
2976 stream_calls: Arc::new(AtomicUsize::new(0)),
2977 supports_tool_events,
2978 }
2979 }
2980 }
2981
2982 #[async_trait]
2983 impl ModelProvider for StreamingToolEventMock {
2984 async fn chat_with_system(
2985 &self,
2986 _system_prompt: Option<&str>,
2987 _message: &str,
2988 _model: &str,
2989 _temperature: Option<f64>,
2990 ) -> anyhow::Result<String> {
2991 Ok("ok".to_string())
2992 }
2993
2994 fn supports_streaming(&self) -> bool {
2995 true
2996 }
2997
2998 fn supports_streaming_tool_events(&self) -> bool {
2999 self.supports_tool_events
3000 }
3001
3002 fn stream_chat(
3003 &self,
3004 _request: ChatRequest<'_>,
3005 _model: &str,
3006 _temperature: Option<f64>,
3007 _options: StreamOptions,
3008 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
3009 self.stream_calls.fetch_add(1, Ordering::SeqCst);
3010 stream::iter(vec![
3011 Ok(StreamEvent::ToolCall(super::super::traits::ToolCall {
3012 id: "call_1".to_string(),
3013 name: "shell".to_string(),
3014 arguments: r#"{"command":"date"}"#.to_string(),
3015 extra_content: None,
3016 })),
3017 Ok(StreamEvent::Final),
3018 ])
3019 .boxed()
3020 }
3021 }
3022 impl ::zeroclaw_api::attribution::Attributable for StreamingToolEventMock {
3023 fn role(&self) -> ::zeroclaw_api::attribution::Role {
3024 ::zeroclaw_api::attribution::Role::Provider(
3025 ::zeroclaw_api::attribution::ProviderKind::Model(
3026 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
3027 ),
3028 )
3029 }
3030 fn alias(&self) -> &str {
3031 "StreamingToolEventMock"
3032 }
3033 }
3034
3035 #[tokio::test]
3038 async fn stream_chat_prefers_provider_with_tool_event_support() {
3039 let primary = Arc::new(StreamingToolEventMock::new(false));
3040 let fallback = Arc::new(StreamingToolEventMock::new(true));
3041 let model_provider = ReliableModelProvider::new(
3042 "test",
3043 vec![
3044 (
3045 "primary".into(),
3046 Box::new(Arc::clone(&primary)) as Box<dyn ModelProvider>,
3047 ),
3048 (
3049 "fallback".into(),
3050 Box::new(Arc::clone(&fallback)) as Box<dyn ModelProvider>,
3051 ),
3052 ],
3053 0,
3054 1,
3055 );
3056
3057 let messages = vec![ChatMessage::user("hello")];
3058 let tools = vec![ToolSpec {
3059 name: "shell".to_string(),
3060 description: "run shell".to_string(),
3061 parameters: serde_json::json!({
3062 "type": "object",
3063 "properties": {
3064 "command": { "type": "string" }
3065 }
3066 }),
3067 }];
3068 let mut stream = model_provider.stream_chat(
3069 ChatRequest {
3070 messages: &messages,
3071 tools: Some(&tools),
3072 thinking: None,
3073 },
3074 "model",
3075 Some(0.0),
3076 StreamOptions::new(true),
3077 );
3078
3079 let first = stream.next().await.unwrap().unwrap();
3080 let second = stream.next().await.unwrap().unwrap();
3081 assert!(stream.next().await.is_none());
3082
3083 match first {
3084 StreamEvent::ToolCall(call) => assert_eq!(call.name, "shell"),
3085 other => panic!("expected tool-call event, got {other:?}"),
3086 }
3087 assert!(matches!(second, StreamEvent::Final));
3088 assert_eq!(primary.stream_calls.load(Ordering::SeqCst), 0);
3089 assert_eq!(fallback.stream_calls.load(Ordering::SeqCst), 1);
3090 }
3091
3092 #[tokio::test]
3093 async fn stream_chat_errors_when_no_provider_supports_tool_events() {
3094 let primary = Arc::new(StreamingToolEventMock::new(false));
3095 let model_provider = ReliableModelProvider::new(
3096 "test",
3097 vec![(
3098 "primary".into(),
3099 Box::new(Arc::clone(&primary)) as Box<dyn ModelProvider>,
3100 )],
3101 0,
3102 1,
3103 );
3104
3105 let messages = vec![ChatMessage::user("hello")];
3106 let tools = vec![ToolSpec {
3107 name: "shell".to_string(),
3108 description: "run shell".to_string(),
3109 parameters: serde_json::json!({"type": "object"}),
3110 }];
3111 let mut stream = model_provider.stream_chat(
3112 ChatRequest {
3113 messages: &messages,
3114 tools: Some(&tools),
3115 thinking: None,
3116 },
3117 "model",
3118 Some(0.0),
3119 StreamOptions::new(true),
3120 );
3121
3122 let first = stream.next().await.unwrap();
3123 let err = first.expect_err("stream should fail without tool-event support");
3124 assert!(
3125 err.to_string()
3126 .contains("No model_provider supports streaming tool events"),
3127 "unexpected stream error: {err}"
3128 );
3129 assert!(stream.next().await.is_none());
3130 assert_eq!(primary.stream_calls.load(Ordering::SeqCst), 0);
3131 }
3132
3133 struct StreamingHistoryMock {
3137 stream_calls: Arc<AtomicUsize>,
3138 supports: bool,
3139 }
3140
3141 #[async_trait]
3142 impl ModelProvider for StreamingHistoryMock {
3143 async fn chat_with_system(
3144 &self,
3145 _system_prompt: Option<&str>,
3146 _message: &str,
3147 _model: &str,
3148 _temperature: Option<f64>,
3149 ) -> anyhow::Result<String> {
3150 Ok("ok".to_string())
3151 }
3152
3153 fn supports_streaming(&self) -> bool {
3154 self.supports
3155 }
3156
3157 fn stream_chat_with_history(
3158 &self,
3159 messages: &[ChatMessage],
3160 _model: &str,
3161 _temperature: Option<f64>,
3162 _options: StreamOptions,
3163 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
3164 self.stream_calls.fetch_add(1, Ordering::SeqCst);
3165 let msg_count = messages.len().to_string();
3167 stream::iter(vec![
3168 Ok(StreamChunk::delta(msg_count)),
3169 Ok(StreamChunk::final_chunk()),
3170 ])
3171 .boxed()
3172 }
3173 }
3174 impl ::zeroclaw_api::attribution::Attributable for StreamingHistoryMock {
3175 fn role(&self) -> ::zeroclaw_api::attribution::Role {
3176 ::zeroclaw_api::attribution::Role::Provider(
3177 ::zeroclaw_api::attribution::ProviderKind::Model(
3178 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
3179 ),
3180 )
3181 }
3182 fn alias(&self) -> &str {
3183 "StreamingHistoryMock"
3184 }
3185 }
3186
3187 #[tokio::test]
3188 async fn stream_chat_with_history_delegates_to_streaming_provider() {
3189 let calls = Arc::new(AtomicUsize::new(0));
3190 let model_provider = ReliableModelProvider::new(
3191 "test",
3192 vec![(
3193 "primary".into(),
3194 Box::new(StreamingHistoryMock {
3195 stream_calls: Arc::clone(&calls),
3196 supports: true,
3197 }) as Box<dyn ModelProvider>,
3198 )],
3199 0,
3200 1,
3201 );
3202
3203 let messages = vec![
3204 ChatMessage::system("system"),
3205 ChatMessage::user("msg1"),
3206 ChatMessage::assistant("resp1"),
3207 ChatMessage::user("msg2"),
3208 ];
3209 let mut stream = model_provider.stream_chat_with_history(
3210 &messages,
3211 "model",
3212 Some(0.0),
3213 StreamOptions::new(true),
3214 );
3215
3216 let first = stream.next().await.unwrap().unwrap();
3217 assert_eq!(
3218 first.delta, "4",
3219 "should pass all 4 messages to model_provider"
3220 );
3221 let second = stream.next().await.unwrap().unwrap();
3222 assert!(second.is_final);
3223 assert!(stream.next().await.is_none());
3224 assert_eq!(calls.load(Ordering::SeqCst), 1);
3225 }
3226
3227 #[tokio::test]
3228 async fn stream_chat_with_history_skips_non_streaming_providers() {
3229 let non_streaming_calls = Arc::new(AtomicUsize::new(0));
3230 let streaming_calls = Arc::new(AtomicUsize::new(0));
3231
3232 let model_provider = ReliableModelProvider::new(
3233 "test",
3234 vec![
3235 (
3236 "non-streaming".into(),
3237 Box::new(StreamingHistoryMock {
3238 stream_calls: Arc::clone(&non_streaming_calls),
3239 supports: false,
3240 }) as Box<dyn ModelProvider>,
3241 ),
3242 (
3243 "streaming".into(),
3244 Box::new(StreamingHistoryMock {
3245 stream_calls: Arc::clone(&streaming_calls),
3246 supports: true,
3247 }) as Box<dyn ModelProvider>,
3248 ),
3249 ],
3250 0,
3251 1,
3252 );
3253
3254 let messages = vec![ChatMessage::user("hello")];
3255 let mut stream = model_provider.stream_chat_with_history(
3256 &messages,
3257 "model",
3258 Some(0.0),
3259 StreamOptions::new(true),
3260 );
3261
3262 let first = stream.next().await.unwrap().unwrap();
3263 assert_eq!(first.delta, "1");
3264 assert_eq!(
3265 non_streaming_calls.load(Ordering::SeqCst),
3266 0,
3267 "non-streaming model_provider should be skipped"
3268 );
3269 assert_eq!(
3270 streaming_calls.load(Ordering::SeqCst),
3271 1,
3272 "streaming model_provider should be used"
3273 );
3274 }
3275
3276 #[tokio::test]
3277 async fn stream_chat_with_history_errors_when_no_provider_supports_streaming() {
3278 let model_provider = ReliableModelProvider::new(
3279 "test",
3280 vec![(
3281 "non-streaming".into(),
3282 Box::new(StreamingHistoryMock {
3283 stream_calls: Arc::new(AtomicUsize::new(0)),
3284 supports: false,
3285 }) as Box<dyn ModelProvider>,
3286 )],
3287 0,
3288 1,
3289 );
3290
3291 let messages = vec![ChatMessage::user("hello")];
3292 let mut stream = model_provider.stream_chat_with_history(
3293 &messages,
3294 "model",
3295 Some(0.0),
3296 StreamOptions::new(true),
3297 );
3298
3299 let first = stream.next().await.unwrap();
3300 let err = first.expect_err("should fail when no model_provider supports streaming");
3301 assert!(
3302 err.to_string()
3303 .contains("No model_provider supports streaming"),
3304 "unexpected error: {err}"
3305 );
3306 assert!(stream.next().await.is_none());
3307 }
3308
3309 #[tokio::test]
3310 async fn fallback_records_provider_fallback_info() {
3311 scope_provider_fallback(async {
3312 let model_provider = ReliableModelProvider::new(
3313 "test",
3314 vec![
3315 (
3316 "broken".into(),
3317 Box::new(MockModelProvider {
3318 calls: Arc::new(AtomicUsize::new(0)),
3319 fail_until_attempt: 99, response: "unused",
3321 error: "401 Unauthorized",
3322 }),
3323 ),
3324 (
3325 "working".into(),
3326 Box::new(MockModelProvider {
3327 calls: Arc::new(AtomicUsize::new(0)),
3328 fail_until_attempt: 0,
3329 response: "hello from working",
3330 error: "unused",
3331 }),
3332 ),
3333 ],
3334 2,
3335 1,
3336 );
3337
3338 let resp = model_provider
3339 .simple_chat("hi", "test-model", Some(0.0))
3340 .await
3341 .unwrap();
3342 assert_eq!(resp, "hello from working");
3343
3344 let fb = take_last_provider_fallback();
3345 assert!(fb.is_some(), "fallback info should be recorded");
3346 let fb = fb.unwrap();
3347 assert_eq!(fb.requested_provider, "broken");
3348 assert_eq!(fb.actual_provider, "working");
3349 assert_eq!(fb.actual_model, "test-model");
3350
3351 assert!(take_last_provider_fallback().is_none());
3353 })
3354 .await;
3355 }
3356
3357 #[test]
3361 fn supports_vision_reflects_first_provider_not_any_fallback() {
3362 struct VisionMock(bool);
3363
3364 #[async_trait]
3365 impl ModelProvider for VisionMock {
3366 async fn chat_with_system(
3367 &self,
3368 _system_prompt: Option<&str>,
3369 _message: &str,
3370 _model: &str,
3371 _temperature: Option<f64>,
3372 ) -> anyhow::Result<String> {
3373 Ok(String::new())
3374 }
3375
3376 fn supports_vision(&self) -> bool {
3377 self.0
3378 }
3379 }
3380 impl ::zeroclaw_api::attribution::Attributable for VisionMock {
3381 fn role(&self) -> ::zeroclaw_api::attribution::Role {
3382 ::zeroclaw_api::attribution::Role::Provider(
3383 ::zeroclaw_api::attribution::ProviderKind::Model(
3384 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
3385 ),
3386 )
3387 }
3388 fn alias(&self) -> &str {
3389 "VisionMock"
3390 }
3391 }
3392
3393 let provider = ReliableModelProvider::new(
3394 "test",
3395 vec![
3396 (
3397 "primary".into(),
3398 Box::new(VisionMock(false)) as Box<dyn ModelProvider>,
3399 ),
3400 (
3401 "fallback".into(),
3402 Box::new(VisionMock(true)) as Box<dyn ModelProvider>,
3403 ),
3404 ],
3405 0,
3406 0,
3407 );
3408
3409 assert!(
3410 !provider.supports_vision(),
3411 "ReliableModelProvider with non-vision primary must report supports_vision()=false even when a fallback supports vision"
3412 );
3413
3414 let provider = ReliableModelProvider::new(
3415 "test",
3416 vec![
3417 (
3418 "primary".into(),
3419 Box::new(VisionMock(true)) as Box<dyn ModelProvider>,
3420 ),
3421 (
3422 "fallback".into(),
3423 Box::new(VisionMock(false)) as Box<dyn ModelProvider>,
3424 ),
3425 ],
3426 0,
3427 0,
3428 );
3429
3430 assert!(provider.supports_vision());
3431 }
3432}