1use super::embeddings::EmbeddingProvider;
2use super::traits::{ExportFilter, Memory, MemoryCategory, MemoryEntry, is_recent_recall_query};
3use super::vector;
4use anyhow::Context;
5use async_trait::async_trait;
6use chrono::Local;
7use parking_lot::Mutex;
8use rusqlite::{Connection, params};
9use std::collections::HashSet;
10use std::fmt::Write as _;
11use std::path::Path;
12use std::sync::Arc;
13use std::sync::mpsc;
14use std::sync::{Mutex as StdMutex, MutexGuard};
15use std::thread;
16use std::time::Duration;
17use uuid::Uuid;
18use zeroclaw_api::session_keys::sanitize_session_key;
19use zeroclaw_config::schema::SearchMode;
20
21const SQLITE_OPEN_TIMEOUT_CAP_SECS: u64 = 300;
23static SQLITE_MEMORY_STARTUP_LOCK: StdMutex<()> = StdMutex::new(());
24
25fn acquire_sqlite_startup_lock() -> MutexGuard<'static, ()> {
26 SQLITE_MEMORY_STARTUP_LOCK
27 .lock()
28 .unwrap_or_else(|poisoned| poisoned.into_inner())
29}
30
31pub struct SqliteMemory {
40 alias: String,
41 conn: Arc<Mutex<Connection>>,
42 embedder: Arc<dyn EmbeddingProvider>,
43 vector_weight: f32,
44 keyword_weight: f32,
45 cache_max: usize,
46 search_mode: SearchMode,
47}
48
49impl SqliteMemory {
50 pub fn new(alias: &str, workspace_dir: &Path) -> anyhow::Result<Self> {
51 Self::with_embedder(
52 alias,
53 workspace_dir,
54 Arc::new(super::embeddings::NoopEmbedding),
55 0.7,
56 0.3,
57 10_000,
58 None,
59 SearchMode::default(),
60 )
61 }
62
63 pub fn new_named(alias: &str, workspace_dir: &Path, db_name: &str) -> anyhow::Result<Self> {
65 let db_path = workspace_dir.join("memory").join(format!("{db_name}.db"));
66 let _startup_guard = acquire_sqlite_startup_lock();
67 if let Some(parent) = db_path.parent() {
68 std::fs::create_dir_all(parent)?;
69 }
70 let conn = Self::open_connection(&db_path, None)?;
71 conn.execute_batch(
72 "PRAGMA foreign_keys = ON;
77 PRAGMA journal_mode = WAL;
78 PRAGMA synchronous = NORMAL;
79 PRAGMA mmap_size = 8388608;
80 PRAGMA cache_size = -2000;
81 PRAGMA temp_store = MEMORY;",
82 )?;
83 Self::init_schema(&conn)?;
84 zeroclaw_config::schema::v2::migrate_sqlite_memory_to_v3(&db_path, &conn)?;
85 Ok(Self {
86 alias: alias.to_string(),
87 conn: Arc::new(Mutex::new(conn)),
88 embedder: Arc::new(super::embeddings::NoopEmbedding),
89 vector_weight: 0.7,
90 keyword_weight: 0.3,
91 cache_max: 10_000,
92 search_mode: SearchMode::default(),
93 })
94 }
95
96 pub fn with_embedder(
102 alias: &str,
103 workspace_dir: &Path,
104 embedder: Arc<dyn EmbeddingProvider>,
105 vector_weight: f32,
106 keyword_weight: f32,
107 cache_max: usize,
108 open_timeout_secs: Option<u64>,
109 search_mode: SearchMode,
110 ) -> anyhow::Result<Self> {
111 let db_path = workspace_dir.join("memory").join("brain.db");
112 let _startup_guard = acquire_sqlite_startup_lock();
113
114 if let Some(parent) = db_path.parent() {
115 std::fs::create_dir_all(parent)?;
116 }
117
118 let conn = Self::open_connection(&db_path, open_timeout_secs)?;
119
120 conn.execute_batch(
130 "PRAGMA foreign_keys = ON;
131 PRAGMA journal_mode = WAL;
132 PRAGMA synchronous = NORMAL;
133 PRAGMA mmap_size = 8388608;
134 PRAGMA cache_size = -2000;
135 PRAGMA temp_store = MEMORY;",
136 )?;
137
138 Self::init_schema(&conn)?;
139 zeroclaw_config::schema::v2::migrate_sqlite_memory_to_v3(&db_path, &conn)?;
140
141 Ok(Self {
142 alias: alias.to_string(),
143 conn: Arc::new(Mutex::new(conn)),
144 embedder,
145 vector_weight,
146 keyword_weight,
147 cache_max,
148 search_mode,
149 })
150 }
151
152 fn open_connection(
154 db_path: &Path,
155 open_timeout_secs: Option<u64>,
156 ) -> anyhow::Result<Connection> {
157 let path_buf = db_path.to_path_buf();
158
159 let conn = if let Some(secs) = open_timeout_secs {
160 let capped = secs.min(SQLITE_OPEN_TIMEOUT_CAP_SECS);
161 let (tx, rx) = mpsc::channel();
162 thread::spawn(move || {
163 let result = Connection::open(&path_buf);
164 let _ = tx.send(result);
165 });
166 match rx.recv_timeout(Duration::from_secs(capped)) {
167 Ok(Ok(c)) => c,
168 Ok(Err(e)) => return Err(e).context("SQLite failed to open database"),
169 Err(mpsc::RecvTimeoutError::Timeout) => {
170 anyhow::bail!("SQLite connection open timed out after {} seconds", capped);
171 }
172 Err(mpsc::RecvTimeoutError::Disconnected) => {
173 anyhow::bail!("SQLite open thread exited unexpectedly");
174 }
175 }
176 } else {
177 Connection::open(&path_buf).context("SQLite failed to open database")?
178 };
179
180 Ok(conn)
181 }
182
183 fn init_schema(conn: &Connection) -> anyhow::Result<()> {
185 fn is_db_locked_error(e: &rusqlite::Error) -> bool {
186 use rusqlite::ffi::ErrorCode;
187 matches!(
188 e,
189 rusqlite::Error::SqliteFailure(err, _)
190 if matches!(err.code, ErrorCode::DatabaseBusy | ErrorCode::DatabaseLocked)
191 )
192 }
193
194 fn execute_batch_retry(conn: &Connection, sql: &str) -> Result<(), rusqlite::Error> {
195 let mut backoff = Duration::from_millis(10);
199 let max_backoff = Duration::from_millis(250);
200 let max_attempts: usize = 24; for attempt in 1..=max_attempts {
203 match conn.execute_batch(sql) {
204 Ok(()) => return Ok(()),
205 Err(e) if is_db_locked_error(&e) && attempt < max_attempts => {
206 std::thread::sleep(backoff);
207 backoff = (backoff * 2).min(max_backoff);
208 }
209 Err(e) => return Err(e),
210 }
211 }
212
213 Ok(())
215 }
216
217 fn memories_has_column(conn: &Connection, name: &str) -> anyhow::Result<bool> {
218 let mut stmt = conn.prepare("PRAGMA table_info(memories)")?;
219 let mut rows = stmt.query([])?;
220 while let Some(row) = rows.next()? {
221 let col_name: String = row.get(1)?;
222 if col_name == name {
223 return Ok(true);
224 }
225 }
226 Ok(false)
227 }
228
229 fn is_duplicate_column_error(e: &rusqlite::Error) -> bool {
230 matches!(
231 e,
232 rusqlite::Error::SqliteFailure(_, Some(msg)) if msg.contains("duplicate column name")
233 )
234 }
235
236 fn add_memories_column_if_missing(
237 conn: &Connection,
238 name: &str,
239 alter_sql: &str,
240 ) -> anyhow::Result<()> {
241 if memories_has_column(conn, name)? {
242 return Ok(());
243 }
244
245 match execute_batch_retry(conn, alter_sql) {
246 Ok(()) => Ok(()),
247 Err(e) if is_duplicate_column_error(&e) => Ok(()),
248 Err(e) => Err(e)
249 .with_context(|| format!("SQLite migration failed adding memories.{name}")),
250 }
251 }
252
253 execute_batch_retry(
254 conn,
255 "-- Core memories table. This is an intermediate shape; the V3
256 -- migration in `zeroclaw_config::schema::v2::migrate_sqlite_memory_to_v3`
257 -- rebuilds it with the `agent_id` column and a composite
258 -- `UNIQUE (agent_id, key)` constraint immediately after init.
259 CREATE TABLE IF NOT EXISTS memories (
260 id TEXT PRIMARY KEY,
261 key TEXT NOT NULL UNIQUE,
262 content TEXT NOT NULL,
263 category TEXT NOT NULL DEFAULT 'core',
264 embedding BLOB,
265 created_at TEXT NOT NULL,
266 updated_at TEXT NOT NULL
267 );
268 CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category);
269 CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key);
270
271 -- FTS5 full-text search (BM25 scoring)
272 CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
273 key, content, content=memories, content_rowid=rowid
274 );
275
276 -- FTS5 triggers: keep in sync with memories table
277 CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
278 INSERT INTO memories_fts(rowid, key, content)
279 VALUES (new.rowid, new.key, new.content);
280 END;
281 CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
282 INSERT INTO memories_fts(memories_fts, rowid, key, content)
283 VALUES ('delete', old.rowid, old.key, old.content);
284 END;
285 CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
286 INSERT INTO memories_fts(memories_fts, rowid, key, content)
287 VALUES ('delete', old.rowid, old.key, old.content);
288 INSERT INTO memories_fts(rowid, key, content)
289 VALUES (new.rowid, new.key, new.content);
290 END;
291
292 -- Embedding cache with LRU eviction
293 CREATE TABLE IF NOT EXISTS embedding_cache (
294 content_hash TEXT PRIMARY KEY,
295 embedding BLOB NOT NULL,
296 created_at TEXT NOT NULL,
297 accessed_at TEXT NOT NULL
298 );
299 CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
300 )
301 .with_context(|| "SQLite init_schema failed: CREATE base schema")?;
302
303 add_memories_column_if_missing(
304 conn,
305 "session_id",
306 "ALTER TABLE memories ADD COLUMN session_id TEXT;",
307 )?;
308 execute_batch_retry(
309 conn,
310 "CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
311 )
312 .with_context(|| "SQLite init_schema failed: CREATE INDEX idx_memories_session")?;
313
314 add_memories_column_if_missing(
315 conn,
316 "namespace",
317 "ALTER TABLE memories ADD COLUMN namespace TEXT DEFAULT 'default';",
318 )?;
319 execute_batch_retry(
320 conn,
321 "CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace);",
322 )
323 .with_context(|| "SQLite init_schema failed: CREATE INDEX idx_memories_namespace")?;
324
325 add_memories_column_if_missing(
326 conn,
327 "importance",
328 "ALTER TABLE memories ADD COLUMN importance REAL DEFAULT 0.5;",
329 )?;
330
331 add_memories_column_if_missing(
332 conn,
333 "superseded_by",
334 "ALTER TABLE memories ADD COLUMN superseded_by TEXT;",
335 )?;
336
337 Self::migrate_session_ids_to_sanitized(conn)?;
338
339 Ok(())
340 }
341
342 fn migrate_session_ids_to_sanitized(conn: &Connection) -> anyhow::Result<()> {
351 let distinct: Vec<String> = {
352 let mut stmt = conn
353 .prepare("SELECT DISTINCT session_id FROM memories WHERE session_id IS NOT NULL")?;
354 let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
355 rows.collect::<Result<Vec<_>, _>>()?
356 };
357
358 let mut update =
359 conn.prepare("UPDATE memories SET session_id = ?1 WHERE session_id = ?2")?;
360 let mut rewritten = 0usize;
361 for old in &distinct {
362 let new = sanitize_session_key(old);
363 if new != *old {
364 update.execute(params![new, old])?;
365 rewritten += 1;
366 }
367 }
368
369 if rewritten > 0 {
370 ::zeroclaw_log::record!(
371 INFO,
372 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
373 .with_attrs(::serde_json::json!({"rewritten": rewritten})),
374 "Normalized session_id values in memories table to sanitized form"
375 );
376 }
377
378 Ok(())
379 }
380
381 fn category_to_str(cat: &MemoryCategory) -> String {
382 match cat {
383 MemoryCategory::Core => "core".into(),
384 MemoryCategory::Daily => "daily".into(),
385 MemoryCategory::Conversation => "conversation".into(),
386 MemoryCategory::Custom(name) => name.clone(),
387 }
388 }
389
390 fn str_to_category(s: &str) -> MemoryCategory {
391 match s {
392 "core" => MemoryCategory::Core,
393 "daily" => MemoryCategory::Daily,
394 "conversation" => MemoryCategory::Conversation,
395 other => MemoryCategory::Custom(other.to_string()),
396 }
397 }
398
399 fn content_hash(text: &str) -> String {
403 use sha2::{Digest, Sha256};
404 let hash = Sha256::digest(text.as_bytes());
405 format!(
407 "{:016x}",
408 u64::from_be_bytes(
409 hash[..8]
410 .try_into()
411 .expect("SHA-256 always produces >= 8 bytes")
412 )
413 )
414 }
415
416 pub fn connection(&self) -> &Arc<Mutex<Connection>> {
418 &self.conn
419 }
420
421 pub async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
423 if self.embedder.dimensions() == 0 {
424 return Ok(None); }
426
427 let hash = Self::content_hash(text);
428 let now = Local::now().to_rfc3339();
429
430 let conn = self.conn.clone();
432 let hash_c = hash.clone();
433 let now_c = now.clone();
434 let cached = tokio::task::spawn_blocking(move || -> anyhow::Result<Option<Vec<f32>>> {
435 let conn = conn.lock();
436 let mut stmt =
437 conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
438 let blob: Option<Vec<u8>> = stmt.query_row(params![hash_c], |row| row.get(0)).ok();
439 if let Some(bytes) = blob {
440 conn.execute(
441 "UPDATE embedding_cache SET accessed_at = ?1 WHERE content_hash = ?2",
442 params![now_c, hash_c],
443 )?;
444 return Ok(Some(vector::bytes_to_vec(&bytes)));
445 }
446 Ok(None)
447 })
448 .await??;
449
450 if cached.is_some() {
451 return Ok(cached);
452 }
453
454 let embedding = self.embedder.embed_one(text).await?;
456 let bytes = vector::vec_to_bytes(&embedding);
457
458 let conn = self.conn.clone();
460 #[allow(clippy::cast_possible_wrap)]
461 let cache_max = self.cache_max as i64;
462 tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
463 let conn = conn.lock();
464 conn.execute(
465 "INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
466 VALUES (?1, ?2, ?3, ?4)",
467 params![hash, bytes, now, now],
468 )?;
469 conn.execute(
470 "DELETE FROM embedding_cache WHERE content_hash IN (
471 SELECT content_hash FROM embedding_cache
472 ORDER BY accessed_at ASC
473 LIMIT MAX(0, (SELECT COUNT(*) FROM embedding_cache) - ?1)
474 )",
475 params![cache_max],
476 )?;
477 Ok(())
478 })
479 .await??;
480
481 Ok(Some(embedding))
482 }
483
484 pub fn fts5_search(
486 conn: &Connection,
487 query: &str,
488 limit: usize,
489 ) -> anyhow::Result<Vec<(String, f32)>> {
490 let fts_query: String = query
492 .split_whitespace()
493 .map(Self::fts5_term_query)
494 .collect::<Vec<_>>()
495 .join(" OR ");
496
497 if fts_query.is_empty() {
498 return Ok(Vec::new());
499 }
500
501 let sql = "SELECT m.id, bm25(memories_fts) as score
502 FROM memories_fts f
503 JOIN memories m ON m.rowid = f.rowid
504 WHERE memories_fts MATCH ?1
505 ORDER BY score
506 LIMIT ?2";
507
508 let mut stmt = conn.prepare(sql)?;
509 #[allow(clippy::cast_possible_wrap)]
510 let limit_i64 = limit as i64;
511
512 let rows = stmt.query_map(params![fts_query, limit_i64], |row| {
513 let id: String = row.get(0)?;
514 let score: f64 = row.get(1)?;
515 #[allow(clippy::cast_possible_truncation)]
517 Ok((id, (-score) as f32))
518 })?;
519
520 let mut results = Vec::new();
521 for row in rows {
522 results.push(row?);
523 }
524 Ok(results)
525 }
526
527 fn fts5_term_query(term: &str) -> String {
528 if let Some(prefix) = term.strip_suffix('*')
529 && !prefix.is_empty()
530 {
531 let escaped = prefix.replace('"', "\"\"");
532 format!("\"{escaped}\"*")
533 } else {
534 let escaped = term.replace('"', "\"\"");
535 format!("\"{escaped}\"")
536 }
537 }
538
539 fn like_search_pattern(term: &str) -> String {
540 if let Some(prefix) = term.strip_suffix('*')
541 && !prefix.is_empty()
542 {
543 return format!("%{}%", Self::escape_like_pattern(prefix));
544 }
545 format!("%{}%", Self::escape_like_pattern(term))
546 }
547
548 fn is_prefix_wildcard_term(term: &str) -> bool {
549 matches!(term.strip_suffix('*'), Some(prefix) if !prefix.is_empty())
550 }
551
552 fn escape_like_pattern(term: &str) -> String {
553 let mut escaped = String::with_capacity(term.len());
554 for ch in term.chars() {
555 if matches!(ch, '%' | '_' | '\\') {
556 escaped.push('\\');
557 }
558 escaped.push(ch);
559 }
560 escaped
561 }
562
563 fn like_fallback_matches(text: &str, term: &str) -> bool {
564 let text = text.to_lowercase();
565 if let Some(prefix) = term.strip_suffix('*')
566 && !prefix.is_empty()
567 {
568 let prefix = prefix.to_lowercase();
569 return text
570 .split(|ch: char| !ch.is_alphanumeric() && ch != '_')
571 .any(|token| token.starts_with(&prefix));
572 }
573 text.contains(&term.to_lowercase())
574 }
575
576 pub fn vector_search(
581 conn: &Connection,
582 query_embedding: &[f32],
583 limit: usize,
584 category: Option<&str>,
585 session_id: Option<&str>,
586 ) -> anyhow::Result<Vec<(String, f32)>> {
587 let mut sql = "SELECT id, embedding FROM memories WHERE embedding IS NOT NULL".to_string();
588 let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
589 let mut idx = 1;
590
591 if let Some(cat) = category {
592 let _ = write!(sql, " AND category = ?{idx}");
593 param_values.push(Box::new(cat.to_string()));
594 idx += 1;
595 }
596 if let Some(sid) = session_id {
597 let _ = write!(sql, " AND session_id = ?{idx}");
598 param_values.push(Box::new(sid.to_string()));
599 }
600
601 let mut stmt = conn.prepare(&sql)?;
602 let params_ref: Vec<&dyn rusqlite::types::ToSql> =
603 param_values.iter().map(AsRef::as_ref).collect();
604 let rows = stmt.query_map(params_ref.as_slice(), |row| {
605 let id: String = row.get(0)?;
606 let blob: Vec<u8> = row.get(1)?;
607 Ok((id, blob))
608 })?;
609
610 let mut scored: Vec<(String, f32)> = Vec::new();
611 for row in rows {
612 let (id, blob) = row?;
613 let emb = vector::bytes_to_vec(&blob);
614 let sim = vector::cosine_similarity(query_embedding, &emb);
615 if sim > 0.0 {
616 scored.push((id, sim));
617 }
618 }
619
620 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
621 scored.truncate(limit);
622 Ok(scored)
623 }
624
625 async fn recall_by_time_only(
627 &self,
628 limit: usize,
629 session_id: Option<&str>,
630 since: Option<&str>,
631 until: Option<&str>,
632 ) -> anyhow::Result<Vec<MemoryEntry>> {
633 let conn = self.conn.clone();
634 let sid = session_id.map(String::from);
635 let since_owned = since.map(String::from);
636 let until_owned = until.map(String::from);
637
638 tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
639 let conn = conn.lock();
640 let since_ref = since_owned.as_deref();
641 let until_ref = until_owned.as_deref();
642
643 let mut sql =
644 "SELECT m.id, m.key, m.content, m.category, m.created_at, m.session_id, m.namespace, m.importance, m.superseded_by, a.alias, m.agent_id \
645 FROM memories m LEFT JOIN agents a ON a.id = m.agent_id \
646 WHERE m.superseded_by IS NULL AND 1=1"
647 .to_string();
648 let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
649 let mut idx = 1;
650
651 if let Some(sid) = sid.as_deref() {
652 let _ = write!(sql, " AND m.session_id = ?{idx}");
653 param_values.push(Box::new(sid.to_string()));
654 idx += 1;
655 }
656 if let Some(s) = since_ref {
657 let _ = write!(sql, " AND m.created_at >= ?{idx}");
658 param_values.push(Box::new(s.to_string()));
659 idx += 1;
660 }
661 if let Some(u) = until_ref {
662 let _ = write!(sql, " AND m.created_at <= ?{idx}");
663 param_values.push(Box::new(u.to_string()));
664 idx += 1;
665 }
666 let _ = write!(sql, " ORDER BY m.updated_at DESC LIMIT ?{idx}");
667 #[allow(clippy::cast_possible_wrap)]
668 param_values.push(Box::new(limit as i64));
669
670 let mut stmt = conn.prepare(&sql)?;
671 let params_ref: Vec<&dyn rusqlite::types::ToSql> =
672 param_values.iter().map(AsRef::as_ref).collect();
673 let rows = stmt.query_map(params_ref.as_slice(), |row| {
674 Ok(MemoryEntry {
675 id: row.get(0)?,
676 key: row.get(1)?,
677 content: row.get(2)?,
678 category: Self::str_to_category(&row.get::<_, String>(3)?),
679 timestamp: row.get(4)?,
680 session_id: row.get(5)?,
681 score: None,
682 namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
683 importance: row.get(7)?,
684 superseded_by: row.get(8)?,
685 agent_alias: row.get(9)?,
686 agent_id: row.get(10)?,
687 })
688 })?;
689
690 let mut results = Vec::new();
691 for row in rows {
692 results.push(row?);
693 }
694 Ok(results)
695 })
696 .await?
697 }
698}
699
700#[async_trait]
701impl Memory for SqliteMemory {
702 fn name(&self) -> &str {
703 "sqlite"
704 }
705
706 async fn store(
707 &self,
708 key: &str,
709 content: &str,
710 category: MemoryCategory,
711 session_id: Option<&str>,
712 ) -> anyhow::Result<()> {
713 self.store_with_agent(key, content, category, session_id, None, None, None)
718 .await
719 }
720
721 async fn recall(
722 &self,
723 query: &str,
724 limit: usize,
725 session_id: Option<&str>,
726 since: Option<&str>,
727 until: Option<&str>,
728 ) -> anyhow::Result<Vec<MemoryEntry>> {
729 if is_recent_recall_query(query) {
733 return self
734 .recall_by_time_only(limit, session_id, since, until)
735 .await;
736 }
737
738 let query_embedding = if self.search_mode == SearchMode::Bm25 {
740 None
741 } else {
742 self.get_or_compute_embedding(query).await?
743 };
744
745 let conn = self.conn.clone();
746 let query = query.to_string();
747 let sid = session_id.map(String::from);
748 let since_owned = since.map(String::from);
749 let until_owned = until.map(String::from);
750 let vector_weight = self.vector_weight;
751 let keyword_weight = self.keyword_weight;
752 let search_mode = self.search_mode.clone();
753
754 tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
755 let conn = conn.lock();
756 let session_ref = sid.as_deref();
757 let since_ref = since_owned.as_deref();
758 let until_ref = until_owned.as_deref();
759
760 let keyword_results = if search_mode == SearchMode::Embedding {
762 Vec::new()
763 } else {
764 Self::fts5_search(&conn, &query, limit * 2).unwrap_or_default()
765 };
766
767 let vector_results = if search_mode == SearchMode::Bm25 {
769 Vec::new()
770 } else if let Some(ref qe) = query_embedding {
771 Self::vector_search(&conn, qe, limit * 2, None, session_ref).unwrap_or_default()
772 } else {
773 Vec::new()
774 };
775
776 let merged = if vector_results.is_empty() {
778 keyword_results
779 .iter()
780 .map(|(id, score)| vector::ScoredResult {
781 id: id.clone(),
782 vector_score: None,
783 keyword_score: Some(*score),
784 final_score: *score,
785 })
786 .collect::<Vec<_>>()
787 } else if keyword_results.is_empty() {
788 vector_results
789 .iter()
790 .map(|(id, score)| vector::ScoredResult {
791 id: id.clone(),
792 vector_score: Some(*score),
793 keyword_score: None,
794 final_score: *score,
795 })
796 .collect::<Vec<_>>()
797 } else {
798 vector::hybrid_merge(
799 &vector_results,
800 &keyword_results,
801 vector_weight,
802 keyword_weight,
803 limit,
804 )
805 };
806
807 let mut results = Vec::new();
810 if !merged.is_empty() {
811 let placeholders: String = (1..=merged.len())
812 .map(|i| format!("?{i}"))
813 .collect::<Vec<_>>()
814 .join(", ");
815 let sql = format!(
816 "SELECT m.id, m.key, m.content, m.category, m.created_at, m.session_id, m.namespace, m.importance, m.superseded_by, a.alias, m.agent_id \
817 FROM memories m LEFT JOIN agents a ON a.id = m.agent_id \
818 WHERE m.superseded_by IS NULL AND m.id IN ({placeholders})"
819 );
820 let mut stmt = conn.prepare(&sql)?;
821 let id_params: Vec<Box<dyn rusqlite::types::ToSql>> = merged
822 .iter()
823 .map(|s| Box::new(s.id.clone()) as Box<dyn rusqlite::types::ToSql>)
824 .collect();
825 let params_ref: Vec<&dyn rusqlite::types::ToSql> =
826 id_params.iter().map(AsRef::as_ref).collect();
827 let rows = stmt.query_map(params_ref.as_slice(), |row| {
828 Ok((
829 row.get::<_, String>(0)?,
830 row.get::<_, String>(1)?,
831 row.get::<_, String>(2)?,
832 row.get::<_, String>(3)?,
833 row.get::<_, String>(4)?,
834 row.get::<_, Option<String>>(5)?,
835 row.get::<_, Option<String>>(6)?,
836 row.get::<_, Option<f64>>(7)?,
837 row.get::<_, Option<String>>(8)?,
838 row.get::<_, Option<String>>(9)?,
839 row.get::<_, Option<String>>(10)?,
840 ))
841 })?;
842
843 let mut entry_map = std::collections::HashMap::new();
844 for row in rows {
845 let (id, key, content, cat, ts, sid, ns, imp, sup, alias, aid) = row?;
846 entry_map.insert(id, (key, content, cat, ts, sid, ns, imp, sup, alias, aid));
847 }
848
849 for scored in &merged {
850 if let Some((key, content, cat, ts, sid, ns, imp, sup, alias, aid)) = entry_map.remove(&scored.id) {
851 if let Some(s) = since_ref
852 && ts.as_str() < s {
853 continue;
854 }
855 if let Some(u) = until_ref
856 && ts.as_str() > u {
857 continue;
858 }
859 let entry = MemoryEntry {
860 id: scored.id.clone(),
861 key,
862 content,
863 category: Self::str_to_category(&cat),
864 timestamp: ts,
865 session_id: sid,
866 score: Some(f64::from(scored.final_score)),
867 namespace: ns.unwrap_or_else(|| "default".into()),
868 importance: imp,
869 superseded_by: sup,
870 agent_alias: alias,
871 agent_id: aid,
872 };
873 if let Some(filter_sid) = session_ref
874 && entry.session_id.as_deref() != Some(filter_sid) {
875 continue;
876 }
877 results.push(entry);
878 }
879 }
880 }
881
882 if results.is_empty() {
884 const MAX_LIKE_KEYWORDS: usize = 8;
885 let raw_keywords: Vec<String> = query
886 .split_whitespace()
887 .take(MAX_LIKE_KEYWORDS)
888 .map(str::to_string)
889 .collect();
890 if !raw_keywords.is_empty() {
891 let needs_prefix_filter = raw_keywords
892 .iter()
893 .any(|keyword| Self::is_prefix_wildcard_term(keyword));
894 let sql_limit = if needs_prefix_filter {
895 limit.saturating_mul(8).min(limit.saturating_add(512))
896 } else {
897 limit
898 };
899 let patterns: Vec<String> = raw_keywords
900 .iter()
901 .map(|keyword| Self::like_search_pattern(keyword))
902 .collect();
903 let conditions: Vec<String> = patterns
904 .iter()
905 .enumerate()
906 .map(|(i, _)| {
907 format!(
908 "(m.content LIKE ?{} ESCAPE '\\' OR m.key LIKE ?{} ESCAPE '\\')",
909 i * 2 + 1,
910 i * 2 + 2
911 )
912 })
913 .collect();
914 let where_clause = conditions.join(" OR ");
915 let mut param_idx = patterns.len() * 2 + 1;
916 let mut time_conditions = String::new();
917 if since_ref.is_some() {
918 let _ = write!(time_conditions, " AND m.created_at >= ?{param_idx}");
919 param_idx += 1;
920 }
921 if until_ref.is_some() {
922 let _ = write!(time_conditions, " AND m.created_at <= ?{param_idx}");
923 param_idx += 1;
924 }
925 let sql = format!(
926 "SELECT m.id, m.key, m.content, m.category, m.created_at, m.session_id, m.namespace, m.importance, m.superseded_by, a.alias, m.agent_id
927 FROM memories m LEFT JOIN agents a ON a.id = m.agent_id
928 WHERE m.superseded_by IS NULL AND ({where_clause}){time_conditions}
929 ORDER BY m.updated_at DESC
930 LIMIT ?{param_idx}"
931 );
932 let mut stmt = conn.prepare(&sql)?;
933 let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
934 for kw in &patterns {
935 param_values.push(Box::new(kw.clone()));
936 param_values.push(Box::new(kw.clone()));
937 }
938 if let Some(s) = since_ref {
939 param_values.push(Box::new(s.to_string()));
940 }
941 if let Some(u) = until_ref {
942 param_values.push(Box::new(u.to_string()));
943 }
944 #[allow(clippy::cast_possible_wrap)]
945 param_values.push(Box::new(sql_limit as i64));
946 let params_ref: Vec<&dyn rusqlite::types::ToSql> =
947 param_values.iter().map(AsRef::as_ref).collect();
948 let rows = stmt.query_map(params_ref.as_slice(), |row| {
949 Ok(MemoryEntry {
950 id: row.get(0)?,
951 key: row.get(1)?,
952 content: row.get(2)?,
953 category: Self::str_to_category(&row.get::<_, String>(3)?),
954 timestamp: row.get(4)?,
955 session_id: row.get(5)?,
956 score: Some(1.0),
957 namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
958 importance: row.get(7)?,
959 superseded_by: row.get(8)?,
960 agent_alias: row.get(9)?,
961 agent_id: row.get(10)?,
962 })
963 })?;
964 for row in rows {
965 let entry = row?;
966 if let Some(sid) = session_ref
967 && entry.session_id.as_deref() != Some(sid) {
968 continue;
969 }
970 if needs_prefix_filter
971 && !raw_keywords.iter().any(|keyword| {
972 Self::like_fallback_matches(&entry.key, keyword)
973 || Self::like_fallback_matches(&entry.content, keyword)
974 })
975 {
976 continue;
977 }
978 results.push(entry);
979 if results.len() >= limit {
980 break;
981 }
982 }
983 }
984 }
985
986 results.truncate(limit);
987 Ok(results)
988 })
989 .await?
990 }
991
992 async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
993 let conn = self.conn.clone();
994 let key = key.to_string();
995
996 tokio::task::spawn_blocking(move || -> anyhow::Result<Option<MemoryEntry>> {
997 let conn = conn.lock();
998 let mut stmt = conn.prepare(
999 "SELECT m.id, m.key, m.content, m.category, m.created_at, m.session_id, m.namespace, m.importance, m.superseded_by, a.alias, m.agent_id \
1000 FROM memories m LEFT JOIN agents a ON a.id = m.agent_id \
1001 WHERE m.key = ?1",
1002 )?;
1003
1004 let mut rows = stmt.query_map(params![key], |row| {
1005 Ok(MemoryEntry {
1006 id: row.get(0)?,
1007 key: row.get(1)?,
1008 content: row.get(2)?,
1009 category: Self::str_to_category(&row.get::<_, String>(3)?),
1010 timestamp: row.get(4)?,
1011 session_id: row.get(5)?,
1012 score: None,
1013 namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
1014 importance: row.get(7)?,
1015 superseded_by: row.get(8)?,
1016 agent_alias: row.get(9)?,
1017 agent_id: row.get(10)?,
1018 })
1019 })?;
1020
1021 match rows.next() {
1022 Some(Ok(entry)) => Ok(Some(entry)),
1023 _ => Ok(None),
1024 }
1025 })
1026 .await?
1027 }
1028
1029 async fn get_for_agent(
1030 &self,
1031 key: &str,
1032 agent_id: &str,
1033 ) -> anyhow::Result<Option<MemoryEntry>> {
1034 let conn = self.conn.clone();
1035 let key = key.to_string();
1036 let agent_id = agent_id.to_string();
1037
1038 tokio::task::spawn_blocking(move || -> anyhow::Result<Option<MemoryEntry>> {
1039 let conn = conn.lock();
1040 let mut stmt = conn.prepare(
1041 "SELECT m.id, m.key, m.content, m.category, m.created_at, m.session_id, m.namespace, m.importance, m.superseded_by, a.alias, m.agent_id \
1042 FROM memories m LEFT JOIN agents a ON a.id = m.agent_id \
1043 WHERE m.key = ?1 AND m.agent_id = ?2",
1044 )?;
1045
1046 let mut rows = stmt.query_map(params![key, agent_id], |row| {
1047 Ok(MemoryEntry {
1048 id: row.get(0)?,
1049 key: row.get(1)?,
1050 content: row.get(2)?,
1051 category: Self::str_to_category(&row.get::<_, String>(3)?),
1052 timestamp: row.get(4)?,
1053 session_id: row.get(5)?,
1054 score: None,
1055 namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
1056 importance: row.get(7)?,
1057 superseded_by: row.get(8)?,
1058 agent_alias: row.get(9)?,
1059 agent_id: row.get(10)?,
1060 })
1061 })?;
1062
1063 match rows.next() {
1064 Some(Ok(entry)) => Ok(Some(entry)),
1065 _ => Ok(None),
1066 }
1067 })
1068 .await?
1069 }
1070
1071 async fn list(
1072 &self,
1073 category: Option<&MemoryCategory>,
1074 session_id: Option<&str>,
1075 ) -> anyhow::Result<Vec<MemoryEntry>> {
1076 const DEFAULT_LIST_LIMIT: i64 = 1000;
1077
1078 let conn = self.conn.clone();
1079 let category = category.cloned();
1080 let sid = session_id.map(String::from);
1081
1082 tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
1083 let conn = conn.lock();
1084 let session_ref = sid.as_deref();
1085 let mut results = Vec::new();
1086
1087 let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result<MemoryEntry> {
1088 Ok(MemoryEntry {
1089 id: row.get(0)?,
1090 key: row.get(1)?,
1091 content: row.get(2)?,
1092 category: Self::str_to_category(&row.get::<_, String>(3)?),
1093 timestamp: row.get(4)?,
1094 session_id: row.get(5)?,
1095 score: None,
1096 namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
1097 importance: row.get(7)?,
1098 superseded_by: row.get(8)?,
1099 agent_alias: row.get(9)?,
1100 agent_id: row.get(10)?,
1101 })
1102 };
1103
1104 if let Some(ref cat) = category {
1105 let cat_str = Self::category_to_str(cat);
1106 let mut stmt = conn.prepare(
1107 "SELECT m.id, m.key, m.content, m.category, m.created_at, m.session_id, m.namespace, m.importance, m.superseded_by, a.alias, m.agent_id
1108 FROM memories m LEFT JOIN agents a ON a.id = m.agent_id
1109 WHERE m.superseded_by IS NULL AND m.category = ?1 ORDER BY m.updated_at DESC LIMIT ?2",
1110 )?;
1111 let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?;
1112 for row in rows {
1113 let entry = row?;
1114 if let Some(sid) = session_ref
1115 && entry.session_id.as_deref() != Some(sid) {
1116 continue;
1117 }
1118 results.push(entry);
1119 }
1120 } else {
1121 let mut stmt = conn.prepare(
1122 "SELECT m.id, m.key, m.content, m.category, m.created_at, m.session_id, m.namespace, m.importance, m.superseded_by, a.alias, m.agent_id
1123 FROM memories m LEFT JOIN agents a ON a.id = m.agent_id
1124 WHERE m.superseded_by IS NULL ORDER BY m.updated_at DESC LIMIT ?1",
1125 )?;
1126 let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?;
1127 for row in rows {
1128 let entry = row?;
1129 if let Some(sid) = session_ref
1130 && entry.session_id.as_deref() != Some(sid) {
1131 continue;
1132 }
1133 results.push(entry);
1134 }
1135 }
1136
1137 Ok(results)
1138 })
1139 .await?
1140 }
1141
1142 async fn forget(&self, key: &str) -> anyhow::Result<bool> {
1143 let conn = self.conn.clone();
1144 let key = key.to_string();
1145
1146 tokio::task::spawn_blocking(move || -> anyhow::Result<bool> {
1147 let conn = conn.lock();
1148 let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
1149 Ok(affected > 0)
1150 })
1151 .await?
1152 }
1153
1154 async fn forget_for_agent(&self, key: &str, agent_id: &str) -> anyhow::Result<bool> {
1155 let conn = self.conn.clone();
1156 let key = key.to_string();
1157 let agent_id = agent_id.to_string();
1158
1159 tokio::task::spawn_blocking(move || -> anyhow::Result<bool> {
1160 let conn = conn.lock();
1161 let affected = conn.execute(
1162 "DELETE FROM memories WHERE key = ?1 AND agent_id = ?2",
1163 params![key, agent_id],
1164 )?;
1165 Ok(affected > 0)
1166 })
1167 .await?
1168 }
1169
1170 async fn purge_namespace(&self, namespace: &str) -> anyhow::Result<usize> {
1171 let conn = self.conn.clone();
1172 let namespace = namespace.to_string();
1173
1174 tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
1175 let conn = conn.lock();
1176 let affected = conn.execute(
1177 "DELETE FROM memories WHERE namespace = ?1",
1178 params![namespace],
1179 )?;
1180 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
1181 Ok(affected)
1182 })
1183 .await?
1184 }
1185
1186 async fn purge_session(&self, session_id: &str) -> anyhow::Result<usize> {
1187 let conn = self.conn.clone();
1188 let session_id = session_id.to_string();
1189
1190 tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
1191 let conn = conn.lock();
1192 let affected = conn.execute(
1193 "DELETE FROM memories WHERE session_id = ?1",
1194 params![session_id],
1195 )?;
1196 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
1197 Ok(affected)
1198 })
1199 .await?
1200 }
1201
1202 async fn purge_session_for_agent(
1203 &self,
1204 session_id: &str,
1205 agent_id: &str,
1206 ) -> anyhow::Result<usize> {
1207 let conn = self.conn.clone();
1208 let session_id = session_id.to_string();
1209 let agent_id = agent_id.to_string();
1210
1211 tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
1212 let conn = conn.lock();
1213 let affected = conn.execute(
1214 "DELETE FROM memories WHERE session_id = ?1 AND agent_id = ?2",
1215 params![session_id, agent_id],
1216 )?;
1217 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
1218 Ok(affected)
1219 })
1220 .await?
1221 }
1222
1223 async fn purge_agent(&self, agent_alias: &str) -> anyhow::Result<usize> {
1224 let conn = self.conn.clone();
1225 let agent_alias = agent_alias.to_string();
1226
1227 tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
1228 let conn = conn.lock();
1229 let affected = conn.execute(
1230 "DELETE FROM memories WHERE agent_id = ?1",
1231 params![agent_alias],
1232 )?;
1233 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
1234 Ok(affected)
1235 })
1236 .await?
1237 }
1238
1239 async fn count(&self) -> anyhow::Result<usize> {
1240 let conn = self.conn.clone();
1241
1242 tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
1243 let conn = conn.lock();
1244 let count: i64 =
1245 conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
1246 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
1247 Ok(count as usize)
1248 })
1249 .await?
1250 }
1251
1252 async fn health_check(&self) -> bool {
1253 let conn = self.conn.clone();
1254 tokio::task::spawn_blocking(move || conn.lock().execute_batch("SELECT 1").is_ok())
1255 .await
1256 .unwrap_or(false)
1257 }
1258
1259 async fn reindex(&self) -> anyhow::Result<usize> {
1271 {
1273 let conn = self.conn.clone();
1274 tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
1275 let conn = conn.lock();
1276 conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
1277 Ok(())
1278 })
1279 .await??;
1280 }
1281
1282 if self.embedder.dimensions() == 0 {
1284 return Ok(0);
1285 }
1286
1287 let conn = self.conn.clone();
1288 let entries: Vec<(String, String)> = tokio::task::spawn_blocking(move || {
1289 let conn = conn.lock();
1290 let mut stmt =
1291 conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
1292 let rows = stmt.query_map([], |row| {
1293 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
1294 })?;
1295 Ok::<_, anyhow::Error>(rows.filter_map(std::result::Result::ok).collect())
1296 })
1297 .await??;
1298
1299 let mut count = 0;
1300 for (id, content) in &entries {
1301 if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
1302 let bytes = vector::vec_to_bytes(&emb);
1303 let conn = self.conn.clone();
1304 let id = id.clone();
1305 tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
1306 let conn = conn.lock();
1307 conn.execute(
1308 "UPDATE memories SET embedding = ?1 WHERE id = ?2",
1309 params![bytes, id],
1310 )?;
1311 Ok(())
1312 })
1313 .await??;
1314 count += 1;
1315 }
1316 }
1317
1318 Ok(count)
1319 }
1320
1321 async fn export(&self, filter: &ExportFilter) -> anyhow::Result<Vec<MemoryEntry>> {
1322 let conn = self.conn.clone();
1323 let filter = filter.clone();
1324
1325 tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
1326 let conn = conn.lock();
1327 let mut sql =
1328 "SELECT m.id, m.key, m.content, m.category, m.created_at, m.session_id, m.namespace, m.importance, m.superseded_by, a.alias, m.agent_id \
1329 FROM memories m LEFT JOIN agents a ON a.id = m.agent_id \
1330 WHERE 1=1"
1331 .to_string();
1332 let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
1333 let mut idx = 1;
1334
1335 if let Some(ref ns) = filter.namespace {
1336 let _ = write!(sql, " AND m.namespace = ?{idx}");
1337 param_values.push(Box::new(ns.clone()));
1338 idx += 1;
1339 }
1340 if let Some(ref sid) = filter.session_id {
1341 let _ = write!(sql, " AND m.session_id = ?{idx}");
1342 param_values.push(Box::new(sid.clone()));
1343 idx += 1;
1344 }
1345 if let Some(ref cat) = filter.category {
1346 let _ = write!(sql, " AND m.category = ?{idx}");
1347 param_values.push(Box::new(Self::category_to_str(cat)));
1348 idx += 1;
1349 }
1350 if let Some(ref since) = filter.since {
1351 let _ = write!(sql, " AND m.created_at >= ?{idx}");
1352 param_values.push(Box::new(since.clone()));
1353 idx += 1;
1354 }
1355 if let Some(ref until) = filter.until {
1356 let _ = write!(sql, " AND m.created_at <= ?{idx}");
1357 param_values.push(Box::new(until.clone()));
1358 let _ = idx;
1359 }
1360 sql.push_str(" ORDER BY m.created_at ASC");
1361
1362 let mut stmt = conn.prepare(&sql)?;
1363 let params_ref: Vec<&dyn rusqlite::types::ToSql> =
1364 param_values.iter().map(AsRef::as_ref).collect();
1365 let rows = stmt.query_map(params_ref.as_slice(), |row| {
1366 Ok(MemoryEntry {
1367 id: row.get(0)?,
1368 key: row.get(1)?,
1369 content: row.get(2)?,
1370 category: Self::str_to_category(&row.get::<_, String>(3)?),
1371 timestamp: row.get(4)?,
1372 session_id: row.get(5)?,
1373 score: None,
1374 namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
1375 importance: row.get(7)?,
1376 superseded_by: row.get(8)?,
1377 agent_alias: row.get(9)?,
1378 agent_id: row.get(10)?,
1379 })
1380 })?;
1381
1382 let mut results = Vec::new();
1383 for row in rows {
1384 results.push(row?);
1385 }
1386 Ok(results)
1387 })
1388 .await?
1389 }
1390
1391 async fn recall_namespaced(
1392 &self,
1393 namespace: &str,
1394 query: &str,
1395 limit: usize,
1396 session_id: Option<&str>,
1397 since: Option<&str>,
1398 until: Option<&str>,
1399 ) -> anyhow::Result<Vec<MemoryEntry>> {
1400 let entries = self
1401 .recall(query, limit * 2, session_id, since, until)
1402 .await?;
1403 let filtered: Vec<MemoryEntry> = entries
1404 .into_iter()
1405 .filter(|e| e.namespace == namespace)
1406 .take(limit)
1407 .collect();
1408 Ok(filtered)
1409 }
1410
1411 async fn store_with_metadata(
1412 &self,
1413 key: &str,
1414 content: &str,
1415 category: MemoryCategory,
1416 session_id: Option<&str>,
1417 namespace: Option<&str>,
1418 importance: Option<f64>,
1419 ) -> anyhow::Result<()> {
1420 self.store_with_agent(
1424 key, content, category, session_id, namespace, importance, None,
1425 )
1426 .await
1427 }
1428
1429 async fn store_with_agent(
1430 &self,
1431 key: &str,
1432 content: &str,
1433 category: MemoryCategory,
1434 session_id: Option<&str>,
1435 namespace: Option<&str>,
1436 importance: Option<f64>,
1437 agent_id: Option<&str>,
1438 ) -> anyhow::Result<()> {
1439 let embedding_bytes = self
1440 .get_or_compute_embedding(content)
1441 .await?
1442 .map(|emb| vector::vec_to_bytes(&emb));
1443
1444 let conn = self.conn.clone();
1445 let key = key.to_string();
1446 let content = content.to_string();
1447 let sid = session_id.map(String::from);
1448 let ns = namespace.unwrap_or("default").to_string();
1449 let imp = importance.unwrap_or(0.5);
1450 let aid = agent_id.map(String::from);
1451
1452 tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
1453 let conn = conn.lock();
1454 let now = Local::now().to_rfc3339();
1455 let cat = Self::category_to_str(&category);
1456 let id = Uuid::new_v4().to_string();
1457
1458 conn.execute(
1464 "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id, namespace, importance, agent_id)
1465 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, COALESCE(?11, (SELECT id FROM agents WHERE alias = 'default' LIMIT 1)))
1466 ON CONFLICT(agent_id, key) DO UPDATE SET
1467 content = excluded.content,
1468 category = excluded.category,
1469 embedding = excluded.embedding,
1470 updated_at = excluded.updated_at,
1471 session_id = excluded.session_id,
1472 namespace = excluded.namespace,
1473 importance = excluded.importance",
1474 params![id, key, content, cat, embedding_bytes, now, now, sid, ns, imp, aid],
1475 )?;
1476 Ok(())
1477 })
1478 .await?
1479 }
1480
1481 async fn recall_for_agents(
1482 &self,
1483 allowed_agent_ids: &[&str],
1484 query: &str,
1485 limit: usize,
1486 session_id: Option<&str>,
1487 since: Option<&str>,
1488 until: Option<&str>,
1489 ) -> anyhow::Result<Vec<MemoryEntry>> {
1490 if allowed_agent_ids.is_empty() {
1494 return self.recall(query, limit, session_id, since, until).await;
1495 }
1496
1497 let full_candidate_limit = self.count().await?.max(limit);
1498 let raw = self
1499 .recall(query, full_candidate_limit, session_id, since, until)
1500 .await?;
1501 if raw.is_empty() {
1502 return Ok(Vec::new());
1503 }
1504
1505 let conn = self.conn.clone();
1506 let ids: Vec<String> = raw.iter().map(|e| e.id.clone()).collect();
1507 let allowed: Vec<String> = allowed_agent_ids.iter().map(|s| (*s).to_string()).collect();
1508
1509 let kept: HashSet<String> =
1516 tokio::task::spawn_blocking(move || -> anyhow::Result<HashSet<String>> {
1517 let conn = conn.lock();
1518 let id_placeholders: String = (1..=ids.len())
1519 .map(|i| format!("?{i}"))
1520 .collect::<Vec<_>>()
1521 .join(", ");
1522 let agent_placeholders: String = (ids.len() + 1..=ids.len() + allowed.len())
1523 .map(|i| format!("?{i}"))
1524 .collect::<Vec<_>>()
1525 .join(", ");
1526 let sql = format!(
1527 "SELECT id FROM memories \
1528 WHERE id IN ({id_placeholders}) \
1529 AND agent_id IN ({agent_placeholders})"
1530 );
1531 let mut stmt = conn.prepare(&sql)?;
1532 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> =
1533 Vec::with_capacity(ids.len() + allowed.len());
1534 for id in &ids {
1535 params.push(Box::new(id.clone()) as Box<dyn rusqlite::types::ToSql>);
1536 }
1537 for aid in &allowed {
1538 params.push(Box::new(aid.clone()) as Box<dyn rusqlite::types::ToSql>);
1539 }
1540 let params_ref: Vec<&dyn rusqlite::types::ToSql> =
1541 params.iter().map(AsRef::as_ref).collect();
1542 let rows = stmt.query_map(params_ref.as_slice(), |row| row.get::<_, String>(0))?;
1543 let mut set = HashSet::new();
1544 for row in rows {
1545 set.insert(row?);
1546 }
1547 Ok(set)
1548 })
1549 .await??;
1550
1551 Ok(raw
1552 .into_iter()
1553 .filter(|e| kept.contains(&e.id))
1554 .take(limit)
1555 .collect())
1556 }
1557
1558 async fn ensure_agent_uuid(&self, alias: &str) -> anyhow::Result<String> {
1559 let conn = self.conn.clone();
1560 let alias = alias.to_string();
1561 tokio::task::spawn_blocking(move || -> anyhow::Result<String> {
1562 let conn = conn.lock();
1563 zeroclaw_config::schema::v2::sqlite_ensure_agent_uuid(&conn, &alias)
1564 })
1565 .await?
1566 }
1567}
1568
1569impl ::zeroclaw_api::attribution::Attributable for SqliteMemory {
1570 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1571 ::zeroclaw_api::attribution::Role::Memory(::zeroclaw_api::attribution::MemoryKind::Sqlite)
1572 }
1573 fn alias(&self) -> &str {
1574 &self.alias
1575 }
1576}
1577
1578#[cfg(test)]
1579mod tests {
1580 use super::*;
1581 use tempfile::TempDir;
1582
1583 fn temp_sqlite() -> (TempDir, SqliteMemory) {
1584 let tmp = TempDir::new().unwrap();
1585 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
1586 (tmp, mem)
1587 }
1588
1589 #[tokio::test]
1590 async fn sqlite_name() {
1591 let (_tmp, mem) = temp_sqlite();
1592 assert_eq!(mem.name(), "sqlite");
1593 }
1594
1595 #[tokio::test]
1596 async fn sqlite_health() {
1597 let (_tmp, mem) = temp_sqlite();
1598 assert!(mem.health_check().await);
1599 }
1600
1601 #[tokio::test]
1602 async fn sqlite_store_and_get() {
1603 let (_tmp, mem) = temp_sqlite();
1604 mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None)
1605 .await
1606 .unwrap();
1607
1608 let entry = mem.get("user_lang").await.unwrap();
1609 assert!(entry.is_some());
1610 let entry = entry.unwrap();
1611 assert_eq!(entry.key, "user_lang");
1612 assert_eq!(entry.content, "Prefers Rust");
1613 assert_eq!(entry.category, MemoryCategory::Core);
1614 }
1615
1616 #[tokio::test]
1617 async fn sqlite_store_upsert() {
1618 let (_tmp, mem) = temp_sqlite();
1619 mem.store("pref", "likes Rust", MemoryCategory::Core, None)
1620 .await
1621 .unwrap();
1622 mem.store("pref", "loves Rust", MemoryCategory::Core, None)
1623 .await
1624 .unwrap();
1625
1626 let entry = mem.get("pref").await.unwrap().unwrap();
1627 assert_eq!(entry.content, "loves Rust");
1628 assert_eq!(mem.count().await.unwrap(), 1);
1629 }
1630
1631 #[tokio::test]
1632 async fn sqlite_recall_keyword() {
1633 let (_tmp, mem) = temp_sqlite();
1634 mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None)
1635 .await
1636 .unwrap();
1637 mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
1638 .await
1639 .unwrap();
1640 mem.store(
1641 "c",
1642 "Rust has zero-cost abstractions",
1643 MemoryCategory::Core,
1644 None,
1645 )
1646 .await
1647 .unwrap();
1648
1649 let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
1650 assert_eq!(results.len(), 2);
1651 assert!(
1652 results
1653 .iter()
1654 .all(|r| r.content.to_lowercase().contains("rust"))
1655 );
1656 }
1657
1658 #[tokio::test]
1659 async fn sqlite_recall_for_agents_does_not_lose_allowed_rows_behind_disallowed_matches() {
1660 let (_tmp, mem) = temp_sqlite();
1661 let alpha = mem.ensure_agent_uuid("alpha").await.unwrap();
1662 let rogue = mem.ensure_agent_uuid("rogue").await.unwrap();
1663
1664 for idx in 0..12 {
1665 mem.store_with_agent(
1666 &format!("rogue-{idx}"),
1667 "needle disallowed row",
1668 MemoryCategory::Core,
1669 None,
1670 None,
1671 None,
1672 Some(&rogue),
1673 )
1674 .await
1675 .unwrap();
1676 }
1677 mem.store_with_agent(
1678 "alpha-allowed",
1679 "needle allowed row",
1680 MemoryCategory::Core,
1681 None,
1682 None,
1683 None,
1684 Some(&alpha),
1685 )
1686 .await
1687 .unwrap();
1688
1689 let results = mem
1690 .recall_for_agents(&[alpha.as_str()], "needle", 1, None, None, None)
1691 .await
1692 .unwrap();
1693 assert_eq!(results.len(), 1);
1694 assert_eq!(results[0].key, "alpha-allowed");
1695 }
1696
1697 #[tokio::test]
1698 async fn sqlite_recall_multi_keyword() {
1699 let (_tmp, mem) = temp_sqlite();
1700 mem.store("a", "Rust is fast", MemoryCategory::Core, None)
1701 .await
1702 .unwrap();
1703 mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None)
1704 .await
1705 .unwrap();
1706
1707 let results = mem.recall("fast safe", 10, None, None, None).await.unwrap();
1708 assert!(!results.is_empty());
1709 assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
1711 }
1712
1713 #[tokio::test]
1714 async fn sqlite_recall_no_match() {
1715 let (_tmp, mem) = temp_sqlite();
1716 mem.store("a", "Rust rocks", MemoryCategory::Core, None)
1717 .await
1718 .unwrap();
1719 let results = mem
1720 .recall("javascript", 10, None, None, None)
1721 .await
1722 .unwrap();
1723 assert!(results.is_empty());
1724 }
1725
1726 #[tokio::test]
1727 async fn sqlite_forget() {
1728 let (_tmp, mem) = temp_sqlite();
1729 mem.store("temp", "temporary data", MemoryCategory::Conversation, None)
1730 .await
1731 .unwrap();
1732 assert_eq!(mem.count().await.unwrap(), 1);
1733
1734 let removed = mem.forget("temp").await.unwrap();
1735 assert!(removed);
1736 assert_eq!(mem.count().await.unwrap(), 0);
1737 }
1738
1739 #[tokio::test]
1740 async fn sqlite_forget_nonexistent() {
1741 let (_tmp, mem) = temp_sqlite();
1742 let removed = mem.forget("nope").await.unwrap();
1743 assert!(!removed);
1744 }
1745
1746 #[tokio::test]
1747 async fn sqlite_list_all() {
1748 let (_tmp, mem) = temp_sqlite();
1749 mem.store("a", "one", MemoryCategory::Core, None)
1750 .await
1751 .unwrap();
1752 mem.store("b", "two", MemoryCategory::Daily, None)
1753 .await
1754 .unwrap();
1755 mem.store("c", "three", MemoryCategory::Conversation, None)
1756 .await
1757 .unwrap();
1758
1759 let all = mem.list(None, None).await.unwrap();
1760 assert_eq!(all.len(), 3);
1761 }
1762
1763 #[tokio::test]
1764 async fn sqlite_list_by_category() {
1765 let (_tmp, mem) = temp_sqlite();
1766 mem.store("a", "core1", MemoryCategory::Core, None)
1767 .await
1768 .unwrap();
1769 mem.store("b", "core2", MemoryCategory::Core, None)
1770 .await
1771 .unwrap();
1772 mem.store("c", "daily1", MemoryCategory::Daily, None)
1773 .await
1774 .unwrap();
1775
1776 let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
1777 assert_eq!(core.len(), 2);
1778
1779 let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
1780 assert_eq!(daily.len(), 1);
1781 }
1782
1783 #[tokio::test]
1784 async fn sqlite_count_empty() {
1785 let (_tmp, mem) = temp_sqlite();
1786 assert_eq!(mem.count().await.unwrap(), 0);
1787 }
1788
1789 #[tokio::test]
1790 async fn sqlite_get_nonexistent() {
1791 let (_tmp, mem) = temp_sqlite();
1792 assert!(mem.get("nope").await.unwrap().is_none());
1793 }
1794
1795 #[tokio::test]
1796 async fn sqlite_db_persists() {
1797 let tmp = TempDir::new().unwrap();
1798
1799 {
1800 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
1801 mem.store("persist", "I survive restarts", MemoryCategory::Core, None)
1802 .await
1803 .unwrap();
1804 }
1805
1806 let mem2 = SqliteMemory::new("test", tmp.path()).unwrap();
1808 let entry = mem2.get("persist").await.unwrap();
1809 assert!(entry.is_some());
1810 assert_eq!(entry.unwrap().content, "I survive restarts");
1811 }
1812
1813 #[tokio::test]
1814 async fn sqlite_category_roundtrip() {
1815 let (_tmp, mem) = temp_sqlite();
1816 let categories = [
1817 MemoryCategory::Core,
1818 MemoryCategory::Daily,
1819 MemoryCategory::Conversation,
1820 MemoryCategory::Custom("project".into()),
1821 ];
1822
1823 for (i, cat) in categories.iter().enumerate() {
1824 mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None)
1825 .await
1826 .unwrap();
1827 }
1828
1829 for (i, cat) in categories.iter().enumerate() {
1830 let entry = mem.get(&format!("k{i}")).await.unwrap().unwrap();
1831 assert_eq!(&entry.category, cat);
1832 }
1833 }
1834
1835 #[tokio::test]
1838 async fn fts5_bm25_ranking() {
1839 let (_tmp, mem) = temp_sqlite();
1840 mem.store(
1841 "a",
1842 "Rust is a systems programming language",
1843 MemoryCategory::Core,
1844 None,
1845 )
1846 .await
1847 .unwrap();
1848 mem.store(
1849 "b",
1850 "Python is great for scripting",
1851 MemoryCategory::Core,
1852 None,
1853 )
1854 .await
1855 .unwrap();
1856 mem.store(
1857 "c",
1858 "Rust and Rust and Rust everywhere",
1859 MemoryCategory::Core,
1860 None,
1861 )
1862 .await
1863 .unwrap();
1864
1865 let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
1866 assert!(results.len() >= 2);
1867 for r in &results {
1869 assert!(
1870 r.content.to_lowercase().contains("rust"),
1871 "Expected 'rust' in: {}",
1872 r.content
1873 );
1874 }
1875 }
1876
1877 #[tokio::test]
1878 async fn fts5_multi_word_query() {
1879 let (_tmp, mem) = temp_sqlite();
1880 mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None)
1881 .await
1882 .unwrap();
1883 mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None)
1884 .await
1885 .unwrap();
1886 mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None)
1887 .await
1888 .unwrap();
1889
1890 let results = mem.recall("quick dog", 10, None, None, None).await.unwrap();
1891 assert!(!results.is_empty());
1892 assert!(results[0].content.contains("quick"));
1894 }
1895
1896 #[tokio::test]
1897 async fn recall_empty_query_returns_recent_entries() {
1898 let (_tmp, mem) = temp_sqlite();
1899 mem.store("a", "data", MemoryCategory::Core, None)
1900 .await
1901 .unwrap();
1902 let results = mem.recall("", 10, None, None, None).await.unwrap();
1904 assert_eq!(results.len(), 1);
1905 assert_eq!(results[0].key, "a");
1906 }
1907
1908 #[tokio::test]
1909 async fn recall_whitespace_query_returns_recent_entries() {
1910 let (_tmp, mem) = temp_sqlite();
1911 mem.store("a", "data", MemoryCategory::Core, None)
1912 .await
1913 .unwrap();
1914 let results = mem.recall(" ", 10, None, None, None).await.unwrap();
1916 assert_eq!(results.len(), 1);
1917 assert_eq!(results[0].key, "a");
1918 }
1919
1920 #[tokio::test]
1921 async fn recall_star_query_returns_recent_entries() {
1922 let (_tmp, mem) = temp_sqlite();
1923 mem.store("a", "first memory", MemoryCategory::Core, None)
1924 .await
1925 .unwrap();
1926 mem.store("b", "second memory", MemoryCategory::Core, None)
1927 .await
1928 .unwrap();
1929
1930 let results = mem.recall("*", 10, None, None, None).await.unwrap();
1931 assert_eq!(results.len(), 2);
1932 assert!(results.iter().any(|entry| entry.key == "a"));
1933 assert!(results.iter().any(|entry| entry.key == "b"));
1934 }
1935
1936 #[test]
1939 fn content_hash_deterministic() {
1940 let h1 = SqliteMemory::content_hash("hello world");
1941 let h2 = SqliteMemory::content_hash("hello world");
1942 assert_eq!(h1, h2);
1943 }
1944
1945 #[test]
1946 fn content_hash_different_inputs() {
1947 let h1 = SqliteMemory::content_hash("hello");
1948 let h2 = SqliteMemory::content_hash("world");
1949 assert_ne!(h1, h2);
1950 }
1951
1952 #[tokio::test]
1955 async fn schema_has_fts5_table() {
1956 let (_tmp, mem) = temp_sqlite();
1957 let conn = mem.conn.lock();
1958 let count: i64 = conn
1960 .query_row(
1961 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='memories_fts'",
1962 [],
1963 |row| row.get(0),
1964 )
1965 .unwrap();
1966 assert_eq!(count, 1);
1967 }
1968
1969 #[tokio::test]
1970 async fn schema_has_embedding_cache() {
1971 let (_tmp, mem) = temp_sqlite();
1972 let conn = mem.conn.lock();
1973 let count: i64 = conn
1974 .query_row(
1975 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'",
1976 [],
1977 |row| row.get(0),
1978 )
1979 .unwrap();
1980 assert_eq!(count, 1);
1981 }
1982
1983 #[tokio::test]
1984 async fn schema_memories_has_embedding_column() {
1985 let (_tmp, mem) = temp_sqlite();
1986 let conn = mem.conn.lock();
1987 let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0");
1989 assert!(result.is_ok());
1990 }
1991
1992 #[tokio::test]
1995 async fn fts5_syncs_on_insert() {
1996 let (_tmp, mem) = temp_sqlite();
1997 mem.store(
1998 "test_key",
1999 "unique_searchterm_xyz",
2000 MemoryCategory::Core,
2001 None,
2002 )
2003 .await
2004 .unwrap();
2005
2006 let conn = mem.conn.lock();
2007 let count: i64 = conn
2008 .query_row(
2009 "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'",
2010 [],
2011 |row| row.get(0),
2012 )
2013 .unwrap();
2014 assert_eq!(count, 1);
2015 }
2016
2017 #[tokio::test]
2018 async fn fts5_syncs_on_delete() {
2019 let (_tmp, mem) = temp_sqlite();
2020 mem.store(
2021 "del_key",
2022 "deletable_content_abc",
2023 MemoryCategory::Core,
2024 None,
2025 )
2026 .await
2027 .unwrap();
2028 mem.forget("del_key").await.unwrap();
2029
2030 let conn = mem.conn.lock();
2031 let count: i64 = conn
2032 .query_row(
2033 "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'",
2034 [],
2035 |row| row.get(0),
2036 )
2037 .unwrap();
2038 assert_eq!(count, 0);
2039 }
2040
2041 #[tokio::test]
2042 async fn fts5_syncs_on_update() {
2043 let (_tmp, mem) = temp_sqlite();
2044 mem.store(
2045 "upd_key",
2046 "original_content_111",
2047 MemoryCategory::Core,
2048 None,
2049 )
2050 .await
2051 .unwrap();
2052 mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None)
2053 .await
2054 .unwrap();
2055
2056 let conn = mem.conn.lock();
2057 let old: i64 = conn
2059 .query_row(
2060 "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"original_content_111\"'",
2061 [],
2062 |row| row.get(0),
2063 )
2064 .unwrap();
2065 assert_eq!(old, 0);
2066
2067 let new: i64 = conn
2069 .query_row(
2070 "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"updated_content_222\"'",
2071 [],
2072 |row| row.get(0),
2073 )
2074 .unwrap();
2075 assert_eq!(new, 1);
2076 }
2077
2078 #[test]
2081 fn open_with_timeout_succeeds_when_fast() {
2082 let tmp = TempDir::new().unwrap();
2083 let embedder = Arc::new(super::super::embeddings::NoopEmbedding);
2084 let mem = SqliteMemory::with_embedder(
2085 "test",
2086 tmp.path(),
2087 embedder,
2088 0.7,
2089 0.3,
2090 1000,
2091 Some(5),
2092 SearchMode::default(),
2093 );
2094 assert!(
2095 mem.is_ok(),
2096 "open with 5s timeout should succeed on fast path"
2097 );
2098 assert_eq!(mem.unwrap().name(), "sqlite");
2099 }
2100
2101 #[tokio::test]
2102 async fn open_with_timeout_store_recall_unchanged() {
2103 let tmp = TempDir::new().unwrap();
2104 let mem = SqliteMemory::with_embedder(
2105 "test",
2106 tmp.path(),
2107 Arc::new(super::super::embeddings::NoopEmbedding),
2108 0.7,
2109 0.3,
2110 1000,
2111 Some(2),
2112 SearchMode::default(),
2113 )
2114 .unwrap();
2115 mem.store(
2116 "timeout_key",
2117 "value with timeout",
2118 MemoryCategory::Core,
2119 None,
2120 )
2121 .await
2122 .unwrap();
2123 let entry = mem.get("timeout_key").await.unwrap().unwrap();
2124 assert_eq!(entry.content, "value with timeout");
2125 }
2126
2127 #[test]
2130 fn with_embedder_noop() {
2131 let tmp = TempDir::new().unwrap();
2132 let embedder = Arc::new(super::super::embeddings::NoopEmbedding);
2133 let mem = SqliteMemory::with_embedder(
2134 "test",
2135 tmp.path(),
2136 embedder,
2137 0.7,
2138 0.3,
2139 1000,
2140 None,
2141 SearchMode::default(),
2142 );
2143 assert!(mem.is_ok());
2144 assert_eq!(mem.unwrap().name(), "sqlite");
2145 }
2146
2147 #[tokio::test]
2150 async fn reindex_rebuilds_fts() {
2151 let (_tmp, mem) = temp_sqlite();
2152 mem.store("r1", "reindex test alpha", MemoryCategory::Core, None)
2153 .await
2154 .unwrap();
2155 mem.store("r2", "reindex test beta", MemoryCategory::Core, None)
2156 .await
2157 .unwrap();
2158
2159 let count = mem.reindex().await.unwrap();
2161 assert_eq!(count, 0);
2162
2163 let results = mem.recall("reindex", 10, None, None, None).await.unwrap();
2165 assert_eq!(results.len(), 2);
2166 }
2167
2168 #[tokio::test]
2171 async fn recall_respects_limit() {
2172 let (_tmp, mem) = temp_sqlite();
2173 for i in 0..20 {
2174 mem.store(
2175 &format!("k{i}"),
2176 &format!("common keyword item {i}"),
2177 MemoryCategory::Core,
2178 None,
2179 )
2180 .await
2181 .unwrap();
2182 }
2183
2184 let results = mem
2185 .recall("common keyword", 5, None, None, None)
2186 .await
2187 .unwrap();
2188 assert!(results.len() <= 5);
2189 }
2190
2191 #[tokio::test]
2194 async fn recall_results_have_scores() {
2195 let (_tmp, mem) = temp_sqlite();
2196 mem.store("s1", "scored result test", MemoryCategory::Core, None)
2197 .await
2198 .unwrap();
2199
2200 let results = mem.recall("scored", 10, None, None, None).await.unwrap();
2201 assert!(!results.is_empty());
2202 for r in &results {
2203 assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
2204 }
2205 }
2206
2207 #[tokio::test]
2210 async fn recall_with_quotes_in_query() {
2211 let (_tmp, mem) = temp_sqlite();
2212 mem.store("q1", "He said hello world", MemoryCategory::Core, None)
2213 .await
2214 .unwrap();
2215 let results = mem.recall("\"hello\"", 10, None, None, None).await.unwrap();
2217 assert!(results.len() <= 10);
2219 }
2220
2221 #[tokio::test]
2222 async fn recall_with_asterisk_in_query() {
2223 let (_tmp, mem) = temp_sqlite();
2224 mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
2225 .await
2226 .unwrap();
2227 mem.store("b1", "unrelated recent content", MemoryCategory::Core, None)
2228 .await
2229 .unwrap();
2230 let results = mem.recall("wild*", 10, None, None, None).await.unwrap();
2231 assert!(results.iter().any(|entry| entry.key == "a1"));
2232 assert!(results.iter().all(|entry| entry.key != "b1"));
2233 }
2234
2235 #[tokio::test]
2236 async fn recall_prefix_wildcard_like_fallback_keeps_token_prefix() {
2237 let tmp = TempDir::new().unwrap();
2238 let mem = SqliteMemory::with_embedder(
2239 "test",
2240 tmp.path(),
2241 Arc::new(super::super::embeddings::NoopEmbedding),
2242 0.7,
2243 0.3,
2244 1000,
2245 None,
2246 SearchMode::Embedding,
2247 )
2248 .unwrap();
2249 mem.store("a1", "fallback wildcard token", MemoryCategory::Core, None)
2250 .await
2251 .unwrap();
2252 mem.store("b1", "fallback unwild token", MemoryCategory::Core, None)
2253 .await
2254 .unwrap();
2255
2256 let results = mem.recall("wild*", 10, None, None, None).await.unwrap();
2257 assert!(results.iter().any(|entry| entry.key == "a1"));
2258 assert!(results.iter().all(|entry| entry.key != "b1"));
2259 }
2260
2261 #[tokio::test]
2262 async fn recall_prefix_wildcard_like_fallback_overfetches_filtered_rows() {
2263 let tmp = TempDir::new().unwrap();
2264 let mem = SqliteMemory::with_embedder(
2265 "test",
2266 tmp.path(),
2267 Arc::new(super::super::embeddings::NoopEmbedding),
2268 0.7,
2269 0.3,
2270 1000,
2271 None,
2272 SearchMode::Embedding,
2273 )
2274 .unwrap();
2275 mem.store(
2276 "real",
2277 "fallback wildcard token",
2278 MemoryCategory::Core,
2279 None,
2280 )
2281 .await
2282 .unwrap();
2283 for i in 0..3 {
2284 mem.store(
2285 &format!("noise{i}"),
2286 "fallback unwild token",
2287 MemoryCategory::Core,
2288 None,
2289 )
2290 .await
2291 .unwrap();
2292 }
2293 {
2294 let conn = mem.conn.lock();
2295 conn.execute(
2296 "UPDATE memories SET updated_at = ?1 WHERE key = ?2",
2297 rusqlite::params!["2026-05-03T00:00:00Z", "real"],
2298 )
2299 .unwrap();
2300 for i in 0..3 {
2301 conn.execute(
2302 "UPDATE memories SET updated_at = ?1 WHERE key = ?2",
2303 rusqlite::params![format!("2026-05-03T00:00:0{}Z", i + 1), format!("noise{i}")],
2304 )
2305 .unwrap();
2306 }
2307 }
2308
2309 let results = mem.recall("wild*", 1, None, None, None).await.unwrap();
2310 assert_eq!(results.len(), 1);
2311 assert_eq!(results[0].key, "real");
2312 }
2313
2314 #[tokio::test]
2315 async fn recall_with_parentheses_in_query() {
2316 let (_tmp, mem) = temp_sqlite();
2317 mem.store("p1", "function call test", MemoryCategory::Core, None)
2318 .await
2319 .unwrap();
2320 let results = mem
2321 .recall("function()", 10, None, None, None)
2322 .await
2323 .unwrap();
2324 assert!(results.len() <= 10);
2325 }
2326
2327 #[tokio::test]
2328 async fn recall_with_sql_injection_attempt() {
2329 let (_tmp, mem) = temp_sqlite();
2330 mem.store("safe", "normal content", MemoryCategory::Core, None)
2331 .await
2332 .unwrap();
2333 let results = mem
2335 .recall("'; DROP TABLE memories; --", 10, None, None, None)
2336 .await
2337 .unwrap();
2338 assert!(results.len() <= 10);
2339 assert_eq!(mem.count().await.unwrap(), 1);
2341 }
2342
2343 #[tokio::test]
2346 async fn store_empty_content() {
2347 let (_tmp, mem) = temp_sqlite();
2348 mem.store("empty", "", MemoryCategory::Core, None)
2349 .await
2350 .unwrap();
2351 let entry = mem.get("empty").await.unwrap().unwrap();
2352 assert_eq!(entry.content, "");
2353 }
2354
2355 #[tokio::test]
2356 async fn store_empty_key() {
2357 let (_tmp, mem) = temp_sqlite();
2358 mem.store("", "content for empty key", MemoryCategory::Core, None)
2359 .await
2360 .unwrap();
2361 let entry = mem.get("").await.unwrap().unwrap();
2362 assert_eq!(entry.content, "content for empty key");
2363 }
2364
2365 #[tokio::test]
2366 async fn store_very_long_content() {
2367 let (_tmp, mem) = temp_sqlite();
2368 let long_content = "x".repeat(100_000);
2369 mem.store("long", &long_content, MemoryCategory::Core, None)
2370 .await
2371 .unwrap();
2372 let entry = mem.get("long").await.unwrap().unwrap();
2373 assert_eq!(entry.content.len(), 100_000);
2374 }
2375
2376 #[tokio::test]
2377 async fn store_unicode_and_emoji() {
2378 let (_tmp, mem) = temp_sqlite();
2379 mem.store(
2380 "emoji_key_🦀",
2381 "こんにちは 🚀 Ñoño",
2382 MemoryCategory::Core,
2383 None,
2384 )
2385 .await
2386 .unwrap();
2387 let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
2388 assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
2389 }
2390
2391 #[tokio::test]
2392 async fn store_content_with_newlines_and_tabs() {
2393 let (_tmp, mem) = temp_sqlite();
2394 let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
2395 mem.store("whitespace", content, MemoryCategory::Core, None)
2396 .await
2397 .unwrap();
2398 let entry = mem.get("whitespace").await.unwrap().unwrap();
2399 assert_eq!(entry.content, content);
2400 }
2401
2402 #[tokio::test]
2405 async fn recall_single_character_query() {
2406 let (_tmp, mem) = temp_sqlite();
2407 mem.store("a", "x marks the spot", MemoryCategory::Core, None)
2408 .await
2409 .unwrap();
2410 let results = mem.recall("x", 10, None, None, None).await.unwrap();
2412 assert!(results.len() <= 10);
2414 }
2415
2416 #[tokio::test]
2417 async fn recall_limit_zero() {
2418 let (_tmp, mem) = temp_sqlite();
2419 mem.store("a", "some content", MemoryCategory::Core, None)
2420 .await
2421 .unwrap();
2422 let results = mem.recall("some", 0, None, None, None).await.unwrap();
2423 assert!(results.is_empty());
2424 }
2425
2426 #[tokio::test]
2427 async fn recall_limit_one() {
2428 let (_tmp, mem) = temp_sqlite();
2429 mem.store("a", "matching content alpha", MemoryCategory::Core, None)
2430 .await
2431 .unwrap();
2432 mem.store("b", "matching content beta", MemoryCategory::Core, None)
2433 .await
2434 .unwrap();
2435 let results = mem
2436 .recall("matching content", 1, None, None, None)
2437 .await
2438 .unwrap();
2439 assert_eq!(results.len(), 1);
2440 }
2441
2442 #[tokio::test]
2443 async fn recall_matches_by_key_not_just_content() {
2444 let (_tmp, mem) = temp_sqlite();
2445 mem.store(
2446 "rust_preferences",
2447 "User likes systems programming",
2448 MemoryCategory::Core,
2449 None,
2450 )
2451 .await
2452 .unwrap();
2453 let results = mem.recall("rust", 10, None, None, None).await.unwrap();
2455 assert!(!results.is_empty(), "Should match by key");
2456 }
2457
2458 #[tokio::test]
2459 async fn recall_unicode_query() {
2460 let (_tmp, mem) = temp_sqlite();
2461 mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
2462 .await
2463 .unwrap();
2464 let results = mem.recall("日本語", 10, None, None, None).await.unwrap();
2465 assert!(!results.is_empty());
2466 }
2467
2468 #[tokio::test]
2471 async fn schema_idempotent_reopen() {
2472 let tmp = TempDir::new().unwrap();
2473 {
2474 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
2475 mem.store("k1", "v1", MemoryCategory::Core, None)
2476 .await
2477 .unwrap();
2478 }
2479 let mem2 = SqliteMemory::new("test", tmp.path()).unwrap();
2481 let entry = mem2.get("k1").await.unwrap();
2482 assert!(entry.is_some());
2483 assert_eq!(entry.unwrap().content, "v1");
2484 mem2.store("k2", "v2", MemoryCategory::Daily, None)
2486 .await
2487 .unwrap();
2488 assert_eq!(mem2.count().await.unwrap(), 2);
2489 }
2490
2491 #[tokio::test]
2492 async fn schema_triple_open() {
2493 let tmp = TempDir::new().unwrap();
2494 let _m1 = SqliteMemory::new("test", tmp.path()).unwrap();
2495 let _m2 = SqliteMemory::new("test", tmp.path()).unwrap();
2496 let m3 = SqliteMemory::new("test", tmp.path()).unwrap();
2497 assert!(m3.health_check().await);
2498 }
2499
2500 #[tokio::test]
2503 async fn forget_then_recall_no_ghost_results() {
2504 let (_tmp, mem) = temp_sqlite();
2505 mem.store(
2506 "ghost",
2507 "phantom memory content",
2508 MemoryCategory::Core,
2509 None,
2510 )
2511 .await
2512 .unwrap();
2513 mem.forget("ghost").await.unwrap();
2514 let results = mem
2515 .recall("phantom memory", 10, None, None, None)
2516 .await
2517 .unwrap();
2518 assert!(
2519 results.is_empty(),
2520 "Deleted memory should not appear in recall"
2521 );
2522 }
2523
2524 #[tokio::test]
2525 async fn forget_and_re_store_same_key() {
2526 let (_tmp, mem) = temp_sqlite();
2527 mem.store("cycle", "version 1", MemoryCategory::Core, None)
2528 .await
2529 .unwrap();
2530 mem.forget("cycle").await.unwrap();
2531 mem.store("cycle", "version 2", MemoryCategory::Core, None)
2532 .await
2533 .unwrap();
2534 let entry = mem.get("cycle").await.unwrap().unwrap();
2535 assert_eq!(entry.content, "version 2");
2536 assert_eq!(mem.count().await.unwrap(), 1);
2537 }
2538
2539 #[tokio::test]
2542 async fn reindex_empty_db() {
2543 let (_tmp, mem) = temp_sqlite();
2544 let count = mem.reindex().await.unwrap();
2545 assert_eq!(count, 0);
2546 }
2547
2548 #[tokio::test]
2549 async fn reindex_twice_is_safe() {
2550 let (_tmp, mem) = temp_sqlite();
2551 mem.store("r1", "reindex data", MemoryCategory::Core, None)
2552 .await
2553 .unwrap();
2554 mem.reindex().await.unwrap();
2555 let count = mem.reindex().await.unwrap();
2556 assert_eq!(count, 0); let results = mem.recall("reindex", 10, None, None, None).await.unwrap();
2559 assert_eq!(results.len(), 1);
2560 }
2561
2562 #[test]
2565 fn content_hash_empty_string() {
2566 let h = SqliteMemory::content_hash("");
2567 assert!(!h.is_empty());
2568 assert_eq!(h.len(), 16); }
2570
2571 #[test]
2572 fn content_hash_unicode() {
2573 let h1 = SqliteMemory::content_hash("🦀");
2574 let h2 = SqliteMemory::content_hash("🦀");
2575 assert_eq!(h1, h2);
2576 let h3 = SqliteMemory::content_hash("🚀");
2577 assert_ne!(h1, h3);
2578 }
2579
2580 #[test]
2581 fn content_hash_long_input() {
2582 let long = "a".repeat(1_000_000);
2583 let h = SqliteMemory::content_hash(&long);
2584 assert_eq!(h.len(), 16);
2585 }
2586
2587 #[test]
2590 fn category_roundtrip_custom_with_spaces() {
2591 let cat = MemoryCategory::Custom("my custom category".into());
2592 let s = SqliteMemory::category_to_str(&cat);
2593 assert_eq!(s, "my custom category");
2594 let back = SqliteMemory::str_to_category(&s);
2595 assert_eq!(back, cat);
2596 }
2597
2598 #[test]
2599 fn category_roundtrip_empty_custom() {
2600 let cat = MemoryCategory::Custom(String::new());
2601 let s = SqliteMemory::category_to_str(&cat);
2602 assert_eq!(s, "");
2603 let back = SqliteMemory::str_to_category(&s);
2604 assert_eq!(back, MemoryCategory::Custom(String::new()));
2605 }
2606
2607 #[tokio::test]
2610 async fn list_custom_category() {
2611 let (_tmp, mem) = temp_sqlite();
2612 mem.store(
2613 "c1",
2614 "custom1",
2615 MemoryCategory::Custom("project".into()),
2616 None,
2617 )
2618 .await
2619 .unwrap();
2620 mem.store(
2621 "c2",
2622 "custom2",
2623 MemoryCategory::Custom("project".into()),
2624 None,
2625 )
2626 .await
2627 .unwrap();
2628 mem.store("c3", "other", MemoryCategory::Core, None)
2629 .await
2630 .unwrap();
2631
2632 let project = mem
2633 .list(Some(&MemoryCategory::Custom("project".into())), None)
2634 .await
2635 .unwrap();
2636 assert_eq!(project.len(), 2);
2637 }
2638
2639 #[tokio::test]
2640 async fn list_empty_db() {
2641 let (_tmp, mem) = temp_sqlite();
2642 let all = mem.list(None, None).await.unwrap();
2643 assert!(all.is_empty());
2644 }
2645
2646 #[tokio::test]
2649 async fn sqlite_purge_namespace_deletes_only_all_matching_entries() {
2650 let (_tmp, mem) = temp_sqlite();
2651
2652 mem.store_with_metadata("a", "data", MemoryCategory::Core, None, Some("ns1"), None)
2653 .await
2654 .unwrap();
2655 mem.store_with_metadata("b", "data", MemoryCategory::Core, None, Some("ns2"), None)
2656 .await
2657 .unwrap();
2658
2659 let in_ns1 =
2660 |entries: &[MemoryEntry]| entries.iter().filter(|e| e.namespace == "ns1").count();
2661
2662 let before = mem.list(None, None).await.unwrap();
2663 let deleted = mem.purge_namespace("ns1").await.unwrap();
2664 let after = mem.list(None, None).await.unwrap();
2665
2666 assert_eq!(in_ns1(&after), 0);
2667 assert_eq!(after.len() - in_ns1(&after), before.len() - in_ns1(&before));
2668 assert_eq!(deleted, in_ns1(&before));
2669 }
2670
2671 #[tokio::test]
2672 async fn sqlite_purge_session_removes_all_matching_entries() {
2673 let (_tmp, mem) = temp_sqlite();
2674 mem.store("a1", "data1", MemoryCategory::Core, Some("sess-a"))
2675 .await
2676 .unwrap();
2677 mem.store("a2", "data2", MemoryCategory::Core, Some("sess-a"))
2678 .await
2679 .unwrap();
2680 mem.store("b1", "data3", MemoryCategory::Core, Some("sess-b"))
2681 .await
2682 .unwrap();
2683
2684 let count = mem.purge_session("sess-a").await.unwrap();
2685 assert_eq!(count, 2);
2686 assert_eq!(mem.count().await.unwrap(), 1);
2687 }
2688
2689 #[tokio::test]
2690 async fn sqlite_purge_session_preserves_other_sessions() {
2691 let (_tmp, mem) = temp_sqlite();
2692 mem.store("a1", "data1", MemoryCategory::Core, Some("sess-a"))
2693 .await
2694 .unwrap();
2695 mem.store("b1", "data2", MemoryCategory::Core, Some("sess-b"))
2696 .await
2697 .unwrap();
2698 mem.store("c1", "data3", MemoryCategory::Core, None)
2699 .await
2700 .unwrap();
2701
2702 let count = mem.purge_session("sess-a").await.unwrap();
2703 assert_eq!(count, 1);
2704 assert_eq!(mem.count().await.unwrap(), 2);
2705
2706 let remaining = mem.list(None, None).await.unwrap();
2707 assert!(
2708 remaining
2709 .iter()
2710 .all(|e| e.session_id.as_deref() != Some("sess-a"))
2711 );
2712 }
2713
2714 #[tokio::test]
2715 async fn sqlite_purge_session_returns_count() {
2716 let (_tmp, mem) = temp_sqlite();
2717 for i in 0..3 {
2718 mem.store(
2719 &format!("k{i}"),
2720 "data",
2721 MemoryCategory::Core,
2722 Some("target-sess"),
2723 )
2724 .await
2725 .unwrap();
2726 }
2727
2728 let count = mem.purge_session("target-sess").await.unwrap();
2729 assert_eq!(count, 3);
2730 }
2731
2732 #[tokio::test]
2733 async fn sqlite_purge_session_empty_session_is_noop() {
2734 let (_tmp, mem) = temp_sqlite();
2735 mem.store("a", "data", MemoryCategory::Core, Some("sess"))
2736 .await
2737 .unwrap();
2738
2739 let count = mem.purge_session("").await.unwrap();
2740 assert_eq!(count, 0);
2741 assert_eq!(mem.count().await.unwrap(), 1);
2742 }
2743
2744 #[tokio::test]
2747 async fn store_and_recall_with_session_id() {
2748 let (_tmp, mem) = temp_sqlite();
2749 mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a"))
2750 .await
2751 .unwrap();
2752 mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b"))
2753 .await
2754 .unwrap();
2755 mem.store("k3", "no session fact", MemoryCategory::Core, None)
2756 .await
2757 .unwrap();
2758
2759 let results = mem
2761 .recall("fact", 10, Some("sess-a"), None, None)
2762 .await
2763 .unwrap();
2764 assert_eq!(results.len(), 1);
2765 assert_eq!(results[0].key, "k1");
2766 assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
2767 }
2768
2769 #[tokio::test]
2770 async fn recall_no_session_filter_returns_all() {
2771 let (_tmp, mem) = temp_sqlite();
2772 mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a"))
2773 .await
2774 .unwrap();
2775 mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b"))
2776 .await
2777 .unwrap();
2778 mem.store("k3", "gamma fact", MemoryCategory::Core, None)
2779 .await
2780 .unwrap();
2781
2782 let results = mem.recall("fact", 10, None, None, None).await.unwrap();
2784 assert_eq!(results.len(), 3);
2785 }
2786
2787 #[tokio::test]
2788 async fn cross_session_recall_isolation() {
2789 let (_tmp, mem) = temp_sqlite();
2790 mem.store(
2791 "secret",
2792 "session A secret data",
2793 MemoryCategory::Core,
2794 Some("sess-a"),
2795 )
2796 .await
2797 .unwrap();
2798
2799 let results = mem
2801 .recall("secret", 10, Some("sess-b"), None, None)
2802 .await
2803 .unwrap();
2804 assert!(results.is_empty());
2805
2806 let results = mem
2808 .recall("secret", 10, Some("sess-a"), None, None)
2809 .await
2810 .unwrap();
2811 assert_eq!(results.len(), 1);
2812 }
2813
2814 #[tokio::test]
2815 async fn list_with_session_filter() {
2816 let (_tmp, mem) = temp_sqlite();
2817 mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a"))
2818 .await
2819 .unwrap();
2820 mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a"))
2821 .await
2822 .unwrap();
2823 mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b"))
2824 .await
2825 .unwrap();
2826 mem.store("k4", "none1", MemoryCategory::Core, None)
2827 .await
2828 .unwrap();
2829
2830 let results = mem.list(None, Some("sess-a")).await.unwrap();
2832 assert_eq!(results.len(), 2);
2833 assert!(
2834 results
2835 .iter()
2836 .all(|e| e.session_id.as_deref() == Some("sess-a"))
2837 );
2838
2839 let results = mem
2841 .list(Some(&MemoryCategory::Core), Some("sess-a"))
2842 .await
2843 .unwrap();
2844 assert_eq!(results.len(), 1);
2845 assert_eq!(results[0].key, "k1");
2846 }
2847
2848 #[tokio::test]
2849 async fn schema_migration_idempotent_on_reopen() {
2850 let tmp = TempDir::new().unwrap();
2851
2852 {
2854 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
2855 mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x"))
2856 .await
2857 .unwrap();
2858 }
2859
2860 {
2862 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
2863 let results = mem
2864 .recall("reopen", 10, Some("sess-x"), None, None)
2865 .await
2866 .unwrap();
2867 assert_eq!(results.len(), 1);
2868 assert_eq!(results[0].key, "k1");
2869 assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
2870 }
2871 }
2872
2873 #[tokio::test]
2874 async fn schema_migration_tolerates_concurrent_initialization() {
2875 let tmp = TempDir::new().unwrap();
2876
2877 let db_path = tmp.path().join("memory").join("brain.db");
2880 std::fs::create_dir_all(db_path.parent().unwrap()).unwrap();
2881 {
2882 let conn = rusqlite::Connection::open(&db_path).unwrap();
2883 conn.execute_batch(
2884 "CREATE TABLE IF NOT EXISTS memories (
2885 id TEXT PRIMARY KEY,
2886 key TEXT NOT NULL UNIQUE,
2887 content TEXT NOT NULL,
2888 category TEXT NOT NULL DEFAULT 'core',
2889 embedding BLOB,
2890 created_at TEXT NOT NULL,
2891 updated_at TEXT NOT NULL
2892 );",
2893 )
2894 .unwrap();
2895 }
2896
2897 let workers = 12usize;
2898 let barrier = std::sync::Arc::new(std::sync::Barrier::new(workers));
2899 let mut handles = Vec::new();
2900 for _ in 0..workers {
2901 let dir = tmp.path().to_path_buf();
2902 let barrier = barrier.clone();
2903 handles.push(tokio::task::spawn_blocking(move || {
2904 barrier.wait();
2905 SqliteMemory::new("test", &dir)
2906 }));
2907 }
2908
2909 for h in handles {
2910 h.await.unwrap().unwrap();
2911 }
2912
2913 let conn = rusqlite::Connection::open(&db_path).unwrap();
2915 let mut stmt = conn.prepare("PRAGMA table_info(memories)").unwrap();
2916 let mut rows = stmt.query([]).unwrap();
2917 let mut cols = std::collections::HashSet::<String>::new();
2918 while let Some(row) = rows.next().unwrap() {
2919 cols.insert(row.get::<_, String>(1).unwrap());
2920 }
2921
2922 assert!(cols.contains("session_id"));
2923 assert!(cols.contains("namespace"));
2924 assert!(cols.contains("importance"));
2925 assert!(cols.contains("superseded_by"));
2926 }
2927
2928 #[tokio::test]
2931 async fn sqlite_concurrent_writes_no_data_loss() {
2932 let (_tmp, mem) = temp_sqlite();
2933 let mem = std::sync::Arc::new(mem);
2934
2935 let mut handles = Vec::new();
2936 for i in 0..10 {
2937 let mem = std::sync::Arc::clone(&mem);
2938 handles.push(tokio::spawn(async move {
2939 mem.store(
2940 &format!("concurrent_key_{i}"),
2941 &format!("value_{i}"),
2942 MemoryCategory::Core,
2943 None,
2944 )
2945 .await
2946 .unwrap();
2947 }));
2948 }
2949
2950 for handle in handles {
2951 handle.await.unwrap();
2952 }
2953
2954 let count = mem.count().await.unwrap();
2955 assert_eq!(
2956 count, 10,
2957 "all 10 concurrent writes must succeed without data loss"
2958 );
2959 }
2960
2961 #[tokio::test]
2962 async fn sqlite_concurrent_read_write_no_panic() {
2963 let (_tmp, mem) = temp_sqlite();
2964 let mem = std::sync::Arc::new(mem);
2965
2966 mem.store("shared_key", "initial", MemoryCategory::Core, None)
2968 .await
2969 .unwrap();
2970
2971 let mut handles = Vec::new();
2972
2973 for _ in 0..5 {
2975 let mem = std::sync::Arc::clone(&mem);
2976 handles.push(tokio::spawn(async move {
2977 let _ = mem.get("shared_key").await.unwrap();
2978 }));
2979 }
2980
2981 for i in 0..5 {
2983 let mem = std::sync::Arc::clone(&mem);
2984 handles.push(tokio::spawn(async move {
2985 mem.store(
2986 &format!("key_{i}"),
2987 &format!("val_{i}"),
2988 MemoryCategory::Core,
2989 None,
2990 )
2991 .await
2992 .unwrap();
2993 }));
2994 }
2995
2996 for handle in handles {
2997 handle.await.unwrap();
2998 }
2999
3000 assert_eq!(mem.count().await.unwrap(), 6);
3002 }
3003
3004 #[tokio::test]
3007 async fn export_no_filter_returns_all_entries() {
3008 let (_tmp, mem) = temp_sqlite();
3009 mem.store("a", "one", MemoryCategory::Core, None)
3010 .await
3011 .unwrap();
3012 mem.store("b", "two", MemoryCategory::Daily, None)
3013 .await
3014 .unwrap();
3015 mem.store("c", "three", MemoryCategory::Conversation, None)
3016 .await
3017 .unwrap();
3018
3019 let filter = ExportFilter::default();
3020 let results = mem.export(&filter).await.unwrap();
3021 assert_eq!(results.len(), 3);
3022 }
3023
3024 #[tokio::test]
3025 async fn export_with_namespace_filter() {
3026 let (_tmp, mem) = temp_sqlite();
3027 mem.store_with_metadata(
3028 "a",
3029 "ns1 data",
3030 MemoryCategory::Core,
3031 None,
3032 Some("ns1"),
3033 None,
3034 )
3035 .await
3036 .unwrap();
3037 mem.store_with_metadata(
3038 "b",
3039 "ns2 data",
3040 MemoryCategory::Core,
3041 None,
3042 Some("ns2"),
3043 None,
3044 )
3045 .await
3046 .unwrap();
3047
3048 let filter = ExportFilter {
3049 namespace: Some("ns1".into()),
3050 ..Default::default()
3051 };
3052 let results = mem.export(&filter).await.unwrap();
3053 assert_eq!(results.len(), 1);
3054 assert_eq!(results[0].namespace, "ns1");
3055 }
3056
3057 #[tokio::test]
3058 async fn export_with_session_id_filter() {
3059 let (_tmp, mem) = temp_sqlite();
3060 mem.store("a", "sess-a data", MemoryCategory::Core, Some("sess-a"))
3061 .await
3062 .unwrap();
3063 mem.store("b", "sess-b data", MemoryCategory::Core, Some("sess-b"))
3064 .await
3065 .unwrap();
3066
3067 let filter = ExportFilter {
3068 session_id: Some("sess-a".into()),
3069 ..Default::default()
3070 };
3071 let results = mem.export(&filter).await.unwrap();
3072 assert_eq!(results.len(), 1);
3073 assert_eq!(results[0].key, "a");
3074 }
3075
3076 #[tokio::test]
3077 async fn export_with_category_filter() {
3078 let (_tmp, mem) = temp_sqlite();
3079 mem.store("a", "core data", MemoryCategory::Core, None)
3080 .await
3081 .unwrap();
3082 mem.store("b", "daily data", MemoryCategory::Daily, None)
3083 .await
3084 .unwrap();
3085
3086 let filter = ExportFilter {
3087 category: Some(MemoryCategory::Core),
3088 ..Default::default()
3089 };
3090 let results = mem.export(&filter).await.unwrap();
3091 assert_eq!(results.len(), 1);
3092 assert_eq!(results[0].category, MemoryCategory::Core);
3093 }
3094
3095 #[tokio::test]
3096 async fn export_with_time_range() {
3097 let (_tmp, mem) = temp_sqlite();
3098 mem.store("a", "old data", MemoryCategory::Core, None)
3100 .await
3101 .unwrap();
3102 mem.store("b", "new data", MemoryCategory::Core, None)
3103 .await
3104 .unwrap();
3105
3106 let filter = ExportFilter {
3108 since: Some("2000-01-01T00:00:00Z".into()),
3109 until: Some("2099-12-31T23:59:59Z".into()),
3110 ..Default::default()
3111 };
3112 let results = mem.export(&filter).await.unwrap();
3113 assert_eq!(results.len(), 2);
3114
3115 let filter = ExportFilter {
3117 since: Some("2099-01-01T00:00:00Z".into()),
3118 ..Default::default()
3119 };
3120 let results = mem.export(&filter).await.unwrap();
3121 assert!(results.is_empty());
3122 }
3123
3124 #[tokio::test]
3125 async fn export_with_combined_filters() {
3126 let (_tmp, mem) = temp_sqlite();
3127 mem.store_with_metadata(
3128 "a",
3129 "match",
3130 MemoryCategory::Core,
3131 Some("sess-a"),
3132 Some("ns1"),
3133 None,
3134 )
3135 .await
3136 .unwrap();
3137 mem.store_with_metadata(
3138 "b",
3139 "no match ns",
3140 MemoryCategory::Core,
3141 Some("sess-a"),
3142 Some("ns2"),
3143 None,
3144 )
3145 .await
3146 .unwrap();
3147 mem.store_with_metadata(
3148 "c",
3149 "no match sess",
3150 MemoryCategory::Core,
3151 None,
3152 Some("ns1"),
3153 None,
3154 )
3155 .await
3156 .unwrap();
3157
3158 let filter = ExportFilter {
3159 namespace: Some("ns1".into()),
3160 session_id: Some("sess-a".into()),
3161 category: Some(MemoryCategory::Core),
3162 since: Some("2000-01-01T00:00:00Z".into()),
3163 until: Some("2099-12-31T23:59:59Z".into()),
3164 };
3165 let results = mem.export(&filter).await.unwrap();
3166 assert_eq!(results.len(), 1);
3167 assert_eq!(results[0].key, "a");
3168 }
3169
3170 #[tokio::test]
3171 async fn export_empty_database_returns_empty_vec() {
3172 let (_tmp, mem) = temp_sqlite();
3173 let filter = ExportFilter::default();
3174 let results = mem.export(&filter).await.unwrap();
3175 assert!(results.is_empty());
3176 }
3177
3178 #[tokio::test]
3179 async fn export_ordering_is_chronological() {
3180 let (_tmp, mem) = temp_sqlite();
3181 mem.store("first", "data1", MemoryCategory::Core, None)
3182 .await
3183 .unwrap();
3184 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
3186 mem.store("second", "data2", MemoryCategory::Core, None)
3187 .await
3188 .unwrap();
3189
3190 let filter = ExportFilter::default();
3191 let results = mem.export(&filter).await.unwrap();
3192 assert_eq!(results.len(), 2);
3193 assert!(
3194 results[0].timestamp <= results[1].timestamp,
3195 "Export must be ordered by created_at ASC"
3196 );
3197 }
3198
3199 #[tokio::test]
3200 async fn export_preserves_field_integrity() {
3201 let (_tmp, mem) = temp_sqlite();
3202 mem.store_with_metadata(
3203 "roundtrip_key",
3204 "roundtrip content",
3205 MemoryCategory::Custom("custom_cat".into()),
3206 Some("sess-rt"),
3207 Some("ns-rt"),
3208 Some(0.9),
3209 )
3210 .await
3211 .unwrap();
3212
3213 let filter = ExportFilter::default();
3214 let results = mem.export(&filter).await.unwrap();
3215 assert_eq!(results.len(), 1);
3216 let e = &results[0];
3217 assert_eq!(e.key, "roundtrip_key");
3218 assert_eq!(e.content, "roundtrip content");
3219 assert_eq!(e.category, MemoryCategory::Custom("custom_cat".into()));
3220 assert_eq!(e.session_id.as_deref(), Some("sess-rt"));
3221 assert_eq!(e.namespace, "ns-rt");
3222 assert_eq!(e.importance, Some(0.9));
3223 }
3224
3225 #[tokio::test]
3228 async fn sqlite_reindex_preserves_data() {
3229 let (_tmp, mem) = temp_sqlite();
3230 mem.store("a", "Rust is fast", MemoryCategory::Core, None)
3231 .await
3232 .unwrap();
3233 mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
3234 .await
3235 .unwrap();
3236
3237 mem.reindex().await.unwrap();
3238
3239 let count = mem.count().await.unwrap();
3240 assert_eq!(count, 2, "reindex must preserve all entries");
3241
3242 let entry = mem.get("a").await.unwrap();
3243 assert!(entry.is_some());
3244 assert_eq!(entry.unwrap().content, "Rust is fast");
3245 }
3246
3247 #[tokio::test]
3248 async fn sqlite_reindex_idempotent() {
3249 let (_tmp, mem) = temp_sqlite();
3250 mem.store("x", "test data", MemoryCategory::Core, None)
3251 .await
3252 .unwrap();
3253
3254 mem.reindex().await.unwrap();
3256 mem.reindex().await.unwrap();
3257 mem.reindex().await.unwrap();
3258
3259 assert_eq!(mem.count().await.unwrap(), 1);
3260 }
3261
3262 #[tokio::test]
3265 async fn search_mode_bm25_only() {
3266 let tmp = TempDir::new().unwrap();
3267 let mem = SqliteMemory::with_embedder(
3268 "test",
3269 tmp.path(),
3270 Arc::new(super::super::embeddings::NoopEmbedding),
3271 0.7,
3272 0.3,
3273 1000,
3274 None,
3275 SearchMode::Bm25,
3276 )
3277 .unwrap();
3278 mem.store(
3279 "lang",
3280 "User prefers Rust programming",
3281 MemoryCategory::Core,
3282 None,
3283 )
3284 .await
3285 .unwrap();
3286 mem.store("food", "User likes pizza", MemoryCategory::Core, None)
3287 .await
3288 .unwrap();
3289
3290 let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
3291 assert!(!results.is_empty(), "BM25 mode should find keyword matches");
3292 assert!(
3293 results.iter().any(|e| e.content.contains("Rust")),
3294 "BM25 should match on keyword 'Rust'"
3295 );
3296 }
3297
3298 #[tokio::test]
3299 async fn search_mode_embedding_only() {
3300 let tmp = TempDir::new().unwrap();
3301 let mem = SqliteMemory::with_embedder(
3303 "test",
3304 tmp.path(),
3305 Arc::new(super::super::embeddings::NoopEmbedding),
3306 0.7,
3307 0.3,
3308 1000,
3309 None,
3310 SearchMode::Embedding,
3311 )
3312 .unwrap();
3313 mem.store(
3314 "lang",
3315 "User prefers Rust programming",
3316 MemoryCategory::Core,
3317 None,
3318 )
3319 .await
3320 .unwrap();
3321
3322 let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
3325 assert!(
3327 results.iter().any(|e| e.content.contains("Rust")),
3328 "Embedding mode with noop should fall back to LIKE and still find results"
3329 );
3330 }
3331
3332 #[tokio::test]
3333 async fn search_mode_hybrid_default() {
3334 let tmp = TempDir::new().unwrap();
3335 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
3336 assert_eq!(mem.search_mode, SearchMode::Hybrid);
3338
3339 mem.store(
3340 "lang",
3341 "User prefers Rust programming",
3342 MemoryCategory::Core,
3343 None,
3344 )
3345 .await
3346 .unwrap();
3347
3348 let results = mem.recall("Rust", 10, None, None, None).await.unwrap();
3349 assert!(!results.is_empty(), "Hybrid mode should find results");
3350 }
3351
3352 #[tokio::test]
3360 async fn get_returns_alias_text_in_agent_alias_and_uuid_in_agent_id() {
3361 let (_tmp, mem) = temp_sqlite();
3362 let alpha_uuid = mem.ensure_agent_uuid("clamps").await.unwrap();
3363 mem.store_with_agent(
3364 "row1",
3365 "v",
3366 MemoryCategory::Core,
3367 None,
3368 None,
3369 None,
3370 Some(&alpha_uuid),
3371 )
3372 .await
3373 .unwrap();
3374
3375 let entry = mem.get("row1").await.unwrap().expect("row1 must exist");
3376 assert_eq!(
3377 entry.agent_alias.as_deref(),
3378 Some("clamps"),
3379 "agent_alias must carry the human-readable alias, not the UUID"
3380 );
3381 assert_eq!(
3382 entry.agent_id.as_deref(),
3383 Some(alpha_uuid.as_str()),
3384 "agent_id must carry the raw UUID FK so scoping equality works"
3385 );
3386 assert_ne!(
3387 entry.agent_alias, entry.agent_id,
3388 "alias and id must differ on a SQL backend"
3389 );
3390 }
3391
3392 #[tokio::test]
3393 async fn list_returns_alias_text_for_every_row() {
3394 let (_tmp, mem) = temp_sqlite();
3395 let a = mem.ensure_agent_uuid("clamps").await.unwrap();
3396 let b = mem.ensure_agent_uuid("glados").await.unwrap();
3397 for (key, owner) in [("r1", &a), ("r2", &b)] {
3398 mem.store_with_agent(
3399 key,
3400 "v",
3401 MemoryCategory::Core,
3402 None,
3403 None,
3404 None,
3405 Some(owner),
3406 )
3407 .await
3408 .unwrap();
3409 }
3410
3411 let mut rows = mem.list(None, None).await.unwrap();
3412 rows.sort_by(|x, y| x.key.cmp(&y.key));
3413 assert_eq!(rows.len(), 2);
3414 assert_eq!(rows[0].agent_alias.as_deref(), Some("clamps"));
3415 assert_eq!(rows[1].agent_alias.as_deref(), Some("glados"));
3416 assert!(
3417 rows.iter().all(|r| r.agent_id.is_some()),
3418 "every row should carry agent_id"
3419 );
3420 }
3421
3422 #[tokio::test]
3425 async fn migrates_legacy_session_ids_to_sanitized_form() {
3426 let tmp = TempDir::new().unwrap();
3427 let raw_sid = "slack_C123_1.2_user one";
3428 let sanitized = sanitize_session_key(raw_sid);
3429 assert_ne!(
3430 raw_sid, sanitized,
3431 "test only meaningful when sanitization changes the value"
3432 );
3433
3434 {
3435 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
3436 mem.store(
3437 "legacy_key",
3438 "stored before sanitize fix",
3439 MemoryCategory::Conversation,
3440 Some(raw_sid),
3441 )
3442 .await
3443 .unwrap();
3444 let pre = mem.list(None, Some(raw_sid)).await.unwrap();
3445 assert_eq!(pre.len(), 1, "raw session_id should match before migration");
3446 }
3447
3448 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
3449
3450 let by_sanitized = mem.list(None, Some(&sanitized)).await.unwrap();
3451 assert_eq!(
3452 by_sanitized.len(),
3453 1,
3454 "row must be discoverable via sanitized session_id"
3455 );
3456 assert_eq!(by_sanitized[0].key, "legacy_key");
3457
3458 let by_raw = mem.list(None, Some(raw_sid)).await.unwrap();
3459 assert!(
3460 by_raw.is_empty(),
3461 "raw form must no longer match after migration"
3462 );
3463 }
3464
3465 #[tokio::test]
3466 async fn session_id_migration_is_idempotent() {
3467 let tmp = TempDir::new().unwrap();
3468 let sanitized = sanitize_session_key("slack_C123_1.2_user");
3469
3470 {
3471 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
3472 mem.store("k", "v", MemoryCategory::Core, Some(&sanitized))
3473 .await
3474 .unwrap();
3475 }
3476
3477 for _ in 0..3 {
3478 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
3479 let entries = mem.list(None, Some(&sanitized)).await.unwrap();
3480 assert_eq!(entries.len(), 1);
3481 }
3482 }
3483
3484 #[tokio::test]
3485 async fn session_id_migration_leaves_null_rows_untouched() {
3486 let tmp = TempDir::new().unwrap();
3487
3488 {
3489 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
3490 mem.store("global", "no session", MemoryCategory::Core, None)
3491 .await
3492 .unwrap();
3493 }
3494
3495 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
3496 let entry = mem.get("global").await.unwrap().expect("row should exist");
3497 assert!(entry.session_id.is_none());
3498 }
3499}