Skip to main content

zeroclaw_tools/
tool_search.rs

1//! Built-in `tool_search` tool for on-demand MCP tool schema loading.
2//!
3//! When `mcp.deferred_loading` is enabled, this tool lets the LLM discover and
4//! activate deferred MCP tools. Supports two query modes:
5//! - `select:name1,name2` — fetch exact tools by prefixed name.
6//! - Free-text keyword search — returns the best-matching stubs.
7
8use std::fmt::Write;
9use std::sync::{Arc, Mutex};
10
11use async_trait::async_trait;
12
13use crate::mcp_deferred::{ActivatedToolSet, DeferredMcpToolSet};
14use zeroclaw_api::tool::{Tool, ToolResult};
15
16/// Default maximum number of search results.
17const DEFAULT_MAX_RESULTS: usize = 5;
18
19/// Tool-level access policy applied at discovery time.
20///
21/// When set on `ToolSearchTool`, deferred tools that fail this check are
22/// never surfaced to the LLM and never activated — keeping them out of
23/// the context window entirely.
24#[derive(Clone, Default)]
25pub struct ToolAccessPolicy {
26    pub allowed: Option<Vec<String>>,
27    pub denied: Option<Vec<String>>,
28}
29
30impl ToolAccessPolicy {
31    /// Construct from a `SecurityPolicy`'s tool fields and an optional
32    /// caller-supplied allowlist. Used by both `run()` and
33    /// `process_message()` to keep policy construction in sync.
34    pub fn from_security(
35        allowed_tools: Option<&[String]>,
36        excluded_tools: Option<&[String]>,
37        caller_allowed: Option<&[String]>,
38    ) -> Option<Self> {
39        let mut policy = Self::default();
40        if let Some(list) = allowed_tools {
41            let mut merged = list.to_vec();
42            if let Some(caller) = caller_allowed {
43                merged.retain(|t| caller.iter().any(|c| c == t));
44            }
45            policy.allowed = Some(merged);
46        } else if let Some(caller) = caller_allowed {
47            policy.allowed = Some(caller.to_vec());
48        }
49        if let Some(list) = excluded_tools {
50            policy.denied = Some(list.to_vec());
51        }
52        if policy.allowed.is_some() || policy.denied.is_some() {
53            Some(policy)
54        } else {
55            None
56        }
57    }
58
59    pub fn is_tool_allowed(&self, name: &str) -> bool {
60        let in_allow = self
61            .allowed
62            .as_ref()
63            .is_none_or(|list| list.iter().any(|t| t == name));
64        let in_deny = self
65            .denied
66            .as_ref()
67            .is_some_and(|list| list.iter().any(|t| t == name));
68        in_allow && !in_deny
69    }
70}
71
72/// Built-in tool that fetches full schemas for deferred MCP tools.
73pub struct ToolSearchTool {
74    deferred: DeferredMcpToolSet,
75    activated: Arc<Mutex<ActivatedToolSet>>,
76    access_policy: Option<ToolAccessPolicy>,
77}
78
79impl ToolSearchTool {
80    pub fn new(deferred: DeferredMcpToolSet, activated: Arc<Mutex<ActivatedToolSet>>) -> Self {
81        Self {
82            deferred,
83            activated,
84            access_policy: None,
85        }
86    }
87
88    pub fn with_access_policy(mut self, policy: ToolAccessPolicy) -> Self {
89        self.access_policy = Some(policy);
90        self
91    }
92
93    fn is_allowed(&self, tool_name: &str) -> bool {
94        self.access_policy
95            .as_ref()
96            .is_none_or(|p| p.is_tool_allowed(tool_name))
97    }
98}
99
100#[async_trait]
101impl Tool for ToolSearchTool {
102    fn name(&self) -> &str {
103        "tool_search"
104    }
105
106    fn description(&self) -> &str {
107        "Fetch full schema definitions for deferred MCP tools so they can be called. \
108         Use \"select:name1,name2\" for exact match or keywords to search."
109    }
110
111    fn parameters_schema(&self) -> serde_json::Value {
112        serde_json::json!({
113            "type": "object",
114            "properties": {
115                "query": {
116                    "description": "Query to find deferred tools. Use \"select:<tool_name>\" for direct selection, or keywords to search.",
117                    "type": "string"
118                },
119                "max_results": {
120                    "description": "Maximum number of results to return (default: 5)",
121                    "type": "number",
122                    "default": DEFAULT_MAX_RESULTS
123                }
124            },
125            "required": ["query"]
126        })
127    }
128
129    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
130        let query = args
131            .get("query")
132            .and_then(|v| v.as_str())
133            .unwrap_or_default()
134            .trim();
135
136        let max_results = args
137            .get("max_results")
138            .and_then(|v| v.as_u64())
139            .map(|v| usize::try_from(v).unwrap_or(DEFAULT_MAX_RESULTS))
140            .unwrap_or(DEFAULT_MAX_RESULTS);
141
142        if query.is_empty() {
143            return Ok(ToolResult {
144                success: false,
145                output: String::new(),
146                error: Some("query parameter is required".into()),
147            });
148        }
149
150        // Parse query mode
151        if let Some(names_str) = query.strip_prefix("select:") {
152            // Exact selection mode
153            let names: Vec<&str> = names_str.split(',').map(str::trim).collect();
154            return self.select_tools(&names);
155        }
156
157        // Keyword search mode.
158        // When a policy is active, fetch all matches so denied tools don't
159        // consume result slots. The max_results cap is applied after filtering.
160        let search_limit = if self.access_policy.is_some() {
161            usize::MAX
162        } else {
163            max_results
164        };
165        let results = self.deferred.search(query, search_limit);
166        if results.is_empty() {
167            return Ok(ToolResult {
168                success: true,
169                output: "No matching deferred tools found.".into(),
170                error: None,
171            });
172        }
173
174        // Activate and return full specs (policy-filtered, then capped)
175        let mut output = String::from("<functions>\n");
176        let mut activated_count = 0;
177        let mut returned_count = 0;
178        let mut guard = self.activated.lock().unwrap();
179
180        for stub in &results {
181            if returned_count >= max_results {
182                break;
183            }
184            if !self.is_allowed(&stub.prefixed_name) {
185                ::zeroclaw_log::record!(
186                    DEBUG,
187                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
188                    &format!(
189                        "tool_search: '{}' matched query but denied by access policy",
190                        stub.prefixed_name
191                    )
192                );
193                continue;
194            }
195            if let Some(spec) = self.deferred.tool_spec(&stub.prefixed_name) {
196                if !guard.is_activated(&stub.prefixed_name)
197                    && let Some(tool) = self.deferred.activate(&stub.prefixed_name)
198                {
199                    guard.activate(stub.prefixed_name.clone(), Arc::from(tool));
200                    activated_count += 1;
201                }
202                let _ = writeln!(
203                    output,
204                    "<function>{{\"name\": \"{}\", \"description\": \"{}\", \"parameters\": {}}}</function>",
205                    spec.name,
206                    spec.description.replace('"', "\\\""),
207                    spec.parameters
208                );
209                returned_count += 1;
210            }
211        }
212
213        output.push_str("</functions>\n");
214        drop(guard);
215
216        ::zeroclaw_log::record!(
217            DEBUG,
218            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
219            &format!(
220                "tool_search: query={query:?}, matched={}, activated={activated_count}",
221                results.len()
222            )
223        );
224
225        Ok(ToolResult {
226            success: true,
227            output,
228            error: None,
229        })
230    }
231}
232
233impl ToolSearchTool {
234    fn select_tools(&self, names: &[&str]) -> anyhow::Result<ToolResult> {
235        let mut output = String::from("<functions>\n");
236        let mut not_found = Vec::new();
237        let mut activated_count = 0;
238        let mut guard = self.activated.lock().unwrap();
239
240        for name in names {
241            if name.is_empty() {
242                continue;
243            }
244            if !self.is_allowed(name) {
245                ::zeroclaw_log::record!(
246                    DEBUG,
247                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
248                    &format!("tool_search select: '{}' denied by access policy", name)
249                );
250                not_found.push(*name);
251                continue;
252            }
253            match self.deferred.tool_spec(name) {
254                Some(spec) => {
255                    if !guard.is_activated(name)
256                        && let Some(tool) = self.deferred.activate(name)
257                    {
258                        guard.activate(String::from(*name), Arc::from(tool));
259                        activated_count += 1;
260                    }
261                    let _ = writeln!(
262                        output,
263                        "<function>{{\"name\": \"{}\", \"description\": \"{}\", \"parameters\": {}}}</function>",
264                        spec.name,
265                        spec.description.replace('"', "\\\""),
266                        spec.parameters
267                    );
268                }
269                None => {
270                    not_found.push(*name);
271                }
272            }
273        }
274
275        output.push_str("</functions>\n");
276        drop(guard);
277
278        if !not_found.is_empty() {
279            let _ = write!(output, "\nNot found: {}", not_found.join(", "));
280        }
281
282        ::zeroclaw_log::record!(
283            DEBUG,
284            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
285            &format!(
286                "tool_search select: requested={}, activated={activated_count}, not_found={}",
287                names.len(),
288                not_found.len()
289            )
290        );
291
292        Ok(ToolResult {
293            success: true,
294            output,
295            error: None,
296        })
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::mcp_client::McpRegistry;
304    use crate::mcp_deferred::DeferredMcpToolStub;
305    use crate::mcp_protocol::McpToolDef;
306
307    async fn make_deferred_set(stubs: Vec<DeferredMcpToolStub>) -> DeferredMcpToolSet {
308        let registry = Arc::new(McpRegistry::connect_all(&[]).await.unwrap());
309        DeferredMcpToolSet { stubs, registry }
310    }
311
312    fn make_stub(name: &str, desc: &str) -> DeferredMcpToolStub {
313        let def = McpToolDef {
314            name: name.to_string(),
315            description: Some(desc.to_string()),
316            input_schema: serde_json::json!({"type": "object", "properties": {}}),
317        };
318        DeferredMcpToolStub::new(name.to_string(), def)
319    }
320
321    #[tokio::test]
322    async fn tool_metadata() {
323        let tool = ToolSearchTool::new(
324            make_deferred_set(vec![]).await,
325            Arc::new(Mutex::new(ActivatedToolSet::new())),
326        );
327        assert_eq!(tool.name(), "tool_search");
328        assert!(!tool.description().is_empty());
329        assert!(tool.parameters_schema()["properties"]["query"].is_object());
330    }
331
332    #[tokio::test]
333    async fn empty_query_returns_error() {
334        let tool = ToolSearchTool::new(
335            make_deferred_set(vec![]).await,
336            Arc::new(Mutex::new(ActivatedToolSet::new())),
337        );
338        let result = tool
339            .execute(serde_json::json!({"query": ""}))
340            .await
341            .unwrap();
342        assert!(!result.success);
343    }
344
345    #[tokio::test]
346    async fn select_nonexistent_tool_reports_not_found() {
347        let tool = ToolSearchTool::new(
348            make_deferred_set(vec![]).await,
349            Arc::new(Mutex::new(ActivatedToolSet::new())),
350        );
351        let result = tool
352            .execute(serde_json::json!({"query": "select:nonexistent"}))
353            .await
354            .unwrap();
355        assert!(result.success);
356        assert!(result.output.contains("Not found"));
357    }
358
359    #[tokio::test]
360    async fn keyword_search_no_matches() {
361        let tool = ToolSearchTool::new(
362            make_deferred_set(vec![make_stub("fs__read", "Read file")]).await,
363            Arc::new(Mutex::new(ActivatedToolSet::new())),
364        );
365        let result = tool
366            .execute(serde_json::json!({"query": "zzzzz_nonexistent"}))
367            .await
368            .unwrap();
369        assert!(result.success);
370        assert!(result.output.contains("No matching"));
371    }
372
373    #[tokio::test]
374    async fn keyword_search_finds_match() {
375        let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
376        let tool = ToolSearchTool::new(
377            make_deferred_set(vec![make_stub("fs__read", "Read a file from disk")]).await,
378            Arc::clone(&activated),
379        );
380        let result = tool
381            .execute(serde_json::json!({"query": "read file"}))
382            .await
383            .unwrap();
384        assert!(result.success);
385        assert!(result.output.contains("<function>"));
386        assert!(result.output.contains("fs__read"));
387        // Tool should now be activated
388        assert!(activated.lock().unwrap().is_activated("fs__read"));
389    }
390
391    /// Verify tool_search works with stubs from multiple MCP servers,
392    /// simulating a daemon-mode setup where several servers are deferred.
393    #[tokio::test]
394    async fn multiple_servers_stubs_all_searchable() {
395        let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
396        let stubs = vec![
397            make_stub("server_a__list_files", "List files on server A"),
398            make_stub("server_a__read_file", "Read file on server A"),
399            make_stub("server_b__query_db", "Query database on server B"),
400            make_stub("server_b__insert_row", "Insert row on server B"),
401        ];
402        let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated));
403
404        // Search should find tools across both servers
405        let result = tool
406            .execute(serde_json::json!({"query": "file"}))
407            .await
408            .unwrap();
409        assert!(result.success);
410        assert!(result.output.contains("server_a__list_files"));
411        assert!(result.output.contains("server_a__read_file"));
412
413        // Server B tools should also be searchable
414        let result = tool
415            .execute(serde_json::json!({"query": "database query"}))
416            .await
417            .unwrap();
418        assert!(result.success);
419        assert!(result.output.contains("server_b__query_db"));
420    }
421
422    /// Verify select mode activates tools and they stay activated across calls,
423    /// matching the daemon-mode pattern where a single ActivatedToolSet persists.
424    #[tokio::test]
425    async fn select_activates_and_persists_across_calls() {
426        let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
427        let stubs = vec![
428            make_stub("srv__tool_a", "Tool A"),
429            make_stub("srv__tool_b", "Tool B"),
430        ];
431        let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated));
432
433        // Activate tool_a
434        let result = tool
435            .execute(serde_json::json!({"query": "select:srv__tool_a"}))
436            .await
437            .unwrap();
438        assert!(result.success);
439        assert!(activated.lock().unwrap().is_activated("srv__tool_a"));
440        assert!(!activated.lock().unwrap().is_activated("srv__tool_b"));
441
442        // Activate tool_b in a separate call
443        let result = tool
444            .execute(serde_json::json!({"query": "select:srv__tool_b"}))
445            .await
446            .unwrap();
447        assert!(result.success);
448
449        // Both should remain activated
450        let guard = activated.lock().unwrap();
451        assert!(guard.is_activated("srv__tool_a"));
452        assert!(guard.is_activated("srv__tool_b"));
453        assert_eq!(guard.tool_specs().len(), 2);
454    }
455
456    /// Verify re-activating an already-activated tool does not duplicate it.
457    #[tokio::test]
458    async fn reactivation_is_idempotent() {
459        let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
460        let tool = ToolSearchTool::new(
461            make_deferred_set(vec![make_stub("srv__tool", "A tool")]).await,
462            Arc::clone(&activated),
463        );
464
465        tool.execute(serde_json::json!({"query": "select:srv__tool"}))
466            .await
467            .unwrap();
468        tool.execute(serde_json::json!({"query": "select:srv__tool"}))
469            .await
470            .unwrap();
471
472        assert_eq!(activated.lock().unwrap().tool_specs().len(), 1);
473    }
474
475    #[test]
476    fn policy_none_is_unrestricted() {
477        let p = ToolAccessPolicy::default();
478        assert!(p.is_tool_allowed("shell"));
479        assert!(p.is_tool_allowed("anything"));
480    }
481
482    #[test]
483    fn policy_allowlist_admits_only_listed() {
484        let p = ToolAccessPolicy {
485            allowed: Some(vec!["shell".into(), "file_read".into()]),
486            denied: None,
487        };
488        assert!(p.is_tool_allowed("shell"));
489        assert!(!p.is_tool_allowed("file_write"));
490    }
491
492    #[test]
493    fn policy_denylist_rejects_listed() {
494        let p = ToolAccessPolicy {
495            allowed: None,
496            denied: Some(vec!["shell".into()]),
497        };
498        assert!(!p.is_tool_allowed("shell"));
499        assert!(p.is_tool_allowed("file_read"));
500    }
501
502    #[test]
503    fn policy_deny_overrides_allow() {
504        let p = ToolAccessPolicy {
505            allowed: Some(vec!["shell".into(), "file_read".into()]),
506            denied: Some(vec!["shell".into()]),
507        };
508        assert!(!p.is_tool_allowed("shell"));
509        assert!(p.is_tool_allowed("file_read"));
510    }
511
512    #[tokio::test]
513    async fn policy_filters_keyword_search_results() {
514        let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
515        let stubs = vec![
516            make_stub("srv__allowed_tool", "An allowed tool"),
517            make_stub("srv__blocked_tool", "A blocked tool"),
518        ];
519        let policy = ToolAccessPolicy {
520            allowed: None,
521            denied: Some(vec!["srv__blocked_tool".into()]),
522        };
523        let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated))
524            .with_access_policy(policy);
525
526        let result = tool
527            .execute(serde_json::json!({"query": "tool"}))
528            .await
529            .unwrap();
530        assert!(result.success);
531        assert!(result.output.contains("srv__allowed_tool"));
532        assert!(!result.output.contains("srv__blocked_tool"));
533        assert!(!activated.lock().unwrap().is_activated("srv__blocked_tool"));
534    }
535
536    #[tokio::test]
537    async fn policy_denied_tool_does_not_consume_max_results_slot() {
538        let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
539        // "denied_tool" ranks higher (more keyword matches) but is blocked.
540        // "allowed_tool" ranks lower but should still be returned with max_results=1.
541        let stubs = vec![
542            make_stub("srv__denied_tool", "tool for searching files"),
543            make_stub("srv__allowed_tool", "tool for files"),
544        ];
545        let policy = ToolAccessPolicy {
546            allowed: None,
547            denied: Some(vec!["srv__denied_tool".into()]),
548        };
549        let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated))
550            .with_access_policy(policy);
551
552        let result = tool
553            .execute(serde_json::json!({"query": "searching files", "max_results": 1}))
554            .await
555            .unwrap();
556        assert!(result.success);
557        // The allowed tool should be returned even though max_results=1
558        // and the denied tool ranked higher.
559        assert!(result.output.contains("srv__allowed_tool"));
560        assert!(!result.output.contains("srv__denied_tool"));
561        assert!(activated.lock().unwrap().is_activated("srv__allowed_tool"));
562    }
563
564    #[tokio::test]
565    async fn policy_filters_select_results() {
566        let activated = Arc::new(Mutex::new(ActivatedToolSet::new()));
567        let stubs = vec![
568            make_stub("srv__ok", "OK tool"),
569            make_stub("srv__nope", "Blocked tool"),
570        ];
571        let policy = ToolAccessPolicy {
572            allowed: Some(vec!["srv__ok".into()]),
573            denied: None,
574        };
575        let tool = ToolSearchTool::new(make_deferred_set(stubs).await, Arc::clone(&activated))
576            .with_access_policy(policy);
577
578        let result = tool
579            .execute(serde_json::json!({"query": "select:srv__ok,srv__nope"}))
580            .await
581            .unwrap();
582        assert!(result.success);
583        assert!(result.output.contains("srv__ok"));
584        assert!(!result.output.contains("\"name\": \"srv__nope\""));
585        assert!(result.output.contains("Not found"));
586        assert!(!activated.lock().unwrap().is_activated("srv__nope"));
587    }
588}