1use crate::tool::ToolSpec;
2use async_trait::async_trait;
3use futures_util::{StreamExt, stream};
4use serde::{Deserialize, Serialize};
5use std::fmt::Write;
6use std::sync::Arc;
7
8pub const MAX_BUDGET_TOKENS: u32 = 128_000;
9pub const MIN_BUDGET_TOKENS: u32 = 1_024;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub struct NativeThinkingParams {
17 pub budget_tokens: u32,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ChatMessage {
23 pub role: String,
24 pub content: String,
25}
26
27impl ChatMessage {
28 pub fn system(content: impl Into<String>) -> Self {
29 Self {
30 role: "system".into(),
31 content: content.into(),
32 }
33 }
34
35 pub fn user(content: impl Into<String>) -> Self {
36 Self {
37 role: "user".into(),
38 content: content.into(),
39 }
40 }
41
42 pub fn assistant(content: impl Into<String>) -> Self {
43 Self {
44 role: "assistant".into(),
45 content: content.into(),
46 }
47 }
48
49 pub fn tool(content: impl Into<String>) -> Self {
50 Self {
51 role: "tool".into(),
52 content: content.into(),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ToolCall {
60 pub id: String,
61 pub name: String,
62 pub arguments: String,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
67 pub extra_content: Option<serde_json::Value>,
68}
69
70#[derive(Debug, Clone, Default)]
72pub struct TokenUsage {
73 pub input_tokens: Option<u64>,
74 pub output_tokens: Option<u64>,
75 pub cached_input_tokens: Option<u64>,
78}
79
80#[derive(Debug, Clone)]
82pub struct ChatResponse {
83 pub text: Option<String>,
85 pub tool_calls: Vec<ToolCall>,
87 pub usage: Option<TokenUsage>,
89 pub reasoning_content: Option<String>,
94}
95
96impl ChatResponse {
97 pub fn has_tool_calls(&self) -> bool {
99 !self.tool_calls.is_empty()
100 }
101
102 pub fn text_or_empty(&self) -> &str {
104 self.text.as_deref().unwrap_or("")
105 }
106}
107
108#[derive(Debug, Clone, Copy)]
110pub struct ChatRequest<'a> {
111 pub messages: &'a [ChatMessage],
112 pub tools: Option<&'a [ToolSpec]>,
113 pub thinking: Option<NativeThinkingParams>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ToolResultMessage {
122 pub tool_call_id: String,
123 pub content: String,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128#[serde(tag = "type", content = "data")]
129pub enum ConversationMessage {
130 Chat(ChatMessage),
132 AssistantToolCalls {
134 text: Option<String>,
135 tool_calls: Vec<ToolCall>,
136 reasoning_content: Option<String>,
139 },
140 ToolResults(Vec<ToolResultMessage>),
142}
143
144#[derive(Debug, Clone)]
146pub struct StreamChunk {
147 pub delta: String,
149 pub reasoning: Option<String>,
151 pub is_final: bool,
153 pub token_count: usize,
155}
156
157impl StreamChunk {
158 pub fn delta(text: impl Into<String>) -> Self {
160 Self {
161 delta: text.into(),
162 reasoning: None,
163 is_final: false,
164 token_count: 0,
165 }
166 }
167
168 pub fn reasoning(text: impl Into<String>) -> Self {
170 Self {
171 delta: String::new(),
172 reasoning: Some(text.into()),
173 is_final: false,
174 token_count: 0,
175 }
176 }
177
178 pub fn final_chunk() -> Self {
180 Self {
181 delta: String::new(),
182 reasoning: None,
183 is_final: true,
184 token_count: 0,
185 }
186 }
187
188 pub fn error(message: impl Into<String>) -> Self {
190 Self {
191 delta: message.into(),
192 reasoning: None,
193 is_final: true,
194 token_count: 0,
195 }
196 }
197
198 pub fn with_token_estimate(mut self) -> Self {
200 self.token_count = self.delta.len().div_ceil(4);
201 self
202 }
203}
204
205#[derive(Debug, Clone)]
210pub enum StreamEvent {
211 TextDelta(StreamChunk),
213 ToolCall(ToolCall),
215 PreExecutedToolCall { name: String, args: String },
218 PreExecutedToolResult { name: String, output: String },
220 Usage(TokenUsage),
223 Final,
225}
226
227impl StreamEvent {
228 pub fn from_chunk(chunk: StreamChunk) -> Self {
229 if chunk.is_final {
230 Self::Final
231 } else {
232 Self::TextDelta(chunk)
233 }
234 }
235}
236
237#[derive(Debug, Clone, Copy, Default)]
239pub struct StreamOptions {
240 pub enabled: bool,
242 pub count_tokens: bool,
244}
245
246impl StreamOptions {
247 pub fn new(enabled: bool) -> Self {
249 Self {
250 enabled,
251 count_tokens: false,
252 }
253 }
254
255 pub fn with_token_count(mut self) -> Self {
257 self.count_tokens = true;
258 self
259 }
260}
261
262pub type StreamResult<T> = std::result::Result<T, StreamError>;
264
265#[derive(Debug, thiserror::Error)]
267pub enum StreamError {
268 #[error("HTTP error: {0}")]
269 Http(String),
270
271 #[error("JSON parse error: {0}")]
272 Json(serde_json::Error),
273
274 #[error("Invalid SSE format: {0}")]
275 InvalidSse(String),
276
277 #[error("ModelProvider error: {0}")]
278 ModelProvider(String),
279
280 #[error("IO error: {0}")]
281 Io(#[from] std::io::Error),
282}
283
284#[derive(Debug, Clone, thiserror::Error)]
286#[error(
287 "provider_capability_error model_provider={model_provider} capability={capability} message={message}"
288)]
289pub struct ProviderCapabilityError {
290 pub model_provider: String,
291 pub capability: String,
292 pub message: String,
293}
294
295#[allow(clippy::struct_excessive_bools)]
300#[derive(Debug, Clone, Default, PartialEq, Eq)]
301pub struct ProviderCapabilities {
302 pub native_tool_calling: bool,
304 pub vision: bool,
306 pub prompt_caching: bool,
308 pub extended_thinking: bool,
310}
311
312#[derive(Debug, Clone)]
314pub enum ToolsPayload {
315 Gemini {
317 function_declarations: Vec<serde_json::Value>,
318 },
319 Anthropic { tools: Vec<serde_json::Value> },
321 OpenAI { tools: Vec<serde_json::Value> },
323 PromptGuided { instructions: String },
325}
326
327pub const BASELINE_TEMPERATURE: f64 = 0.7;
331
332pub const BASELINE_MAX_TOKENS: u32 = 4096;
336
337pub const BASELINE_TIMEOUT_SECS: u64 = 120;
341
342pub const BASELINE_WIRE_API: &str = "chat_completions";
346
347#[async_trait]
348pub trait ModelProvider: Send + Sync + crate::attribution::Attributable {
349 fn capabilities(&self) -> ProviderCapabilities {
351 ProviderCapabilities::default()
352 }
353
354 fn default_temperature(&self) -> f64 {
363 BASELINE_TEMPERATURE
364 }
365
366 fn default_max_tokens(&self) -> u32 {
368 BASELINE_MAX_TOKENS
369 }
370
371 fn default_timeout_secs(&self) -> u64 {
373 BASELINE_TIMEOUT_SECS
374 }
375
376 fn default_base_url(&self) -> Option<&str> {
381 None
382 }
383
384 fn default_wire_api(&self) -> &str {
388 BASELINE_WIRE_API
389 }
390
391 fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
393 ToolsPayload::PromptGuided {
394 instructions: build_tool_instructions_text(tools),
395 }
396 }
397
398 async fn simple_chat(
403 &self,
404 message: &str,
405 model: &str,
406 temperature: Option<f64>,
407 ) -> anyhow::Result<String> {
408 self.chat_with_system(None, message, model, temperature)
409 .await
410 }
411
412 async fn chat_with_system(
415 &self,
416 system_prompt: Option<&str>,
417 message: &str,
418 model: &str,
419 temperature: Option<f64>,
420 ) -> anyhow::Result<String>;
421
422 async fn list_models(&self) -> anyhow::Result<Vec<String>> {
429 anyhow::bail!("live model listing is not supported for this model_provider")
430 }
431
432 async fn chat_with_history(
435 &self,
436 messages: &[ChatMessage],
437 model: &str,
438 temperature: Option<f64>,
439 ) -> anyhow::Result<String> {
440 let system = messages
441 .iter()
442 .find(|m| m.role == "system")
443 .map(|m| m.content.as_str());
444 let last_user = messages
445 .iter()
446 .rfind(|m| m.role == "user")
447 .map(|m| m.content.as_str())
448 .unwrap_or("");
449 self.chat_with_system(system, last_user, model, temperature)
450 .await
451 }
452
453 async fn chat(
456 &self,
457 request: ChatRequest<'_>,
458 model: &str,
459 temperature: Option<f64>,
460 ) -> anyhow::Result<ChatResponse> {
461 if let Some(tools) = request.tools
462 && !tools.is_empty()
463 && !self.supports_native_tools()
464 {
465 let tool_instructions = match self.convert_tools(tools) {
466 ToolsPayload::PromptGuided { instructions } => instructions,
467 payload => {
468 anyhow::bail!(
469 "ModelProvider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
470 )
471 }
472 };
473 let mut modified_messages = request.messages.to_vec();
474
475 if let Some(system_message) = modified_messages.iter_mut().find(|m| m.role == "system")
476 {
477 if !system_message.content.is_empty() {
478 system_message.content.push_str("\n\n");
479 }
480 system_message.content.push_str(&tool_instructions);
481 } else {
482 modified_messages.insert(0, ChatMessage::system(tool_instructions));
483 }
484
485 let text = self
486 .chat_with_history(&modified_messages, model, temperature)
487 .await?;
488 return Ok(ChatResponse {
489 text: Some(text),
490 tool_calls: Vec::new(),
491 usage: None,
492 reasoning_content: None,
493 });
494 }
495
496 let text = self
497 .chat_with_history(request.messages, model, temperature)
498 .await?;
499 Ok(ChatResponse {
500 text: Some(text),
501 tool_calls: Vec::new(),
502 usage: None,
503 reasoning_content: None,
504 })
505 }
506
507 fn supports_native_tools(&self) -> bool {
509 self.capabilities().native_tool_calling
510 }
511
512 fn supports_vision(&self) -> bool {
514 self.capabilities().vision
515 }
516
517 async fn warmup(&self) -> anyhow::Result<()> {
519 Ok(())
520 }
521
522 async fn chat_with_tools(
525 &self,
526 messages: &[ChatMessage],
527 _tools: &[serde_json::Value],
528 model: &str,
529 temperature: Option<f64>,
530 ) -> anyhow::Result<ChatResponse> {
531 let text = self.chat_with_history(messages, model, temperature).await?;
532 Ok(ChatResponse {
533 text: Some(text),
534 tool_calls: Vec::new(),
535 usage: None,
536 reasoning_content: None,
537 })
538 }
539
540 fn supports_streaming(&self) -> bool {
542 false
543 }
544
545 fn supports_streaming_tool_events(&self) -> bool {
547 false
548 }
549
550 fn stream_chat_with_system(
553 &self,
554 _system_prompt: Option<&str>,
555 _message: &str,
556 _model: &str,
557 _temperature: Option<f64>,
558 _options: StreamOptions,
559 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
560 stream::empty().boxed()
561 }
562
563 fn stream_chat_with_history(
566 &self,
567 messages: &[ChatMessage],
568 model: &str,
569 temperature: Option<f64>,
570 options: StreamOptions,
571 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
572 let system = messages
573 .iter()
574 .find(|m| m.role == "system")
575 .map(|m| m.content.as_str());
576 let last_user = messages
577 .iter()
578 .rfind(|m| m.role == "user")
579 .map(|m| m.content.as_str())
580 .unwrap_or("");
581 self.stream_chat_with_system(system, last_user, model, temperature, options)
582 }
583
584 fn stream_chat(
587 &self,
588 request: ChatRequest<'_>,
589 model: &str,
590 temperature: Option<f64>,
591 options: StreamOptions,
592 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
593 self.stream_chat_with_history(request.messages, model, temperature, options)
594 .map(|chunk_result| chunk_result.map(StreamEvent::from_chunk))
595 .boxed()
596 }
597}
598
599#[async_trait]
604impl<T: ModelProvider + ?Sized> ModelProvider for Arc<T> {
605 fn capabilities(&self) -> ProviderCapabilities {
606 self.as_ref().capabilities()
607 }
608
609 fn default_temperature(&self) -> f64 {
610 self.as_ref().default_temperature()
611 }
612
613 fn default_max_tokens(&self) -> u32 {
614 self.as_ref().default_max_tokens()
615 }
616
617 fn default_timeout_secs(&self) -> u64 {
618 self.as_ref().default_timeout_secs()
619 }
620
621 fn default_base_url(&self) -> Option<&str> {
622 self.as_ref().default_base_url()
623 }
624
625 fn default_wire_api(&self) -> &str {
626 self.as_ref().default_wire_api()
627 }
628
629 fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
630 self.as_ref().convert_tools(tools)
631 }
632
633 fn supports_native_tools(&self) -> bool {
634 self.as_ref().supports_native_tools()
635 }
636
637 fn supports_vision(&self) -> bool {
638 self.as_ref().supports_vision()
639 }
640
641 async fn chat_with_system(
642 &self,
643 system_prompt: Option<&str>,
644 message: &str,
645 model: &str,
646 temperature: Option<f64>,
647 ) -> anyhow::Result<String> {
648 self.as_ref()
649 .chat_with_system(system_prompt, message, model, temperature)
650 .await
651 }
652
653 async fn chat_with_history(
654 &self,
655 messages: &[ChatMessage],
656 model: &str,
657 temperature: Option<f64>,
658 ) -> anyhow::Result<String> {
659 self.as_ref()
660 .chat_with_history(messages, model, temperature)
661 .await
662 }
663
664 async fn chat(
665 &self,
666 request: ChatRequest<'_>,
667 model: &str,
668 temperature: Option<f64>,
669 ) -> anyhow::Result<ChatResponse> {
670 self.as_ref().chat(request, model, temperature).await
671 }
672
673 async fn warmup(&self) -> anyhow::Result<()> {
674 self.as_ref().warmup().await
675 }
676
677 async fn chat_with_tools(
678 &self,
679 messages: &[ChatMessage],
680 tools: &[serde_json::Value],
681 model: &str,
682 temperature: Option<f64>,
683 ) -> anyhow::Result<ChatResponse> {
684 self.as_ref()
685 .chat_with_tools(messages, tools, model, temperature)
686 .await
687 }
688
689 fn supports_streaming(&self) -> bool {
690 self.as_ref().supports_streaming()
691 }
692
693 fn supports_streaming_tool_events(&self) -> bool {
694 self.as_ref().supports_streaming_tool_events()
695 }
696
697 fn stream_chat_with_system(
698 &self,
699 system_prompt: Option<&str>,
700 message: &str,
701 model: &str,
702 temperature: Option<f64>,
703 options: StreamOptions,
704 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
705 self.as_ref()
706 .stream_chat_with_system(system_prompt, message, model, temperature, options)
707 }
708
709 fn stream_chat_with_history(
710 &self,
711 messages: &[ChatMessage],
712 model: &str,
713 temperature: Option<f64>,
714 options: StreamOptions,
715 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
716 self.as_ref()
717 .stream_chat_with_history(messages, model, temperature, options)
718 }
719
720 fn stream_chat(
721 &self,
722 request: ChatRequest<'_>,
723 model: &str,
724 temperature: Option<f64>,
725 options: StreamOptions,
726 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
727 self.as_ref()
728 .stream_chat(request, model, temperature, options)
729 }
730}
731
732pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String {
734 let mut instructions = String::new();
735
736 instructions.push_str("## Tool Use Protocol\n\n");
737 instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
738 instructions.push_str("<tool_call>\n");
739 instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#);
740 instructions.push_str("\n</tool_call>\n\n");
741 instructions.push_str("You may use multiple tool calls in a single response. ");
742 instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
743 instructions
744 .push_str("Continue reasoning with the results until you can give a final answer.\n\n");
745 instructions.push_str("### Available Tools\n\n");
746
747 for tool in tools {
748 writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description)
749 .expect("writing to String cannot fail");
750
751 let parameters =
752 serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string());
753 writeln!(&mut instructions, "Parameters: `{parameters}`")
754 .expect("writing to String cannot fail");
755 instructions.push('\n');
756 }
757
758 instructions
759}