1pub use zeroclaw_api::model_provider::*;
2
3#[cfg(test)]
4mod tests {
5 use super::*;
6 use crate::tools::ToolSpec;
7 use async_trait::async_trait;
8 use futures_util::StreamExt;
9 use futures_util::stream::{self, BoxStream};
10
11 const TEST_DEFAULT_TEMPERATURE: f64 = 0.7;
15
16 const TEST_GREEDY_TEMPERATURE: f64 = 0.0;
19
20 struct CapabilityMockModelProvider;
21
22 #[async_trait]
23 impl ModelProvider for CapabilityMockModelProvider {
24 fn capabilities(&self) -> ProviderCapabilities {
25 ProviderCapabilities {
26 native_tool_calling: true,
27 vision: true,
28 prompt_caching: false,
29 extended_thinking: false,
30 }
31 }
32
33 async fn chat_with_system(
34 &self,
35 _system_prompt: Option<&str>,
36 _message: &str,
37 _model: &str,
38 _temperature: Option<f64>,
39 ) -> anyhow::Result<String> {
40 Ok("ok".into())
41 }
42 }
43 impl ::zeroclaw_api::attribution::Attributable for CapabilityMockModelProvider {
44 fn role(&self) -> ::zeroclaw_api::attribution::Role {
45 ::zeroclaw_api::attribution::Role::Provider(
46 ::zeroclaw_api::attribution::ProviderKind::Model(
47 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
48 ),
49 )
50 }
51 fn alias(&self) -> &str {
52 "CapabilityMockModelProvider"
53 }
54 }
55
56 #[test]
57 fn chat_message_constructors() {
58 let sys = ChatMessage::system("Be helpful");
59 assert_eq!(sys.role, "system");
60 assert_eq!(sys.content, "Be helpful");
61
62 let user = ChatMessage::user("Hello");
63 assert_eq!(user.role, "user");
64
65 let asst = ChatMessage::assistant("Hi there");
66 assert_eq!(asst.role, "assistant");
67
68 let tool = ChatMessage::tool("{}");
69 assert_eq!(tool.role, "tool");
70 }
71
72 #[test]
73 fn chat_response_helpers() {
74 let empty = ChatResponse {
75 text: None,
76 tool_calls: vec![],
77 usage: None,
78 reasoning_content: None,
79 };
80 assert!(!empty.has_tool_calls());
81 assert_eq!(empty.text_or_empty(), "");
82
83 let with_tools = ChatResponse {
84 text: Some("Let me check".into()),
85 tool_calls: vec![ToolCall {
86 id: "1".into(),
87 name: "shell".into(),
88 arguments: "{}".into(),
89 extra_content: None,
90 }],
91 usage: None,
92 reasoning_content: None,
93 };
94 assert!(with_tools.has_tool_calls());
95 assert_eq!(with_tools.text_or_empty(), "Let me check");
96 }
97
98 #[test]
99 fn token_usage_default_is_none() {
100 let usage = TokenUsage::default();
101 assert!(usage.input_tokens.is_none());
102 assert!(usage.output_tokens.is_none());
103 }
104
105 #[test]
106 fn chat_response_with_usage() {
107 let resp = ChatResponse {
108 text: Some("Hello".into()),
109 tool_calls: vec![],
110 usage: Some(TokenUsage {
111 input_tokens: Some(100),
112 output_tokens: Some(50),
113 cached_input_tokens: None,
114 }),
115 reasoning_content: None,
116 };
117 assert_eq!(resp.usage.as_ref().unwrap().input_tokens, Some(100));
118 assert_eq!(resp.usage.as_ref().unwrap().output_tokens, Some(50));
119 }
120
121 #[test]
122 fn tool_call_serialization() {
123 let tc = ToolCall {
124 id: "call_123".into(),
125 name: "file_read".into(),
126 arguments: r#"{"path":"test.txt"}"#.into(),
127 extra_content: None,
128 };
129 let json = serde_json::to_string(&tc).unwrap();
130 assert!(json.contains("call_123"));
131 assert!(json.contains("file_read"));
132 }
133
134 #[test]
135 fn conversation_message_variants() {
136 let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
137 let json = serde_json::to_string(&chat).unwrap();
138 assert!(json.contains("\"type\":\"Chat\""));
139
140 let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage {
141 tool_call_id: "1".into(),
142 content: "done".into(),
143 }]);
144 let json = serde_json::to_string(&tool_result).unwrap();
145 assert!(json.contains("\"type\":\"ToolResults\""));
146 }
147
148 #[test]
149 fn provider_capabilities_default() {
150 let caps = ProviderCapabilities::default();
151 assert!(!caps.native_tool_calling);
152 assert!(!caps.vision);
153 }
154
155 #[test]
156 fn provider_capabilities_equality() {
157 let caps1 = ProviderCapabilities {
158 native_tool_calling: true,
159 vision: false,
160 prompt_caching: false,
161 extended_thinking: false,
162 };
163 let caps2 = ProviderCapabilities {
164 native_tool_calling: true,
165 vision: false,
166 prompt_caching: false,
167 extended_thinking: false,
168 };
169 let caps3 = ProviderCapabilities {
170 native_tool_calling: false,
171 vision: false,
172 prompt_caching: false,
173 extended_thinking: false,
174 };
175
176 assert_eq!(caps1, caps2);
177 assert_ne!(caps1, caps3);
178 }
179
180 #[test]
181 fn supports_native_tools_reflects_capabilities_default_mapping() {
182 let model_provider = CapabilityMockModelProvider;
183 assert!(model_provider.supports_native_tools());
184 }
185
186 #[test]
187 fn supports_vision_reflects_capabilities_default_mapping() {
188 let model_provider = CapabilityMockModelProvider;
189 assert!(model_provider.supports_vision());
190 }
191
192 #[test]
193 fn tools_payload_variants() {
194 let gemini = ToolsPayload::Gemini {
195 function_declarations: vec![serde_json::json!({"name": "test"})],
196 };
197 assert!(matches!(gemini, ToolsPayload::Gemini { .. }));
198
199 let anthropic = ToolsPayload::Anthropic {
200 tools: vec![serde_json::json!({"name": "test"})],
201 };
202 assert!(matches!(anthropic, ToolsPayload::Anthropic { .. }));
203
204 let openai = ToolsPayload::OpenAI {
205 tools: vec![serde_json::json!({"type": "function"})],
206 };
207 assert!(matches!(openai, ToolsPayload::OpenAI { .. }));
208
209 let prompt_guided = ToolsPayload::PromptGuided {
210 instructions: "Use tools...".to_string(),
211 };
212 assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. }));
213 }
214
215 #[test]
216 fn build_tool_instructions_text_format() {
217 let tools = vec![
218 ToolSpec {
219 name: "shell".to_string(),
220 description: "Execute commands".to_string(),
221 parameters: serde_json::json!({
222 "type": "object",
223 "properties": {
224 "command": {"type": "string"}
225 }
226 }),
227 },
228 ToolSpec {
229 name: "file_read".to_string(),
230 description: "Read files".to_string(),
231 parameters: serde_json::json!({
232 "type": "object",
233 "properties": {
234 "path": {"type": "string"}
235 }
236 }),
237 },
238 ];
239
240 let instructions = build_tool_instructions_text(&tools);
241
242 assert!(instructions.contains("Tool Use Protocol"));
243 assert!(instructions.contains("<tool_call>"));
244 assert!(instructions.contains("</tool_call>"));
245 assert!(instructions.contains("**shell**"));
246 assert!(instructions.contains("Execute commands"));
247 assert!(instructions.contains("**file_read**"));
248 assert!(instructions.contains("Read files"));
249 assert!(instructions.contains("Parameters:"));
250 assert!(instructions.contains(r#""type":"object""#));
251 }
252
253 #[test]
254 fn build_tool_instructions_text_empty() {
255 let instructions = build_tool_instructions_text(&[]);
256 assert!(instructions.contains("Tool Use Protocol"));
257 assert!(instructions.contains("Available Tools"));
258 }
259
260 struct MockModelProvider {
261 supports_native: bool,
262 }
263
264 #[async_trait]
265 impl ModelProvider for MockModelProvider {
266 fn supports_native_tools(&self) -> bool {
267 self.supports_native
268 }
269
270 async fn chat_with_system(
271 &self,
272 _system: Option<&str>,
273 _message: &str,
274 _model: &str,
275 _temperature: Option<f64>,
276 ) -> anyhow::Result<String> {
277 Ok("response".to_string())
278 }
279 }
280 impl ::zeroclaw_api::attribution::Attributable for MockModelProvider {
281 fn role(&self) -> ::zeroclaw_api::attribution::Role {
282 ::zeroclaw_api::attribution::Role::Provider(
283 ::zeroclaw_api::attribution::ProviderKind::Model(
284 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
285 ),
286 )
287 }
288 fn alias(&self) -> &str {
289 "MockModelProvider"
290 }
291 }
292
293 #[test]
294 fn provider_convert_tools_default() {
295 let model_provider = MockModelProvider {
296 supports_native: false,
297 };
298
299 let tools = vec![ToolSpec {
300 name: "test_tool".to_string(),
301 description: "A test tool".to_string(),
302 parameters: serde_json::json!({"type": "object"}),
303 }];
304
305 let payload = model_provider.convert_tools(&tools);
306 assert!(matches!(payload, ToolsPayload::PromptGuided { .. }));
307
308 if let ToolsPayload::PromptGuided { instructions } = payload {
309 assert!(instructions.contains("test_tool"));
310 assert!(instructions.contains("A test tool"));
311 }
312 }
313
314 #[tokio::test]
315 async fn provider_chat_prompt_guided_fallback() {
316 let model_provider = MockModelProvider {
317 supports_native: false,
318 };
319
320 let tools = vec![ToolSpec {
321 name: "shell".to_string(),
322 description: "Run commands".to_string(),
323 parameters: serde_json::json!({"type": "object"}),
324 }];
325
326 let request = ChatRequest {
327 messages: &[ChatMessage::user("Hello")],
328 tools: Some(&tools),
329 thinking: None,
330 };
331
332 let response = model_provider
333 .chat(request, "model", Some(TEST_DEFAULT_TEMPERATURE))
334 .await
335 .unwrap();
336 assert!(response.text.is_some());
337 }
338
339 #[tokio::test]
340 async fn provider_chat_without_tools() {
341 let model_provider = MockModelProvider {
342 supports_native: true,
343 };
344
345 let request = ChatRequest {
346 messages: &[ChatMessage::user("Hello")],
347 tools: None,
348 thinking: None,
349 };
350
351 let response = model_provider
352 .chat(request, "model", Some(TEST_DEFAULT_TEMPERATURE))
353 .await
354 .unwrap();
355 assert!(response.text.is_some());
356 }
357
358 struct EchoSystemModelProvider {
359 supports_native: bool,
360 }
361
362 #[async_trait]
363 impl ModelProvider for EchoSystemModelProvider {
364 fn supports_native_tools(&self) -> bool {
365 self.supports_native
366 }
367
368 async fn chat_with_system(
369 &self,
370 system: Option<&str>,
371 _message: &str,
372 _model: &str,
373 _temperature: Option<f64>,
374 ) -> anyhow::Result<String> {
375 Ok(system.unwrap_or_default().to_string())
376 }
377 }
378 impl ::zeroclaw_api::attribution::Attributable for EchoSystemModelProvider {
379 fn role(&self) -> ::zeroclaw_api::attribution::Role {
380 ::zeroclaw_api::attribution::Role::Provider(
381 ::zeroclaw_api::attribution::ProviderKind::Model(
382 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
383 ),
384 )
385 }
386 fn alias(&self) -> &str {
387 "EchoSystemModelProvider"
388 }
389 }
390
391 struct CustomConvertModelProvider;
392
393 #[async_trait]
394 impl ModelProvider for CustomConvertModelProvider {
395 fn supports_native_tools(&self) -> bool {
396 false
397 }
398
399 fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
400 ToolsPayload::PromptGuided {
401 instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(),
402 }
403 }
404
405 async fn chat_with_system(
406 &self,
407 system: Option<&str>,
408 _message: &str,
409 _model: &str,
410 _temperature: Option<f64>,
411 ) -> anyhow::Result<String> {
412 Ok(system.unwrap_or_default().to_string())
413 }
414 }
415 impl ::zeroclaw_api::attribution::Attributable for CustomConvertModelProvider {
416 fn role(&self) -> ::zeroclaw_api::attribution::Role {
417 ::zeroclaw_api::attribution::Role::Provider(
418 ::zeroclaw_api::attribution::ProviderKind::Model(
419 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
420 ),
421 )
422 }
423 fn alias(&self) -> &str {
424 "CustomConvertModelProvider"
425 }
426 }
427
428 struct InvalidConvertModelProvider;
429
430 #[async_trait]
431 impl ModelProvider for InvalidConvertModelProvider {
432 fn supports_native_tools(&self) -> bool {
433 false
434 }
435
436 fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
437 ToolsPayload::OpenAI {
438 tools: vec![serde_json::json!({"type": "function"})],
439 }
440 }
441
442 async fn chat_with_system(
443 &self,
444 _system: Option<&str>,
445 _message: &str,
446 _model: &str,
447 _temperature: Option<f64>,
448 ) -> anyhow::Result<String> {
449 Ok("should_not_reach".to_string())
450 }
451 }
452 impl ::zeroclaw_api::attribution::Attributable for InvalidConvertModelProvider {
453 fn role(&self) -> ::zeroclaw_api::attribution::Role {
454 ::zeroclaw_api::attribution::Role::Provider(
455 ::zeroclaw_api::attribution::ProviderKind::Model(
456 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
457 ),
458 )
459 }
460 fn alias(&self) -> &str {
461 "InvalidConvertModelProvider"
462 }
463 }
464
465 #[tokio::test]
466 async fn provider_chat_prompt_guided_preserves_existing_system_not_first() {
467 let model_provider = EchoSystemModelProvider {
468 supports_native: false,
469 };
470
471 let tools = vec![ToolSpec {
472 name: "shell".to_string(),
473 description: "Run commands".to_string(),
474 parameters: serde_json::json!({"type": "object"}),
475 }];
476
477 let request = ChatRequest {
478 messages: &[
479 ChatMessage::user("Hello"),
480 ChatMessage::system("BASE_SYSTEM_PROMPT"),
481 ],
482 tools: Some(&tools),
483 thinking: None,
484 };
485
486 let response = model_provider
487 .chat(request, "model", Some(TEST_DEFAULT_TEMPERATURE))
488 .await
489 .unwrap();
490 let text = response.text.unwrap_or_default();
491
492 assert!(text.contains("BASE_SYSTEM_PROMPT"));
493 assert!(text.contains("Tool Use Protocol"));
494 }
495
496 #[tokio::test]
497 async fn provider_chat_prompt_guided_uses_convert_tools_override() {
498 let model_provider = CustomConvertModelProvider;
499
500 let tools = vec![ToolSpec {
501 name: "shell".to_string(),
502 description: "Run commands".to_string(),
503 parameters: serde_json::json!({"type": "object"}),
504 }];
505
506 let request = ChatRequest {
507 messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")],
508 tools: Some(&tools),
509 thinking: None,
510 };
511
512 let response = model_provider
513 .chat(request, "model", Some(TEST_DEFAULT_TEMPERATURE))
514 .await
515 .unwrap();
516 let text = response.text.unwrap_or_default();
517
518 assert!(text.contains("BASE"));
519 assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS"));
520 }
521
522 #[tokio::test]
523 async fn provider_chat_prompt_guided_rejects_non_prompt_payload() {
524 let model_provider = InvalidConvertModelProvider;
525
526 let tools = vec![ToolSpec {
527 name: "shell".to_string(),
528 description: "Run commands".to_string(),
529 parameters: serde_json::json!({"type": "object"}),
530 }];
531
532 let request = ChatRequest {
533 messages: &[ChatMessage::user("Hello")],
534 tools: Some(&tools),
535 thinking: None,
536 };
537
538 let err = model_provider
539 .chat(request, "model", Some(TEST_DEFAULT_TEMPERATURE))
540 .await
541 .unwrap_err();
542 let message = err.to_string();
543
544 assert!(message.contains("non-prompt-guided"));
545 }
546
547 struct StreamingChunkOnlyModelProvider;
548
549 #[async_trait]
550 impl ModelProvider for StreamingChunkOnlyModelProvider {
551 async fn chat_with_system(
552 &self,
553 _system_prompt: Option<&str>,
554 _message: &str,
555 _model: &str,
556 _temperature: Option<f64>,
557 ) -> anyhow::Result<String> {
558 Ok("ok".to_string())
559 }
560
561 fn supports_streaming(&self) -> bool {
562 true
563 }
564
565 fn stream_chat_with_history(
566 &self,
567 _messages: &[ChatMessage],
568 _model: &str,
569 _temperature: Option<f64>,
570 _options: StreamOptions,
571 ) -> BoxStream<'static, StreamResult<StreamChunk>> {
572 stream::iter(vec![
573 Ok(StreamChunk::delta("hello")),
574 Ok(StreamChunk::final_chunk()),
575 ])
576 .boxed()
577 }
578 }
579 impl ::zeroclaw_api::attribution::Attributable for StreamingChunkOnlyModelProvider {
580 fn role(&self) -> ::zeroclaw_api::attribution::Role {
581 ::zeroclaw_api::attribution::Role::Provider(
582 ::zeroclaw_api::attribution::ProviderKind::Model(
583 ::zeroclaw_api::attribution::ModelProviderKind::Custom,
584 ),
585 )
586 }
587 fn alias(&self) -> &str {
588 "StreamingChunkOnlyModelProvider"
589 }
590 }
591
592 #[tokio::test]
593 async fn provider_stream_chat_default_maps_legacy_chunks_to_events() {
594 let model_provider = StreamingChunkOnlyModelProvider;
595 let mut stream = model_provider.stream_chat(
596 ChatRequest {
597 messages: &[ChatMessage::user("hi")],
598 tools: None,
599 thinking: None,
600 },
601 "model",
602 Some(TEST_GREEDY_TEMPERATURE),
603 StreamOptions::new(true),
604 );
605
606 let first = stream.next().await.unwrap().unwrap();
607 let second = stream.next().await.unwrap().unwrap();
608 assert!(stream.next().await.is_none());
609
610 match first {
611 StreamEvent::TextDelta(chunk) => assert_eq!(chunk.delta, "hello"),
612 other => panic!("expected text delta event, got {other:?}"),
613 }
614 assert!(matches!(second, StreamEvent::Final));
615 }
616}