Skip to main content

zeroclaw/providers/
traits.rs

1pub use zeroclaw_api::model_provider::*;
2
3#[cfg(test)]
4mod tests {
5    use super::*;
6    use crate::tools::ToolSpec;
7    use async_trait::async_trait;
8    use futures_util::StreamExt;
9    use futures_util::stream::{self, BoxStream};
10
11    /// Representative non-zero temperature for default-path chat tests;
12    /// mocks ignore it, so any plausible in-range value is fine — this
13    /// matches the historical default used across the codebase.
14    const TEST_DEFAULT_TEMPERATURE: f64 = 0.7;
15
16    /// Zero = greedy sampling; used by streaming tests where we want
17    /// deterministic replays from the mock stream.
18    const TEST_GREEDY_TEMPERATURE: f64 = 0.0;
19
20    struct CapabilityMockModelProvider;
21
22    #[async_trait]
23    impl ModelProvider for CapabilityMockModelProvider {
24        fn capabilities(&self) -> ProviderCapabilities {
25            ProviderCapabilities {
26                native_tool_calling: true,
27                vision: true,
28                prompt_caching: false,
29                extended_thinking: false,
30            }
31        }
32
33        async fn chat_with_system(
34            &self,
35            _system_prompt: Option<&str>,
36            _message: &str,
37            _model: &str,
38            _temperature: Option<f64>,
39        ) -> anyhow::Result<String> {
40            Ok("ok".into())
41        }
42    }
43    impl ::zeroclaw_api::attribution::Attributable for CapabilityMockModelProvider {
44        fn role(&self) -> ::zeroclaw_api::attribution::Role {
45            ::zeroclaw_api::attribution::Role::Provider(
46                ::zeroclaw_api::attribution::ProviderKind::Model(
47                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
48                ),
49            )
50        }
51        fn alias(&self) -> &str {
52            "CapabilityMockModelProvider"
53        }
54    }
55
56    #[test]
57    fn chat_message_constructors() {
58        let sys = ChatMessage::system("Be helpful");
59        assert_eq!(sys.role, "system");
60        assert_eq!(sys.content, "Be helpful");
61
62        let user = ChatMessage::user("Hello");
63        assert_eq!(user.role, "user");
64
65        let asst = ChatMessage::assistant("Hi there");
66        assert_eq!(asst.role, "assistant");
67
68        let tool = ChatMessage::tool("{}");
69        assert_eq!(tool.role, "tool");
70    }
71
72    #[test]
73    fn chat_response_helpers() {
74        let empty = ChatResponse {
75            text: None,
76            tool_calls: vec![],
77            usage: None,
78            reasoning_content: None,
79        };
80        assert!(!empty.has_tool_calls());
81        assert_eq!(empty.text_or_empty(), "");
82
83        let with_tools = ChatResponse {
84            text: Some("Let me check".into()),
85            tool_calls: vec![ToolCall {
86                id: "1".into(),
87                name: "shell".into(),
88                arguments: "{}".into(),
89                extra_content: None,
90            }],
91            usage: None,
92            reasoning_content: None,
93        };
94        assert!(with_tools.has_tool_calls());
95        assert_eq!(with_tools.text_or_empty(), "Let me check");
96    }
97
98    #[test]
99    fn token_usage_default_is_none() {
100        let usage = TokenUsage::default();
101        assert!(usage.input_tokens.is_none());
102        assert!(usage.output_tokens.is_none());
103    }
104
105    #[test]
106    fn chat_response_with_usage() {
107        let resp = ChatResponse {
108            text: Some("Hello".into()),
109            tool_calls: vec![],
110            usage: Some(TokenUsage {
111                input_tokens: Some(100),
112                output_tokens: Some(50),
113                cached_input_tokens: None,
114            }),
115            reasoning_content: None,
116        };
117        assert_eq!(resp.usage.as_ref().unwrap().input_tokens, Some(100));
118        assert_eq!(resp.usage.as_ref().unwrap().output_tokens, Some(50));
119    }
120
121    #[test]
122    fn tool_call_serialization() {
123        let tc = ToolCall {
124            id: "call_123".into(),
125            name: "file_read".into(),
126            arguments: r#"{"path":"test.txt"}"#.into(),
127            extra_content: None,
128        };
129        let json = serde_json::to_string(&tc).unwrap();
130        assert!(json.contains("call_123"));
131        assert!(json.contains("file_read"));
132    }
133
134    #[test]
135    fn conversation_message_variants() {
136        let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
137        let json = serde_json::to_string(&chat).unwrap();
138        assert!(json.contains("\"type\":\"Chat\""));
139
140        let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage {
141            tool_call_id: "1".into(),
142            content: "done".into(),
143        }]);
144        let json = serde_json::to_string(&tool_result).unwrap();
145        assert!(json.contains("\"type\":\"ToolResults\""));
146    }
147
148    #[test]
149    fn provider_capabilities_default() {
150        let caps = ProviderCapabilities::default();
151        assert!(!caps.native_tool_calling);
152        assert!(!caps.vision);
153    }
154
155    #[test]
156    fn provider_capabilities_equality() {
157        let caps1 = ProviderCapabilities {
158            native_tool_calling: true,
159            vision: false,
160            prompt_caching: false,
161            extended_thinking: false,
162        };
163        let caps2 = ProviderCapabilities {
164            native_tool_calling: true,
165            vision: false,
166            prompt_caching: false,
167            extended_thinking: false,
168        };
169        let caps3 = ProviderCapabilities {
170            native_tool_calling: false,
171            vision: false,
172            prompt_caching: false,
173            extended_thinking: false,
174        };
175
176        assert_eq!(caps1, caps2);
177        assert_ne!(caps1, caps3);
178    }
179
180    #[test]
181    fn supports_native_tools_reflects_capabilities_default_mapping() {
182        let model_provider = CapabilityMockModelProvider;
183        assert!(model_provider.supports_native_tools());
184    }
185
186    #[test]
187    fn supports_vision_reflects_capabilities_default_mapping() {
188        let model_provider = CapabilityMockModelProvider;
189        assert!(model_provider.supports_vision());
190    }
191
192    #[test]
193    fn tools_payload_variants() {
194        let gemini = ToolsPayload::Gemini {
195            function_declarations: vec![serde_json::json!({"name": "test"})],
196        };
197        assert!(matches!(gemini, ToolsPayload::Gemini { .. }));
198
199        let anthropic = ToolsPayload::Anthropic {
200            tools: vec![serde_json::json!({"name": "test"})],
201        };
202        assert!(matches!(anthropic, ToolsPayload::Anthropic { .. }));
203
204        let openai = ToolsPayload::OpenAI {
205            tools: vec![serde_json::json!({"type": "function"})],
206        };
207        assert!(matches!(openai, ToolsPayload::OpenAI { .. }));
208
209        let prompt_guided = ToolsPayload::PromptGuided {
210            instructions: "Use tools...".to_string(),
211        };
212        assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. }));
213    }
214
215    #[test]
216    fn build_tool_instructions_text_format() {
217        let tools = vec![
218            ToolSpec {
219                name: "shell".to_string(),
220                description: "Execute commands".to_string(),
221                parameters: serde_json::json!({
222                    "type": "object",
223                    "properties": {
224                        "command": {"type": "string"}
225                    }
226                }),
227            },
228            ToolSpec {
229                name: "file_read".to_string(),
230                description: "Read files".to_string(),
231                parameters: serde_json::json!({
232                    "type": "object",
233                    "properties": {
234                        "path": {"type": "string"}
235                    }
236                }),
237            },
238        ];
239
240        let instructions = build_tool_instructions_text(&tools);
241
242        assert!(instructions.contains("Tool Use Protocol"));
243        assert!(instructions.contains("<tool_call>"));
244        assert!(instructions.contains("</tool_call>"));
245        assert!(instructions.contains("**shell**"));
246        assert!(instructions.contains("Execute commands"));
247        assert!(instructions.contains("**file_read**"));
248        assert!(instructions.contains("Read files"));
249        assert!(instructions.contains("Parameters:"));
250        assert!(instructions.contains(r#""type":"object""#));
251    }
252
253    #[test]
254    fn build_tool_instructions_text_empty() {
255        let instructions = build_tool_instructions_text(&[]);
256        assert!(instructions.contains("Tool Use Protocol"));
257        assert!(instructions.contains("Available Tools"));
258    }
259
260    struct MockModelProvider {
261        supports_native: bool,
262    }
263
264    #[async_trait]
265    impl ModelProvider for MockModelProvider {
266        fn supports_native_tools(&self) -> bool {
267            self.supports_native
268        }
269
270        async fn chat_with_system(
271            &self,
272            _system: Option<&str>,
273            _message: &str,
274            _model: &str,
275            _temperature: Option<f64>,
276        ) -> anyhow::Result<String> {
277            Ok("response".to_string())
278        }
279    }
280    impl ::zeroclaw_api::attribution::Attributable for MockModelProvider {
281        fn role(&self) -> ::zeroclaw_api::attribution::Role {
282            ::zeroclaw_api::attribution::Role::Provider(
283                ::zeroclaw_api::attribution::ProviderKind::Model(
284                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
285                ),
286            )
287        }
288        fn alias(&self) -> &str {
289            "MockModelProvider"
290        }
291    }
292
293    #[test]
294    fn provider_convert_tools_default() {
295        let model_provider = MockModelProvider {
296            supports_native: false,
297        };
298
299        let tools = vec![ToolSpec {
300            name: "test_tool".to_string(),
301            description: "A test tool".to_string(),
302            parameters: serde_json::json!({"type": "object"}),
303        }];
304
305        let payload = model_provider.convert_tools(&tools);
306        assert!(matches!(payload, ToolsPayload::PromptGuided { .. }));
307
308        if let ToolsPayload::PromptGuided { instructions } = payload {
309            assert!(instructions.contains("test_tool"));
310            assert!(instructions.contains("A test tool"));
311        }
312    }
313
314    #[tokio::test]
315    async fn provider_chat_prompt_guided_fallback() {
316        let model_provider = MockModelProvider {
317            supports_native: false,
318        };
319
320        let tools = vec![ToolSpec {
321            name: "shell".to_string(),
322            description: "Run commands".to_string(),
323            parameters: serde_json::json!({"type": "object"}),
324        }];
325
326        let request = ChatRequest {
327            messages: &[ChatMessage::user("Hello")],
328            tools: Some(&tools),
329            thinking: None,
330        };
331
332        let response = model_provider
333            .chat(request, "model", Some(TEST_DEFAULT_TEMPERATURE))
334            .await
335            .unwrap();
336        assert!(response.text.is_some());
337    }
338
339    #[tokio::test]
340    async fn provider_chat_without_tools() {
341        let model_provider = MockModelProvider {
342            supports_native: true,
343        };
344
345        let request = ChatRequest {
346            messages: &[ChatMessage::user("Hello")],
347            tools: None,
348            thinking: None,
349        };
350
351        let response = model_provider
352            .chat(request, "model", Some(TEST_DEFAULT_TEMPERATURE))
353            .await
354            .unwrap();
355        assert!(response.text.is_some());
356    }
357
358    struct EchoSystemModelProvider {
359        supports_native: bool,
360    }
361
362    #[async_trait]
363    impl ModelProvider for EchoSystemModelProvider {
364        fn supports_native_tools(&self) -> bool {
365            self.supports_native
366        }
367
368        async fn chat_with_system(
369            &self,
370            system: Option<&str>,
371            _message: &str,
372            _model: &str,
373            _temperature: Option<f64>,
374        ) -> anyhow::Result<String> {
375            Ok(system.unwrap_or_default().to_string())
376        }
377    }
378    impl ::zeroclaw_api::attribution::Attributable for EchoSystemModelProvider {
379        fn role(&self) -> ::zeroclaw_api::attribution::Role {
380            ::zeroclaw_api::attribution::Role::Provider(
381                ::zeroclaw_api::attribution::ProviderKind::Model(
382                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
383                ),
384            )
385        }
386        fn alias(&self) -> &str {
387            "EchoSystemModelProvider"
388        }
389    }
390
391    struct CustomConvertModelProvider;
392
393    #[async_trait]
394    impl ModelProvider for CustomConvertModelProvider {
395        fn supports_native_tools(&self) -> bool {
396            false
397        }
398
399        fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
400            ToolsPayload::PromptGuided {
401                instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(),
402            }
403        }
404
405        async fn chat_with_system(
406            &self,
407            system: Option<&str>,
408            _message: &str,
409            _model: &str,
410            _temperature: Option<f64>,
411        ) -> anyhow::Result<String> {
412            Ok(system.unwrap_or_default().to_string())
413        }
414    }
415    impl ::zeroclaw_api::attribution::Attributable for CustomConvertModelProvider {
416        fn role(&self) -> ::zeroclaw_api::attribution::Role {
417            ::zeroclaw_api::attribution::Role::Provider(
418                ::zeroclaw_api::attribution::ProviderKind::Model(
419                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
420                ),
421            )
422        }
423        fn alias(&self) -> &str {
424            "CustomConvertModelProvider"
425        }
426    }
427
428    struct InvalidConvertModelProvider;
429
430    #[async_trait]
431    impl ModelProvider for InvalidConvertModelProvider {
432        fn supports_native_tools(&self) -> bool {
433            false
434        }
435
436        fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
437            ToolsPayload::OpenAI {
438                tools: vec![serde_json::json!({"type": "function"})],
439            }
440        }
441
442        async fn chat_with_system(
443            &self,
444            _system: Option<&str>,
445            _message: &str,
446            _model: &str,
447            _temperature: Option<f64>,
448        ) -> anyhow::Result<String> {
449            Ok("should_not_reach".to_string())
450        }
451    }
452    impl ::zeroclaw_api::attribution::Attributable for InvalidConvertModelProvider {
453        fn role(&self) -> ::zeroclaw_api::attribution::Role {
454            ::zeroclaw_api::attribution::Role::Provider(
455                ::zeroclaw_api::attribution::ProviderKind::Model(
456                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
457                ),
458            )
459        }
460        fn alias(&self) -> &str {
461            "InvalidConvertModelProvider"
462        }
463    }
464
465    #[tokio::test]
466    async fn provider_chat_prompt_guided_preserves_existing_system_not_first() {
467        let model_provider = EchoSystemModelProvider {
468            supports_native: false,
469        };
470
471        let tools = vec![ToolSpec {
472            name: "shell".to_string(),
473            description: "Run commands".to_string(),
474            parameters: serde_json::json!({"type": "object"}),
475        }];
476
477        let request = ChatRequest {
478            messages: &[
479                ChatMessage::user("Hello"),
480                ChatMessage::system("BASE_SYSTEM_PROMPT"),
481            ],
482            tools: Some(&tools),
483            thinking: None,
484        };
485
486        let response = model_provider
487            .chat(request, "model", Some(TEST_DEFAULT_TEMPERATURE))
488            .await
489            .unwrap();
490        let text = response.text.unwrap_or_default();
491
492        assert!(text.contains("BASE_SYSTEM_PROMPT"));
493        assert!(text.contains("Tool Use Protocol"));
494    }
495
496    #[tokio::test]
497    async fn provider_chat_prompt_guided_uses_convert_tools_override() {
498        let model_provider = CustomConvertModelProvider;
499
500        let tools = vec![ToolSpec {
501            name: "shell".to_string(),
502            description: "Run commands".to_string(),
503            parameters: serde_json::json!({"type": "object"}),
504        }];
505
506        let request = ChatRequest {
507            messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")],
508            tools: Some(&tools),
509            thinking: None,
510        };
511
512        let response = model_provider
513            .chat(request, "model", Some(TEST_DEFAULT_TEMPERATURE))
514            .await
515            .unwrap();
516        let text = response.text.unwrap_or_default();
517
518        assert!(text.contains("BASE"));
519        assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS"));
520    }
521
522    #[tokio::test]
523    async fn provider_chat_prompt_guided_rejects_non_prompt_payload() {
524        let model_provider = InvalidConvertModelProvider;
525
526        let tools = vec![ToolSpec {
527            name: "shell".to_string(),
528            description: "Run commands".to_string(),
529            parameters: serde_json::json!({"type": "object"}),
530        }];
531
532        let request = ChatRequest {
533            messages: &[ChatMessage::user("Hello")],
534            tools: Some(&tools),
535            thinking: None,
536        };
537
538        let err = model_provider
539            .chat(request, "model", Some(TEST_DEFAULT_TEMPERATURE))
540            .await
541            .unwrap_err();
542        let message = err.to_string();
543
544        assert!(message.contains("non-prompt-guided"));
545    }
546
547    struct StreamingChunkOnlyModelProvider;
548
549    #[async_trait]
550    impl ModelProvider for StreamingChunkOnlyModelProvider {
551        async fn chat_with_system(
552            &self,
553            _system_prompt: Option<&str>,
554            _message: &str,
555            _model: &str,
556            _temperature: Option<f64>,
557        ) -> anyhow::Result<String> {
558            Ok("ok".to_string())
559        }
560
561        fn supports_streaming(&self) -> bool {
562            true
563        }
564
565        fn stream_chat_with_history(
566            &self,
567            _messages: &[ChatMessage],
568            _model: &str,
569            _temperature: Option<f64>,
570            _options: StreamOptions,
571        ) -> BoxStream<'static, StreamResult<StreamChunk>> {
572            stream::iter(vec![
573                Ok(StreamChunk::delta("hello")),
574                Ok(StreamChunk::final_chunk()),
575            ])
576            .boxed()
577        }
578    }
579    impl ::zeroclaw_api::attribution::Attributable for StreamingChunkOnlyModelProvider {
580        fn role(&self) -> ::zeroclaw_api::attribution::Role {
581            ::zeroclaw_api::attribution::Role::Provider(
582                ::zeroclaw_api::attribution::ProviderKind::Model(
583                    ::zeroclaw_api::attribution::ModelProviderKind::Custom,
584                ),
585            )
586        }
587        fn alias(&self) -> &str {
588            "StreamingChunkOnlyModelProvider"
589        }
590    }
591
592    #[tokio::test]
593    async fn provider_stream_chat_default_maps_legacy_chunks_to_events() {
594        let model_provider = StreamingChunkOnlyModelProvider;
595        let mut stream = model_provider.stream_chat(
596            ChatRequest {
597                messages: &[ChatMessage::user("hi")],
598                tools: None,
599                thinking: None,
600            },
601            "model",
602            Some(TEST_GREEDY_TEMPERATURE),
603            StreamOptions::new(true),
604        );
605
606        let first = stream.next().await.unwrap().unwrap();
607        let second = stream.next().await.unwrap().unwrap();
608        assert!(stream.next().await.is_none());
609
610        match first {
611            StreamEvent::TextDelta(chunk) => assert_eq!(chunk.delta, "hello"),
612            other => panic!("expected text delta event, got {other:?}"),
613        }
614        assert!(matches!(second, StreamEvent::Final));
615    }
616}