1use std::fmt::Write;
9use std::sync::{Arc, Mutex};
10
11use async_trait::async_trait;
12
13use crate::mcp_deferred::{ActivatedToolSet, DeferredMcpToolSet};
14use zeroclaw_api::tool::{Tool, ToolResult};
15
16const DEFAULT_MAX_RESULTS: usize = 5;
18
19#[derive(Clone, Default)]
25pub struct ToolAccessPolicy {
26 pub allowed: Option<Vec<String>>,
27 pub denied: Option<Vec<String>>,
28}
29
30impl ToolAccessPolicy {
31 pub fn from_security(
35 allowed_tools: Option<&[String]>,
36 excluded_tools: Option<&[String]>,
37 caller_allowed: Option<&[String]>,
38 ) -> Option<Self> {
39 let mut policy = Self::default();
40 if let Some(list) = allowed_tools {
41 let mut merged = list.to_vec();
42 if let Some(caller) = caller_allowed {
43 merged.retain(|t| caller.iter().any(|c| c == t));
44 }
45 policy.allowed = Some(merged);
46 } else if let Some(caller) = caller_allowed {
47 policy.allowed = Some(caller.to_vec());
48 }
49 if let Some(list) = excluded_tools {
50 policy.denied = Some(list.to_vec());
51 }
52 if policy.allowed.is_some() || policy.denied.is_some() {
53 Some(policy)
54 } else {
55 None
56 }
57 }
58
59 pub fn is_tool_allowed(&self, name: &str) -> bool {
60 let in_allow = self
61 .allowed
62 .as_ref()
63 .is_none_or(|list| list.iter().any(|t| t == name));
64 let in_deny = self
65 .denied
66 .as_ref()
67 .is_some_and(|list| list.iter().any(|t| t == name));
68 in_allow && !in_deny
69 }
70}
71
72pub struct ToolSearchTool {
74 deferred: DeferredMcpToolSet,
75 activated: Arc<Mutex<ActivatedToolSet>>,
76 access_policy: Option<ToolAccessPolicy>,
77}
78
79impl ToolSearchTool {
80 pub fn new(deferred: DeferredMcpToolSet, activated: Arc<Mutex<ActivatedToolSet>>) -> Self {
81 Self {
82 deferred,
83 activated,
84 access_policy: None,
85 }
86 }
87
88 pub fn with_access_policy(mut self, policy: ToolAccessPolicy) -> Self {
89 self.access_policy = Some(policy);
90 self
91 }
92
93 fn is_allowed(&self, tool_name: &str) -> bool {
94 self.access_policy
95 .as_ref()
96 .is_none_or(|p| p.is_tool_allowed(tool_name))
97 }
98}
99
100#[async_trait]
101impl Tool for ToolSearchTool {
102 fn name(&self) -> &str {
103 "tool_search"
104 }
105
106 fn description(&self) -> &str {
107 "Fetch full schema definitions for deferred MCP tools so they can be called. \
108 Use \"select:name1,name2\" for exact match or keywords to search."
109 }
110
111 fn parameters_schema(&self) -> serde_json::Value {
112 serde_json::json!({
113 "type": "object",
114 "properties": {
115 "query": {
116 "description": "Query to find deferred tools. Use \"select:<tool_name>\" for direct selection, or keywords to search.",
117 "type": "string"
118 },
119 "max_results": {
120 "description": "Maximum number of results to return (default: 5)",
121 "type": "number",
122 "default": DEFAULT_MAX_RESULTS
123 }
124 },
125 "required": ["query"]
126 })
127 }
128
129 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
130 let query = args
131 .get("query")
132 .and_then(|v| v.as_str())
133 .unwrap_or_default()
134 .trim();
135
136 let max_results = args
137 .get("max_results")
138 .and_then(|v| v.as_u64())
139 .map(|v| usize::try_from(v).unwrap_or(DEFAULT_MAX_RESULTS))
140 .unwrap_or(DEFAULT_MAX_RESULTS);
141
142 if query.is_empty() {
143 return Ok(ToolResult {
144 success: false,
145 output: String::new(),
146 error: Some("query parameter is required".into()),
147 });
148 }
149
150 if let Some(names_str) = query.strip_prefix("select:") {
152 let names: Vec<&str> = names_str.split(',').map(str::trim).collect();
154 return self.select_tools(&names);
155 }
156
157 let search_limit = if self.access_policy.is_some() {
161 usize::MAX
162 } else {
163 max_results
164 };
165 let results = self.deferred.search(query, search_limit);
166 if results.is_empty() {
167 return Ok(ToolResult {
168 success: true,
169 output: "No matching deferred tools found.".into(),
170 error: None,
171 });
172 }
173
174 let mut output = String::from("<functions>\n");
176 let mut activated_count = 0;
177 let mut returned_count = 0;
178 let mut guard = self.activated.lock().unwrap();
179
180 for stub in &results {
181 if returned_count >= max_results {
182 break;
183 }
184 if !self.is_allowed(&stub.prefixed_name) {
185 ::zeroclaw_log::record!(
186 DEBUG,
187 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
188 &format!(
189 "tool_search: '{}' matched query but denied by access policy",
190 stub.prefixed_name
191 )
192 );
193 continue;
194 }
195 if let Some(spec) = self.deferred.tool_spec(&stub.prefixed_name) {
196 if !guard.is_activated(&stub.prefixed_name)
197 && let Some(tool) = self.deferred.activate(&stub.prefixed_name)
198 {
199 guard.activate(stub.prefixed_name.clone(), Arc::from(tool));
200 activated_count += 1;
201 }
202 let _ = writeln!(
203 output,
204 "<function>{{\"name\": \"{}\", \"description\": \"{}\", \"parameters\": {}}}</function>",
205 spec.name,
206 spec.description.replace('"', "\\\""),
207 spec.parameters
208 );
209 returned_count += 1;
210 }
211 }
212
213 output.push_str("</functions>\n");
214 drop(guard);
215
216 ::zeroclaw_log::record!(
217 DEBUG,
218 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
219 &format!(
220 "tool_search: query={query:?}, matched={}, activated={activated_count}",
221 results.len()
222 )
223 );
224
225 Ok(ToolResult {
226 success: true,
227 output,
228 error: None,
229 })
230 }
231}
232
233impl ToolSearchTool {
234 fn select_tools(&self, names: &[&str]) -> anyhow::Result<ToolResult> {
235 let mut output = String::from("<functions>\n");
236 let mut not_found = Vec::new();
237 let mut activated_count = 0;
238 let mut guard = self.activated.lock().unwrap();
239
240 for name in names {
241 if name.is_empty() {
242 continue;
243 }
244 if !self.is_allowed(name) {
245 ::zeroclaw_log::record!(
246 DEBUG,
247 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
248 &format!("tool_search select: '{}' denied by access policy", name)
249 );
250 not_found.push(*name);
251 continue;
252 }
253 match self.deferred.tool_spec(name) {
254 Some(spec) => {
255 if !guard.is_activated(name)
256 && let Some(tool) = self.deferred.activate(name)
257 {
258 guard.activate(String::from(*name), Arc::from(tool));
259 activated_count += 1;
260 }
261 let _ = writeln!(
262 output,
263 "<function>{{\"name\": \"{}\", \"description\": \"{}\", \"parameters\": {}}}</function>",
264 spec.name,
265 spec.description.replace('"', "\\\""),
266 spec.parameters
267 );
268 }
269 None => {
270 not_found.push(*name);
271 }
272 }
273 }
274
275 output.push_str("</functions>\n");
276 drop(guard);
277
278 if !not_found.is_empty() {
279 let _ = write!(output, "\nNot found: {}", not_found.join(", "));
280 }
281
282 ::zeroclaw_log::record!(
283 DEBUG,
284 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
285 &format!(
286 "tool_search select: requested={}, activated={activated_count}, not_found={}",
287 names.len(),
288 not_found.len()
289 )
290 );
291
292 Ok(ToolResult {
293 success: true,
294 output,
295 error: None,
296 })
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::mcp_client::McpRegistry;
304 use crate::mcp_deferred::DeferredMcpToolStub;
305 use crate::mcp_protocol::McpToolDef;
306
307 async fn make_deferred_set(stubs: Vec<DeferredMcpToolStub>) -> DeferredMcpToolSet {
308 let registry = Arc::new(McpRegistry::connect_all(&[]).await.unwrap());
309 DeferredMcpToolSet { stubs, registry }
310 }
311
312 fn make_stub(name: &str, desc: &str) -> DeferredMcpToolStub {
313 let def = McpToolDef {
314 name: name.to_string(),
315 description: Some(desc.to_string()),
316 input_schema: serde_json::json!({"type": "object", "properties": {}}),
317 };
318 DeferredMcpToolStub::new(name.to_string(), def)
319 }
320
321 #[tokio::test]
322 async fn tool_metadata() {
323 let tool = ToolSearchTool::new(
324 make_deferred_set(vec![]).await,
325 Arc::new(Mutex::new(ActivatedToolSet::new())),
326 );
327 assert_eq!(tool.name(), "tool_search");
328 assert!(!tool.description().is_empty());
329 assert!(tool.parameters_schema()["properties"]["query"].is_object());
330 }
331
332 #[tokio::test]
333 async fn empty_query_returns_error() {
334 let tool = ToolSearchTool::new(
335 make_deferred_set(vec![]).await,
336 Arc::new(Mutex::new(ActivatedToolSet::new())),
337 );
338 let result = tool
339 .execute(serde_json::json!({"query": ""}))
340 .await
341 .unwrap();
342 assert!(!result.success);
343 }
344
345 #[tokio::test]
346 async fn select_nonexistent_tool_reports_not_found() {
347 let tool = ToolSearchTool::new(
348 make_deferred_set(vec![]).await,
349 Arc::new(Mutex::new(ActivatedToolSet::new())),
350 );
351 let result = tool
352 .execute(serde_json::json!({"query": "select:nonexistent"}))
353 .await
354 .unwrap();
355 assert!(result.success);
356 assert!(result.output.contains("Not found"));
357 }
358
359 #[tokio::test]
360 async fn keyword_search_no_matches() {
361 let tool = ToolSearchTool::new(
362 make_deferred_set(vec![make_stub("fs__read", "Read file")]).await,
363 Arc::new(Mutex::new(ActivatedToolSet::new())),
364 );
365 let result = tool
366 .execute(serde_json::json!({"query": "zzzzz_nonexistent"}))
367 .await
368 .unwrap();
369 assert!(result.success);
370 assert!(result.output.contains("No matching"));
371 }
372
373 #[tokio::test]
374 async fn keyword_search_finds_match() {
375 let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
376 let tool = ToolSearchTool::new(
377 make_deferred_set(vec![make_stub("fs__read", "Read a file from disk")]).await,
378 Arc::clone(&activated),
379 );
380 let result = tool
381 .execute(serde_json::json!({"query": "read file"}))
382 .await
383 .unwrap();
384 assert!(result.success);
385 assert!(result.output.contains("<function>"));
386 assert!(result.output.contains("fs__read"));
387 assert!(activated.lock().unwrap().is_activated("fs__read"));
389 }
390
391 #[tokio::test]
394 async fn multiple_servers_stubs_all_searchable() {
395 let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
396 let stubs = vec![
397 make_stub("server_a__list_files", "List files on server A"),
398 make_stub("server_a__read_file", "Read file on server A"),
399 make_stub("server_b__query_db", "Query database on server B"),
400 make_stub("server_b__insert_row", "Insert row on server B"),
401 ];
402 let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated));
403
404 let result = tool
406 .execute(serde_json::json!({"query": "file"}))
407 .await
408 .unwrap();
409 assert!(result.success);
410 assert!(result.output.contains("server_a__list_files"));
411 assert!(result.output.contains("server_a__read_file"));
412
413 let result = tool
415 .execute(serde_json::json!({"query": "database query"}))
416 .await
417 .unwrap();
418 assert!(result.success);
419 assert!(result.output.contains("server_b__query_db"));
420 }
421
422 #[tokio::test]
425 async fn select_activates_and_persists_across_calls() {
426 let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
427 let stubs = vec![
428 make_stub("srv__tool_a", "Tool A"),
429 make_stub("srv__tool_b", "Tool B"),
430 ];
431 let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated));
432
433 let result = tool
435 .execute(serde_json::json!({"query": "select:srv__tool_a"}))
436 .await
437 .unwrap();
438 assert!(result.success);
439 assert!(activated.lock().unwrap().is_activated("srv__tool_a"));
440 assert!(!activated.lock().unwrap().is_activated("srv__tool_b"));
441
442 let result = tool
444 .execute(serde_json::json!({"query": "select:srv__tool_b"}))
445 .await
446 .unwrap();
447 assert!(result.success);
448
449 let guard = activated.lock().unwrap();
451 assert!(guard.is_activated("srv__tool_a"));
452 assert!(guard.is_activated("srv__tool_b"));
453 assert_eq!(guard.tool_specs().len(), 2);
454 }
455
456 #[tokio::test]
458 async fn reactivation_is_idempotent() {
459 let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
460 let tool = ToolSearchTool::new(
461 make_deferred_set(vec![make_stub("srv__tool", "A tool")]).await,
462 Arc::clone(&activated),
463 );
464
465 tool.execute(serde_json::json!({"query": "select:srv__tool"}))
466 .await
467 .unwrap();
468 tool.execute(serde_json::json!({"query": "select:srv__tool"}))
469 .await
470 .unwrap();
471
472 assert_eq!(activated.lock().unwrap().tool_specs().len(), 1);
473 }
474
475 #[test]
476 fn policy_none_is_unrestricted() {
477 let p = ToolAccessPolicy::default();
478 assert!(p.is_tool_allowed("shell"));
479 assert!(p.is_tool_allowed("anything"));
480 }
481
482 #[test]
483 fn policy_allowlist_admits_only_listed() {
484 let p = ToolAccessPolicy {
485 allowed: Some(vec!["shell".into(), "file_read".into()]),
486 denied: None,
487 };
488 assert!(p.is_tool_allowed("shell"));
489 assert!(!p.is_tool_allowed("file_write"));
490 }
491
492 #[test]
493 fn policy_denylist_rejects_listed() {
494 let p = ToolAccessPolicy {
495 allowed: None,
496 denied: Some(vec!["shell".into()]),
497 };
498 assert!(!p.is_tool_allowed("shell"));
499 assert!(p.is_tool_allowed("file_read"));
500 }
501
502 #[test]
503 fn policy_deny_overrides_allow() {
504 let p = ToolAccessPolicy {
505 allowed: Some(vec!["shell".into(), "file_read".into()]),
506 denied: Some(vec!["shell".into()]),
507 };
508 assert!(!p.is_tool_allowed("shell"));
509 assert!(p.is_tool_allowed("file_read"));
510 }
511
512 #[tokio::test]
513 async fn policy_filters_keyword_search_results() {
514 let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
515 let stubs = vec![
516 make_stub("srv__allowed_tool", "An allowed tool"),
517 make_stub("srv__blocked_tool", "A blocked tool"),
518 ];
519 let policy = ToolAccessPolicy {
520 allowed: None,
521 denied: Some(vec!["srv__blocked_tool".into()]),
522 };
523 let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated))
524 .with_access_policy(policy);
525
526 let result = tool
527 .execute(serde_json::json!({"query": "tool"}))
528 .await
529 .unwrap();
530 assert!(result.success);
531 assert!(result.output.contains("srv__allowed_tool"));
532 assert!(!result.output.contains("srv__blocked_tool"));
533 assert!(!activated.lock().unwrap().is_activated("srv__blocked_tool"));
534 }
535
536 #[tokio::test]
537 async fn policy_denied_tool_does_not_consume_max_results_slot() {
538 let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
539 let stubs = vec![
542 make_stub("srv__denied_tool", "tool for searching files"),
543 make_stub("srv__allowed_tool", "tool for files"),
544 ];
545 let policy = ToolAccessPolicy {
546 allowed: None,
547 denied: Some(vec!["srv__denied_tool".into()]),
548 };
549 let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated))
550 .with_access_policy(policy);
551
552 let result = tool
553 .execute(serde_json::json!({"query": "searching files", "max_results": 1}))
554 .await
555 .unwrap();
556 assert!(result.success);
557 assert!(result.output.contains("srv__allowed_tool"));
560 assert!(!result.output.contains("srv__denied_tool"));
561 assert!(activated.lock().unwrap().is_activated("srv__allowed_tool"));
562 }
563
564 #[tokio::test]
565 async fn policy_filters_select_results() {
566 let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
567 let stubs = vec![
568 make_stub("srv__ok", "OK tool"),
569 make_stub("srv__nope", "Blocked tool"),
570 ];
571 let policy = ToolAccessPolicy {
572 allowed: Some(vec!["srv__ok".into()]),
573 denied: None,
574 };
575 let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated))
576 .with_access_policy(policy);
577
578 let result = tool
579 .execute(serde_json::json!({"query": "select:srv__ok,srv__nope"}))
580 .await
581 .unwrap();
582 assert!(result.success);
583 assert!(result.output.contains("srv__ok"));
584 assert!(!result.output.contains("\"name\": \"srv__nope\""));
585 assert!(result.output.contains("Not found"));
586 assert!(!activated.lock().unwrap().is_activated("srv__nope"));
587 }
588}