Skip to main content

zeroclaw_tools/
wrappers.rs

1//! Generic tool wrappers for crosscutting concerns.
2//!
3//! Each wrapper implements [`Tool`] by delegating to an inner tool while
4//! applying one crosscutting concern around the `execute` call.  Wrappers
5//! compose: stack them at construction time in `tools/mod.rs` rather than
6//! repeating the same guard blocks inside every tool's `execute` method.
7//!
8//! # Composition order (outermost first)
9//!
10//! ```text
11//! RateLimitedTool
12//!   └─ PathGuardedTool
13//!        └─ <concrete tool>
14//! ```
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! let tool = RateLimitedTool::new(
20//!     PathGuardedTool::new(ShellTool::new(security.clone(), runtime), security.clone()),
21//!     security.clone(),
22//! );
23//! ```
24
25use async_trait::async_trait;
26use std::sync::Arc;
27use zeroclaw_api::attribution::{Attributable, Role};
28use zeroclaw_api::tool::{Tool, ToolResult};
29use zeroclaw_config::policy::SecurityPolicy;
30
31/// Type alias for a path-extraction closure used by [`PathGuardedTool`].
32type PathExtractor = dyn Fn(&serde_json::Value) -> Option<String> + Send + Sync;
33
34// ── RateLimitedTool ───────────────────────────────────────────────────────────
35
36/// Wraps any [`Tool`] and enforces the [`SecurityPolicy`] rate limit.
37///
38/// Replaces the repeated `is_rate_limited()` / `record_action()` guard blocks
39/// previously inlined in every tool's `execute` method (~30 files, ~50 call
40/// sites).
41///
42/// # Budget semantics
43///
44/// `record_action()` runs **after** the inner tool returns and only when
45/// `ToolResult.success == true`.  This matches the pre-wrapper behaviour: only
46/// calls that actually performed work consumed the action budget.  Validation,
47/// policy, path-allowlist, read-only, and command-validation failures all
48/// surface as `success: false` from the inner tool (or inner wrapper) and do
49/// not consume a slot.
50///
51/// ## Read-tool exception (anti-probing)
52///
53/// `FileReadTool` (`zeroclaw-runtime::tools::file_read`) and `PdfReadTool` in
54/// this crate intentionally call `record_action()` *themselves* on the
55/// post-`PathGuardedTool` `resolve_candidate` / `canonicalize` failure paths.
56/// This prevents an attacker from probing path existence for free: each
57/// attempt — successful or failed — consumes exactly one slot.  The outer
58/// `RateLimitedTool` only records on `success: true`, so the totals stay at
59/// one slot per attempt.  When introducing a new read-style tool, follow the
60/// same pattern.
61pub struct RateLimitedTool<T: Tool> {
62    inner: T,
63    security: Arc<SecurityPolicy>,
64}
65
66impl<T: Tool> RateLimitedTool<T> {
67    pub fn new(inner: T, security: Arc<SecurityPolicy>) -> Self {
68        Self { inner, security }
69    }
70}
71
72impl<T: Tool> Attributable for RateLimitedTool<T> {
73    fn role(&self) -> Role {
74        self.inner.role()
75    }
76    fn alias(&self) -> &str {
77        self.inner.alias()
78    }
79}
80
81#[async_trait]
82impl<T: Tool> Tool for RateLimitedTool<T> {
83    fn name(&self) -> &str {
84        self.inner.name()
85    }
86
87    fn description(&self) -> &str {
88        self.inner.description()
89    }
90
91    fn parameters_schema(&self) -> serde_json::Value {
92        self.inner.parameters_schema()
93    }
94
95    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
96        if self.security.is_rate_limited() {
97            return Ok(ToolResult {
98                success: false,
99                output: String::new(),
100                error: Some("Rate limit exceeded: too many actions in the last hour".into()),
101            });
102        }
103
104        // Delegate first; only record against the budget when the inner tool
105        // actually performed work (ToolResult.success == true).  This preserves
106        // the pre-wrapper semantics where validation/policy failures (forbidden
107        // paths, malformed args, disabled config, read-only blocks, command
108        // validation) did not consume the action budget.
109        let result = self.inner.execute(args).await?;
110
111        if result.success && !self.security.record_action() {
112            return Ok(ToolResult {
113                success: false,
114                output: String::new(),
115                error: Some("Rate limit exceeded: action budget exhausted".into()),
116            });
117        }
118
119        Ok(result)
120    }
121}
122
123// ── PathGuardedTool ───────────────────────────────────────────────────────────
124
125/// Wraps any [`Tool`] and blocks calls whose arguments contain a forbidden path.
126///
127/// Replaces the `forbidden_path_argument()` guard blocks previously inlined in
128/// tools that accept a path-like argument (`shell`, `file_read`, `file_write`,
129/// `file_edit`, `pdf_read`, `content_search`, `glob_search`, `image_info`).
130///
131/// Path extraction is argument-name-driven: the wrapper inspects the `"path"`,
132/// `"command"`, `"pattern"`, and `"query"` fields of the JSON argument object.
133/// Tools whose path argument uses a different field name can pass a custom
134/// extractor at construction via [`PathGuardedTool::with_extractor`].
135pub struct PathGuardedTool<T: Tool> {
136    inner: T,
137    security: Arc<SecurityPolicy>,
138    /// Optional override: extract a path string from the args JSON.
139    extractor: Option<Box<PathExtractor>>,
140}
141
142impl<T: Tool> PathGuardedTool<T> {
143    pub fn new(inner: T, security: Arc<SecurityPolicy>) -> Self {
144        Self {
145            inner,
146            security,
147            extractor: None,
148        }
149    }
150
151    /// Supply a custom path-extraction closure for tools with non-standard arg names.
152    pub fn with_extractor<F>(mut self, f: F) -> Self
153    where
154        F: Fn(&serde_json::Value) -> Option<String> + Send + Sync + 'static,
155    {
156        self.extractor = Some(Box::new(f));
157        self
158    }
159
160    fn extract_path_string(&self, args: &serde_json::Value) -> Option<String> {
161        if let Some(ref f) = self.extractor {
162            return f(args);
163        }
164        // Default: check common argument names used across ZeroClaw tools.
165        for field in &["path", "command", "pattern", "query", "file"] {
166            if let Some(s) = args.get(field).and_then(|v| v.as_str()) {
167                return Some(s.to_string());
168            }
169        }
170        None
171    }
172}
173
174impl<T: Tool> Attributable for PathGuardedTool<T> {
175    fn role(&self) -> Role {
176        self.inner.role()
177    }
178    fn alias(&self) -> &str {
179        self.inner.alias()
180    }
181}
182
183#[async_trait]
184impl<T: Tool> Tool for PathGuardedTool<T> {
185    fn name(&self) -> &str {
186        self.inner.name()
187    }
188
189    fn description(&self) -> &str {
190        self.inner.description()
191    }
192
193    fn parameters_schema(&self) -> serde_json::Value {
194        self.inner.parameters_schema()
195    }
196
197    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
198        if let Some(arg) = self.extract_path_string(&args) {
199            // For shell command arguments, use the full token-aware scanner.
200            // For plain path values (e.g. "path" or custom extractor), fall back
201            // to the direct path check.
202            let blocked = if self.extractor.is_none()
203                && args.get("command").and_then(|v| v.as_str()).is_some()
204            {
205                self.security.forbidden_path_argument(&arg)
206            } else if !self.security.is_path_allowed(&arg) {
207                Some(arg.clone())
208            } else {
209                None
210            };
211
212            if let Some(path) = blocked {
213                return Ok(ToolResult {
214                    success: false,
215                    output: String::new(),
216                    error: Some(format!("Path blocked by security policy: {path}")),
217                });
218            }
219        }
220
221        self.inner.execute(args).await
222    }
223}
224
225// ── Tests ─────────────────────────────────────────────────────────────────────
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use async_trait::async_trait;
231    use std::sync::atomic::{AtomicUsize, Ordering};
232    use zeroclaw_config::autonomy::AutonomyLevel;
233    use zeroclaw_config::policy::SecurityPolicy;
234
235    zeroclaw_api::mock_tool_attribution!(CountingTool);
236
237    // ── Helpers ───────────────────────────────────────────────────────────────
238
239    fn policy(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
240        Arc::new(SecurityPolicy {
241            autonomy,
242            workspace_dir: std::env::temp_dir(),
243            ..SecurityPolicy::default()
244        })
245    }
246
247    /// A minimal tool that records how many times `execute` was called.
248    struct CountingTool {
249        calls: Arc<AtomicUsize>,
250    }
251
252    impl CountingTool {
253        fn new() -> (Self, Arc<AtomicUsize>) {
254            let counter = Arc::new(AtomicUsize::new(0));
255            (
256                CountingTool {
257                    calls: counter.clone(),
258                },
259                counter,
260            )
261        }
262    }
263
264    #[async_trait]
265    impl Tool for CountingTool {
266        fn name(&self) -> &str {
267            "counting"
268        }
269        fn description(&self) -> &str {
270            "counts calls"
271        }
272        fn parameters_schema(&self) -> serde_json::Value {
273            serde_json::json!({})
274        }
275        async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
276            self.calls.fetch_add(1, Ordering::SeqCst);
277            Ok(ToolResult {
278                success: true,
279                output: "ok".into(),
280                error: None,
281            })
282        }
283    }
284
285    // ── RateLimitedTool tests ─────────────────────────────────────────────────
286
287    #[tokio::test]
288    async fn rate_limited_allows_call_within_budget() {
289        let (inner, counter) = CountingTool::new();
290        let tool = RateLimitedTool::new(inner, policy(AutonomyLevel::Full));
291        let result = tool
292            .execute(serde_json::json!({}))
293            .await
294            .expect("should succeed");
295        assert!(result.success);
296        assert_eq!(counter.load(Ordering::SeqCst), 1);
297    }
298
299    #[tokio::test]
300    async fn rate_limited_delegates_name_and_schema() {
301        let (inner, _) = CountingTool::new();
302        let tool = RateLimitedTool::new(inner, policy(AutonomyLevel::Full));
303        assert_eq!(tool.name(), "counting");
304        assert_eq!(tool.description(), "counts calls");
305        assert!(tool.parameters_schema().is_object());
306    }
307
308    #[tokio::test]
309    async fn rate_limited_blocks_when_exhausted() {
310        // Use a policy with a tiny action budget (1 action per window).
311        let sec = Arc::new(SecurityPolicy {
312            autonomy: AutonomyLevel::Full,
313            workspace_dir: std::env::temp_dir(),
314            max_actions_per_hour: 1,
315            ..SecurityPolicy::default()
316        });
317        let (inner, counter) = CountingTool::new();
318        let tool = RateLimitedTool::new(inner, sec);
319
320        let r1 = tool.execute(serde_json::json!({})).await.unwrap();
321        assert!(r1.success, "first call should succeed");
322
323        let r2 = tool.execute(serde_json::json!({})).await.unwrap();
324        assert!(!r2.success, "second call should be rate-limited");
325        assert!(r2.error.unwrap().contains("Rate limit exceeded"));
326        // Inner tool must NOT have been called on the blocked attempt.
327        assert_eq!(counter.load(Ordering::SeqCst), 1);
328    }
329
330    // ── PathGuardedTool tests ─────────────────────────────────────────────────
331
332    #[tokio::test]
333    async fn path_guard_allows_safe_path() {
334        let (inner, counter) = CountingTool::new();
335        let tool = PathGuardedTool::new(inner, policy(AutonomyLevel::Full));
336        let result = tool
337            .execute(serde_json::json!({"path": "src/main.rs"}))
338            .await
339            .unwrap();
340        assert!(result.success);
341        assert_eq!(counter.load(Ordering::SeqCst), 1);
342    }
343
344    #[tokio::test]
345    async fn path_guard_blocks_forbidden_path() {
346        let (inner, counter) = CountingTool::new();
347        let tool = PathGuardedTool::new(inner, policy(AutonomyLevel::Full));
348        let result = tool
349            .execute(serde_json::json!({"command": "cat /etc/passwd"}))
350            .await
351            .unwrap();
352        assert!(!result.success);
353        assert!(result.error.unwrap().contains("Path blocked"));
354        assert_eq!(
355            counter.load(Ordering::SeqCst),
356            0,
357            "inner must not be called"
358        );
359    }
360
361    #[tokio::test]
362    async fn path_guard_no_path_arg_passes_through() {
363        let (inner, counter) = CountingTool::new();
364        let tool = PathGuardedTool::new(inner, policy(AutonomyLevel::Full));
365        // No recognised path field — wrapper must not block.
366        let result = tool
367            .execute(serde_json::json!({"value": "hello"}))
368            .await
369            .unwrap();
370        assert!(result.success);
371        assert_eq!(counter.load(Ordering::SeqCst), 1);
372    }
373
374    #[tokio::test]
375    async fn path_guard_custom_extractor() {
376        let (inner, counter) = CountingTool::new();
377        let tool =
378            PathGuardedTool::new(inner, policy(AutonomyLevel::Full)).with_extractor(|args| {
379                args.get("target")
380                    .and_then(|v| v.as_str())
381                    .map(String::from)
382            });
383        let result = tool
384            .execute(serde_json::json!({"target": "/etc/shadow"}))
385            .await
386            .unwrap();
387        assert!(!result.success);
388        assert!(result.error.unwrap().contains("Path blocked"));
389        assert_eq!(counter.load(Ordering::SeqCst), 0);
390    }
391
392    // ── Composition test ──────────────────────────────────────────────────────
393
394    #[tokio::test]
395    async fn composed_wrappers_both_enforce() {
396        // RateLimited(PathGuarded(CountingTool)) — path check happens inside
397        // the rate-limit window, so a forbidden path must still be blocked
398        // (and not consume a rate-limit slot).
399        let sec = policy(AutonomyLevel::Full);
400        let (inner, counter) = CountingTool::new();
401        let tool = RateLimitedTool::new(PathGuardedTool::new(inner, sec.clone()), sec);
402
403        let blocked = tool
404            .execute(serde_json::json!({"path": "/etc/passwd"}))
405            .await
406            .unwrap();
407        assert!(!blocked.success);
408        assert_eq!(counter.load(Ordering::SeqCst), 0);
409    }
410
411    #[tokio::test]
412    async fn rate_limited_does_not_consume_budget_on_failure() {
413        // Inner tool that always reports failure (e.g. validation error).
414        // record_action() must NOT fire, so the budget stays at full and
415        // a subsequent successful call still goes through.
416        struct AlwaysFails;
417        impl ::zeroclaw_api::attribution::Attributable for AlwaysFails {
418            fn role(&self) -> ::zeroclaw_api::attribution::Role {
419                ::zeroclaw_api::attribution::Role::Tool(
420                    ::zeroclaw_api::attribution::ToolKind::Plugin,
421                )
422            }
423            fn alias(&self) -> &str {
424                <Self as Tool>::name(self)
425            }
426        }
427        #[async_trait]
428        impl Tool for AlwaysFails {
429            fn name(&self) -> &str {
430                "always_fails"
431            }
432            fn description(&self) -> &str {
433                ""
434            }
435            fn parameters_schema(&self) -> serde_json::Value {
436                serde_json::json!({})
437            }
438            async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
439                Ok(ToolResult {
440                    success: false,
441                    output: String::new(),
442                    error: Some("validation failed".into()),
443                })
444            }
445        }
446
447        let sec = Arc::new(SecurityPolicy {
448            autonomy: AutonomyLevel::Full,
449            workspace_dir: std::env::temp_dir(),
450            max_actions_per_hour: 1,
451            ..SecurityPolicy::default()
452        });
453        let failing = RateLimitedTool::new(AlwaysFails, sec.clone());
454
455        // Three failed calls — none should consume the single-slot budget.
456        for _ in 0..3 {
457            let r = failing.execute(serde_json::json!({})).await.unwrap();
458            assert!(!r.success);
459            assert!(r.error.unwrap().contains("validation failed"));
460        }
461
462        // Now a fresh successful tool wrapped against the same policy must
463        // still have its slot available.
464        let (success_inner, counter) = CountingTool::new();
465        let succeeding = RateLimitedTool::new(success_inner, sec);
466        let r = succeeding.execute(serde_json::json!({})).await.unwrap();
467        assert!(r.success);
468        assert_eq!(counter.load(Ordering::SeqCst), 1);
469    }
470
471    #[tokio::test]
472    async fn composed_wrappers_path_block_preserves_budget() {
473        // RateLimited(PathGuarded(CountingTool)) — PathGuard blocks the call,
474        // budget must NOT be consumed, so a subsequent allowed call still runs.
475        let sec = Arc::new(SecurityPolicy {
476            autonomy: AutonomyLevel::Full,
477            workspace_dir: std::env::temp_dir(),
478            max_actions_per_hour: 1,
479            ..SecurityPolicy::default()
480        });
481        let (inner, counter) = CountingTool::new();
482        let tool = RateLimitedTool::new(PathGuardedTool::new(inner, sec.clone()), sec);
483
484        let blocked = tool
485            .execute(serde_json::json!({"path": "/etc/passwd"}))
486            .await
487            .unwrap();
488        assert!(!blocked.success);
489        assert_eq!(counter.load(Ordering::SeqCst), 0);
490
491        // Budget intact: an allowed call should still pass.
492        let allowed = tool
493            .execute(serde_json::json!({"path": "src/main.rs"}))
494            .await
495            .unwrap();
496        assert!(allowed.success, "budget should still have a slot");
497        assert_eq!(counter.load(Ordering::SeqCst), 1);
498    }
499}