Skip to main content

zeroclaw_runtime/agent/
dispatcher.rs

1use super::history::canonicalize_tool_result_media_markers;
2use crate::tools::{Tool, ToolSpec};
3use serde_json::Value;
4use std::fmt::Write;
5use zeroclaw_providers::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage};
6
7#[derive(Debug, Clone)]
8pub struct ParsedToolCall {
9    pub name: String,
10    pub arguments: Value,
11    pub tool_call_id: Option<String>,
12}
13
14#[derive(Debug, Clone)]
15pub struct ToolExecutionResult {
16    pub name: String,
17    pub output: String,
18    pub success: bool,
19    pub tool_call_id: Option<String>,
20}
21
22pub trait ToolDispatcher: Send + Sync {
23    fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>);
24    fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage;
25    fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String;
26    fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage>;
27    fn should_send_tool_specs(&self) -> bool;
28}
29
30#[derive(Default)]
31pub struct XmlToolDispatcher;
32
33impl XmlToolDispatcher {
34    fn parse_xml_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
35        // Strip `<think>...</think>` blocks before parsing tool calls.
36        // Qwen and other reasoning models may embed chain-of-thought inline.
37        let cleaned = Self::strip_think_tags(response);
38        let mut text_parts = Vec::new();
39        let mut calls = Vec::new();
40        let mut remaining = cleaned.as_str();
41
42        while let Some(start) = remaining.find("<tool_call>") {
43            let before = &remaining[..start];
44            if !before.trim().is_empty() {
45                text_parts.push(before.trim().to_string());
46            }
47
48            if let Some(end) = remaining[start..].find("</tool_call>") {
49                let inner = &remaining[start + 11..start + end];
50                match serde_json::from_str::<Value>(inner.trim()) {
51                    Ok(parsed) => {
52                        let name = parsed
53                            .get("name")
54                            .and_then(Value::as_str)
55                            .unwrap_or("")
56                            .to_string();
57                        if name.is_empty() {
58                            remaining = &remaining[start + end + 12..];
59                            continue;
60                        }
61                        let arguments = parsed
62                            .get("arguments")
63                            .cloned()
64                            .unwrap_or_else(|| Value::Object(serde_json::Map::new()));
65                        calls.push(ParsedToolCall {
66                            name,
67                            arguments,
68                            tool_call_id: None,
69                        });
70                    }
71                    Err(e) => {
72                        ::zeroclaw_log::record!(
73                            WARN,
74                            ::zeroclaw_log::Event::new(
75                                module_path!(),
76                                ::zeroclaw_log::Action::Note
77                            )
78                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
79                            .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
80                            "Malformed <tool_call> JSON"
81                        );
82                    }
83                }
84                remaining = &remaining[start + end + 12..];
85            } else {
86                break;
87            }
88        }
89
90        if !remaining.trim().is_empty() {
91            text_parts.push(remaining.trim().to_string());
92        }
93
94        (text_parts.join("\n"), calls)
95    }
96
97    /// Remove `<think>...</think>` blocks from model output.
98    fn strip_think_tags(s: &str) -> String {
99        let mut result = String::with_capacity(s.len());
100        let mut rest = s;
101        loop {
102            if let Some(start) = rest.find("<think>") {
103                result.push_str(&rest[..start]);
104                if let Some(end) = rest[start..].find("</think>") {
105                    rest = &rest[start + end + "</think>".len()..];
106                } else {
107                    break;
108                }
109            } else {
110                result.push_str(rest);
111                break;
112            }
113        }
114        result
115    }
116
117    pub fn tool_specs(tools: &[Box<dyn Tool>]) -> Vec<ToolSpec> {
118        tools.iter().map(|tool| tool.spec()).collect()
119    }
120}
121
122impl ToolDispatcher for XmlToolDispatcher {
123    fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
124        let text = response.text_or_empty();
125        Self::parse_xml_tool_calls(text)
126    }
127
128    fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
129        let mut content = String::new();
130        for result in results {
131            let status = if result.success { "ok" } else { "error" };
132            let output = canonicalize_tool_result_media_markers(&result.output);
133            let _ = writeln!(
134                content,
135                "<tool_result name=\"{}\" status=\"{}\">\n{}\n</tool_result>",
136                result.name, status, output
137            );
138        }
139        ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}")))
140    }
141
142    fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String {
143        if tools.is_empty() {
144            return String::new();
145        }
146
147        let mut instructions = String::new();
148        instructions.push_str("## Tool Use Protocol\n\n");
149        instructions
150            .push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
151        instructions.push_str(
152            "```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n",
153        );
154
155        instructions
156    }
157
158    fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
159        history
160            .iter()
161            .flat_map(|msg| match msg {
162                ConversationMessage::Chat(chat) => vec![chat.clone()],
163                ConversationMessage::AssistantToolCalls { text, .. } => {
164                    vec![ChatMessage::assistant(text.clone().unwrap_or_default())]
165                }
166                ConversationMessage::ToolResults(results) => {
167                    let mut content = String::new();
168                    for result in results {
169                        let output = canonicalize_tool_result_media_markers(&result.content);
170                        let _ = writeln!(
171                            content,
172                            "<tool_result id=\"{}\">\n{}\n</tool_result>",
173                            result.tool_call_id, output
174                        );
175                    }
176                    vec![ChatMessage::user(format!("[Tool results]\n{content}"))]
177                }
178            })
179            .collect()
180    }
181
182    fn should_send_tool_specs(&self) -> bool {
183        false
184    }
185}
186
187pub struct NativeToolDispatcher;
188
189impl ToolDispatcher for NativeToolDispatcher {
190    fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
191        let text = response.text.clone().unwrap_or_default();
192        let calls = response
193            .tool_calls
194            .iter()
195            .map(|tc| ParsedToolCall {
196                name: tc.name.clone(),
197                arguments: serde_json::from_str(&tc.arguments).unwrap_or_else(|e| {
198                    ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"tool": tc.name, "error": format!("{}", e)})), "Failed to parse native tool call arguments as JSON; defaulting to empty object");
199                    Value::Object(serde_json::Map::new())
200                }),
201                tool_call_id: Some(tc.id.clone()),
202            })
203            .collect();
204        (text, calls)
205    }
206
207    fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
208        let messages = results
209            .iter()
210            .map(|result| ToolResultMessage {
211                tool_call_id: result
212                    .tool_call_id
213                    .clone()
214                    .unwrap_or_else(|| "unknown".to_string()),
215                content: canonicalize_tool_result_media_markers(&result.output),
216            })
217            .collect();
218        ConversationMessage::ToolResults(messages)
219    }
220
221    fn prompt_instructions(&self, _tools: &[Box<dyn Tool>]) -> String {
222        String::new()
223    }
224
225    fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
226        history
227            .iter()
228            .flat_map(|msg| match msg {
229                ConversationMessage::Chat(chat) => vec![chat.clone()],
230                ConversationMessage::AssistantToolCalls {
231                    text,
232                    tool_calls,
233                    reasoning_content,
234                } => {
235                    let mut payload = serde_json::json!({
236                        "content": text,
237                        "tool_calls": tool_calls,
238                    });
239                    if let Some(rc) = reasoning_content {
240                        payload["reasoning_content"] = serde_json::json!(rc);
241                    }
242                    vec![ChatMessage::assistant(payload.to_string())]
243                }
244                ConversationMessage::ToolResults(results) => results
245                    .iter()
246                    .map(|result| {
247                        ChatMessage::tool(
248                            serde_json::json!({
249                                "tool_call_id": result.tool_call_id,
250                                "content": canonicalize_tool_result_media_markers(&result.content),
251                            })
252                            .to_string(),
253                        )
254                    })
255                    .collect(),
256            })
257            .collect()
258    }
259
260    fn should_send_tool_specs(&self) -> bool {
261        true
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn xml_dispatcher_parses_tool_calls() {
271        let response = ChatResponse {
272            text: Some(
273                "Checking\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
274                    .into(),
275            ),
276            tool_calls: vec![],
277            usage: None,
278            reasoning_content: None,
279        };
280        let dispatcher = XmlToolDispatcher;
281        let (_, calls) = dispatcher.parse_response(&response);
282        assert_eq!(calls.len(), 1);
283        assert_eq!(calls[0].name, "shell");
284    }
285
286    #[test]
287    fn xml_dispatcher_strips_think_before_tool_call() {
288        let response = ChatResponse {
289            text: Some(
290                "<think>I should list files</think>\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
291                    .into(),
292            ),
293            tool_calls: vec![],
294            usage: None,
295            reasoning_content: None,
296        };
297        let dispatcher = XmlToolDispatcher;
298        let (text, calls) = dispatcher.parse_response(&response);
299        assert_eq!(calls.len(), 1);
300        assert_eq!(calls[0].name, "shell");
301        assert!(
302            !text.contains("<think>"),
303            "think tags should be stripped from text"
304        );
305    }
306
307    #[test]
308    fn xml_dispatcher_think_only_returns_no_calls() {
309        let response = ChatResponse {
310            text: Some("<think>Just thinking</think>".into()),
311            tool_calls: vec![],
312            usage: None,
313            reasoning_content: None,
314        };
315        let dispatcher = XmlToolDispatcher;
316        let (_, calls) = dispatcher.parse_response(&response);
317        assert!(calls.is_empty());
318    }
319
320    #[test]
321    fn native_dispatcher_roundtrip() {
322        let response = ChatResponse {
323            text: Some("ok".into()),
324            tool_calls: vec![zeroclaw_providers::ToolCall {
325                id: "tc1".into(),
326                name: "file_read".into(),
327                arguments: "{\"path\":\"a.txt\"}".into(),
328                extra_content: None,
329            }],
330            usage: None,
331            reasoning_content: None,
332        };
333        let dispatcher = NativeToolDispatcher;
334        let (_, calls) = dispatcher.parse_response(&response);
335        assert_eq!(calls.len(), 1);
336        assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1"));
337
338        let msg = dispatcher.format_results(&[ToolExecutionResult {
339            name: "file_read".into(),
340            output: "hello".into(),
341            success: true,
342            tool_call_id: Some("tc1".into()),
343        }]);
344        match msg {
345            ConversationMessage::ToolResults(results) => {
346                assert_eq!(results.len(), 1);
347                assert_eq!(results[0].tool_call_id, "tc1");
348            }
349            _ => panic!("expected tool results"),
350        }
351    }
352
353    #[test]
354    fn xml_format_results_contains_tool_result_tags() {
355        let dispatcher = XmlToolDispatcher;
356        let msg = dispatcher.format_results(&[ToolExecutionResult {
357            name: "shell".into(),
358            output: "ok".into(),
359            success: true,
360            tool_call_id: None,
361        }]);
362        let rendered = match msg {
363            ConversationMessage::Chat(chat) => chat.content,
364            _ => String::new(),
365        };
366        assert!(rendered.contains("<tool_result"));
367        assert!(rendered.contains("shell"));
368    }
369
370    #[test]
371    fn native_format_results_keeps_tool_call_id() {
372        let dispatcher = NativeToolDispatcher;
373        let msg = dispatcher.format_results(&[ToolExecutionResult {
374            name: "shell".into(),
375            output: "ok".into(),
376            success: true,
377            tool_call_id: Some("tc-1".into()),
378        }]);
379
380        match msg {
381            ConversationMessage::ToolResults(results) => {
382                assert_eq!(results.len(), 1);
383                assert_eq!(results[0].tool_call_id, "tc-1");
384            }
385            _ => panic!("expected ToolResults variant"),
386        }
387    }
388
389    // ═══════════════════════════════════════════════════════════════════════
390    // reasoning_content pass-through tests
391    // ═══════════════════════════════════════════════════════════════════════
392
393    #[test]
394    fn native_to_provider_messages_includes_reasoning_content() {
395        let dispatcher = NativeToolDispatcher;
396        let history = vec![ConversationMessage::AssistantToolCalls {
397            text: Some("answer".into()),
398            tool_calls: vec![zeroclaw_providers::ToolCall {
399                id: "tc_1".into(),
400                name: "shell".into(),
401                arguments: "{}".into(),
402                extra_content: None,
403            }],
404            reasoning_content: Some("thinking step".into()),
405        }];
406
407        let messages = dispatcher.to_provider_messages(&history);
408        assert_eq!(messages.len(), 1);
409        assert_eq!(messages[0].role, "assistant");
410
411        let payload: serde_json::Value = serde_json::from_str(&messages[0].content).unwrap();
412        assert_eq!(payload["reasoning_content"].as_str(), Some("thinking step"));
413        assert_eq!(payload["content"].as_str(), Some("answer"));
414        assert!(payload["tool_calls"].is_array());
415    }
416
417    #[test]
418    fn native_to_provider_messages_omits_reasoning_content_when_none() {
419        let dispatcher = NativeToolDispatcher;
420        let history = vec![ConversationMessage::AssistantToolCalls {
421            text: Some("answer".into()),
422            tool_calls: vec![zeroclaw_providers::ToolCall {
423                id: "tc_1".into(),
424                name: "shell".into(),
425                arguments: "{}".into(),
426                extra_content: None,
427            }],
428            reasoning_content: None,
429        }];
430
431        let messages = dispatcher.to_provider_messages(&history);
432        assert_eq!(messages.len(), 1);
433
434        let payload: serde_json::Value = serde_json::from_str(&messages[0].content).unwrap();
435        assert!(payload.get("reasoning_content").is_none());
436    }
437
438    #[test]
439    fn xml_to_provider_messages_ignores_reasoning_content() {
440        let dispatcher = XmlToolDispatcher;
441        let history = vec![ConversationMessage::AssistantToolCalls {
442            text: Some("answer".into()),
443            tool_calls: vec![zeroclaw_providers::ToolCall {
444                id: "tc_1".into(),
445                name: "shell".into(),
446                arguments: "{}".into(),
447                extra_content: None,
448            }],
449            reasoning_content: Some("should be ignored".into()),
450        }];
451
452        let messages = dispatcher.to_provider_messages(&history);
453        assert_eq!(messages.len(), 1);
454        assert_eq!(messages[0].role, "assistant");
455        // XmlToolDispatcher returns text only, not JSON payload
456        assert_eq!(messages[0].content, "answer");
457        assert!(!messages[0].content.contains("reasoning_content"));
458    }
459}