1use std::collections::HashMap;
6use std::sync::Arc;
7#[cfg(not(target_has_atomic = "64"))]
8use std::sync::atomic::AtomicU32;
9#[cfg(target_has_atomic = "64")]
10use std::sync::atomic::AtomicU64;
11use std::sync::atomic::Ordering;
12
13use anyhow::{Context, Result, bail};
14use serde_json::json;
15use tokio::sync::Mutex;
16use tokio::time::{Duration, timeout};
17
18use crate::mcp_protocol::{JsonRpcRequest, MCP_PROTOCOL_VERSION, McpToolDef, McpToolsListResult};
19use crate::mcp_transport::{McpTransportConn, create_transport};
20use zeroclaw_config::schema::McpServerConfig;
21
22const RECV_TIMEOUT_SECS: u64 = 30;
25
26const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
28
29const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
31
32struct McpServerInner {
35 config: McpServerConfig,
36 transport: Box<dyn McpTransportConn>,
37 #[cfg(target_has_atomic = "64")]
38 next_id: AtomicU64,
39 #[cfg(not(target_has_atomic = "64"))]
40 next_id: AtomicU32,
41 tools: Vec<McpToolDef>,
42}
43
44#[derive(Clone)]
48pub struct McpServer {
49 inner: Arc<Mutex<McpServerInner>>,
50}
51
52impl McpServer {
53 pub async fn connect(config: McpServerConfig) -> Result<Self> {
55 let mut transport = create_transport(&config).with_context(|| {
57 format!(
58 "failed to create transport for MCP server `{}`",
59 config.name
60 )
61 })?;
62
63 let id = 1u64;
65 let init_req = JsonRpcRequest::new(
66 id,
67 "initialize",
68 json!({
69 "protocolVersion": MCP_PROTOCOL_VERSION,
70 "capabilities": {},
71 "clientInfo": {
72 "name": "zeroclaw",
73 "version": env!("CARGO_PKG_VERSION")
74 }
75 }),
76 );
77
78 let init_resp = timeout(
79 Duration::from_secs(RECV_TIMEOUT_SECS),
80 transport.send_and_recv(&init_req),
81 )
82 .await
83 .with_context(|| {
84 format!(
85 "MCP server `{}` timed out after {}s waiting for initialize response",
86 config.name, RECV_TIMEOUT_SECS
87 )
88 })??;
89
90 if init_resp.error.is_some() {
91 bail!(
92 "MCP server `{}` rejected initialize: {:?}",
93 config.name,
94 init_resp.error
95 );
96 }
97
98 let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
101 let _ = transport.send_and_recv(¬if).await;
103
104 let id = 2u64;
106 let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
107
108 let list_resp = timeout(
109 Duration::from_secs(RECV_TIMEOUT_SECS),
110 transport.send_and_recv(&list_req),
111 )
112 .await
113 .with_context(|| {
114 format!(
115 "MCP server `{}` timed out after {}s waiting for tools/list response",
116 config.name, RECV_TIMEOUT_SECS
117 )
118 })??;
119
120 let result = list_resp.result.ok_or_else(|| {
121 ::zeroclaw_log::record!(
122 ERROR,
123 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
124 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
125 .with_attrs(::serde_json::json!({"mcp_server": &config.name})),
126 "mcp_client: tools/list returned no result"
127 );
128 anyhow::Error::msg(format!(
129 "tools/list returned no result from `{}`",
130 config.name
131 ))
132 })?;
133 let tool_list: McpToolsListResult = serde_json::from_value(result)
134 .with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
135
136 let tool_count = tool_list.tools.len();
137
138 let inner = McpServerInner {
139 config,
140 transport,
141 #[cfg(target_has_atomic = "64")]
142 next_id: AtomicU64::new(3), #[cfg(not(target_has_atomic = "64"))]
144 next_id: AtomicU32::new(3), tools: tool_list.tools,
146 };
147
148 ::zeroclaw_log::record!(
149 INFO,
150 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
151 &format!(
152 "MCP server `{}` connected — {} tool(s) available",
153 inner.config.name, tool_count
154 )
155 );
156
157 Ok(Self {
158 inner: Arc::new(Mutex::new(inner)),
159 })
160 }
161
162 pub async fn tools(&self) -> Vec<McpToolDef> {
164 self.inner.lock().await.tools.clone()
165 }
166
167 pub async fn name(&self) -> String {
169 self.inner.lock().await.config.name.clone()
170 }
171
172 pub async fn call_tool(
174 &self,
175 tool_name: &str,
176 arguments: serde_json::Value,
177 ) -> Result<serde_json::Value> {
178 let mut inner = self.inner.lock().await;
179 let id = inner.next_id.fetch_add(1, Ordering::Relaxed);
180 let req = JsonRpcRequest::new(
181 id,
182 "tools/call",
183 json!({ "name": tool_name, "arguments": arguments }),
184 );
185
186 let tool_timeout = inner
189 .config
190 .tool_timeout_secs
191 .unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
192 .min(MAX_TOOL_TIMEOUT_SECS);
193
194 let resp = timeout(
195 Duration::from_secs(tool_timeout),
196 inner.transport.send_and_recv(&req),
197 )
198 .await
199 .map_err(|_| {
200 ::zeroclaw_log::record!(
201 WARN,
202 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Timeout)
203 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
204 .with_attrs(::serde_json::json!({
205 "mcp_server": &inner.config.name,
206 "tool": tool_name,
207 "timeout_secs": tool_timeout,
208 })),
209 "mcp_client: tool call timed out"
210 );
211 anyhow::Error::msg(format!(
212 "MCP server `{}` timed out after {}s during tool call `{tool_name}`",
213 inner.config.name, tool_timeout
214 ))
215 })?
216 .with_context(|| {
217 format!(
218 "MCP server `{}` error during tool call `{tool_name}`",
219 inner.config.name
220 )
221 })?;
222
223 if let Some(err) = resp.error {
224 bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
225 }
226
227 let result = resp.result.unwrap_or(serde_json::Value::Null);
228
229 if result.get("isError").and_then(serde_json::Value::as_bool) == Some(true) {
236 let detail = result
237 .get("content")
238 .and_then(|c| c.as_array())
239 .map(|arr| {
240 arr.iter()
241 .filter_map(|item| item.get("text").and_then(|t| t.as_str()))
242 .collect::<Vec<_>>()
243 .join("\n")
244 })
245 .filter(|s: &String| !s.is_empty())
246 .unwrap_or_else(|| "(no error detail returned by server)".to_string());
247 let detail = zeroclaw_providers::sanitize_api_error(&detail);
251 ::zeroclaw_log::record!(
252 WARN,
253 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
254 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
255 .with_attrs(::serde_json::json!({
256 "mcp_server": &inner.config.name,
257 "tool": tool_name,
258 "detail": &detail,
259 })),
260 "mcp_client: tool returned isError:true"
261 );
262 bail!(
263 "MCP tool `{tool_name}` (server `{}`) returned isError: {detail}",
264 inner.config.name
265 );
266 }
267
268 Ok(result)
269 }
270}
271
272pub struct McpRegistry {
276 servers: Vec<McpServer>,
277 tool_index: HashMap<String, (usize, String)>,
279}
280
281impl McpRegistry {
282 pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
284 let mut servers = Vec::new();
285 let mut tool_index = HashMap::new();
286
287 for config in configs {
288 match McpServer::connect(config.clone()).await {
289 Ok(server) => {
290 let server_idx = servers.len();
291 let tools = server.tools().await;
293 for tool in &tools {
294 let prefixed = format!("{}__{}", config.name, tool.name);
296 tool_index.insert(prefixed, (server_idx, tool.name.clone()));
297 }
298 servers.push(server);
299 }
300 Err(e) => {
302 ::zeroclaw_log::record!(
303 ERROR,
304 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
305 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
306 &format!("Failed to connect to MCP server `{}`: {:#}", config.name, e)
307 );
308 }
309 }
310 }
311
312 Ok(Self {
313 servers,
314 tool_index,
315 })
316 }
317
318 pub fn tool_names(&self) -> Vec<String> {
320 self.tool_index.keys().cloned().collect()
321 }
322
323 pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
325 let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
326 let inner = self.servers[*server_idx].inner.lock().await;
327 inner
328 .tools
329 .iter()
330 .find(|t| &t.name == original_name)
331 .cloned()
332 }
333
334 pub async fn call_tool(
336 &self,
337 prefixed_name: &str,
338 arguments: serde_json::Value,
339 ) -> Result<String> {
340 let (server_idx, original_name) = self.tool_index.get(prefixed_name).ok_or_else(|| {
341 ::zeroclaw_log::record!(
342 WARN,
343 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
344 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
345 .with_attrs(::serde_json::json!({"tool": prefixed_name})),
346 "mcp_client: unknown MCP tool"
347 );
348 anyhow::Error::msg(format!("unknown MCP tool `{prefixed_name}`"))
349 })?;
350 let result = self.servers[*server_idx]
351 .call_tool(original_name, arguments)
352 .await?;
353 serde_json::to_string_pretty(&result)
354 .with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
355 }
356
357 pub fn is_empty(&self) -> bool {
358 self.servers.is_empty()
359 }
360
361 pub fn server_count(&self) -> usize {
362 self.servers.len()
363 }
364
365 pub fn tool_count(&self) -> usize {
366 self.tool_index.len()
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use zeroclaw_config::schema::McpTransport;
374
375 #[test]
376 fn tool_name_prefix_format() {
377 let prefixed = format!("{}__{}", "filesystem", "read_file");
378 assert_eq!(prefixed, "filesystem__read_file");
379 }
380
381 #[tokio::test]
382 async fn connect_nonexistent_command_fails_cleanly() {
383 let config = McpServerConfig {
385 name: "nonexistent".to_string(),
386 command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
387 args: vec![],
388 env: std::collections::HashMap::default(),
389 tool_timeout_secs: None,
390 transport: McpTransport::Stdio,
391 url: None,
392 headers: std::collections::HashMap::default(),
393 };
394 let result = McpServer::connect(config).await;
395 assert!(result.is_err());
396 let msg = result.err().unwrap().to_string();
397 assert!(msg.contains("failed to create transport"), "got: {msg}");
398 }
399
400 #[tokio::test]
401 async fn connect_all_nonfatal_on_single_failure() {
402 let configs = vec![McpServerConfig {
404 name: "bad".to_string(),
405 command: "/usr/bin/does_not_exist_zc_test".to_string(),
406 args: vec![],
407 env: std::collections::HashMap::default(),
408 tool_timeout_secs: None,
409 transport: McpTransport::Stdio,
410 url: None,
411 headers: std::collections::HashMap::default(),
412 }];
413 let registry = McpRegistry::connect_all(&configs)
414 .await
415 .expect("connect_all should not fail");
416 assert!(registry.is_empty());
417 assert_eq!(registry.tool_count(), 0);
418 }
419
420 #[test]
421 fn http_transport_requires_url() {
422 let config = McpServerConfig {
423 name: "test".into(),
424 transport: McpTransport::Http,
425 ..Default::default()
426 };
427 let result = create_transport(&config);
428 assert!(result.is_err());
429 }
430
431 #[test]
432 fn sse_transport_requires_url() {
433 let config = McpServerConfig {
434 name: "test".into(),
435 transport: McpTransport::Sse,
436 ..Default::default()
437 };
438 let result = create_transport(&config);
439 assert!(result.is_err());
440 }
441
442 #[tokio::test]
445 async fn empty_registry_is_empty() {
446 let registry = McpRegistry::connect_all(&[])
447 .await
448 .expect("connect_all on empty slice should succeed");
449 assert!(registry.is_empty());
450 assert_eq!(registry.server_count(), 0);
451 assert_eq!(registry.tool_count(), 0);
452 }
453
454 #[tokio::test]
455 async fn empty_registry_tool_names_is_empty() {
456 let registry = McpRegistry::connect_all(&[])
457 .await
458 .expect("connect_all should succeed");
459 assert!(registry.tool_names().is_empty());
460 }
461
462 #[tokio::test]
463 async fn empty_registry_get_tool_def_returns_none() {
464 let registry = McpRegistry::connect_all(&[])
465 .await
466 .expect("connect_all should succeed");
467 let result = registry.get_tool_def("nonexistent__tool").await;
468 assert!(result.is_none());
469 }
470
471 #[tokio::test]
472 async fn empty_registry_call_tool_unknown_name_returns_error() {
473 let registry = McpRegistry::connect_all(&[])
474 .await
475 .expect("connect_all should succeed");
476 let err = registry
477 .call_tool("nonexistent__tool", serde_json::json!({}))
478 .await
479 .expect_err("should fail for unknown tool");
480 assert!(err.to_string().contains("unknown MCP tool"), "got: {err}");
481 }
482
483 #[tokio::test]
484 async fn connect_all_empty_gives_zero_servers() {
485 let registry = McpRegistry::connect_all(&[])
486 .await
487 .expect("connect_all should succeed");
488 assert_eq!(registry.server_count(), 0);
490 assert_eq!(registry.tool_count(), 0);
491 assert!(registry.is_empty());
492 }
493
494 struct FakeTransport {
503 result: serde_json::Value,
504 }
505
506 #[async_trait::async_trait]
507 impl McpTransportConn for FakeTransport {
508 async fn send_and_recv(
509 &mut self,
510 _request: &JsonRpcRequest,
511 ) -> Result<crate::mcp_protocol::JsonRpcResponse> {
512 Ok(crate::mcp_protocol::JsonRpcResponse {
513 jsonrpc: "2.0".to_string(),
514 id: Some(serde_json::json!(1)),
515 result: Some(self.result.clone()),
516 error: None,
517 })
518 }
519
520 async fn close(&mut self) -> Result<()> {
521 Ok(())
522 }
523 }
524
525 fn server_returning(result: serde_json::Value) -> McpServer {
527 let inner = McpServerInner {
528 config: McpServerConfig {
529 name: "fake".into(),
530 ..Default::default()
531 },
532 transport: Box::new(FakeTransport { result }),
533 #[cfg(target_has_atomic = "64")]
534 next_id: AtomicU64::new(3),
535 #[cfg(not(target_has_atomic = "64"))]
536 next_id: AtomicU32::new(3),
537 tools: vec![],
538 };
539 McpServer {
540 inner: Arc::new(Mutex::new(inner)),
541 }
542 }
543
544 #[tokio::test]
545 async fn call_tool_iserror_err_is_sanitized_and_bounded() {
546 let server = server_returning(serde_json::json!({
550 "isError": true,
551 "content": [{ "type": "text", "text": "auth failed using sk-supersecrettoken12345abcdef" }],
552 }));
553 let err = server
554 .call_tool("do_thing", serde_json::json!({}))
555 .await
556 .expect_err("isError:true must map to Err");
557 let msg = err.to_string();
558 assert!(msg.contains("returned isError"), "got: {msg}");
559 assert!(msg.contains("[REDACTED]"), "secret not scrubbed: {msg}");
560 assert!(
561 !msg.contains("supersecrettoken"),
562 "raw secret leaked: {msg}"
563 );
564
565 let huge = "A".repeat(5000);
568 let server = server_returning(serde_json::json!({
569 "isError": true,
570 "content": [{ "type": "text", "text": huge }],
571 }));
572 let msg = server
573 .call_tool("do_thing", serde_json::json!({}))
574 .await
575 .expect_err("isError:true must map to Err")
576 .to_string();
577 assert!(
578 msg.contains("..."),
579 "bounded detail should be truncated: {msg}"
580 );
581 assert!(
582 msg.len() < 1000,
583 "5000-char payload not bounded: len={}",
584 msg.len()
585 );
586 }
587
588 #[tokio::test]
589 async fn call_tool_success_returns_ok_result() {
590 let payload = serde_json::json!({
592 "content": [{ "type": "text", "text": "all good" }],
593 });
594 let out = server_returning(payload.clone())
595 .call_tool("do_thing", serde_json::json!({}))
596 .await
597 .expect("absent isError must be Ok");
598 assert_eq!(out, payload);
599
600 let payload = serde_json::json!({ "isError": false, "value": 42 });
602 let out = server_returning(payload.clone())
603 .call_tool("do_thing", serde_json::json!({}))
604 .await
605 .expect("isError:false must be Ok");
606 assert_eq!(out, payload);
607 }
608
609 #[tokio::test]
610 async fn call_tool_iserror_empty_detail_falls_back() {
611 let msg = server_returning(serde_json::json!({ "isError": true }))
613 .call_tool("do_thing", serde_json::json!({}))
614 .await
615 .expect_err("isError:true must map to Err")
616 .to_string();
617 assert!(
618 msg.contains("(no error detail returned by server)"),
619 "got: {msg}"
620 );
621
622 let msg = server_returning(serde_json::json!({
624 "isError": true,
625 "content": [{ "type": "text", "text": "" }],
626 }))
627 .call_tool("do_thing", serde_json::json!({}))
628 .await
629 .expect_err("isError:true must map to Err")
630 .to_string();
631 assert!(
632 msg.contains("(no error detail returned by server)"),
633 "got: {msg}"
634 );
635 }
636
637 #[cfg(unix)]
638 #[tokio::test]
639 async fn dropping_stdio_registry_reaps_child_process() {
640 use std::io::Write;
641 use std::os::unix::fs::PermissionsExt;
642 use std::path::Path;
643 use tokio::time::{Duration, sleep};
644
645 fn process_is_alive(pid: u32) -> bool {
646 std::process::Command::new("kill")
647 .arg("-0")
648 .arg(pid.to_string())
649 .stdout(std::process::Stdio::null())
650 .stderr(std::process::Stdio::null())
651 .status()
652 .is_ok_and(|status| status.success())
653 }
654
655 async fn read_pid(path: &Path) -> u32 {
656 for _ in 0..50 {
657 if let Ok(raw) = tokio::fs::read_to_string(path).await
658 && let Ok(pid) = raw.trim().parse()
659 {
660 return pid;
661 }
662 sleep(Duration::from_millis(20)).await;
663 }
664 panic!("stdio MCP test server did not write its pid");
665 }
666
667 let temp = tempfile::tempdir().expect("tempdir");
668 let server_path = temp.path().join("echo-mcp.sh");
669 let pid_path = temp.path().join("echo-mcp.pid");
670 let mut script = std::fs::File::create(&server_path).expect("script");
671 script
672 .write_all(
673 br#"#!/bin/sh
674echo "$$" > "$1"
675while IFS= read -r line; do
676 case "$line" in
677 *'"method":"initialize"'*)
678 printf '%s\n' '{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{"tools":{}},"serverInfo":{"name":"echo-mcp","version":"0.1.0"}}}'
679 ;;
680 *'"method":"tools/list"'*)
681 printf '%s\n' '{"jsonrpc":"2.0","id":2,"result":{"tools":[]}}'
682 exec tail -f /dev/null
683 ;;
684 esac
685done
686"#,
687 )
688 .expect("write script");
689 drop(script);
690 let mut perms = std::fs::metadata(&server_path)
691 .expect("metadata")
692 .permissions();
693 perms.set_mode(0o755);
694 std::fs::set_permissions(&server_path, perms).expect("chmod");
695
696 let config = McpServerConfig {
697 name: "echo".to_string(),
698 command: server_path.display().to_string(),
699 args: vec![pid_path.display().to_string()],
700 env: std::collections::HashMap::default(),
701 tool_timeout_secs: None,
702 transport: McpTransport::Stdio,
703 url: None,
704 headers: std::collections::HashMap::default(),
705 };
706
707 let registry = McpRegistry::connect_all(&[config])
708 .await
709 .expect("connect_all should not fail");
710 assert_eq!(registry.server_count(), 1);
711 assert_eq!(registry.tool_count(), 0);
712 let child_pid = read_pid(&pid_path).await;
713 assert!(
714 process_is_alive(child_pid),
715 "stdio MCP child should be alive while the registry is alive"
716 );
717
718 drop(registry);
719
720 for _ in 0..50 {
721 if !process_is_alive(child_pid) {
722 return;
723 }
724 sleep(Duration::from_millis(20)).await;
725 }
726 panic!("stdio MCP child process {child_pid} survived after registry drop");
727 }
728}