1use crate::multimodal;
6use crate::stream_guard::AbortOnDrop;
7use crate::traits::{
8 ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
9 ModelProvider, StreamChunk, StreamError, StreamEvent, StreamOptions, StreamResult, TokenUsage,
10 ToolCall as ProviderToolCall,
11};
12use async_trait::async_trait;
13use futures_util::{StreamExt, stream};
14use reqwest::{
15 Client,
16 header::{HeaderMap, HeaderValue, USER_AGENT},
17};
18use serde::{Deserialize, Serialize};
19
20#[allow(clippy::struct_excessive_bools)]
24#[derive(Clone)]
25pub struct OpenAiCompatibleModelProvider {
26 pub alias: String,
30 pub name: String,
31 pub base_url: String,
32 pub credential: Option<String>,
33 pub auth_header: AuthStyle,
34 supports_vision: bool,
35 user_agent: Option<String>,
36 merge_system_into_user: bool,
40 native_tool_calling: bool,
43 timeout_secs: u64,
45 extra_headers: std::collections::HashMap<String, String>,
47 reasoning_effort: Option<String>,
49 api_path: Option<String>,
52 max_tokens: Option<u32>,
54 models_dev_key: Option<String>,
57 openrouter_vendor_prefix: Option<String>,
63 local_model_tool_sanitize: bool,
69 public_model_listing: bool,
74}
75
76#[derive(Debug, Clone)]
78pub enum AuthStyle {
79 Bearer,
81 XApiKey,
83 Custom(String),
85 ZhipuJwt,
89}
90
91fn zhipu_jwt_bearer(credential: &str) -> Result<String, String> {
94 let (id, secret) = credential
95 .split_once('.')
96 .ok_or_else(|| "Zhipu API key must be in 'id.secret' format".to_string())?;
97
98 #[allow(clippy::cast_possible_truncation)] let now_ms = std::time::SystemTime::now()
100 .duration_since(std::time::UNIX_EPOCH)
101 .map_err(|e| e.to_string())?
102 .as_millis() as u64;
103 let exp_ms = now_ms + 210_000; let header_b64 = base64url_no_pad(br#"{"alg":"HS256","typ":"JWT","sign_type":"SIGN"}"#);
107 let payload = format!(r#"{{"api_key":"{id}","exp":{exp_ms},"timestamp":{now_ms}}}"#);
108 let payload_b64 = base64url_no_pad(payload.as_bytes());
109
110 let signing_input = format!("{header_b64}.{payload_b64}");
111 let key = ring::hmac::Key::new(ring::hmac::HMAC_SHA256, secret.as_bytes());
112 let sig = ring::hmac::sign(&key, signing_input.as_bytes());
113 let sig_b64 = base64url_no_pad(sig.as_ref());
114
115 Ok(format!("Bearer {signing_input}.{sig_b64}"))
116}
117
118fn base64url_no_pad(data: &[u8]) -> String {
119 use base64::engine::{Engine, general_purpose::URL_SAFE_NO_PAD};
120 URL_SAFE_NO_PAD.encode(data)
121}
122
123fn apply_auth_to_request(
128 req: reqwest::RequestBuilder,
129 style: &AuthStyle,
130 credential: Option<&str>,
131) -> reqwest::RequestBuilder {
132 let credential = match credential {
133 Some(c) => c,
134 None => return req,
135 };
136 match style {
137 AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
138 AuthStyle::XApiKey => req.header("x-api-key", credential),
139 AuthStyle::Custom(header) => req.header(header, credential),
140 AuthStyle::ZhipuJwt => match zhipu_jwt_bearer(credential) {
141 Ok(val) => req.header("Authorization", val),
142 Err(_) => req.header("Authorization", format!("Bearer {credential}")),
143 },
144 }
145}
146
147#[derive(Deserialize)]
148struct ModelsResponse {
149 data: Vec<ModelEntry>,
150}
151
152#[derive(Deserialize)]
153struct ModelEntry {
154 id: String,
155 #[serde(default)]
160 pricing: Option<zeroclaw_api::model_provider::ModelPricing>,
161}
162
163fn normalize_model_ids(body: ModelsResponse) -> Vec<String> {
164 let mut ids: Vec<String> = body
165 .data
166 .into_iter()
167 .map(|e| e.id.trim().to_string())
168 .filter(|id| !id.is_empty())
169 .collect();
170 ids.sort();
171 ids
172}
173
174fn normalize_models_with_pricing(
177 body: ModelsResponse,
178) -> Vec<zeroclaw_api::model_provider::ModelInfo> {
179 use zeroclaw_api::model_provider::ModelInfo;
180 let mut models: Vec<ModelInfo> = body
181 .data
182 .into_iter()
183 .filter(|e| !e.id.trim().is_empty())
184 .map(|e| ModelInfo {
185 id: e.id.trim().to_string(),
186 pricing: e.pricing,
187 })
188 .collect();
189 models.sort_by(|a, b| a.id.cmp(&b.id));
190 models
191}
192
193fn models_dev_to_model_info(ids: Vec<String>) -> Vec<zeroclaw_api::model_provider::ModelInfo> {
198 use zeroclaw_api::model_provider::ModelInfo;
199 ids.into_iter()
200 .map(|id| ModelInfo { id, pricing: None })
201 .collect()
202}
203
204impl OpenAiCompatibleModelProvider {
205 pub fn new(
206 alias: &str,
207 name: &str,
208 base_url: &str,
209 credential: Option<&str>,
210 auth_style: AuthStyle,
211 ) -> Self {
212 Self::new_with_options(
213 alias, name, base_url, credential, auth_style, false, None, false,
214 )
215 }
216
217 pub fn new_with_vision(
218 alias: &str,
219 name: &str,
220 base_url: &str,
221 credential: Option<&str>,
222 auth_style: AuthStyle,
223 supports_vision: bool,
224 ) -> Self {
225 Self::new_with_options(
226 alias,
227 name,
228 base_url,
229 credential,
230 auth_style,
231 supports_vision,
232 None,
233 false,
234 )
235 }
236
237 pub fn new_with_user_agent(
242 alias: &str,
243 name: &str,
244 base_url: &str,
245 credential: Option<&str>,
246 auth_style: AuthStyle,
247 user_agent: &str,
248 ) -> Self {
249 Self::new_with_options(
250 alias,
251 name,
252 base_url,
253 credential,
254 auth_style,
255 false,
256 Some(user_agent),
257 false,
258 )
259 }
260
261 pub fn new_with_user_agent_and_vision(
262 alias: &str,
263 name: &str,
264 base_url: &str,
265 credential: Option<&str>,
266 auth_style: AuthStyle,
267 user_agent: &str,
268 supports_vision: bool,
269 ) -> Self {
270 Self::new_with_options(
271 alias,
272 name,
273 base_url,
274 credential,
275 auth_style,
276 supports_vision,
277 Some(user_agent),
278 false,
279 )
280 }
281
282 pub fn new_merge_system_into_user(
285 alias: &str,
286 name: &str,
287 base_url: &str,
288 credential: Option<&str>,
289 auth_style: AuthStyle,
290 ) -> Self {
291 Self::new_with_options(
292 alias, name, base_url, credential, auth_style, false, None, true,
293 )
294 }
295
296 fn new_with_options(
297 alias: &str,
298 name: &str,
299 base_url: &str,
300 credential: Option<&str>,
301 auth_style: AuthStyle,
302 supports_vision: bool,
303 user_agent: Option<&str>,
304 merge_system_into_user: bool,
305 ) -> Self {
306 Self {
307 alias: alias.to_string(),
308 name: name.to_string(),
309 base_url: base_url.trim_end_matches('/').to_string(),
310 credential: credential.map(ToString::to_string),
311 auth_header: auth_style,
312 supports_vision,
313 user_agent: user_agent.map(ToString::to_string),
314 merge_system_into_user,
315 native_tool_calling: !merge_system_into_user,
316 timeout_secs: 120,
317 extra_headers: std::collections::HashMap::new(),
318 reasoning_effort: None,
319 api_path: None,
320 max_tokens: None,
321 models_dev_key: None,
322 openrouter_vendor_prefix: None,
323 local_model_tool_sanitize: false,
324 public_model_listing: false,
325 }
326 }
327 pub fn with_local_model_tool_sanitize(mut self) -> Self {
334 self.local_model_tool_sanitize = true;
335 self
336 }
337
338 pub fn with_public_model_listing(mut self) -> Self {
339 self.public_model_listing = true;
340 self
341 }
342
343 pub fn without_native_tools(mut self) -> Self {
345 self.native_tool_calling = false;
346 self
347 }
348
349 pub fn with_merge_system_into_user(mut self) -> Self {
352 self.merge_system_into_user = true;
353 self
354 }
355
356 pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
358 self.timeout_secs = timeout_secs;
359 self
360 }
361
362 pub fn with_extra_headers(
364 mut self,
365 headers: std::collections::HashMap<String, String>,
366 ) -> Self {
367 self.extra_headers = headers;
368 self
369 }
370
371 pub fn with_reasoning_effort(mut self, reasoning_effort: Option<String>) -> Self {
373 self.reasoning_effort = reasoning_effort;
374 self
375 }
376
377 pub fn with_api_path(mut self, api_path: Option<String>) -> Self {
380 self.api_path = api_path;
381 self
382 }
383
384 pub fn with_max_tokens(mut self, max_tokens: Option<u32>) -> Self {
386 self.max_tokens = max_tokens;
387 self
388 }
389
390 pub fn with_models_dev_key(mut self, key: &str) -> Self {
393 self.models_dev_key = Some(key.to_string());
394 self
395 }
396
397 pub fn with_openrouter_vendor_prefix(mut self, prefix: &str) -> Self {
401 self.openrouter_vendor_prefix = Some(prefix.to_string());
402 self
403 }
404
405 fn flatten_system_messages(messages: &[ChatMessage], merge: bool) -> Vec<ChatMessage> {
409 let mut saw_system = false;
410 let mut system_content = String::new();
411 let mut result: Vec<ChatMessage> = Vec::with_capacity(messages.len());
412
413 for message in messages {
414 if message.role == "system" {
415 saw_system = true;
416 if !message.content.is_empty() {
417 if !system_content.is_empty() {
418 system_content.push_str("\n\n");
419 }
420 system_content.push_str(&message.content);
421 }
422 } else {
423 result.push(message.clone());
424 }
425 }
426
427 if !saw_system {
428 return messages.to_vec();
429 }
430
431 if system_content.is_empty() {
432 return result;
433 }
434
435 if !merge {
436 result.insert(0, ChatMessage::system(system_content));
437 return result;
438 }
439
440 if let Some(first_user) = result.iter_mut().find(|m| m.role == "user") {
441 if !system_content.is_empty() {
442 first_user.content = format!("{system_content}\n\n{}", first_user.content);
443 }
444 } else {
445 result.insert(0, ChatMessage::user(&system_content));
447 }
448
449 result
450 }
451
452 fn http_client(&self) -> Client {
453 let timeout = self.timeout_secs;
454 let has_user_agent = self.user_agent.is_some();
455 let has_extra_headers = !self.extra_headers.is_empty();
456
457 if has_user_agent || has_extra_headers {
458 let mut headers = HeaderMap::new();
459 if let Some(ua) = self.user_agent.as_deref()
460 && let Ok(value) = HeaderValue::from_str(ua)
461 {
462 headers.insert(USER_AGENT, value);
463 }
464 for (key, value) in &self.extra_headers {
465 match (
466 reqwest::header::HeaderName::from_bytes(key.as_bytes()),
467 HeaderValue::from_str(value),
468 ) {
469 (Ok(name), Ok(val)) => {
470 headers.insert(name, val);
471 }
472 _ => {
473 ::zeroclaw_log::record!(
474 WARN,
475 ::zeroclaw_log::Event::new(
476 module_path!(),
477 ::zeroclaw_log::Action::Note
478 )
479 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
480 .with_attrs(::serde_json::json!({"header": key})),
481 "Skipping invalid extra header name or value"
482 );
483 }
484 }
485 }
486
487 let builder = Client::builder()
488 .timeout(std::time::Duration::from_secs(timeout))
489 .connect_timeout(std::time::Duration::from_secs(10))
490 .default_headers(headers);
491 let builder = zeroclaw_config::schema::apply_runtime_proxy_to_builder(
492 builder,
493 "model_provider.compatible",
494 );
495
496 return builder.build().unwrap_or_else(|error| {
497 ::zeroclaw_log::record!(
498 WARN,
499 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
500 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
501 .with_attrs(
502 ::serde_json::json!({"error": super::format_error_chain(&error)})
503 ),
504 "Failed to build proxied timeout client with custom headers: "
505 );
506 Client::new()
507 });
508 }
509
510 zeroclaw_config::schema::build_runtime_proxy_client_with_timeouts(
511 "model_provider.compatible",
512 timeout,
513 10,
514 )
515 }
516
517 fn streaming_http_client(&self) -> Client {
521 let has_user_agent = self.user_agent.is_some();
522 let has_extra_headers = !self.extra_headers.is_empty();
523
524 if has_user_agent || has_extra_headers {
525 let mut headers = HeaderMap::new();
526 if let Some(ua) = self.user_agent.as_deref()
527 && let Ok(value) = HeaderValue::from_str(ua)
528 {
529 headers.insert(USER_AGENT, value);
530 }
531 for (key, value) in &self.extra_headers {
532 match (
533 reqwest::header::HeaderName::from_bytes(key.as_bytes()),
534 HeaderValue::from_str(value),
535 ) {
536 (Ok(name), Ok(val)) => {
537 headers.insert(name, val);
538 }
539 _ => {
540 ::zeroclaw_log::record!(
541 WARN,
542 ::zeroclaw_log::Event::new(
543 module_path!(),
544 ::zeroclaw_log::Action::Note
545 )
546 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
547 .with_attrs(::serde_json::json!({"header": key})),
548 "Skipping invalid extra header name or value"
549 );
550 }
551 }
552 }
553
554 let builder = Client::builder()
555 .connect_timeout(std::time::Duration::from_secs(10))
556 .default_headers(headers);
557 let builder = zeroclaw_config::schema::apply_runtime_proxy_to_builder(
558 builder,
559 "provider.compatible",
560 );
561 return builder.build().unwrap_or_else(|error| {
562 ::zeroclaw_log::record!(
563 WARN,
564 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
565 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
566 .with_attrs(
567 ::serde_json::json!({"error": super::format_error_chain(&error)})
568 ),
569 "Failed to build proxied streaming client with custom headers: "
570 );
571 Client::new()
572 });
573 }
574
575 let builder = Client::builder().connect_timeout(std::time::Duration::from_secs(10));
576 let builder =
577 zeroclaw_config::schema::apply_runtime_proxy_to_builder(builder, "provider.compatible");
578 builder.build().unwrap_or_else(|error| {
579 ::zeroclaw_log::record!(
580 WARN,
581 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
582 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
583 .with_attrs(::serde_json::json!({"error": super::format_error_chain(&error)})),
584 "Failed to build proxied streaming client: "
585 );
586 Client::new()
587 })
588 }
589
590 fn chat_completions_url(&self) -> String {
594 if let Some(ref api_path) = self.api_path {
596 let separator = if api_path.starts_with('/') { "" } else { "/" };
597 return format!("{}{separator}{api_path}", self.base_url);
598 }
599
600 let has_full_endpoint = reqwest::Url::parse(&self.base_url)
601 .map(|url| {
602 url.path()
603 .trim_end_matches('/')
604 .ends_with("/chat/completions")
605 })
606 .unwrap_or_else(|_| {
607 self.base_url
608 .trim_end_matches('/')
609 .ends_with("/chat/completions")
610 });
611
612 if has_full_endpoint {
613 self.base_url.clone()
614 } else {
615 format!("{}/chat/completions", self.base_url)
616 }
617 }
618
619 fn requires_tool_stream(&self) -> bool {
620 let host_requires_tool_stream = reqwest::Url::parse(&self.base_url)
621 .ok()
622 .and_then(|url| url.host_str().map(str::to_ascii_lowercase))
623 .is_some_and(|host| host == "api.z.ai" || host.ends_with(".z.ai"));
624
625 host_requires_tool_stream || matches!(self.name.as_str(), "zai" | "z.ai")
626 }
627
628 fn tool_stream_for_tools(&self, has_tools: bool) -> Option<bool> {
629 if has_tools && self.requires_tool_stream() {
630 Some(true)
631 } else {
632 None
633 }
634 }
635
636 fn model_requires_system_merge(model: &str) -> bool {
640 let id = model
641 .rsplit('/')
642 .next()
643 .unwrap_or(model)
644 .to_ascii_lowercase();
645 id.contains("deepseek-v3") || id.contains("deepseek_v3")
646 }
647
648 fn effective_merge_system(&self, model: &str) -> bool {
651 self.merge_system_into_user || Self::model_requires_system_merge(model)
652 }
653
654 fn reasoning_effort_for_model(&self, model: &str) -> Option<String> {
655 let effort = self.reasoning_effort.as_ref()?;
656 let id = model
657 .rsplit('/')
658 .next()
659 .unwrap_or(model)
660 .to_ascii_lowercase();
661 let is_gpt5_chat_latest = id.starts_with("gpt-5") && id.ends_with("-chat-latest");
666 let is_openai_reasoning_model = id == "o1"
667 || id.starts_with("o1-")
668 || id == "o3"
669 || id.starts_with("o3-")
670 || id == "o4"
671 || id.starts_with("o4-")
672 || (id.starts_with("gpt-5") && !is_gpt5_chat_latest);
673 let is_likely_codex_supported = id.contains("codex") && id.starts_with("gpt-");
674
675 (is_openai_reasoning_model || is_likely_codex_supported).then(|| effort.clone())
676 }
677}
678
679#[derive(Debug, Serialize)]
680struct ApiChatRequest {
681 model: String,
682 messages: Vec<Message>,
683 #[serde(skip_serializing_if = "Option::is_none")]
684 temperature: Option<f64>,
685 #[serde(skip_serializing_if = "Option::is_none")]
686 stream: Option<bool>,
687 #[serde(skip_serializing_if = "Option::is_none")]
688 stream_options: Option<StreamOptionsBody>,
689 #[serde(skip_serializing_if = "Option::is_none")]
690 reasoning_effort: Option<String>,
691 #[serde(skip_serializing_if = "Option::is_none")]
692 tool_stream: Option<bool>,
693 #[serde(skip_serializing_if = "Option::is_none")]
694 tools: Option<Vec<serde_json::Value>>,
695 #[serde(skip_serializing_if = "Option::is_none")]
696 tool_choice: Option<String>,
697 #[serde(skip_serializing_if = "Option::is_none")]
698 max_tokens: Option<u32>,
699}
700
701#[derive(Debug, Serialize, Clone, Copy)]
706struct StreamOptionsBody {
707 include_usage: bool,
708}
709
710#[derive(Debug, Serialize)]
711struct Message {
712 role: String,
713 content: MessageContent,
714}
715
716#[derive(Debug, Serialize)]
717#[serde(untagged)]
718enum MessageContent {
719 Text(String),
720 Parts(Vec<MessagePart>),
721}
722
723#[derive(Debug, Serialize)]
724#[serde(tag = "type", rename_all = "snake_case")]
725enum MessagePart {
726 Text { text: String },
727 ImageUrl { image_url: ImageUrlPart },
728}
729
730#[derive(Debug, Serialize)]
731struct ImageUrlPart {
732 url: String,
733}
734
735#[derive(Debug, Deserialize)]
736struct ApiChatResponse {
737 choices: Vec<Choice>,
738 #[serde(default)]
739 usage: Option<UsageInfo>,
740}
741
742#[derive(Debug, Deserialize)]
743struct UsageInfo {
744 #[serde(default)]
745 prompt_tokens: Option<u64>,
746 #[serde(default)]
747 completion_tokens: Option<u64>,
748}
749
750#[derive(Debug, Deserialize)]
751struct Choice {
752 message: ResponseMessage,
753}
754
755fn strip_think_tags(s: &str) -> String {
760 let mut result = String::with_capacity(s.len());
761 let mut rest = s;
762 loop {
763 if let Some(start) = rest.find("<think>") {
764 result.push_str(&rest[..start]);
765 if let Some(end) = rest[start..].find("</think>") {
766 rest = &rest[start + end + "</think>".len()..];
767 } else {
768 break;
770 }
771 } else {
772 result.push_str(rest);
773 break;
774 }
775 }
776 result.trim().to_string()
777}
778
779fn openai_assistant_content_plaintext(content: Option<OpenAiAssistantContent>) -> Option<String> {
784 match content? {
785 OpenAiAssistantContent::Text(s) => {
786 if s.is_empty() {
787 None
788 } else {
789 Some(s)
790 }
791 }
792 OpenAiAssistantContent::Parts(parts) => {
793 let mut text = String::new();
794 for part in parts {
795 if part.kind.as_deref() != Some("text") {
796 continue;
797 }
798 let Some(part_text) = part.text.filter(|text| !text.is_empty()) else {
799 continue;
800 };
801 if !text.is_empty() {
802 text.push('\n');
803 }
804 text.push_str(&part_text);
805 }
806
807 if text.is_empty() { None } else { Some(text) }
808 }
809 }
810}
811
812#[derive(Debug, Deserialize)]
813#[serde(untagged)]
814enum OpenAiAssistantContent {
815 Text(String),
816 Parts(Vec<OpenAiAssistantContentPart>),
817}
818
819#[derive(Debug, Deserialize)]
820struct OpenAiAssistantContentPart {
821 #[serde(rename = "type")]
822 kind: Option<String>,
823 text: Option<String>,
824}
825
826#[derive(Debug, Deserialize, Serialize)]
827#[serde(from = "RawResponseMessage")]
828struct ResponseMessage {
829 content: Option<String>,
830 reasoning_content: Option<String>,
838 tool_calls: Option<Vec<ToolCall>>,
839}
840
841#[derive(Debug, Deserialize)]
849struct RawResponseMessage {
850 #[serde(default)]
851 content: Option<OpenAiAssistantContent>,
852 #[serde(default)]
853 reasoning_content: Option<String>,
854 #[serde(default)]
855 reasoning: Option<String>,
856 #[serde(default)]
857 tool_calls: Option<Vec<ToolCall>>,
858}
859
860impl From<RawResponseMessage> for ResponseMessage {
861 fn from(raw: RawResponseMessage) -> Self {
862 let reasoning_content = raw.reasoning_content.or(raw.reasoning);
865 ResponseMessage {
866 content: openai_assistant_content_plaintext(raw.content),
867 reasoning_content,
868 tool_calls: raw.tool_calls,
869 }
870 }
871}
872
873impl ResponseMessage {
874 fn effective_content(&self) -> String {
880 if let Some(content) = self.content.as_ref().filter(|c| !c.is_empty()) {
881 let stripped = strip_think_tags(content);
882 if !stripped.is_empty() {
883 return stripped;
884 }
885 }
886
887 self.reasoning_content
888 .as_ref()
889 .map(|c| strip_think_tags(c))
890 .filter(|c| !c.is_empty())
891 .unwrap_or_default()
892 }
893
894 fn effective_content_optional(&self) -> Option<String> {
895 if let Some(content) = self.content.as_ref().filter(|c| !c.is_empty()) {
896 let stripped = strip_think_tags(content);
897 if !stripped.is_empty() {
898 return Some(stripped);
899 }
900 }
901
902 self.reasoning_content
903 .as_ref()
904 .map(|c| strip_think_tags(c))
905 .filter(|c| !c.is_empty())
906 }
907}
908
909#[derive(Debug, Deserialize, Serialize)]
910struct ToolCall {
911 #[serde(skip_serializing_if = "Option::is_none")]
912 id: Option<String>,
913 #[serde(rename = "type")]
914 #[serde(default, skip_serializing_if = "Option::is_none")]
915 kind: Option<String>,
916 #[serde(default, skip_serializing_if = "Option::is_none")]
917 function: Option<Function>,
918
919 #[serde(default, skip_serializing_if = "Option::is_none")]
921 name: Option<String>,
922 #[serde(default, skip_serializing_if = "Option::is_none")]
923 arguments: Option<String>,
924
925 #[serde(
927 rename = "parameters",
928 default,
929 skip_serializing_if = "Option::is_none"
930 )]
931 parameters: Option<serde_json::Value>,
932
933 #[serde(default, skip_serializing_if = "Option::is_none")]
935 extra_content: Option<serde_json::Value>,
936}
937
938impl ToolCall {
939 fn function_name(&self) -> Option<String> {
941 if let Some(ref func) = self.function
943 && let Some(ref name) = func.name
944 {
945 return Some(name.clone());
946 }
947 self.name.clone()
949 }
950
951 fn function_arguments(&self) -> Option<String> {
953 if let Some(ref func) = self.function
955 && let Some(ref args) = func.arguments
956 {
957 return Some(args.clone());
958 }
959 if let Some(ref args) = self.arguments {
961 return Some(args.clone());
962 }
963 if let Some(ref params) = self.parameters {
965 return serde_json::to_string(params).ok();
966 }
967 None
968 }
969}
970
971#[derive(Debug, Deserialize, Serialize)]
972struct Function {
973 #[serde(default)]
974 name: Option<String>,
975 #[serde(default)]
976 arguments: Option<String>,
977}
978
979#[derive(Debug, Serialize)]
980struct NativeChatRequest {
981 model: String,
982 messages: Vec<NativeMessage>,
983 #[serde(skip_serializing_if = "Option::is_none")]
984 temperature: Option<f64>,
985 #[serde(skip_serializing_if = "Option::is_none")]
986 stream: Option<bool>,
987 #[serde(skip_serializing_if = "Option::is_none")]
993 stream_options: Option<StreamOptionsBody>,
994 #[serde(skip_serializing_if = "Option::is_none")]
995 reasoning_effort: Option<String>,
996 #[serde(skip_serializing_if = "Option::is_none")]
997 tool_stream: Option<bool>,
998 #[serde(skip_serializing_if = "Option::is_none")]
999 tools: Option<Vec<serde_json::Value>>,
1000 #[serde(skip_serializing_if = "Option::is_none")]
1001 tool_choice: Option<String>,
1002 #[serde(skip_serializing_if = "Option::is_none")]
1003 max_tokens: Option<u32>,
1004}
1005
1006#[derive(Debug, Serialize)]
1007struct NativeMessage {
1008 role: String,
1009 #[serde(skip_serializing_if = "Option::is_none")]
1010 content: Option<MessageContent>,
1011 #[serde(skip_serializing_if = "Option::is_none")]
1012 tool_call_id: Option<String>,
1013 #[serde(skip_serializing_if = "Option::is_none")]
1014 tool_calls: Option<Vec<ToolCall>>,
1015 #[serde(skip_serializing_if = "Option::is_none")]
1018 reasoning_content: Option<String>,
1019}
1020
1021#[derive(Debug, Deserialize)]
1027struct StreamChunkResponse {
1028 #[serde(default)]
1029 choices: Vec<StreamChoice>,
1030 #[serde(default)]
1033 usage: Option<UsageInfo>,
1034}
1035
1036#[derive(Debug, Deserialize)]
1037struct StreamChoice {
1038 #[serde(default)]
1039 delta: StreamDelta,
1040 #[serde(default)]
1041 finish_reason: Option<String>,
1042}
1043
1044#[derive(Debug, Default)]
1045struct StreamDelta {
1046 content: Option<String>,
1047 reasoning_content: Option<String>,
1053 tool_calls: Option<Vec<StreamToolCallDelta>>,
1055}
1056
1057#[derive(Debug, Deserialize, Default)]
1063struct RawStreamDelta {
1064 #[serde(default)]
1065 content: Option<String>,
1066 #[serde(default)]
1067 reasoning_content: Option<String>,
1068 #[serde(default)]
1069 reasoning: Option<String>,
1070 #[serde(default)]
1071 tool_calls: Option<Vec<StreamToolCallDelta>>,
1072}
1073
1074impl<'de> Deserialize<'de> for StreamDelta {
1075 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1076 where
1077 D: serde::Deserializer<'de>,
1078 {
1079 let raw = RawStreamDelta::deserialize(deserializer)?;
1080 Ok(StreamDelta {
1081 content: raw.content,
1082 reasoning_content: raw.reasoning_content.or(raw.reasoning),
1083 tool_calls: raw.tool_calls,
1084 })
1085 }
1086}
1087
1088#[derive(Debug, Deserialize)]
1089struct StreamToolCallDelta {
1090 #[serde(default)]
1091 index: Option<usize>,
1092 #[serde(default)]
1093 id: Option<String>,
1094 #[serde(default)]
1095 function: Option<StreamFunctionDelta>,
1096 #[serde(default)]
1098 name: Option<String>,
1099 #[serde(default)]
1100 arguments: Option<String>,
1101 #[serde(default)]
1102 extra_content: Option<serde_json::Value>,
1103}
1104
1105#[derive(Debug, Deserialize)]
1106struct StreamFunctionDelta {
1107 #[serde(default)]
1108 name: Option<String>,
1109 #[serde(default)]
1110 arguments: Option<String>,
1111}
1112
1113#[derive(Debug, Default)]
1114struct StreamToolCallAccumulator {
1115 id: Option<String>,
1116 name: Option<String>,
1117 arguments: String,
1118 extra_content: Option<serde_json::Value>,
1119}
1120
1121impl StreamToolCallAccumulator {
1122 fn apply_delta(&mut self, delta: &StreamToolCallDelta) {
1123 if let Some(id) = delta.id.as_ref().filter(|value| !value.is_empty()) {
1124 self.id = Some(id.clone());
1125 }
1126
1127 let delta_name = delta
1128 .function
1129 .as_ref()
1130 .and_then(|function| function.name.as_ref())
1131 .or(delta.name.as_ref())
1132 .filter(|value| !value.is_empty());
1133 if let Some(name) = delta_name {
1134 self.name = Some(name.clone());
1135 }
1136
1137 if let Some(arguments_delta) = delta
1138 .function
1139 .as_ref()
1140 .and_then(|function| function.arguments.as_ref())
1141 .or(delta.arguments.as_ref())
1142 .filter(|value| !value.is_empty())
1143 {
1144 self.arguments.push_str(arguments_delta);
1145 }
1146
1147 if let Some(extra) = delta.extra_content.as_ref() {
1149 self.extra_content = Some(extra.clone());
1150 }
1151 }
1152
1153 fn into_provider_tool_call(
1154 self,
1155 targets_mistral_tool_call_contract: bool,
1156 used_tool_call_ids: &mut std::collections::HashSet<String>,
1157 ) -> Option<ProviderToolCall> {
1158 let name = self.name?;
1159 let arguments = if self.arguments.trim().is_empty() {
1160 "{}".to_string()
1161 } else {
1162 self.arguments
1163 };
1164 let normalized_arguments = if serde_json::from_str::<serde_json::Value>(&arguments).is_ok()
1165 {
1166 arguments
1167 } else {
1168 ::zeroclaw_log::record!(
1169 WARN,
1170 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
1171 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
1172 .with_attrs(::serde_json::json!({"function": name, "arguments": arguments})),
1173 "Invalid JSON in streamed native tool-call arguments, using empty object"
1174 );
1175 "{}".to_string()
1176 };
1177
1178 Some(ProviderToolCall {
1179 id: reserve_tool_call_id_for_contract(
1180 targets_mistral_tool_call_contract,
1181 self.id,
1182 used_tool_call_ids,
1183 ),
1184 name,
1185 arguments: normalized_arguments,
1186 extra_content: self.extra_content,
1187 })
1188 }
1189}
1190
1191fn parse_sse_chunk(line: &str) -> StreamResult<Option<StreamChunkResponse>> {
1192 let line = line.trim();
1193
1194 if line.is_empty() || line.starts_with(':') {
1195 return Ok(None);
1196 }
1197
1198 let Some(data) = line.strip_prefix("data:") else {
1199 return Ok(None);
1200 };
1201 let data = data.trim();
1202
1203 if data == "[DONE]" {
1204 return Ok(None);
1205 }
1206
1207 serde_json::from_str(data)
1208 .map(Some)
1209 .map_err(StreamError::Json)
1210}
1211
1212fn parse_proxy_tool_event(line: &str) -> Option<StreamEvent> {
1216 let data = line.trim().strip_prefix("data:")?.trim();
1217 let obj: serde_json::Value = serde_json::from_str(data).ok()?;
1218
1219 if let Some(ts) = obj.get("x_tool_start") {
1220 let Some(name) = ts.get("name").and_then(|v| v.as_str()) else {
1221 ::zeroclaw_log::record!(
1222 DEBUG,
1223 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
1224 "proxy x_tool_start event missing required 'name' field"
1225 );
1226 return None;
1227 };
1228 let name = name.to_string();
1229 let args = ts
1230 .get("arguments")
1231 .and_then(|v| v.as_str())
1232 .unwrap_or("{}")
1233 .to_string();
1234 return Some(StreamEvent::PreExecutedToolCall { name, args });
1235 }
1236
1237 if let Some(tr) = obj.get("x_tool_result") {
1238 let name = tr
1239 .get("name")
1240 .and_then(|v| v.as_str())
1241 .unwrap_or("unknown")
1242 .to_string();
1243 let output = tr
1244 .get("output")
1245 .and_then(|v| v.as_str())
1246 .unwrap_or("")
1247 .to_string();
1248 return Some(StreamEvent::PreExecutedToolResult { name, output });
1249 }
1250
1251 None
1252}
1253
1254fn extract_sse_text_delta(choice: &StreamChoice) -> Option<String> {
1255 if let Some(content) = &choice.delta.content
1256 && !content.is_empty()
1257 {
1258 return Some(content.clone());
1259 }
1260
1261 None
1262}
1263
1264fn extract_sse_reasoning_delta(choice: &StreamChoice) -> Option<String> {
1265 choice
1266 .delta
1267 .reasoning_content
1268 .as_ref()
1269 .filter(|value| !value.is_empty())
1270 .cloned()
1271}
1272
1273fn is_valid_mistral_tool_call_id(id: &str) -> bool {
1274 id.len() == 9 && id.chars().all(|c| c.is_ascii_alphanumeric())
1275}
1276
1277fn reserve_tool_call_id_for_contract(
1278 targets_mistral_tool_call_contract: bool,
1279 raw_id: Option<String>,
1280 used_ids: &mut std::collections::HashSet<String>,
1281) -> String {
1282 if !targets_mistral_tool_call_contract {
1283 let id = raw_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
1284 if used_ids.insert(id.clone()) {
1285 return id;
1286 }
1287
1288 loop {
1289 let candidate = uuid::Uuid::new_v4().to_string();
1290 if used_ids.insert(candidate.clone()) {
1291 return candidate;
1292 }
1293 }
1294 }
1295
1296 if let Some(id) = raw_id.as_deref()
1297 && is_valid_mistral_tool_call_id(id)
1298 && used_ids.insert(id.to_string())
1299 {
1300 return id.to_string();
1301 }
1302
1303 let mut candidate = raw_id
1304 .as_deref()
1305 .unwrap_or_default()
1306 .chars()
1307 .filter(|c| c.is_ascii_alphanumeric())
1308 .take(9)
1309 .collect::<String>();
1310
1311 if candidate.len() < 9 {
1312 candidate.extend(
1313 uuid::Uuid::new_v4()
1314 .as_simple()
1315 .to_string()
1316 .chars()
1317 .take(9 - candidate.len()),
1318 );
1319 }
1320
1321 if used_ids.insert(candidate.clone()) {
1322 return candidate;
1323 }
1324
1325 loop {
1326 let generated = uuid::Uuid::new_v4()
1327 .as_simple()
1328 .to_string()
1329 .chars()
1330 .take(9)
1331 .collect::<String>();
1332 if used_ids.insert(generated.clone()) {
1333 return generated;
1334 }
1335 }
1336}
1337
1338fn parse_sse_line(line: &str) -> StreamResult<Option<StreamChunk>> {
1345 let chunk = match parse_sse_chunk(line)? {
1346 Some(c) => c,
1347 None => return Ok(None),
1348 };
1349
1350 if let Some(choice) = chunk.choices.first() {
1351 if let Some(content) = &choice.delta.content
1352 && !content.is_empty()
1353 {
1354 return Ok(Some(StreamChunk::delta(content.clone())));
1355 }
1356 if let Some(reasoning) = &choice.delta.reasoning_content
1357 && !reasoning.is_empty()
1358 {
1359 return Ok(Some(StreamChunk::reasoning(reasoning.clone())));
1360 }
1361 }
1362
1363 Ok(None)
1364}
1365
1366fn sse_bytes_to_chunks(
1368 response: reqwest::Response,
1369 count_tokens: bool,
1370) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1371 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1372
1373 let handle = ::zeroclaw_spawn::spawn!(async move {
1374 let mut buffer = String::new();
1375
1376 match response.error_for_status_ref() {
1377 Ok(_) => {}
1378 Err(e) => {
1379 let _ = tx
1380 .send(Err(StreamError::Http(super::format_error_chain(&e))))
1381 .await;
1382 return;
1383 }
1384 }
1385
1386 let mut bytes_stream = response.bytes_stream();
1387 let mut utf8_buf: Vec<u8> = Vec::new();
1390
1391 while let Some(item) = bytes_stream.next().await {
1392 match item {
1393 Ok(bytes) => {
1394 utf8_buf.extend_from_slice(&bytes);
1395 let text = match std::str::from_utf8(&utf8_buf) {
1396 Ok(s) => {
1397 let owned = s.to_string();
1398 utf8_buf.clear();
1399 owned
1400 }
1401 Err(e) => {
1402 let valid_up_to = e.valid_up_to();
1403 if valid_up_to == 0 && utf8_buf.len() < 4 {
1404 continue;
1406 }
1407 let valid =
1408 String::from_utf8_lossy(&utf8_buf[..valid_up_to]).into_owned();
1409 utf8_buf.drain(..valid_up_to);
1410 valid
1411 }
1412 };
1413 if text.is_empty() {
1414 continue;
1415 }
1416
1417 buffer.push_str(&text);
1418
1419 while let Some(pos) = buffer.find('\n') {
1420 let line = buffer[..pos].to_string();
1421 buffer.drain(..=pos);
1422
1423 match parse_sse_line(&line) {
1424 Ok(Some(chunk)) => {
1425 let chunk = if count_tokens {
1426 chunk.with_token_estimate()
1427 } else {
1428 chunk
1429 };
1430 if tx.send(Ok(chunk)).await.is_err() {
1431 return; }
1433 }
1434 Ok(None) => {}
1435 Err(e) => {
1436 let _ = tx.send(Err(e)).await;
1437 return;
1438 }
1439 }
1440 }
1441 }
1442 Err(e) => {
1443 let _ = tx
1444 .send(Err(StreamError::Http(super::format_error_chain(&e))))
1445 .await;
1446 return;
1447 }
1448 }
1449 }
1450
1451 let _ = tx.send(Ok(StreamChunk::final_chunk())).await;
1452 });
1453
1454 let guard = AbortOnDrop::new(handle.abort_handle());
1455 stream::unfold((rx, guard), |(mut rx, guard)| async {
1456 rx.recv().await.map(|chunk| (chunk, (rx, guard)))
1457 })
1458 .boxed()
1459}
1460
1461pub(crate) fn sse_bytes_to_events(
1463 response: reqwest::Response,
1464 count_tokens: bool,
1465) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
1466 sse_bytes_to_events_for_contract(response, count_tokens, false)
1467}
1468
1469fn sse_bytes_to_events_for_contract(
1470 response: reqwest::Response,
1471 count_tokens: bool,
1472 targets_mistral_tool_call_contract: bool,
1473) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
1474 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamEvent>>(100);
1475
1476 let handle = ::zeroclaw_spawn::spawn!(async move {
1477 let mut buffer = String::new();
1478 let mut tool_calls: Vec<StreamToolCallAccumulator> = Vec::new();
1479 let mut used_tool_call_ids = std::collections::HashSet::new();
1480 let mut emitted_tool_calls = false;
1481
1482 match response.error_for_status_ref() {
1483 Ok(_) => {}
1484 Err(e) => {
1485 let _ = tx
1486 .send(Err(StreamError::Http(super::format_error_chain(&e))))
1487 .await;
1488 return;
1489 }
1490 }
1491
1492 let mut bytes_stream = response.bytes_stream();
1493 let mut utf8_buf: Vec<u8> = Vec::new();
1495 while let Some(item) = bytes_stream.next().await {
1496 match item {
1497 Ok(bytes) => {
1498 utf8_buf.extend_from_slice(&bytes);
1499 let text = match std::str::from_utf8(&utf8_buf) {
1500 Ok(s) => {
1501 let owned = s.to_string();
1502 utf8_buf.clear();
1503 owned
1504 }
1505 Err(e) => {
1506 let valid_up_to = e.valid_up_to();
1507 if valid_up_to == 0 && utf8_buf.len() < 4 {
1508 continue;
1509 }
1510 let valid =
1511 String::from_utf8_lossy(&utf8_buf[..valid_up_to]).into_owned();
1512 utf8_buf.drain(..valid_up_to);
1513 valid
1514 }
1515 };
1516 if text.is_empty() {
1517 continue;
1518 }
1519
1520 buffer.push_str(&text);
1521
1522 while let Some(pos) = buffer.find('\n') {
1523 let line = buffer[..pos].to_string();
1524 buffer.drain(..=pos);
1525
1526 if let Some(event) = parse_proxy_tool_event(&line) {
1529 if tx.send(Ok(event)).await.is_err() {
1530 return;
1531 }
1532 continue;
1533 }
1534
1535 let chunk = match parse_sse_chunk(&line) {
1536 Ok(Some(chunk)) => chunk,
1537 Ok(None) => continue,
1538 Err(e) => {
1539 let _ = tx.send(Err(e)).await;
1540 return;
1541 }
1542 };
1543
1544 let mut should_emit_tool_calls = false;
1545 for choice in &chunk.choices {
1546 if let Some(reasoning_delta) = extract_sse_reasoning_delta(choice) {
1547 let reasoning_chunk = StreamChunk::reasoning(reasoning_delta);
1548 if tx
1549 .send(Ok(StreamEvent::TextDelta(reasoning_chunk)))
1550 .await
1551 .is_err()
1552 {
1553 return;
1554 }
1555 }
1556 if let Some(text_delta) = extract_sse_text_delta(choice) {
1557 let mut text_chunk = StreamChunk::delta(text_delta);
1558 if count_tokens {
1559 text_chunk = text_chunk.with_token_estimate();
1560 }
1561 if tx
1562 .send(Ok(StreamEvent::TextDelta(text_chunk)))
1563 .await
1564 .is_err()
1565 {
1566 return;
1567 }
1568 }
1569
1570 if let Some(deltas) = choice.delta.tool_calls.as_ref() {
1571 for delta in deltas {
1572 let index = delta.index.unwrap_or(tool_calls.len());
1573 if index >= tool_calls.len() {
1574 tool_calls.resize_with(index + 1, Default::default);
1575 }
1576 if let Some(acc) = tool_calls.get_mut(index) {
1577 acc.apply_delta(delta);
1578 }
1579 }
1580 }
1581
1582 if choice.finish_reason.as_deref() == Some("tool_calls") {
1583 should_emit_tool_calls = true;
1584 }
1585 }
1586
1587 if let Some(usage) = chunk.usage.as_ref() {
1588 let token_usage = zeroclaw_api::model_provider::TokenUsage {
1589 input_tokens: usage.prompt_tokens,
1590 output_tokens: usage.completion_tokens,
1591 cached_input_tokens: None,
1592 };
1593 if tx.send(Ok(StreamEvent::Usage(token_usage))).await.is_err() {
1594 return;
1595 }
1596 }
1597
1598 if should_emit_tool_calls && !emitted_tool_calls {
1599 emitted_tool_calls = true;
1600 for tool_call in tool_calls.drain(..).filter_map(|tool_call| {
1601 tool_call.into_provider_tool_call(
1602 targets_mistral_tool_call_contract,
1603 &mut used_tool_call_ids,
1604 )
1605 }) {
1606 if tx.send(Ok(StreamEvent::ToolCall(tool_call))).await.is_err() {
1607 return;
1608 }
1609 }
1610 }
1611 }
1612 }
1613 Err(e) => {
1614 let _ = tx
1615 .send(Err(StreamError::Http(super::format_error_chain(&e))))
1616 .await;
1617 return;
1618 }
1619 }
1620 }
1621
1622 if !emitted_tool_calls {
1623 for tool_call in tool_calls.drain(..).filter_map(|tool_call| {
1624 tool_call.into_provider_tool_call(
1625 targets_mistral_tool_call_contract,
1626 &mut used_tool_call_ids,
1627 )
1628 }) {
1629 if tx.send(Ok(StreamEvent::ToolCall(tool_call))).await.is_err() {
1630 return;
1631 }
1632 }
1633 }
1634
1635 let _ = tx.send(Ok(StreamEvent::Final)).await;
1636 });
1637
1638 let guard = AbortOnDrop::new(handle.abort_handle());
1639 stream::unfold((rx, guard), |(mut rx, guard)| async move {
1640 rx.recv().await.map(|event| (event, (rx, guard)))
1641 })
1642 .boxed()
1643}
1644
1645fn parse_chat_response_body(name: &str, body: &str) -> anyhow::Result<ApiChatResponse> {
1646 serde_json::from_str(body).map_err(|_| {
1647 let sanitized = super::sanitize_api_error(body);
1648 ::zeroclaw_log::record!(
1649 ERROR,
1650 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1651 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1652 .with_attrs(::serde_json::json!({
1653 "model_provider": name,
1654 "body": &sanitized,
1655 })),
1656 "compatible: unexpected chat-completions payload"
1657 );
1658 anyhow::Error::msg(format!(
1659 "{name} API returned an unexpected chat-completions payload; body={sanitized}"
1660 ))
1661 })
1662}
1663
1664impl OpenAiCompatibleModelProvider {
1665 fn apply_auth_header(
1666 &self,
1667 req: reqwest::RequestBuilder,
1668 credential: Option<&str>,
1669 ) -> reqwest::RequestBuilder {
1670 apply_auth_to_request(req, &self.auth_header, credential)
1671 }
1672
1673 fn convert_tool_specs(
1674 tools: Option<&[zeroclaw_api::tool::ToolSpec]>,
1675 ) -> Option<Vec<serde_json::Value>> {
1676 tools.map(|items| {
1677 items
1678 .iter()
1679 .map(|tool| {
1680 let params = zeroclaw_api::schema::SchemaCleanr::clean_for_openai(
1681 tool.parameters.clone(),
1682 );
1683 serde_json::json!({
1684 "type": "function",
1685 "function": {
1686 "name": tool.name,
1687 "description": tool.description,
1688 "parameters": params,
1689 }
1690 })
1691 })
1692 .collect()
1693 })
1694 }
1695
1696 fn convert_tool_specs_for_model(
1702 &self,
1703 tools: Option<&[zeroclaw_api::tool::ToolSpec]>,
1704 model: &str,
1705 ) -> Option<Vec<serde_json::Value>> {
1706 let converted = Self::convert_tool_specs(tools)?;
1707 if !self.local_model_tool_sanitize || !Self::should_sanitize_local_tool_schema(model) {
1708 return Some(converted);
1709 }
1710 Some(
1711 converted
1712 .into_iter()
1713 .map(|mut tool| {
1714 let Some(raw_parameters) = tool.get("parameters").cloned() else {
1715 return tool;
1716 };
1717 let cleaned = zeroclaw_api::schema::SchemaCleanr::clean(
1718 raw_parameters,
1719 zeroclaw_api::schema::CleaningStrategy::Conservative,
1720 );
1721 if let Some(obj) = tool.as_object_mut() {
1722 obj.insert("parameters".to_string(), cleaned);
1723 }
1724 tool
1725 })
1726 .collect(),
1727 )
1728 }
1729
1730 fn should_sanitize_local_tool_schema(model: &str) -> bool {
1731 let lower = model.to_ascii_lowercase();
1732 model.is_empty() || lower.contains("gemma-4") || lower.contains("gemma4")
1733 }
1734
1735 fn build_native_tool_chat_request(
1736 &self,
1737 effective_messages: &[ChatMessage],
1738 tools: Option<Vec<serde_json::Value>>,
1739 model: &str,
1740 temperature: Option<f64>,
1741 allow_user_image_parts: bool,
1742 ) -> NativeChatRequest {
1743 let has_tool_entries = tools.as_ref().is_some_and(|tools| !tools.is_empty());
1744 let tool_choice = tools.as_ref().map(|_| "auto".to_string());
1745
1746 NativeChatRequest {
1747 model: model.to_string(),
1748 messages: self.convert_messages_for_native(effective_messages, allow_user_image_parts),
1749 temperature,
1750 stream: Some(false),
1751 stream_options: None,
1754 reasoning_effort: self.reasoning_effort_for_model(model),
1755 tool_stream: self.tool_stream_for_tools(has_tool_entries),
1756 tools,
1757 tool_choice,
1758 max_tokens: self.max_tokens,
1759 }
1760 }
1761
1762 async fn normalize_messages_for_upstream(
1779 messages: &[ChatMessage],
1780 ) -> anyhow::Result<Vec<ChatMessage>> {
1781 let config = zeroclaw_config::schema::MultimodalConfig::default();
1782 let prepared = multimodal::prepare_messages_for_provider(messages, &config).await?;
1783 Ok(prepared.messages)
1784 }
1785
1786 fn to_message_content(
1787 role: &str,
1788 content: &str,
1789 allow_user_image_parts: bool,
1790 ) -> MessageContent {
1791 if role != "user" || !allow_user_image_parts {
1792 return MessageContent::Text(content.to_string());
1793 }
1794
1795 let (cleaned_text, image_refs) = multimodal::parse_image_markers(content);
1796 if image_refs.is_empty() {
1797 return MessageContent::Text(content.to_string());
1798 }
1799
1800 let mut parts = Vec::with_capacity(image_refs.len() + 1);
1801 let trimmed_text = cleaned_text.trim();
1802 if !trimmed_text.is_empty() {
1803 parts.push(MessagePart::Text {
1804 text: trimmed_text.to_string(),
1805 });
1806 }
1807
1808 for image_ref in image_refs {
1809 parts.push(MessagePart::ImageUrl {
1810 image_url: ImageUrlPart { url: image_ref },
1811 });
1812 }
1813
1814 MessageContent::Parts(parts)
1815 }
1816
1817 fn convert_messages_for_native(
1818 &self,
1819 messages: &[ChatMessage],
1820 allow_user_image_parts: bool,
1821 ) -> Vec<NativeMessage> {
1822 let targets_mistral_tool_call_contract = self.targets_mistral_tool_call_contract();
1823 let mut used_tool_call_ids = std::collections::HashSet::new();
1824 let mut tool_call_id_map = std::collections::HashMap::new();
1825
1826 messages
1827 .iter()
1828 .map(|message| {
1829 if message.role == "assistant"
1830 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
1831 && let Some(tool_calls_value) = value.get("tool_calls")
1832 && let Ok(parsed_calls) =
1833 serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
1834 {
1835 let tool_calls = parsed_calls
1836 .into_iter()
1837 .map(|tc| ToolCall {
1838 id: Some({
1839 let normalized_id = reserve_tool_call_id_for_contract(
1840 targets_mistral_tool_call_contract,
1841 Some(tc.id.clone()),
1842 &mut used_tool_call_ids,
1843 );
1844 tool_call_id_map.insert(tc.id, normalized_id.clone());
1845 normalized_id
1846 }),
1847 kind: Some("function".to_string()),
1848 function: Some(Function {
1849 name: Some(tc.name),
1850 arguments: Some(tc.arguments),
1851 }),
1852 name: None,
1853 arguments: None,
1854 parameters: None,
1855 extra_content: tc.extra_content,
1858 })
1859 .collect::<Vec<_>>();
1860
1861 let content = value
1862 .get("content")
1863 .and_then(serde_json::Value::as_str)
1864 .map(|value| MessageContent::Text(value.to_string()));
1865
1866 let reasoning_content = value
1869 .get("reasoning_content")
1870 .or_else(|| value.get("reasoning"))
1871 .and_then(serde_json::Value::as_str)
1872 .map(ToString::to_string);
1873
1874 return NativeMessage {
1875 role: "assistant".to_string(),
1876 content,
1877 tool_call_id: None,
1878 tool_calls: Some(tool_calls),
1879 reasoning_content,
1880 };
1881 }
1882
1883 if message.role == "assistant"
1891 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
1892 && value.get("tool_calls").is_none()
1893 && let Some(reasoning_content) = value
1894 .get("reasoning_content")
1895 .and_then(serde_json::Value::as_str)
1896 && matches!(
1897 value.get("content"),
1898 None | Some(serde_json::Value::Null | serde_json::Value::String(_))
1899 )
1900 {
1901 let content = value
1902 .get("content")
1903 .and_then(serde_json::Value::as_str)
1904 .map(|value| MessageContent::Text(value.to_string()));
1905
1906 return NativeMessage {
1907 role: "assistant".to_string(),
1908 content,
1909 tool_call_id: None,
1910 tool_calls: None,
1911 reasoning_content: Some(reasoning_content.to_string()),
1912 };
1913 }
1914
1915 if message.role == "tool"
1916 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
1917 {
1918 let tool_call_id = value
1919 .get("tool_call_id")
1920 .and_then(serde_json::Value::as_str)
1921 .map(|raw_id| {
1922 tool_call_id_map.get(raw_id).cloned().unwrap_or_else(|| {
1923 let normalized_id = reserve_tool_call_id_for_contract(
1924 targets_mistral_tool_call_contract,
1925 Some(raw_id.to_string()),
1926 &mut used_tool_call_ids,
1927 );
1928 tool_call_id_map.insert(raw_id.to_string(), normalized_id.clone());
1929 normalized_id
1930 })
1931 });
1932 let content = value
1933 .get("content")
1934 .and_then(serde_json::Value::as_str)
1935 .map(|value| MessageContent::Text(value.to_string()))
1936 .or_else(|| Some(MessageContent::Text(message.content.clone())));
1937
1938 return NativeMessage {
1939 role: "tool".to_string(),
1940 content,
1941 tool_call_id,
1942 tool_calls: None,
1943 reasoning_content: None,
1944 };
1945 }
1946
1947 NativeMessage {
1948 role: message.role.clone(),
1949 content: Some(Self::to_message_content(
1950 &message.role,
1951 &message.content,
1952 allow_user_image_parts,
1953 )),
1954 tool_call_id: None,
1955 tool_calls: None,
1956 reasoning_content: None,
1957 }
1958 })
1959 .collect()
1960 }
1961
1962 fn strip_native_tool_messages(&self, messages: &[ChatMessage]) -> Vec<ChatMessage> {
1976 if self.native_tool_calling {
1977 return messages.to_vec();
1978 }
1979 let intermediate = messages.iter().filter_map(|msg| {
1980 if msg.role == "tool" {
1981 return None;
1982 }
1983 if msg.role == "assistant"
1984 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&msg.content)
1985 && value.get("tool_calls").is_some()
1986 {
1987 let text = value
1988 .get("content")
1989 .and_then(serde_json::Value::as_str)
1990 .unwrap_or("")
1991 .to_string();
1992 return if text.is_empty() {
1993 None
1994 } else {
1995 Some(ChatMessage::assistant(&text))
1996 };
1997 }
1998 Some(msg.clone())
1999 });
2000
2001 let mut coalesced: Vec<ChatMessage> = Vec::with_capacity(messages.len());
2012 for msg in intermediate {
2013 match coalesced.last_mut() {
2014 Some(last) if last.role == "assistant" && msg.role == "assistant" => {
2015 if !last.content.is_empty() && !msg.content.is_empty() {
2016 last.content.push_str("\n\n");
2017 }
2018 last.content.push_str(&msg.content);
2019 }
2020 _ => coalesced.push(msg),
2021 }
2022 }
2023 coalesced
2024 }
2025
2026 fn with_prompt_guided_tool_instructions(
2027 messages: &[ChatMessage],
2028 tools: Option<&[zeroclaw_api::tool::ToolSpec]>,
2029 ) -> Vec<ChatMessage> {
2030 let Some(tools) = tools else {
2031 return messages.to_vec();
2032 };
2033
2034 if tools.is_empty() {
2035 return messages.to_vec();
2036 }
2037
2038 let instructions = zeroclaw_api::model_provider::build_tool_instructions_text(tools);
2039 let mut modified_messages = messages.to_vec();
2040
2041 if let Some(system_message) = modified_messages.iter_mut().find(|m| m.role == "system") {
2042 if !system_message.content.is_empty() {
2043 system_message.content.push_str("\n\n");
2044 }
2045 system_message.content.push_str(&instructions);
2046 } else {
2047 modified_messages.insert(0, ChatMessage::system(instructions));
2048 }
2049
2050 modified_messages
2051 }
2052
2053 fn targets_mistral_tool_call_contract(&self) -> bool {
2054 if self.name.eq_ignore_ascii_case("mistral") {
2055 return true;
2056 }
2057
2058 reqwest::Url::parse(&self.base_url)
2059 .ok()
2060 .and_then(|url| url.host_str().map(|h| h.to_ascii_lowercase()))
2061 .is_some_and(|host| host == "mistral.ai" || host.ends_with(".mistral.ai"))
2062 }
2063
2064 fn reserve_tool_call_id(
2065 &self,
2066 raw_id: Option<String>,
2067 used_ids: &mut std::collections::HashSet<String>,
2068 ) -> String {
2069 reserve_tool_call_id_for_contract(
2070 self.targets_mistral_tool_call_contract(),
2071 raw_id,
2072 used_ids,
2073 )
2074 }
2075
2076 fn parse_native_response(&self, message: ResponseMessage) -> ProviderChatResponse {
2077 let text = message.effective_content_optional();
2078 let reasoning_content = message.reasoning_content.clone();
2079 let mut used_tool_call_ids = std::collections::HashSet::new();
2080 let tool_calls = message
2081 .tool_calls
2082 .unwrap_or_default()
2083 .into_iter()
2084 .filter_map(|tc| {
2085 let name = tc.function_name()?;
2086 let arguments = tc.function_arguments().unwrap_or_else(|| "{}".to_string());
2087 let normalized_arguments = if serde_json::from_str::<serde_json::Value>(&arguments)
2088 .is_ok()
2089 {
2090 arguments
2091 } else {
2092 ::zeroclaw_log::record!(
2093 WARN,
2094 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
2095 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
2096 .with_attrs(
2097 ::serde_json::json!({"function": name, "arguments": arguments})
2098 ),
2099 "Invalid JSON in native tool-call arguments, using empty object"
2100 );
2101 "{}".to_string()
2102 };
2103 Some(ProviderToolCall {
2104 id: self.reserve_tool_call_id(tc.id, &mut used_tool_call_ids),
2105 name,
2106 arguments: normalized_arguments,
2107 extra_content: tc.extra_content,
2108 })
2109 })
2110 .collect::<Vec<_>>();
2111
2112 ProviderChatResponse {
2113 text,
2114 tool_calls,
2115 usage: None,
2116 reasoning_content,
2117 }
2118 }
2119
2120 fn is_native_tool_schema_unsupported(status: reqwest::StatusCode, error: &str) -> bool {
2121 if !matches!(
2122 status,
2123 reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::UNPROCESSABLE_ENTITY
2124 ) {
2125 return false;
2126 }
2127
2128 let lower = error.to_lowercase();
2129 [
2130 "unknown parameter: tools",
2131 "unsupported parameter: tools",
2132 "unrecognized field `tools`",
2133 "does not support tools",
2134 "function calling is not supported",
2135 "tool_choice",
2136 "tool call validation failed",
2137 "was not in request",
2138 ]
2139 .iter()
2140 .any(|hint| lower.contains(hint))
2141 }
2142}
2143
2144#[async_trait]
2145impl ModelProvider for OpenAiCompatibleModelProvider {
2146 fn capabilities(&self) -> zeroclaw_api::model_provider::ProviderCapabilities {
2147 zeroclaw_api::model_provider::ProviderCapabilities {
2148 native_tool_calling: self.native_tool_calling,
2149 vision: self.supports_vision,
2150 prompt_caching: false,
2151 extended_thinking: false,
2152 }
2153 }
2154
2155 async fn list_models(&self) -> anyhow::Result<Vec<String>> {
2156 let list_credential = self.credential.as_deref();
2160 if list_credential.is_some() || self.public_model_listing {
2161 let url = format!("{}/models", self.base_url);
2162 let response = self
2163 .apply_auth_header(self.http_client().get(&url), list_credential)
2164 .send()
2165 .await
2166 .map_err(|e| {
2167 ::zeroclaw_log::record!(
2168 ERROR,
2169 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2170 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2171 .with_attrs(::serde_json::json!({
2172 "model_provider": &self.name,
2173 "url": &url,
2174 "phase": "model_list_request",
2175 "error": super::format_error_chain(&e),
2176 })),
2177 "compatible: model list request failed"
2178 );
2179 anyhow::Error::msg(format!(
2180 "{} model list request failed: {url}: {e}",
2181 self.name
2182 ))
2183 })?;
2184 if !response.status().is_success() {
2185 let status = response.status();
2186 anyhow::bail!("{} model list failed at {url}: HTTP {status}", self.name);
2187 }
2188 let body: ModelsResponse = response.json().await.map_err(|e| {
2189 ::zeroclaw_log::record!(
2190 ERROR,
2191 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2192 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2193 .with_attrs(::serde_json::json!({
2194 "model_provider": &self.name,
2195 "phase": "model_list_parse",
2196 "error": super::format_error_chain(&e),
2197 })),
2198 "compatible: model list returned invalid JSON"
2199 );
2200 anyhow::Error::msg(format!(
2201 "{} model list returned invalid JSON: {e}",
2202 self.name
2203 ))
2204 })?;
2205 return Ok(normalize_model_ids(body));
2206 }
2207 if let Some(key) = &self.models_dev_key {
2210 match crate::models_dev::list_models_for(key).await {
2211 Ok(models) if !models.is_empty() => return Ok(models),
2212 Ok(_) => {} Err(e) => {
2214 if self.openrouter_vendor_prefix.is_none() {
2215 return Err(e);
2216 }
2217 }
2218 }
2219 }
2220 match &self.openrouter_vendor_prefix {
2221 Some(prefix) => crate::openrouter_catalog::list_models_for_vendor(prefix).await,
2222 None => anyhow::bail!("live model listing is not supported for this model_provider"),
2223 }
2224 }
2225
2226 async fn list_models_with_pricing(
2227 &self,
2228 ) -> anyhow::Result<Vec<zeroclaw_api::model_provider::ModelInfo>> {
2229 let list_credential = self.credential.as_deref();
2232 if list_credential.is_some() || self.public_model_listing {
2233 let url = format!("{}/models", self.base_url);
2234 let response = self
2235 .apply_auth_header(self.http_client().get(&url), list_credential)
2236 .send()
2237 .await
2238 .map_err(|e| {
2239 ::zeroclaw_log::record!(
2240 ERROR,
2241 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2242 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2243 .with_attrs(::serde_json::json!({
2244 "model_provider": &self.name,
2245 "url": &url,
2246 "phase": "model_list_request",
2247 "error": super::format_error_chain(&e),
2248 })),
2249 "compatible: model list request failed"
2250 );
2251 anyhow::Error::msg(format!(
2252 "{} model list request failed: {url}: {e}",
2253 self.name
2254 ))
2255 })?;
2256 if !response.status().is_success() {
2257 let status = response.status();
2258 anyhow::bail!("{} model list failed at {url}: HTTP {status}", self.name);
2259 }
2260 let body: ModelsResponse = response.json().await.map_err(|e| {
2261 ::zeroclaw_log::record!(
2262 ERROR,
2263 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2264 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2265 .with_attrs(::serde_json::json!({
2266 "model_provider": &self.name,
2267 "phase": "model_list_parse",
2268 "error": super::format_error_chain(&e),
2269 })),
2270 "compatible: model list returned invalid JSON"
2271 );
2272 anyhow::Error::msg(format!(
2273 "{} model list returned invalid JSON: {e}",
2274 self.name
2275 ))
2276 })?;
2277 return Ok(normalize_models_with_pricing(body));
2278 }
2279 if let Some(key) = &self.models_dev_key {
2282 match crate::models_dev::list_models_for(key).await {
2283 Ok(models) if !models.is_empty() => {
2284 return Ok(models_dev_to_model_info(models));
2285 }
2286 Ok(_) => {} Err(_) if self.openrouter_vendor_prefix.is_none() => {
2288 return Ok(Vec::new());
2289 }
2290 Err(_) => {} }
2292 }
2293 match &self.openrouter_vendor_prefix {
2294 Some(prefix) => {
2295 crate::openrouter_catalog::list_models_for_vendor_with_pricing(prefix).await
2296 }
2297 None => Ok(Vec::new()),
2298 }
2299 }
2300
2301 async fn chat_with_system(
2302 &self,
2303 system_prompt: Option<&str>,
2304 message: &str,
2305 model: &str,
2306 temperature: Option<f64>,
2307 ) -> anyhow::Result<String> {
2308 let credential = self.credential.as_deref();
2309
2310 let user_msg = ChatMessage {
2314 role: "user".to_string(),
2315 content: message.to_string(),
2316 };
2317 let normalized_user =
2318 Self::normalize_messages_for_upstream(std::slice::from_ref(&user_msg))
2319 .await?
2320 .pop()
2321 .unwrap_or(user_msg);
2322 let normalized_message = normalized_user.content;
2323
2324 let merge = self.effective_merge_system(model);
2325 let mut messages = Vec::new();
2326
2327 if merge {
2328 let content = match system_prompt {
2329 Some(sys) => format!("{sys}\n\n{normalized_message}"),
2330 None => normalized_message,
2331 };
2332 messages.push(Message {
2333 role: "user".to_string(),
2334 content: Self::to_message_content("user", &content, !merge),
2335 });
2336 } else {
2337 if let Some(sys) = system_prompt {
2338 messages.push(Message {
2339 role: "system".to_string(),
2340 content: MessageContent::Text(sys.to_string()),
2341 });
2342 }
2343 messages.push(Message {
2344 role: "user".to_string(),
2345 content: Self::to_message_content("user", &normalized_message, true),
2346 });
2347 }
2348
2349 let request = ApiChatRequest {
2350 model: model.to_string(),
2351 messages,
2352 temperature,
2353 stream: Some(false),
2354 stream_options: None,
2355 reasoning_effort: self.reasoning_effort_for_model(model),
2356 tool_stream: None,
2357 tools: None,
2358 tool_choice: None,
2359 max_tokens: self.max_tokens,
2360 };
2361
2362 let url = self.chat_completions_url();
2363
2364 let response = match self
2365 .apply_auth_header(self.http_client().post(&url).json(&request), credential)
2366 .send()
2367 .await
2368 {
2369 Ok(response) => response,
2370 Err(chat_error) => {
2371 return Err(chat_error.into());
2372 }
2373 };
2374
2375 if !response.status().is_success() {
2376 let status = response.status();
2377 let error = response.text().await?;
2378 let sanitized = super::sanitize_api_error(&error);
2379 anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
2380 }
2381
2382 let body = response.text().await?;
2383 let chat_response = parse_chat_response_body(&self.name, &body)?;
2384
2385 chat_response
2386 .choices
2387 .into_iter()
2388 .next()
2389 .map(|c| {
2390 if c.message.tool_calls.is_some()
2391 && c.message
2392 .tool_calls
2393 .as_ref()
2394 .is_some_and(|t: &Vec<_>| !t.is_empty())
2395 {
2396 serde_json::to_string(&c.message)
2397 .unwrap_or_else(|_| c.message.effective_content())
2398 } else {
2399 c.message.effective_content()
2400 }
2401 })
2402 .ok_or_else(|| {
2403 ::zeroclaw_log::record!(
2404 ERROR,
2405 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2406 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2407 .with_attrs(::serde_json::json!({"model_provider": &self.name})),
2408 "compatible: empty choices in response"
2409 );
2410 anyhow::Error::msg(format!("No response from {}", self.name))
2411 })
2412 }
2413
2414 async fn chat_with_history(
2415 &self,
2416 messages: &[ChatMessage],
2417 model: &str,
2418 temperature: Option<f64>,
2419 ) -> anyhow::Result<String> {
2420 let credential = self.credential.as_deref();
2421
2422 let normalized = Self::normalize_messages_for_upstream(messages).await?;
2423 let merge = self.effective_merge_system(model);
2424 let effective_messages = Self::flatten_system_messages(&normalized, merge);
2425 let effective_messages = self.strip_native_tool_messages(&effective_messages);
2427 let api_messages: Vec<Message> = effective_messages
2428 .iter()
2429 .map(|m| Message {
2430 role: m.role.clone(),
2431 content: Self::to_message_content(&m.role, &m.content, !merge),
2432 })
2433 .collect();
2434
2435 let request = ApiChatRequest {
2436 model: model.to_string(),
2437 messages: api_messages,
2438 temperature,
2439 stream: Some(false),
2440 stream_options: None,
2441 reasoning_effort: self.reasoning_effort_for_model(model),
2442 tool_stream: None,
2443 tools: None,
2444 tool_choice: None,
2445 max_tokens: self.max_tokens,
2446 };
2447
2448 let url = self.chat_completions_url();
2449 let response = match self
2450 .apply_auth_header(self.http_client().post(&url).json(&request), credential)
2451 .send()
2452 .await
2453 {
2454 Ok(response) => response,
2455 Err(chat_error) => return Err(chat_error.into()),
2456 };
2457
2458 if !response.status().is_success() {
2459 return Err(super::api_error(&self.name, response).await);
2460 }
2461
2462 let body = response.text().await?;
2463 let chat_response = parse_chat_response_body(&self.name, &body)?;
2464
2465 chat_response
2466 .choices
2467 .into_iter()
2468 .next()
2469 .map(|c| {
2470 if c.message.tool_calls.is_some()
2471 && c.message
2472 .tool_calls
2473 .as_ref()
2474 .is_some_and(|t: &Vec<_>| !t.is_empty())
2475 {
2476 serde_json::to_string(&c.message)
2477 .unwrap_or_else(|_| c.message.effective_content())
2478 } else {
2479 c.message.effective_content()
2480 }
2481 })
2482 .ok_or_else(|| {
2483 ::zeroclaw_log::record!(
2484 ERROR,
2485 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2486 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2487 .with_attrs(::serde_json::json!({"model_provider": &self.name})),
2488 "compatible: empty choices in response"
2489 );
2490 anyhow::Error::msg(format!("No response from {}", self.name))
2491 })
2492 }
2493
2494 async fn chat_with_tools(
2495 &self,
2496 messages: &[ChatMessage],
2497 tools: &[serde_json::Value],
2498 model: &str,
2499 temperature: Option<f64>,
2500 ) -> anyhow::Result<ProviderChatResponse> {
2501 let credential = self.credential.as_deref();
2502
2503 let normalized = Self::normalize_messages_for_upstream(messages).await?;
2504 let merge = self.effective_merge_system(model);
2505 let effective_messages = Self::flatten_system_messages(&normalized, merge);
2506 let effective_messages = if self.native_tool_calling {
2507 effective_messages
2508 } else {
2509 self.strip_native_tool_messages(&effective_messages)
2510 };
2511 let tools = if tools.is_empty() {
2512 None
2513 } else {
2514 Some(tools.to_vec())
2515 };
2516 let request = self.build_native_tool_chat_request(
2517 &effective_messages,
2518 tools,
2519 model,
2520 temperature,
2521 !merge,
2522 );
2523
2524 let url = self.chat_completions_url();
2525 let response = match self
2526 .apply_auth_header(self.http_client().post(&url).json(&request), credential)
2527 .send()
2528 .await
2529 {
2530 Ok(response) => response,
2531 Err(error) => {
2532 ::zeroclaw_log::record!(
2533 WARN,
2534 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
2535 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
2536 &format!(
2537 "{} native tool call transport failed: {error}; falling back to history path",
2538 self.name
2539 )
2540 );
2541 let text = self.chat_with_history(messages, model, temperature).await?;
2542 return Ok(ProviderChatResponse {
2543 text: Some(text),
2544 tool_calls: vec![],
2545 usage: None,
2546 reasoning_content: None,
2547 });
2548 }
2549 };
2550
2551 if !response.status().is_success() {
2552 return Err(super::api_error(&self.name, response).await);
2553 }
2554
2555 let body = response.text().await?;
2556 let chat_response = parse_chat_response_body(&self.name, &body)?;
2557 let usage = chat_response.usage.map(|u| TokenUsage {
2558 input_tokens: u.prompt_tokens,
2559 output_tokens: u.completion_tokens,
2560 cached_input_tokens: None,
2561 });
2562 let choice = chat_response.choices.into_iter().next().ok_or_else(|| {
2563 ::zeroclaw_log::record!(
2564 ERROR,
2565 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2566 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2567 .with_attrs(::serde_json::json!({"model_provider": &self.name})),
2568 "compatible: empty choices in response"
2569 );
2570 anyhow::Error::msg(format!("No response from {}", self.name))
2571 })?;
2572
2573 let text = choice.message.effective_content_optional();
2574 let reasoning_content = choice.message.reasoning_content;
2575 let mut used_tool_call_ids = std::collections::HashSet::new();
2576 let tool_calls = choice
2577 .message
2578 .tool_calls
2579 .unwrap_or_default()
2580 .into_iter()
2581 .filter_map(|tc| {
2582 let function = tc.function?;
2583 let name = function.name?;
2584 let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
2585 Some(ProviderToolCall {
2586 id: self.reserve_tool_call_id(tc.id, &mut used_tool_call_ids),
2587 name,
2588 arguments,
2589 extra_content: tc.extra_content,
2590 })
2591 })
2592 .collect::<Vec<_>>();
2593
2594 Ok(ProviderChatResponse {
2595 text,
2596 tool_calls,
2597 usage,
2598 reasoning_content,
2599 })
2600 }
2601
2602 async fn chat(
2603 &self,
2604 request: ProviderChatRequest<'_>,
2605 model: &str,
2606 temperature: Option<f64>,
2607 ) -> anyhow::Result<ProviderChatResponse> {
2608 let credential = self.credential.as_deref();
2609
2610 let normalized = Self::normalize_messages_for_upstream(request.messages).await?;
2611 let merge = self.effective_merge_system(model);
2612 let effective_messages = Self::flatten_system_messages(&normalized, merge);
2613 let effective_messages = if self.native_tool_calling {
2614 effective_messages
2615 } else {
2616 self.strip_native_tool_messages(&effective_messages)
2617 };
2618
2619 let tools = self.convert_tool_specs_for_model(request.tools, model);
2622 let native_request = self.build_native_tool_chat_request(
2623 &effective_messages,
2624 tools,
2625 model,
2626 temperature,
2627 !merge,
2628 );
2629
2630 let url = self.chat_completions_url();
2631 let response = match self
2632 .apply_auth_header(
2633 self.http_client().post(&url).json(&native_request),
2634 credential,
2635 )
2636 .send()
2637 .await
2638 {
2639 Ok(response) => response,
2640 Err(chat_error) => return Err(chat_error.into()),
2641 };
2642
2643 if !response.status().is_success() {
2644 let status = response.status();
2645 let error = response.text().await?;
2646 let sanitized = super::sanitize_api_error(&error);
2647
2648 if Self::is_native_tool_schema_unsupported(status, &sanitized) {
2649 let fallback_messages =
2650 Self::with_prompt_guided_tool_instructions(request.messages, request.tools);
2651 let text = self
2652 .chat_with_history(&fallback_messages, model, temperature)
2653 .await?;
2654 return Ok(ProviderChatResponse {
2655 text: Some(text),
2656 tool_calls: vec![],
2657 usage: None,
2658 reasoning_content: None,
2659 });
2660 }
2661
2662 anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
2663 }
2664
2665 let native_response: ApiChatResponse = response.json().await?;
2666 let usage = native_response.usage.map(|u| TokenUsage {
2667 input_tokens: u.prompt_tokens,
2668 output_tokens: u.completion_tokens,
2669 cached_input_tokens: None,
2670 });
2671 let message = native_response
2672 .choices
2673 .into_iter()
2674 .next()
2675 .map(|choice| choice.message)
2676 .ok_or_else(|| {
2677 ::zeroclaw_log::record!(
2678 ERROR,
2679 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2680 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2681 .with_attrs(::serde_json::json!({"model_provider": &self.name})),
2682 "compatible: empty choices in response"
2683 );
2684 anyhow::Error::msg(format!("No response from {}", self.name))
2685 })?;
2686
2687 let mut result = self.parse_native_response(message);
2688 result.usage = usage;
2689 Ok(result)
2690 }
2691
2692 fn supports_native_tools(&self) -> bool {
2693 self.native_tool_calling
2694 }
2695
2696 fn supports_streaming(&self) -> bool {
2697 true
2698 }
2699
2700 fn supports_streaming_tool_events(&self) -> bool {
2701 self.native_tool_calling
2703 }
2704
2705 fn stream_chat(
2706 &self,
2707 request: ProviderChatRequest<'_>,
2708 model: &str,
2709 temperature: Option<f64>,
2710 options: StreamOptions,
2711 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
2712 if !options.enabled {
2713 return stream::once(async { Ok(StreamEvent::Final) }).boxed();
2714 }
2715
2716 let provider = self.clone();
2717 let messages_owned: Vec<ChatMessage> = request.messages.to_vec();
2718 let tools_owned: Option<Vec<zeroclaw_api::tool::ToolSpec>> =
2719 request.tools.map(<[zeroclaw_api::tool::ToolSpec]>::to_vec);
2720 let model = model.to_string();
2721 let count_tokens = options.count_tokens;
2722 let options_enabled = options.enabled;
2723
2724 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamEvent>>(100);
2725
2726 let handle = ::zeroclaw_spawn::spawn!(async move {
2727 let normalized = match Self::normalize_messages_for_upstream(&messages_owned).await {
2728 Ok(n) => n,
2729 Err(err) => {
2730 let _ = tx
2731 .send(Err(StreamError::ModelProvider(err.to_string())))
2732 .await;
2733 return;
2734 }
2735 };
2736
2737 let merge = provider.effective_merge_system(&model);
2738 let has_tools = tools_owned.as_ref().is_some_and(|tools| !tools.is_empty());
2739 let effective_messages = Self::flatten_system_messages(&normalized, merge);
2740 let effective_messages = provider.strip_native_tool_messages(&effective_messages);
2741 let tools = provider.convert_tool_specs_for_model(tools_owned.as_deref(), &model);
2742
2743 let payload_result = if has_tools {
2744 serde_json::to_value(NativeChatRequest {
2745 model: model.clone(),
2746 messages: provider.convert_messages_for_native(&effective_messages, !merge),
2747 temperature,
2748 reasoning_effort: provider.reasoning_effort_for_model(&model),
2749 tool_stream: if options_enabled {
2750 provider.tool_stream_for_tools(true)
2751 } else {
2752 None
2753 },
2754 stream: Some(options_enabled),
2755 stream_options: options_enabled.then_some(StreamOptionsBody {
2759 include_usage: true,
2760 }),
2761 tools: tools.clone(),
2762 tool_choice: tools.as_ref().map(|_| "auto".to_string()),
2763 max_tokens: provider.max_tokens,
2764 })
2765 } else {
2766 let messages = effective_messages
2767 .iter()
2768 .map(|message| Message {
2769 role: message.role.clone(),
2770 content: Self::to_message_content(&message.role, &message.content, !merge),
2771 })
2772 .collect();
2773
2774 serde_json::to_value(ApiChatRequest {
2775 model: model.clone(),
2776 messages,
2777 temperature,
2778 reasoning_effort: provider.reasoning_effort_for_model(&model),
2779 tool_stream: if options_enabled {
2780 provider.tool_stream_for_tools(false)
2781 } else {
2782 None
2783 },
2784 stream: Some(options_enabled),
2785 stream_options: options_enabled.then_some(StreamOptionsBody {
2786 include_usage: true,
2787 }),
2788 tools: None,
2789 tool_choice: None,
2790 max_tokens: provider.max_tokens,
2791 })
2792 };
2793
2794 let payload = match payload_result {
2795 Ok(payload) => payload,
2796 Err(error) => {
2797 let _ = tx.send(Err(StreamError::Json(error))).await;
2798 return;
2799 }
2800 };
2801
2802 let url = provider.chat_completions_url();
2803 let client = provider.streaming_http_client();
2804 let auth_header = provider.auth_header.clone();
2805 let credential = provider.credential.clone();
2806 let targets_mistral_tool_call_contract = provider.targets_mistral_tool_call_contract();
2807
2808 let mut req_builder = client.post(&url).json(&payload);
2809 req_builder = apply_auth_to_request(req_builder, &auth_header, credential.as_deref());
2810 req_builder = req_builder.header("Accept", "text/event-stream");
2811
2812 let response = match req_builder.send().await {
2813 Ok(r) => r,
2814 Err(e) => {
2815 let _ = tx
2816 .send(Err(StreamError::Http(super::format_error_chain(&e))))
2817 .await;
2818 return;
2819 }
2820 };
2821
2822 if !response.status().is_success() {
2823 let status = response.status();
2824 let error = match response.text().await {
2825 Ok(text) => text,
2826 Err(_) => format!("HTTP error: {}", status),
2827 };
2828 let _ = tx
2829 .send(Err(StreamError::ModelProvider(format!(
2830 "{}: {}",
2831 status, error
2832 ))))
2833 .await;
2834 return;
2835 }
2836
2837 let mut event_stream = sse_bytes_to_events_for_contract(
2838 response,
2839 count_tokens,
2840 targets_mistral_tool_call_contract,
2841 );
2842 while let Some(event) = event_stream.next().await {
2843 if tx.send(event).await.is_err() {
2844 break;
2845 }
2846 }
2847 });
2848
2849 let guard = AbortOnDrop::new(handle.abort_handle());
2850 stream::unfold((rx, guard), |(mut rx, guard)| async move {
2851 rx.recv().await.map(|event| (event, (rx, guard)))
2852 })
2853 .boxed()
2854 }
2855
2856 fn stream_chat_with_system(
2857 &self,
2858 system_prompt: Option<&str>,
2859 message: &str,
2860 model: &str,
2861 temperature: Option<f64>,
2862 options: StreamOptions,
2863 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
2864 let provider = self.clone();
2865 let system_prompt_owned: Option<String> = system_prompt.map(str::to_string);
2866 let message_owned = message.to_string();
2867 let model = model.to_string();
2868 let count_tokens = options.count_tokens;
2869 let options_enabled = options.enabled;
2870
2871 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
2873
2874 let handle = ::zeroclaw_spawn::spawn!(async move {
2875 let user_msg = ChatMessage {
2879 role: "user".to_string(),
2880 content: message_owned,
2881 };
2882 let normalized_user = match Self::normalize_messages_for_upstream(std::slice::from_ref(
2883 &user_msg,
2884 ))
2885 .await
2886 {
2887 Ok(mut msgs) => msgs.pop().unwrap_or(user_msg),
2888 Err(err) => {
2889 let _ = tx
2890 .send(Err(StreamError::ModelProvider(err.to_string())))
2891 .await;
2892 return;
2893 }
2894 };
2895 let normalized_message_content = normalized_user.content;
2896
2897 let merge = provider.effective_merge_system(&model);
2898 let mut messages = Vec::new();
2899 if merge {
2900 let content = match system_prompt_owned.as_deref() {
2901 Some(sys) => format!("{sys}\n\n{normalized_message_content}"),
2902 None => normalized_message_content,
2903 };
2904 messages.push(Message {
2905 role: "user".to_string(),
2906 content: Self::to_message_content("user", &content, !merge),
2907 });
2908 } else {
2909 if let Some(sys) = system_prompt_owned {
2910 messages.push(Message {
2911 role: "system".to_string(),
2912 content: MessageContent::Text(sys),
2913 });
2914 }
2915 messages.push(Message {
2916 role: "user".to_string(),
2917 content: Self::to_message_content("user", &normalized_message_content, !merge),
2918 });
2919 }
2920
2921 let request = ApiChatRequest {
2922 model: model.clone(),
2923 messages,
2924 temperature,
2925 stream: Some(options_enabled),
2926 stream_options: options_enabled.then_some(StreamOptionsBody {
2927 include_usage: true,
2928 }),
2929 reasoning_effort: provider.reasoning_effort_for_model(&model),
2930 tool_stream: None,
2931 tools: None,
2932 tool_choice: None,
2933 max_tokens: provider.max_tokens,
2934 };
2935
2936 let url = provider.chat_completions_url();
2937 let client = provider.streaming_http_client();
2938 let auth_header = provider.auth_header.clone();
2939 let credential = provider.credential.clone();
2940
2941 let mut req_builder = client.post(&url).json(&request);
2943
2944 req_builder = apply_auth_to_request(req_builder, &auth_header, credential.as_deref());
2946
2947 req_builder = req_builder.header("Accept", "text/event-stream");
2949
2950 let response = match req_builder.send().await {
2952 Ok(r) => r,
2953 Err(e) => {
2954 let _ = tx
2955 .send(Err(StreamError::Http(super::format_error_chain(&e))))
2956 .await;
2957 return;
2958 }
2959 };
2960
2961 if !response.status().is_success() {
2963 let status = response.status();
2964 let error = match response.text().await {
2965 Ok(e) => e,
2966 Err(_) => format!("HTTP error: {}", status),
2967 };
2968 let _ = tx
2969 .send(Err(StreamError::ModelProvider(format!(
2970 "{}: {}",
2971 status, error
2972 ))))
2973 .await;
2974 return;
2975 }
2976
2977 let mut chunk_stream = sse_bytes_to_chunks(response, count_tokens);
2979 while let Some(chunk) = chunk_stream.next().await {
2980 if tx.send(chunk).await.is_err() {
2981 break; }
2983 }
2984 });
2985
2986 let guard = AbortOnDrop::new(handle.abort_handle());
2988 stream::unfold((rx, guard), |(mut rx, guard)| async move {
2989 rx.recv().await.map(|chunk| (chunk, (rx, guard)))
2990 })
2991 .boxed()
2992 }
2993
2994 fn stream_chat_with_history(
2995 &self,
2996 messages: &[ChatMessage],
2997 model: &str,
2998 temperature: Option<f64>,
2999 options: StreamOptions,
3000 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
3001 let provider = self.clone();
3002 let messages_owned: Vec<ChatMessage> = messages.to_vec();
3003 let model = model.to_string();
3004 let count_tokens = options.count_tokens;
3005 let options_enabled = options.enabled;
3006
3007 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
3008
3009 let handle = ::zeroclaw_spawn::spawn!(async move {
3010 let normalized = match Self::normalize_messages_for_upstream(&messages_owned).await {
3011 Ok(n) => n,
3012 Err(err) => {
3013 let _ = tx
3014 .send(Err(StreamError::ModelProvider(err.to_string())))
3015 .await;
3016 return;
3017 }
3018 };
3019
3020 let merge = provider.effective_merge_system(&model);
3021 let effective_messages = Self::flatten_system_messages(&normalized, merge);
3022 let effective_messages = provider.strip_native_tool_messages(&effective_messages);
3023 let api_messages: Vec<Message> = effective_messages
3024 .iter()
3025 .map(|m| Message {
3026 role: m.role.clone(),
3027 content: Self::to_message_content(&m.role, &m.content, !merge),
3028 })
3029 .collect();
3030
3031 let request = ApiChatRequest {
3032 model: model.clone(),
3033 messages: api_messages,
3034 temperature,
3035 stream: Some(options_enabled),
3036 stream_options: options_enabled.then_some(StreamOptionsBody {
3037 include_usage: true,
3038 }),
3039 reasoning_effort: provider.reasoning_effort_for_model(&model),
3040 tool_stream: None,
3041 tools: None,
3042 tool_choice: None,
3043 max_tokens: provider.max_tokens,
3044 };
3045
3046 let url = provider.chat_completions_url();
3047 let client = provider.streaming_http_client();
3048 let auth_header = provider.auth_header.clone();
3049 let credential = provider.credential.clone();
3050
3051 let mut req_builder = client.post(&url).json(&request);
3052 req_builder = apply_auth_to_request(req_builder, &auth_header, credential.as_deref());
3053 req_builder = req_builder.header("Accept", "text/event-stream");
3054
3055 let response = match req_builder.send().await {
3056 Ok(r) => r,
3057 Err(e) => {
3058 let _ = tx
3059 .send(Err(StreamError::Http(super::format_error_chain(&e))))
3060 .await;
3061 return;
3062 }
3063 };
3064
3065 if !response.status().is_success() {
3066 let status = response.status();
3067 let error = match response.text().await {
3068 Ok(e) => e,
3069 Err(_) => format!("HTTP error: {}", status),
3070 };
3071 let _ = tx
3072 .send(Err(StreamError::ModelProvider(format!(
3073 "{}: {}",
3074 status, error
3075 ))))
3076 .await;
3077 return;
3078 }
3079
3080 let mut chunk_stream = sse_bytes_to_chunks(response, count_tokens);
3081 while let Some(chunk) = chunk_stream.next().await {
3082 if tx.send(chunk).await.is_err() {
3083 break;
3084 }
3085 }
3086 });
3087
3088 let guard = AbortOnDrop::new(handle.abort_handle());
3089 stream::unfold((rx, guard), |(mut rx, guard)| async move {
3090 rx.recv().await.map(|chunk| (chunk, (rx, guard)))
3091 })
3092 .boxed()
3093 }
3094
3095 async fn warmup(&self) -> anyhow::Result<()> {
3096 let url = self.chat_completions_url();
3099 let _ = self
3100 .apply_auth_header(self.http_client().get(&url), self.credential.as_deref())
3101 .send()
3102 .await?;
3103 Ok(())
3104 }
3105}
3106
3107impl ::zeroclaw_api::attribution::Attributable for OpenAiCompatibleModelProvider {
3108 fn role(&self) -> ::zeroclaw_api::attribution::Role {
3109 ::zeroclaw_api::attribution::Role::Provider(
3110 ::zeroclaw_api::attribution::ProviderKind::Model(
3111 ::zeroclaw_api::attribution::ModelProviderKind::Plugin,
3112 ),
3113 )
3114 }
3115 fn alias(&self) -> &str {
3116 &self.alias
3117 }
3118}
3119
3120#[cfg(test)]
3121mod tests {
3122 use super::*;
3123
3124 fn make_model_provider(
3125 name: &str,
3126 url: &str,
3127 key: Option<&str>,
3128 ) -> OpenAiCompatibleModelProvider {
3129 OpenAiCompatibleModelProvider::new("test", name, url, key, AuthStyle::Bearer)
3130 }
3131
3132 #[test]
3133 fn creates_with_key() {
3134 let p = make_model_provider(
3135 "venice",
3136 "https://api.venice.ai",
3137 Some("venice-test-credential"),
3138 );
3139 assert_eq!(p.name, "venice");
3140 assert_eq!(p.base_url, "https://api.venice.ai");
3141 assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
3142 }
3143
3144 #[test]
3145 fn creates_without_key() {
3146 let p = make_model_provider("test", "https://example.com", None);
3147 assert!(p.credential.is_none());
3148 }
3149
3150 #[test]
3151 fn strips_trailing_slash() {
3152 let p = make_model_provider("test", "https://example.com/", None);
3153 assert_eq!(p.base_url, "https://example.com");
3154 }
3155
3156 #[tokio::test]
3157 async fn chat_without_key_attempts_request() {
3158 let p = make_model_provider("Local", "http://127.0.0.1:1", None);
3159 let result = p
3160 .chat_with_system(None, "hello", "default", Some(0.7))
3161 .await;
3162 assert!(result.is_err());
3163 let err_msg = result.unwrap_err().to_string();
3164 assert!(
3165 !err_msg.contains("API key not set"),
3166 "should not get credential error, got: {err_msg}"
3167 );
3168 }
3169
3170 #[test]
3171 fn native_chat_request_with_tools_includes_stream_options() {
3172 let req = NativeChatRequest {
3177 model: "gpt-4o".to_string(),
3178 messages: vec![NativeMessage {
3179 role: "user".to_string(),
3180 content: Some(MessageContent::Text("hello".to_string())),
3181 tool_call_id: None,
3182 tool_calls: None,
3183 reasoning_content: None,
3184 }],
3185 temperature: Some(0.7),
3186 stream: Some(true),
3187 stream_options: Some(StreamOptionsBody {
3188 include_usage: true,
3189 }),
3190 reasoning_effort: None,
3191 tool_stream: None,
3192 tools: Some(vec![serde_json::json!({"name": "echo"})]),
3193 tool_choice: Some("auto".to_string()),
3194 max_tokens: None,
3195 };
3196 let value: serde_json::Value = serde_json::to_value(&req).unwrap();
3197 assert_eq!(
3198 value
3199 .get("stream_options")
3200 .and_then(|v| v.get("include_usage"))
3201 .and_then(serde_json::Value::as_bool),
3202 Some(true),
3203 "tool-enabled streaming request must serialize stream_options.include_usage=true; \
3204 without it OpenAI-compatible providers omit the final usage event"
3205 );
3206 }
3207
3208 #[test]
3209 fn native_chat_request_omits_stream_options_when_none() {
3210 let req = NativeChatRequest {
3214 model: "gpt-4o".to_string(),
3215 messages: vec![],
3216 temperature: Some(0.7),
3217 stream: Some(false),
3218 stream_options: None,
3219 reasoning_effort: None,
3220 tool_stream: None,
3221 tools: None,
3222 tool_choice: None,
3223 max_tokens: None,
3224 };
3225 let value: serde_json::Value = serde_json::to_value(&req).unwrap();
3226 assert!(
3227 value.get("stream_options").is_none(),
3228 "non-streaming NativeChatRequest must not emit a stream_options key"
3229 );
3230 }
3231
3232 #[test]
3233 fn normalize_model_ids_trims_filters_and_sorts() {
3234 let body = serde_json::from_value(serde_json::json!({
3235 "data": [
3236 {"id": " zeta-model "},
3237 {"id": ""},
3238 {"id": "alpha-model"}
3239 ]
3240 }))
3241 .unwrap();
3242
3243 assert_eq!(normalize_model_ids(body), vec!["alpha-model", "zeta-model"]);
3244 }
3245
3246 #[test]
3247 fn request_serializes_correctly() {
3248 let req = ApiChatRequest {
3249 model: "llama-3.3-70b".to_string(),
3250 messages: vec![
3251 Message {
3252 role: "system".to_string(),
3253 content: MessageContent::Text("You are ZeroClaw".to_string()),
3254 },
3255 Message {
3256 role: "user".to_string(),
3257 content: MessageContent::Text("hello".to_string()),
3258 },
3259 ],
3260 temperature: Some(0.4),
3261 stream: Some(false),
3262 stream_options: None,
3263 reasoning_effort: None,
3264 tool_stream: None,
3265 tools: None,
3266 tool_choice: None,
3267 max_tokens: None,
3268 };
3269 let json = serde_json::to_string(&req).unwrap();
3270 assert!(json.contains("llama-3.3-70b"));
3271 assert!(json.contains("system"));
3272 assert!(json.contains("user"));
3273 assert!(!json.contains("tools"));
3275 assert!(!json.contains("tool_choice"));
3276 }
3277
3278 #[test]
3279 fn response_deserializes() {
3280 let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#;
3281 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
3282 assert_eq!(
3283 resp.choices[0].message.content,
3284 Some("Hello from Venice!".to_string())
3285 );
3286 }
3287
3288 #[test]
3289 fn response_deserializes_content_as_openai_text_parts_array() {
3290 let json =
3291 r#"{"choices":[{"message":{"content":[{"type":"text","text":"Hello array"}]}}]}"#;
3292 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
3293 assert_eq!(
3294 resp.choices[0].message.content.as_deref(),
3295 Some("Hello array")
3296 );
3297 }
3298
3299 #[test]
3300 fn response_deserializes_multiple_text_parts_with_newlines() {
3301 let json = r#"{"choices":[{"message":{"content":[{"type":"text","text":"Hello"},{"type":"image_url","image_url":{"url":"https://example.com/image.png"}},{"type":"text","text":"array"}]}}]}"#;
3302 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
3303 assert_eq!(
3304 resp.choices[0].message.content.as_deref(),
3305 Some("Hello\narray")
3306 );
3307 }
3308
3309 #[test]
3310 fn response_rejects_unsupported_top_level_content_shape() {
3311 let json = r#"{"choices":[{"message":{"content":{"type":"text","text":"Hello object"}}}]}"#;
3312 serde_json::from_str::<ApiChatResponse>(json)
3313 .expect_err("object-shaped assistant content must remain an invalid payload");
3314 }
3315
3316 #[test]
3317 fn response_empty_choices() {
3318 let json = r#"{"choices":[]}"#;
3319 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
3320 assert!(resp.choices.is_empty());
3321 }
3322
3323 #[test]
3324 fn parse_chat_response_body_reports_sanitized_snippet() {
3325 let body = r#"{"choices":"invalid","api_key":"sk-test-secret-value"}"#;
3326 let err = parse_chat_response_body("custom", body).expect_err("payload should fail");
3327 let msg = err.to_string();
3328
3329 assert!(msg.contains("custom API returned an unexpected chat-completions payload"));
3330 assert!(msg.contains("body="));
3331 assert!(msg.contains("[REDACTED]"));
3332 assert!(!msg.contains("sk-test-secret-value"));
3333 }
3334
3335 #[test]
3336 fn x_api_key_auth_style() {
3337 let p = OpenAiCompatibleModelProvider::new(
3338 "test",
3339 "moonshot",
3340 "https://api.moonshot.cn",
3341 Some("ms-key"),
3342 AuthStyle::XApiKey,
3343 );
3344 assert!(matches!(p.auth_header, AuthStyle::XApiKey));
3345 }
3346
3347 #[test]
3348 fn custom_auth_style() {
3349 let p = OpenAiCompatibleModelProvider::new(
3350 "test",
3351 "custom",
3352 "https://api.example.com",
3353 Some("key"),
3354 AuthStyle::Custom("X-Custom-Key".into()),
3355 );
3356 assert!(matches!(p.auth_header, AuthStyle::Custom(_)));
3357 }
3358
3359 #[test]
3360 fn zhipu_jwt_produces_valid_three_part_token() {
3361 let result = zhipu_jwt_bearer("testid.testsecret").unwrap();
3362 assert!(result.starts_with("Bearer "));
3363 let jwt = result.strip_prefix("Bearer ").unwrap();
3364 let parts: Vec<&str> = jwt.split('.').collect();
3365 assert_eq!(parts.len(), 3, "JWT must have 3 dot-separated parts: {jwt}");
3366 }
3367
3368 #[test]
3369 fn zhipu_jwt_header_is_correct() {
3370 use base64::engine::{Engine, general_purpose::URL_SAFE_NO_PAD};
3371 let result = zhipu_jwt_bearer("myid.mysecret").unwrap();
3372 let jwt = result.strip_prefix("Bearer ").unwrap();
3373 let header_b64 = jwt.split('.').next().unwrap();
3374 let header_bytes = URL_SAFE_NO_PAD.decode(header_b64).unwrap();
3375 let header: serde_json::Value = serde_json::from_slice(&header_bytes).unwrap();
3376 assert_eq!(header["alg"], "HS256");
3377 assert_eq!(header["typ"], "JWT");
3378 assert_eq!(header["sign_type"], "SIGN");
3379 }
3380
3381 #[test]
3382 fn zhipu_jwt_payload_contains_api_key_and_timestamps() {
3383 use base64::engine::{Engine, general_purpose::URL_SAFE_NO_PAD};
3384 let result = zhipu_jwt_bearer("myapiid.mysecretkey").unwrap();
3385 let jwt = result.strip_prefix("Bearer ").unwrap();
3386 let payload_b64 = jwt.split('.').nth(1).unwrap();
3387 let payload_bytes = URL_SAFE_NO_PAD.decode(payload_b64).unwrap();
3388 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
3389 assert_eq!(payload["api_key"], "myapiid");
3390 assert!(payload["exp"].is_number());
3391 assert!(payload["timestamp"].is_number());
3392 let ts = payload["timestamp"].as_u64().unwrap();
3394 let exp = payload["exp"].as_u64().unwrap();
3395 assert_eq!(exp - ts, 210_000);
3396 }
3397
3398 #[test]
3399 fn zhipu_jwt_signature_is_verifiable() {
3400 let secret = "testsecret123";
3401 let credential = format!("testid.{secret}");
3402 let result = zhipu_jwt_bearer(&credential).unwrap();
3403 let jwt = result.strip_prefix("Bearer ").unwrap();
3404 let parts: Vec<&str> = jwt.split('.').collect();
3405 let signing_input = format!("{}.{}", parts[0], parts[1]);
3406
3407 let key = ring::hmac::Key::new(ring::hmac::HMAC_SHA256, secret.as_bytes());
3409 use base64::engine::{Engine, general_purpose::URL_SAFE_NO_PAD};
3410 let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
3411 ring::hmac::verify(&key, signing_input.as_bytes(), &sig_bytes)
3412 .expect("signature must verify");
3413 }
3414
3415 #[test]
3416 fn zhipu_jwt_rejects_invalid_key_format() {
3417 assert!(zhipu_jwt_bearer("no-dot-here").is_err());
3418 assert!(zhipu_jwt_bearer("").is_err());
3419 }
3420
3421 #[test]
3422 fn zhipu_jwt_auth_style_applies_correctly() {
3423 let p = OpenAiCompatibleModelProvider::new(
3424 "test",
3425 "Z.AI",
3426 "https://api.z.ai/api/coding/paas/v4",
3427 Some("testid.testsecret"),
3428 AuthStyle::ZhipuJwt,
3429 );
3430 assert!(matches!(p.auth_header, AuthStyle::ZhipuJwt));
3431 }
3432
3433 #[tokio::test]
3434 async fn all_compatible_providers_attempt_request_without_key() {
3435 let model_providers = vec![
3436 make_model_provider("Venice", "http://127.0.0.1:1", None),
3437 make_model_provider("Moonshot", "http://127.0.0.1:1", None),
3438 make_model_provider("GLM", "http://127.0.0.1:1", None),
3439 make_model_provider("MiniMax", "http://127.0.0.1:1", None),
3440 make_model_provider("Groq", "http://127.0.0.1:1", None),
3441 make_model_provider("Mistral", "http://127.0.0.1:1", None),
3442 make_model_provider("xAI", "http://127.0.0.1:1", None),
3443 make_model_provider("Astrai", "http://127.0.0.1:1", None),
3444 ];
3445
3446 for p in model_providers {
3447 let result = p.chat_with_system(None, "test", "model", Some(0.7)).await;
3448 assert!(result.is_err(), "{} should fail (unreachable host)", p.name);
3449 let err_msg = result.unwrap_err().to_string();
3450 assert!(
3451 !err_msg.contains("API key not set"),
3452 "{} should get transport error, not credential error, got: {err_msg}",
3453 p.name
3454 );
3455 }
3456 }
3457
3458 #[test]
3459 fn tool_call_function_name_falls_back_to_top_level_name() {
3460 let call: ToolCall = serde_json::from_value(serde_json::json!({
3461 "name": "memory_recall",
3462 "arguments": "{\"query\":\"latest roadmap\"}"
3463 }))
3464 .unwrap();
3465
3466 assert_eq!(call.function_name().as_deref(), Some("memory_recall"));
3467 }
3468
3469 #[test]
3470 fn tool_call_function_arguments_falls_back_to_parameters_object() {
3471 let call: ToolCall = serde_json::from_value(serde_json::json!({
3472 "name": "shell",
3473 "parameters": {"command": "pwd"}
3474 }))
3475 .unwrap();
3476
3477 assert_eq!(
3478 call.function_arguments().as_deref(),
3479 Some("{\"command\":\"pwd\"}")
3480 );
3481 }
3482
3483 #[test]
3484 fn tool_call_function_arguments_prefers_nested_function_field() {
3485 let call: ToolCall = serde_json::from_value(serde_json::json!({
3486 "name": "ignored_name",
3487 "arguments": "{\"query\":\"ignored\"}",
3488 "function": {
3489 "name": "memory_recall",
3490 "arguments": "{\"query\":\"preferred\"}"
3491 }
3492 }))
3493 .unwrap();
3494
3495 assert_eq!(call.function_name().as_deref(), Some("memory_recall"));
3496 assert_eq!(
3497 call.function_arguments().as_deref(),
3498 Some("{\"query\":\"preferred\"}")
3499 );
3500 }
3501
3502 #[test]
3507 fn chat_completions_url_standard_openai() {
3508 let p = make_model_provider("openai", "https://api.openai.com/v1", None);
3510 assert_eq!(
3511 p.chat_completions_url(),
3512 "https://api.openai.com/v1/chat/completions"
3513 );
3514 }
3515
3516 #[test]
3517 fn chat_completions_url_trailing_slash() {
3518 let p = make_model_provider("test", "https://api.example.com/v1/", None);
3520 assert_eq!(
3521 p.chat_completions_url(),
3522 "https://api.example.com/v1/chat/completions"
3523 );
3524 }
3525
3526 #[test]
3527 fn chat_completions_url_volcengine_ark() {
3528 let p = make_model_provider(
3530 "volcengine",
3531 "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions",
3532 None,
3533 );
3534 assert_eq!(
3535 p.chat_completions_url(),
3536 "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions"
3537 );
3538 }
3539
3540 #[test]
3541 fn chat_completions_url_custom_full_endpoint() {
3542 let p = make_model_provider(
3544 "custom",
3545 "https://my-api.example.com/v2/llm/chat/completions",
3546 None,
3547 );
3548 assert_eq!(
3549 p.chat_completions_url(),
3550 "https://my-api.example.com/v2/llm/chat/completions"
3551 );
3552 }
3553
3554 #[test]
3555 fn chat_completions_url_requires_exact_suffix_match() {
3556 let p = make_model_provider(
3557 "custom",
3558 "https://my-api.example.com/v2/llm/chat/completions-proxy",
3559 None,
3560 );
3561 assert_eq!(
3562 p.chat_completions_url(),
3563 "https://my-api.example.com/v2/llm/chat/completions-proxy/chat/completions"
3564 );
3565 }
3566
3567 #[test]
3568 fn chat_completions_url_without_v1() {
3569 let p = make_model_provider("test", "https://api.example.com", None);
3571 assert_eq!(
3572 p.chat_completions_url(),
3573 "https://api.example.com/chat/completions"
3574 );
3575 }
3576
3577 #[test]
3578 fn chat_completions_url_base_with_v1() {
3579 let p = make_model_provider("test", "https://api.example.com/v1", None);
3581 assert_eq!(
3582 p.chat_completions_url(),
3583 "https://api.example.com/v1/chat/completions"
3584 );
3585 }
3586
3587 #[test]
3592 fn chat_completions_url_zai() {
3593 let p = make_model_provider("zai", "https://api.z.ai/api/paas/v4", None);
3595 assert_eq!(
3596 p.chat_completions_url(),
3597 "https://api.z.ai/api/paas/v4/chat/completions"
3598 );
3599 }
3600
3601 #[test]
3602 fn chat_completions_url_minimax() {
3603 let p = make_model_provider("minimax", "https://api.minimaxi.com/v1", None);
3605 assert_eq!(
3606 p.chat_completions_url(),
3607 "https://api.minimaxi.com/v1/chat/completions"
3608 );
3609 }
3610
3611 #[test]
3612 fn chat_completions_url_glm() {
3613 let p = make_model_provider("glm", "https://open.bigmodel.cn/api/paas/v4", None);
3615 assert_eq!(
3616 p.chat_completions_url(),
3617 "https://open.bigmodel.cn/api/paas/v4/chat/completions"
3618 );
3619 }
3620
3621 #[test]
3622 fn chat_completions_url_opencode() {
3623 let p = make_model_provider("opencode", "https://opencode.ai/zen/v1", None);
3625 assert_eq!(
3626 p.chat_completions_url(),
3627 "https://opencode.ai/zen/v1/chat/completions"
3628 );
3629 }
3630
3631 #[test]
3632 fn chat_completions_url_opencode_go() {
3633 let p = make_model_provider("opencode-go", "https://opencode.ai/zen/go/v1", None);
3635 assert_eq!(
3636 p.chat_completions_url(),
3637 "https://opencode.ai/zen/go/v1/chat/completions"
3638 );
3639 }
3640
3641 #[test]
3642 fn parse_native_response_preserves_tool_call_id() {
3643 let provider = make_model_provider("test", "https://example.com", None);
3644 let message = ResponseMessage {
3645 content: None,
3646 tool_calls: Some(vec![ToolCall {
3647 id: Some("call_123".to_string()),
3648 kind: Some("function".to_string()),
3649 function: Some(Function {
3650 name: Some("shell".to_string()),
3651 arguments: Some(r#"{"command":"pwd"}"#.to_string()),
3652 }),
3653 name: None,
3654 arguments: None,
3655 parameters: None,
3656 extra_content: None,
3657 }]),
3658 reasoning_content: None,
3659 };
3660
3661 let parsed = provider.parse_native_response(message);
3662 assert_eq!(parsed.tool_calls.len(), 1);
3663 assert_eq!(parsed.tool_calls[0].id, "call_123");
3664 assert_eq!(parsed.tool_calls[0].name, "shell");
3665 }
3666
3667 #[test]
3668 fn parse_native_response_mistral_normalizes_invalid_tool_call_id() {
3669 let provider = make_model_provider("Mistral", "https://api.mistral.ai/v1", None);
3670 let message = ResponseMessage {
3671 content: None,
3672 tool_calls: Some(vec![ToolCall {
3673 id: Some("xvL0p9bZ41j2X0O3Q1y9vL0p9bZ41j2X".to_string()),
3674 kind: Some("function".to_string()),
3675 function: Some(Function {
3676 name: Some("shell".to_string()),
3677 arguments: Some(r#"{"command":"pwd"}"#.to_string()),
3678 }),
3679 name: None,
3680 arguments: None,
3681 parameters: None,
3682 extra_content: None,
3683 }]),
3684 reasoning_content: None,
3685 };
3686
3687 let parsed = provider.parse_native_response(message);
3688 assert_eq!(parsed.tool_calls.len(), 1);
3689 let id = &parsed.tool_calls[0].id;
3690 assert_eq!(id.len(), 9);
3691 assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
3692 }
3693
3694 #[test]
3695 fn parse_native_response_mistral_generates_valid_id_when_missing() {
3696 let provider = make_model_provider("Mistral", "https://api.mistral.ai/v1", None);
3697 let message = ResponseMessage {
3698 content: None,
3699 tool_calls: Some(vec![ToolCall {
3700 id: None,
3701 kind: Some("function".to_string()),
3702 function: Some(Function {
3703 name: Some("shell".to_string()),
3704 arguments: Some(r#"{"command":"pwd"}"#.to_string()),
3705 }),
3706 name: None,
3707 arguments: None,
3708 parameters: None,
3709 extra_content: None,
3710 }]),
3711 reasoning_content: None,
3712 };
3713
3714 let parsed = provider.parse_native_response(message);
3715 assert_eq!(parsed.tool_calls.len(), 1);
3716 let id = &parsed.tool_calls[0].id;
3717 assert_eq!(id.len(), 9);
3718 assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
3719 }
3720
3721 #[test]
3722 fn parse_native_response_custom_mistral_endpoint_normalizes_tool_call_id() {
3723 let provider = make_model_provider("Custom", "https://api.mistral.ai/v1", None);
3724 let message = ResponseMessage {
3725 content: None,
3726 tool_calls: Some(vec![ToolCall {
3727 id: Some("xvL0p9bZ41j2X0O3Q1y9vL0p9bZ41j2X".to_string()),
3728 kind: Some("function".to_string()),
3729 function: Some(Function {
3730 name: Some("shell".to_string()),
3731 arguments: Some(r#"{"command":"pwd"}"#.to_string()),
3732 }),
3733 name: None,
3734 arguments: None,
3735 parameters: None,
3736 extra_content: None,
3737 }]),
3738 reasoning_content: None,
3739 };
3740
3741 let parsed = provider.parse_native_response(message);
3742 assert_eq!(parsed.tool_calls.len(), 1);
3743 let id = &parsed.tool_calls[0].id;
3744 assert_eq!(id.len(), 9);
3745 assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
3746 }
3747
3748 #[test]
3749 fn parse_native_response_mistral_avoids_id_collision_after_normalization() {
3750 let provider = make_model_provider("Mistral", "https://api.mistral.ai/v1", None);
3751 let message = ResponseMessage {
3752 content: None,
3753 tool_calls: Some(vec![
3754 ToolCall {
3755 id: Some("ABCDEFGHI123".to_string()),
3756 kind: Some("function".to_string()),
3757 function: Some(Function {
3758 name: Some("shell".to_string()),
3759 arguments: Some(r#"{"command":"pwd"}"#.to_string()),
3760 }),
3761 name: None,
3762 arguments: None,
3763 parameters: None,
3764 extra_content: None,
3765 },
3766 ToolCall {
3767 id: Some("ABCDEFGHIxyz".to_string()),
3768 kind: Some("function".to_string()),
3769 function: Some(Function {
3770 name: Some("echo".to_string()),
3771 arguments: Some(r#"{"text":"ok"}"#.to_string()),
3772 }),
3773 name: None,
3774 arguments: None,
3775 parameters: None,
3776 extra_content: None,
3777 },
3778 ]),
3779 reasoning_content: None,
3780 };
3781
3782 let parsed = provider.parse_native_response(message);
3783 assert_eq!(parsed.tool_calls.len(), 2);
3784 let id0 = &parsed.tool_calls[0].id;
3785 let id1 = &parsed.tool_calls[1].id;
3786 assert_eq!(id0.len(), 9);
3787 assert_eq!(id1.len(), 9);
3788 assert!(id0.chars().all(|c| c.is_ascii_alphanumeric()));
3789 assert!(id1.chars().all(|c| c.is_ascii_alphanumeric()));
3790 assert_ne!(id0, id1);
3791 }
3792
3793 #[test]
3794 fn convert_messages_for_native_maps_tool_result_payload() {
3795 let input = vec![ChatMessage::tool(
3796 r#"{"tool_call_id":"call_abc","content":"done"}"#,
3797 )];
3798
3799 let provider = make_model_provider("test", "https://example.com", None);
3800 let converted = provider.convert_messages_for_native(&input, true);
3801 assert_eq!(converted.len(), 1);
3802 assert_eq!(converted[0].role, "tool");
3803 assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_abc"));
3804 assert!(matches!(
3805 converted[0].content.as_ref(),
3806 Some(MessageContent::Text(value)) if value == "done"
3807 ));
3808 }
3809
3810 #[test]
3811 fn native_chat_request_mistral_serializes_matching_valid_tool_call_ids() {
3812 let provider = make_model_provider("Mistral", "https://api.mistral.ai/v1", None);
3813 let invalid_id = "chatcmpl-tool-abc";
3814 let history_json = serde_json::json!({
3815 "content": "",
3816 "tool_calls": [{
3817 "id": invalid_id,
3818 "name": "shell",
3819 "arguments": "{\"cmd\":\"pwd\"}"
3820 }]
3821 });
3822 let messages = vec![
3823 ChatMessage::assistant(history_json.to_string()),
3824 ChatMessage::tool(
3825 serde_json::json!({
3826 "tool_call_id": invalid_id,
3827 "content": "done"
3828 })
3829 .to_string(),
3830 ),
3831 ];
3832
3833 let req = NativeChatRequest {
3834 model: "mistral-large-latest".to_string(),
3835 messages: provider.convert_messages_for_native(&messages, true),
3836 temperature: Some(0.7),
3837 stream: Some(false),
3838 stream_options: None,
3839 reasoning_effort: None,
3840 tool_stream: None,
3841 tools: Some(vec![serde_json::json!({
3842 "type": "function",
3843 "function": {
3844 "name": "shell",
3845 "description": "Run a shell command",
3846 "parameters": {"type": "object"}
3847 }
3848 })]),
3849 tool_choice: Some("auto".to_string()),
3850 max_tokens: None,
3851 };
3852
3853 let value = serde_json::to_value(&req).unwrap();
3854 let assistant_id = value["messages"][0]["tool_calls"][0]["id"]
3855 .as_str()
3856 .expect("assistant tool call id should serialize");
3857 let tool_id = value["messages"][1]["tool_call_id"]
3858 .as_str()
3859 .expect("tool result id should serialize");
3860
3861 assert_ne!(assistant_id, invalid_id);
3862 assert!(is_valid_mistral_tool_call_id(assistant_id));
3863 assert_eq!(assistant_id, tool_id);
3864 }
3865
3866 #[test]
3867 fn convert_messages_for_native_keeps_user_image_markers_as_text_when_disabled() {
3868 let input = vec![ChatMessage::user(
3869 "System primer [IMAGE:data:image/png;base64,abcd] user turn",
3870 )];
3871
3872 let provider = make_model_provider("test", "https://example.com", None);
3873 let converted = provider.convert_messages_for_native(&input, false);
3874 assert_eq!(converted.len(), 1);
3875 assert_eq!(converted[0].role, "user");
3876 assert!(matches!(
3877 converted[0].content.as_ref(),
3878 Some(MessageContent::Text(value))
3879 if value == "System primer [IMAGE:data:image/png;base64,abcd] user turn"
3880 ));
3881 }
3882
3883 #[test]
3884 fn flatten_system_messages_merges_into_first_user() {
3885 let input = vec![
3886 ChatMessage::system("core policy"),
3887 ChatMessage::assistant("ack"),
3888 ChatMessage::system("delivery rules"),
3889 ChatMessage::user("hello"),
3890 ChatMessage::assistant("post-user"),
3891 ];
3892
3893 let output = OpenAiCompatibleModelProvider::flatten_system_messages(&input, true);
3894 assert_eq!(output.len(), 3);
3895 assert_eq!(output[0].role, "assistant");
3896 assert_eq!(output[0].content, "ack");
3897 assert_eq!(output[1].role, "user");
3898 assert_eq!(output[1].content, "core policy\n\ndelivery rules\n\nhello");
3899 assert_eq!(output[2].role, "assistant");
3900 assert_eq!(output[2].content, "post-user");
3901 assert!(output.iter().all(|m| m.role != "system"));
3902 }
3903
3904 #[test]
3905 fn flatten_system_messages_inserts_user_when_missing() {
3906 let input = vec![
3907 ChatMessage::system("core policy"),
3908 ChatMessage::assistant("ack"),
3909 ];
3910
3911 let output = OpenAiCompatibleModelProvider::flatten_system_messages(&input, true);
3912 assert_eq!(output.len(), 2);
3913 assert_eq!(output[0].role, "user");
3914 assert_eq!(output[0].content, "core policy");
3915 assert_eq!(output[1].role, "assistant");
3916 assert_eq!(output[1].content, "ack");
3917 }
3918
3919 #[test]
3920 fn strip_think_tags_drops_unclosed_block_suffix() {
3921 let input = "visible<think>hidden";
3922 assert_eq!(strip_think_tags(input), "visible");
3923 }
3924
3925 #[test]
3926 fn native_tool_schema_unsupported_detection_is_precise() {
3927 assert!(
3928 OpenAiCompatibleModelProvider::is_native_tool_schema_unsupported(
3929 reqwest::StatusCode::BAD_REQUEST,
3930 "unknown parameter: tools"
3931 )
3932 );
3933 assert!(
3934 !OpenAiCompatibleModelProvider::is_native_tool_schema_unsupported(
3935 reqwest::StatusCode::UNAUTHORIZED,
3936 "unknown parameter: tools"
3937 )
3938 );
3939 }
3940
3941 #[test]
3942 fn native_tool_schema_unsupported_detects_groq_tool_validation_error() {
3943 assert!(
3944 OpenAiCompatibleModelProvider::is_native_tool_schema_unsupported(
3945 reqwest::StatusCode::BAD_REQUEST,
3946 r#"Groq API error (400 Bad Request): {"error":{"message":"tool call validation failed: attempted to call tool 'memory_recall={\"limit\":5}' which was not in request"}}"#
3947 )
3948 );
3949 }
3950
3951 #[test]
3952 fn prompt_guided_tool_fallback_injects_system_instruction() {
3953 let input = vec![ChatMessage::user("check status")];
3954 let tools = vec![zeroclaw_api::tool::ToolSpec {
3955 name: "shell_exec".to_string(),
3956 description: "Execute shell command".to_string(),
3957 parameters: serde_json::json!({
3958 "type": "object",
3959 "properties": {
3960 "command": { "type": "string" }
3961 },
3962 "required": ["command"]
3963 }),
3964 }];
3965
3966 let output = OpenAiCompatibleModelProvider::with_prompt_guided_tool_instructions(
3967 &input,
3968 Some(&tools),
3969 );
3970 assert!(!output.is_empty());
3971 assert_eq!(output[0].role, "system");
3972 assert!(output[0].content.contains("Available Tools"));
3973 assert!(output[0].content.contains("shell_exec"));
3974 }
3975
3976 #[test]
3977 fn reasoning_effort_only_applies_to_openai_and_selected_codex_models() {
3978 let model_provider = make_model_provider("test", "https://example.com", None)
3979 .with_reasoning_effort(Some("high".to_string()));
3980
3981 assert_eq!(
3982 model_provider.reasoning_effort_for_model("o1-preview"),
3983 Some("high".to_string())
3984 );
3985 assert_eq!(
3986 model_provider.reasoning_effort_for_model("openai/o3-mini"),
3987 Some("high".to_string())
3988 );
3989 assert_eq!(
3990 model_provider.reasoning_effort_for_model("o4-mini"),
3991 Some("high".to_string())
3992 );
3993 assert_eq!(
3994 model_provider.reasoning_effort_for_model("gpt-5"),
3995 Some("high".to_string())
3996 );
3997 assert_eq!(
3998 model_provider.reasoning_effort_for_model("gpt-5.3-codex"),
3999 Some("high".to_string())
4000 );
4001 assert_eq!(
4002 model_provider.reasoning_effort_for_model("openai/gpt-5"),
4003 Some("high".to_string())
4004 );
4005 assert_eq!(
4006 model_provider.reasoning_effort_for_model("gpt-5-chat-latest"),
4007 None,
4008 "gpt-5*-chat-latest are non-reasoning chat-router models and must not receive reasoning_effort",
4009 );
4010 assert_eq!(
4011 model_provider.reasoning_effort_for_model("gpt-5.1-chat-latest"),
4012 None,
4013 "gpt-5*-chat-latest are non-reasoning chat-router models and must not receive reasoning_effort",
4014 );
4015 assert_eq!(
4016 model_provider.reasoning_effort_for_model("gpt-4-codex"),
4017 Some("high".to_string())
4018 );
4019 assert_eq!(
4020 model_provider.reasoning_effort_for_model("llama-3-codex"),
4021 None,
4022 "generic codex-like model names must not receive OpenAI-only reasoning_effort",
4023 );
4024 assert_eq!(
4025 model_provider.reasoning_effort_for_model("llama-3.3-70b"),
4026 None
4027 );
4028 }
4029
4030 #[tokio::test]
4031 async fn warmup_without_key_attempts_connection() {
4032 let model_provider = make_model_provider("test", "http://127.0.0.1:1", None);
4033 let result = model_provider.warmup().await;
4034 assert!(result.is_err());
4035 let err_msg = result.unwrap_err().to_string();
4036 assert!(
4037 !err_msg.contains("API key not set"),
4038 "should not get credential error, got: {err_msg}"
4039 );
4040 }
4041
4042 #[test]
4047 fn capabilities_reports_native_tool_calling() {
4048 let p = make_model_provider("test", "https://example.com", None);
4049 let caps = <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p);
4050 assert!(caps.native_tool_calling);
4051 assert!(!caps.vision);
4052 }
4053
4054 #[test]
4055 fn capabilities_reports_vision_for_qwen_compatible_provider() {
4056 let p = OpenAiCompatibleModelProvider::new_with_vision(
4057 "test",
4058 "Qwen",
4059 "https://dashscope.aliyuncs.com/compatible-mode/v1",
4060 Some("k"),
4061 AuthStyle::Bearer,
4062 true,
4063 );
4064 let caps = <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p);
4065 assert!(caps.native_tool_calling);
4066 assert!(caps.vision);
4067 }
4068
4069 #[test]
4070 fn minimax_provider_supports_native_tool_calling_with_system_merge() {
4071 let p = OpenAiCompatibleModelProvider::new(
4072 "test",
4073 "MiniMax",
4074 "https://api.minimax.chat/v1",
4075 Some("k"),
4076 AuthStyle::Bearer,
4077 )
4078 .with_merge_system_into_user();
4079 let caps = <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p);
4080 assert!(
4081 caps.native_tool_calling,
4082 "MiniMax should preserve native tool calling when system messages are merged"
4083 );
4084 assert!(!caps.vision);
4085 }
4086
4087 #[test]
4090 fn strip_native_tool_messages_removes_tool_and_tool_calls() {
4091 let messages = vec![
4092 ChatMessage::system("sys"),
4093 ChatMessage::user("search for cats"),
4094 ChatMessage::assistant(
4095 r#"{"content":"I'll search","tool_calls":[{"id":"chatcmpl-tool-abc","name":"web_search","arguments":"{}"}]}"#,
4096 ),
4097 ChatMessage::tool(
4098 r#"{"tool_call_id":"chatcmpl-tool-abc","content":"Found 10 results"}"#,
4099 ),
4100 ChatMessage::assistant("Here are the results about cats"),
4101 ChatMessage::user("thanks"),
4102 ];
4103 let p = OpenAiCompatibleModelProvider::new_merge_system_into_user(
4104 "test",
4105 "MiniMax",
4106 "https://api.minimax.chat/v1",
4107 Some("k"),
4108 AuthStyle::Bearer,
4109 );
4110 let stripped = p.strip_native_tool_messages(&messages);
4111 assert_eq!(stripped.len(), 4);
4116 assert_eq!(stripped[0].role, "system");
4117 assert_eq!(stripped[1].role, "user");
4118 assert_eq!(stripped[1].content, "search for cats");
4119 assert_eq!(stripped[2].role, "assistant");
4120 assert!(
4121 stripped[2].content.starts_with("I'll search"),
4122 "coalesced assistant must preserve the pre-tool narration; got {:?}",
4123 stripped[2].content
4124 );
4125 assert!(
4126 stripped[2]
4127 .content
4128 .contains("Here are the results about cats"),
4129 "coalesced assistant must preserve the post-tool reply; got {:?}",
4130 stripped[2].content
4131 );
4132 assert!(
4133 !stripped[2].content.contains("tool_calls"),
4134 "tool_calls structure must be stripped"
4135 );
4136 assert_eq!(stripped[3].role, "user");
4137 }
4138
4139 #[test]
4140 fn strip_native_tool_messages_drops_empty_assistant_tool_calls() {
4141 let messages = vec![
4142 ChatMessage::system("sys"),
4143 ChatMessage::user("do it"),
4144 ChatMessage::assistant(
4145 r#"{"content":"","tool_calls":[{"id":"tc1","name":"shell","arguments":"{}"}]}"#,
4146 ),
4147 ChatMessage::tool(r#"{"tool_call_id":"tc1","content":"ok"}"#),
4148 ChatMessage::assistant("Done"),
4149 ];
4150 let p = OpenAiCompatibleModelProvider::new_merge_system_into_user(
4151 "test",
4152 "MiniMax",
4153 "https://api.minimax.chat/v1",
4154 Some("k"),
4155 AuthStyle::Bearer,
4156 );
4157 let stripped = p.strip_native_tool_messages(&messages);
4158 assert_eq!(stripped.len(), 3);
4160 assert_eq!(stripped[0].role, "system");
4161 assert_eq!(stripped[1].role, "user");
4162 assert_eq!(stripped[2].role, "assistant");
4163 assert_eq!(stripped[2].content, "Done");
4164 }
4165
4166 #[test]
4167 fn strip_native_tool_messages_preserves_regular_messages() {
4168 let messages = vec![
4169 ChatMessage::system("sys"),
4170 ChatMessage::user("hello"),
4171 ChatMessage::assistant("hi there"),
4172 ChatMessage::user("bye"),
4173 ];
4174 let p = OpenAiCompatibleModelProvider::new_merge_system_into_user(
4175 "test",
4176 "MiniMax",
4177 "https://api.minimax.chat/v1",
4178 Some("k"),
4179 AuthStyle::Bearer,
4180 );
4181 let stripped = p.strip_native_tool_messages(&messages);
4182 assert_eq!(stripped.len(), 4);
4183 for (orig, result) in messages.iter().zip(stripped.iter()) {
4184 assert_eq!(orig.role, result.role);
4185 assert_eq!(orig.content, result.content);
4186 }
4187 }
4188
4189 #[test]
4193 fn strip_native_tool_messages_passthrough_when_native_tool_calling_enabled() {
4194 let messages = vec![
4195 ChatMessage::system("sys"),
4196 ChatMessage::user("search for cats"),
4197 ChatMessage::assistant(
4198 r#"{"content":"I'll search","tool_calls":[{"id":"chatcmpl-tool-abc","name":"web_search","arguments":"{}"}]}"#,
4199 ),
4200 ChatMessage::tool(
4201 r#"{"tool_call_id":"chatcmpl-tool-abc","content":"Found 10 results"}"#,
4202 ),
4203 ChatMessage::assistant("Here are the results about cats"),
4204 ];
4205 let p = OpenAiCompatibleModelProvider::new(
4206 "test",
4207 "NativeToolProvider",
4208 "https://api.example.com/v1",
4209 Some("k"),
4210 AuthStyle::Bearer,
4211 );
4212 assert!(
4213 <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p).native_tool_calling,
4214 "model_provider must have native_tool_calling enabled for this test"
4215 );
4216 let result = p.strip_native_tool_messages(&messages);
4217 assert_eq!(result.len(), messages.len());
4218 for (orig, out) in messages.iter().zip(result.iter()) {
4219 assert_eq!(orig.role, out.role);
4220 assert_eq!(orig.content, out.content);
4221 }
4222 }
4223
4224 #[test]
4225 fn user_agent_constructor_keeps_native_tool_calling_enabled() {
4226 let p = OpenAiCompatibleModelProvider::new_with_user_agent(
4227 "test",
4228 "TestProvider",
4229 "https://example.com",
4230 Some("k"),
4231 AuthStyle::Bearer,
4232 "zeroclaw-test/1.0",
4233 );
4234 let caps = <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p);
4235 assert!(caps.native_tool_calling);
4236 assert!(!caps.vision);
4237 assert_eq!(p.user_agent.as_deref(), Some("zeroclaw-test/1.0"));
4238 }
4239
4240 #[test]
4241 fn user_agent_and_vision_constructor_preserves_capability_flags() {
4242 let p = OpenAiCompatibleModelProvider::new_with_user_agent_and_vision(
4243 "test",
4244 "VisionModelProvider",
4245 "https://example.com",
4246 Some("k"),
4247 AuthStyle::Bearer,
4248 "zeroclaw-test/vision",
4249 true,
4250 );
4251 let caps = <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p);
4252 assert!(caps.native_tool_calling);
4253 assert!(caps.vision);
4254 assert_eq!(p.user_agent.as_deref(), Some("zeroclaw-test/vision"));
4255 }
4256
4257 #[test]
4258 fn to_message_content_converts_image_markers_to_openai_parts() {
4259 let content = "Describe this\n\n[IMAGE:data:image/png;base64,abcd]";
4260 let value = serde_json::to_value(OpenAiCompatibleModelProvider::to_message_content(
4261 "user", content, true,
4262 ))
4263 .unwrap();
4264 let parts = value
4265 .as_array()
4266 .expect("multimodal content should be an array");
4267 assert_eq!(parts.len(), 2);
4268 assert_eq!(parts[0]["type"], "text");
4269 assert_eq!(parts[0]["text"], "Describe this");
4270 assert_eq!(parts[1]["type"], "image_url");
4271 assert_eq!(parts[1]["image_url"]["url"], "data:image/png;base64,abcd");
4272 }
4273
4274 #[test]
4275 fn to_message_content_keeps_markers_as_text_when_user_image_parts_disabled() {
4276 let content = "Policy [IMAGE:data:image/png;base64,abcd]";
4277 let value = serde_json::to_value(OpenAiCompatibleModelProvider::to_message_content(
4278 "user", content, false,
4279 ))
4280 .unwrap();
4281 assert_eq!(value, serde_json::json!(content));
4282 }
4283
4284 #[test]
4285 fn to_message_content_keeps_plain_text_for_non_user_roles() {
4286 let value = serde_json::to_value(OpenAiCompatibleModelProvider::to_message_content(
4287 "system",
4288 "You are a helpful assistant.",
4289 true,
4290 ))
4291 .unwrap();
4292 assert_eq!(value, serde_json::json!("You are a helpful assistant."));
4293 }
4294
4295 #[tokio::test]
4296 async fn normalize_messages_for_upstream_rewrites_local_image_path_to_data_uri() {
4297 let tmp = tempfile::TempDir::new().expect("tempdir");
4301 let path = tmp.path().join("pixel.png");
4302 let png: [u8; 67] = [
4304 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48,
4305 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x06, 0x00, 0x00,
4306 0x00, 0x1F, 0x15, 0xC4, 0x89, 0x00, 0x00, 0x00, 0x0A, 0x49, 0x44, 0x41, 0x54, 0x78,
4307 0x9C, 0x63, 0x00, 0x01, 0x00, 0x00, 0x05, 0x00, 0x01, 0x0D, 0x0A, 0x2D, 0xB4, 0x00,
4308 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82,
4309 ];
4310 std::fs::write(&path, png).expect("write pixel.png");
4311 let path_str = path.to_string_lossy().into_owned();
4312
4313 let msg = ChatMessage {
4314 role: "user".into(),
4315 content: format!("Caption please [IMAGE:{}]", path_str),
4316 };
4317
4318 let normalized = OpenAiCompatibleModelProvider::normalize_messages_for_upstream(
4319 std::slice::from_ref(&msg),
4320 )
4321 .await
4322 .expect("normalize ok");
4323
4324 assert_eq!(normalized.len(), 1);
4325 let content = &normalized[0].content;
4326 assert!(
4327 content.contains("[IMAGE:data:image/png;base64,"),
4328 "expected base64 data URI in normalized content, got: {content}"
4329 );
4330 assert!(
4331 !content.contains(&path_str),
4332 "raw local path must not leak to upstream, got: {content}"
4333 );
4334 }
4335
4336 #[test]
4337 fn request_serializes_with_tools() {
4338 let tools = vec![serde_json::json!({
4339 "type": "function",
4340 "function": {
4341 "name": "get_weather",
4342 "description": "Get weather for a location",
4343 "parameters": {
4344 "type": "object",
4345 "properties": {
4346 "location": {"type": "string"}
4347 }
4348 }
4349 }
4350 })];
4351
4352 let req = ApiChatRequest {
4353 model: "test-model".to_string(),
4354 messages: vec![Message {
4355 role: "user".to_string(),
4356 content: MessageContent::Text("What is the weather?".to_string()),
4357 }],
4358 temperature: Some(0.7),
4359 stream: Some(false),
4360 stream_options: None,
4361 reasoning_effort: None,
4362 tool_stream: None,
4363 tools: Some(tools),
4364 tool_choice: Some("auto".to_string()),
4365 max_tokens: None,
4366 };
4367 let json = serde_json::to_string(&req).unwrap();
4368 assert!(json.contains("\"tools\""));
4369 assert!(json.contains("get_weather"));
4370 assert!(json.contains("\"tool_choice\":\"auto\""));
4371 }
4372
4373 #[test]
4374 fn zai_tool_requests_enable_tool_stream() {
4375 let model_provider = make_model_provider("zai", "https://api.z.ai/api/paas/v4", None);
4376 let req = ApiChatRequest {
4377 model: "glm-5".to_string(),
4378 messages: vec![Message {
4379 role: "user".to_string(),
4380 content: MessageContent::Text("List /tmp".to_string()),
4381 }],
4382 temperature: Some(0.7),
4383 stream: Some(false),
4384 stream_options: None,
4385 reasoning_effort: None,
4386 tool_stream: model_provider.tool_stream_for_tools(true),
4387 tools: Some(vec![serde_json::json!({
4388 "type": "function",
4389 "function": {
4390 "name": "shell",
4391 "description": "Run a shell command",
4392 "parameters": {
4393 "type": "object",
4394 "properties": {
4395 "command": {"type": "string"}
4396 }
4397 }
4398 }
4399 })]),
4400 tool_choice: Some("auto".to_string()),
4401 max_tokens: None,
4402 };
4403
4404 let json = serde_json::to_string(&req).unwrap();
4405 assert!(json.contains("\"tool_stream\":true"));
4406 }
4407
4408 #[test]
4409 fn non_zai_tool_requests_omit_tool_stream() {
4410 let model_provider = make_model_provider("test", "https://api.example.com/v1", None);
4411 let req = ApiChatRequest {
4412 model: "test-model".to_string(),
4413 messages: vec![Message {
4414 role: "user".to_string(),
4415 content: MessageContent::Text("List /tmp".to_string()),
4416 }],
4417 temperature: Some(0.7),
4418 stream: Some(false),
4419 stream_options: None,
4420 reasoning_effort: None,
4421 tool_stream: model_provider.tool_stream_for_tools(true),
4422 tools: Some(vec![serde_json::json!({
4423 "type": "function",
4424 "function": {
4425 "name": "shell",
4426 "description": "Run a shell command",
4427 "parameters": {
4428 "type": "object",
4429 "properties": {
4430 "command": {"type": "string"}
4431 }
4432 }
4433 }
4434 })]),
4435 tool_choice: Some("auto".to_string()),
4436 max_tokens: None,
4437 };
4438
4439 let json = serde_json::to_string(&req).unwrap();
4440 assert!(!json.contains("\"tool_stream\""));
4441 }
4442
4443 #[test]
4444 fn non_zai_provider_omits_tool_stream_regardless_of_streaming() {
4445 let model_provider = make_model_provider("custom", "https://proxy.example.com/v1", None);
4446 assert_eq!(model_provider.tool_stream_for_tools(true), None);
4448 assert_eq!(model_provider.tool_stream_for_tools(false), None);
4449 }
4450
4451 #[test]
4452 fn z_ai_host_enables_tool_stream_for_custom_profiles() {
4453 let model_provider =
4454 make_model_provider("custom", "https://api.z.ai/api/coding/paas/v4", None);
4455 assert_eq!(model_provider.tool_stream_for_tools(true), Some(true));
4456 }
4457
4458 #[test]
4459 fn response_with_tool_calls_deserializes() {
4460 let json = r#"{
4461 "choices": [{
4462 "message": {
4463 "content": null,
4464 "tool_calls": [{
4465 "type": "function",
4466 "function": {
4467 "name": "get_weather",
4468 "arguments": "{\"location\":\"London\"}"
4469 }
4470 }]
4471 }
4472 }]
4473 }"#;
4474
4475 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4476 let msg = &resp.choices[0].message;
4477 assert!(msg.content.is_none());
4478 let tool_calls = msg.tool_calls.as_ref().unwrap();
4479 assert_eq!(tool_calls.len(), 1);
4480 assert_eq!(
4481 tool_calls[0].function.as_ref().unwrap().name.as_deref(),
4482 Some("get_weather")
4483 );
4484 assert_eq!(
4485 tool_calls[0]
4486 .function
4487 .as_ref()
4488 .unwrap()
4489 .arguments
4490 .as_deref(),
4491 Some("{\"location\":\"London\"}")
4492 );
4493 }
4494
4495 #[test]
4496 fn response_with_multiple_tool_calls() {
4497 let json = r#"{
4498 "choices": [{
4499 "message": {
4500 "content": "I'll check both.",
4501 "tool_calls": [
4502 {
4503 "type": "function",
4504 "function": {
4505 "name": "get_weather",
4506 "arguments": "{\"location\":\"London\"}"
4507 }
4508 },
4509 {
4510 "type": "function",
4511 "function": {
4512 "name": "get_time",
4513 "arguments": "{\"timezone\":\"UTC\"}"
4514 }
4515 }
4516 ]
4517 }
4518 }]
4519 }"#;
4520
4521 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4522 let msg = &resp.choices[0].message;
4523 assert_eq!(msg.content.as_deref(), Some("I'll check both."));
4524 let tool_calls = msg.tool_calls.as_ref().unwrap();
4525 assert_eq!(tool_calls.len(), 2);
4526 assert_eq!(
4527 tool_calls[0].function.as_ref().unwrap().name.as_deref(),
4528 Some("get_weather")
4529 );
4530 assert_eq!(
4531 tool_calls[1].function.as_ref().unwrap().name.as_deref(),
4532 Some("get_time")
4533 );
4534 }
4535
4536 #[tokio::test]
4537 async fn chat_with_tools_without_key_attempts_request() {
4538 let p = make_model_provider("TestProvider", "http://127.0.0.1:1", None);
4539 let messages = vec![ChatMessage {
4540 role: "user".to_string(),
4541 content: "hello".to_string(),
4542 }];
4543 let tools = vec![serde_json::json!({
4544 "type": "function",
4545 "function": {
4546 "name": "test_tool",
4547 "description": "A test tool",
4548 "parameters": {}
4549 }
4550 })];
4551
4552 let result = p
4553 .chat_with_tools(&messages, &tools, "model", Some(0.7))
4554 .await;
4555 assert!(result.is_err());
4556 let err_msg = result.unwrap_err().to_string();
4557 assert!(
4558 !err_msg.contains("API key not set"),
4559 "should not get credential error, got: {err_msg}"
4560 );
4561 }
4562
4563 #[test]
4564 fn chat_with_tools_request_preserves_reasoning_content_in_history() {
4565 let p = make_model_provider("DeepSeek", "https://api.deepseek.example/v1", None);
4566 let history_json = serde_json::json!({
4567 "content": "I will inspect the workspace.",
4568 "tool_calls": [{
4569 "id": "call_1",
4570 "name": "shell",
4571 "arguments": "{\"cmd\":\"ls\"}"
4572 }],
4573 "reasoning_content": "Need to inspect the current files before answering."
4574 });
4575 let messages = vec![
4576 ChatMessage::assistant(history_json.to_string()),
4577 ChatMessage::tool(r#"{"tool_call_id":"call_1","content":"src\nCargo.toml"}"#),
4578 ChatMessage::user("continue"),
4579 ];
4580 let tools = vec![serde_json::json!({
4581 "type": "function",
4582 "function": {
4583 "name": "shell",
4584 "description": "Run a shell command",
4585 "parameters": {}
4586 }
4587 })];
4588
4589 let request = p.build_native_tool_chat_request(
4590 &messages,
4591 Some(tools),
4592 "deepseek-v4-flash",
4593 Some(0.7),
4594 true,
4595 );
4596 let value = serde_json::to_value(&request).unwrap();
4597 let first_message = &value["messages"][0];
4598
4599 assert_eq!(first_message["role"], "assistant");
4600 assert_eq!(
4601 first_message["reasoning_content"],
4602 "Need to inspect the current files before answering."
4603 );
4604 assert!(
4605 first_message["tool_calls"].is_array(),
4606 "assistant tool-call history must stay native in chat_with_tools requests"
4607 );
4608 assert_eq!(value["tools"][0]["function"]["name"], "shell");
4609 assert_eq!(value["tool_choice"], "auto");
4610 }
4611
4612 #[test]
4613 fn response_with_no_tool_calls_has_empty_vec() {
4614 let json = r#"{"choices":[{"message":{"content":"Just text, no tools."}}]}"#;
4615 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4616 let msg = &resp.choices[0].message;
4617 assert_eq!(msg.content.as_deref(), Some("Just text, no tools."));
4618 assert!(msg.tool_calls.is_none());
4619 }
4620
4621 #[test]
4622 fn flatten_system_messages_merges_into_first_user_and_removes_system_roles() {
4623 let messages = vec![
4624 ChatMessage::system("System A"),
4625 ChatMessage::assistant("Earlier assistant turn"),
4626 ChatMessage::system("System B"),
4627 ChatMessage::user("User turn"),
4628 ChatMessage::tool(r#"{"ok":true}"#),
4629 ];
4630
4631 let flattened = OpenAiCompatibleModelProvider::flatten_system_messages(&messages, true);
4632 assert_eq!(flattened.len(), 3);
4633 assert_eq!(flattened[0].role, "assistant");
4634 assert_eq!(
4635 flattened[1].content,
4636 "System A\n\nSystem B\n\nUser turn".to_string()
4637 );
4638 assert_eq!(flattened[1].role, "user");
4639 assert_eq!(flattened[2].role, "tool");
4640 assert!(!flattened.iter().any(|m| m.role == "system"));
4641 }
4642
4643 #[test]
4644 fn flatten_system_messages_keeps_system_only_at_start_without_user_merge() {
4645 let messages = vec![
4646 ChatMessage::system("System A"),
4647 ChatMessage::user("User turn"),
4648 ChatMessage::assistant("Assistant turn"),
4649 ChatMessage::system("System B"),
4650 ChatMessage::user("Follow-up"),
4651 ];
4652
4653 let flattened = OpenAiCompatibleModelProvider::flatten_system_messages(&messages, false);
4654 assert_eq!(
4655 flattened
4656 .iter()
4657 .map(|message| message.role.as_str())
4658 .collect::<Vec<_>>(),
4659 vec!["system", "user", "assistant", "user"]
4660 );
4661 assert_eq!(
4662 flattened
4663 .iter()
4664 .filter(|message| message.role == "system")
4665 .count(),
4666 1
4667 );
4668 assert!(flattened[0].content.contains("System A"));
4669 assert!(flattened[0].content.contains("System B"));
4670 }
4671
4672 #[test]
4673 fn flatten_system_messages_drops_empty_system_messages() {
4674 let messages = vec![
4675 ChatMessage::system(""),
4676 ChatMessage::user("User turn"),
4677 ChatMessage::system(""),
4678 ];
4679
4680 let flattened = OpenAiCompatibleModelProvider::flatten_system_messages(&messages, false);
4681
4682 assert_eq!(flattened.len(), 1);
4683 assert_eq!(flattened[0].role, "user");
4684 assert_eq!(flattened[0].content, "User turn");
4685 }
4686
4687 #[test]
4688 fn flatten_system_messages_inserts_synthetic_user_when_no_user_exists() {
4689 let messages = vec![
4690 ChatMessage::assistant("Assistant only"),
4691 ChatMessage::system("Synthetic system"),
4692 ];
4693
4694 let flattened = OpenAiCompatibleModelProvider::flatten_system_messages(&messages, true);
4695 assert_eq!(flattened.len(), 2);
4696 assert_eq!(flattened[0].role, "user");
4697 assert_eq!(flattened[0].content, "Synthetic system");
4698 assert_eq!(flattened[1].role, "assistant");
4699 }
4700
4701 #[test]
4702 fn strip_think_tags_removes_multiple_blocks_with_surrounding_text() {
4703 let input = "Answer A <think>hidden 1</think> and B <think>hidden 2</think> done";
4704 let output = strip_think_tags(input);
4705 assert_eq!(output, "Answer A and B done");
4706 }
4707
4708 #[test]
4709 fn strip_think_tags_drops_tail_for_unclosed_block() {
4710 let input = "Visible<think>hidden tail";
4711 let output = strip_think_tags(input);
4712 assert_eq!(output, "Visible");
4713 }
4714
4715 #[test]
4720 fn reasoning_content_fallback_when_content_empty() {
4721 let json = r#"{"choices":[{"message":{"content":"","reasoning_content":"Thinking output here"}}]}"#;
4723 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4724 let msg = &resp.choices[0].message;
4725 assert_eq!(msg.effective_content(), "Thinking output here");
4726 }
4727
4728 #[test]
4729 fn reasoning_content_fallback_when_content_null() {
4730 let json =
4732 r#"{"choices":[{"message":{"content":null,"reasoning_content":"Fallback text"}}]}"#;
4733 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4734 let msg = &resp.choices[0].message;
4735 assert_eq!(msg.effective_content(), "Fallback text");
4736 }
4737
4738 #[test]
4739 fn reasoning_content_fallback_when_content_missing() {
4740 let json = r#"{"choices":[{"message":{"reasoning_content":"Only reasoning"}}]}"#;
4742 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4743 let msg = &resp.choices[0].message;
4744 assert_eq!(msg.effective_content(), "Only reasoning");
4745 }
4746
4747 #[test]
4748 fn reasoning_content_not_used_when_content_present() {
4749 let json = r#"{"choices":[{"message":{"content":"Normal response","reasoning_content":"Should be ignored"}}]}"#;
4751 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4752 let msg = &resp.choices[0].message;
4753 assert_eq!(msg.effective_content(), "Normal response");
4754 }
4755
4756 #[test]
4757 fn reasoning_content_used_when_content_only_think_tags() {
4758 let json = r#"{"choices":[{"message":{"content":"<think>secret</think>","reasoning_content":"Fallback text"}}]}"#;
4759 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4760 let msg = &resp.choices[0].message;
4761 assert_eq!(msg.effective_content(), "Fallback text");
4762 assert_eq!(
4763 msg.effective_content_optional().as_deref(),
4764 Some("Fallback text")
4765 );
4766 }
4767
4768 #[test]
4769 fn reasoning_content_both_absent_returns_empty() {
4770 let json = r#"{"choices":[{"message":{}}]}"#;
4772 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4773 let msg = &resp.choices[0].message;
4774 assert_eq!(msg.effective_content(), "");
4775 }
4776
4777 #[test]
4778 fn reasoning_content_ignored_by_normal_models() {
4779 let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#;
4781 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4782 let msg = &resp.choices[0].message;
4783 assert!(msg.reasoning_content.is_none());
4784 assert_eq!(msg.effective_content(), "Hello from Venice!");
4785 }
4786
4787 #[test]
4792 fn parse_sse_line_with_content() {
4793 let line = r#"data: {"choices":[{"delta":{"content":"hello"}}]}"#;
4794 let result = parse_sse_line(line).unwrap().unwrap();
4795 assert_eq!(result.delta, "hello");
4796 assert!(result.reasoning.is_none());
4797 }
4798
4799 #[test]
4800 fn parse_sse_line_with_reasoning_content() {
4801 let line = r#"data: {"choices":[{"delta":{"reasoning_content":"thinking..."}}]}"#;
4802 let result = parse_sse_line(line).unwrap().unwrap();
4803 assert!(result.delta.is_empty());
4804 assert_eq!(result.reasoning.as_deref(), Some("thinking..."));
4805 }
4806
4807 #[test]
4808 fn parse_sse_line_with_both_prefers_content() {
4809 let line = r#"data: {"choices":[{"delta":{"content":"real answer","reasoning_content":"thinking..."}}]}"#;
4810 let result = parse_sse_line(line).unwrap().unwrap();
4811 assert_eq!(result.delta, "real answer");
4812 assert!(result.reasoning.is_none());
4813 }
4814
4815 #[test]
4816 fn parse_sse_line_with_empty_content_falls_back_to_reasoning() {
4817 let line =
4818 r#"data: {"choices":[{"delta":{"content":"","reasoning_content":"thinking..."}}]}"#;
4819 let result = parse_sse_line(line).unwrap().unwrap();
4820 assert!(result.delta.is_empty());
4821 assert_eq!(result.reasoning.as_deref(), Some("thinking..."));
4822 }
4823
4824 #[test]
4828 fn parse_sse_line_accepts_reasoning_alias() {
4829 let line = r#"data: {"choices":[{"delta":{"reasoning":"thinking via vllm..."}}]}"#;
4830 let result = parse_sse_line(line).unwrap().unwrap();
4831 assert!(result.delta.is_empty());
4832 assert_eq!(result.reasoning.as_deref(), Some("thinking via vllm..."));
4833 }
4834
4835 #[test]
4836 fn parse_sse_line_with_empty_content_and_reasoning_alias() {
4837 let line = r#"data: {"choices":[{"delta":{"content":"","reasoning":"vllm thought"}}]}"#;
4838 let result = parse_sse_line(line).unwrap().unwrap();
4839 assert!(result.delta.is_empty());
4840 assert_eq!(result.reasoning.as_deref(), Some("vllm thought"));
4841 }
4842
4843 #[test]
4844 fn response_message_accepts_reasoning_alias_on_non_stream_path() {
4845 let json = r#"{"content":null,"reasoning":"chain-of-thought via vllm","tool_calls":null}"#;
4847 let msg: ResponseMessage = serde_json::from_str(json).unwrap();
4848 assert!(msg.content.is_none());
4849 assert_eq!(
4850 msg.reasoning_content.as_deref(),
4851 Some("chain-of-thought via vllm"),
4852 "the `reasoning` alias must populate the canonical reasoning_content field",
4853 );
4854 assert_eq!(msg.effective_content(), "chain-of-thought via vllm");
4856 }
4857
4858 #[test]
4859 fn response_message_canonical_reasoning_content_still_works() {
4860 let json = r#"{"content":null,"reasoning_content":"canonical thought","tool_calls":null}"#;
4862 let msg: ResponseMessage = serde_json::from_str(json).unwrap();
4863 assert_eq!(msg.reasoning_content.as_deref(), Some("canonical thought"));
4864 }
4865
4866 #[test]
4873 fn response_message_with_both_keys_prefers_canonical_reasoning_content() {
4874 let json = r#"{"content":null,"reasoning_content":"canonical","reasoning":"alias","tool_calls":null}"#;
4875 let msg: ResponseMessage = serde_json::from_str(json)
4876 .expect("payload with both reasoning_content and reasoning must deserialize");
4877 assert_eq!(
4878 msg.reasoning_content.as_deref(),
4879 Some("canonical"),
4880 "canonical reasoning_content must win when both fields are present",
4881 );
4882 }
4883
4884 #[test]
4885 fn response_message_with_only_alias_populates_canonical_field() {
4886 let json = r#"{"content":null,"reasoning":"alias only","tool_calls":null}"#;
4889 let msg: ResponseMessage = serde_json::from_str(json).unwrap();
4890 assert_eq!(msg.reasoning_content.as_deref(), Some("alias only"));
4891 }
4892
4893 #[test]
4894 fn stream_delta_with_both_keys_prefers_canonical_reasoning_content() {
4895 let chunk = r#"data: {"choices":[{"delta":{"reasoning_content":"canonical","reasoning":"alias"}}]}"#;
4898 let result = parse_sse_line(chunk)
4899 .expect("parse must succeed")
4900 .expect("non-empty chunk");
4901 assert_eq!(result.reasoning.as_deref(), Some("canonical"));
4902 }
4903
4904 #[test]
4907 fn round_trip_reasoning_extraction_accepts_alias() {
4908 fn extract_reasoning(value: &serde_json::Value) -> Option<String> {
4909 value
4910 .get("reasoning_content")
4911 .or_else(|| value.get("reasoning"))
4912 .and_then(serde_json::Value::as_str)
4913 .map(ToString::to_string)
4914 }
4915 let canonical: serde_json::Value =
4916 serde_json::from_str(r#"{"reasoning_content":"canonical","tool_calls":[]}"#).unwrap();
4917 let alias: serde_json::Value =
4918 serde_json::from_str(r#"{"reasoning":"vllm","tool_calls":[]}"#).unwrap();
4919 let neither: serde_json::Value = serde_json::from_str(r#"{"tool_calls":[]}"#).unwrap();
4920 let both: serde_json::Value = serde_json::from_str(
4921 r#"{"reasoning_content":"canonical","reasoning":"alias","tool_calls":[]}"#,
4922 )
4923 .unwrap();
4924 assert_eq!(extract_reasoning(&canonical).as_deref(), Some("canonical"));
4925 assert_eq!(extract_reasoning(&alias).as_deref(), Some("vllm"));
4926 assert_eq!(extract_reasoning(&neither), None);
4927 assert_eq!(extract_reasoning(&both).as_deref(), Some("canonical"));
4931 }
4932
4933 #[test]
4934 fn parse_sse_line_done_sentinel() {
4935 let line = "data: [DONE]";
4936 let result = parse_sse_line(line).unwrap();
4937 assert!(result.is_none());
4938 }
4939
4940 #[test]
4941 fn parse_sse_chunk_with_tool_call_delta() {
4942 let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","function":{"name":"shell","arguments":"{\"command\":\"date\"}"}}]}}]}"#;
4943 let chunk = parse_sse_chunk(line)
4944 .unwrap()
4945 .expect("chunk should be parsed");
4946 let choice = chunk.choices.first().expect("choice should exist");
4947 let tool_calls = choice
4948 .delta
4949 .tool_calls
4950 .as_ref()
4951 .expect("tool call deltas should exist");
4952 assert_eq!(tool_calls.len(), 1);
4953 assert_eq!(tool_calls[0].index, Some(0));
4954 assert_eq!(tool_calls[0].id.as_deref(), Some("call_1"));
4955 assert_eq!(
4956 tool_calls[0]
4957 .function
4958 .as_ref()
4959 .and_then(|function| function.name.as_deref()),
4960 Some("shell")
4961 );
4962 }
4963
4964 #[test]
4965 fn stream_tool_call_accumulator_combines_deltas() {
4966 let mut acc = StreamToolCallAccumulator::default();
4967 acc.apply_delta(&StreamToolCallDelta {
4968 index: Some(0),
4969 id: Some("call_1".to_string()),
4970 function: Some(StreamFunctionDelta {
4971 name: Some("shell".to_string()),
4972 arguments: Some("{\"command\":\"".to_string()),
4973 }),
4974 name: None,
4975 arguments: None,
4976 extra_content: None,
4977 });
4978 acc.apply_delta(&StreamToolCallDelta {
4979 index: Some(0),
4980 id: None,
4981 function: Some(StreamFunctionDelta {
4982 name: None,
4983 arguments: Some("date\"}".to_string()),
4984 }),
4985 name: None,
4986 arguments: None,
4987 extra_content: None,
4988 });
4989
4990 let mut used_tool_call_ids = std::collections::HashSet::new();
4991 let tool_call = acc
4992 .into_provider_tool_call(false, &mut used_tool_call_ids)
4993 .expect("accumulator should emit tool call");
4994 assert_eq!(tool_call.id, "call_1");
4995 assert_eq!(tool_call.name, "shell");
4996 assert_eq!(tool_call.arguments, r#"{"command":"date"}"#);
4997 }
4998
4999 #[test]
5000 fn stream_tool_call_accumulator_mistral_normalizes_invalid_id() {
5001 let mut acc = StreamToolCallAccumulator::default();
5002 acc.apply_delta(&StreamToolCallDelta {
5003 index: Some(0),
5004 id: Some("chatcmpl-tool-abc".to_string()),
5005 function: Some(StreamFunctionDelta {
5006 name: Some("shell".to_string()),
5007 arguments: Some(r#"{"command":"date"}"#.to_string()),
5008 }),
5009 name: None,
5010 arguments: None,
5011 extra_content: None,
5012 });
5013
5014 let mut used_tool_call_ids = std::collections::HashSet::new();
5015 let tool_call = acc
5016 .into_provider_tool_call(true, &mut used_tool_call_ids)
5017 .expect("accumulator should emit tool call");
5018
5019 assert_eq!(tool_call.id.len(), 9);
5020 assert!(tool_call.id.chars().all(|c| c.is_ascii_alphanumeric()));
5021 assert_ne!(tool_call.id, "chatcmpl-tool-abc");
5022 }
5023
5024 #[test]
5025 fn api_response_parses_usage() {
5026 let json = r#"{
5027 "choices": [{"message": {"content": "Hello"}}],
5028 "usage": {"prompt_tokens": 150, "completion_tokens": 60}
5029 }"#;
5030 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
5031 let usage = resp.usage.unwrap();
5032 assert_eq!(usage.prompt_tokens, Some(150));
5033 assert_eq!(usage.completion_tokens, Some(60));
5034 }
5035
5036 #[test]
5037 fn api_response_parses_without_usage() {
5038 let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#;
5039 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
5040 assert!(resp.usage.is_none());
5041 }
5042
5043 #[test]
5048 fn parse_native_response_captures_reasoning_content() {
5049 let provider = make_model_provider("test", "https://example.com", None);
5050 let message = ResponseMessage {
5051 content: Some("answer".to_string()),
5052 reasoning_content: Some("thinking step".to_string()),
5053 tool_calls: Some(vec![ToolCall {
5054 id: Some("call_1".to_string()),
5055 kind: Some("function".to_string()),
5056 function: Some(Function {
5057 name: Some("shell".to_string()),
5058 arguments: Some(r#"{"cmd":"ls"}"#.to_string()),
5059 }),
5060 name: None,
5061 arguments: None,
5062 parameters: None,
5063 extra_content: None,
5064 }]),
5065 };
5066
5067 let parsed = provider.parse_native_response(message);
5068 assert_eq!(parsed.reasoning_content.as_deref(), Some("thinking step"));
5069 assert_eq!(parsed.text.as_deref(), Some("answer"));
5070 assert_eq!(parsed.tool_calls.len(), 1);
5071 }
5072
5073 #[test]
5074 fn parse_native_response_none_reasoning_content_for_normal_model() {
5075 let provider = make_model_provider("test", "https://example.com", None);
5076 let message = ResponseMessage {
5077 content: Some("hello".to_string()),
5078 reasoning_content: None,
5079 tool_calls: None,
5080 };
5081
5082 let parsed = provider.parse_native_response(message);
5083 assert!(parsed.reasoning_content.is_none());
5084 assert_eq!(parsed.text.as_deref(), Some("hello"));
5085 }
5086
5087 #[test]
5088 fn convert_messages_for_native_round_trips_reasoning_content() {
5089 let history_json = serde_json::json!({
5091 "content": "I will check",
5092 "tool_calls": [{
5093 "id": "tc_1",
5094 "name": "shell",
5095 "arguments": "{\"cmd\":\"ls\"}"
5096 }],
5097 "reasoning_content": "Let me think about this..."
5098 });
5099
5100 let messages = vec![ChatMessage::assistant(history_json.to_string())];
5101 let provider = make_model_provider("test", "https://example.com", None);
5102 let native = provider.convert_messages_for_native(&messages, true);
5103 assert_eq!(native.len(), 1);
5104 assert_eq!(native[0].role, "assistant");
5105 assert_eq!(
5106 native[0].reasoning_content.as_deref(),
5107 Some("Let me think about this...")
5108 );
5109 assert!(native[0].tool_calls.is_some());
5110 }
5111
5112 #[test]
5113 fn convert_messages_for_native_no_reasoning_content_when_absent() {
5114 let history_json = serde_json::json!({
5116 "content": "I will check",
5117 "tool_calls": [{
5118 "id": "tc_1",
5119 "name": "shell",
5120 "arguments": "{\"cmd\":\"ls\"}"
5121 }]
5122 });
5123
5124 let messages = vec![ChatMessage::assistant(history_json.to_string())];
5125 let provider = make_model_provider("test", "https://example.com", None);
5126 let native = provider.convert_messages_for_native(&messages, true);
5127 assert_eq!(native.len(), 1);
5128 assert!(native[0].reasoning_content.is_none());
5129 }
5130
5131 #[test]
5138 fn convert_messages_for_native_round_trips_reasoning_content_without_tool_calls() {
5139 let history_json = serde_json::json!({
5140 "content": "Direct answer.",
5141 "reasoning_content": "Let me think step by step..."
5142 });
5143
5144 let messages = vec![ChatMessage::assistant(history_json.to_string())];
5145 let provider = make_model_provider("test", "https://example.com", None);
5146 let native = provider.convert_messages_for_native(&messages, true);
5147 assert_eq!(native.len(), 1);
5148 assert_eq!(native[0].role, "assistant");
5149 assert!(
5150 native[0].tool_calls.is_none(),
5151 "no tool_calls on a plain-text turn"
5152 );
5153 assert_eq!(
5154 native[0].reasoning_content.as_deref(),
5155 Some("Let me think step by step...")
5156 );
5157 match &native[0].content {
5158 Some(MessageContent::Text(t)) => assert_eq!(t, "Direct answer."),
5159 other => panic!("expected text content, got {other:?}"),
5160 }
5161 }
5162
5163 #[test]
5166 fn convert_messages_for_native_content_only_json_falls_through() {
5167 let structured_answer = serde_json::json!({"content": "raw"});
5168 let raw_json = structured_answer.to_string();
5169 let messages = vec![ChatMessage::assistant(raw_json.clone())];
5170 let provider = make_model_provider("test", "https://example.com", None);
5171 let native = provider.convert_messages_for_native(&messages, true);
5172 assert_eq!(native.len(), 1);
5173 assert!(native[0].reasoning_content.is_none());
5174 assert!(native[0].tool_calls.is_none());
5175 match &native[0].content {
5176 Some(MessageContent::Text(t)) => assert_eq!(t.as_str(), raw_json.as_str()),
5177 other => panic!("expected text content from fallback, got {other:?}"),
5178 }
5179 }
5180
5181 #[test]
5184 fn convert_messages_for_native_non_string_reasoning_content_falls_through() {
5185 let structured_answer = serde_json::json!({
5186 "content": "raw",
5187 "reasoning_content": null
5188 });
5189 let raw_json = structured_answer.to_string();
5190 let messages = vec![ChatMessage::assistant(raw_json.clone())];
5191 let provider = make_model_provider("test", "https://example.com", None);
5192 let native = provider.convert_messages_for_native(&messages, true);
5193 assert_eq!(native.len(), 1);
5194 assert!(native[0].reasoning_content.is_none());
5195 assert!(native[0].tool_calls.is_none());
5196 match &native[0].content {
5197 Some(MessageContent::Text(t)) => assert_eq!(t.as_str(), raw_json.as_str()),
5198 other => panic!("expected text content from fallback, got {other:?}"),
5199 }
5200 }
5201
5202 #[test]
5207 fn convert_messages_for_native_unrelated_json_falls_through() {
5208 let unrelated = serde_json::json!({"foo": "bar"});
5209 let messages = vec![ChatMessage::assistant(unrelated.to_string())];
5210 let provider = make_model_provider("test", "https://example.com", None);
5211 let native = provider.convert_messages_for_native(&messages, true);
5212 assert_eq!(native.len(), 1);
5213 assert!(native[0].reasoning_content.is_none());
5214 assert!(native[0].tool_calls.is_none());
5215 match &native[0].content {
5216 Some(MessageContent::Text(t)) => {
5217 assert!(
5218 t.contains("\"foo\""),
5219 "expected raw JSON in fallback content, got {t:?}"
5220 );
5221 }
5222 other => panic!("expected text content from fallback, got {other:?}"),
5223 }
5224 }
5225
5226 #[test]
5227 fn convert_messages_for_native_reasoning_content_serialized_only_when_present() {
5228 let msg_without = NativeMessage {
5230 role: "assistant".to_string(),
5231 content: Some(MessageContent::Text("hi".to_string())),
5232 tool_call_id: None,
5233 tool_calls: None,
5234 reasoning_content: None,
5235 };
5236 let json = serde_json::to_string(&msg_without).unwrap();
5237 assert!(
5238 !json.contains("reasoning_content"),
5239 "reasoning_content should be omitted when None"
5240 );
5241
5242 let msg_with = NativeMessage {
5243 role: "assistant".to_string(),
5244 content: Some(MessageContent::Text("hi".to_string())),
5245 tool_call_id: None,
5246 tool_calls: None,
5247 reasoning_content: Some("thinking...".to_string()),
5248 };
5249 let json = serde_json::to_string(&msg_with).unwrap();
5250 assert!(
5251 json.contains("reasoning_content"),
5252 "reasoning_content should be present when Some"
5253 );
5254 assert!(json.contains("thinking..."));
5255 }
5256
5257 #[test]
5258 fn default_timeout_is_120s() {
5259 let p = make_model_provider("test", "https://example.com", None);
5260 assert_eq!(p.timeout_secs, 120);
5261 }
5262
5263 #[test]
5264 fn with_timeout_secs_overrides_default() {
5265 let p = make_model_provider("test", "https://example.com", None).with_timeout_secs(300);
5266 assert_eq!(p.timeout_secs, 300);
5267 }
5268
5269 #[test]
5270 fn extra_headers_default_empty() {
5271 let p = make_model_provider("test", "https://example.com", None);
5272 assert!(p.extra_headers.is_empty());
5273 }
5274
5275 #[test]
5276 fn with_extra_headers_sets_headers() {
5277 let mut headers = std::collections::HashMap::new();
5278 headers.insert("X-Title".to_string(), "zeroclaw".to_string());
5279 headers.insert(
5280 "HTTP-Referer".to_string(),
5281 "https://example.com".to_string(),
5282 );
5283 let p =
5284 make_model_provider("test", "https://example.com", None).with_extra_headers(headers);
5285 assert_eq!(p.extra_headers.len(), 2);
5286 assert_eq!(p.extra_headers.get("X-Title").unwrap(), "zeroclaw");
5287 assert_eq!(
5288 p.extra_headers.get("HTTP-Referer").unwrap(),
5289 "https://example.com"
5290 );
5291 }
5292
5293 #[test]
5294 fn http_client_with_extra_headers_builds_successfully() {
5295 let mut headers = std::collections::HashMap::new();
5296 headers.insert("X-Title".to_string(), "zeroclaw".to_string());
5297 headers.insert("User-Agent".to_string(), "TestAgent/1.0".to_string());
5298 let p =
5299 make_model_provider("test", "https://example.com", None).with_extra_headers(headers);
5300 let _client = p.http_client();
5302 }
5303
5304 #[test]
5305 fn http_client_without_extra_headers_or_user_agent() {
5306 let p = make_model_provider("test", "https://example.com", None);
5307 let _client = p.http_client();
5309 }
5310
5311 #[test]
5312 fn extra_headers_combined_with_user_agent() {
5313 let mut headers = std::collections::HashMap::new();
5314 headers.insert("X-Title".to_string(), "zeroclaw".to_string());
5315 let p = OpenAiCompatibleModelProvider::new_with_user_agent(
5316 "test",
5317 "test",
5318 "https://example.com",
5319 None,
5320 AuthStyle::Bearer,
5321 "CustomAgent/1.0",
5322 )
5323 .with_extra_headers(headers);
5324 assert_eq!(p.user_agent.as_deref(), Some("CustomAgent/1.0"));
5325 assert_eq!(p.extra_headers.len(), 1);
5326 let _client = p.http_client();
5328 }
5329
5330 #[test]
5331 fn tool_call_none_fields_omitted_from_json() {
5332 let tc = ToolCall {
5335 id: Some("call_1".to_string()),
5336 kind: Some("function".to_string()),
5337 function: Some(Function {
5338 name: Some("shell".to_string()),
5339 arguments: Some("{\"command\":\"ls\"}".to_string()),
5340 }),
5341 name: None,
5342 arguments: None,
5343 parameters: None,
5344 extra_content: None,
5345 };
5346 let json = serde_json::to_value(&tc).unwrap();
5347 assert!(!json.as_object().unwrap().contains_key("name"));
5348 assert!(!json.as_object().unwrap().contains_key("arguments"));
5349 assert!(!json.as_object().unwrap().contains_key("parameters"));
5350 assert!(json.as_object().unwrap().contains_key("id"));
5352 assert!(json.as_object().unwrap().contains_key("type"));
5353 assert!(json.as_object().unwrap().contains_key("function"));
5354 }
5355
5356 #[test]
5357 fn tool_call_with_compat_fields_serializes_them() {
5358 let tc = ToolCall {
5360 id: None,
5361 kind: None,
5362 function: None,
5363 name: Some("shell".to_string()),
5364 arguments: Some("{\"command\":\"ls\"}".to_string()),
5365 parameters: None,
5366 extra_content: None,
5367 };
5368 let json = serde_json::to_value(&tc).unwrap();
5369 assert_eq!(json["name"], "shell");
5370 assert_eq!(json["arguments"], "{\"command\":\"ls\"}");
5371 assert!(!json.as_object().unwrap().contains_key("id"));
5373 assert!(!json.as_object().unwrap().contains_key("type"));
5374 assert!(!json.as_object().unwrap().contains_key("function"));
5375 assert!(!json.as_object().unwrap().contains_key("parameters"));
5376 }
5377
5378 #[test]
5381 fn proxy_tool_start_valid() {
5382 let line = r#"data: {"x_tool_start":{"name":"bash","arguments":"{\"cmd\":\"ls\"}"}}"#;
5383 let event = parse_proxy_tool_event(line);
5384 assert!(matches!(
5385 event,
5386 Some(StreamEvent::PreExecutedToolCall { ref name, ref args })
5387 if name == "bash" && args == r#"{"cmd":"ls"}"#
5388 ));
5389 }
5390
5391 #[test]
5392 fn proxy_tool_start_missing_name_returns_none() {
5393 let line = r#"data: {"x_tool_start":{"arguments":"{}"}}"#;
5394 assert!(parse_proxy_tool_event(line).is_none());
5395 }
5396
5397 #[test]
5398 fn proxy_tool_start_missing_arguments_defaults() {
5399 let line = r#"data: {"x_tool_start":{"name":"read"}}"#;
5400 let event = parse_proxy_tool_event(line);
5401 assert!(matches!(
5402 event,
5403 Some(StreamEvent::PreExecutedToolCall { ref name, ref args })
5404 if name == "read" && args == "{}"
5405 ));
5406 }
5407
5408 #[test]
5409 fn proxy_tool_result_valid() {
5410 let line = r#"data: {"x_tool_result":{"name":"bash","output":"hello world"}}"#;
5411 let event = parse_proxy_tool_event(line);
5412 assert!(matches!(
5413 event,
5414 Some(StreamEvent::PreExecutedToolResult { ref name, ref output })
5415 if name == "bash" && output == "hello world"
5416 ));
5417 }
5418
5419 #[test]
5420 fn proxy_tool_result_missing_fields_uses_defaults() {
5421 let line = r#"data: {"x_tool_result":{}}"#;
5422 let event = parse_proxy_tool_event(line);
5423 assert!(matches!(
5424 event,
5425 Some(StreamEvent::PreExecutedToolResult { ref name, ref output })
5426 if name == "unknown" && output.is_empty()
5427 ));
5428 }
5429
5430 #[test]
5431 fn proxy_tool_event_non_json_returns_none() {
5432 assert!(parse_proxy_tool_event("data: not json").is_none());
5433 }
5434
5435 #[test]
5436 fn proxy_tool_event_no_data_prefix_returns_none() {
5437 let line = r#"{"x_tool_start":{"name":"bash"}}"#;
5438 assert!(parse_proxy_tool_event(line).is_none());
5439 }
5440
5441 #[test]
5442 fn proxy_tool_event_standard_openai_chunk_returns_none() {
5443 let line = r#"data: {"id":"chatcmpl-1","choices":[{"delta":{"content":"hi"}}]}"#;
5444 assert!(parse_proxy_tool_event(line).is_none());
5445 }
5446
5447 #[test]
5448 fn proxy_tool_event_done_sentinel_returns_none() {
5449 assert!(parse_proxy_tool_event("data: [DONE]").is_none());
5450 }
5451
5452 #[test]
5461 fn strip_native_tool_messages_coalesces_adjacent_assistants() {
5462 let messages = vec![
5463 ChatMessage::user("search for cats"),
5464 ChatMessage::assistant(
5465 r#"{"content":"I'll search","tool_calls":[{"id":"t1","name":"web_search","arguments":"{}"}]}"#,
5466 ),
5467 ChatMessage::tool(r#"{"tool_call_id":"t1","content":"Found 10 results"}"#),
5468 ChatMessage::assistant("Here are the results about cats"),
5469 ];
5470 let p = OpenAiCompatibleModelProvider::new_merge_system_into_user(
5471 "test",
5472 "MiniMax",
5473 "https://api.minimax.chat/v1",
5474 Some("k"),
5475 AuthStyle::Bearer,
5476 );
5477 let stripped = p.strip_native_tool_messages(&messages);
5478 let roles: Vec<&str> = stripped.iter().map(|m| m.role.as_str()).collect();
5479 assert!(
5480 !roles.windows(2).any(|w| w[0] == w[1]),
5481 "no two consecutive messages should share a role; got {roles:?}"
5482 );
5483 assert_eq!(roles, vec!["user", "assistant"]);
5485 assert_eq!(stripped[0].content, "search for cats");
5486 assert!(
5487 stripped[1].content.contains("I'll search")
5488 && stripped[1]
5489 .content
5490 .contains("Here are the results about cats"),
5491 "merged assistant should preserve both the pre-tool narration and the final reply; \
5492 got {:?}",
5493 stripped[1].content
5494 );
5495 }
5496
5497 #[test]
5502 fn strip_native_tool_messages_drops_empty_narration_cleanly() {
5503 let messages = vec![
5504 ChatMessage::user("search for cats"),
5505 ChatMessage::assistant(
5506 r#"{"content":"","tool_calls":[{"id":"t1","name":"web_search","arguments":"{}"}]}"#,
5507 ),
5508 ChatMessage::tool(r#"{"tool_call_id":"t1","content":"Found"}"#),
5509 ChatMessage::assistant("Here are the results"),
5510 ];
5511 let p = OpenAiCompatibleModelProvider::new_merge_system_into_user(
5512 "test",
5513 "MiniMax",
5514 "https://api.minimax.chat/v1",
5515 Some("k"),
5516 AuthStyle::Bearer,
5517 );
5518 let stripped = p.strip_native_tool_messages(&messages);
5519 assert_eq!(
5520 stripped.iter().map(|m| m.role.as_str()).collect::<Vec<_>>(),
5521 vec!["user", "assistant"]
5522 );
5523 assert_eq!(stripped[1].content, "Here are the results");
5524 }
5525
5526 #[tokio::test]
5530 async fn dropping_stream_aborts_forwarder_and_closes_upstream_socket() {
5531 use axum::Router;
5532 use axum::response::IntoResponse;
5533 use axum::routing::post;
5534 use futures_util::StreamExt as _;
5535 use std::sync::Arc;
5536 use std::sync::atomic::{AtomicBool, Ordering};
5537 use tokio::net::TcpListener;
5538
5539 let handler_dropped = Arc::new(AtomicBool::new(false));
5540 let handler_dropped_for_route = Arc::clone(&handler_dropped);
5541
5542 let app = Router::new().route(
5543 "/chat/completions",
5544 post(move || {
5545 let dropped = Arc::clone(&handler_dropped_for_route);
5546 async move {
5547 let sentinel = scopeguard::guard((), move |()| {
5548 dropped.store(true, Ordering::SeqCst);
5549 });
5550 let first = futures_util::stream::once(async {
5551 Ok::<_, std::convert::Infallible>(axum::body::Bytes::from(
5552 "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n",
5553 ))
5554 });
5555 let park = futures_util::stream::poll_fn(move |_cx| {
5556 let _ = &sentinel;
5557 std::task::Poll::Pending
5558 });
5559 axum::body::Body::from_stream(first.chain(park)).into_response()
5560 }
5561 }),
5562 );
5563
5564 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
5565 let addr = listener.local_addr().unwrap();
5566 let server = ::zeroclaw_spawn::spawn!(async move {
5567 axum::serve(listener, app).await.unwrap();
5568 });
5569
5570 let provider = OpenAiCompatibleModelProvider::new(
5571 "test",
5572 "test",
5573 &format!("http://{addr}"),
5574 Some("k"),
5575 AuthStyle::Bearer,
5576 );
5577
5578 let mut stream = provider.stream_chat(
5579 crate::traits::ChatRequest {
5580 messages: &[ChatMessage::user("hi")],
5581 tools: None,
5582 thinking: None,
5583 },
5584 "gpt-test",
5585 Some(0.0),
5586 StreamOptions {
5587 enabled: true,
5588 count_tokens: false,
5589 },
5590 );
5591
5592 let first = stream.next().await;
5593 assert!(first.is_some(), "expected at least the first SSE chunk");
5594
5595 drop(stream);
5596
5597 let observed = tokio::time::timeout(std::time::Duration::from_secs(5), async {
5598 loop {
5599 if handler_dropped.load(Ordering::SeqCst) {
5600 return;
5601 }
5602 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
5603 }
5604 })
5605 .await;
5606
5607 server.abort();
5608 assert!(
5609 observed.is_ok(),
5610 "dropped stream must abort the forwarder and close the upstream socket"
5611 );
5612 }
5613
5614 fn minimal_request(temperature: Option<f64>) -> ApiChatRequest {
5615 ApiChatRequest {
5616 model: "any-model".to_string(),
5617 messages: vec![Message {
5618 role: "user".to_string(),
5619 content: MessageContent::Text("hi".to_string()),
5620 }],
5621 temperature,
5622 stream: None,
5623 stream_options: None,
5624 reasoning_effort: None,
5625 tool_stream: None,
5626 tools: None,
5627 tool_choice: None,
5628 max_tokens: None,
5629 }
5630 }
5631
5632 #[test]
5633 fn unset_temperature_is_omitted_from_wire() {
5634 let body = serde_json::to_value(minimal_request(None)).unwrap();
5638 assert!(
5639 body.get("temperature").is_none(),
5640 "unset temperature must be omitted from the request body, got: {body}"
5641 );
5642 }
5643
5644 #[test]
5645 fn explicit_temperature_is_sent_on_wire() {
5646 let body = serde_json::to_value(minimal_request(Some(0.5))).unwrap();
5647 assert_eq!(
5648 body.get("temperature").and_then(|v| v.as_f64()),
5649 Some(0.5),
5650 "explicit temperature must be sent verbatim, got: {body}"
5651 );
5652 }
5653
5654 #[test]
5655 fn models_dev_to_model_info_returns_no_pricing() {
5656 let ids = vec![
5659 "openai/gpt-4o".to_string(),
5660 "anthropic/claude-sonnet-4-6".to_string(),
5661 ];
5662 let models = models_dev_to_model_info(ids);
5663 assert_eq!(models.len(), 2);
5664 assert_eq!(models[0].id, "openai/gpt-4o");
5666 assert!(models[0].pricing.is_none());
5667 assert_eq!(models[1].id, "anthropic/claude-sonnet-4-6");
5668 assert!(models[1].pricing.is_none());
5669 }
5670
5671 #[test]
5672 fn public_model_listing_flag_defaults_false() {
5673 let p = make_model_provider("test", "https://example.com", None);
5676 assert!(!p.public_model_listing);
5677 }
5678
5679 #[test]
5680 fn public_model_listing_flag_can_be_set() {
5681 let p =
5683 make_model_provider("test", "https://example.com", None).with_public_model_listing();
5684 assert!(p.public_model_listing);
5685 }
5686}