Skip to main content

zeroclaw_tools/
memory_store.rs

1use async_trait::async_trait;
2use serde_json::json;
3use std::sync::Arc;
4use zeroclaw_api::tool::{Tool, ToolResult};
5use zeroclaw_config::policy::SecurityPolicy;
6use zeroclaw_config::policy::ToolOperation;
7use zeroclaw_memory::{Memory, MemoryCategory};
8
9/// Let the agent store memories — its own brain writes
10pub struct MemoryStoreTool {
11    memory: Arc<dyn Memory>,
12    security: Arc<SecurityPolicy>,
13}
14
15impl MemoryStoreTool {
16    pub fn new(memory: Arc<dyn Memory>, security: Arc<SecurityPolicy>) -> Self {
17        Self { memory, security }
18    }
19}
20
21#[async_trait]
22impl Tool for MemoryStoreTool {
23    fn name(&self) -> &str {
24        "memory_store"
25    }
26
27    fn description(&self) -> &str {
28        "Store a fact, preference, or note in long-term memory. Use category 'core' for permanent facts, 'daily' for session notes, 'conversation' for chat context, or a custom category name."
29    }
30
31    fn parameters_schema(&self) -> serde_json::Value {
32        json!({
33            "type": "object",
34            "properties": {
35                "key": {
36                    "type": "string",
37                    "description": "Unique key for this memory (e.g. 'user_lang', 'project_stack')"
38                },
39                "content": {
40                    "type": "string",
41                    "description": "The information to remember"
42                },
43                "category": {
44                    "type": "string",
45                    "description": "Memory category: 'core' (permanent), 'daily' (session), 'conversation' (chat), or a custom category name. Defaults to 'core'."
46                }
47            },
48            "required": ["key", "content"]
49        })
50    }
51
52    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
53        let key = args.get("key").and_then(|v| v.as_str()).ok_or_else(|| {
54            ::zeroclaw_log::record!(
55                WARN,
56                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
57                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
58                    .with_attrs(::serde_json::json!({"param": "key"})),
59                "memory_store: missing key parameter"
60            );
61            anyhow::Error::msg("Missing 'key' parameter")
62        })?;
63
64        let content = args
65            .get("content")
66            .and_then(|v| v.as_str())
67            .ok_or_else(|| {
68                ::zeroclaw_log::record!(
69                    WARN,
70                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
71                        .with_outcome(::zeroclaw_log::EventOutcome::Failure)
72                        .with_attrs(::serde_json::json!({"param": "content"})),
73                    "memory_store: missing content parameter"
74                );
75                anyhow::Error::msg("Missing 'content' parameter")
76            })?;
77
78        let category = match args.get("category").and_then(|v| v.as_str()) {
79            Some("core") | None => MemoryCategory::Core,
80            Some("daily") => MemoryCategory::Daily,
81            Some("conversation") => MemoryCategory::Conversation,
82            Some(other) => MemoryCategory::Custom(other.to_string()),
83        };
84
85        if let Err(error) = self
86            .security
87            .enforce_tool_operation(ToolOperation::Act, "memory_store")
88        {
89            return Ok(ToolResult {
90                success: false,
91                output: String::new(),
92                error: Some(error),
93            });
94        }
95
96        match self.memory.store(key, content, category, None).await {
97            Ok(()) => Ok(ToolResult {
98                success: true,
99                output: format!("Stored memory: {key}"),
100                error: None,
101            }),
102            Err(e) => Ok(ToolResult {
103                success: false,
104                output: String::new(),
105                error: Some(format!("Failed to store memory: {e}")),
106            }),
107        }
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use tempfile::TempDir;
115    use zeroclaw_config::autonomy::AutonomyLevel;
116    use zeroclaw_config::policy::SecurityPolicy;
117    use zeroclaw_memory::SqliteMemory;
118
119    fn test_security() -> Arc<SecurityPolicy> {
120        Arc::new(SecurityPolicy::default())
121    }
122
123    fn test_mem() -> (TempDir, Arc<dyn Memory>) {
124        let tmp = TempDir::new().unwrap();
125        let mem = SqliteMemory::new("test", tmp.path()).unwrap();
126        (tmp, Arc::new(mem))
127    }
128
129    #[test]
130    fn name_and_schema() {
131        let (_tmp, mem) = test_mem();
132        let tool = MemoryStoreTool::new(mem, test_security());
133        assert_eq!(tool.name(), "memory_store");
134        let schema = tool.parameters_schema();
135        assert!(schema["properties"]["key"].is_object());
136        assert!(schema["properties"]["content"].is_object());
137    }
138
139    #[tokio::test]
140    async fn store_core() {
141        let (_tmp, mem) = test_mem();
142        let tool = MemoryStoreTool::new(mem.clone(), test_security());
143        let result = tool
144            .execute(json!({"key": "lang", "content": "Prefers Rust"}))
145            .await
146            .unwrap();
147        assert!(result.success);
148        assert!(result.output.contains("lang"));
149
150        let entry = mem.get("lang").await.unwrap();
151        assert!(entry.is_some());
152        assert_eq!(entry.unwrap().content, "Prefers Rust");
153    }
154
155    #[tokio::test]
156    async fn store_with_category() {
157        let (_tmp, mem) = test_mem();
158        let tool = MemoryStoreTool::new(mem.clone(), test_security());
159        let result = tool
160            .execute(json!({"key": "note", "content": "Fixed bug", "category": "daily"}))
161            .await
162            .unwrap();
163        assert!(result.success);
164    }
165
166    #[tokio::test]
167    async fn store_with_custom_category() {
168        let (_tmp, mem) = test_mem();
169        let tool = MemoryStoreTool::new(mem.clone(), test_security());
170        let result = tool
171            .execute(
172                json!({"key": "proj_note", "content": "Uses async runtime", "category": "project"}),
173            )
174            .await
175            .unwrap();
176        assert!(result.success);
177
178        let entry = mem.get("proj_note").await.unwrap().unwrap();
179        assert_eq!(entry.content, "Uses async runtime");
180        assert_eq!(entry.category, MemoryCategory::Custom("project".into()));
181    }
182
183    #[tokio::test]
184    async fn store_missing_key() {
185        let (_tmp, mem) = test_mem();
186        let tool = MemoryStoreTool::new(mem, test_security());
187        let result = tool.execute(json!({"content": "no key"})).await;
188        assert!(result.is_err());
189    }
190
191    #[tokio::test]
192    async fn store_missing_content() {
193        let (_tmp, mem) = test_mem();
194        let tool = MemoryStoreTool::new(mem, test_security());
195        let result = tool.execute(json!({"key": "no_content"})).await;
196        assert!(result.is_err());
197    }
198
199    #[tokio::test]
200    async fn store_blocked_in_readonly_mode() {
201        let (_tmp, mem) = test_mem();
202        let readonly = Arc::new(SecurityPolicy {
203            autonomy: AutonomyLevel::ReadOnly,
204            ..SecurityPolicy::default()
205        });
206        let tool = MemoryStoreTool::new(mem.clone(), readonly);
207        let result = tool
208            .execute(json!({"key": "lang", "content": "Prefers Rust"}))
209            .await
210            .unwrap();
211        assert!(!result.success);
212        assert!(
213            result
214                .error
215                .as_deref()
216                .unwrap_or("")
217                .contains("read-only mode")
218        );
219        assert!(mem.get("lang").await.unwrap().is_none());
220    }
221
222    #[tokio::test]
223    async fn store_blocked_when_rate_limited() {
224        let (_tmp, mem) = test_mem();
225        let limited = Arc::new(SecurityPolicy {
226            max_actions_per_hour: 0,
227            ..SecurityPolicy::default()
228        });
229        let tool = MemoryStoreTool::new(mem.clone(), limited);
230        let result = tool
231            .execute(json!({"key": "lang", "content": "Prefers Rust"}))
232            .await
233            .unwrap();
234        assert!(!result.success);
235        assert!(
236            result
237                .error
238                .as_deref()
239                .unwrap_or("")
240                .contains("Rate limit exceeded")
241        );
242        assert!(mem.get("lang").await.unwrap().is_none());
243    }
244}