Skip to main content

zeroclaw_tools/
mcp_client.rs

1//! MCP (Model Context Protocol) client — connects to external tool servers.
2//!
3//! Supports multiple transports: stdio (spawn local process), HTTP, and SSE.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7#[cfg(not(target_has_atomic = "64"))]
8use std::sync::atomic::AtomicU32;
9#[cfg(target_has_atomic = "64")]
10use std::sync::atomic::AtomicU64;
11use std::sync::atomic::Ordering;
12
13use anyhow::{Context, Result, bail};
14use serde_json::json;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, timeout};
17
18use crate::mcp_protocol::{JsonRpcRequest, MCP_PROTOCOL_VERSION, McpToolDef, McpToolsListResult};
19use crate::mcp_transport::{McpTransportConn, create_transport};
20use zeroclaw_config::schema::McpServerConfig;
21
22/// Timeout for receiving a response from an MCP server during init/list.
23/// Prevents a hung server from blocking the daemon indefinitely.
24const RECV_TIMEOUT_SECS: u64 = 30;
25
26/// Default timeout for tool calls (seconds) when not configured per-server.
27const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
28
29/// Maximum allowed tool call timeout (seconds) — hard safety ceiling.
30const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
31
32// ── Internal server state ──────────────────────────────────────────────────
33
34struct McpServerInner {
35    config: McpServerConfig,
36    transport: Box<dyn McpTransportConn>,
37    #[cfg(target_has_atomic = "64")]
38    next_id: AtomicU64,
39    #[cfg(not(target_has_atomic = "64"))]
40    next_id: AtomicU32,
41    tools: Vec<McpToolDef>,
42}
43
44// ── McpServer ──────────────────────────────────────────────────────────────
45
46/// A live connection to one MCP server (any transport).
47#[derive(Clone)]
48pub struct McpServer {
49    inner: Arc<Mutex<McpServerInner>>,
50}
51
52impl McpServer {
53    /// Connect to the server, perform the initialize handshake, and fetch the tool list.
54    pub async fn connect(config: McpServerConfig) -> Result<Self> {
55        // Create transport based on config
56        let mut transport = create_transport(&config).with_context(|| {
57            format!(
58                "failed to create transport for MCP server `{}`",
59                config.name
60            )
61        })?;
62
63        // Initialize handshake
64        let id = 1u64;
65        let init_req = JsonRpcRequest::new(
66            id,
67            "initialize",
68            json!({
69                "protocolVersion": MCP_PROTOCOL_VERSION,
70                "capabilities": {},
71                "clientInfo": {
72                    "name": "zeroclaw",
73                    "version": env!("CARGO_PKG_VERSION")
74                }
75            }),
76        );
77
78        let init_resp = timeout(
79            Duration::from_secs(RECV_TIMEOUT_SECS),
80            transport.send_and_recv(&init_req),
81        )
82        .await
83        .with_context(|| {
84            format!(
85                "MCP server `{}` timed out after {}s waiting for initialize response",
86                config.name, RECV_TIMEOUT_SECS
87            )
88        })??;
89
90        if init_resp.error.is_some() {
91            bail!(
92                "MCP server `{}` rejected initialize: {:?}",
93                config.name,
94                init_resp.error
95            );
96        }
97
98        // Notify server that client is initialized (no response expected for notifications)
99        // For notifications, we send but don't wait for response
100        let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
101        // Best effort - ignore errors for notifications
102        let _ = transport.send_and_recv(&notif).await;
103
104        // Fetch available tools
105        let id = 2u64;
106        let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
107
108        let list_resp = timeout(
109            Duration::from_secs(RECV_TIMEOUT_SECS),
110            transport.send_and_recv(&list_req),
111        )
112        .await
113        .with_context(|| {
114            format!(
115                "MCP server `{}` timed out after {}s waiting for tools/list response",
116                config.name, RECV_TIMEOUT_SECS
117            )
118        })??;
119
120        let result = list_resp.result.ok_or_else(|| {
121            ::zeroclaw_log::record!(
122                ERROR,
123                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
124                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
125                    .with_attrs(::serde_json::json!({"mcp_server": &config.name})),
126                "mcp_client: tools/list returned no result"
127            );
128            anyhow::Error::msg(format!(
129                "tools/list returned no result from `{}`",
130                config.name
131            ))
132        })?;
133        let tool_list: McpToolsListResult = serde_json::from_value(result)
134            .with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
135
136        let tool_count = tool_list.tools.len();
137
138        let inner = McpServerInner {
139            config,
140            transport,
141            #[cfg(target_has_atomic = "64")]
142            next_id: AtomicU64::new(3), // Start at 3 since we used 1 and 2
143            #[cfg(not(target_has_atomic = "64"))]
144            next_id: AtomicU32::new(3), // Start at 3 since we used 1 and 2
145            tools: tool_list.tools,
146        };
147
148        ::zeroclaw_log::record!(
149            INFO,
150            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
151            &format!(
152                "MCP server `{}` connected — {} tool(s) available",
153                inner.config.name, tool_count
154            )
155        );
156
157        Ok(Self {
158            inner: Arc::new(Mutex::new(inner)),
159        })
160    }
161
162    /// Tools advertised by this server.
163    pub async fn tools(&self) -> Vec<McpToolDef> {
164        self.inner.lock().await.tools.clone()
165    }
166
167    /// Server display name.
168    pub async fn name(&self) -> String {
169        self.inner.lock().await.config.name.clone()
170    }
171
172    /// Call a tool on this server. Returns the raw JSON result.
173    pub async fn call_tool(
174        &self,
175        tool_name: &str,
176        arguments: serde_json::Value,
177    ) -> Result<serde_json::Value> {
178        let mut inner = self.inner.lock().await;
179        let id = inner.next_id.fetch_add(1, Ordering::Relaxed);
180        let req = JsonRpcRequest::new(
181            id,
182            "tools/call",
183            json!({ "name": tool_name, "arguments": arguments }),
184        );
185
186        // Use per-server tool timeout if configured, otherwise default.
187        // Cap at MAX_TOOL_TIMEOUT_SECS for safety.
188        let tool_timeout = inner
189            .config
190            .tool_timeout_secs
191            .unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
192            .min(MAX_TOOL_TIMEOUT_SECS);
193
194        let resp = timeout(
195            Duration::from_secs(tool_timeout),
196            inner.transport.send_and_recv(&req),
197        )
198        .await
199        .map_err(|_| {
200            ::zeroclaw_log::record!(
201                WARN,
202                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Timeout)
203                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
204                    .with_attrs(::serde_json::json!({
205                        "mcp_server": &inner.config.name,
206                        "tool": tool_name,
207                        "timeout_secs": tool_timeout,
208                    })),
209                "mcp_client: tool call timed out"
210            );
211            anyhow::Error::msg(format!(
212                "MCP server `{}` timed out after {}s during tool call `{tool_name}`",
213                inner.config.name, tool_timeout
214            ))
215        })?
216        .with_context(|| {
217            format!(
218                "MCP server `{}` error during tool call `{tool_name}`",
219                inner.config.name
220            )
221        })?;
222
223        if let Some(err) = resp.error {
224            bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
225        }
226        Ok(resp.result.unwrap_or(serde_json::Value::Null))
227    }
228}
229
230// ── McpRegistry ───────────────────────────────────────────────────────────
231
232/// Registry of all connected MCP servers, with a flat tool index.
233pub struct McpRegistry {
234    servers: Vec<McpServer>,
235    /// prefixed_name → (server_index, original_tool_name)
236    tool_index: HashMap<String, (usize, String)>,
237}
238
239impl McpRegistry {
240    /// Connect to all configured servers. Non-fatal: failures are logged and skipped.
241    pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
242        let mut servers = Vec::new();
243        let mut tool_index = HashMap::new();
244
245        for config in configs {
246            match McpServer::connect(config.clone()).await {
247                Ok(server) => {
248                    let server_idx = servers.len();
249                    // Collect tools while holding the lock once, then release
250                    let tools = server.tools().await;
251                    for tool in &tools {
252                        // Prefix prevents name collisions across servers
253                        let prefixed = format!("{}__{}", config.name, tool.name);
254                        tool_index.insert(prefixed, (server_idx, tool.name.clone()));
255                    }
256                    servers.push(server);
257                }
258                // Non-fatal — log and continue with remaining servers
259                Err(e) => {
260                    ::zeroclaw_log::record!(
261                        ERROR,
262                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
263                            .with_outcome(::zeroclaw_log::EventOutcome::Failure),
264                        &format!("Failed to connect to MCP server `{}`: {:#}", config.name, e)
265                    );
266                }
267            }
268        }
269
270        Ok(Self {
271            servers,
272            tool_index,
273        })
274    }
275
276    /// All prefixed tool names across all connected servers.
277    pub fn tool_names(&self) -> Vec<String> {
278        self.tool_index.keys().cloned().collect()
279    }
280
281    /// Tool definition for a given prefixed name (cloned).
282    pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
283        let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
284        let inner = self.servers[*server_idx].inner.lock().await;
285        inner
286            .tools
287            .iter()
288            .find(|t| &t.name == original_name)
289            .cloned()
290    }
291
292    /// Execute a tool by prefixed name.
293    pub async fn call_tool(
294        &self,
295        prefixed_name: &str,
296        arguments: serde_json::Value,
297    ) -> Result<String> {
298        let (server_idx, original_name) = self.tool_index.get(prefixed_name).ok_or_else(|| {
299            ::zeroclaw_log::record!(
300                WARN,
301                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
302                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
303                    .with_attrs(::serde_json::json!({"tool": prefixed_name})),
304                "mcp_client: unknown MCP tool"
305            );
306            anyhow::Error::msg(format!("unknown MCP tool `{prefixed_name}`"))
307        })?;
308        let result = self.servers[*server_idx]
309            .call_tool(original_name, arguments)
310            .await?;
311        serde_json::to_string_pretty(&result)
312            .with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
313    }
314
315    pub fn is_empty(&self) -> bool {
316        self.servers.is_empty()
317    }
318
319    pub fn server_count(&self) -> usize {
320        self.servers.len()
321    }
322
323    pub fn tool_count(&self) -> usize {
324        self.tool_index.len()
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use zeroclaw_config::schema::McpTransport;
332
333    #[test]
334    fn tool_name_prefix_format() {
335        let prefixed = format!("{}__{}", "filesystem", "read_file");
336        assert_eq!(prefixed, "filesystem__read_file");
337    }
338
339    #[tokio::test]
340    async fn connect_nonexistent_command_fails_cleanly() {
341        // A command that doesn't exist should fail at spawn, not panic.
342        let config = McpServerConfig {
343            name: "nonexistent".to_string(),
344            command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
345            args: vec![],
346            env: std::collections::HashMap::default(),
347            tool_timeout_secs: None,
348            transport: McpTransport::Stdio,
349            url: None,
350            headers: std::collections::HashMap::default(),
351        };
352        let result = McpServer::connect(config).await;
353        assert!(result.is_err());
354        let msg = result.err().unwrap().to_string();
355        assert!(msg.contains("failed to create transport"), "got: {msg}");
356    }
357
358    #[tokio::test]
359    async fn connect_all_nonfatal_on_single_failure() {
360        // If one server config is bad, connect_all should succeed (with 0 servers).
361        let configs = vec![McpServerConfig {
362            name: "bad".to_string(),
363            command: "/usr/bin/does_not_exist_zc_test".to_string(),
364            args: vec![],
365            env: std::collections::HashMap::default(),
366            tool_timeout_secs: None,
367            transport: McpTransport::Stdio,
368            url: None,
369            headers: std::collections::HashMap::default(),
370        }];
371        let registry = McpRegistry::connect_all(&configs)
372            .await
373            .expect("connect_all should not fail");
374        assert!(registry.is_empty());
375        assert_eq!(registry.tool_count(), 0);
376    }
377
378    #[test]
379    fn http_transport_requires_url() {
380        let config = McpServerConfig {
381            name: "test".into(),
382            transport: McpTransport::Http,
383            ..Default::default()
384        };
385        let result = create_transport(&config);
386        assert!(result.is_err());
387    }
388
389    #[test]
390    fn sse_transport_requires_url() {
391        let config = McpServerConfig {
392            name: "test".into(),
393            transport: McpTransport::Sse,
394            ..Default::default()
395        };
396        let result = create_transport(&config);
397        assert!(result.is_err());
398    }
399
400    // ── Empty registry (no servers) ────────────────────────────────────────
401
402    #[tokio::test]
403    async fn empty_registry_is_empty() {
404        let registry = McpRegistry::connect_all(&[])
405            .await
406            .expect("connect_all on empty slice should succeed");
407        assert!(registry.is_empty());
408        assert_eq!(registry.server_count(), 0);
409        assert_eq!(registry.tool_count(), 0);
410    }
411
412    #[tokio::test]
413    async fn empty_registry_tool_names_is_empty() {
414        let registry = McpRegistry::connect_all(&[])
415            .await
416            .expect("connect_all should succeed");
417        assert!(registry.tool_names().is_empty());
418    }
419
420    #[tokio::test]
421    async fn empty_registry_get_tool_def_returns_none() {
422        let registry = McpRegistry::connect_all(&[])
423            .await
424            .expect("connect_all should succeed");
425        let result = registry.get_tool_def("nonexistent__tool").await;
426        assert!(result.is_none());
427    }
428
429    #[tokio::test]
430    async fn empty_registry_call_tool_unknown_name_returns_error() {
431        let registry = McpRegistry::connect_all(&[])
432            .await
433            .expect("connect_all should succeed");
434        let err = registry
435            .call_tool("nonexistent__tool", serde_json::json!({}))
436            .await
437            .expect_err("should fail for unknown tool");
438        assert!(err.to_string().contains("unknown MCP tool"), "got: {err}");
439    }
440
441    #[tokio::test]
442    async fn connect_all_empty_gives_zero_servers() {
443        let registry = McpRegistry::connect_all(&[])
444            .await
445            .expect("connect_all should succeed");
446        // Verify all three count methods agree on zero.
447        assert_eq!(registry.server_count(), 0);
448        assert_eq!(registry.tool_count(), 0);
449        assert!(registry.is_empty());
450    }
451}