1use crate::multimodal;
6use crate::traits::{
7 ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
8 ModelProvider, StreamChunk, StreamError, StreamEvent, StreamOptions, StreamResult, TokenUsage,
9 ToolCall as ProviderToolCall,
10};
11use async_trait::async_trait;
12use futures_util::{StreamExt, stream};
13use reqwest::{
14 Client,
15 header::{HeaderMap, HeaderValue, USER_AGENT},
16};
17use serde::{Deserialize, Serialize};
18
19#[allow(clippy::struct_excessive_bools)]
23#[derive(Clone)]
24pub struct OpenAiCompatibleModelProvider {
25 pub alias: String,
29 pub name: String,
30 pub base_url: String,
31 pub credential: Option<String>,
32 pub auth_header: AuthStyle,
33 supports_vision: bool,
34 user_agent: Option<String>,
35 merge_system_into_user: bool,
39 native_tool_calling: bool,
42 timeout_secs: u64,
44 extra_headers: std::collections::HashMap<String, String>,
46 reasoning_effort: Option<String>,
48 api_path: Option<String>,
51 max_tokens: Option<u32>,
53 models_dev_key: Option<String>,
56 openrouter_vendor_prefix: Option<String>,
62 local_model_tool_sanitize: bool,
68 unauthenticated_model_listing: bool,
72}
73
74#[derive(Debug, Clone)]
76pub enum AuthStyle {
77 Bearer,
79 XApiKey,
81 Custom(String),
83 ZhipuJwt,
87}
88
89fn zhipu_jwt_bearer(credential: &str) -> Result<String, String> {
92 let (id, secret) = credential
93 .split_once('.')
94 .ok_or_else(|| "Zhipu API key must be in 'id.secret' format".to_string())?;
95
96 #[allow(clippy::cast_possible_truncation)] let now_ms = std::time::SystemTime::now()
98 .duration_since(std::time::UNIX_EPOCH)
99 .map_err(|e| e.to_string())?
100 .as_millis() as u64;
101 let exp_ms = now_ms + 210_000; let header_b64 = base64url_no_pad(br#"{"alg":"HS256","typ":"JWT","sign_type":"SIGN"}"#);
105 let payload = format!(r#"{{"api_key":"{id}","exp":{exp_ms},"timestamp":{now_ms}}}"#);
106 let payload_b64 = base64url_no_pad(payload.as_bytes());
107
108 let signing_input = format!("{header_b64}.{payload_b64}");
109 let key = ring::hmac::Key::new(ring::hmac::HMAC_SHA256, secret.as_bytes());
110 let sig = ring::hmac::sign(&key, signing_input.as_bytes());
111 let sig_b64 = base64url_no_pad(sig.as_ref());
112
113 Ok(format!("Bearer {signing_input}.{sig_b64}"))
114}
115
116fn base64url_no_pad(data: &[u8]) -> String {
117 use base64::engine::{Engine, general_purpose::URL_SAFE_NO_PAD};
118 URL_SAFE_NO_PAD.encode(data)
119}
120
121fn apply_auth_to_request(
126 req: reqwest::RequestBuilder,
127 style: &AuthStyle,
128 credential: Option<&str>,
129) -> reqwest::RequestBuilder {
130 let credential = match credential {
131 Some(c) => c,
132 None => return req,
133 };
134 match style {
135 AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
136 AuthStyle::XApiKey => req.header("x-api-key", credential),
137 AuthStyle::Custom(header) => req.header(header, credential),
138 AuthStyle::ZhipuJwt => match zhipu_jwt_bearer(credential) {
139 Ok(val) => req.header("Authorization", val),
140 Err(_) => req.header("Authorization", format!("Bearer {credential}")),
141 },
142 }
143}
144
145#[derive(Deserialize)]
146struct ModelsResponse {
147 data: Vec<ModelEntry>,
148}
149
150#[derive(Deserialize)]
151struct ModelEntry {
152 id: String,
153}
154
155fn normalize_model_ids(body: ModelsResponse) -> Vec<String> {
156 let mut ids: Vec<String> = body
157 .data
158 .into_iter()
159 .map(|e| e.id.trim().to_string())
160 .filter(|id| !id.is_empty())
161 .collect();
162 ids.sort();
163 ids
164}
165
166impl OpenAiCompatibleModelProvider {
167 pub fn new(
168 alias: &str,
169 name: &str,
170 base_url: &str,
171 credential: Option<&str>,
172 auth_style: AuthStyle,
173 ) -> Self {
174 Self::new_with_options(
175 alias, name, base_url, credential, auth_style, false, None, false,
176 )
177 }
178
179 pub fn new_with_vision(
180 alias: &str,
181 name: &str,
182 base_url: &str,
183 credential: Option<&str>,
184 auth_style: AuthStyle,
185 supports_vision: bool,
186 ) -> Self {
187 Self::new_with_options(
188 alias,
189 name,
190 base_url,
191 credential,
192 auth_style,
193 supports_vision,
194 None,
195 false,
196 )
197 }
198
199 pub fn new_with_user_agent(
204 alias: &str,
205 name: &str,
206 base_url: &str,
207 credential: Option<&str>,
208 auth_style: AuthStyle,
209 user_agent: &str,
210 ) -> Self {
211 Self::new_with_options(
212 alias,
213 name,
214 base_url,
215 credential,
216 auth_style,
217 false,
218 Some(user_agent),
219 false,
220 )
221 }
222
223 pub fn new_with_user_agent_and_vision(
224 alias: &str,
225 name: &str,
226 base_url: &str,
227 credential: Option<&str>,
228 auth_style: AuthStyle,
229 user_agent: &str,
230 supports_vision: bool,
231 ) -> Self {
232 Self::new_with_options(
233 alias,
234 name,
235 base_url,
236 credential,
237 auth_style,
238 supports_vision,
239 Some(user_agent),
240 false,
241 )
242 }
243
244 pub fn new_merge_system_into_user(
247 alias: &str,
248 name: &str,
249 base_url: &str,
250 credential: Option<&str>,
251 auth_style: AuthStyle,
252 ) -> Self {
253 Self::new_with_options(
254 alias, name, base_url, credential, auth_style, false, None, true,
255 )
256 }
257
258 fn new_with_options(
259 alias: &str,
260 name: &str,
261 base_url: &str,
262 credential: Option<&str>,
263 auth_style: AuthStyle,
264 supports_vision: bool,
265 user_agent: Option<&str>,
266 merge_system_into_user: bool,
267 ) -> Self {
268 Self {
269 alias: alias.to_string(),
270 name: name.to_string(),
271 base_url: base_url.trim_end_matches('/').to_string(),
272 credential: credential.map(ToString::to_string),
273 auth_header: auth_style,
274 supports_vision,
275 user_agent: user_agent.map(ToString::to_string),
276 merge_system_into_user,
277 native_tool_calling: !merge_system_into_user,
278 timeout_secs: 120,
279 extra_headers: std::collections::HashMap::new(),
280 reasoning_effort: None,
281 api_path: None,
282 max_tokens: None,
283 models_dev_key: None,
284 openrouter_vendor_prefix: None,
285 local_model_tool_sanitize: false,
286 unauthenticated_model_listing: false,
287 }
288 }
289 pub fn with_local_model_tool_sanitize(mut self) -> Self {
296 self.local_model_tool_sanitize = true;
297 self
298 }
299
300 pub fn with_unauthenticated_model_listing(mut self) -> Self {
301 self.unauthenticated_model_listing = true;
302 self
303 }
304
305 pub fn without_native_tools(mut self) -> Self {
307 self.native_tool_calling = false;
308 self
309 }
310
311 pub fn with_merge_system_into_user(mut self) -> Self {
314 self.merge_system_into_user = true;
315 self
316 }
317
318 pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
320 self.timeout_secs = timeout_secs;
321 self
322 }
323
324 pub fn with_extra_headers(
326 mut self,
327 headers: std::collections::HashMap<String, String>,
328 ) -> Self {
329 self.extra_headers = headers;
330 self
331 }
332
333 pub fn with_reasoning_effort(mut self, reasoning_effort: Option<String>) -> Self {
335 self.reasoning_effort = reasoning_effort;
336 self
337 }
338
339 pub fn with_api_path(mut self, api_path: Option<String>) -> Self {
342 self.api_path = api_path;
343 self
344 }
345
346 pub fn with_max_tokens(mut self, max_tokens: Option<u32>) -> Self {
348 self.max_tokens = max_tokens;
349 self
350 }
351
352 pub fn with_models_dev_key(mut self, key: &str) -> Self {
355 self.models_dev_key = Some(key.to_string());
356 self
357 }
358
359 pub fn with_openrouter_vendor_prefix(mut self, prefix: &str) -> Self {
363 self.openrouter_vendor_prefix = Some(prefix.to_string());
364 self
365 }
366
367 fn flatten_system_messages(messages: &[ChatMessage], merge: bool) -> Vec<ChatMessage> {
371 let mut saw_system = false;
372 let mut system_content = String::new();
373 let mut result: Vec<ChatMessage> = Vec::with_capacity(messages.len());
374
375 for message in messages {
376 if message.role == "system" {
377 saw_system = true;
378 if !message.content.is_empty() {
379 if !system_content.is_empty() {
380 system_content.push_str("\n\n");
381 }
382 system_content.push_str(&message.content);
383 }
384 } else {
385 result.push(message.clone());
386 }
387 }
388
389 if !saw_system {
390 return messages.to_vec();
391 }
392
393 if system_content.is_empty() {
394 return result;
395 }
396
397 if !merge {
398 result.insert(0, ChatMessage::system(system_content));
399 return result;
400 }
401
402 if let Some(first_user) = result.iter_mut().find(|m| m.role == "user") {
403 if !system_content.is_empty() {
404 first_user.content = format!("{system_content}\n\n{}", first_user.content);
405 }
406 } else {
407 result.insert(0, ChatMessage::user(&system_content));
409 }
410
411 result
412 }
413
414 fn http_client(&self) -> Client {
415 let timeout = self.timeout_secs;
416 let has_user_agent = self.user_agent.is_some();
417 let has_extra_headers = !self.extra_headers.is_empty();
418
419 if has_user_agent || has_extra_headers {
420 let mut headers = HeaderMap::new();
421 if let Some(ua) = self.user_agent.as_deref()
422 && let Ok(value) = HeaderValue::from_str(ua)
423 {
424 headers.insert(USER_AGENT, value);
425 }
426 for (key, value) in &self.extra_headers {
427 match (
428 reqwest::header::HeaderName::from_bytes(key.as_bytes()),
429 HeaderValue::from_str(value),
430 ) {
431 (Ok(name), Ok(val)) => {
432 headers.insert(name, val);
433 }
434 _ => {
435 ::zeroclaw_log::record!(
436 WARN,
437 ::zeroclaw_log::Event::new(
438 module_path!(),
439 ::zeroclaw_log::Action::Note
440 )
441 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
442 .with_attrs(::serde_json::json!({"header": key})),
443 "Skipping invalid extra header name or value"
444 );
445 }
446 }
447 }
448
449 let builder = Client::builder()
450 .timeout(std::time::Duration::from_secs(timeout))
451 .connect_timeout(std::time::Duration::from_secs(10))
452 .default_headers(headers);
453 let builder = zeroclaw_config::schema::apply_runtime_proxy_to_builder(
454 builder,
455 "model_provider.compatible",
456 );
457
458 return builder.build().unwrap_or_else(|error| {
459 ::zeroclaw_log::record!(
460 WARN,
461 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
462 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
463 .with_attrs(
464 ::serde_json::json!({"error": super::format_error_chain(&error)})
465 ),
466 "Failed to build proxied timeout client with custom headers: "
467 );
468 Client::new()
469 });
470 }
471
472 zeroclaw_config::schema::build_runtime_proxy_client_with_timeouts(
473 "model_provider.compatible",
474 timeout,
475 10,
476 )
477 }
478
479 fn streaming_http_client(&self) -> Client {
483 let has_user_agent = self.user_agent.is_some();
484 let has_extra_headers = !self.extra_headers.is_empty();
485
486 if has_user_agent || has_extra_headers {
487 let mut headers = HeaderMap::new();
488 if let Some(ua) = self.user_agent.as_deref()
489 && let Ok(value) = HeaderValue::from_str(ua)
490 {
491 headers.insert(USER_AGENT, value);
492 }
493 for (key, value) in &self.extra_headers {
494 match (
495 reqwest::header::HeaderName::from_bytes(key.as_bytes()),
496 HeaderValue::from_str(value),
497 ) {
498 (Ok(name), Ok(val)) => {
499 headers.insert(name, val);
500 }
501 _ => {
502 ::zeroclaw_log::record!(
503 WARN,
504 ::zeroclaw_log::Event::new(
505 module_path!(),
506 ::zeroclaw_log::Action::Note
507 )
508 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
509 .with_attrs(::serde_json::json!({"header": key})),
510 "Skipping invalid extra header name or value"
511 );
512 }
513 }
514 }
515
516 let builder = Client::builder()
517 .connect_timeout(std::time::Duration::from_secs(10))
518 .default_headers(headers);
519 let builder = zeroclaw_config::schema::apply_runtime_proxy_to_builder(
520 builder,
521 "provider.compatible",
522 );
523 return builder.build().unwrap_or_else(|error| {
524 ::zeroclaw_log::record!(
525 WARN,
526 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
527 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
528 .with_attrs(
529 ::serde_json::json!({"error": super::format_error_chain(&error)})
530 ),
531 "Failed to build proxied streaming client with custom headers: "
532 );
533 Client::new()
534 });
535 }
536
537 let builder = Client::builder().connect_timeout(std::time::Duration::from_secs(10));
538 let builder =
539 zeroclaw_config::schema::apply_runtime_proxy_to_builder(builder, "provider.compatible");
540 builder.build().unwrap_or_else(|error| {
541 ::zeroclaw_log::record!(
542 WARN,
543 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
544 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
545 .with_attrs(::serde_json::json!({"error": super::format_error_chain(&error)})),
546 "Failed to build proxied streaming client: "
547 );
548 Client::new()
549 })
550 }
551
552 fn chat_completions_url(&self) -> String {
556 if let Some(ref api_path) = self.api_path {
558 let separator = if api_path.starts_with('/') { "" } else { "/" };
559 return format!("{}{separator}{api_path}", self.base_url);
560 }
561
562 let has_full_endpoint = reqwest::Url::parse(&self.base_url)
563 .map(|url| {
564 url.path()
565 .trim_end_matches('/')
566 .ends_with("/chat/completions")
567 })
568 .unwrap_or_else(|_| {
569 self.base_url
570 .trim_end_matches('/')
571 .ends_with("/chat/completions")
572 });
573
574 if has_full_endpoint {
575 self.base_url.clone()
576 } else {
577 format!("{}/chat/completions", self.base_url)
578 }
579 }
580
581 fn requires_tool_stream(&self) -> bool {
582 let host_requires_tool_stream = reqwest::Url::parse(&self.base_url)
583 .ok()
584 .and_then(|url| url.host_str().map(str::to_ascii_lowercase))
585 .is_some_and(|host| host == "api.z.ai" || host.ends_with(".z.ai"));
586
587 host_requires_tool_stream || matches!(self.name.as_str(), "zai" | "z.ai")
588 }
589
590 fn tool_stream_for_tools(&self, has_tools: bool) -> Option<bool> {
591 if has_tools && self.requires_tool_stream() {
592 Some(true)
593 } else {
594 None
595 }
596 }
597
598 fn model_requires_system_merge(model: &str) -> bool {
602 let id = model
603 .rsplit('/')
604 .next()
605 .unwrap_or(model)
606 .to_ascii_lowercase();
607 id.contains("deepseek-v3") || id.contains("deepseek_v3")
608 }
609
610 fn effective_merge_system(&self, model: &str) -> bool {
613 self.merge_system_into_user || Self::model_requires_system_merge(model)
614 }
615
616 fn reasoning_effort_for_model(&self, model: &str) -> Option<String> {
617 let effort = self.reasoning_effort.as_ref()?;
618 let id = model
619 .rsplit('/')
620 .next()
621 .unwrap_or(model)
622 .to_ascii_lowercase();
623 let is_openai_reasoning_model = id == "o1"
624 || id.starts_with("o1-")
625 || id == "o3"
626 || id.starts_with("o3-")
627 || id == "o4"
628 || id.starts_with("o4-")
629 || id.starts_with("gpt-5");
630 let is_likely_codex_supported = id.contains("codex") && id.starts_with("gpt-");
631
632 (is_openai_reasoning_model || is_likely_codex_supported).then(|| effort.clone())
633 }
634}
635
636#[derive(Debug, Serialize)]
637struct ApiChatRequest {
638 model: String,
639 messages: Vec<Message>,
640 temperature: f64,
641 #[serde(skip_serializing_if = "Option::is_none")]
642 stream: Option<bool>,
643 #[serde(skip_serializing_if = "Option::is_none")]
644 stream_options: Option<StreamOptionsBody>,
645 #[serde(skip_serializing_if = "Option::is_none")]
646 reasoning_effort: Option<String>,
647 #[serde(skip_serializing_if = "Option::is_none")]
648 tool_stream: Option<bool>,
649 #[serde(skip_serializing_if = "Option::is_none")]
650 tools: Option<Vec<serde_json::Value>>,
651 #[serde(skip_serializing_if = "Option::is_none")]
652 tool_choice: Option<String>,
653 #[serde(skip_serializing_if = "Option::is_none")]
654 max_tokens: Option<u32>,
655}
656
657#[derive(Debug, Serialize, Clone, Copy)]
662struct StreamOptionsBody {
663 include_usage: bool,
664}
665
666#[derive(Debug, Serialize)]
667struct Message {
668 role: String,
669 content: MessageContent,
670}
671
672#[derive(Debug, Serialize)]
673#[serde(untagged)]
674enum MessageContent {
675 Text(String),
676 Parts(Vec<MessagePart>),
677}
678
679#[derive(Debug, Serialize)]
680#[serde(tag = "type", rename_all = "snake_case")]
681enum MessagePart {
682 Text { text: String },
683 ImageUrl { image_url: ImageUrlPart },
684}
685
686#[derive(Debug, Serialize)]
687struct ImageUrlPart {
688 url: String,
689}
690
691#[derive(Debug, Deserialize)]
692struct ApiChatResponse {
693 choices: Vec<Choice>,
694 #[serde(default)]
695 usage: Option<UsageInfo>,
696}
697
698#[derive(Debug, Deserialize)]
699struct UsageInfo {
700 #[serde(default)]
701 prompt_tokens: Option<u64>,
702 #[serde(default)]
703 completion_tokens: Option<u64>,
704}
705
706#[derive(Debug, Deserialize)]
707struct Choice {
708 message: ResponseMessage,
709}
710
711fn strip_think_tags(s: &str) -> String {
716 let mut result = String::with_capacity(s.len());
717 let mut rest = s;
718 loop {
719 if let Some(start) = rest.find("<think>") {
720 result.push_str(&rest[..start]);
721 if let Some(end) = rest[start..].find("</think>") {
722 rest = &rest[start + end + "</think>".len()..];
723 } else {
724 break;
726 }
727 } else {
728 result.push_str(rest);
729 break;
730 }
731 }
732 result.trim().to_string()
733}
734
735fn openai_assistant_content_plaintext(content: Option<OpenAiAssistantContent>) -> Option<String> {
740 match content? {
741 OpenAiAssistantContent::Text(s) => {
742 if s.is_empty() {
743 None
744 } else {
745 Some(s)
746 }
747 }
748 OpenAiAssistantContent::Parts(parts) => {
749 let mut text = String::new();
750 for part in parts {
751 if part.kind.as_deref() != Some("text") {
752 continue;
753 }
754 let Some(part_text) = part.text.filter(|text| !text.is_empty()) else {
755 continue;
756 };
757 if !text.is_empty() {
758 text.push('\n');
759 }
760 text.push_str(&part_text);
761 }
762
763 if text.is_empty() { None } else { Some(text) }
764 }
765 }
766}
767
768#[derive(Debug, Deserialize)]
769#[serde(untagged)]
770enum OpenAiAssistantContent {
771 Text(String),
772 Parts(Vec<OpenAiAssistantContentPart>),
773}
774
775#[derive(Debug, Deserialize)]
776struct OpenAiAssistantContentPart {
777 #[serde(rename = "type")]
778 kind: Option<String>,
779 text: Option<String>,
780}
781
782#[derive(Debug, Deserialize, Serialize)]
783#[serde(from = "RawResponseMessage")]
784struct ResponseMessage {
785 content: Option<String>,
786 reasoning_content: Option<String>,
794 tool_calls: Option<Vec<ToolCall>>,
795}
796
797#[derive(Debug, Deserialize)]
805struct RawResponseMessage {
806 #[serde(default)]
807 content: Option<OpenAiAssistantContent>,
808 #[serde(default)]
809 reasoning_content: Option<String>,
810 #[serde(default)]
811 reasoning: Option<String>,
812 #[serde(default)]
813 tool_calls: Option<Vec<ToolCall>>,
814}
815
816impl From<RawResponseMessage> for ResponseMessage {
817 fn from(raw: RawResponseMessage) -> Self {
818 let reasoning_content = raw.reasoning_content.or(raw.reasoning);
821 ResponseMessage {
822 content: openai_assistant_content_plaintext(raw.content),
823 reasoning_content,
824 tool_calls: raw.tool_calls,
825 }
826 }
827}
828
829impl ResponseMessage {
830 fn effective_content(&self) -> String {
836 if let Some(content) = self.content.as_ref().filter(|c| !c.is_empty()) {
837 let stripped = strip_think_tags(content);
838 if !stripped.is_empty() {
839 return stripped;
840 }
841 }
842
843 self.reasoning_content
844 .as_ref()
845 .map(|c| strip_think_tags(c))
846 .filter(|c| !c.is_empty())
847 .unwrap_or_default()
848 }
849
850 fn effective_content_optional(&self) -> Option<String> {
851 if let Some(content) = self.content.as_ref().filter(|c| !c.is_empty()) {
852 let stripped = strip_think_tags(content);
853 if !stripped.is_empty() {
854 return Some(stripped);
855 }
856 }
857
858 self.reasoning_content
859 .as_ref()
860 .map(|c| strip_think_tags(c))
861 .filter(|c| !c.is_empty())
862 }
863}
864
865#[derive(Debug, Deserialize, Serialize)]
866struct ToolCall {
867 #[serde(skip_serializing_if = "Option::is_none")]
868 id: Option<String>,
869 #[serde(rename = "type")]
870 #[serde(default, skip_serializing_if = "Option::is_none")]
871 kind: Option<String>,
872 #[serde(default, skip_serializing_if = "Option::is_none")]
873 function: Option<Function>,
874
875 #[serde(default, skip_serializing_if = "Option::is_none")]
877 name: Option<String>,
878 #[serde(default, skip_serializing_if = "Option::is_none")]
879 arguments: Option<String>,
880
881 #[serde(
883 rename = "parameters",
884 default,
885 skip_serializing_if = "Option::is_none"
886 )]
887 parameters: Option<serde_json::Value>,
888
889 #[serde(default, skip_serializing_if = "Option::is_none")]
891 extra_content: Option<serde_json::Value>,
892}
893
894impl ToolCall {
895 fn function_name(&self) -> Option<String> {
897 if let Some(ref func) = self.function
899 && let Some(ref name) = func.name
900 {
901 return Some(name.clone());
902 }
903 self.name.clone()
905 }
906
907 fn function_arguments(&self) -> Option<String> {
909 if let Some(ref func) = self.function
911 && let Some(ref args) = func.arguments
912 {
913 return Some(args.clone());
914 }
915 if let Some(ref args) = self.arguments {
917 return Some(args.clone());
918 }
919 if let Some(ref params) = self.parameters {
921 return serde_json::to_string(params).ok();
922 }
923 None
924 }
925}
926
927#[derive(Debug, Deserialize, Serialize)]
928struct Function {
929 #[serde(default)]
930 name: Option<String>,
931 #[serde(default)]
932 arguments: Option<String>,
933}
934
935#[derive(Debug, Serialize)]
936struct NativeChatRequest {
937 model: String,
938 messages: Vec<NativeMessage>,
939 temperature: f64,
940 #[serde(skip_serializing_if = "Option::is_none")]
941 stream: Option<bool>,
942 #[serde(skip_serializing_if = "Option::is_none")]
948 stream_options: Option<StreamOptionsBody>,
949 #[serde(skip_serializing_if = "Option::is_none")]
950 reasoning_effort: Option<String>,
951 #[serde(skip_serializing_if = "Option::is_none")]
952 tool_stream: Option<bool>,
953 #[serde(skip_serializing_if = "Option::is_none")]
954 tools: Option<Vec<serde_json::Value>>,
955 #[serde(skip_serializing_if = "Option::is_none")]
956 tool_choice: Option<String>,
957 #[serde(skip_serializing_if = "Option::is_none")]
958 max_tokens: Option<u32>,
959}
960
961#[derive(Debug, Serialize)]
962struct NativeMessage {
963 role: String,
964 #[serde(skip_serializing_if = "Option::is_none")]
965 content: Option<MessageContent>,
966 #[serde(skip_serializing_if = "Option::is_none")]
967 tool_call_id: Option<String>,
968 #[serde(skip_serializing_if = "Option::is_none")]
969 tool_calls: Option<Vec<ToolCall>>,
970 #[serde(skip_serializing_if = "Option::is_none")]
973 reasoning_content: Option<String>,
974}
975
976#[derive(Debug, Deserialize)]
982struct StreamChunkResponse {
983 #[serde(default)]
984 choices: Vec<StreamChoice>,
985 #[serde(default)]
988 usage: Option<UsageInfo>,
989}
990
991#[derive(Debug, Deserialize)]
992struct StreamChoice {
993 #[serde(default)]
994 delta: StreamDelta,
995 #[serde(default)]
996 finish_reason: Option<String>,
997}
998
999#[derive(Debug, Default)]
1000struct StreamDelta {
1001 content: Option<String>,
1002 reasoning_content: Option<String>,
1008 tool_calls: Option<Vec<StreamToolCallDelta>>,
1010}
1011
1012#[derive(Debug, Deserialize, Default)]
1018struct RawStreamDelta {
1019 #[serde(default)]
1020 content: Option<String>,
1021 #[serde(default)]
1022 reasoning_content: Option<String>,
1023 #[serde(default)]
1024 reasoning: Option<String>,
1025 #[serde(default)]
1026 tool_calls: Option<Vec<StreamToolCallDelta>>,
1027}
1028
1029impl<'de> Deserialize<'de> for StreamDelta {
1030 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1031 where
1032 D: serde::Deserializer<'de>,
1033 {
1034 let raw = RawStreamDelta::deserialize(deserializer)?;
1035 Ok(StreamDelta {
1036 content: raw.content,
1037 reasoning_content: raw.reasoning_content.or(raw.reasoning),
1038 tool_calls: raw.tool_calls,
1039 })
1040 }
1041}
1042
1043#[derive(Debug, Deserialize)]
1044struct StreamToolCallDelta {
1045 #[serde(default)]
1046 index: Option<usize>,
1047 #[serde(default)]
1048 id: Option<String>,
1049 #[serde(default)]
1050 function: Option<StreamFunctionDelta>,
1051 #[serde(default)]
1053 name: Option<String>,
1054 #[serde(default)]
1055 arguments: Option<String>,
1056 #[serde(default)]
1057 extra_content: Option<serde_json::Value>,
1058}
1059
1060#[derive(Debug, Deserialize)]
1061struct StreamFunctionDelta {
1062 #[serde(default)]
1063 name: Option<String>,
1064 #[serde(default)]
1065 arguments: Option<String>,
1066}
1067
1068#[derive(Debug, Default)]
1069struct StreamToolCallAccumulator {
1070 id: Option<String>,
1071 name: Option<String>,
1072 arguments: String,
1073 extra_content: Option<serde_json::Value>,
1074}
1075
1076impl StreamToolCallAccumulator {
1077 fn apply_delta(&mut self, delta: &StreamToolCallDelta) {
1078 if let Some(id) = delta.id.as_ref().filter(|value| !value.is_empty()) {
1079 self.id = Some(id.clone());
1080 }
1081
1082 let delta_name = delta
1083 .function
1084 .as_ref()
1085 .and_then(|function| function.name.as_ref())
1086 .or(delta.name.as_ref())
1087 .filter(|value| !value.is_empty());
1088 if let Some(name) = delta_name {
1089 self.name = Some(name.clone());
1090 }
1091
1092 if let Some(arguments_delta) = delta
1093 .function
1094 .as_ref()
1095 .and_then(|function| function.arguments.as_ref())
1096 .or(delta.arguments.as_ref())
1097 .filter(|value| !value.is_empty())
1098 {
1099 self.arguments.push_str(arguments_delta);
1100 }
1101
1102 if let Some(extra) = delta.extra_content.as_ref() {
1104 self.extra_content = Some(extra.clone());
1105 }
1106 }
1107
1108 fn into_provider_tool_call(
1109 self,
1110 targets_mistral_tool_call_contract: bool,
1111 used_tool_call_ids: &mut std::collections::HashSet<String>,
1112 ) -> Option<ProviderToolCall> {
1113 let name = self.name?;
1114 let arguments = if self.arguments.trim().is_empty() {
1115 "{}".to_string()
1116 } else {
1117 self.arguments
1118 };
1119 let normalized_arguments = if serde_json::from_str::<serde_json::Value>(&arguments).is_ok()
1120 {
1121 arguments
1122 } else {
1123 ::zeroclaw_log::record!(
1124 WARN,
1125 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
1126 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
1127 .with_attrs(::serde_json::json!({"function": name, "arguments": arguments})),
1128 "Invalid JSON in streamed native tool-call arguments, using empty object"
1129 );
1130 "{}".to_string()
1131 };
1132
1133 Some(ProviderToolCall {
1134 id: reserve_tool_call_id_for_contract(
1135 targets_mistral_tool_call_contract,
1136 self.id,
1137 used_tool_call_ids,
1138 ),
1139 name,
1140 arguments: normalized_arguments,
1141 extra_content: self.extra_content,
1142 })
1143 }
1144}
1145
1146fn parse_sse_chunk(line: &str) -> StreamResult<Option<StreamChunkResponse>> {
1147 let line = line.trim();
1148
1149 if line.is_empty() || line.starts_with(':') {
1150 return Ok(None);
1151 }
1152
1153 let Some(data) = line.strip_prefix("data:") else {
1154 return Ok(None);
1155 };
1156 let data = data.trim();
1157
1158 if data == "[DONE]" {
1159 return Ok(None);
1160 }
1161
1162 serde_json::from_str(data)
1163 .map(Some)
1164 .map_err(StreamError::Json)
1165}
1166
1167fn parse_proxy_tool_event(line: &str) -> Option<StreamEvent> {
1171 let data = line.trim().strip_prefix("data:")?.trim();
1172 let obj: serde_json::Value = serde_json::from_str(data).ok()?;
1173
1174 if let Some(ts) = obj.get("x_tool_start") {
1175 let Some(name) = ts.get("name").and_then(|v| v.as_str()) else {
1176 ::zeroclaw_log::record!(
1177 DEBUG,
1178 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
1179 "proxy x_tool_start event missing required 'name' field"
1180 );
1181 return None;
1182 };
1183 let name = name.to_string();
1184 let args = ts
1185 .get("arguments")
1186 .and_then(|v| v.as_str())
1187 .unwrap_or("{}")
1188 .to_string();
1189 return Some(StreamEvent::PreExecutedToolCall { name, args });
1190 }
1191
1192 if let Some(tr) = obj.get("x_tool_result") {
1193 let name = tr
1194 .get("name")
1195 .and_then(|v| v.as_str())
1196 .unwrap_or("unknown")
1197 .to_string();
1198 let output = tr
1199 .get("output")
1200 .and_then(|v| v.as_str())
1201 .unwrap_or("")
1202 .to_string();
1203 return Some(StreamEvent::PreExecutedToolResult { name, output });
1204 }
1205
1206 None
1207}
1208
1209fn extract_sse_text_delta(choice: &StreamChoice) -> Option<String> {
1210 if let Some(content) = &choice.delta.content
1211 && !content.is_empty()
1212 {
1213 return Some(content.clone());
1214 }
1215
1216 None
1217}
1218
1219fn extract_sse_reasoning_delta(choice: &StreamChoice) -> Option<String> {
1220 choice
1221 .delta
1222 .reasoning_content
1223 .as_ref()
1224 .filter(|value| !value.is_empty())
1225 .cloned()
1226}
1227
1228fn is_valid_mistral_tool_call_id(id: &str) -> bool {
1229 id.len() == 9 && id.chars().all(|c| c.is_ascii_alphanumeric())
1230}
1231
1232fn reserve_tool_call_id_for_contract(
1233 targets_mistral_tool_call_contract: bool,
1234 raw_id: Option<String>,
1235 used_ids: &mut std::collections::HashSet<String>,
1236) -> String {
1237 if !targets_mistral_tool_call_contract {
1238 let id = raw_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
1239 if used_ids.insert(id.clone()) {
1240 return id;
1241 }
1242
1243 loop {
1244 let candidate = uuid::Uuid::new_v4().to_string();
1245 if used_ids.insert(candidate.clone()) {
1246 return candidate;
1247 }
1248 }
1249 }
1250
1251 if let Some(id) = raw_id.as_deref()
1252 && is_valid_mistral_tool_call_id(id)
1253 && used_ids.insert(id.to_string())
1254 {
1255 return id.to_string();
1256 }
1257
1258 let mut candidate = raw_id
1259 .as_deref()
1260 .unwrap_or_default()
1261 .chars()
1262 .filter(|c| c.is_ascii_alphanumeric())
1263 .take(9)
1264 .collect::<String>();
1265
1266 if candidate.len() < 9 {
1267 candidate.extend(
1268 uuid::Uuid::new_v4()
1269 .as_simple()
1270 .to_string()
1271 .chars()
1272 .take(9 - candidate.len()),
1273 );
1274 }
1275
1276 if used_ids.insert(candidate.clone()) {
1277 return candidate;
1278 }
1279
1280 loop {
1281 let generated = uuid::Uuid::new_v4()
1282 .as_simple()
1283 .to_string()
1284 .chars()
1285 .take(9)
1286 .collect::<String>();
1287 if used_ids.insert(generated.clone()) {
1288 return generated;
1289 }
1290 }
1291}
1292
1293fn parse_sse_line(line: &str) -> StreamResult<Option<StreamChunk>> {
1300 let chunk = match parse_sse_chunk(line)? {
1301 Some(c) => c,
1302 None => return Ok(None),
1303 };
1304
1305 if let Some(choice) = chunk.choices.first() {
1306 if let Some(content) = &choice.delta.content
1307 && !content.is_empty()
1308 {
1309 return Ok(Some(StreamChunk::delta(content.clone())));
1310 }
1311 if let Some(reasoning) = &choice.delta.reasoning_content
1312 && !reasoning.is_empty()
1313 {
1314 return Ok(Some(StreamChunk::reasoning(reasoning.clone())));
1315 }
1316 }
1317
1318 Ok(None)
1319}
1320
1321fn sse_bytes_to_chunks(
1323 response: reqwest::Response,
1324 count_tokens: bool,
1325) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1326 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1327
1328 tokio::spawn(async move {
1329 let mut buffer = String::new();
1330
1331 match response.error_for_status_ref() {
1332 Ok(_) => {}
1333 Err(e) => {
1334 let _ = tx
1335 .send(Err(StreamError::Http(super::format_error_chain(&e))))
1336 .await;
1337 return;
1338 }
1339 }
1340
1341 let mut bytes_stream = response.bytes_stream();
1342 let mut utf8_buf: Vec<u8> = Vec::new();
1345
1346 while let Some(item) = bytes_stream.next().await {
1347 match item {
1348 Ok(bytes) => {
1349 utf8_buf.extend_from_slice(&bytes);
1350 let text = match std::str::from_utf8(&utf8_buf) {
1351 Ok(s) => {
1352 let owned = s.to_string();
1353 utf8_buf.clear();
1354 owned
1355 }
1356 Err(e) => {
1357 let valid_up_to = e.valid_up_to();
1358 if valid_up_to == 0 && utf8_buf.len() < 4 {
1359 continue;
1361 }
1362 let valid =
1363 String::from_utf8_lossy(&utf8_buf[..valid_up_to]).into_owned();
1364 utf8_buf.drain(..valid_up_to);
1365 valid
1366 }
1367 };
1368 if text.is_empty() {
1369 continue;
1370 }
1371
1372 buffer.push_str(&text);
1373
1374 while let Some(pos) = buffer.find('\n') {
1375 let line = buffer[..pos].to_string();
1376 buffer.drain(..=pos);
1377
1378 match parse_sse_line(&line) {
1379 Ok(Some(chunk)) => {
1380 let chunk = if count_tokens {
1381 chunk.with_token_estimate()
1382 } else {
1383 chunk
1384 };
1385 if tx.send(Ok(chunk)).await.is_err() {
1386 return; }
1388 }
1389 Ok(None) => {}
1390 Err(e) => {
1391 let _ = tx.send(Err(e)).await;
1392 return;
1393 }
1394 }
1395 }
1396 }
1397 Err(e) => {
1398 let _ = tx
1399 .send(Err(StreamError::Http(super::format_error_chain(&e))))
1400 .await;
1401 return;
1402 }
1403 }
1404 }
1405
1406 let _ = tx.send(Ok(StreamChunk::final_chunk())).await;
1407 });
1408
1409 stream::unfold(rx, |mut rx| async {
1410 rx.recv().await.map(|chunk| (chunk, rx))
1411 })
1412 .boxed()
1413}
1414
1415pub(crate) fn sse_bytes_to_events(
1417 response: reqwest::Response,
1418 count_tokens: bool,
1419) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
1420 sse_bytes_to_events_for_contract(response, count_tokens, false)
1421}
1422
1423fn sse_bytes_to_events_for_contract(
1424 response: reqwest::Response,
1425 count_tokens: bool,
1426 targets_mistral_tool_call_contract: bool,
1427) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
1428 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamEvent>>(100);
1429
1430 tokio::spawn(async move {
1431 let mut buffer = String::new();
1432 let mut tool_calls: Vec<StreamToolCallAccumulator> = Vec::new();
1433 let mut used_tool_call_ids = std::collections::HashSet::new();
1434 let mut emitted_tool_calls = false;
1435
1436 match response.error_for_status_ref() {
1437 Ok(_) => {}
1438 Err(e) => {
1439 let _ = tx
1440 .send(Err(StreamError::Http(super::format_error_chain(&e))))
1441 .await;
1442 return;
1443 }
1444 }
1445
1446 let mut bytes_stream = response.bytes_stream();
1447 let mut utf8_buf: Vec<u8> = Vec::new();
1449 while let Some(item) = bytes_stream.next().await {
1450 match item {
1451 Ok(bytes) => {
1452 utf8_buf.extend_from_slice(&bytes);
1453 let text = match std::str::from_utf8(&utf8_buf) {
1454 Ok(s) => {
1455 let owned = s.to_string();
1456 utf8_buf.clear();
1457 owned
1458 }
1459 Err(e) => {
1460 let valid_up_to = e.valid_up_to();
1461 if valid_up_to == 0 && utf8_buf.len() < 4 {
1462 continue;
1463 }
1464 let valid =
1465 String::from_utf8_lossy(&utf8_buf[..valid_up_to]).into_owned();
1466 utf8_buf.drain(..valid_up_to);
1467 valid
1468 }
1469 };
1470 if text.is_empty() {
1471 continue;
1472 }
1473
1474 buffer.push_str(&text);
1475
1476 while let Some(pos) = buffer.find('\n') {
1477 let line = buffer[..pos].to_string();
1478 buffer.drain(..=pos);
1479
1480 if let Some(event) = parse_proxy_tool_event(&line) {
1483 if tx.send(Ok(event)).await.is_err() {
1484 return;
1485 }
1486 continue;
1487 }
1488
1489 let chunk = match parse_sse_chunk(&line) {
1490 Ok(Some(chunk)) => chunk,
1491 Ok(None) => continue,
1492 Err(e) => {
1493 let _ = tx.send(Err(e)).await;
1494 return;
1495 }
1496 };
1497
1498 let mut should_emit_tool_calls = false;
1499 for choice in &chunk.choices {
1500 if let Some(reasoning_delta) = extract_sse_reasoning_delta(choice) {
1501 let reasoning_chunk = StreamChunk::reasoning(reasoning_delta);
1502 if tx
1503 .send(Ok(StreamEvent::TextDelta(reasoning_chunk)))
1504 .await
1505 .is_err()
1506 {
1507 return;
1508 }
1509 }
1510 if let Some(text_delta) = extract_sse_text_delta(choice) {
1511 let mut text_chunk = StreamChunk::delta(text_delta);
1512 if count_tokens {
1513 text_chunk = text_chunk.with_token_estimate();
1514 }
1515 if tx
1516 .send(Ok(StreamEvent::TextDelta(text_chunk)))
1517 .await
1518 .is_err()
1519 {
1520 return;
1521 }
1522 }
1523
1524 if let Some(deltas) = choice.delta.tool_calls.as_ref() {
1525 for delta in deltas {
1526 let index = delta.index.unwrap_or(tool_calls.len());
1527 if index >= tool_calls.len() {
1528 tool_calls.resize_with(index + 1, Default::default);
1529 }
1530 if let Some(acc) = tool_calls.get_mut(index) {
1531 acc.apply_delta(delta);
1532 }
1533 }
1534 }
1535
1536 if choice.finish_reason.as_deref() == Some("tool_calls") {
1537 should_emit_tool_calls = true;
1538 }
1539 }
1540
1541 if let Some(usage) = chunk.usage.as_ref() {
1542 let token_usage = zeroclaw_api::model_provider::TokenUsage {
1543 input_tokens: usage.prompt_tokens,
1544 output_tokens: usage.completion_tokens,
1545 cached_input_tokens: None,
1546 };
1547 if tx.send(Ok(StreamEvent::Usage(token_usage))).await.is_err() {
1548 return;
1549 }
1550 }
1551
1552 if should_emit_tool_calls && !emitted_tool_calls {
1553 emitted_tool_calls = true;
1554 for tool_call in tool_calls.drain(..).filter_map(|tool_call| {
1555 tool_call.into_provider_tool_call(
1556 targets_mistral_tool_call_contract,
1557 &mut used_tool_call_ids,
1558 )
1559 }) {
1560 if tx.send(Ok(StreamEvent::ToolCall(tool_call))).await.is_err() {
1561 return;
1562 }
1563 }
1564 }
1565 }
1566 }
1567 Err(e) => {
1568 let _ = tx
1569 .send(Err(StreamError::Http(super::format_error_chain(&e))))
1570 .await;
1571 return;
1572 }
1573 }
1574 }
1575
1576 if !emitted_tool_calls {
1577 for tool_call in tool_calls.drain(..).filter_map(|tool_call| {
1578 tool_call.into_provider_tool_call(
1579 targets_mistral_tool_call_contract,
1580 &mut used_tool_call_ids,
1581 )
1582 }) {
1583 if tx.send(Ok(StreamEvent::ToolCall(tool_call))).await.is_err() {
1584 return;
1585 }
1586 }
1587 }
1588
1589 let _ = tx.send(Ok(StreamEvent::Final)).await;
1590 });
1591
1592 stream::unfold(rx, |mut rx| async move {
1593 rx.recv().await.map(|event| (event, rx))
1594 })
1595 .boxed()
1596}
1597
1598fn parse_chat_response_body(name: &str, body: &str) -> anyhow::Result<ApiChatResponse> {
1599 serde_json::from_str(body).map_err(|_| {
1600 let sanitized = super::sanitize_api_error(body);
1601 ::zeroclaw_log::record!(
1602 ERROR,
1603 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1604 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1605 .with_attrs(::serde_json::json!({
1606 "model_provider": name,
1607 "body": &sanitized,
1608 })),
1609 "compatible: unexpected chat-completions payload"
1610 );
1611 anyhow::Error::msg(format!(
1612 "{name} API returned an unexpected chat-completions payload; body={sanitized}"
1613 ))
1614 })
1615}
1616
1617impl OpenAiCompatibleModelProvider {
1618 fn apply_auth_header(
1619 &self,
1620 req: reqwest::RequestBuilder,
1621 credential: Option<&str>,
1622 ) -> reqwest::RequestBuilder {
1623 apply_auth_to_request(req, &self.auth_header, credential)
1624 }
1625
1626 fn convert_tool_specs(
1627 tools: Option<&[zeroclaw_api::tool::ToolSpec]>,
1628 ) -> Option<Vec<serde_json::Value>> {
1629 tools.map(|items| {
1630 items
1631 .iter()
1632 .map(|tool| {
1633 let params = zeroclaw_api::schema::SchemaCleanr::clean_for_openai(
1634 tool.parameters.clone(),
1635 );
1636 serde_json::json!({
1637 "type": "function",
1638 "function": {
1639 "name": tool.name,
1640 "description": tool.description,
1641 "parameters": params,
1642 }
1643 })
1644 })
1645 .collect()
1646 })
1647 }
1648
1649 fn convert_tool_specs_for_model(
1655 &self,
1656 tools: Option<&[zeroclaw_api::tool::ToolSpec]>,
1657 model: &str,
1658 ) -> Option<Vec<serde_json::Value>> {
1659 let converted = Self::convert_tool_specs(tools)?;
1660 if !self.local_model_tool_sanitize || !Self::should_sanitize_local_tool_schema(model) {
1661 return Some(converted);
1662 }
1663 Some(
1664 converted
1665 .into_iter()
1666 .map(|mut tool| {
1667 let Some(raw_parameters) = tool.get("parameters").cloned() else {
1668 return tool;
1669 };
1670 let cleaned = zeroclaw_api::schema::SchemaCleanr::clean(
1671 raw_parameters,
1672 zeroclaw_api::schema::CleaningStrategy::Conservative,
1673 );
1674 if let Some(obj) = tool.as_object_mut() {
1675 obj.insert("parameters".to_string(), cleaned);
1676 }
1677 tool
1678 })
1679 .collect(),
1680 )
1681 }
1682
1683 fn should_sanitize_local_tool_schema(model: &str) -> bool {
1684 let lower = model.to_ascii_lowercase();
1685 model.is_empty() || lower.contains("gemma-4") || lower.contains("gemma4")
1686 }
1687
1688 fn build_native_tool_chat_request(
1689 &self,
1690 effective_messages: &[ChatMessage],
1691 tools: Option<Vec<serde_json::Value>>,
1692 model: &str,
1693 temperature: f64,
1694 allow_user_image_parts: bool,
1695 ) -> NativeChatRequest {
1696 let has_tool_entries = tools.as_ref().is_some_and(|tools| !tools.is_empty());
1697 let tool_choice = tools.as_ref().map(|_| "auto".to_string());
1698
1699 NativeChatRequest {
1700 model: model.to_string(),
1701 messages: self.convert_messages_for_native(effective_messages, allow_user_image_parts),
1702 temperature,
1703 stream: Some(false),
1704 stream_options: None,
1707 reasoning_effort: self.reasoning_effort_for_model(model),
1708 tool_stream: self.tool_stream_for_tools(has_tool_entries),
1709 tools,
1710 tool_choice,
1711 max_tokens: self.max_tokens,
1712 }
1713 }
1714
1715 async fn normalize_messages_for_upstream(
1732 messages: &[ChatMessage],
1733 ) -> anyhow::Result<Vec<ChatMessage>> {
1734 let config = zeroclaw_config::schema::MultimodalConfig::default();
1735 let prepared = multimodal::prepare_messages_for_provider(messages, &config).await?;
1736 Ok(prepared.messages)
1737 }
1738
1739 fn to_message_content(
1740 role: &str,
1741 content: &str,
1742 allow_user_image_parts: bool,
1743 ) -> MessageContent {
1744 if role != "user" || !allow_user_image_parts {
1745 return MessageContent::Text(content.to_string());
1746 }
1747
1748 let (cleaned_text, image_refs) = multimodal::parse_image_markers(content);
1749 if image_refs.is_empty() {
1750 return MessageContent::Text(content.to_string());
1751 }
1752
1753 let mut parts = Vec::with_capacity(image_refs.len() + 1);
1754 let trimmed_text = cleaned_text.trim();
1755 if !trimmed_text.is_empty() {
1756 parts.push(MessagePart::Text {
1757 text: trimmed_text.to_string(),
1758 });
1759 }
1760
1761 for image_ref in image_refs {
1762 parts.push(MessagePart::ImageUrl {
1763 image_url: ImageUrlPart { url: image_ref },
1764 });
1765 }
1766
1767 MessageContent::Parts(parts)
1768 }
1769
1770 fn convert_messages_for_native(
1771 &self,
1772 messages: &[ChatMessage],
1773 allow_user_image_parts: bool,
1774 ) -> Vec<NativeMessage> {
1775 let targets_mistral_tool_call_contract = self.targets_mistral_tool_call_contract();
1776 let mut used_tool_call_ids = std::collections::HashSet::new();
1777 let mut tool_call_id_map = std::collections::HashMap::new();
1778
1779 messages
1780 .iter()
1781 .map(|message| {
1782 if message.role == "assistant"
1783 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
1784 && let Some(tool_calls_value) = value.get("tool_calls")
1785 && let Ok(parsed_calls) =
1786 serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
1787 {
1788 let tool_calls = parsed_calls
1789 .into_iter()
1790 .map(|tc| ToolCall {
1791 id: Some({
1792 let normalized_id = reserve_tool_call_id_for_contract(
1793 targets_mistral_tool_call_contract,
1794 Some(tc.id.clone()),
1795 &mut used_tool_call_ids,
1796 );
1797 tool_call_id_map.insert(tc.id, normalized_id.clone());
1798 normalized_id
1799 }),
1800 kind: Some("function".to_string()),
1801 function: Some(Function {
1802 name: Some(tc.name),
1803 arguments: Some(tc.arguments),
1804 }),
1805 name: None,
1806 arguments: None,
1807 parameters: None,
1808 extra_content: tc.extra_content,
1811 })
1812 .collect::<Vec<_>>();
1813
1814 let content = value
1815 .get("content")
1816 .and_then(serde_json::Value::as_str)
1817 .map(|value| MessageContent::Text(value.to_string()));
1818
1819 let reasoning_content = value
1822 .get("reasoning_content")
1823 .or_else(|| value.get("reasoning"))
1824 .and_then(serde_json::Value::as_str)
1825 .map(ToString::to_string);
1826
1827 return NativeMessage {
1828 role: "assistant".to_string(),
1829 content,
1830 tool_call_id: None,
1831 tool_calls: Some(tool_calls),
1832 reasoning_content,
1833 };
1834 }
1835
1836 if message.role == "assistant"
1844 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
1845 && value.get("tool_calls").is_none()
1846 && let Some(reasoning_content) = value
1847 .get("reasoning_content")
1848 .and_then(serde_json::Value::as_str)
1849 && matches!(
1850 value.get("content"),
1851 None | Some(serde_json::Value::Null | serde_json::Value::String(_))
1852 )
1853 {
1854 let content = value
1855 .get("content")
1856 .and_then(serde_json::Value::as_str)
1857 .map(|value| MessageContent::Text(value.to_string()));
1858
1859 return NativeMessage {
1860 role: "assistant".to_string(),
1861 content,
1862 tool_call_id: None,
1863 tool_calls: None,
1864 reasoning_content: Some(reasoning_content.to_string()),
1865 };
1866 }
1867
1868 if message.role == "tool"
1869 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
1870 {
1871 let tool_call_id = value
1872 .get("tool_call_id")
1873 .and_then(serde_json::Value::as_str)
1874 .map(|raw_id| {
1875 tool_call_id_map.get(raw_id).cloned().unwrap_or_else(|| {
1876 let normalized_id = reserve_tool_call_id_for_contract(
1877 targets_mistral_tool_call_contract,
1878 Some(raw_id.to_string()),
1879 &mut used_tool_call_ids,
1880 );
1881 tool_call_id_map.insert(raw_id.to_string(), normalized_id.clone());
1882 normalized_id
1883 })
1884 });
1885 let content = value
1886 .get("content")
1887 .and_then(serde_json::Value::as_str)
1888 .map(|value| MessageContent::Text(value.to_string()))
1889 .or_else(|| Some(MessageContent::Text(message.content.clone())));
1890
1891 return NativeMessage {
1892 role: "tool".to_string(),
1893 content,
1894 tool_call_id,
1895 tool_calls: None,
1896 reasoning_content: None,
1897 };
1898 }
1899
1900 NativeMessage {
1901 role: message.role.clone(),
1902 content: Some(Self::to_message_content(
1903 &message.role,
1904 &message.content,
1905 allow_user_image_parts,
1906 )),
1907 tool_call_id: None,
1908 tool_calls: None,
1909 reasoning_content: None,
1910 }
1911 })
1912 .collect()
1913 }
1914
1915 fn strip_native_tool_messages(&self, messages: &[ChatMessage]) -> Vec<ChatMessage> {
1929 if self.native_tool_calling {
1930 return messages.to_vec();
1931 }
1932 let intermediate = messages.iter().filter_map(|msg| {
1933 if msg.role == "tool" {
1934 return None;
1935 }
1936 if msg.role == "assistant"
1937 && let Ok(value) = serde_json::from_str::<serde_json::Value>(&msg.content)
1938 && value.get("tool_calls").is_some()
1939 {
1940 let text = value
1941 .get("content")
1942 .and_then(serde_json::Value::as_str)
1943 .unwrap_or("")
1944 .to_string();
1945 return if text.is_empty() {
1946 None
1947 } else {
1948 Some(ChatMessage::assistant(&text))
1949 };
1950 }
1951 Some(msg.clone())
1952 });
1953
1954 let mut coalesced: Vec<ChatMessage> = Vec::with_capacity(messages.len());
1965 for msg in intermediate {
1966 match coalesced.last_mut() {
1967 Some(last) if last.role == "assistant" && msg.role == "assistant" => {
1968 if !last.content.is_empty() && !msg.content.is_empty() {
1969 last.content.push_str("\n\n");
1970 }
1971 last.content.push_str(&msg.content);
1972 }
1973 _ => coalesced.push(msg),
1974 }
1975 }
1976 coalesced
1977 }
1978
1979 fn with_prompt_guided_tool_instructions(
1980 messages: &[ChatMessage],
1981 tools: Option<&[zeroclaw_api::tool::ToolSpec]>,
1982 ) -> Vec<ChatMessage> {
1983 let Some(tools) = tools else {
1984 return messages.to_vec();
1985 };
1986
1987 if tools.is_empty() {
1988 return messages.to_vec();
1989 }
1990
1991 let instructions = zeroclaw_api::model_provider::build_tool_instructions_text(tools);
1992 let mut modified_messages = messages.to_vec();
1993
1994 if let Some(system_message) = modified_messages.iter_mut().find(|m| m.role == "system") {
1995 if !system_message.content.is_empty() {
1996 system_message.content.push_str("\n\n");
1997 }
1998 system_message.content.push_str(&instructions);
1999 } else {
2000 modified_messages.insert(0, ChatMessage::system(instructions));
2001 }
2002
2003 modified_messages
2004 }
2005
2006 fn targets_mistral_tool_call_contract(&self) -> bool {
2007 if self.name.eq_ignore_ascii_case("mistral") {
2008 return true;
2009 }
2010
2011 reqwest::Url::parse(&self.base_url)
2012 .ok()
2013 .and_then(|url| url.host_str().map(|h| h.to_ascii_lowercase()))
2014 .is_some_and(|host| host == "mistral.ai" || host.ends_with(".mistral.ai"))
2015 }
2016
2017 fn reserve_tool_call_id(
2018 &self,
2019 raw_id: Option<String>,
2020 used_ids: &mut std::collections::HashSet<String>,
2021 ) -> String {
2022 reserve_tool_call_id_for_contract(
2023 self.targets_mistral_tool_call_contract(),
2024 raw_id,
2025 used_ids,
2026 )
2027 }
2028
2029 fn parse_native_response(&self, message: ResponseMessage) -> ProviderChatResponse {
2030 let text = message.effective_content_optional();
2031 let reasoning_content = message.reasoning_content.clone();
2032 let mut used_tool_call_ids = std::collections::HashSet::new();
2033 let tool_calls = message
2034 .tool_calls
2035 .unwrap_or_default()
2036 .into_iter()
2037 .filter_map(|tc| {
2038 let name = tc.function_name()?;
2039 let arguments = tc.function_arguments().unwrap_or_else(|| "{}".to_string());
2040 let normalized_arguments = if serde_json::from_str::<serde_json::Value>(&arguments)
2041 .is_ok()
2042 {
2043 arguments
2044 } else {
2045 ::zeroclaw_log::record!(
2046 WARN,
2047 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
2048 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
2049 .with_attrs(
2050 ::serde_json::json!({"function": name, "arguments": arguments})
2051 ),
2052 "Invalid JSON in native tool-call arguments, using empty object"
2053 );
2054 "{}".to_string()
2055 };
2056 Some(ProviderToolCall {
2057 id: self.reserve_tool_call_id(tc.id, &mut used_tool_call_ids),
2058 name,
2059 arguments: normalized_arguments,
2060 extra_content: tc.extra_content,
2061 })
2062 })
2063 .collect::<Vec<_>>();
2064
2065 ProviderChatResponse {
2066 text,
2067 tool_calls,
2068 usage: None,
2069 reasoning_content,
2070 }
2071 }
2072
2073 fn is_native_tool_schema_unsupported(status: reqwest::StatusCode, error: &str) -> bool {
2074 if !matches!(
2075 status,
2076 reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::UNPROCESSABLE_ENTITY
2077 ) {
2078 return false;
2079 }
2080
2081 let lower = error.to_lowercase();
2082 [
2083 "unknown parameter: tools",
2084 "unsupported parameter: tools",
2085 "unrecognized field `tools`",
2086 "does not support tools",
2087 "function calling is not supported",
2088 "tool_choice",
2089 "tool call validation failed",
2090 "was not in request",
2091 ]
2092 .iter()
2093 .any(|hint| lower.contains(hint))
2094 }
2095}
2096
2097#[async_trait]
2098impl ModelProvider for OpenAiCompatibleModelProvider {
2099 fn capabilities(&self) -> zeroclaw_api::model_provider::ProviderCapabilities {
2100 zeroclaw_api::model_provider::ProviderCapabilities {
2101 native_tool_calling: self.native_tool_calling,
2102 vision: self.supports_vision,
2103 prompt_caching: false,
2104 extended_thinking: false,
2105 }
2106 }
2107
2108 async fn list_models(&self) -> anyhow::Result<Vec<String>> {
2109 let list_credential = self.credential.as_deref();
2114 if list_credential.is_some() || self.unauthenticated_model_listing {
2115 let url = format!("{}/models", self.base_url);
2116 let response = self
2117 .apply_auth_header(self.http_client().get(&url), list_credential)
2118 .send()
2119 .await
2120 .map_err(|e| {
2121 ::zeroclaw_log::record!(
2122 ERROR,
2123 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2124 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2125 .with_attrs(::serde_json::json!({
2126 "model_provider": &self.name,
2127 "url": &url,
2128 "phase": "model_list_request",
2129 "error": super::format_error_chain(&e),
2130 })),
2131 "compatible: model list request failed"
2132 );
2133 anyhow::Error::msg(format!(
2134 "{} model list request failed: {url}: {e}",
2135 self.name
2136 ))
2137 })?;
2138 if !response.status().is_success() {
2139 let status = response.status();
2140 anyhow::bail!("{} model list failed at {url}: HTTP {status}", self.name);
2141 }
2142 let body: ModelsResponse = response.json().await.map_err(|e| {
2143 ::zeroclaw_log::record!(
2144 ERROR,
2145 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2146 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2147 .with_attrs(::serde_json::json!({
2148 "model_provider": &self.name,
2149 "phase": "model_list_parse",
2150 "error": super::format_error_chain(&e),
2151 })),
2152 "compatible: model list returned invalid JSON"
2153 );
2154 anyhow::Error::msg(format!(
2155 "{} model list returned invalid JSON: {e}",
2156 self.name
2157 ))
2158 })?;
2159 return Ok(normalize_model_ids(body));
2160 }
2161 if let Some(key) = &self.models_dev_key {
2164 match crate::models_dev::list_models_for(key).await {
2165 Ok(models) if !models.is_empty() => return Ok(models),
2166 Ok(_) => {} Err(e) => {
2168 if self.openrouter_vendor_prefix.is_none() {
2169 return Err(e);
2170 }
2171 }
2172 }
2173 }
2174 match &self.openrouter_vendor_prefix {
2175 Some(prefix) => crate::openrouter_catalog::list_models_for_vendor(prefix).await,
2176 None => anyhow::bail!("live model listing is not supported for this model_provider"),
2177 }
2178 }
2179
2180 async fn chat_with_system(
2181 &self,
2182 system_prompt: Option<&str>,
2183 message: &str,
2184 model: &str,
2185 temperature: Option<f64>,
2186 ) -> anyhow::Result<String> {
2187 let temperature = temperature.unwrap_or(self.default_temperature());
2188 let credential = self.credential.as_deref();
2189
2190 let user_msg = ChatMessage {
2194 role: "user".to_string(),
2195 content: message.to_string(),
2196 };
2197 let normalized_user =
2198 Self::normalize_messages_for_upstream(std::slice::from_ref(&user_msg))
2199 .await?
2200 .pop()
2201 .unwrap_or(user_msg);
2202 let normalized_message = normalized_user.content;
2203
2204 let merge = self.effective_merge_system(model);
2205 let mut messages = Vec::new();
2206
2207 if merge {
2208 let content = match system_prompt {
2209 Some(sys) => format!("{sys}\n\n{normalized_message}"),
2210 None => normalized_message,
2211 };
2212 messages.push(Message {
2213 role: "user".to_string(),
2214 content: Self::to_message_content("user", &content, !merge),
2215 });
2216 } else {
2217 if let Some(sys) = system_prompt {
2218 messages.push(Message {
2219 role: "system".to_string(),
2220 content: MessageContent::Text(sys.to_string()),
2221 });
2222 }
2223 messages.push(Message {
2224 role: "user".to_string(),
2225 content: Self::to_message_content("user", &normalized_message, true),
2226 });
2227 }
2228
2229 let request = ApiChatRequest {
2230 model: model.to_string(),
2231 messages,
2232 temperature,
2233 stream: Some(false),
2234 stream_options: None,
2235 reasoning_effort: self.reasoning_effort_for_model(model),
2236 tool_stream: None,
2237 tools: None,
2238 tool_choice: None,
2239 max_tokens: self.max_tokens,
2240 };
2241
2242 let url = self.chat_completions_url();
2243
2244 let response = match self
2245 .apply_auth_header(self.http_client().post(&url).json(&request), credential)
2246 .send()
2247 .await
2248 {
2249 Ok(response) => response,
2250 Err(chat_error) => {
2251 return Err(chat_error.into());
2252 }
2253 };
2254
2255 if !response.status().is_success() {
2256 let status = response.status();
2257 let error = response.text().await?;
2258 let sanitized = super::sanitize_api_error(&error);
2259 anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
2260 }
2261
2262 let body = response.text().await?;
2263 let chat_response = parse_chat_response_body(&self.name, &body)?;
2264
2265 chat_response
2266 .choices
2267 .into_iter()
2268 .next()
2269 .map(|c| {
2270 if c.message.tool_calls.is_some()
2271 && c.message
2272 .tool_calls
2273 .as_ref()
2274 .is_some_and(|t: &Vec<_>| !t.is_empty())
2275 {
2276 serde_json::to_string(&c.message)
2277 .unwrap_or_else(|_| c.message.effective_content())
2278 } else {
2279 c.message.effective_content()
2280 }
2281 })
2282 .ok_or_else(|| {
2283 ::zeroclaw_log::record!(
2284 ERROR,
2285 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2286 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2287 .with_attrs(::serde_json::json!({"model_provider": &self.name})),
2288 "compatible: empty choices in response"
2289 );
2290 anyhow::Error::msg(format!("No response from {}", self.name))
2291 })
2292 }
2293
2294 async fn chat_with_history(
2295 &self,
2296 messages: &[ChatMessage],
2297 model: &str,
2298 temperature: Option<f64>,
2299 ) -> anyhow::Result<String> {
2300 let temperature = temperature.unwrap_or(self.default_temperature());
2301 let credential = self.credential.as_deref();
2302
2303 let normalized = Self::normalize_messages_for_upstream(messages).await?;
2304 let merge = self.effective_merge_system(model);
2305 let effective_messages = Self::flatten_system_messages(&normalized, merge);
2306 let effective_messages = self.strip_native_tool_messages(&effective_messages);
2308 let api_messages: Vec<Message> = effective_messages
2309 .iter()
2310 .map(|m| Message {
2311 role: m.role.clone(),
2312 content: Self::to_message_content(&m.role, &m.content, !merge),
2313 })
2314 .collect();
2315
2316 let request = ApiChatRequest {
2317 model: model.to_string(),
2318 messages: api_messages,
2319 temperature,
2320 stream: Some(false),
2321 stream_options: None,
2322 reasoning_effort: self.reasoning_effort_for_model(model),
2323 tool_stream: None,
2324 tools: None,
2325 tool_choice: None,
2326 max_tokens: self.max_tokens,
2327 };
2328
2329 let url = self.chat_completions_url();
2330 let response = match self
2331 .apply_auth_header(self.http_client().post(&url).json(&request), credential)
2332 .send()
2333 .await
2334 {
2335 Ok(response) => response,
2336 Err(chat_error) => return Err(chat_error.into()),
2337 };
2338
2339 if !response.status().is_success() {
2340 return Err(super::api_error(&self.name, response).await);
2341 }
2342
2343 let body = response.text().await?;
2344 let chat_response = parse_chat_response_body(&self.name, &body)?;
2345
2346 chat_response
2347 .choices
2348 .into_iter()
2349 .next()
2350 .map(|c| {
2351 if c.message.tool_calls.is_some()
2352 && c.message
2353 .tool_calls
2354 .as_ref()
2355 .is_some_and(|t: &Vec<_>| !t.is_empty())
2356 {
2357 serde_json::to_string(&c.message)
2358 .unwrap_or_else(|_| c.message.effective_content())
2359 } else {
2360 c.message.effective_content()
2361 }
2362 })
2363 .ok_or_else(|| {
2364 ::zeroclaw_log::record!(
2365 ERROR,
2366 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2367 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2368 .with_attrs(::serde_json::json!({"model_provider": &self.name})),
2369 "compatible: empty choices in response"
2370 );
2371 anyhow::Error::msg(format!("No response from {}", self.name))
2372 })
2373 }
2374
2375 async fn chat_with_tools(
2376 &self,
2377 messages: &[ChatMessage],
2378 tools: &[serde_json::Value],
2379 model: &str,
2380 temperature: Option<f64>,
2381 ) -> anyhow::Result<ProviderChatResponse> {
2382 let temperature = temperature.unwrap_or(self.default_temperature());
2383 let credential = self.credential.as_deref();
2384
2385 let normalized = Self::normalize_messages_for_upstream(messages).await?;
2386 let merge = self.effective_merge_system(model);
2387 let effective_messages = Self::flatten_system_messages(&normalized, merge);
2388 let effective_messages = if self.native_tool_calling {
2389 effective_messages
2390 } else {
2391 self.strip_native_tool_messages(&effective_messages)
2392 };
2393 let tools = if tools.is_empty() {
2394 None
2395 } else {
2396 Some(tools.to_vec())
2397 };
2398 let request = self.build_native_tool_chat_request(
2399 &effective_messages,
2400 tools,
2401 model,
2402 temperature,
2403 !merge,
2404 );
2405
2406 let url = self.chat_completions_url();
2407 let response = match self
2408 .apply_auth_header(self.http_client().post(&url).json(&request), credential)
2409 .send()
2410 .await
2411 {
2412 Ok(response) => response,
2413 Err(error) => {
2414 ::zeroclaw_log::record!(
2415 WARN,
2416 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
2417 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
2418 &format!(
2419 "{} native tool call transport failed: {error}; falling back to history path",
2420 self.name
2421 )
2422 );
2423 let text = self
2424 .chat_with_history(messages, model, Some(temperature))
2425 .await?;
2426 return Ok(ProviderChatResponse {
2427 text: Some(text),
2428 tool_calls: vec![],
2429 usage: None,
2430 reasoning_content: None,
2431 });
2432 }
2433 };
2434
2435 if !response.status().is_success() {
2436 return Err(super::api_error(&self.name, response).await);
2437 }
2438
2439 let body = response.text().await?;
2440 let chat_response = parse_chat_response_body(&self.name, &body)?;
2441 let usage = chat_response.usage.map(|u| TokenUsage {
2442 input_tokens: u.prompt_tokens,
2443 output_tokens: u.completion_tokens,
2444 cached_input_tokens: None,
2445 });
2446 let choice = chat_response.choices.into_iter().next().ok_or_else(|| {
2447 ::zeroclaw_log::record!(
2448 ERROR,
2449 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2450 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2451 .with_attrs(::serde_json::json!({"model_provider": &self.name})),
2452 "compatible: empty choices in response"
2453 );
2454 anyhow::Error::msg(format!("No response from {}", self.name))
2455 })?;
2456
2457 let text = choice.message.effective_content_optional();
2458 let reasoning_content = choice.message.reasoning_content;
2459 let mut used_tool_call_ids = std::collections::HashSet::new();
2460 let tool_calls = choice
2461 .message
2462 .tool_calls
2463 .unwrap_or_default()
2464 .into_iter()
2465 .filter_map(|tc| {
2466 let function = tc.function?;
2467 let name = function.name?;
2468 let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
2469 Some(ProviderToolCall {
2470 id: self.reserve_tool_call_id(tc.id, &mut used_tool_call_ids),
2471 name,
2472 arguments,
2473 extra_content: tc.extra_content,
2474 })
2475 })
2476 .collect::<Vec<_>>();
2477
2478 Ok(ProviderChatResponse {
2479 text,
2480 tool_calls,
2481 usage,
2482 reasoning_content,
2483 })
2484 }
2485
2486 async fn chat(
2487 &self,
2488 request: ProviderChatRequest<'_>,
2489 model: &str,
2490 temperature: Option<f64>,
2491 ) -> anyhow::Result<ProviderChatResponse> {
2492 let temperature = temperature.unwrap_or(self.default_temperature());
2493 let credential = self.credential.as_deref();
2494
2495 let normalized = Self::normalize_messages_for_upstream(request.messages).await?;
2496 let merge = self.effective_merge_system(model);
2497 let effective_messages = Self::flatten_system_messages(&normalized, merge);
2498 let effective_messages = if self.native_tool_calling {
2499 effective_messages
2500 } else {
2501 self.strip_native_tool_messages(&effective_messages)
2502 };
2503
2504 let tools = self.convert_tool_specs_for_model(request.tools, model);
2507 let native_request = self.build_native_tool_chat_request(
2508 &effective_messages,
2509 tools,
2510 model,
2511 temperature,
2512 !merge,
2513 );
2514
2515 let url = self.chat_completions_url();
2516 let response = match self
2517 .apply_auth_header(
2518 self.http_client().post(&url).json(&native_request),
2519 credential,
2520 )
2521 .send()
2522 .await
2523 {
2524 Ok(response) => response,
2525 Err(chat_error) => return Err(chat_error.into()),
2526 };
2527
2528 if !response.status().is_success() {
2529 let status = response.status();
2530 let error = response.text().await?;
2531 let sanitized = super::sanitize_api_error(&error);
2532
2533 if Self::is_native_tool_schema_unsupported(status, &sanitized) {
2534 let fallback_messages =
2535 Self::with_prompt_guided_tool_instructions(request.messages, request.tools);
2536 let text = self
2537 .chat_with_history(&fallback_messages, model, Some(temperature))
2538 .await?;
2539 return Ok(ProviderChatResponse {
2540 text: Some(text),
2541 tool_calls: vec![],
2542 usage: None,
2543 reasoning_content: None,
2544 });
2545 }
2546
2547 anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
2548 }
2549
2550 let native_response: ApiChatResponse = response.json().await?;
2551 let usage = native_response.usage.map(|u| TokenUsage {
2552 input_tokens: u.prompt_tokens,
2553 output_tokens: u.completion_tokens,
2554 cached_input_tokens: None,
2555 });
2556 let message = native_response
2557 .choices
2558 .into_iter()
2559 .next()
2560 .map(|choice| choice.message)
2561 .ok_or_else(|| {
2562 ::zeroclaw_log::record!(
2563 ERROR,
2564 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
2565 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
2566 .with_attrs(::serde_json::json!({"model_provider": &self.name})),
2567 "compatible: empty choices in response"
2568 );
2569 anyhow::Error::msg(format!("No response from {}", self.name))
2570 })?;
2571
2572 let mut result = self.parse_native_response(message);
2573 result.usage = usage;
2574 Ok(result)
2575 }
2576
2577 fn supports_native_tools(&self) -> bool {
2578 self.native_tool_calling
2579 }
2580
2581 fn supports_streaming(&self) -> bool {
2582 true
2583 }
2584
2585 fn supports_streaming_tool_events(&self) -> bool {
2586 self.native_tool_calling
2588 }
2589
2590 fn stream_chat(
2591 &self,
2592 request: ProviderChatRequest<'_>,
2593 model: &str,
2594 temperature: Option<f64>,
2595 options: StreamOptions,
2596 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
2597 if !options.enabled {
2598 return stream::once(async { Ok(StreamEvent::Final) }).boxed();
2599 }
2600
2601 let temperature = temperature.unwrap_or(self.default_temperature());
2602 let provider = self.clone();
2603 let messages_owned: Vec<ChatMessage> = request.messages.to_vec();
2604 let tools_owned: Option<Vec<zeroclaw_api::tool::ToolSpec>> =
2605 request.tools.map(<[zeroclaw_api::tool::ToolSpec]>::to_vec);
2606 let model = model.to_string();
2607 let count_tokens = options.count_tokens;
2608 let options_enabled = options.enabled;
2609
2610 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamEvent>>(100);
2611
2612 tokio::spawn(async move {
2613 let normalized = match Self::normalize_messages_for_upstream(&messages_owned).await {
2614 Ok(n) => n,
2615 Err(err) => {
2616 let _ = tx
2617 .send(Err(StreamError::ModelProvider(err.to_string())))
2618 .await;
2619 return;
2620 }
2621 };
2622
2623 let merge = provider.effective_merge_system(&model);
2624 let has_tools = tools_owned.as_ref().is_some_and(|tools| !tools.is_empty());
2625 let effective_messages = Self::flatten_system_messages(&normalized, merge);
2626 let effective_messages = provider.strip_native_tool_messages(&effective_messages);
2627 let tools = provider.convert_tool_specs_for_model(tools_owned.as_deref(), &model);
2628
2629 let payload_result = if has_tools {
2630 serde_json::to_value(NativeChatRequest {
2631 model: model.clone(),
2632 messages: provider.convert_messages_for_native(&effective_messages, !merge),
2633 temperature,
2634 reasoning_effort: provider.reasoning_effort_for_model(&model),
2635 tool_stream: if options_enabled {
2636 provider.tool_stream_for_tools(true)
2637 } else {
2638 None
2639 },
2640 stream: Some(options_enabled),
2641 stream_options: options_enabled.then_some(StreamOptionsBody {
2645 include_usage: true,
2646 }),
2647 tools: tools.clone(),
2648 tool_choice: tools.as_ref().map(|_| "auto".to_string()),
2649 max_tokens: provider.max_tokens,
2650 })
2651 } else {
2652 let messages = effective_messages
2653 .iter()
2654 .map(|message| Message {
2655 role: message.role.clone(),
2656 content: Self::to_message_content(&message.role, &message.content, !merge),
2657 })
2658 .collect();
2659
2660 serde_json::to_value(ApiChatRequest {
2661 model: model.clone(),
2662 messages,
2663 temperature,
2664 reasoning_effort: provider.reasoning_effort_for_model(&model),
2665 tool_stream: if options_enabled {
2666 provider.tool_stream_for_tools(false)
2667 } else {
2668 None
2669 },
2670 stream: Some(options_enabled),
2671 stream_options: options_enabled.then_some(StreamOptionsBody {
2672 include_usage: true,
2673 }),
2674 tools: None,
2675 tool_choice: None,
2676 max_tokens: provider.max_tokens,
2677 })
2678 };
2679
2680 let payload = match payload_result {
2681 Ok(payload) => payload,
2682 Err(error) => {
2683 let _ = tx.send(Err(StreamError::Json(error))).await;
2684 return;
2685 }
2686 };
2687
2688 let url = provider.chat_completions_url();
2689 let client = provider.streaming_http_client();
2690 let auth_header = provider.auth_header.clone();
2691 let credential = provider.credential.clone();
2692 let targets_mistral_tool_call_contract = provider.targets_mistral_tool_call_contract();
2693
2694 let mut req_builder = client.post(&url).json(&payload);
2695 req_builder = apply_auth_to_request(req_builder, &auth_header, credential.as_deref());
2696 req_builder = req_builder.header("Accept", "text/event-stream");
2697
2698 let response = match req_builder.send().await {
2699 Ok(r) => r,
2700 Err(e) => {
2701 let _ = tx
2702 .send(Err(StreamError::Http(super::format_error_chain(&e))))
2703 .await;
2704 return;
2705 }
2706 };
2707
2708 if !response.status().is_success() {
2709 let status = response.status();
2710 let error = match response.text().await {
2711 Ok(text) => text,
2712 Err(_) => format!("HTTP error: {}", status),
2713 };
2714 let _ = tx
2715 .send(Err(StreamError::ModelProvider(format!(
2716 "{}: {}",
2717 status, error
2718 ))))
2719 .await;
2720 return;
2721 }
2722
2723 let mut event_stream = sse_bytes_to_events_for_contract(
2724 response,
2725 count_tokens,
2726 targets_mistral_tool_call_contract,
2727 );
2728 while let Some(event) = event_stream.next().await {
2729 if tx.send(event).await.is_err() {
2730 break;
2731 }
2732 }
2733 });
2734
2735 stream::unfold(rx, |mut rx| async move {
2736 rx.recv().await.map(|event| (event, rx))
2737 })
2738 .boxed()
2739 }
2740
2741 fn stream_chat_with_system(
2742 &self,
2743 system_prompt: Option<&str>,
2744 message: &str,
2745 model: &str,
2746 temperature: Option<f64>,
2747 options: StreamOptions,
2748 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
2749 let temperature = temperature.unwrap_or(self.default_temperature());
2750 let provider = self.clone();
2751 let system_prompt_owned: Option<String> = system_prompt.map(str::to_string);
2752 let message_owned = message.to_string();
2753 let model = model.to_string();
2754 let count_tokens = options.count_tokens;
2755 let options_enabled = options.enabled;
2756
2757 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
2759
2760 tokio::spawn(async move {
2761 let user_msg = ChatMessage {
2765 role: "user".to_string(),
2766 content: message_owned,
2767 };
2768 let normalized_user = match Self::normalize_messages_for_upstream(std::slice::from_ref(
2769 &user_msg,
2770 ))
2771 .await
2772 {
2773 Ok(mut msgs) => msgs.pop().unwrap_or(user_msg),
2774 Err(err) => {
2775 let _ = tx
2776 .send(Err(StreamError::ModelProvider(err.to_string())))
2777 .await;
2778 return;
2779 }
2780 };
2781 let normalized_message_content = normalized_user.content;
2782
2783 let merge = provider.effective_merge_system(&model);
2784 let mut messages = Vec::new();
2785 if merge {
2786 let content = match system_prompt_owned.as_deref() {
2787 Some(sys) => format!("{sys}\n\n{normalized_message_content}"),
2788 None => normalized_message_content,
2789 };
2790 messages.push(Message {
2791 role: "user".to_string(),
2792 content: Self::to_message_content("user", &content, !merge),
2793 });
2794 } else {
2795 if let Some(sys) = system_prompt_owned {
2796 messages.push(Message {
2797 role: "system".to_string(),
2798 content: MessageContent::Text(sys),
2799 });
2800 }
2801 messages.push(Message {
2802 role: "user".to_string(),
2803 content: Self::to_message_content("user", &normalized_message_content, !merge),
2804 });
2805 }
2806
2807 let request = ApiChatRequest {
2808 model: model.clone(),
2809 messages,
2810 temperature,
2811 stream: Some(options_enabled),
2812 stream_options: options_enabled.then_some(StreamOptionsBody {
2813 include_usage: true,
2814 }),
2815 reasoning_effort: provider.reasoning_effort_for_model(&model),
2816 tool_stream: None,
2817 tools: None,
2818 tool_choice: None,
2819 max_tokens: provider.max_tokens,
2820 };
2821
2822 let url = provider.chat_completions_url();
2823 let client = provider.streaming_http_client();
2824 let auth_header = provider.auth_header.clone();
2825 let credential = provider.credential.clone();
2826
2827 let mut req_builder = client.post(&url).json(&request);
2829
2830 req_builder = apply_auth_to_request(req_builder, &auth_header, credential.as_deref());
2832
2833 req_builder = req_builder.header("Accept", "text/event-stream");
2835
2836 let response = match req_builder.send().await {
2838 Ok(r) => r,
2839 Err(e) => {
2840 let _ = tx
2841 .send(Err(StreamError::Http(super::format_error_chain(&e))))
2842 .await;
2843 return;
2844 }
2845 };
2846
2847 if !response.status().is_success() {
2849 let status = response.status();
2850 let error = match response.text().await {
2851 Ok(e) => e,
2852 Err(_) => format!("HTTP error: {}", status),
2853 };
2854 let _ = tx
2855 .send(Err(StreamError::ModelProvider(format!(
2856 "{}: {}",
2857 status, error
2858 ))))
2859 .await;
2860 return;
2861 }
2862
2863 let mut chunk_stream = sse_bytes_to_chunks(response, count_tokens);
2865 while let Some(chunk) = chunk_stream.next().await {
2866 if tx.send(chunk).await.is_err() {
2867 break; }
2869 }
2870 });
2871
2872 stream::unfold(rx, |mut rx| async move {
2874 rx.recv().await.map(|chunk| (chunk, rx))
2875 })
2876 .boxed()
2877 }
2878
2879 fn stream_chat_with_history(
2880 &self,
2881 messages: &[ChatMessage],
2882 model: &str,
2883 temperature: Option<f64>,
2884 options: StreamOptions,
2885 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
2886 let temperature = temperature.unwrap_or(self.default_temperature());
2887 let provider = self.clone();
2888 let messages_owned: Vec<ChatMessage> = messages.to_vec();
2889 let model = model.to_string();
2890 let count_tokens = options.count_tokens;
2891 let options_enabled = options.enabled;
2892
2893 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
2894
2895 tokio::spawn(async move {
2896 let normalized = match Self::normalize_messages_for_upstream(&messages_owned).await {
2897 Ok(n) => n,
2898 Err(err) => {
2899 let _ = tx
2900 .send(Err(StreamError::ModelProvider(err.to_string())))
2901 .await;
2902 return;
2903 }
2904 };
2905
2906 let merge = provider.effective_merge_system(&model);
2907 let effective_messages = Self::flatten_system_messages(&normalized, merge);
2908 let effective_messages = provider.strip_native_tool_messages(&effective_messages);
2909 let api_messages: Vec<Message> = effective_messages
2910 .iter()
2911 .map(|m| Message {
2912 role: m.role.clone(),
2913 content: Self::to_message_content(&m.role, &m.content, !merge),
2914 })
2915 .collect();
2916
2917 let request = ApiChatRequest {
2918 model: model.clone(),
2919 messages: api_messages,
2920 temperature,
2921 stream: Some(options_enabled),
2922 stream_options: options_enabled.then_some(StreamOptionsBody {
2923 include_usage: true,
2924 }),
2925 reasoning_effort: provider.reasoning_effort_for_model(&model),
2926 tool_stream: None,
2927 tools: None,
2928 tool_choice: None,
2929 max_tokens: provider.max_tokens,
2930 };
2931
2932 let url = provider.chat_completions_url();
2933 let client = provider.streaming_http_client();
2934 let auth_header = provider.auth_header.clone();
2935 let credential = provider.credential.clone();
2936
2937 let mut req_builder = client.post(&url).json(&request);
2938 req_builder = apply_auth_to_request(req_builder, &auth_header, credential.as_deref());
2939 req_builder = req_builder.header("Accept", "text/event-stream");
2940
2941 let response = match req_builder.send().await {
2942 Ok(r) => r,
2943 Err(e) => {
2944 let _ = tx
2945 .send(Err(StreamError::Http(super::format_error_chain(&e))))
2946 .await;
2947 return;
2948 }
2949 };
2950
2951 if !response.status().is_success() {
2952 let status = response.status();
2953 let error = match response.text().await {
2954 Ok(e) => e,
2955 Err(_) => format!("HTTP error: {}", status),
2956 };
2957 let _ = tx
2958 .send(Err(StreamError::ModelProvider(format!(
2959 "{}: {}",
2960 status, error
2961 ))))
2962 .await;
2963 return;
2964 }
2965
2966 let mut chunk_stream = sse_bytes_to_chunks(response, count_tokens);
2967 while let Some(chunk) = chunk_stream.next().await {
2968 if tx.send(chunk).await.is_err() {
2969 break;
2970 }
2971 }
2972 });
2973
2974 stream::unfold(rx, |mut rx| async move {
2975 rx.recv().await.map(|chunk| (chunk, rx))
2976 })
2977 .boxed()
2978 }
2979
2980 async fn warmup(&self) -> anyhow::Result<()> {
2981 let url = self.chat_completions_url();
2984 let _ = self
2985 .apply_auth_header(self.http_client().get(&url), self.credential.as_deref())
2986 .send()
2987 .await?;
2988 Ok(())
2989 }
2990}
2991
2992impl ::zeroclaw_api::attribution::Attributable for OpenAiCompatibleModelProvider {
2993 fn role(&self) -> ::zeroclaw_api::attribution::Role {
2994 ::zeroclaw_api::attribution::Role::Provider(
2995 ::zeroclaw_api::attribution::ProviderKind::Model(
2996 ::zeroclaw_api::attribution::ModelProviderKind::Plugin,
2997 ),
2998 )
2999 }
3000 fn alias(&self) -> &str {
3001 &self.alias
3002 }
3003}
3004
3005#[cfg(test)]
3006mod tests {
3007 use super::*;
3008
3009 fn make_model_provider(
3010 name: &str,
3011 url: &str,
3012 key: Option<&str>,
3013 ) -> OpenAiCompatibleModelProvider {
3014 OpenAiCompatibleModelProvider::new("test", name, url, key, AuthStyle::Bearer)
3015 }
3016
3017 #[test]
3018 fn creates_with_key() {
3019 let p = make_model_provider(
3020 "venice",
3021 "https://api.venice.ai",
3022 Some("venice-test-credential"),
3023 );
3024 assert_eq!(p.name, "venice");
3025 assert_eq!(p.base_url, "https://api.venice.ai");
3026 assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
3027 }
3028
3029 #[test]
3030 fn creates_without_key() {
3031 let p = make_model_provider("test", "https://example.com", None);
3032 assert!(p.credential.is_none());
3033 }
3034
3035 #[test]
3036 fn strips_trailing_slash() {
3037 let p = make_model_provider("test", "https://example.com/", None);
3038 assert_eq!(p.base_url, "https://example.com");
3039 }
3040
3041 #[tokio::test]
3042 async fn chat_without_key_attempts_request() {
3043 let p = make_model_provider("Local", "http://127.0.0.1:1", None);
3044 let result = p
3045 .chat_with_system(None, "hello", "default", Some(0.7))
3046 .await;
3047 assert!(result.is_err());
3048 let err_msg = result.unwrap_err().to_string();
3049 assert!(
3050 !err_msg.contains("API key not set"),
3051 "should not get credential error, got: {err_msg}"
3052 );
3053 }
3054
3055 #[test]
3056 fn native_chat_request_with_tools_includes_stream_options() {
3057 let req = NativeChatRequest {
3062 model: "gpt-4o".to_string(),
3063 messages: vec![NativeMessage {
3064 role: "user".to_string(),
3065 content: Some(MessageContent::Text("hello".to_string())),
3066 tool_call_id: None,
3067 tool_calls: None,
3068 reasoning_content: None,
3069 }],
3070 temperature: 0.7,
3071 stream: Some(true),
3072 stream_options: Some(StreamOptionsBody {
3073 include_usage: true,
3074 }),
3075 reasoning_effort: None,
3076 tool_stream: None,
3077 tools: Some(vec![serde_json::json!({"name": "echo"})]),
3078 tool_choice: Some("auto".to_string()),
3079 max_tokens: None,
3080 };
3081 let value: serde_json::Value = serde_json::to_value(&req).unwrap();
3082 assert_eq!(
3083 value
3084 .get("stream_options")
3085 .and_then(|v| v.get("include_usage"))
3086 .and_then(serde_json::Value::as_bool),
3087 Some(true),
3088 "tool-enabled streaming request must serialize stream_options.include_usage=true; \
3089 without it OpenAI-compatible providers omit the final usage event"
3090 );
3091 }
3092
3093 #[test]
3094 fn native_chat_request_omits_stream_options_when_none() {
3095 let req = NativeChatRequest {
3099 model: "gpt-4o".to_string(),
3100 messages: vec![],
3101 temperature: 0.7,
3102 stream: Some(false),
3103 stream_options: None,
3104 reasoning_effort: None,
3105 tool_stream: None,
3106 tools: None,
3107 tool_choice: None,
3108 max_tokens: None,
3109 };
3110 let value: serde_json::Value = serde_json::to_value(&req).unwrap();
3111 assert!(
3112 value.get("stream_options").is_none(),
3113 "non-streaming NativeChatRequest must not emit a stream_options key"
3114 );
3115 }
3116
3117 #[test]
3118 fn normalize_model_ids_trims_filters_and_sorts() {
3119 let body = serde_json::from_value(serde_json::json!({
3120 "data": [
3121 {"id": " zeta-model "},
3122 {"id": ""},
3123 {"id": "alpha-model"}
3124 ]
3125 }))
3126 .unwrap();
3127
3128 assert_eq!(normalize_model_ids(body), vec!["alpha-model", "zeta-model"]);
3129 }
3130
3131 #[test]
3132 fn request_serializes_correctly() {
3133 let req = ApiChatRequest {
3134 model: "llama-3.3-70b".to_string(),
3135 messages: vec![
3136 Message {
3137 role: "system".to_string(),
3138 content: MessageContent::Text("You are ZeroClaw".to_string()),
3139 },
3140 Message {
3141 role: "user".to_string(),
3142 content: MessageContent::Text("hello".to_string()),
3143 },
3144 ],
3145 temperature: 0.4,
3146 stream: Some(false),
3147 stream_options: None,
3148 reasoning_effort: None,
3149 tool_stream: None,
3150 tools: None,
3151 tool_choice: None,
3152 max_tokens: None,
3153 };
3154 let json = serde_json::to_string(&req).unwrap();
3155 assert!(json.contains("llama-3.3-70b"));
3156 assert!(json.contains("system"));
3157 assert!(json.contains("user"));
3158 assert!(!json.contains("tools"));
3160 assert!(!json.contains("tool_choice"));
3161 }
3162
3163 #[test]
3164 fn response_deserializes() {
3165 let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#;
3166 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
3167 assert_eq!(
3168 resp.choices[0].message.content,
3169 Some("Hello from Venice!".to_string())
3170 );
3171 }
3172
3173 #[test]
3174 fn response_deserializes_content_as_openai_text_parts_array() {
3175 let json =
3176 r#"{"choices":[{"message":{"content":[{"type":"text","text":"Hello array"}]}}]}"#;
3177 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
3178 assert_eq!(
3179 resp.choices[0].message.content.as_deref(),
3180 Some("Hello array")
3181 );
3182 }
3183
3184 #[test]
3185 fn response_deserializes_multiple_text_parts_with_newlines() {
3186 let json = r#"{"choices":[{"message":{"content":[{"type":"text","text":"Hello"},{"type":"image_url","image_url":{"url":"https://example.com/image.png"}},{"type":"text","text":"array"}]}}]}"#;
3187 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
3188 assert_eq!(
3189 resp.choices[0].message.content.as_deref(),
3190 Some("Hello\narray")
3191 );
3192 }
3193
3194 #[test]
3195 fn response_rejects_unsupported_top_level_content_shape() {
3196 let json = r#"{"choices":[{"message":{"content":{"type":"text","text":"Hello object"}}}]}"#;
3197 serde_json::from_str::<ApiChatResponse>(json)
3198 .expect_err("object-shaped assistant content must remain an invalid payload");
3199 }
3200
3201 #[test]
3202 fn response_empty_choices() {
3203 let json = r#"{"choices":[]}"#;
3204 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
3205 assert!(resp.choices.is_empty());
3206 }
3207
3208 #[test]
3209 fn parse_chat_response_body_reports_sanitized_snippet() {
3210 let body = r#"{"choices":"invalid","api_key":"sk-test-secret-value"}"#;
3211 let err = parse_chat_response_body("custom", body).expect_err("payload should fail");
3212 let msg = err.to_string();
3213
3214 assert!(msg.contains("custom API returned an unexpected chat-completions payload"));
3215 assert!(msg.contains("body="));
3216 assert!(msg.contains("[REDACTED]"));
3217 assert!(!msg.contains("sk-test-secret-value"));
3218 }
3219
3220 #[test]
3221 fn x_api_key_auth_style() {
3222 let p = OpenAiCompatibleModelProvider::new(
3223 "test",
3224 "moonshot",
3225 "https://api.moonshot.cn",
3226 Some("ms-key"),
3227 AuthStyle::XApiKey,
3228 );
3229 assert!(matches!(p.auth_header, AuthStyle::XApiKey));
3230 }
3231
3232 #[test]
3233 fn custom_auth_style() {
3234 let p = OpenAiCompatibleModelProvider::new(
3235 "test",
3236 "custom",
3237 "https://api.example.com",
3238 Some("key"),
3239 AuthStyle::Custom("X-Custom-Key".into()),
3240 );
3241 assert!(matches!(p.auth_header, AuthStyle::Custom(_)));
3242 }
3243
3244 #[test]
3245 fn zhipu_jwt_produces_valid_three_part_token() {
3246 let result = zhipu_jwt_bearer("testid.testsecret").unwrap();
3247 assert!(result.starts_with("Bearer "));
3248 let jwt = result.strip_prefix("Bearer ").unwrap();
3249 let parts: Vec<&str> = jwt.split('.').collect();
3250 assert_eq!(parts.len(), 3, "JWT must have 3 dot-separated parts: {jwt}");
3251 }
3252
3253 #[test]
3254 fn zhipu_jwt_header_is_correct() {
3255 use base64::engine::{Engine, general_purpose::URL_SAFE_NO_PAD};
3256 let result = zhipu_jwt_bearer("myid.mysecret").unwrap();
3257 let jwt = result.strip_prefix("Bearer ").unwrap();
3258 let header_b64 = jwt.split('.').next().unwrap();
3259 let header_bytes = URL_SAFE_NO_PAD.decode(header_b64).unwrap();
3260 let header: serde_json::Value = serde_json::from_slice(&header_bytes).unwrap();
3261 assert_eq!(header["alg"], "HS256");
3262 assert_eq!(header["typ"], "JWT");
3263 assert_eq!(header["sign_type"], "SIGN");
3264 }
3265
3266 #[test]
3267 fn zhipu_jwt_payload_contains_api_key_and_timestamps() {
3268 use base64::engine::{Engine, general_purpose::URL_SAFE_NO_PAD};
3269 let result = zhipu_jwt_bearer("myapiid.mysecretkey").unwrap();
3270 let jwt = result.strip_prefix("Bearer ").unwrap();
3271 let payload_b64 = jwt.split('.').nth(1).unwrap();
3272 let payload_bytes = URL_SAFE_NO_PAD.decode(payload_b64).unwrap();
3273 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
3274 assert_eq!(payload["api_key"], "myapiid");
3275 assert!(payload["exp"].is_number());
3276 assert!(payload["timestamp"].is_number());
3277 let ts = payload["timestamp"].as_u64().unwrap();
3279 let exp = payload["exp"].as_u64().unwrap();
3280 assert_eq!(exp - ts, 210_000);
3281 }
3282
3283 #[test]
3284 fn zhipu_jwt_signature_is_verifiable() {
3285 let secret = "testsecret123";
3286 let credential = format!("testid.{secret}");
3287 let result = zhipu_jwt_bearer(&credential).unwrap();
3288 let jwt = result.strip_prefix("Bearer ").unwrap();
3289 let parts: Vec<&str> = jwt.split('.').collect();
3290 let signing_input = format!("{}.{}", parts[0], parts[1]);
3291
3292 let key = ring::hmac::Key::new(ring::hmac::HMAC_SHA256, secret.as_bytes());
3294 use base64::engine::{Engine, general_purpose::URL_SAFE_NO_PAD};
3295 let sig_bytes = URL_SAFE_NO_PAD.decode(parts[2]).unwrap();
3296 ring::hmac::verify(&key, signing_input.as_bytes(), &sig_bytes)
3297 .expect("signature must verify");
3298 }
3299
3300 #[test]
3301 fn zhipu_jwt_rejects_invalid_key_format() {
3302 assert!(zhipu_jwt_bearer("no-dot-here").is_err());
3303 assert!(zhipu_jwt_bearer("").is_err());
3304 }
3305
3306 #[test]
3307 fn zhipu_jwt_auth_style_applies_correctly() {
3308 let p = OpenAiCompatibleModelProvider::new(
3309 "test",
3310 "Z.AI",
3311 "https://api.z.ai/api/coding/paas/v4",
3312 Some("testid.testsecret"),
3313 AuthStyle::ZhipuJwt,
3314 );
3315 assert!(matches!(p.auth_header, AuthStyle::ZhipuJwt));
3316 }
3317
3318 #[tokio::test]
3319 async fn all_compatible_providers_attempt_request_without_key() {
3320 let model_providers = vec![
3321 make_model_provider("Venice", "http://127.0.0.1:1", None),
3322 make_model_provider("Moonshot", "http://127.0.0.1:1", None),
3323 make_model_provider("GLM", "http://127.0.0.1:1", None),
3324 make_model_provider("MiniMax", "http://127.0.0.1:1", None),
3325 make_model_provider("Groq", "http://127.0.0.1:1", None),
3326 make_model_provider("Mistral", "http://127.0.0.1:1", None),
3327 make_model_provider("xAI", "http://127.0.0.1:1", None),
3328 make_model_provider("Astrai", "http://127.0.0.1:1", None),
3329 ];
3330
3331 for p in model_providers {
3332 let result = p.chat_with_system(None, "test", "model", Some(0.7)).await;
3333 assert!(result.is_err(), "{} should fail (unreachable host)", p.name);
3334 let err_msg = result.unwrap_err().to_string();
3335 assert!(
3336 !err_msg.contains("API key not set"),
3337 "{} should get transport error, not credential error, got: {err_msg}",
3338 p.name
3339 );
3340 }
3341 }
3342
3343 #[test]
3344 fn tool_call_function_name_falls_back_to_top_level_name() {
3345 let call: ToolCall = serde_json::from_value(serde_json::json!({
3346 "name": "memory_recall",
3347 "arguments": "{\"query\":\"latest roadmap\"}"
3348 }))
3349 .unwrap();
3350
3351 assert_eq!(call.function_name().as_deref(), Some("memory_recall"));
3352 }
3353
3354 #[test]
3355 fn tool_call_function_arguments_falls_back_to_parameters_object() {
3356 let call: ToolCall = serde_json::from_value(serde_json::json!({
3357 "name": "shell",
3358 "parameters": {"command": "pwd"}
3359 }))
3360 .unwrap();
3361
3362 assert_eq!(
3363 call.function_arguments().as_deref(),
3364 Some("{\"command\":\"pwd\"}")
3365 );
3366 }
3367
3368 #[test]
3369 fn tool_call_function_arguments_prefers_nested_function_field() {
3370 let call: ToolCall = serde_json::from_value(serde_json::json!({
3371 "name": "ignored_name",
3372 "arguments": "{\"query\":\"ignored\"}",
3373 "function": {
3374 "name": "memory_recall",
3375 "arguments": "{\"query\":\"preferred\"}"
3376 }
3377 }))
3378 .unwrap();
3379
3380 assert_eq!(call.function_name().as_deref(), Some("memory_recall"));
3381 assert_eq!(
3382 call.function_arguments().as_deref(),
3383 Some("{\"query\":\"preferred\"}")
3384 );
3385 }
3386
3387 #[test]
3392 fn chat_completions_url_standard_openai() {
3393 let p = make_model_provider("openai", "https://api.openai.com/v1", None);
3395 assert_eq!(
3396 p.chat_completions_url(),
3397 "https://api.openai.com/v1/chat/completions"
3398 );
3399 }
3400
3401 #[test]
3402 fn chat_completions_url_trailing_slash() {
3403 let p = make_model_provider("test", "https://api.example.com/v1/", None);
3405 assert_eq!(
3406 p.chat_completions_url(),
3407 "https://api.example.com/v1/chat/completions"
3408 );
3409 }
3410
3411 #[test]
3412 fn chat_completions_url_volcengine_ark() {
3413 let p = make_model_provider(
3415 "volcengine",
3416 "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions",
3417 None,
3418 );
3419 assert_eq!(
3420 p.chat_completions_url(),
3421 "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions"
3422 );
3423 }
3424
3425 #[test]
3426 fn chat_completions_url_custom_full_endpoint() {
3427 let p = make_model_provider(
3429 "custom",
3430 "https://my-api.example.com/v2/llm/chat/completions",
3431 None,
3432 );
3433 assert_eq!(
3434 p.chat_completions_url(),
3435 "https://my-api.example.com/v2/llm/chat/completions"
3436 );
3437 }
3438
3439 #[test]
3440 fn chat_completions_url_requires_exact_suffix_match() {
3441 let p = make_model_provider(
3442 "custom",
3443 "https://my-api.example.com/v2/llm/chat/completions-proxy",
3444 None,
3445 );
3446 assert_eq!(
3447 p.chat_completions_url(),
3448 "https://my-api.example.com/v2/llm/chat/completions-proxy/chat/completions"
3449 );
3450 }
3451
3452 #[test]
3453 fn chat_completions_url_without_v1() {
3454 let p = make_model_provider("test", "https://api.example.com", None);
3456 assert_eq!(
3457 p.chat_completions_url(),
3458 "https://api.example.com/chat/completions"
3459 );
3460 }
3461
3462 #[test]
3463 fn chat_completions_url_base_with_v1() {
3464 let p = make_model_provider("test", "https://api.example.com/v1", None);
3466 assert_eq!(
3467 p.chat_completions_url(),
3468 "https://api.example.com/v1/chat/completions"
3469 );
3470 }
3471
3472 #[test]
3477 fn chat_completions_url_zai() {
3478 let p = make_model_provider("zai", "https://api.z.ai/api/paas/v4", None);
3480 assert_eq!(
3481 p.chat_completions_url(),
3482 "https://api.z.ai/api/paas/v4/chat/completions"
3483 );
3484 }
3485
3486 #[test]
3487 fn chat_completions_url_minimax() {
3488 let p = make_model_provider("minimax", "https://api.minimaxi.com/v1", None);
3490 assert_eq!(
3491 p.chat_completions_url(),
3492 "https://api.minimaxi.com/v1/chat/completions"
3493 );
3494 }
3495
3496 #[test]
3497 fn chat_completions_url_glm() {
3498 let p = make_model_provider("glm", "https://open.bigmodel.cn/api/paas/v4", None);
3500 assert_eq!(
3501 p.chat_completions_url(),
3502 "https://open.bigmodel.cn/api/paas/v4/chat/completions"
3503 );
3504 }
3505
3506 #[test]
3507 fn chat_completions_url_opencode() {
3508 let p = make_model_provider("opencode", "https://opencode.ai/zen/v1", None);
3510 assert_eq!(
3511 p.chat_completions_url(),
3512 "https://opencode.ai/zen/v1/chat/completions"
3513 );
3514 }
3515
3516 #[test]
3517 fn chat_completions_url_opencode_go() {
3518 let p = make_model_provider("opencode-go", "https://opencode.ai/zen/go/v1", None);
3520 assert_eq!(
3521 p.chat_completions_url(),
3522 "https://opencode.ai/zen/go/v1/chat/completions"
3523 );
3524 }
3525
3526 #[test]
3527 fn parse_native_response_preserves_tool_call_id() {
3528 let provider = make_model_provider("test", "https://example.com", None);
3529 let message = ResponseMessage {
3530 content: None,
3531 tool_calls: Some(vec![ToolCall {
3532 id: Some("call_123".to_string()),
3533 kind: Some("function".to_string()),
3534 function: Some(Function {
3535 name: Some("shell".to_string()),
3536 arguments: Some(r#"{"command":"pwd"}"#.to_string()),
3537 }),
3538 name: None,
3539 arguments: None,
3540 parameters: None,
3541 extra_content: None,
3542 }]),
3543 reasoning_content: None,
3544 };
3545
3546 let parsed = provider.parse_native_response(message);
3547 assert_eq!(parsed.tool_calls.len(), 1);
3548 assert_eq!(parsed.tool_calls[0].id, "call_123");
3549 assert_eq!(parsed.tool_calls[0].name, "shell");
3550 }
3551
3552 #[test]
3553 fn parse_native_response_mistral_normalizes_invalid_tool_call_id() {
3554 let provider = make_model_provider("Mistral", "https://api.mistral.ai/v1", None);
3555 let message = ResponseMessage {
3556 content: None,
3557 tool_calls: Some(vec![ToolCall {
3558 id: Some("xvL0p9bZ41j2X0O3Q1y9vL0p9bZ41j2X".to_string()),
3559 kind: Some("function".to_string()),
3560 function: Some(Function {
3561 name: Some("shell".to_string()),
3562 arguments: Some(r#"{"command":"pwd"}"#.to_string()),
3563 }),
3564 name: None,
3565 arguments: None,
3566 parameters: None,
3567 extra_content: None,
3568 }]),
3569 reasoning_content: None,
3570 };
3571
3572 let parsed = provider.parse_native_response(message);
3573 assert_eq!(parsed.tool_calls.len(), 1);
3574 let id = &parsed.tool_calls[0].id;
3575 assert_eq!(id.len(), 9);
3576 assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
3577 }
3578
3579 #[test]
3580 fn parse_native_response_mistral_generates_valid_id_when_missing() {
3581 let provider = make_model_provider("Mistral", "https://api.mistral.ai/v1", None);
3582 let message = ResponseMessage {
3583 content: None,
3584 tool_calls: Some(vec![ToolCall {
3585 id: None,
3586 kind: Some("function".to_string()),
3587 function: Some(Function {
3588 name: Some("shell".to_string()),
3589 arguments: Some(r#"{"command":"pwd"}"#.to_string()),
3590 }),
3591 name: None,
3592 arguments: None,
3593 parameters: None,
3594 extra_content: None,
3595 }]),
3596 reasoning_content: None,
3597 };
3598
3599 let parsed = provider.parse_native_response(message);
3600 assert_eq!(parsed.tool_calls.len(), 1);
3601 let id = &parsed.tool_calls[0].id;
3602 assert_eq!(id.len(), 9);
3603 assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
3604 }
3605
3606 #[test]
3607 fn parse_native_response_custom_mistral_endpoint_normalizes_tool_call_id() {
3608 let provider = make_model_provider("Custom", "https://api.mistral.ai/v1", None);
3609 let message = ResponseMessage {
3610 content: None,
3611 tool_calls: Some(vec![ToolCall {
3612 id: Some("xvL0p9bZ41j2X0O3Q1y9vL0p9bZ41j2X".to_string()),
3613 kind: Some("function".to_string()),
3614 function: Some(Function {
3615 name: Some("shell".to_string()),
3616 arguments: Some(r#"{"command":"pwd"}"#.to_string()),
3617 }),
3618 name: None,
3619 arguments: None,
3620 parameters: None,
3621 extra_content: None,
3622 }]),
3623 reasoning_content: None,
3624 };
3625
3626 let parsed = provider.parse_native_response(message);
3627 assert_eq!(parsed.tool_calls.len(), 1);
3628 let id = &parsed.tool_calls[0].id;
3629 assert_eq!(id.len(), 9);
3630 assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
3631 }
3632
3633 #[test]
3634 fn parse_native_response_mistral_avoids_id_collision_after_normalization() {
3635 let provider = make_model_provider("Mistral", "https://api.mistral.ai/v1", None);
3636 let message = ResponseMessage {
3637 content: None,
3638 tool_calls: Some(vec![
3639 ToolCall {
3640 id: Some("ABCDEFGHI123".to_string()),
3641 kind: Some("function".to_string()),
3642 function: Some(Function {
3643 name: Some("shell".to_string()),
3644 arguments: Some(r#"{"command":"pwd"}"#.to_string()),
3645 }),
3646 name: None,
3647 arguments: None,
3648 parameters: None,
3649 extra_content: None,
3650 },
3651 ToolCall {
3652 id: Some("ABCDEFGHIxyz".to_string()),
3653 kind: Some("function".to_string()),
3654 function: Some(Function {
3655 name: Some("echo".to_string()),
3656 arguments: Some(r#"{"text":"ok"}"#.to_string()),
3657 }),
3658 name: None,
3659 arguments: None,
3660 parameters: None,
3661 extra_content: None,
3662 },
3663 ]),
3664 reasoning_content: None,
3665 };
3666
3667 let parsed = provider.parse_native_response(message);
3668 assert_eq!(parsed.tool_calls.len(), 2);
3669 let id0 = &parsed.tool_calls[0].id;
3670 let id1 = &parsed.tool_calls[1].id;
3671 assert_eq!(id0.len(), 9);
3672 assert_eq!(id1.len(), 9);
3673 assert!(id0.chars().all(|c| c.is_ascii_alphanumeric()));
3674 assert!(id1.chars().all(|c| c.is_ascii_alphanumeric()));
3675 assert_ne!(id0, id1);
3676 }
3677
3678 #[test]
3679 fn convert_messages_for_native_maps_tool_result_payload() {
3680 let input = vec![ChatMessage::tool(
3681 r#"{"tool_call_id":"call_abc","content":"done"}"#,
3682 )];
3683
3684 let provider = make_model_provider("test", "https://example.com", None);
3685 let converted = provider.convert_messages_for_native(&input, true);
3686 assert_eq!(converted.len(), 1);
3687 assert_eq!(converted[0].role, "tool");
3688 assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_abc"));
3689 assert!(matches!(
3690 converted[0].content.as_ref(),
3691 Some(MessageContent::Text(value)) if value == "done"
3692 ));
3693 }
3694
3695 #[test]
3696 fn native_chat_request_mistral_serializes_matching_valid_tool_call_ids() {
3697 let provider = make_model_provider("Mistral", "https://api.mistral.ai/v1", None);
3698 let invalid_id = "chatcmpl-tool-abc";
3699 let history_json = serde_json::json!({
3700 "content": "",
3701 "tool_calls": [{
3702 "id": invalid_id,
3703 "name": "shell",
3704 "arguments": "{\"cmd\":\"pwd\"}"
3705 }]
3706 });
3707 let messages = vec![
3708 ChatMessage::assistant(history_json.to_string()),
3709 ChatMessage::tool(
3710 serde_json::json!({
3711 "tool_call_id": invalid_id,
3712 "content": "done"
3713 })
3714 .to_string(),
3715 ),
3716 ];
3717
3718 let req = NativeChatRequest {
3719 model: "mistral-large-latest".to_string(),
3720 messages: provider.convert_messages_for_native(&messages, true),
3721 temperature: 0.7,
3722 stream: Some(false),
3723 stream_options: None,
3724 reasoning_effort: None,
3725 tool_stream: None,
3726 tools: Some(vec![serde_json::json!({
3727 "type": "function",
3728 "function": {
3729 "name": "shell",
3730 "description": "Run a shell command",
3731 "parameters": {"type": "object"}
3732 }
3733 })]),
3734 tool_choice: Some("auto".to_string()),
3735 max_tokens: None,
3736 };
3737
3738 let value = serde_json::to_value(&req).unwrap();
3739 let assistant_id = value["messages"][0]["tool_calls"][0]["id"]
3740 .as_str()
3741 .expect("assistant tool call id should serialize");
3742 let tool_id = value["messages"][1]["tool_call_id"]
3743 .as_str()
3744 .expect("tool result id should serialize");
3745
3746 assert_ne!(assistant_id, invalid_id);
3747 assert!(is_valid_mistral_tool_call_id(assistant_id));
3748 assert_eq!(assistant_id, tool_id);
3749 }
3750
3751 #[test]
3752 fn convert_messages_for_native_keeps_user_image_markers_as_text_when_disabled() {
3753 let input = vec![ChatMessage::user(
3754 "System primer [IMAGE:data:image/png;base64,abcd] user turn",
3755 )];
3756
3757 let provider = make_model_provider("test", "https://example.com", None);
3758 let converted = provider.convert_messages_for_native(&input, false);
3759 assert_eq!(converted.len(), 1);
3760 assert_eq!(converted[0].role, "user");
3761 assert!(matches!(
3762 converted[0].content.as_ref(),
3763 Some(MessageContent::Text(value))
3764 if value == "System primer [IMAGE:data:image/png;base64,abcd] user turn"
3765 ));
3766 }
3767
3768 #[test]
3769 fn flatten_system_messages_merges_into_first_user() {
3770 let input = vec![
3771 ChatMessage::system("core policy"),
3772 ChatMessage::assistant("ack"),
3773 ChatMessage::system("delivery rules"),
3774 ChatMessage::user("hello"),
3775 ChatMessage::assistant("post-user"),
3776 ];
3777
3778 let output = OpenAiCompatibleModelProvider::flatten_system_messages(&input, true);
3779 assert_eq!(output.len(), 3);
3780 assert_eq!(output[0].role, "assistant");
3781 assert_eq!(output[0].content, "ack");
3782 assert_eq!(output[1].role, "user");
3783 assert_eq!(output[1].content, "core policy\n\ndelivery rules\n\nhello");
3784 assert_eq!(output[2].role, "assistant");
3785 assert_eq!(output[2].content, "post-user");
3786 assert!(output.iter().all(|m| m.role != "system"));
3787 }
3788
3789 #[test]
3790 fn flatten_system_messages_inserts_user_when_missing() {
3791 let input = vec![
3792 ChatMessage::system("core policy"),
3793 ChatMessage::assistant("ack"),
3794 ];
3795
3796 let output = OpenAiCompatibleModelProvider::flatten_system_messages(&input, true);
3797 assert_eq!(output.len(), 2);
3798 assert_eq!(output[0].role, "user");
3799 assert_eq!(output[0].content, "core policy");
3800 assert_eq!(output[1].role, "assistant");
3801 assert_eq!(output[1].content, "ack");
3802 }
3803
3804 #[test]
3805 fn strip_think_tags_drops_unclosed_block_suffix() {
3806 let input = "visible<think>hidden";
3807 assert_eq!(strip_think_tags(input), "visible");
3808 }
3809
3810 #[test]
3811 fn native_tool_schema_unsupported_detection_is_precise() {
3812 assert!(
3813 OpenAiCompatibleModelProvider::is_native_tool_schema_unsupported(
3814 reqwest::StatusCode::BAD_REQUEST,
3815 "unknown parameter: tools"
3816 )
3817 );
3818 assert!(
3819 !OpenAiCompatibleModelProvider::is_native_tool_schema_unsupported(
3820 reqwest::StatusCode::UNAUTHORIZED,
3821 "unknown parameter: tools"
3822 )
3823 );
3824 }
3825
3826 #[test]
3827 fn native_tool_schema_unsupported_detects_groq_tool_validation_error() {
3828 assert!(
3829 OpenAiCompatibleModelProvider::is_native_tool_schema_unsupported(
3830 reqwest::StatusCode::BAD_REQUEST,
3831 r#"Groq API error (400 Bad Request): {"error":{"message":"tool call validation failed: attempted to call tool 'memory_recall={\"limit\":5}' which was not in request"}}"#
3832 )
3833 );
3834 }
3835
3836 #[test]
3837 fn prompt_guided_tool_fallback_injects_system_instruction() {
3838 let input = vec![ChatMessage::user("check status")];
3839 let tools = vec![zeroclaw_api::tool::ToolSpec {
3840 name: "shell_exec".to_string(),
3841 description: "Execute shell command".to_string(),
3842 parameters: serde_json::json!({
3843 "type": "object",
3844 "properties": {
3845 "command": { "type": "string" }
3846 },
3847 "required": ["command"]
3848 }),
3849 }];
3850
3851 let output = OpenAiCompatibleModelProvider::with_prompt_guided_tool_instructions(
3852 &input,
3853 Some(&tools),
3854 );
3855 assert!(!output.is_empty());
3856 assert_eq!(output[0].role, "system");
3857 assert!(output[0].content.contains("Available Tools"));
3858 assert!(output[0].content.contains("shell_exec"));
3859 }
3860
3861 #[test]
3862 fn reasoning_effort_only_applies_to_openai_and_selected_codex_models() {
3863 let model_provider = make_model_provider("test", "https://example.com", None)
3864 .with_reasoning_effort(Some("high".to_string()));
3865
3866 assert_eq!(
3867 model_provider.reasoning_effort_for_model("o1-preview"),
3868 Some("high".to_string())
3869 );
3870 assert_eq!(
3871 model_provider.reasoning_effort_for_model("openai/o3-mini"),
3872 Some("high".to_string())
3873 );
3874 assert_eq!(
3875 model_provider.reasoning_effort_for_model("o4-mini"),
3876 Some("high".to_string())
3877 );
3878 assert_eq!(
3879 model_provider.reasoning_effort_for_model("gpt-5"),
3880 Some("high".to_string())
3881 );
3882 assert_eq!(
3883 model_provider.reasoning_effort_for_model("gpt-5.3-codex"),
3884 Some("high".to_string())
3885 );
3886 assert_eq!(
3887 model_provider.reasoning_effort_for_model("openai/gpt-5"),
3888 Some("high".to_string())
3889 );
3890 assert_eq!(
3891 model_provider.reasoning_effort_for_model("gpt-4-codex"),
3892 Some("high".to_string())
3893 );
3894 assert_eq!(
3895 model_provider.reasoning_effort_for_model("llama-3-codex"),
3896 None,
3897 "generic codex-like model names must not receive OpenAI-only reasoning_effort",
3898 );
3899 assert_eq!(
3900 model_provider.reasoning_effort_for_model("llama-3.3-70b"),
3901 None
3902 );
3903 }
3904
3905 #[tokio::test]
3906 async fn warmup_without_key_attempts_connection() {
3907 let model_provider = make_model_provider("test", "http://127.0.0.1:1", None);
3908 let result = model_provider.warmup().await;
3909 assert!(result.is_err());
3910 let err_msg = result.unwrap_err().to_string();
3911 assert!(
3912 !err_msg.contains("API key not set"),
3913 "should not get credential error, got: {err_msg}"
3914 );
3915 }
3916
3917 #[test]
3922 fn capabilities_reports_native_tool_calling() {
3923 let p = make_model_provider("test", "https://example.com", None);
3924 let caps = <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p);
3925 assert!(caps.native_tool_calling);
3926 assert!(!caps.vision);
3927 }
3928
3929 #[test]
3930 fn capabilities_reports_vision_for_qwen_compatible_provider() {
3931 let p = OpenAiCompatibleModelProvider::new_with_vision(
3932 "test",
3933 "Qwen",
3934 "https://dashscope.aliyuncs.com/compatible-mode/v1",
3935 Some("k"),
3936 AuthStyle::Bearer,
3937 true,
3938 );
3939 let caps = <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p);
3940 assert!(caps.native_tool_calling);
3941 assert!(caps.vision);
3942 }
3943
3944 #[test]
3945 fn minimax_provider_supports_native_tool_calling_with_system_merge() {
3946 let p = OpenAiCompatibleModelProvider::new(
3947 "test",
3948 "MiniMax",
3949 "https://api.minimax.chat/v1",
3950 Some("k"),
3951 AuthStyle::Bearer,
3952 )
3953 .with_merge_system_into_user();
3954 let caps = <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p);
3955 assert!(
3956 caps.native_tool_calling,
3957 "MiniMax should preserve native tool calling when system messages are merged"
3958 );
3959 assert!(!caps.vision);
3960 }
3961
3962 #[test]
3965 fn strip_native_tool_messages_removes_tool_and_tool_calls() {
3966 let messages = vec![
3967 ChatMessage::system("sys"),
3968 ChatMessage::user("search for cats"),
3969 ChatMessage::assistant(
3970 r#"{"content":"I'll search","tool_calls":[{"id":"chatcmpl-tool-abc","name":"web_search","arguments":"{}"}]}"#,
3971 ),
3972 ChatMessage::tool(
3973 r#"{"tool_call_id":"chatcmpl-tool-abc","content":"Found 10 results"}"#,
3974 ),
3975 ChatMessage::assistant("Here are the results about cats"),
3976 ChatMessage::user("thanks"),
3977 ];
3978 let p = OpenAiCompatibleModelProvider::new_merge_system_into_user(
3979 "test",
3980 "MiniMax",
3981 "https://api.minimax.chat/v1",
3982 Some("k"),
3983 AuthStyle::Bearer,
3984 );
3985 let stripped = p.strip_native_tool_messages(&messages);
3986 assert_eq!(stripped.len(), 4);
3991 assert_eq!(stripped[0].role, "system");
3992 assert_eq!(stripped[1].role, "user");
3993 assert_eq!(stripped[1].content, "search for cats");
3994 assert_eq!(stripped[2].role, "assistant");
3995 assert!(
3996 stripped[2].content.starts_with("I'll search"),
3997 "coalesced assistant must preserve the pre-tool narration; got {:?}",
3998 stripped[2].content
3999 );
4000 assert!(
4001 stripped[2]
4002 .content
4003 .contains("Here are the results about cats"),
4004 "coalesced assistant must preserve the post-tool reply; got {:?}",
4005 stripped[2].content
4006 );
4007 assert!(
4008 !stripped[2].content.contains("tool_calls"),
4009 "tool_calls structure must be stripped"
4010 );
4011 assert_eq!(stripped[3].role, "user");
4012 }
4013
4014 #[test]
4015 fn strip_native_tool_messages_drops_empty_assistant_tool_calls() {
4016 let messages = vec![
4017 ChatMessage::system("sys"),
4018 ChatMessage::user("do it"),
4019 ChatMessage::assistant(
4020 r#"{"content":"","tool_calls":[{"id":"tc1","name":"shell","arguments":"{}"}]}"#,
4021 ),
4022 ChatMessage::tool(r#"{"tool_call_id":"tc1","content":"ok"}"#),
4023 ChatMessage::assistant("Done"),
4024 ];
4025 let p = OpenAiCompatibleModelProvider::new_merge_system_into_user(
4026 "test",
4027 "MiniMax",
4028 "https://api.minimax.chat/v1",
4029 Some("k"),
4030 AuthStyle::Bearer,
4031 );
4032 let stripped = p.strip_native_tool_messages(&messages);
4033 assert_eq!(stripped.len(), 3);
4035 assert_eq!(stripped[0].role, "system");
4036 assert_eq!(stripped[1].role, "user");
4037 assert_eq!(stripped[2].role, "assistant");
4038 assert_eq!(stripped[2].content, "Done");
4039 }
4040
4041 #[test]
4042 fn strip_native_tool_messages_preserves_regular_messages() {
4043 let messages = vec![
4044 ChatMessage::system("sys"),
4045 ChatMessage::user("hello"),
4046 ChatMessage::assistant("hi there"),
4047 ChatMessage::user("bye"),
4048 ];
4049 let p = OpenAiCompatibleModelProvider::new_merge_system_into_user(
4050 "test",
4051 "MiniMax",
4052 "https://api.minimax.chat/v1",
4053 Some("k"),
4054 AuthStyle::Bearer,
4055 );
4056 let stripped = p.strip_native_tool_messages(&messages);
4057 assert_eq!(stripped.len(), 4);
4058 for (orig, result) in messages.iter().zip(stripped.iter()) {
4059 assert_eq!(orig.role, result.role);
4060 assert_eq!(orig.content, result.content);
4061 }
4062 }
4063
4064 #[test]
4068 fn strip_native_tool_messages_passthrough_when_native_tool_calling_enabled() {
4069 let messages = vec![
4070 ChatMessage::system("sys"),
4071 ChatMessage::user("search for cats"),
4072 ChatMessage::assistant(
4073 r#"{"content":"I'll search","tool_calls":[{"id":"chatcmpl-tool-abc","name":"web_search","arguments":"{}"}]}"#,
4074 ),
4075 ChatMessage::tool(
4076 r#"{"tool_call_id":"chatcmpl-tool-abc","content":"Found 10 results"}"#,
4077 ),
4078 ChatMessage::assistant("Here are the results about cats"),
4079 ];
4080 let p = OpenAiCompatibleModelProvider::new(
4081 "test",
4082 "NativeToolProvider",
4083 "https://api.example.com/v1",
4084 Some("k"),
4085 AuthStyle::Bearer,
4086 );
4087 assert!(
4088 <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p).native_tool_calling,
4089 "model_provider must have native_tool_calling enabled for this test"
4090 );
4091 let result = p.strip_native_tool_messages(&messages);
4092 assert_eq!(result.len(), messages.len());
4093 for (orig, out) in messages.iter().zip(result.iter()) {
4094 assert_eq!(orig.role, out.role);
4095 assert_eq!(orig.content, out.content);
4096 }
4097 }
4098
4099 #[test]
4100 fn user_agent_constructor_keeps_native_tool_calling_enabled() {
4101 let p = OpenAiCompatibleModelProvider::new_with_user_agent(
4102 "test",
4103 "TestProvider",
4104 "https://example.com",
4105 Some("k"),
4106 AuthStyle::Bearer,
4107 "zeroclaw-test/1.0",
4108 );
4109 let caps = <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p);
4110 assert!(caps.native_tool_calling);
4111 assert!(!caps.vision);
4112 assert_eq!(p.user_agent.as_deref(), Some("zeroclaw-test/1.0"));
4113 }
4114
4115 #[test]
4116 fn user_agent_and_vision_constructor_preserves_capability_flags() {
4117 let p = OpenAiCompatibleModelProvider::new_with_user_agent_and_vision(
4118 "test",
4119 "VisionModelProvider",
4120 "https://example.com",
4121 Some("k"),
4122 AuthStyle::Bearer,
4123 "zeroclaw-test/vision",
4124 true,
4125 );
4126 let caps = <OpenAiCompatibleModelProvider as ModelProvider>::capabilities(&p);
4127 assert!(caps.native_tool_calling);
4128 assert!(caps.vision);
4129 assert_eq!(p.user_agent.as_deref(), Some("zeroclaw-test/vision"));
4130 }
4131
4132 #[test]
4133 fn to_message_content_converts_image_markers_to_openai_parts() {
4134 let content = "Describe this\n\n[IMAGE:data:image/png;base64,abcd]";
4135 let value = serde_json::to_value(OpenAiCompatibleModelProvider::to_message_content(
4136 "user", content, true,
4137 ))
4138 .unwrap();
4139 let parts = value
4140 .as_array()
4141 .expect("multimodal content should be an array");
4142 assert_eq!(parts.len(), 2);
4143 assert_eq!(parts[0]["type"], "text");
4144 assert_eq!(parts[0]["text"], "Describe this");
4145 assert_eq!(parts[1]["type"], "image_url");
4146 assert_eq!(parts[1]["image_url"]["url"], "data:image/png;base64,abcd");
4147 }
4148
4149 #[test]
4150 fn to_message_content_keeps_markers_as_text_when_user_image_parts_disabled() {
4151 let content = "Policy [IMAGE:data:image/png;base64,abcd]";
4152 let value = serde_json::to_value(OpenAiCompatibleModelProvider::to_message_content(
4153 "user", content, false,
4154 ))
4155 .unwrap();
4156 assert_eq!(value, serde_json::json!(content));
4157 }
4158
4159 #[test]
4160 fn to_message_content_keeps_plain_text_for_non_user_roles() {
4161 let value = serde_json::to_value(OpenAiCompatibleModelProvider::to_message_content(
4162 "system",
4163 "You are a helpful assistant.",
4164 true,
4165 ))
4166 .unwrap();
4167 assert_eq!(value, serde_json::json!("You are a helpful assistant."));
4168 }
4169
4170 #[tokio::test]
4171 async fn normalize_messages_for_upstream_rewrites_local_image_path_to_data_uri() {
4172 let tmp = tempfile::TempDir::new().expect("tempdir");
4176 let path = tmp.path().join("pixel.png");
4177 let png: [u8; 67] = [
4179 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48,
4180 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x06, 0x00, 0x00,
4181 0x00, 0x1F, 0x15, 0xC4, 0x89, 0x00, 0x00, 0x00, 0x0A, 0x49, 0x44, 0x41, 0x54, 0x78,
4182 0x9C, 0x63, 0x00, 0x01, 0x00, 0x00, 0x05, 0x00, 0x01, 0x0D, 0x0A, 0x2D, 0xB4, 0x00,
4183 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82,
4184 ];
4185 std::fs::write(&path, png).expect("write pixel.png");
4186 let path_str = path.to_string_lossy().into_owned();
4187
4188 let msg = ChatMessage {
4189 role: "user".into(),
4190 content: format!("Caption please [IMAGE:{}]", path_str),
4191 };
4192
4193 let normalized = OpenAiCompatibleModelProvider::normalize_messages_for_upstream(
4194 std::slice::from_ref(&msg),
4195 )
4196 .await
4197 .expect("normalize ok");
4198
4199 assert_eq!(normalized.len(), 1);
4200 let content = &normalized[0].content;
4201 assert!(
4202 content.contains("[IMAGE:data:image/png;base64,"),
4203 "expected base64 data URI in normalized content, got: {content}"
4204 );
4205 assert!(
4206 !content.contains(&path_str),
4207 "raw local path must not leak to upstream, got: {content}"
4208 );
4209 }
4210
4211 #[test]
4212 fn request_serializes_with_tools() {
4213 let tools = vec![serde_json::json!({
4214 "type": "function",
4215 "function": {
4216 "name": "get_weather",
4217 "description": "Get weather for a location",
4218 "parameters": {
4219 "type": "object",
4220 "properties": {
4221 "location": {"type": "string"}
4222 }
4223 }
4224 }
4225 })];
4226
4227 let req = ApiChatRequest {
4228 model: "test-model".to_string(),
4229 messages: vec![Message {
4230 role: "user".to_string(),
4231 content: MessageContent::Text("What is the weather?".to_string()),
4232 }],
4233 temperature: 0.7,
4234 stream: Some(false),
4235 stream_options: None,
4236 reasoning_effort: None,
4237 tool_stream: None,
4238 tools: Some(tools),
4239 tool_choice: Some("auto".to_string()),
4240 max_tokens: None,
4241 };
4242 let json = serde_json::to_string(&req).unwrap();
4243 assert!(json.contains("\"tools\""));
4244 assert!(json.contains("get_weather"));
4245 assert!(json.contains("\"tool_choice\":\"auto\""));
4246 }
4247
4248 #[test]
4249 fn zai_tool_requests_enable_tool_stream() {
4250 let model_provider = make_model_provider("zai", "https://api.z.ai/api/paas/v4", None);
4251 let req = ApiChatRequest {
4252 model: "glm-5".to_string(),
4253 messages: vec![Message {
4254 role: "user".to_string(),
4255 content: MessageContent::Text("List /tmp".to_string()),
4256 }],
4257 temperature: 0.7,
4258 stream: Some(false),
4259 stream_options: None,
4260 reasoning_effort: None,
4261 tool_stream: model_provider.tool_stream_for_tools(true),
4262 tools: Some(vec![serde_json::json!({
4263 "type": "function",
4264 "function": {
4265 "name": "shell",
4266 "description": "Run a shell command",
4267 "parameters": {
4268 "type": "object",
4269 "properties": {
4270 "command": {"type": "string"}
4271 }
4272 }
4273 }
4274 })]),
4275 tool_choice: Some("auto".to_string()),
4276 max_tokens: None,
4277 };
4278
4279 let json = serde_json::to_string(&req).unwrap();
4280 assert!(json.contains("\"tool_stream\":true"));
4281 }
4282
4283 #[test]
4284 fn non_zai_tool_requests_omit_tool_stream() {
4285 let model_provider = make_model_provider("test", "https://api.example.com/v1", None);
4286 let req = ApiChatRequest {
4287 model: "test-model".to_string(),
4288 messages: vec![Message {
4289 role: "user".to_string(),
4290 content: MessageContent::Text("List /tmp".to_string()),
4291 }],
4292 temperature: 0.7,
4293 stream: Some(false),
4294 stream_options: None,
4295 reasoning_effort: None,
4296 tool_stream: model_provider.tool_stream_for_tools(true),
4297 tools: Some(vec![serde_json::json!({
4298 "type": "function",
4299 "function": {
4300 "name": "shell",
4301 "description": "Run a shell command",
4302 "parameters": {
4303 "type": "object",
4304 "properties": {
4305 "command": {"type": "string"}
4306 }
4307 }
4308 }
4309 })]),
4310 tool_choice: Some("auto".to_string()),
4311 max_tokens: None,
4312 };
4313
4314 let json = serde_json::to_string(&req).unwrap();
4315 assert!(!json.contains("\"tool_stream\""));
4316 }
4317
4318 #[test]
4319 fn non_zai_provider_omits_tool_stream_regardless_of_streaming() {
4320 let model_provider = make_model_provider("custom", "https://proxy.example.com/v1", None);
4321 assert_eq!(model_provider.tool_stream_for_tools(true), None);
4323 assert_eq!(model_provider.tool_stream_for_tools(false), None);
4324 }
4325
4326 #[test]
4327 fn z_ai_host_enables_tool_stream_for_custom_profiles() {
4328 let model_provider =
4329 make_model_provider("custom", "https://api.z.ai/api/coding/paas/v4", None);
4330 assert_eq!(model_provider.tool_stream_for_tools(true), Some(true));
4331 }
4332
4333 #[test]
4334 fn response_with_tool_calls_deserializes() {
4335 let json = r#"{
4336 "choices": [{
4337 "message": {
4338 "content": null,
4339 "tool_calls": [{
4340 "type": "function",
4341 "function": {
4342 "name": "get_weather",
4343 "arguments": "{\"location\":\"London\"}"
4344 }
4345 }]
4346 }
4347 }]
4348 }"#;
4349
4350 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4351 let msg = &resp.choices[0].message;
4352 assert!(msg.content.is_none());
4353 let tool_calls = msg.tool_calls.as_ref().unwrap();
4354 assert_eq!(tool_calls.len(), 1);
4355 assert_eq!(
4356 tool_calls[0].function.as_ref().unwrap().name.as_deref(),
4357 Some("get_weather")
4358 );
4359 assert_eq!(
4360 tool_calls[0]
4361 .function
4362 .as_ref()
4363 .unwrap()
4364 .arguments
4365 .as_deref(),
4366 Some("{\"location\":\"London\"}")
4367 );
4368 }
4369
4370 #[test]
4371 fn response_with_multiple_tool_calls() {
4372 let json = r#"{
4373 "choices": [{
4374 "message": {
4375 "content": "I'll check both.",
4376 "tool_calls": [
4377 {
4378 "type": "function",
4379 "function": {
4380 "name": "get_weather",
4381 "arguments": "{\"location\":\"London\"}"
4382 }
4383 },
4384 {
4385 "type": "function",
4386 "function": {
4387 "name": "get_time",
4388 "arguments": "{\"timezone\":\"UTC\"}"
4389 }
4390 }
4391 ]
4392 }
4393 }]
4394 }"#;
4395
4396 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4397 let msg = &resp.choices[0].message;
4398 assert_eq!(msg.content.as_deref(), Some("I'll check both."));
4399 let tool_calls = msg.tool_calls.as_ref().unwrap();
4400 assert_eq!(tool_calls.len(), 2);
4401 assert_eq!(
4402 tool_calls[0].function.as_ref().unwrap().name.as_deref(),
4403 Some("get_weather")
4404 );
4405 assert_eq!(
4406 tool_calls[1].function.as_ref().unwrap().name.as_deref(),
4407 Some("get_time")
4408 );
4409 }
4410
4411 #[tokio::test]
4412 async fn chat_with_tools_without_key_attempts_request() {
4413 let p = make_model_provider("TestProvider", "http://127.0.0.1:1", None);
4414 let messages = vec![ChatMessage {
4415 role: "user".to_string(),
4416 content: "hello".to_string(),
4417 }];
4418 let tools = vec![serde_json::json!({
4419 "type": "function",
4420 "function": {
4421 "name": "test_tool",
4422 "description": "A test tool",
4423 "parameters": {}
4424 }
4425 })];
4426
4427 let result = p
4428 .chat_with_tools(&messages, &tools, "model", Some(0.7))
4429 .await;
4430 assert!(result.is_err());
4431 let err_msg = result.unwrap_err().to_string();
4432 assert!(
4433 !err_msg.contains("API key not set"),
4434 "should not get credential error, got: {err_msg}"
4435 );
4436 }
4437
4438 #[test]
4439 fn chat_with_tools_request_preserves_reasoning_content_in_history() {
4440 let p = make_model_provider("DeepSeek", "https://api.deepseek.example/v1", None);
4441 let history_json = serde_json::json!({
4442 "content": "I will inspect the workspace.",
4443 "tool_calls": [{
4444 "id": "call_1",
4445 "name": "shell",
4446 "arguments": "{\"cmd\":\"ls\"}"
4447 }],
4448 "reasoning_content": "Need to inspect the current files before answering."
4449 });
4450 let messages = vec![
4451 ChatMessage::assistant(history_json.to_string()),
4452 ChatMessage::tool(r#"{"tool_call_id":"call_1","content":"src\nCargo.toml"}"#),
4453 ChatMessage::user("continue"),
4454 ];
4455 let tools = vec![serde_json::json!({
4456 "type": "function",
4457 "function": {
4458 "name": "shell",
4459 "description": "Run a shell command",
4460 "parameters": {}
4461 }
4462 })];
4463
4464 let request = p.build_native_tool_chat_request(
4465 &messages,
4466 Some(tools),
4467 "deepseek-v4-flash",
4468 0.7,
4469 true,
4470 );
4471 let value = serde_json::to_value(&request).unwrap();
4472 let first_message = &value["messages"][0];
4473
4474 assert_eq!(first_message["role"], "assistant");
4475 assert_eq!(
4476 first_message["reasoning_content"],
4477 "Need to inspect the current files before answering."
4478 );
4479 assert!(
4480 first_message["tool_calls"].is_array(),
4481 "assistant tool-call history must stay native in chat_with_tools requests"
4482 );
4483 assert_eq!(value["tools"][0]["function"]["name"], "shell");
4484 assert_eq!(value["tool_choice"], "auto");
4485 }
4486
4487 #[test]
4488 fn response_with_no_tool_calls_has_empty_vec() {
4489 let json = r#"{"choices":[{"message":{"content":"Just text, no tools."}}]}"#;
4490 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4491 let msg = &resp.choices[0].message;
4492 assert_eq!(msg.content.as_deref(), Some("Just text, no tools."));
4493 assert!(msg.tool_calls.is_none());
4494 }
4495
4496 #[test]
4497 fn flatten_system_messages_merges_into_first_user_and_removes_system_roles() {
4498 let messages = vec![
4499 ChatMessage::system("System A"),
4500 ChatMessage::assistant("Earlier assistant turn"),
4501 ChatMessage::system("System B"),
4502 ChatMessage::user("User turn"),
4503 ChatMessage::tool(r#"{"ok":true}"#),
4504 ];
4505
4506 let flattened = OpenAiCompatibleModelProvider::flatten_system_messages(&messages, true);
4507 assert_eq!(flattened.len(), 3);
4508 assert_eq!(flattened[0].role, "assistant");
4509 assert_eq!(
4510 flattened[1].content,
4511 "System A\n\nSystem B\n\nUser turn".to_string()
4512 );
4513 assert_eq!(flattened[1].role, "user");
4514 assert_eq!(flattened[2].role, "tool");
4515 assert!(!flattened.iter().any(|m| m.role == "system"));
4516 }
4517
4518 #[test]
4519 fn flatten_system_messages_keeps_system_only_at_start_without_user_merge() {
4520 let messages = vec![
4521 ChatMessage::system("System A"),
4522 ChatMessage::user("User turn"),
4523 ChatMessage::assistant("Assistant turn"),
4524 ChatMessage::system("System B"),
4525 ChatMessage::user("Follow-up"),
4526 ];
4527
4528 let flattened = OpenAiCompatibleModelProvider::flatten_system_messages(&messages, false);
4529 assert_eq!(
4530 flattened
4531 .iter()
4532 .map(|message| message.role.as_str())
4533 .collect::<Vec<_>>(),
4534 vec!["system", "user", "assistant", "user"]
4535 );
4536 assert_eq!(
4537 flattened
4538 .iter()
4539 .filter(|message| message.role == "system")
4540 .count(),
4541 1
4542 );
4543 assert!(flattened[0].content.contains("System A"));
4544 assert!(flattened[0].content.contains("System B"));
4545 }
4546
4547 #[test]
4548 fn flatten_system_messages_drops_empty_system_messages() {
4549 let messages = vec![
4550 ChatMessage::system(""),
4551 ChatMessage::user("User turn"),
4552 ChatMessage::system(""),
4553 ];
4554
4555 let flattened = OpenAiCompatibleModelProvider::flatten_system_messages(&messages, false);
4556
4557 assert_eq!(flattened.len(), 1);
4558 assert_eq!(flattened[0].role, "user");
4559 assert_eq!(flattened[0].content, "User turn");
4560 }
4561
4562 #[test]
4563 fn flatten_system_messages_inserts_synthetic_user_when_no_user_exists() {
4564 let messages = vec![
4565 ChatMessage::assistant("Assistant only"),
4566 ChatMessage::system("Synthetic system"),
4567 ];
4568
4569 let flattened = OpenAiCompatibleModelProvider::flatten_system_messages(&messages, true);
4570 assert_eq!(flattened.len(), 2);
4571 assert_eq!(flattened[0].role, "user");
4572 assert_eq!(flattened[0].content, "Synthetic system");
4573 assert_eq!(flattened[1].role, "assistant");
4574 }
4575
4576 #[test]
4577 fn strip_think_tags_removes_multiple_blocks_with_surrounding_text() {
4578 let input = "Answer A <think>hidden 1</think> and B <think>hidden 2</think> done";
4579 let output = strip_think_tags(input);
4580 assert_eq!(output, "Answer A and B done");
4581 }
4582
4583 #[test]
4584 fn strip_think_tags_drops_tail_for_unclosed_block() {
4585 let input = "Visible<think>hidden tail";
4586 let output = strip_think_tags(input);
4587 assert_eq!(output, "Visible");
4588 }
4589
4590 #[test]
4595 fn reasoning_content_fallback_when_content_empty() {
4596 let json = r#"{"choices":[{"message":{"content":"","reasoning_content":"Thinking output here"}}]}"#;
4598 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4599 let msg = &resp.choices[0].message;
4600 assert_eq!(msg.effective_content(), "Thinking output here");
4601 }
4602
4603 #[test]
4604 fn reasoning_content_fallback_when_content_null() {
4605 let json =
4607 r#"{"choices":[{"message":{"content":null,"reasoning_content":"Fallback text"}}]}"#;
4608 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4609 let msg = &resp.choices[0].message;
4610 assert_eq!(msg.effective_content(), "Fallback text");
4611 }
4612
4613 #[test]
4614 fn reasoning_content_fallback_when_content_missing() {
4615 let json = r#"{"choices":[{"message":{"reasoning_content":"Only reasoning"}}]}"#;
4617 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4618 let msg = &resp.choices[0].message;
4619 assert_eq!(msg.effective_content(), "Only reasoning");
4620 }
4621
4622 #[test]
4623 fn reasoning_content_not_used_when_content_present() {
4624 let json = r#"{"choices":[{"message":{"content":"Normal response","reasoning_content":"Should be ignored"}}]}"#;
4626 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4627 let msg = &resp.choices[0].message;
4628 assert_eq!(msg.effective_content(), "Normal response");
4629 }
4630
4631 #[test]
4632 fn reasoning_content_used_when_content_only_think_tags() {
4633 let json = r#"{"choices":[{"message":{"content":"<think>secret</think>","reasoning_content":"Fallback text"}}]}"#;
4634 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4635 let msg = &resp.choices[0].message;
4636 assert_eq!(msg.effective_content(), "Fallback text");
4637 assert_eq!(
4638 msg.effective_content_optional().as_deref(),
4639 Some("Fallback text")
4640 );
4641 }
4642
4643 #[test]
4644 fn reasoning_content_both_absent_returns_empty() {
4645 let json = r#"{"choices":[{"message":{}}]}"#;
4647 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4648 let msg = &resp.choices[0].message;
4649 assert_eq!(msg.effective_content(), "");
4650 }
4651
4652 #[test]
4653 fn reasoning_content_ignored_by_normal_models() {
4654 let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#;
4656 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4657 let msg = &resp.choices[0].message;
4658 assert!(msg.reasoning_content.is_none());
4659 assert_eq!(msg.effective_content(), "Hello from Venice!");
4660 }
4661
4662 #[test]
4667 fn parse_sse_line_with_content() {
4668 let line = r#"data: {"choices":[{"delta":{"content":"hello"}}]}"#;
4669 let result = parse_sse_line(line).unwrap().unwrap();
4670 assert_eq!(result.delta, "hello");
4671 assert!(result.reasoning.is_none());
4672 }
4673
4674 #[test]
4675 fn parse_sse_line_with_reasoning_content() {
4676 let line = r#"data: {"choices":[{"delta":{"reasoning_content":"thinking..."}}]}"#;
4677 let result = parse_sse_line(line).unwrap().unwrap();
4678 assert!(result.delta.is_empty());
4679 assert_eq!(result.reasoning.as_deref(), Some("thinking..."));
4680 }
4681
4682 #[test]
4683 fn parse_sse_line_with_both_prefers_content() {
4684 let line = r#"data: {"choices":[{"delta":{"content":"real answer","reasoning_content":"thinking..."}}]}"#;
4685 let result = parse_sse_line(line).unwrap().unwrap();
4686 assert_eq!(result.delta, "real answer");
4687 assert!(result.reasoning.is_none());
4688 }
4689
4690 #[test]
4691 fn parse_sse_line_with_empty_content_falls_back_to_reasoning() {
4692 let line =
4693 r#"data: {"choices":[{"delta":{"content":"","reasoning_content":"thinking..."}}]}"#;
4694 let result = parse_sse_line(line).unwrap().unwrap();
4695 assert!(result.delta.is_empty());
4696 assert_eq!(result.reasoning.as_deref(), Some("thinking..."));
4697 }
4698
4699 #[test]
4703 fn parse_sse_line_accepts_reasoning_alias() {
4704 let line = r#"data: {"choices":[{"delta":{"reasoning":"thinking via vllm..."}}]}"#;
4705 let result = parse_sse_line(line).unwrap().unwrap();
4706 assert!(result.delta.is_empty());
4707 assert_eq!(result.reasoning.as_deref(), Some("thinking via vllm..."));
4708 }
4709
4710 #[test]
4711 fn parse_sse_line_with_empty_content_and_reasoning_alias() {
4712 let line = r#"data: {"choices":[{"delta":{"content":"","reasoning":"vllm thought"}}]}"#;
4713 let result = parse_sse_line(line).unwrap().unwrap();
4714 assert!(result.delta.is_empty());
4715 assert_eq!(result.reasoning.as_deref(), Some("vllm thought"));
4716 }
4717
4718 #[test]
4719 fn response_message_accepts_reasoning_alias_on_non_stream_path() {
4720 let json = r#"{"content":null,"reasoning":"chain-of-thought via vllm","tool_calls":null}"#;
4722 let msg: ResponseMessage = serde_json::from_str(json).unwrap();
4723 assert!(msg.content.is_none());
4724 assert_eq!(
4725 msg.reasoning_content.as_deref(),
4726 Some("chain-of-thought via vllm"),
4727 "the `reasoning` alias must populate the canonical reasoning_content field",
4728 );
4729 assert_eq!(msg.effective_content(), "chain-of-thought via vllm");
4731 }
4732
4733 #[test]
4734 fn response_message_canonical_reasoning_content_still_works() {
4735 let json = r#"{"content":null,"reasoning_content":"canonical thought","tool_calls":null}"#;
4737 let msg: ResponseMessage = serde_json::from_str(json).unwrap();
4738 assert_eq!(msg.reasoning_content.as_deref(), Some("canonical thought"));
4739 }
4740
4741 #[test]
4748 fn response_message_with_both_keys_prefers_canonical_reasoning_content() {
4749 let json = r#"{"content":null,"reasoning_content":"canonical","reasoning":"alias","tool_calls":null}"#;
4750 let msg: ResponseMessage = serde_json::from_str(json)
4751 .expect("payload with both reasoning_content and reasoning must deserialize");
4752 assert_eq!(
4753 msg.reasoning_content.as_deref(),
4754 Some("canonical"),
4755 "canonical reasoning_content must win when both fields are present",
4756 );
4757 }
4758
4759 #[test]
4760 fn response_message_with_only_alias_populates_canonical_field() {
4761 let json = r#"{"content":null,"reasoning":"alias only","tool_calls":null}"#;
4764 let msg: ResponseMessage = serde_json::from_str(json).unwrap();
4765 assert_eq!(msg.reasoning_content.as_deref(), Some("alias only"));
4766 }
4767
4768 #[test]
4769 fn stream_delta_with_both_keys_prefers_canonical_reasoning_content() {
4770 let chunk = r#"data: {"choices":[{"delta":{"reasoning_content":"canonical","reasoning":"alias"}}]}"#;
4773 let result = parse_sse_line(chunk)
4774 .expect("parse must succeed")
4775 .expect("non-empty chunk");
4776 assert_eq!(result.reasoning.as_deref(), Some("canonical"));
4777 }
4778
4779 #[test]
4782 fn round_trip_reasoning_extraction_accepts_alias() {
4783 fn extract_reasoning(value: &serde_json::Value) -> Option<String> {
4784 value
4785 .get("reasoning_content")
4786 .or_else(|| value.get("reasoning"))
4787 .and_then(serde_json::Value::as_str)
4788 .map(ToString::to_string)
4789 }
4790 let canonical: serde_json::Value =
4791 serde_json::from_str(r#"{"reasoning_content":"canonical","tool_calls":[]}"#).unwrap();
4792 let alias: serde_json::Value =
4793 serde_json::from_str(r#"{"reasoning":"vllm","tool_calls":[]}"#).unwrap();
4794 let neither: serde_json::Value = serde_json::from_str(r#"{"tool_calls":[]}"#).unwrap();
4795 let both: serde_json::Value = serde_json::from_str(
4796 r#"{"reasoning_content":"canonical","reasoning":"alias","tool_calls":[]}"#,
4797 )
4798 .unwrap();
4799 assert_eq!(extract_reasoning(&canonical).as_deref(), Some("canonical"));
4800 assert_eq!(extract_reasoning(&alias).as_deref(), Some("vllm"));
4801 assert_eq!(extract_reasoning(&neither), None);
4802 assert_eq!(extract_reasoning(&both).as_deref(), Some("canonical"));
4806 }
4807
4808 #[test]
4809 fn parse_sse_line_done_sentinel() {
4810 let line = "data: [DONE]";
4811 let result = parse_sse_line(line).unwrap();
4812 assert!(result.is_none());
4813 }
4814
4815 #[test]
4816 fn parse_sse_chunk_with_tool_call_delta() {
4817 let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","function":{"name":"shell","arguments":"{\"command\":\"date\"}"}}]}}]}"#;
4818 let chunk = parse_sse_chunk(line)
4819 .unwrap()
4820 .expect("chunk should be parsed");
4821 let choice = chunk.choices.first().expect("choice should exist");
4822 let tool_calls = choice
4823 .delta
4824 .tool_calls
4825 .as_ref()
4826 .expect("tool call deltas should exist");
4827 assert_eq!(tool_calls.len(), 1);
4828 assert_eq!(tool_calls[0].index, Some(0));
4829 assert_eq!(tool_calls[0].id.as_deref(), Some("call_1"));
4830 assert_eq!(
4831 tool_calls[0]
4832 .function
4833 .as_ref()
4834 .and_then(|function| function.name.as_deref()),
4835 Some("shell")
4836 );
4837 }
4838
4839 #[test]
4840 fn stream_tool_call_accumulator_combines_deltas() {
4841 let mut acc = StreamToolCallAccumulator::default();
4842 acc.apply_delta(&StreamToolCallDelta {
4843 index: Some(0),
4844 id: Some("call_1".to_string()),
4845 function: Some(StreamFunctionDelta {
4846 name: Some("shell".to_string()),
4847 arguments: Some("{\"command\":\"".to_string()),
4848 }),
4849 name: None,
4850 arguments: None,
4851 extra_content: None,
4852 });
4853 acc.apply_delta(&StreamToolCallDelta {
4854 index: Some(0),
4855 id: None,
4856 function: Some(StreamFunctionDelta {
4857 name: None,
4858 arguments: Some("date\"}".to_string()),
4859 }),
4860 name: None,
4861 arguments: None,
4862 extra_content: None,
4863 });
4864
4865 let mut used_tool_call_ids = std::collections::HashSet::new();
4866 let tool_call = acc
4867 .into_provider_tool_call(false, &mut used_tool_call_ids)
4868 .expect("accumulator should emit tool call");
4869 assert_eq!(tool_call.id, "call_1");
4870 assert_eq!(tool_call.name, "shell");
4871 assert_eq!(tool_call.arguments, r#"{"command":"date"}"#);
4872 }
4873
4874 #[test]
4875 fn stream_tool_call_accumulator_mistral_normalizes_invalid_id() {
4876 let mut acc = StreamToolCallAccumulator::default();
4877 acc.apply_delta(&StreamToolCallDelta {
4878 index: Some(0),
4879 id: Some("chatcmpl-tool-abc".to_string()),
4880 function: Some(StreamFunctionDelta {
4881 name: Some("shell".to_string()),
4882 arguments: Some(r#"{"command":"date"}"#.to_string()),
4883 }),
4884 name: None,
4885 arguments: None,
4886 extra_content: None,
4887 });
4888
4889 let mut used_tool_call_ids = std::collections::HashSet::new();
4890 let tool_call = acc
4891 .into_provider_tool_call(true, &mut used_tool_call_ids)
4892 .expect("accumulator should emit tool call");
4893
4894 assert_eq!(tool_call.id.len(), 9);
4895 assert!(tool_call.id.chars().all(|c| c.is_ascii_alphanumeric()));
4896 assert_ne!(tool_call.id, "chatcmpl-tool-abc");
4897 }
4898
4899 #[test]
4900 fn api_response_parses_usage() {
4901 let json = r#"{
4902 "choices": [{"message": {"content": "Hello"}}],
4903 "usage": {"prompt_tokens": 150, "completion_tokens": 60}
4904 }"#;
4905 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4906 let usage = resp.usage.unwrap();
4907 assert_eq!(usage.prompt_tokens, Some(150));
4908 assert_eq!(usage.completion_tokens, Some(60));
4909 }
4910
4911 #[test]
4912 fn api_response_parses_without_usage() {
4913 let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#;
4914 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
4915 assert!(resp.usage.is_none());
4916 }
4917
4918 #[test]
4923 fn parse_native_response_captures_reasoning_content() {
4924 let provider = make_model_provider("test", "https://example.com", None);
4925 let message = ResponseMessage {
4926 content: Some("answer".to_string()),
4927 reasoning_content: Some("thinking step".to_string()),
4928 tool_calls: Some(vec![ToolCall {
4929 id: Some("call_1".to_string()),
4930 kind: Some("function".to_string()),
4931 function: Some(Function {
4932 name: Some("shell".to_string()),
4933 arguments: Some(r#"{"cmd":"ls"}"#.to_string()),
4934 }),
4935 name: None,
4936 arguments: None,
4937 parameters: None,
4938 extra_content: None,
4939 }]),
4940 };
4941
4942 let parsed = provider.parse_native_response(message);
4943 assert_eq!(parsed.reasoning_content.as_deref(), Some("thinking step"));
4944 assert_eq!(parsed.text.as_deref(), Some("answer"));
4945 assert_eq!(parsed.tool_calls.len(), 1);
4946 }
4947
4948 #[test]
4949 fn parse_native_response_none_reasoning_content_for_normal_model() {
4950 let provider = make_model_provider("test", "https://example.com", None);
4951 let message = ResponseMessage {
4952 content: Some("hello".to_string()),
4953 reasoning_content: None,
4954 tool_calls: None,
4955 };
4956
4957 let parsed = provider.parse_native_response(message);
4958 assert!(parsed.reasoning_content.is_none());
4959 assert_eq!(parsed.text.as_deref(), Some("hello"));
4960 }
4961
4962 #[test]
4963 fn convert_messages_for_native_round_trips_reasoning_content() {
4964 let history_json = serde_json::json!({
4966 "content": "I will check",
4967 "tool_calls": [{
4968 "id": "tc_1",
4969 "name": "shell",
4970 "arguments": "{\"cmd\":\"ls\"}"
4971 }],
4972 "reasoning_content": "Let me think about this..."
4973 });
4974
4975 let messages = vec![ChatMessage::assistant(history_json.to_string())];
4976 let provider = make_model_provider("test", "https://example.com", None);
4977 let native = provider.convert_messages_for_native(&messages, true);
4978 assert_eq!(native.len(), 1);
4979 assert_eq!(native[0].role, "assistant");
4980 assert_eq!(
4981 native[0].reasoning_content.as_deref(),
4982 Some("Let me think about this...")
4983 );
4984 assert!(native[0].tool_calls.is_some());
4985 }
4986
4987 #[test]
4988 fn convert_messages_for_native_no_reasoning_content_when_absent() {
4989 let history_json = serde_json::json!({
4991 "content": "I will check",
4992 "tool_calls": [{
4993 "id": "tc_1",
4994 "name": "shell",
4995 "arguments": "{\"cmd\":\"ls\"}"
4996 }]
4997 });
4998
4999 let messages = vec![ChatMessage::assistant(history_json.to_string())];
5000 let provider = make_model_provider("test", "https://example.com", None);
5001 let native = provider.convert_messages_for_native(&messages, true);
5002 assert_eq!(native.len(), 1);
5003 assert!(native[0].reasoning_content.is_none());
5004 }
5005
5006 #[test]
5013 fn convert_messages_for_native_round_trips_reasoning_content_without_tool_calls() {
5014 let history_json = serde_json::json!({
5015 "content": "Direct answer.",
5016 "reasoning_content": "Let me think step by step..."
5017 });
5018
5019 let messages = vec![ChatMessage::assistant(history_json.to_string())];
5020 let provider = make_model_provider("test", "https://example.com", None);
5021 let native = provider.convert_messages_for_native(&messages, true);
5022 assert_eq!(native.len(), 1);
5023 assert_eq!(native[0].role, "assistant");
5024 assert!(
5025 native[0].tool_calls.is_none(),
5026 "no tool_calls on a plain-text turn"
5027 );
5028 assert_eq!(
5029 native[0].reasoning_content.as_deref(),
5030 Some("Let me think step by step...")
5031 );
5032 match &native[0].content {
5033 Some(MessageContent::Text(t)) => assert_eq!(t, "Direct answer."),
5034 other => panic!("expected text content, got {other:?}"),
5035 }
5036 }
5037
5038 #[test]
5041 fn convert_messages_for_native_content_only_json_falls_through() {
5042 let structured_answer = serde_json::json!({"content": "raw"});
5043 let raw_json = structured_answer.to_string();
5044 let messages = vec![ChatMessage::assistant(raw_json.clone())];
5045 let provider = make_model_provider("test", "https://example.com", None);
5046 let native = provider.convert_messages_for_native(&messages, true);
5047 assert_eq!(native.len(), 1);
5048 assert!(native[0].reasoning_content.is_none());
5049 assert!(native[0].tool_calls.is_none());
5050 match &native[0].content {
5051 Some(MessageContent::Text(t)) => assert_eq!(t.as_str(), raw_json.as_str()),
5052 other => panic!("expected text content from fallback, got {other:?}"),
5053 }
5054 }
5055
5056 #[test]
5059 fn convert_messages_for_native_non_string_reasoning_content_falls_through() {
5060 let structured_answer = serde_json::json!({
5061 "content": "raw",
5062 "reasoning_content": null
5063 });
5064 let raw_json = structured_answer.to_string();
5065 let messages = vec![ChatMessage::assistant(raw_json.clone())];
5066 let provider = make_model_provider("test", "https://example.com", None);
5067 let native = provider.convert_messages_for_native(&messages, true);
5068 assert_eq!(native.len(), 1);
5069 assert!(native[0].reasoning_content.is_none());
5070 assert!(native[0].tool_calls.is_none());
5071 match &native[0].content {
5072 Some(MessageContent::Text(t)) => assert_eq!(t.as_str(), raw_json.as_str()),
5073 other => panic!("expected text content from fallback, got {other:?}"),
5074 }
5075 }
5076
5077 #[test]
5082 fn convert_messages_for_native_unrelated_json_falls_through() {
5083 let unrelated = serde_json::json!({"foo": "bar"});
5084 let messages = vec![ChatMessage::assistant(unrelated.to_string())];
5085 let provider = make_model_provider("test", "https://example.com", None);
5086 let native = provider.convert_messages_for_native(&messages, true);
5087 assert_eq!(native.len(), 1);
5088 assert!(native[0].reasoning_content.is_none());
5089 assert!(native[0].tool_calls.is_none());
5090 match &native[0].content {
5091 Some(MessageContent::Text(t)) => {
5092 assert!(
5093 t.contains("\"foo\""),
5094 "expected raw JSON in fallback content, got {t:?}"
5095 );
5096 }
5097 other => panic!("expected text content from fallback, got {other:?}"),
5098 }
5099 }
5100
5101 #[test]
5102 fn convert_messages_for_native_reasoning_content_serialized_only_when_present() {
5103 let msg_without = NativeMessage {
5105 role: "assistant".to_string(),
5106 content: Some(MessageContent::Text("hi".to_string())),
5107 tool_call_id: None,
5108 tool_calls: None,
5109 reasoning_content: None,
5110 };
5111 let json = serde_json::to_string(&msg_without).unwrap();
5112 assert!(
5113 !json.contains("reasoning_content"),
5114 "reasoning_content should be omitted when None"
5115 );
5116
5117 let msg_with = NativeMessage {
5118 role: "assistant".to_string(),
5119 content: Some(MessageContent::Text("hi".to_string())),
5120 tool_call_id: None,
5121 tool_calls: None,
5122 reasoning_content: Some("thinking...".to_string()),
5123 };
5124 let json = serde_json::to_string(&msg_with).unwrap();
5125 assert!(
5126 json.contains("reasoning_content"),
5127 "reasoning_content should be present when Some"
5128 );
5129 assert!(json.contains("thinking..."));
5130 }
5131
5132 #[test]
5133 fn default_timeout_is_120s() {
5134 let p = make_model_provider("test", "https://example.com", None);
5135 assert_eq!(p.timeout_secs, 120);
5136 }
5137
5138 #[test]
5139 fn with_timeout_secs_overrides_default() {
5140 let p = make_model_provider("test", "https://example.com", None).with_timeout_secs(300);
5141 assert_eq!(p.timeout_secs, 300);
5142 }
5143
5144 #[test]
5145 fn extra_headers_default_empty() {
5146 let p = make_model_provider("test", "https://example.com", None);
5147 assert!(p.extra_headers.is_empty());
5148 }
5149
5150 #[test]
5151 fn with_extra_headers_sets_headers() {
5152 let mut headers = std::collections::HashMap::new();
5153 headers.insert("X-Title".to_string(), "zeroclaw".to_string());
5154 headers.insert(
5155 "HTTP-Referer".to_string(),
5156 "https://example.com".to_string(),
5157 );
5158 let p =
5159 make_model_provider("test", "https://example.com", None).with_extra_headers(headers);
5160 assert_eq!(p.extra_headers.len(), 2);
5161 assert_eq!(p.extra_headers.get("X-Title").unwrap(), "zeroclaw");
5162 assert_eq!(
5163 p.extra_headers.get("HTTP-Referer").unwrap(),
5164 "https://example.com"
5165 );
5166 }
5167
5168 #[test]
5169 fn http_client_with_extra_headers_builds_successfully() {
5170 let mut headers = std::collections::HashMap::new();
5171 headers.insert("X-Title".to_string(), "zeroclaw".to_string());
5172 headers.insert("User-Agent".to_string(), "TestAgent/1.0".to_string());
5173 let p =
5174 make_model_provider("test", "https://example.com", None).with_extra_headers(headers);
5175 let _client = p.http_client();
5177 }
5178
5179 #[test]
5180 fn http_client_without_extra_headers_or_user_agent() {
5181 let p = make_model_provider("test", "https://example.com", None);
5182 let _client = p.http_client();
5184 }
5185
5186 #[test]
5187 fn extra_headers_combined_with_user_agent() {
5188 let mut headers = std::collections::HashMap::new();
5189 headers.insert("X-Title".to_string(), "zeroclaw".to_string());
5190 let p = OpenAiCompatibleModelProvider::new_with_user_agent(
5191 "test",
5192 "test",
5193 "https://example.com",
5194 None,
5195 AuthStyle::Bearer,
5196 "CustomAgent/1.0",
5197 )
5198 .with_extra_headers(headers);
5199 assert_eq!(p.user_agent.as_deref(), Some("CustomAgent/1.0"));
5200 assert_eq!(p.extra_headers.len(), 1);
5201 let _client = p.http_client();
5203 }
5204
5205 #[test]
5206 fn tool_call_none_fields_omitted_from_json() {
5207 let tc = ToolCall {
5210 id: Some("call_1".to_string()),
5211 kind: Some("function".to_string()),
5212 function: Some(Function {
5213 name: Some("shell".to_string()),
5214 arguments: Some("{\"command\":\"ls\"}".to_string()),
5215 }),
5216 name: None,
5217 arguments: None,
5218 parameters: None,
5219 extra_content: None,
5220 };
5221 let json = serde_json::to_value(&tc).unwrap();
5222 assert!(!json.as_object().unwrap().contains_key("name"));
5223 assert!(!json.as_object().unwrap().contains_key("arguments"));
5224 assert!(!json.as_object().unwrap().contains_key("parameters"));
5225 assert!(json.as_object().unwrap().contains_key("id"));
5227 assert!(json.as_object().unwrap().contains_key("type"));
5228 assert!(json.as_object().unwrap().contains_key("function"));
5229 }
5230
5231 #[test]
5232 fn tool_call_with_compat_fields_serializes_them() {
5233 let tc = ToolCall {
5235 id: None,
5236 kind: None,
5237 function: None,
5238 name: Some("shell".to_string()),
5239 arguments: Some("{\"command\":\"ls\"}".to_string()),
5240 parameters: None,
5241 extra_content: None,
5242 };
5243 let json = serde_json::to_value(&tc).unwrap();
5244 assert_eq!(json["name"], "shell");
5245 assert_eq!(json["arguments"], "{\"command\":\"ls\"}");
5246 assert!(!json.as_object().unwrap().contains_key("id"));
5248 assert!(!json.as_object().unwrap().contains_key("type"));
5249 assert!(!json.as_object().unwrap().contains_key("function"));
5250 assert!(!json.as_object().unwrap().contains_key("parameters"));
5251 }
5252
5253 #[test]
5256 fn proxy_tool_start_valid() {
5257 let line = r#"data: {"x_tool_start":{"name":"bash","arguments":"{\"cmd\":\"ls\"}"}}"#;
5258 let event = parse_proxy_tool_event(line);
5259 assert!(matches!(
5260 event,
5261 Some(StreamEvent::PreExecutedToolCall { ref name, ref args })
5262 if name == "bash" && args == r#"{"cmd":"ls"}"#
5263 ));
5264 }
5265
5266 #[test]
5267 fn proxy_tool_start_missing_name_returns_none() {
5268 let line = r#"data: {"x_tool_start":{"arguments":"{}"}}"#;
5269 assert!(parse_proxy_tool_event(line).is_none());
5270 }
5271
5272 #[test]
5273 fn proxy_tool_start_missing_arguments_defaults() {
5274 let line = r#"data: {"x_tool_start":{"name":"read"}}"#;
5275 let event = parse_proxy_tool_event(line);
5276 assert!(matches!(
5277 event,
5278 Some(StreamEvent::PreExecutedToolCall { ref name, ref args })
5279 if name == "read" && args == "{}"
5280 ));
5281 }
5282
5283 #[test]
5284 fn proxy_tool_result_valid() {
5285 let line = r#"data: {"x_tool_result":{"name":"bash","output":"hello world"}}"#;
5286 let event = parse_proxy_tool_event(line);
5287 assert!(matches!(
5288 event,
5289 Some(StreamEvent::PreExecutedToolResult { ref name, ref output })
5290 if name == "bash" && output == "hello world"
5291 ));
5292 }
5293
5294 #[test]
5295 fn proxy_tool_result_missing_fields_uses_defaults() {
5296 let line = r#"data: {"x_tool_result":{}}"#;
5297 let event = parse_proxy_tool_event(line);
5298 assert!(matches!(
5299 event,
5300 Some(StreamEvent::PreExecutedToolResult { ref name, ref output })
5301 if name == "unknown" && output.is_empty()
5302 ));
5303 }
5304
5305 #[test]
5306 fn proxy_tool_event_non_json_returns_none() {
5307 assert!(parse_proxy_tool_event("data: not json").is_none());
5308 }
5309
5310 #[test]
5311 fn proxy_tool_event_no_data_prefix_returns_none() {
5312 let line = r#"{"x_tool_start":{"name":"bash"}}"#;
5313 assert!(parse_proxy_tool_event(line).is_none());
5314 }
5315
5316 #[test]
5317 fn proxy_tool_event_standard_openai_chunk_returns_none() {
5318 let line = r#"data: {"id":"chatcmpl-1","choices":[{"delta":{"content":"hi"}}]}"#;
5319 assert!(parse_proxy_tool_event(line).is_none());
5320 }
5321
5322 #[test]
5323 fn proxy_tool_event_done_sentinel_returns_none() {
5324 assert!(parse_proxy_tool_event("data: [DONE]").is_none());
5325 }
5326
5327 #[test]
5336 fn strip_native_tool_messages_coalesces_adjacent_assistants() {
5337 let messages = vec![
5338 ChatMessage::user("search for cats"),
5339 ChatMessage::assistant(
5340 r#"{"content":"I'll search","tool_calls":[{"id":"t1","name":"web_search","arguments":"{}"}]}"#,
5341 ),
5342 ChatMessage::tool(r#"{"tool_call_id":"t1","content":"Found 10 results"}"#),
5343 ChatMessage::assistant("Here are the results about cats"),
5344 ];
5345 let p = OpenAiCompatibleModelProvider::new_merge_system_into_user(
5346 "test",
5347 "MiniMax",
5348 "https://api.minimax.chat/v1",
5349 Some("k"),
5350 AuthStyle::Bearer,
5351 );
5352 let stripped = p.strip_native_tool_messages(&messages);
5353 let roles: Vec<&str> = stripped.iter().map(|m| m.role.as_str()).collect();
5354 assert!(
5355 !roles.windows(2).any(|w| w[0] == w[1]),
5356 "no two consecutive messages should share a role; got {roles:?}"
5357 );
5358 assert_eq!(roles, vec!["user", "assistant"]);
5360 assert_eq!(stripped[0].content, "search for cats");
5361 assert!(
5362 stripped[1].content.contains("I'll search")
5363 && stripped[1]
5364 .content
5365 .contains("Here are the results about cats"),
5366 "merged assistant should preserve both the pre-tool narration and the final reply; \
5367 got {:?}",
5368 stripped[1].content
5369 );
5370 }
5371
5372 #[test]
5377 fn strip_native_tool_messages_drops_empty_narration_cleanly() {
5378 let messages = vec![
5379 ChatMessage::user("search for cats"),
5380 ChatMessage::assistant(
5381 r#"{"content":"","tool_calls":[{"id":"t1","name":"web_search","arguments":"{}"}]}"#,
5382 ),
5383 ChatMessage::tool(r#"{"tool_call_id":"t1","content":"Found"}"#),
5384 ChatMessage::assistant("Here are the results"),
5385 ];
5386 let p = OpenAiCompatibleModelProvider::new_merge_system_into_user(
5387 "test",
5388 "MiniMax",
5389 "https://api.minimax.chat/v1",
5390 Some("k"),
5391 AuthStyle::Bearer,
5392 );
5393 let stripped = p.strip_native_tool_messages(&messages);
5394 assert_eq!(
5395 stripped.iter().map(|m| m.role.as_str()).collect::<Vec<_>>(),
5396 vec!["user", "assistant"]
5397 );
5398 assert_eq!(stripped[1].content, "Here are the results");
5399 }
5400}