1use std::collections::HashMap;
6use std::sync::Arc;
7#[cfg(not(target_has_atomic = "64"))]
8use std::sync::atomic::AtomicU32;
9#[cfg(target_has_atomic = "64")]
10use std::sync::atomic::AtomicU64;
11use std::sync::atomic::Ordering;
12
13use anyhow::{Context, Result, bail};
14use serde_json::json;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, timeout};
17
18use crate::mcp_protocol::{JsonRpcRequest, MCP_PROTOCOL_VERSION, McpToolDef, McpToolsListResult};
19use crate::mcp_transport::{McpTransportConn, create_transport};
20use zeroclaw_config::schema::McpServerConfig;
21
22const RECV_TIMEOUT_SECS: u64 = 30;
25
26const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
28
29const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
31
32struct McpServerInner {
35 config: McpServerConfig,
36 transport: Box<dyn McpTransportConn>,
37 #[cfg(target_has_atomic = "64")]
38 next_id: AtomicU64,
39 #[cfg(not(target_has_atomic = "64"))]
40 next_id: AtomicU32,
41 tools: Vec<McpToolDef>,
42}
43
44#[derive(Clone)]
48pub struct McpServer {
49 inner: Arc<Mutex<McpServerInner>>,
50}
51
52impl McpServer {
53 pub async fn connect(config: McpServerConfig) -> Result<Self> {
55 let mut transport = create_transport(&config).with_context(|| {
57 format!(
58 "failed to create transport for MCP server `{}`",
59 config.name
60 )
61 })?;
62
63 let id = 1u64;
65 let init_req = JsonRpcRequest::new(
66 id,
67 "initialize",
68 json!({
69 "protocolVersion": MCP_PROTOCOL_VERSION,
70 "capabilities": {},
71 "clientInfo": {
72 "name": "zeroclaw",
73 "version": env!("CARGO_PKG_VERSION")
74 }
75 }),
76 );
77
78 let init_resp = timeout(
79 Duration::from_secs(RECV_TIMEOUT_SECS),
80 transport.send_and_recv(&init_req),
81 )
82 .await
83 .with_context(|| {
84 format!(
85 "MCP server `{}` timed out after {}s waiting for initialize response",
86 config.name, RECV_TIMEOUT_SECS
87 )
88 })??;
89
90 if init_resp.error.is_some() {
91 bail!(
92 "MCP server `{}` rejected initialize: {:?}",
93 config.name,
94 init_resp.error
95 );
96 }
97
98 let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
101 let _ = transport.send_and_recv(¬if).await;
103
104 let id = 2u64;
106 let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
107
108 let list_resp = timeout(
109 Duration::from_secs(RECV_TIMEOUT_SECS),
110 transport.send_and_recv(&list_req),
111 )
112 .await
113 .with_context(|| {
114 format!(
115 "MCP server `{}` timed out after {}s waiting for tools/list response",
116 config.name, RECV_TIMEOUT_SECS
117 )
118 })??;
119
120 let result = list_resp.result.ok_or_else(|| {
121 ::zeroclaw_log::record!(
122 ERROR,
123 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
124 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
125 .with_attrs(::serde_json::json!({"mcp_server": &config.name})),
126 "mcp_client: tools/list returned no result"
127 );
128 anyhow::Error::msg(format!(
129 "tools/list returned no result from `{}`",
130 config.name
131 ))
132 })?;
133 let tool_list: McpToolsListResult = serde_json::from_value(result)
134 .with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
135
136 let tool_count = tool_list.tools.len();
137
138 let inner = McpServerInner {
139 config,
140 transport,
141 #[cfg(target_has_atomic = "64")]
142 next_id: AtomicU64::new(3), #[cfg(not(target_has_atomic = "64"))]
144 next_id: AtomicU32::new(3), tools: tool_list.tools,
146 };
147
148 ::zeroclaw_log::record!(
149 INFO,
150 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
151 &format!(
152 "MCP server `{}` connected — {} tool(s) available",
153 inner.config.name, tool_count
154 )
155 );
156
157 Ok(Self {
158 inner: Arc::new(Mutex::new(inner)),
159 })
160 }
161
162 pub async fn tools(&self) -> Vec<McpToolDef> {
164 self.inner.lock().await.tools.clone()
165 }
166
167 pub async fn name(&self) -> String {
169 self.inner.lock().await.config.name.clone()
170 }
171
172 pub async fn call_tool(
174 &self,
175 tool_name: &str,
176 arguments: serde_json::Value,
177 ) -> Result<serde_json::Value> {
178 let mut inner = self.inner.lock().await;
179 let id = inner.next_id.fetch_add(1, Ordering::Relaxed);
180 let req = JsonRpcRequest::new(
181 id,
182 "tools/call",
183 json!({ "name": tool_name, "arguments": arguments }),
184 );
185
186 let tool_timeout = inner
189 .config
190 .tool_timeout_secs
191 .unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
192 .min(MAX_TOOL_TIMEOUT_SECS);
193
194 let resp = timeout(
195 Duration::from_secs(tool_timeout),
196 inner.transport.send_and_recv(&req),
197 )
198 .await
199 .map_err(|_| {
200 ::zeroclaw_log::record!(
201 WARN,
202 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Timeout)
203 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
204 .with_attrs(::serde_json::json!({
205 "mcp_server": &inner.config.name,
206 "tool": tool_name,
207 "timeout_secs": tool_timeout,
208 })),
209 "mcp_client: tool call timed out"
210 );
211 anyhow::Error::msg(format!(
212 "MCP server `{}` timed out after {}s during tool call `{tool_name}`",
213 inner.config.name, tool_timeout
214 ))
215 })?
216 .with_context(|| {
217 format!(
218 "MCP server `{}` error during tool call `{tool_name}`",
219 inner.config.name
220 )
221 })?;
222
223 if let Some(err) = resp.error {
224 bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
225 }
226 Ok(resp.result.unwrap_or(serde_json::Value::Null))
227 }
228}
229
230pub struct McpRegistry {
234 servers: Vec<McpServer>,
235 tool_index: HashMap<String, (usize, String)>,
237}
238
239impl McpRegistry {
240 pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
242 let mut servers = Vec::new();
243 let mut tool_index = HashMap::new();
244
245 for config in configs {
246 match McpServer::connect(config.clone()).await {
247 Ok(server) => {
248 let server_idx = servers.len();
249 let tools = server.tools().await;
251 for tool in &tools {
252 let prefixed = format!("{}__{}", config.name, tool.name);
254 tool_index.insert(prefixed, (server_idx, tool.name.clone()));
255 }
256 servers.push(server);
257 }
258 Err(e) => {
260 ::zeroclaw_log::record!(
261 ERROR,
262 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
263 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
264 &format!("Failed to connect to MCP server `{}`: {:#}", config.name, e)
265 );
266 }
267 }
268 }
269
270 Ok(Self {
271 servers,
272 tool_index,
273 })
274 }
275
276 pub fn tool_names(&self) -> Vec<String> {
278 self.tool_index.keys().cloned().collect()
279 }
280
281 pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
283 let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
284 let inner = self.servers[*server_idx].inner.lock().await;
285 inner
286 .tools
287 .iter()
288 .find(|t| &t.name == original_name)
289 .cloned()
290 }
291
292 pub async fn call_tool(
294 &self,
295 prefixed_name: &str,
296 arguments: serde_json::Value,
297 ) -> Result<String> {
298 let (server_idx, original_name) = self.tool_index.get(prefixed_name).ok_or_else(|| {
299 ::zeroclaw_log::record!(
300 WARN,
301 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
302 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
303 .with_attrs(::serde_json::json!({"tool": prefixed_name})),
304 "mcp_client: unknown MCP tool"
305 );
306 anyhow::Error::msg(format!("unknown MCP tool `{prefixed_name}`"))
307 })?;
308 let result = self.servers[*server_idx]
309 .call_tool(original_name, arguments)
310 .await?;
311 serde_json::to_string_pretty(&result)
312 .with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
313 }
314
315 pub fn is_empty(&self) -> bool {
316 self.servers.is_empty()
317 }
318
319 pub fn server_count(&self) -> usize {
320 self.servers.len()
321 }
322
323 pub fn tool_count(&self) -> usize {
324 self.tool_index.len()
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use zeroclaw_config::schema::McpTransport;
332
333 #[test]
334 fn tool_name_prefix_format() {
335 let prefixed = format!("{}__{}", "filesystem", "read_file");
336 assert_eq!(prefixed, "filesystem__read_file");
337 }
338
339 #[tokio::test]
340 async fn connect_nonexistent_command_fails_cleanly() {
341 let config = McpServerConfig {
343 name: "nonexistent".to_string(),
344 command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
345 args: vec![],
346 env: std::collections::HashMap::default(),
347 tool_timeout_secs: None,
348 transport: McpTransport::Stdio,
349 url: None,
350 headers: std::collections::HashMap::default(),
351 };
352 let result = McpServer::connect(config).await;
353 assert!(result.is_err());
354 let msg = result.err().unwrap().to_string();
355 assert!(msg.contains("failed to create transport"), "got: {msg}");
356 }
357
358 #[tokio::test]
359 async fn connect_all_nonfatal_on_single_failure() {
360 let configs = vec![McpServerConfig {
362 name: "bad".to_string(),
363 command: "/usr/bin/does_not_exist_zc_test".to_string(),
364 args: vec![],
365 env: std::collections::HashMap::default(),
366 tool_timeout_secs: None,
367 transport: McpTransport::Stdio,
368 url: None,
369 headers: std::collections::HashMap::default(),
370 }];
371 let registry = McpRegistry::connect_all(&configs)
372 .await
373 .expect("connect_all should not fail");
374 assert!(registry.is_empty());
375 assert_eq!(registry.tool_count(), 0);
376 }
377
378 #[test]
379 fn http_transport_requires_url() {
380 let config = McpServerConfig {
381 name: "test".into(),
382 transport: McpTransport::Http,
383 ..Default::default()
384 };
385 let result = create_transport(&config);
386 assert!(result.is_err());
387 }
388
389 #[test]
390 fn sse_transport_requires_url() {
391 let config = McpServerConfig {
392 name: "test".into(),
393 transport: McpTransport::Sse,
394 ..Default::default()
395 };
396 let result = create_transport(&config);
397 assert!(result.is_err());
398 }
399
400 #[tokio::test]
403 async fn empty_registry_is_empty() {
404 let registry = McpRegistry::connect_all(&[])
405 .await
406 .expect("connect_all on empty slice should succeed");
407 assert!(registry.is_empty());
408 assert_eq!(registry.server_count(), 0);
409 assert_eq!(registry.tool_count(), 0);
410 }
411
412 #[tokio::test]
413 async fn empty_registry_tool_names_is_empty() {
414 let registry = McpRegistry::connect_all(&[])
415 .await
416 .expect("connect_all should succeed");
417 assert!(registry.tool_names().is_empty());
418 }
419
420 #[tokio::test]
421 async fn empty_registry_get_tool_def_returns_none() {
422 let registry = McpRegistry::connect_all(&[])
423 .await
424 .expect("connect_all should succeed");
425 let result = registry.get_tool_def("nonexistent__tool").await;
426 assert!(result.is_none());
427 }
428
429 #[tokio::test]
430 async fn empty_registry_call_tool_unknown_name_returns_error() {
431 let registry = McpRegistry::connect_all(&[])
432 .await
433 .expect("connect_all should succeed");
434 let err = registry
435 .call_tool("nonexistent__tool", serde_json::json!({}))
436 .await
437 .expect_err("should fail for unknown tool");
438 assert!(err.to_string().contains("unknown MCP tool"), "got: {err}");
439 }
440
441 #[tokio::test]
442 async fn connect_all_empty_gives_zero_servers() {
443 let registry = McpRegistry::connect_all(&[])
444 .await
445 .expect("connect_all should succeed");
446 assert_eq!(registry.server_count(), 0);
448 assert_eq!(registry.tool_count(), 0);
449 assert!(registry.is_empty());
450 }
451}