1use crate::tool::ToolSpec;
2use async_trait::async_trait;
3use futures_util::{StreamExt, stream};
4use serde::{Deserialize, Serialize};
5use std::fmt::Write;
6use std::sync::Arc;
7
8pub const MAX_BUDGET_TOKENS: u32 = 128_000;
9pub const MIN_BUDGET_TOKENS: u32 = 1_024;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub struct NativeThinkingParams {
17 pub budget_tokens: u32,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ChatMessage {
23 pub role: String,
24 pub content: String,
25}
26
27impl ChatMessage {
28 pub fn system(content: impl Into<String>) -> Self {
29 Self {
30 role: "system".into(),
31 content: content.into(),
32 }
33 }
34
35 pub fn user(content: impl Into<String>) -> Self {
36 Self {
37 role: "user".into(),
38 content: content.into(),
39 }
40 }
41
42 pub fn assistant(content: impl Into<String>) -> Self {
43 Self {
44 role: "assistant".into(),
45 content: content.into(),
46 }
47 }
48
49 pub fn tool(content: impl Into<String>) -> Self {
50 Self {
51 role: "tool".into(),
52 content: content.into(),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ToolCall {
60 pub id: String,
61 pub name: String,
62 pub arguments: String,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
67 pub extra_content: Option<serde_json::Value>,
68}
69
70#[derive(Debug, Clone, Default)]
89pub struct TokenUsage {
90 pub input_tokens: Option<u64>,
92 pub output_tokens: Option<u64>,
93 pub cached_input_tokens: Option<u64>,
97}
98
99#[derive(Debug, Clone)]
101pub struct ChatResponse {
102 pub text: Option<String>,
104 pub tool_calls: Vec<ToolCall>,
106 pub usage: Option<TokenUsage>,
108 pub reasoning_content: Option<String>,
113}
114
115impl ChatResponse {
116 pub fn has_tool_calls(&self) -> bool {
118 !self.tool_calls.is_empty()
119 }
120
121 pub fn text_or_empty(&self) -> &str {
123 self.text.as_deref().unwrap_or("")
124 }
125}
126
127#[derive(Debug, Clone, Copy)]
129pub struct ChatRequest<'a> {
130 pub messages: &'a [ChatMessage],
131 pub tools: Option<&'a [ToolSpec]>,
132 pub thinking: Option<NativeThinkingParams>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ToolResultMessage {
141 pub tool_call_id: String,
142 pub content: String,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147#[serde(tag = "type", content = "data")]
148pub enum ConversationMessage {
149 Chat(ChatMessage),
151 AssistantToolCalls {
153 text: Option<String>,
154 tool_calls: Vec<ToolCall>,
155 reasoning_content: Option<String>,
158 },
159 ToolResults(Vec<ToolResultMessage>),
161}
162
163#[derive(Debug, Clone)]
165pub struct StreamChunk {
166 pub delta: String,
168 pub reasoning: Option<String>,
170 pub is_final: bool,
172 pub token_count: usize,
174}
175
176impl StreamChunk {
177 pub fn delta(text: impl Into<String>) -> Self {
179 Self {
180 delta: text.into(),
181 reasoning: None,
182 is_final: false,
183 token_count: 0,
184 }
185 }
186
187 pub fn reasoning(text: impl Into<String>) -> Self {
189 Self {
190 delta: String::new(),
191 reasoning: Some(text.into()),
192 is_final: false,
193 token_count: 0,
194 }
195 }
196
197 pub fn final_chunk() -> Self {
199 Self {
200 delta: String::new(),
201 reasoning: None,
202 is_final: true,
203 token_count: 0,
204 }
205 }
206
207 pub fn error(message: impl Into<String>) -> Self {
209 Self {
210 delta: message.into(),
211 reasoning: None,
212 is_final: true,
213 token_count: 0,
214 }
215 }
216
217 pub fn with_token_estimate(mut self) -> Self {
219 self.token_count = self.delta.len().div_ceil(4);
220 self
221 }
222}
223
224#[derive(Debug, Clone)]
229pub enum StreamEvent {
230 TextDelta(StreamChunk),
232 ToolCall(ToolCall),
234 PreExecutedToolCall { name: String, args: String },
237 PreExecutedToolResult { name: String, output: String },
239 Usage(TokenUsage),
242 Final,
244}
245
246impl StreamEvent {
247 pub fn from_chunk(chunk: StreamChunk) -> Self {
248 if chunk.is_final {
249 Self::Final
250 } else {
251 Self::TextDelta(chunk)
252 }
253 }
254}
255
256#[derive(Debug, Clone, Copy, Default)]
258pub struct StreamOptions {
259 pub enabled: bool,
261 pub count_tokens: bool,
263}
264
265impl StreamOptions {
266 pub fn new(enabled: bool) -> Self {
268 Self {
269 enabled,
270 count_tokens: false,
271 }
272 }
273
274 pub fn with_token_count(mut self) -> Self {
276 self.count_tokens = true;
277 self
278 }
279}
280
281pub type StreamResult<T> = std::result::Result<T, StreamError>;
283
284#[derive(Debug, thiserror::Error)]
286pub enum StreamError {
287 #[error("HTTP error: {0}")]
288 Http(String),
289
290 #[error("JSON parse error: {0}")]
291 Json(serde_json::Error),
292
293 #[error("Invalid SSE format: {0}")]
294 InvalidSse(String),
295
296 #[error("ModelProvider error: {0}")]
297 ModelProvider(String),
298
299 #[error("IO error: {0}")]
300 Io(#[from] std::io::Error),
301}
302
303#[derive(Debug, Clone, thiserror::Error)]
305#[error(
306 "provider_capability_error model_provider={model_provider} capability={capability} message={message}"
307)]
308pub struct ProviderCapabilityError {
309 pub model_provider: String,
310 pub capability: String,
311 pub message: String,
312}
313
314#[allow(clippy::struct_excessive_bools)]
319#[derive(Debug, Clone, Default, PartialEq, Eq)]
320pub struct ProviderCapabilities {
321 pub native_tool_calling: bool,
323 pub vision: bool,
325 pub prompt_caching: bool,
327 pub extended_thinking: bool,
329}
330
331#[derive(Debug, Clone)]
333pub enum ToolsPayload {
334 Gemini {
336 function_declarations: Vec<serde_json::Value>,
337 },
338 Anthropic { tools: Vec<serde_json::Value> },
340 OpenAI { tools: Vec<serde_json::Value> },
342 PromptGuided { instructions: String },
344}
345
346pub const BASELINE_TEMPERATURE: f64 = 0.7;
350
351pub const BASELINE_MAX_TOKENS: u32 = 4096;
355
356pub const BASELINE_TIMEOUT_SECS: u64 = 120;
360
361pub const BASELINE_WIRE_API: &str = "chat_completions";
365
366#[derive(Debug, Clone, Deserialize, Serialize)]
372pub struct ModelPricing {
373 #[serde(default, skip_serializing_if = "Option::is_none")]
375 pub prompt: Option<String>,
376 #[serde(default, skip_serializing_if = "Option::is_none")]
378 pub completion: Option<String>,
379 #[serde(default, skip_serializing_if = "Option::is_none")]
382 pub input_cache_read: Option<String>,
383 #[serde(default, skip_serializing_if = "Option::is_none")]
386 pub input_cache_write: Option<String>,
387}
388
389#[derive(Debug, Clone, Serialize)]
391pub struct ModelInfo {
392 pub id: String,
393 #[serde(skip_serializing_if = "Option::is_none")]
394 pub pricing: Option<ModelPricing>,
395}
396
397#[async_trait]
398pub trait ModelProvider: Send + Sync + crate::attribution::Attributable {
399 fn capabilities(&self) -> ProviderCapabilities {
401 ProviderCapabilities::default()
402 }
403
404 fn default_temperature(&self) -> f64 {
415 BASELINE_TEMPERATURE
416 }
417
418 fn default_max_tokens(&self) -> u32 {
420 BASELINE_MAX_TOKENS
421 }
422
423 fn default_timeout_secs(&self) -> u64 {
425 BASELINE_TIMEOUT_SECS
426 }
427
428 fn default_base_url(&self) -> Option<&str> {
433 None
434 }
435
436 fn default_wire_api(&self) -> &str {
440 BASELINE_WIRE_API
441 }
442
443 fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
445 ToolsPayload::PromptGuided {
446 instructions: build_tool_instructions_text(tools),
447 }
448 }
449
450 async fn simple_chat(
454 &self,
455 message: &str,
456 model: &str,
457 temperature: Option<f64>,
458 ) -> anyhow::Result<String> {
459 self.chat_with_system(None, message, model, temperature)
460 .await
461 }
462
463 async fn chat_with_system(
466 &self,
467 system_prompt: Option<&str>,
468 message: &str,
469 model: &str,
470 temperature: Option<f64>,
471 ) -> anyhow::Result<String>;
472
473 async fn list_models(&self) -> anyhow::Result<Vec<String>> {
480 anyhow::bail!("live model listing is not supported for this model_provider")
481 }
482
483 async fn list_models_with_pricing(&self) -> anyhow::Result<Vec<ModelInfo>> {
488 Ok(self
489 .list_models()
490 .await?
491 .into_iter()
492 .map(|id| ModelInfo { id, pricing: None })
493 .collect())
494 }
495
496 async fn chat_with_history(
499 &self,
500 messages: &[ChatMessage],
501 model: &str,
502 temperature: Option<f64>,
503 ) -> anyhow::Result<String> {
504 let system = messages
505 .iter()
506 .find(|m| m.role == "system")
507 .map(|m| m.content.as_str());
508 let last_user = messages
509 .iter()
510 .rfind(|m| m.role == "user")
511 .map(|m| m.content.as_str())
512 .unwrap_or("");
513 self.chat_with_system(system, last_user, model, temperature)
514 .await
515 }
516
517 async fn chat(
520 &self,
521 request: ChatRequest<'_>,
522 model: &str,
523 temperature: Option<f64>,
524 ) -> anyhow::Result<ChatResponse> {
525 if let Some(tools) = request.tools
526 && !tools.is_empty()
527 && !self.supports_native_tools()
528 {
529 let tool_instructions = match self.convert_tools(tools) {
530 ToolsPayload::PromptGuided { instructions } => instructions,
531 payload => {
532 anyhow::bail!(
533 "ModelProvider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
534 )
535 }
536 };
537 let mut modified_messages = request.messages.to_vec();
538
539 if let Some(system_message) = modified_messages.iter_mut().find(|m| m.role == "system")
540 {
541 if !system_message.content.is_empty() {
542 system_message.content.push_str("\n\n");
543 }
544 system_message.content.push_str(&tool_instructions);
545 } else {
546 modified_messages.insert(0, ChatMessage::system(tool_instructions));
547 }
548
549 let text = self
550 .chat_with_history(&modified_messages, model, temperature)
551 .await?;
552 return Ok(ChatResponse {
553 text: Some(text),
554 tool_calls: Vec::new(),
555 usage: None,
556 reasoning_content: None,
557 });
558 }
559
560 let text = self
561 .chat_with_history(request.messages, model, temperature)
562 .await?;
563 Ok(ChatResponse {
564 text: Some(text),
565 tool_calls: Vec::new(),
566 usage: None,
567 reasoning_content: None,
568 })
569 }
570
571 fn supports_native_tools(&self) -> bool {
573 self.capabilities().native_tool_calling
574 }
575
576 fn supports_vision(&self) -> bool {
578 self.capabilities().vision
579 }
580
581 async fn warmup(&self) -> anyhow::Result<()> {
583 Ok(())
584 }
585
586 async fn chat_with_tools(
589 &self,
590 messages: &[ChatMessage],
591 _tools: &[serde_json::Value],
592 model: &str,
593 temperature: Option<f64>,
594 ) -> anyhow::Result<ChatResponse> {
595 let text = self.chat_with_history(messages, model, temperature).await?;
596 Ok(ChatResponse {
597 text: Some(text),
598 tool_calls: Vec::new(),
599 usage: None,
600 reasoning_content: None,
601 })
602 }
603
604 fn supports_streaming(&self) -> bool {
606 false
607 }
608
609 fn supports_streaming_tool_events(&self) -> bool {
611 false
612 }
613
614 fn stream_chat_with_system(
617 &self,
618 _system_prompt: Option<&str>,
619 _message: &str,
620 _model: &str,
621 _temperature: Option<f64>,
622 _options: StreamOptions,
623 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
624 stream::empty().boxed()
625 }
626
627 fn stream_chat_with_history(
630 &self,
631 messages: &[ChatMessage],
632 model: &str,
633 temperature: Option<f64>,
634 options: StreamOptions,
635 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
636 let system = messages
637 .iter()
638 .find(|m| m.role == "system")
639 .map(|m| m.content.as_str());
640 let last_user = messages
641 .iter()
642 .rfind(|m| m.role == "user")
643 .map(|m| m.content.as_str())
644 .unwrap_or("");
645 self.stream_chat_with_system(system, last_user, model, temperature, options)
646 }
647
648 fn stream_chat(
651 &self,
652 request: ChatRequest<'_>,
653 model: &str,
654 temperature: Option<f64>,
655 options: StreamOptions,
656 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
657 self.stream_chat_with_history(request.messages, model, temperature, options)
658 .map(|chunk_result| chunk_result.map(StreamEvent::from_chunk))
659 .boxed()
660 }
661}
662
663#[async_trait]
668impl<T: ModelProvider + ?Sized> ModelProvider for Arc<T> {
669 fn capabilities(&self) -> ProviderCapabilities {
670 self.as_ref().capabilities()
671 }
672
673 fn default_max_tokens(&self) -> u32 {
674 self.as_ref().default_max_tokens()
675 }
676
677 fn default_temperature(&self) -> f64 {
678 self.as_ref().default_temperature()
679 }
680
681 fn default_timeout_secs(&self) -> u64 {
682 self.as_ref().default_timeout_secs()
683 }
684
685 fn default_base_url(&self) -> Option<&str> {
686 self.as_ref().default_base_url()
687 }
688
689 fn default_wire_api(&self) -> &str {
690 self.as_ref().default_wire_api()
691 }
692
693 fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
694 self.as_ref().convert_tools(tools)
695 }
696
697 fn supports_native_tools(&self) -> bool {
698 self.as_ref().supports_native_tools()
699 }
700
701 fn supports_vision(&self) -> bool {
702 self.as_ref().supports_vision()
703 }
704
705 async fn chat_with_system(
706 &self,
707 system_prompt: Option<&str>,
708 message: &str,
709 model: &str,
710 temperature: Option<f64>,
711 ) -> anyhow::Result<String> {
712 self.as_ref()
713 .chat_with_system(system_prompt, message, model, temperature)
714 .await
715 }
716
717 async fn chat_with_history(
718 &self,
719 messages: &[ChatMessage],
720 model: &str,
721 temperature: Option<f64>,
722 ) -> anyhow::Result<String> {
723 self.as_ref()
724 .chat_with_history(messages, model, temperature)
725 .await
726 }
727
728 async fn chat(
729 &self,
730 request: ChatRequest<'_>,
731 model: &str,
732 temperature: Option<f64>,
733 ) -> anyhow::Result<ChatResponse> {
734 self.as_ref().chat(request, model, temperature).await
735 }
736
737 async fn warmup(&self) -> anyhow::Result<()> {
738 self.as_ref().warmup().await
739 }
740
741 async fn chat_with_tools(
742 &self,
743 messages: &[ChatMessage],
744 tools: &[serde_json::Value],
745 model: &str,
746 temperature: Option<f64>,
747 ) -> anyhow::Result<ChatResponse> {
748 self.as_ref()
749 .chat_with_tools(messages, tools, model, temperature)
750 .await
751 }
752
753 fn supports_streaming(&self) -> bool {
754 self.as_ref().supports_streaming()
755 }
756
757 fn supports_streaming_tool_events(&self) -> bool {
758 self.as_ref().supports_streaming_tool_events()
759 }
760
761 fn stream_chat_with_system(
762 &self,
763 system_prompt: Option<&str>,
764 message: &str,
765 model: &str,
766 temperature: Option<f64>,
767 options: StreamOptions,
768 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
769 self.as_ref()
770 .stream_chat_with_system(system_prompt, message, model, temperature, options)
771 }
772
773 fn stream_chat_with_history(
774 &self,
775 messages: &[ChatMessage],
776 model: &str,
777 temperature: Option<f64>,
778 options: StreamOptions,
779 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
780 self.as_ref()
781 .stream_chat_with_history(messages, model, temperature, options)
782 }
783
784 fn stream_chat(
785 &self,
786 request: ChatRequest<'_>,
787 model: &str,
788 temperature: Option<f64>,
789 options: StreamOptions,
790 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
791 self.as_ref()
792 .stream_chat(request, model, temperature, options)
793 }
794}
795
796pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String {
798 let mut instructions = String::new();
799
800 instructions.push_str("## Tool Use Protocol\n\n");
801 instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
802 instructions.push_str("<tool_call>\n");
803 instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#);
804 instructions.push_str("\n</tool_call>\n\n");
805 instructions.push_str("You may use multiple tool calls in a single response. ");
806 instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
807 instructions
808 .push_str("Continue reasoning with the results until you can give a final answer.\n\n");
809 instructions.push_str("### Available Tools\n\n");
810
811 for tool in tools {
812 writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description)
813 .expect("writing to String cannot fail");
814
815 let parameters =
816 serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string());
817 writeln!(&mut instructions, "Parameters: `{parameters}`")
818 .expect("writing to String cannot fail");
819 instructions.push('\n');
820 }
821
822 instructions
823}