Skip to main content

zeroclaw_tools/
memory_recall.rs

1use async_trait::async_trait;
2use serde_json::json;
3use std::fmt::Write;
4use std::sync::Arc;
5use zeroclaw_api::tool::{Tool, ToolResult};
6use zeroclaw_memory::Memory;
7
8/// Let the agent search its own memory
9pub struct MemoryRecallTool {
10    memory: Arc<dyn Memory>,
11}
12
13impl MemoryRecallTool {
14    pub fn new(memory: Arc<dyn Memory>) -> Self {
15        Self { memory }
16    }
17}
18
19#[async_trait]
20impl Tool for MemoryRecallTool {
21    fn name(&self) -> &str {
22        "memory_recall"
23    }
24
25    fn description(&self) -> &str {
26        "Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance. Supports keyword search, recent recall with omitted query or bare '*', time-only query (since/until), or both."
27    }
28
29    fn parameters_schema(&self) -> serde_json::Value {
30        json!({
31            "type": "object",
32            "properties": {
33                "query": {
34                    "type": "string",
35                    "description": "Keywords or phrase to search for in memory. Omit or pass bare '*' to return recent memories; non-bare wildcard terms remain keyword searches."
36                },
37                "limit": {
38                    "type": "integer",
39                    "description": "Max results to return (default: 5)"
40                },
41                "since": {
42                    "type": "string",
43                    "description": "Filter memories created at or after this time (RFC 3339, e.g. 2025-03-01T00:00:00Z)"
44                },
45                "until": {
46                    "type": "string",
47                    "description": "Filter memories created at or before this time (RFC 3339)"
48                },
49                "search_mode": {
50                    "type": "string",
51                    "enum": ["bm25", "embedding", "hybrid"],
52                    "description": "Search strategy: bm25 (keyword), embedding (semantic), or hybrid (both). Defaults to config value."
53                }
54            }
55        })
56    }
57
58    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
59        let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
60        let since = args.get("since").and_then(|v| v.as_str());
61        let until = args.get("until").and_then(|v| v.as_str());
62
63        // Validate date strings
64        if let Some(s) = since
65            && chrono::DateTime::parse_from_rfc3339(s).is_err()
66        {
67            return Ok(ToolResult {
68                success: false,
69                output: String::new(),
70                error: Some(format!(
71                    "Invalid 'since' date: {s}. Expected RFC 3339 format, e.g. 2025-03-01T00:00:00Z"
72                )),
73            });
74        }
75        if let Some(u) = until
76            && chrono::DateTime::parse_from_rfc3339(u).is_err()
77        {
78            return Ok(ToolResult {
79                success: false,
80                output: String::new(),
81                error: Some(format!(
82                    "Invalid 'until' date: {u}. Expected RFC 3339 format, e.g. 2025-03-01T00:00:00Z"
83                )),
84            });
85        }
86        if let (Some(s), Some(u)) = (since, until)
87            && let (Ok(s_dt), Ok(u_dt)) = (
88                chrono::DateTime::parse_from_rfc3339(s),
89                chrono::DateTime::parse_from_rfc3339(u),
90            )
91            && s_dt >= u_dt
92        {
93            return Ok(ToolResult {
94                success: false,
95                output: String::new(),
96                error: Some("'since' must be before 'until'".into()),
97            });
98        }
99
100        #[allow(clippy::cast_possible_truncation)]
101        let limit = args
102            .get("limit")
103            .and_then(serde_json::Value::as_u64)
104            .map_or(5, |v| v as usize);
105
106        match self.memory.recall(query, limit, None, since, until).await {
107            Ok(entries) if entries.is_empty() => Ok(ToolResult {
108                success: true,
109                output: "No memories found.".into(),
110                error: None,
111            }),
112            Ok(entries) => {
113                let mut output = format!("Found {} memories:\n", entries.len());
114                for entry in &entries {
115                    let score = entry
116                        .score
117                        .map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
118                    let _ = writeln!(
119                        output,
120                        "- [{}] {}: {}{score}",
121                        entry.category, entry.key, entry.content
122                    );
123                }
124                Ok(ToolResult {
125                    success: true,
126                    output,
127                    error: None,
128                })
129            }
130            Err(e) => Ok(ToolResult {
131                success: false,
132                output: String::new(),
133                error: Some(format!("Memory recall failed: {e}")),
134            }),
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use std::sync::Mutex;
143    use tempfile::TempDir;
144    use zeroclaw_memory::{MemoryCategory, MemoryEntry, SqliteMemory, is_recent_recall_query};
145
146    fn seeded_mem() -> (TempDir, Arc<dyn Memory>) {
147        let tmp = TempDir::new().unwrap();
148        let mem = SqliteMemory::new("test", tmp.path()).unwrap();
149        (tmp, Arc::new(mem))
150    }
151
152    struct QueryEchoMemory {
153        last_query: Arc<Mutex<Option<String>>>,
154    }
155
156    #[async_trait]
157    impl Memory for QueryEchoMemory {
158        fn name(&self) -> &str {
159            "query_echo"
160        }
161
162        async fn store(
163            &self,
164            _key: &str,
165            _content: &str,
166            _category: MemoryCategory,
167            _session_id: Option<&str>,
168        ) -> anyhow::Result<()> {
169            Ok(())
170        }
171
172        async fn recall(
173            &self,
174            query: &str,
175            _limit: usize,
176            _session_id: Option<&str>,
177            _since: Option<&str>,
178            _until: Option<&str>,
179        ) -> anyhow::Result<Vec<MemoryEntry>> {
180            *self.last_query.lock().unwrap() = Some(query.to_string());
181            if is_recent_recall_query(query) {
182                Ok(vec![MemoryEntry {
183                    id: "recent".into(),
184                    key: "recent".into(),
185                    content: "recent memory".into(),
186                    category: MemoryCategory::Core,
187                    timestamp: "2026-05-03T00:00:00Z".into(),
188                    session_id: None,
189                    score: None,
190                    namespace: "default".into(),
191                    importance: None,
192                    superseded_by: None,
193                    agent_alias: None,
194                    agent_id: None,
195                }])
196            } else {
197                Ok(Vec::new())
198            }
199        }
200
201        async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
202            Ok(None)
203        }
204
205        async fn list(
206            &self,
207            _category: Option<&MemoryCategory>,
208            _session_id: Option<&str>,
209        ) -> anyhow::Result<Vec<MemoryEntry>> {
210            Ok(Vec::new())
211        }
212
213        async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
214            Ok(false)
215        }
216
217        async fn forget_for_agent(&self, _key: &str, _agent_id: &str) -> anyhow::Result<bool> {
218            Ok(false)
219        }
220
221        async fn count(&self) -> anyhow::Result<usize> {
222            Ok(0)
223        }
224
225        async fn health_check(&self) -> bool {
226            true
227        }
228
229        async fn store_with_agent(
230            &self,
231            _key: &str,
232            _content: &str,
233            _category: MemoryCategory,
234            _session_id: Option<&str>,
235            _namespace: Option<&str>,
236            _importance: Option<f64>,
237            _agent_id: Option<&str>,
238        ) -> anyhow::Result<()> {
239            Ok(())
240        }
241
242        async fn recall_for_agents(
243            &self,
244            _allowed_agent_ids: &[&str],
245            query: &str,
246            limit: usize,
247            session_id: Option<&str>,
248            since: Option<&str>,
249            until: Option<&str>,
250        ) -> anyhow::Result<Vec<MemoryEntry>> {
251            self.recall(query, limit, session_id, since, until).await
252        }
253    }
254    impl ::zeroclaw_api::attribution::Attributable for QueryEchoMemory {
255        fn role(&self) -> ::zeroclaw_api::attribution::Role {
256            ::zeroclaw_api::attribution::Role::Memory(
257                ::zeroclaw_api::attribution::MemoryKind::InMemory,
258            )
259        }
260        fn alias(&self) -> &str {
261            "QueryEchoMemory"
262        }
263    }
264
265    #[tokio::test]
266    async fn recall_empty() {
267        let (_tmp, mem) = seeded_mem();
268        let tool = MemoryRecallTool::new(mem);
269        let result = tool.execute(json!({"query": "anything"})).await.unwrap();
270        assert!(result.success);
271        assert!(result.output.contains("No memories found"));
272    }
273
274    #[tokio::test]
275    async fn recall_finds_match() {
276        let (_tmp, mem) = seeded_mem();
277        mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
278            .await
279            .unwrap();
280        mem.store("tz", "Timezone is EST", MemoryCategory::Core, None)
281            .await
282            .unwrap();
283
284        let tool = MemoryRecallTool::new(mem);
285        let result = tool.execute(json!({"query": "Rust"})).await.unwrap();
286        assert!(result.success);
287        assert!(result.output.contains("Rust"));
288        assert!(result.output.contains("Found 1"));
289    }
290
291    #[tokio::test]
292    async fn recall_respects_limit() {
293        let (_tmp, mem) = seeded_mem();
294        for i in 0..10 {
295            mem.store(
296                &format!("k{i}"),
297                &format!("Rust fact {i}"),
298                MemoryCategory::Core,
299                None,
300            )
301            .await
302            .unwrap();
303        }
304
305        let tool = MemoryRecallTool::new(mem);
306        let result = tool
307            .execute(json!({"query": "Rust", "limit": 3}))
308            .await
309            .unwrap();
310        assert!(result.success);
311        assert!(result.output.contains("Found 3"));
312    }
313
314    #[tokio::test]
315    async fn bare_recall_returns_recent_entries() {
316        let (_tmp, mem) = seeded_mem();
317        mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
318            .await
319            .unwrap();
320        let tool = MemoryRecallTool::new(mem);
321        let result = tool.execute(json!({})).await.unwrap();
322        assert!(result.success);
323        assert!(result.output.contains("Found 1"));
324        assert!(result.output.contains("Rust"));
325    }
326
327    #[tokio::test]
328    async fn recall_star_query_returns_recent_entries() {
329        let (_tmp, mem) = seeded_mem();
330        mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
331            .await
332            .unwrap();
333        mem.store("tz", "Timezone is EST", MemoryCategory::Core, None)
334            .await
335            .unwrap();
336
337        let tool = MemoryRecallTool::new(mem);
338        let result = tool.execute(json!({"query": "*"})).await.unwrap();
339        assert!(result.success);
340        assert!(result.output.contains("Found 2"));
341        assert!(result.output.contains("Rust"));
342        assert!(result.output.contains("EST"));
343    }
344
345    #[tokio::test]
346    async fn recall_star_query_uses_backend_recent_query_contract() {
347        let last_query = Arc::new(Mutex::new(None));
348        let mem = Arc::new(QueryEchoMemory {
349            last_query: last_query.clone(),
350        });
351        let tool = MemoryRecallTool::new(mem);
352
353        let result = tool.execute(json!({"query": "*"})).await.unwrap();
354
355        assert!(result.success);
356        assert!(result.output.contains("recent memory"));
357        assert_eq!(*last_query.lock().unwrap(), Some("*".into()));
358    }
359
360    #[tokio::test]
361    async fn recall_time_only_returns_entries() {
362        let (_tmp, mem) = seeded_mem();
363        mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
364            .await
365            .unwrap();
366        let tool = MemoryRecallTool::new(mem);
367        // Time-only: since far in past
368        let result = tool
369            .execute(json!({"since": "2020-01-01T00:00:00Z", "limit": 5}))
370            .await
371            .unwrap();
372        assert!(result.success);
373        assert!(result.output.contains("Found 1"));
374        assert!(result.output.contains("Rust"));
375    }
376
377    #[test]
378    fn name_and_schema() {
379        let (_tmp, mem) = seeded_mem();
380        let tool = MemoryRecallTool::new(mem);
381        assert_eq!(tool.name(), "memory_recall");
382        assert!(tool.parameters_schema()["properties"]["query"].is_object());
383    }
384
385    #[test]
386    fn score_formatted_as_percent() {
387        let score: Option<f64> = Some(0.63);
388        let formatted = score.map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
389        assert_eq!(formatted, " [63%]");
390
391        let score: Option<f64> = Some(0.42);
392        let formatted = score.map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
393        assert_eq!(formatted, " [42%]");
394
395        let score: Option<f64> = Some(1.0);
396        let formatted = score.map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
397        assert_eq!(formatted, " [100%]");
398
399        let score: Option<f64> = Some(0.0);
400        let formatted = score.map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
401        assert_eq!(formatted, " [0%]");
402
403        let score: Option<f64> = None;
404        let formatted = score.map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
405        assert_eq!(formatted, "");
406    }
407
408    #[test]
409    fn schema_includes_search_mode_parameter() {
410        let (_tmp, mem) = seeded_mem();
411        let tool = MemoryRecallTool::new(mem);
412        let schema = tool.parameters_schema();
413        let search_mode = &schema["properties"]["search_mode"];
414        assert_eq!(search_mode["type"], "string");
415        let enum_values = search_mode["enum"].as_array().unwrap();
416        assert_eq!(enum_values.len(), 3);
417        assert!(enum_values.contains(&json!("bm25")));
418        assert!(enum_values.contains(&json!("embedding")));
419        assert!(enum_values.contains(&json!("hybrid")));
420    }
421}