Skip to main content

zeroclaw_tools/
memory_forget.rs

1use async_trait::async_trait;
2use serde_json::json;
3use std::sync::Arc;
4use zeroclaw_api::tool::{Tool, ToolResult};
5use zeroclaw_config::policy::SecurityPolicy;
6use zeroclaw_config::policy::ToolOperation;
7use zeroclaw_memory::Memory;
8
9/// Let the agent forget/delete a memory entry
10pub struct MemoryForgetTool {
11    memory: Arc<dyn Memory>,
12    security: Arc<SecurityPolicy>,
13}
14
15impl MemoryForgetTool {
16    pub fn new(memory: Arc<dyn Memory>, security: Arc<SecurityPolicy>) -> Self {
17        Self { memory, security }
18    }
19}
20
21#[async_trait]
22impl Tool for MemoryForgetTool {
23    fn name(&self) -> &str {
24        "memory_forget"
25    }
26
27    fn description(&self) -> &str {
28        "Remove a memory by key. Use to delete outdated facts or sensitive data. Returns whether the memory was found and removed."
29    }
30
31    fn parameters_schema(&self) -> serde_json::Value {
32        json!({
33            "type": "object",
34            "properties": {
35                "key": {
36                    "type": "string",
37                    "description": "The key of the memory to forget"
38                }
39            },
40            "required": ["key"]
41        })
42    }
43
44    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
45        let key = args.get("key").and_then(|v| v.as_str()).ok_or_else(|| {
46            ::zeroclaw_log::record!(
47                WARN,
48                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
49                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
50                    .with_attrs(::serde_json::json!({"param": "key"})),
51                "memory_forget: missing key parameter"
52            );
53            anyhow::Error::msg("Missing 'key' parameter")
54        })?;
55
56        if let Err(error) = self
57            .security
58            .enforce_tool_operation(ToolOperation::Act, "memory_forget")
59        {
60            return Ok(ToolResult {
61                success: false,
62                output: String::new(),
63                error: Some(error),
64            });
65        }
66
67        match self.memory.forget(key).await {
68            Ok(true) => Ok(ToolResult {
69                success: true,
70                output: format!("Forgot memory: {key}"),
71                error: None,
72            }),
73            Ok(false) => Ok(ToolResult {
74                success: true,
75                output: format!("No memory found with key: {key}"),
76                error: None,
77            }),
78            Err(e) => Ok(ToolResult {
79                success: false,
80                output: String::new(),
81                error: Some(format!("Failed to forget memory: {e}")),
82            }),
83        }
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use tempfile::TempDir;
91    use zeroclaw_config::autonomy::AutonomyLevel;
92    use zeroclaw_config::policy::SecurityPolicy;
93    use zeroclaw_memory::{MemoryCategory, SqliteMemory};
94
95    fn test_security() -> Arc<SecurityPolicy> {
96        Arc::new(SecurityPolicy::default())
97    }
98
99    fn test_mem() -> (TempDir, Arc<dyn Memory>) {
100        let tmp = TempDir::new().unwrap();
101        let mem = SqliteMemory::new("test", tmp.path()).unwrap();
102        (tmp, Arc::new(mem))
103    }
104
105    #[test]
106    fn name_and_schema() {
107        let (_tmp, mem) = test_mem();
108        let tool = MemoryForgetTool::new(mem, test_security());
109        assert_eq!(tool.name(), "memory_forget");
110        assert!(tool.parameters_schema()["properties"]["key"].is_object());
111    }
112
113    #[tokio::test]
114    async fn forget_existing() {
115        let (_tmp, mem) = test_mem();
116        mem.store("temp", "temporary", MemoryCategory::Conversation, None)
117            .await
118            .unwrap();
119
120        let tool = MemoryForgetTool::new(mem.clone(), test_security());
121        let result = tool.execute(json!({"key": "temp"})).await.unwrap();
122        assert!(result.success);
123        assert!(result.output.contains("Forgot"));
124
125        assert!(mem.get("temp").await.unwrap().is_none());
126    }
127
128    #[tokio::test]
129    async fn forget_nonexistent() {
130        let (_tmp, mem) = test_mem();
131        let tool = MemoryForgetTool::new(mem, test_security());
132        let result = tool.execute(json!({"key": "nope"})).await.unwrap();
133        assert!(result.success);
134        assert!(result.output.contains("No memory found"));
135    }
136
137    #[tokio::test]
138    async fn forget_missing_key() {
139        let (_tmp, mem) = test_mem();
140        let tool = MemoryForgetTool::new(mem, test_security());
141        let result = tool.execute(json!({})).await;
142        assert!(result.is_err());
143    }
144
145    #[tokio::test]
146    async fn forget_blocked_in_readonly_mode() {
147        let (_tmp, mem) = test_mem();
148        mem.store("temp", "temporary", MemoryCategory::Conversation, None)
149            .await
150            .unwrap();
151        let readonly = Arc::new(SecurityPolicy {
152            autonomy: AutonomyLevel::ReadOnly,
153            ..SecurityPolicy::default()
154        });
155        let tool = MemoryForgetTool::new(mem.clone(), readonly);
156        let result = tool.execute(json!({"key": "temp"})).await.unwrap();
157        assert!(!result.success);
158        assert!(
159            result
160                .error
161                .as_deref()
162                .unwrap_or("")
163                .contains("read-only mode")
164        );
165        assert!(mem.get("temp").await.unwrap().is_some());
166    }
167
168    #[tokio::test]
169    async fn forget_blocked_when_rate_limited() {
170        let (_tmp, mem) = test_mem();
171        mem.store("temp", "temporary", MemoryCategory::Conversation, None)
172            .await
173            .unwrap();
174        let limited = Arc::new(SecurityPolicy {
175            max_actions_per_hour: 0,
176            ..SecurityPolicy::default()
177        });
178        let tool = MemoryForgetTool::new(mem.clone(), limited);
179        let result = tool.execute(json!({"key": "temp"})).await.unwrap();
180        assert!(!result.success);
181        assert!(
182            result
183                .error
184                .as_deref()
185                .unwrap_or("")
186                .contains("Rate limit exceeded")
187        );
188        assert!(mem.get("temp").await.unwrap().is_some());
189    }
190}