1use async_trait::async_trait;
2use serde_json::json;
3use std::fmt::Write;
4use std::sync::Arc;
5use zeroclaw_api::tool::{Tool, ToolResult};
6use zeroclaw_memory::Memory;
7
8pub struct MemoryRecallTool {
10 memory: Arc<dyn Memory>,
11}
12
13impl MemoryRecallTool {
14 pub fn new(memory: Arc<dyn Memory>) -> Self {
15 Self { memory }
16 }
17}
18
19#[async_trait]
20impl Tool for MemoryRecallTool {
21 fn name(&self) -> &str {
22 "memory_recall"
23 }
24
25 fn description(&self) -> &str {
26 "Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance. Supports keyword search, recent recall with omitted query or bare '*', time-only query (since/until), or both."
27 }
28
29 fn parameters_schema(&self) -> serde_json::Value {
30 json!({
31 "type": "object",
32 "properties": {
33 "query": {
34 "type": "string",
35 "description": "Keywords or phrase to search for in memory. Omit or pass bare '*' to return recent memories; non-bare wildcard terms remain keyword searches."
36 },
37 "limit": {
38 "type": "integer",
39 "description": "Max results to return (default: 5)"
40 },
41 "since": {
42 "type": "string",
43 "description": "Filter memories created at or after this time (RFC 3339, e.g. 2025-03-01T00:00:00Z)"
44 },
45 "until": {
46 "type": "string",
47 "description": "Filter memories created at or before this time (RFC 3339)"
48 },
49 "search_mode": {
50 "type": "string",
51 "enum": ["bm25", "embedding", "hybrid"],
52 "description": "Search strategy: bm25 (keyword), embedding (semantic), or hybrid (both). Defaults to config value."
53 }
54 }
55 })
56 }
57
58 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
59 let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
60 let since = args.get("since").and_then(|v| v.as_str());
61 let until = args.get("until").and_then(|v| v.as_str());
62
63 if let Some(s) = since
65 && chrono::DateTime::parse_from_rfc3339(s).is_err()
66 {
67 return Ok(ToolResult {
68 success: false,
69 output: String::new(),
70 error: Some(format!(
71 "Invalid 'since' date: {s}. Expected RFC 3339 format, e.g. 2025-03-01T00:00:00Z"
72 )),
73 });
74 }
75 if let Some(u) = until
76 && chrono::DateTime::parse_from_rfc3339(u).is_err()
77 {
78 return Ok(ToolResult {
79 success: false,
80 output: String::new(),
81 error: Some(format!(
82 "Invalid 'until' date: {u}. Expected RFC 3339 format, e.g. 2025-03-01T00:00:00Z"
83 )),
84 });
85 }
86 if let (Some(s), Some(u)) = (since, until)
87 && let (Ok(s_dt), Ok(u_dt)) = (
88 chrono::DateTime::parse_from_rfc3339(s),
89 chrono::DateTime::parse_from_rfc3339(u),
90 )
91 && s_dt >= u_dt
92 {
93 return Ok(ToolResult {
94 success: false,
95 output: String::new(),
96 error: Some("'since' must be before 'until'".into()),
97 });
98 }
99
100 #[allow(clippy::cast_possible_truncation)]
101 let limit = args
102 .get("limit")
103 .and_then(serde_json::Value::as_u64)
104 .map_or(5, |v| v as usize);
105
106 match self.memory.recall(query, limit, None, since, until).await {
107 Ok(entries) if entries.is_empty() => Ok(ToolResult {
108 success: true,
109 output: "No memories found.".into(),
110 error: None,
111 }),
112 Ok(entries) => {
113 let mut output = format!("Found {} memories:\n", entries.len());
114 for entry in &entries {
115 let score = entry
116 .score
117 .map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
118 let _ = writeln!(
119 output,
120 "- [{}] {}: {}{score}",
121 entry.category, entry.key, entry.content
122 );
123 }
124 Ok(ToolResult {
125 success: true,
126 output,
127 error: None,
128 })
129 }
130 Err(e) => Ok(ToolResult {
131 success: false,
132 output: String::new(),
133 error: Some(format!("Memory recall failed: {e}")),
134 }),
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use std::sync::Mutex;
143 use tempfile::TempDir;
144 use zeroclaw_memory::{MemoryCategory, MemoryEntry, SqliteMemory, is_recent_recall_query};
145
146 fn seeded_mem() -> (TempDir, Arc<dyn Memory>) {
147 let tmp = TempDir::new().unwrap();
148 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
149 (tmp, Arc::new(mem))
150 }
151
152 struct QueryEchoMemory {
153 last_query: Arc<Mutex<Option<String>>>,
154 }
155
156 #[async_trait]
157 impl Memory for QueryEchoMemory {
158 fn name(&self) -> &str {
159 "query_echo"
160 }
161
162 async fn store(
163 &self,
164 _key: &str,
165 _content: &str,
166 _category: MemoryCategory,
167 _session_id: Option<&str>,
168 ) -> anyhow::Result<()> {
169 Ok(())
170 }
171
172 async fn recall(
173 &self,
174 query: &str,
175 _limit: usize,
176 _session_id: Option<&str>,
177 _since: Option<&str>,
178 _until: Option<&str>,
179 ) -> anyhow::Result<Vec<MemoryEntry>> {
180 *self.last_query.lock().unwrap() = Some(query.to_string());
181 if is_recent_recall_query(query) {
182 Ok(vec![MemoryEntry {
183 id: "recent".into(),
184 key: "recent".into(),
185 content: "recent memory".into(),
186 category: MemoryCategory::Core,
187 timestamp: "2026-05-03T00:00:00Z".into(),
188 session_id: None,
189 score: None,
190 namespace: "default".into(),
191 importance: None,
192 superseded_by: None,
193 agent_alias: None,
194 agent_id: None,
195 }])
196 } else {
197 Ok(Vec::new())
198 }
199 }
200
201 async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
202 Ok(None)
203 }
204
205 async fn list(
206 &self,
207 _category: Option<&MemoryCategory>,
208 _session_id: Option<&str>,
209 ) -> anyhow::Result<Vec<MemoryEntry>> {
210 Ok(Vec::new())
211 }
212
213 async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
214 Ok(false)
215 }
216
217 async fn forget_for_agent(&self, _key: &str, _agent_id: &str) -> anyhow::Result<bool> {
218 Ok(false)
219 }
220
221 async fn count(&self) -> anyhow::Result<usize> {
222 Ok(0)
223 }
224
225 async fn health_check(&self) -> bool {
226 true
227 }
228
229 async fn store_with_agent(
230 &self,
231 _key: &str,
232 _content: &str,
233 _category: MemoryCategory,
234 _session_id: Option<&str>,
235 _namespace: Option<&str>,
236 _importance: Option<f64>,
237 _agent_id: Option<&str>,
238 ) -> anyhow::Result<()> {
239 Ok(())
240 }
241
242 async fn recall_for_agents(
243 &self,
244 _allowed_agent_ids: &[&str],
245 query: &str,
246 limit: usize,
247 session_id: Option<&str>,
248 since: Option<&str>,
249 until: Option<&str>,
250 ) -> anyhow::Result<Vec<MemoryEntry>> {
251 self.recall(query, limit, session_id, since, until).await
252 }
253 }
254 impl ::zeroclaw_api::attribution::Attributable for QueryEchoMemory {
255 fn role(&self) -> ::zeroclaw_api::attribution::Role {
256 ::zeroclaw_api::attribution::Role::Memory(
257 ::zeroclaw_api::attribution::MemoryKind::InMemory,
258 )
259 }
260 fn alias(&self) -> &str {
261 "QueryEchoMemory"
262 }
263 }
264
265 #[tokio::test]
266 async fn recall_empty() {
267 let (_tmp, mem) = seeded_mem();
268 let tool = MemoryRecallTool::new(mem);
269 let result = tool.execute(json!({"query": "anything"})).await.unwrap();
270 assert!(result.success);
271 assert!(result.output.contains("No memories found"));
272 }
273
274 #[tokio::test]
275 async fn recall_finds_match() {
276 let (_tmp, mem) = seeded_mem();
277 mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
278 .await
279 .unwrap();
280 mem.store("tz", "Timezone is EST", MemoryCategory::Core, None)
281 .await
282 .unwrap();
283
284 let tool = MemoryRecallTool::new(mem);
285 let result = tool.execute(json!({"query": "Rust"})).await.unwrap();
286 assert!(result.success);
287 assert!(result.output.contains("Rust"));
288 assert!(result.output.contains("Found 1"));
289 }
290
291 #[tokio::test]
292 async fn recall_respects_limit() {
293 let (_tmp, mem) = seeded_mem();
294 for i in 0..10 {
295 mem.store(
296 &format!("k{i}"),
297 &format!("Rust fact {i}"),
298 MemoryCategory::Core,
299 None,
300 )
301 .await
302 .unwrap();
303 }
304
305 let tool = MemoryRecallTool::new(mem);
306 let result = tool
307 .execute(json!({"query": "Rust", "limit": 3}))
308 .await
309 .unwrap();
310 assert!(result.success);
311 assert!(result.output.contains("Found 3"));
312 }
313
314 #[tokio::test]
315 async fn bare_recall_returns_recent_entries() {
316 let (_tmp, mem) = seeded_mem();
317 mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
318 .await
319 .unwrap();
320 let tool = MemoryRecallTool::new(mem);
321 let result = tool.execute(json!({})).await.unwrap();
322 assert!(result.success);
323 assert!(result.output.contains("Found 1"));
324 assert!(result.output.contains("Rust"));
325 }
326
327 #[tokio::test]
328 async fn recall_star_query_returns_recent_entries() {
329 let (_tmp, mem) = seeded_mem();
330 mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
331 .await
332 .unwrap();
333 mem.store("tz", "Timezone is EST", MemoryCategory::Core, None)
334 .await
335 .unwrap();
336
337 let tool = MemoryRecallTool::new(mem);
338 let result = tool.execute(json!({"query": "*"})).await.unwrap();
339 assert!(result.success);
340 assert!(result.output.contains("Found 2"));
341 assert!(result.output.contains("Rust"));
342 assert!(result.output.contains("EST"));
343 }
344
345 #[tokio::test]
346 async fn recall_star_query_uses_backend_recent_query_contract() {
347 let last_query = Arc::new(Mutex::new(None));
348 let mem = Arc::new(QueryEchoMemory {
349 last_query: last_query.clone(),
350 });
351 let tool = MemoryRecallTool::new(mem);
352
353 let result = tool.execute(json!({"query": "*"})).await.unwrap();
354
355 assert!(result.success);
356 assert!(result.output.contains("recent memory"));
357 assert_eq!(*last_query.lock().unwrap(), Some("*".into()));
358 }
359
360 #[tokio::test]
361 async fn recall_time_only_returns_entries() {
362 let (_tmp, mem) = seeded_mem();
363 mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
364 .await
365 .unwrap();
366 let tool = MemoryRecallTool::new(mem);
367 let result = tool
369 .execute(json!({"since": "2020-01-01T00:00:00Z", "limit": 5}))
370 .await
371 .unwrap();
372 assert!(result.success);
373 assert!(result.output.contains("Found 1"));
374 assert!(result.output.contains("Rust"));
375 }
376
377 #[test]
378 fn name_and_schema() {
379 let (_tmp, mem) = seeded_mem();
380 let tool = MemoryRecallTool::new(mem);
381 assert_eq!(tool.name(), "memory_recall");
382 assert!(tool.parameters_schema()["properties"]["query"].is_object());
383 }
384
385 #[test]
386 fn score_formatted_as_percent() {
387 let score: Option<f64> = Some(0.63);
388 let formatted = score.map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
389 assert_eq!(formatted, " [63%]");
390
391 let score: Option<f64> = Some(0.42);
392 let formatted = score.map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
393 assert_eq!(formatted, " [42%]");
394
395 let score: Option<f64> = Some(1.0);
396 let formatted = score.map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
397 assert_eq!(formatted, " [100%]");
398
399 let score: Option<f64> = Some(0.0);
400 let formatted = score.map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
401 assert_eq!(formatted, " [0%]");
402
403 let score: Option<f64> = None;
404 let formatted = score.map_or_else(String::new, |s| format!(" [{:.0}%]", s * 100.0));
405 assert_eq!(formatted, "");
406 }
407
408 #[test]
409 fn schema_includes_search_mode_parameter() {
410 let (_tmp, mem) = seeded_mem();
411 let tool = MemoryRecallTool::new(mem);
412 let schema = tool.parameters_schema();
413 let search_mode = &schema["properties"]["search_mode"];
414 assert_eq!(search_mode["type"], "string");
415 let enum_values = search_mode["enum"].as_array().unwrap();
416 assert_eq!(enum_values.len(), 3);
417 assert!(enum_values.contains(&json!("bm25")));
418 assert!(enum_values.contains(&json!("embedding")));
419 assert!(enum_values.contains(&json!("hybrid")));
420 }
421}