1use async_trait::async_trait;
26use std::sync::Arc;
27use zeroclaw_api::attribution::{Attributable, Role};
28use zeroclaw_api::tool::{Tool, ToolResult};
29use zeroclaw_config::policy::SecurityPolicy;
30
31type PathExtractor = dyn Fn(&serde_json::Value) -> Option<String> + Send + Sync;
33
34pub struct RateLimitedTool<T: Tool> {
62 inner: T,
63 security: Arc<SecurityPolicy>,
64}
65
66impl<T: Tool> RateLimitedTool<T> {
67 pub fn new(inner: T, security: Arc<SecurityPolicy>) -> Self {
68 Self { inner, security }
69 }
70}
71
72impl<T: Tool> Attributable for RateLimitedTool<T> {
73 fn role(&self) -> Role {
74 self.inner.role()
75 }
76 fn alias(&self) -> &str {
77 self.inner.alias()
78 }
79}
80
81#[async_trait]
82impl<T: Tool> Tool for RateLimitedTool<T> {
83 fn name(&self) -> &str {
84 self.inner.name()
85 }
86
87 fn description(&self) -> &str {
88 self.inner.description()
89 }
90
91 fn parameters_schema(&self) -> serde_json::Value {
92 self.inner.parameters_schema()
93 }
94
95 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
96 if self.security.is_rate_limited() {
97 return Ok(ToolResult {
98 success: false,
99 output: String::new(),
100 error: Some("Rate limit exceeded: too many actions in the last hour".into()),
101 });
102 }
103
104 let result = self.inner.execute(args).await?;
110
111 if result.success && !self.security.record_action() {
112 return Ok(ToolResult {
113 success: false,
114 output: String::new(),
115 error: Some("Rate limit exceeded: action budget exhausted".into()),
116 });
117 }
118
119 Ok(result)
120 }
121}
122
123pub struct PathGuardedTool<T: Tool> {
136 inner: T,
137 security: Arc<SecurityPolicy>,
138 extractor: Option<Box<PathExtractor>>,
140}
141
142impl<T: Tool> PathGuardedTool<T> {
143 pub fn new(inner: T, security: Arc<SecurityPolicy>) -> Self {
144 Self {
145 inner,
146 security,
147 extractor: None,
148 }
149 }
150
151 pub fn with_extractor<F>(mut self, f: F) -> Self
153 where
154 F: Fn(&serde_json::Value) -> Option<String> + Send + Sync + 'static,
155 {
156 self.extractor = Some(Box::new(f));
157 self
158 }
159
160 fn extract_path_string(&self, args: &serde_json::Value) -> Option<String> {
161 if let Some(ref f) = self.extractor {
162 return f(args);
163 }
164 for field in &["path", "command", "pattern", "query", "file"] {
166 if let Some(s) = args.get(field).and_then(|v| v.as_str()) {
167 return Some(s.to_string());
168 }
169 }
170 None
171 }
172}
173
174impl<T: Tool> Attributable for PathGuardedTool<T> {
175 fn role(&self) -> Role {
176 self.inner.role()
177 }
178 fn alias(&self) -> &str {
179 self.inner.alias()
180 }
181}
182
183#[async_trait]
184impl<T: Tool> Tool for PathGuardedTool<T> {
185 fn name(&self) -> &str {
186 self.inner.name()
187 }
188
189 fn description(&self) -> &str {
190 self.inner.description()
191 }
192
193 fn parameters_schema(&self) -> serde_json::Value {
194 self.inner.parameters_schema()
195 }
196
197 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
198 if let Some(arg) = self.extract_path_string(&args) {
199 let blocked = if self.extractor.is_none()
203 && args.get("command").and_then(|v| v.as_str()).is_some()
204 {
205 self.security.forbidden_path_argument(&arg)
206 } else if !self.security.is_path_allowed(&arg) {
207 Some(arg.clone())
208 } else {
209 None
210 };
211
212 if let Some(path) = blocked {
213 return Ok(ToolResult {
214 success: false,
215 output: String::new(),
216 error: Some(format!("Path blocked by security policy: {path}")),
217 });
218 }
219 }
220
221 self.inner.execute(args).await
222 }
223}
224
225#[cfg(test)]
228mod tests {
229 use super::*;
230 use async_trait::async_trait;
231 use std::sync::atomic::{AtomicUsize, Ordering};
232 use zeroclaw_config::autonomy::AutonomyLevel;
233 use zeroclaw_config::policy::SecurityPolicy;
234
235 zeroclaw_api::mock_tool_attribution!(CountingTool);
236
237 fn policy(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
240 Arc::new(SecurityPolicy {
241 autonomy,
242 workspace_dir: std::env::temp_dir(),
243 ..SecurityPolicy::default()
244 })
245 }
246
247 struct CountingTool {
249 calls: Arc<AtomicUsize>,
250 }
251
252 impl CountingTool {
253 fn new() -> (Self, Arc<AtomicUsize>) {
254 let counter = Arc::new(AtomicUsize::new(0));
255 (
256 CountingTool {
257 calls: counter.clone(),
258 },
259 counter,
260 )
261 }
262 }
263
264 #[async_trait]
265 impl Tool for CountingTool {
266 fn name(&self) -> &str {
267 "counting"
268 }
269 fn description(&self) -> &str {
270 "counts calls"
271 }
272 fn parameters_schema(&self) -> serde_json::Value {
273 serde_json::json!({})
274 }
275 async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
276 self.calls.fetch_add(1, Ordering::SeqCst);
277 Ok(ToolResult {
278 success: true,
279 output: "ok".into(),
280 error: None,
281 })
282 }
283 }
284
285 #[tokio::test]
288 async fn rate_limited_allows_call_within_budget() {
289 let (inner, counter) = CountingTool::new();
290 let tool = RateLimitedTool::new(inner, policy(AutonomyLevel::Full));
291 let result = tool
292 .execute(serde_json::json!({}))
293 .await
294 .expect("should succeed");
295 assert!(result.success);
296 assert_eq!(counter.load(Ordering::SeqCst), 1);
297 }
298
299 #[tokio::test]
300 async fn rate_limited_delegates_name_and_schema() {
301 let (inner, _) = CountingTool::new();
302 let tool = RateLimitedTool::new(inner, policy(AutonomyLevel::Full));
303 assert_eq!(tool.name(), "counting");
304 assert_eq!(tool.description(), "counts calls");
305 assert!(tool.parameters_schema().is_object());
306 }
307
308 #[tokio::test]
309 async fn rate_limited_blocks_when_exhausted() {
310 let sec = Arc::new(SecurityPolicy {
312 autonomy: AutonomyLevel::Full,
313 workspace_dir: std::env::temp_dir(),
314 max_actions_per_hour: 1,
315 ..SecurityPolicy::default()
316 });
317 let (inner, counter) = CountingTool::new();
318 let tool = RateLimitedTool::new(inner, sec);
319
320 let r1 = tool.execute(serde_json::json!({})).await.unwrap();
321 assert!(r1.success, "first call should succeed");
322
323 let r2 = tool.execute(serde_json::json!({})).await.unwrap();
324 assert!(!r2.success, "second call should be rate-limited");
325 assert!(r2.error.unwrap().contains("Rate limit exceeded"));
326 assert_eq!(counter.load(Ordering::SeqCst), 1);
328 }
329
330 #[tokio::test]
333 async fn path_guard_allows_safe_path() {
334 let (inner, counter) = CountingTool::new();
335 let tool = PathGuardedTool::new(inner, policy(AutonomyLevel::Full));
336 let result = tool
337 .execute(serde_json::json!({"path": "src/main.rs"}))
338 .await
339 .unwrap();
340 assert!(result.success);
341 assert_eq!(counter.load(Ordering::SeqCst), 1);
342 }
343
344 #[tokio::test]
345 async fn path_guard_blocks_forbidden_path() {
346 let (inner, counter) = CountingTool::new();
347 let tool = PathGuardedTool::new(inner, policy(AutonomyLevel::Full));
348 let result = tool
349 .execute(serde_json::json!({"command": "cat /etc/passwd"}))
350 .await
351 .unwrap();
352 assert!(!result.success);
353 assert!(result.error.unwrap().contains("Path blocked"));
354 assert_eq!(
355 counter.load(Ordering::SeqCst),
356 0,
357 "inner must not be called"
358 );
359 }
360
361 #[tokio::test]
362 async fn path_guard_no_path_arg_passes_through() {
363 let (inner, counter) = CountingTool::new();
364 let tool = PathGuardedTool::new(inner, policy(AutonomyLevel::Full));
365 let result = tool
367 .execute(serde_json::json!({"value": "hello"}))
368 .await
369 .unwrap();
370 assert!(result.success);
371 assert_eq!(counter.load(Ordering::SeqCst), 1);
372 }
373
374 #[tokio::test]
375 async fn path_guard_custom_extractor() {
376 let (inner, counter) = CountingTool::new();
377 let tool =
378 PathGuardedTool::new(inner, policy(AutonomyLevel::Full)).with_extractor(|args| {
379 args.get("target")
380 .and_then(|v| v.as_str())
381 .map(String::from)
382 });
383 let result = tool
384 .execute(serde_json::json!({"target": "/etc/shadow"}))
385 .await
386 .unwrap();
387 assert!(!result.success);
388 assert!(result.error.unwrap().contains("Path blocked"));
389 assert_eq!(counter.load(Ordering::SeqCst), 0);
390 }
391
392 #[tokio::test]
395 async fn composed_wrappers_both_enforce() {
396 let sec = policy(AutonomyLevel::Full);
400 let (inner, counter) = CountingTool::new();
401 let tool = RateLimitedTool::new(PathGuardedTool::new(inner, sec.clone()), sec);
402
403 let blocked = tool
404 .execute(serde_json::json!({"path": "/etc/passwd"}))
405 .await
406 .unwrap();
407 assert!(!blocked.success);
408 assert_eq!(counter.load(Ordering::SeqCst), 0);
409 }
410
411 #[tokio::test]
412 async fn rate_limited_does_not_consume_budget_on_failure() {
413 struct AlwaysFails;
417 impl ::zeroclaw_api::attribution::Attributable for AlwaysFails {
418 fn role(&self) -> ::zeroclaw_api::attribution::Role {
419 ::zeroclaw_api::attribution::Role::Tool(
420 ::zeroclaw_api::attribution::ToolKind::Plugin,
421 )
422 }
423 fn alias(&self) -> &str {
424 <Self as Tool>::name(self)
425 }
426 }
427 #[async_trait]
428 impl Tool for AlwaysFails {
429 fn name(&self) -> &str {
430 "always_fails"
431 }
432 fn description(&self) -> &str {
433 ""
434 }
435 fn parameters_schema(&self) -> serde_json::Value {
436 serde_json::json!({})
437 }
438 async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
439 Ok(ToolResult {
440 success: false,
441 output: String::new(),
442 error: Some("validation failed".into()),
443 })
444 }
445 }
446
447 let sec = Arc::new(SecurityPolicy {
448 autonomy: AutonomyLevel::Full,
449 workspace_dir: std::env::temp_dir(),
450 max_actions_per_hour: 1,
451 ..SecurityPolicy::default()
452 });
453 let failing = RateLimitedTool::new(AlwaysFails, sec.clone());
454
455 for _ in 0..3 {
457 let r = failing.execute(serde_json::json!({})).await.unwrap();
458 assert!(!r.success);
459 assert!(r.error.unwrap().contains("validation failed"));
460 }
461
462 let (success_inner, counter) = CountingTool::new();
465 let succeeding = RateLimitedTool::new(success_inner, sec);
466 let r = succeeding.execute(serde_json::json!({})).await.unwrap();
467 assert!(r.success);
468 assert_eq!(counter.load(Ordering::SeqCst), 1);
469 }
470
471 #[tokio::test]
472 async fn composed_wrappers_path_block_preserves_budget() {
473 let sec = Arc::new(SecurityPolicy {
476 autonomy: AutonomyLevel::Full,
477 workspace_dir: std::env::temp_dir(),
478 max_actions_per_hour: 1,
479 ..SecurityPolicy::default()
480 });
481 let (inner, counter) = CountingTool::new();
482 let tool = RateLimitedTool::new(PathGuardedTool::new(inner, sec.clone()), sec);
483
484 let blocked = tool
485 .execute(serde_json::json!({"path": "/etc/passwd"}))
486 .await
487 .unwrap();
488 assert!(!blocked.success);
489 assert_eq!(counter.load(Ordering::SeqCst), 0);
490
491 let allowed = tool
493 .execute(serde_json::json!({"path": "src/main.rs"}))
494 .await
495 .unwrap();
496 assert!(allowed.success, "budget should still have a slot");
497 assert_eq!(counter.load(Ordering::SeqCst), 1);
498 }
499}