1use super::history::canonicalize_tool_result_media_markers;
2use crate::tools::{Tool, ToolSpec};
3use serde_json::Value;
4use std::fmt::Write;
5use zeroclaw_providers::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage};
6
7#[derive(Debug, Clone)]
8pub struct ParsedToolCall {
9 pub name: String,
10 pub arguments: Value,
11 pub tool_call_id: Option<String>,
12}
13
14#[derive(Debug, Clone)]
15pub struct ToolExecutionResult {
16 pub name: String,
17 pub output: String,
18 pub success: bool,
19 pub tool_call_id: Option<String>,
20}
21
22pub trait ToolDispatcher: Send + Sync {
23 fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>);
24 fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage;
25 fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String;
26 fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage>;
27 fn should_send_tool_specs(&self) -> bool;
28}
29
30#[derive(Default)]
31pub struct XmlToolDispatcher;
32
33impl XmlToolDispatcher {
34 fn parse_xml_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
35 let cleaned = Self::strip_think_tags(response);
38 let mut text_parts = Vec::new();
39 let mut calls = Vec::new();
40 let mut remaining = cleaned.as_str();
41
42 while let Some(start) = remaining.find("<tool_call>") {
43 let before = &remaining[..start];
44 if !before.trim().is_empty() {
45 text_parts.push(before.trim().to_string());
46 }
47
48 if let Some(end) = remaining[start..].find("</tool_call>") {
49 let inner = &remaining[start + 11..start + end];
50 match serde_json::from_str::<Value>(inner.trim()) {
51 Ok(parsed) => {
52 let name = parsed
53 .get("name")
54 .and_then(Value::as_str)
55 .unwrap_or("")
56 .to_string();
57 if name.is_empty() {
58 remaining = &remaining[start + end + 12..];
59 continue;
60 }
61 let arguments = parsed
62 .get("arguments")
63 .cloned()
64 .unwrap_or_else(|| Value::Object(serde_json::Map::new()));
65 calls.push(ParsedToolCall {
66 name,
67 arguments,
68 tool_call_id: None,
69 });
70 }
71 Err(e) => {
72 ::zeroclaw_log::record!(
73 WARN,
74 ::zeroclaw_log::Event::new(
75 module_path!(),
76 ::zeroclaw_log::Action::Note
77 )
78 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
79 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
80 "Malformed <tool_call> JSON"
81 );
82 }
83 }
84 remaining = &remaining[start + end + 12..];
85 } else {
86 break;
87 }
88 }
89
90 if !remaining.trim().is_empty() {
91 text_parts.push(remaining.trim().to_string());
92 }
93
94 (text_parts.join("\n"), calls)
95 }
96
97 fn strip_think_tags(s: &str) -> String {
99 let mut result = String::with_capacity(s.len());
100 let mut rest = s;
101 loop {
102 if let Some(start) = rest.find("<think>") {
103 result.push_str(&rest[..start]);
104 if let Some(end) = rest[start..].find("</think>") {
105 rest = &rest[start + end + "</think>".len()..];
106 } else {
107 break;
108 }
109 } else {
110 result.push_str(rest);
111 break;
112 }
113 }
114 result
115 }
116
117 pub fn tool_specs(tools: &[Box<dyn Tool>]) -> Vec<ToolSpec> {
118 tools.iter().map(|tool| tool.spec()).collect()
119 }
120}
121
122impl ToolDispatcher for XmlToolDispatcher {
123 fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
124 let text = response.text_or_empty();
125 Self::parse_xml_tool_calls(text)
126 }
127
128 fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
129 let mut content = String::new();
130 for result in results {
131 let status = if result.success { "ok" } else { "error" };
132 let output = canonicalize_tool_result_media_markers(&result.output);
133 let _ = writeln!(
134 content,
135 "<tool_result name=\"{}\" status=\"{}\">\n{}\n</tool_result>",
136 result.name, status, output
137 );
138 }
139 ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}")))
140 }
141
142 fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String {
143 if tools.is_empty() {
144 return String::new();
145 }
146
147 let mut instructions = String::new();
148 instructions.push_str("## Tool Use Protocol\n\n");
149 instructions
150 .push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
151 instructions.push_str(
152 "```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n",
153 );
154
155 instructions
156 }
157
158 fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
159 history
160 .iter()
161 .flat_map(|msg| match msg {
162 ConversationMessage::Chat(chat) => vec![chat.clone()],
163 ConversationMessage::AssistantToolCalls { text, .. } => {
164 vec![ChatMessage::assistant(text.clone().unwrap_or_default())]
165 }
166 ConversationMessage::ToolResults(results) => {
167 let mut content = String::new();
168 for result in results {
169 let output = canonicalize_tool_result_media_markers(&result.content);
170 let _ = writeln!(
171 content,
172 "<tool_result id=\"{}\">\n{}\n</tool_result>",
173 result.tool_call_id, output
174 );
175 }
176 vec![ChatMessage::user(format!("[Tool results]\n{content}"))]
177 }
178 })
179 .collect()
180 }
181
182 fn should_send_tool_specs(&self) -> bool {
183 false
184 }
185}
186
187pub struct NativeToolDispatcher;
188
189impl ToolDispatcher for NativeToolDispatcher {
190 fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
191 let text = response.text.clone().unwrap_or_default();
192 let calls = response
193 .tool_calls
194 .iter()
195 .map(|tc| ParsedToolCall {
196 name: tc.name.clone(),
197 arguments: serde_json::from_str(&tc.arguments).unwrap_or_else(|e| {
198 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"tool": tc.name, "error": format!("{}", e)})), "Failed to parse native tool call arguments as JSON; defaulting to empty object");
199 Value::Object(serde_json::Map::new())
200 }),
201 tool_call_id: Some(tc.id.clone()),
202 })
203 .collect();
204 (text, calls)
205 }
206
207 fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
208 let messages = results
209 .iter()
210 .map(|result| ToolResultMessage {
211 tool_call_id: result
212 .tool_call_id
213 .clone()
214 .unwrap_or_else(|| "unknown".to_string()),
215 content: canonicalize_tool_result_media_markers(&result.output),
216 })
217 .collect();
218 ConversationMessage::ToolResults(messages)
219 }
220
221 fn prompt_instructions(&self, _tools: &[Box<dyn Tool>]) -> String {
222 String::new()
223 }
224
225 fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
226 history
227 .iter()
228 .flat_map(|msg| match msg {
229 ConversationMessage::Chat(chat) => vec![chat.clone()],
230 ConversationMessage::AssistantToolCalls {
231 text,
232 tool_calls,
233 reasoning_content,
234 } => {
235 let mut payload = serde_json::json!({
236 "content": text,
237 "tool_calls": tool_calls,
238 });
239 if let Some(rc) = reasoning_content {
240 payload["reasoning_content"] = serde_json::json!(rc);
241 }
242 vec![ChatMessage::assistant(payload.to_string())]
243 }
244 ConversationMessage::ToolResults(results) => results
245 .iter()
246 .map(|result| {
247 ChatMessage::tool(
248 serde_json::json!({
249 "tool_call_id": result.tool_call_id,
250 "content": canonicalize_tool_result_media_markers(&result.content),
251 })
252 .to_string(),
253 )
254 })
255 .collect(),
256 })
257 .collect()
258 }
259
260 fn should_send_tool_specs(&self) -> bool {
261 true
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn xml_dispatcher_parses_tool_calls() {
271 let response = ChatResponse {
272 text: Some(
273 "Checking\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
274 .into(),
275 ),
276 tool_calls: vec![],
277 usage: None,
278 reasoning_content: None,
279 };
280 let dispatcher = XmlToolDispatcher;
281 let (_, calls) = dispatcher.parse_response(&response);
282 assert_eq!(calls.len(), 1);
283 assert_eq!(calls[0].name, "shell");
284 }
285
286 #[test]
287 fn xml_dispatcher_strips_think_before_tool_call() {
288 let response = ChatResponse {
289 text: Some(
290 "<think>I should list files</think>\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
291 .into(),
292 ),
293 tool_calls: vec![],
294 usage: None,
295 reasoning_content: None,
296 };
297 let dispatcher = XmlToolDispatcher;
298 let (text, calls) = dispatcher.parse_response(&response);
299 assert_eq!(calls.len(), 1);
300 assert_eq!(calls[0].name, "shell");
301 assert!(
302 !text.contains("<think>"),
303 "think tags should be stripped from text"
304 );
305 }
306
307 #[test]
308 fn xml_dispatcher_think_only_returns_no_calls() {
309 let response = ChatResponse {
310 text: Some("<think>Just thinking</think>".into()),
311 tool_calls: vec![],
312 usage: None,
313 reasoning_content: None,
314 };
315 let dispatcher = XmlToolDispatcher;
316 let (_, calls) = dispatcher.parse_response(&response);
317 assert!(calls.is_empty());
318 }
319
320 #[test]
321 fn native_dispatcher_roundtrip() {
322 let response = ChatResponse {
323 text: Some("ok".into()),
324 tool_calls: vec![zeroclaw_providers::ToolCall {
325 id: "tc1".into(),
326 name: "file_read".into(),
327 arguments: "{\"path\":\"a.txt\"}".into(),
328 extra_content: None,
329 }],
330 usage: None,
331 reasoning_content: None,
332 };
333 let dispatcher = NativeToolDispatcher;
334 let (_, calls) = dispatcher.parse_response(&response);
335 assert_eq!(calls.len(), 1);
336 assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1"));
337
338 let msg = dispatcher.format_results(&[ToolExecutionResult {
339 name: "file_read".into(),
340 output: "hello".into(),
341 success: true,
342 tool_call_id: Some("tc1".into()),
343 }]);
344 match msg {
345 ConversationMessage::ToolResults(results) => {
346 assert_eq!(results.len(), 1);
347 assert_eq!(results[0].tool_call_id, "tc1");
348 }
349 _ => panic!("expected tool results"),
350 }
351 }
352
353 #[test]
354 fn xml_format_results_contains_tool_result_tags() {
355 let dispatcher = XmlToolDispatcher;
356 let msg = dispatcher.format_results(&[ToolExecutionResult {
357 name: "shell".into(),
358 output: "ok".into(),
359 success: true,
360 tool_call_id: None,
361 }]);
362 let rendered = match msg {
363 ConversationMessage::Chat(chat) => chat.content,
364 _ => String::new(),
365 };
366 assert!(rendered.contains("<tool_result"));
367 assert!(rendered.contains("shell"));
368 }
369
370 #[test]
371 fn native_format_results_keeps_tool_call_id() {
372 let dispatcher = NativeToolDispatcher;
373 let msg = dispatcher.format_results(&[ToolExecutionResult {
374 name: "shell".into(),
375 output: "ok".into(),
376 success: true,
377 tool_call_id: Some("tc-1".into()),
378 }]);
379
380 match msg {
381 ConversationMessage::ToolResults(results) => {
382 assert_eq!(results.len(), 1);
383 assert_eq!(results[0].tool_call_id, "tc-1");
384 }
385 _ => panic!("expected ToolResults variant"),
386 }
387 }
388
389 #[test]
394 fn native_to_provider_messages_includes_reasoning_content() {
395 let dispatcher = NativeToolDispatcher;
396 let history = vec![ConversationMessage::AssistantToolCalls {
397 text: Some("answer".into()),
398 tool_calls: vec![zeroclaw_providers::ToolCall {
399 id: "tc_1".into(),
400 name: "shell".into(),
401 arguments: "{}".into(),
402 extra_content: None,
403 }],
404 reasoning_content: Some("thinking step".into()),
405 }];
406
407 let messages = dispatcher.to_provider_messages(&history);
408 assert_eq!(messages.len(), 1);
409 assert_eq!(messages[0].role, "assistant");
410
411 let payload: serde_json::Value = serde_json::from_str(&messages[0].content).unwrap();
412 assert_eq!(payload["reasoning_content"].as_str(), Some("thinking step"));
413 assert_eq!(payload["content"].as_str(), Some("answer"));
414 assert!(payload["tool_calls"].is_array());
415 }
416
417 #[test]
418 fn native_to_provider_messages_omits_reasoning_content_when_none() {
419 let dispatcher = NativeToolDispatcher;
420 let history = vec![ConversationMessage::AssistantToolCalls {
421 text: Some("answer".into()),
422 tool_calls: vec![zeroclaw_providers::ToolCall {
423 id: "tc_1".into(),
424 name: "shell".into(),
425 arguments: "{}".into(),
426 extra_content: None,
427 }],
428 reasoning_content: None,
429 }];
430
431 let messages = dispatcher.to_provider_messages(&history);
432 assert_eq!(messages.len(), 1);
433
434 let payload: serde_json::Value = serde_json::from_str(&messages[0].content).unwrap();
435 assert!(payload.get("reasoning_content").is_none());
436 }
437
438 #[test]
439 fn xml_to_provider_messages_ignores_reasoning_content() {
440 let dispatcher = XmlToolDispatcher;
441 let history = vec![ConversationMessage::AssistantToolCalls {
442 text: Some("answer".into()),
443 tool_calls: vec![zeroclaw_providers::ToolCall {
444 id: "tc_1".into(),
445 name: "shell".into(),
446 arguments: "{}".into(),
447 extra_content: None,
448 }],
449 reasoning_content: Some("should be ignored".into()),
450 }];
451
452 let messages = dispatcher.to_provider_messages(&history);
453 assert_eq!(messages.len(), 1);
454 assert_eq!(messages[0].role, "assistant");
455 assert_eq!(messages[0].content, "answer");
457 assert!(!messages[0].content.contains("reasoning_content"));
458 }
459}