Skip to main content

zeroclaw_memory/
retrieval.rs

1//! Multi-stage retrieval pipeline.
2//!
3//! Wraps a `Memory` trait object with staged retrieval:
4//! - **Stage 1 (Hot cache):** In-memory LRU of recent recall results.
5//! - **Stage 2 (FTS):** FTS5 keyword search with optional early-return.
6//! - **Stage 3 (Vector):** Vector similarity search + hybrid merge.
7//!
8//! Configurable via `[memory]` settings: `retrieval_stages`, `fts_early_return_score`.
9
10use super::traits::{Memory, MemoryEntry};
11use parking_lot::Mutex;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16/// A cached recall result.
17struct CachedResult {
18    entries: Vec<MemoryEntry>,
19    created_at: Instant,
20}
21
22/// Multi-stage retrieval pipeline configuration.
23#[derive(Debug, Clone)]
24pub struct RetrievalConfig {
25    /// Ordered list of stages: "cache", "fts", "vector".
26    pub stages: Vec<String>,
27    /// FTS score above which to early-return without vector stage.
28    pub fts_early_return_score: f64,
29    /// Max entries in the hot cache.
30    pub cache_max_entries: usize,
31    /// TTL for cached results.
32    pub cache_ttl: Duration,
33}
34
35impl Default for RetrievalConfig {
36    fn default() -> Self {
37        Self {
38            stages: vec!["cache".into(), "fts".into(), "vector".into()],
39            fts_early_return_score: 0.85,
40            cache_max_entries: 256,
41            cache_ttl: Duration::from_secs(300),
42        }
43    }
44}
45
46/// Multi-stage retrieval pipeline wrapping a `Memory` backend.
47pub struct RetrievalPipeline {
48    memory: Arc<dyn Memory>,
49    config: RetrievalConfig,
50    hot_cache: Mutex<HashMap<String, CachedResult>>,
51}
52
53impl RetrievalPipeline {
54    pub fn new(memory: Arc<dyn Memory>, config: RetrievalConfig) -> Self {
55        Self {
56            memory,
57            config,
58            hot_cache: Mutex::new(HashMap::new()),
59        }
60    }
61
62    /// Build a cache key from query parameters.
63    fn cache_key(
64        query: &str,
65        limit: usize,
66        session_id: Option<&str>,
67        namespace: Option<&str>,
68    ) -> String {
69        format!(
70            "{}:{}:{}:{}",
71            query,
72            limit,
73            session_id.unwrap_or(""),
74            namespace.unwrap_or("")
75        )
76    }
77
78    /// Check the hot cache for a previous result.
79    fn check_cache(&self, key: &str) -> Option<Vec<MemoryEntry>> {
80        let cache = self.hot_cache.lock();
81        if let Some(cached) = cache.get(key)
82            && cached.created_at.elapsed() < self.config.cache_ttl
83        {
84            return Some(cached.entries.clone());
85        }
86        None
87    }
88
89    /// Store a result in the hot cache with LRU eviction.
90    fn store_in_cache(&self, key: String, entries: Vec<MemoryEntry>) {
91        let mut cache = self.hot_cache.lock();
92
93        // LRU eviction: remove oldest entries if at capacity
94        if cache.len() >= self.config.cache_max_entries {
95            let oldest_key = cache
96                .iter()
97                .min_by_key(|(_, v)| v.created_at)
98                .map(|(k, _)| k.clone());
99            if let Some(k) = oldest_key {
100                cache.remove(&k);
101            }
102        }
103
104        cache.insert(
105            key,
106            CachedResult {
107                entries,
108                created_at: Instant::now(),
109            },
110        );
111    }
112
113    /// Execute the multi-stage retrieval pipeline.
114    pub async fn recall(
115        &self,
116        query: &str,
117        limit: usize,
118        session_id: Option<&str>,
119        namespace: Option<&str>,
120        since: Option<&str>,
121        until: Option<&str>,
122    ) -> anyhow::Result<Vec<MemoryEntry>> {
123        let ck = Self::cache_key(query, limit, session_id, namespace);
124
125        for stage in &self.config.stages {
126            match stage.as_str() {
127                "cache" => {
128                    if let Some(cached) = self.check_cache(&ck) {
129                        ::zeroclaw_log::record!(
130                            DEBUG,
131                            ::zeroclaw_log::Event::new(
132                                module_path!(),
133                                ::zeroclaw_log::Action::Note
134                            )
135                            .with_attrs(::serde_json::json!({"query": query})),
136                            "retrieval pipeline: cache hit for ''"
137                        );
138                        return Ok(cached);
139                    }
140                }
141                "fts" | "vector" => {
142                    // Both FTS and vector are handled by the backend's recall method
143                    // which already does hybrid merge. We delegate to it.
144                    let results = if let Some(ns) = namespace {
145                        self.memory
146                            .recall_namespaced(ns, query, limit, session_id, since, until)
147                            .await?
148                    } else {
149                        self.memory
150                            .recall(query, limit, session_id, since, until)
151                            .await?
152                    };
153
154                    if !results.is_empty() {
155                        // Check for FTS early-return: if top score exceeds threshold
156                        // and we're in the FTS stage, we can skip further stages
157                        if stage == "fts"
158                            && let Some(top_score) = results.first().and_then(|e| e.score)
159                            && top_score >= self.config.fts_early_return_score
160                        {
161                            ::zeroclaw_log::record!(
162                                DEBUG,
163                                ::zeroclaw_log::Event::new(
164                                    module_path!(),
165                                    ::zeroclaw_log::Action::Note
166                                )
167                                .with_attrs(::serde_json::json!({"top_score": top_score})),
168                                "retrieval pipeline: FTS early return (score=)"
169                            );
170                            self.store_in_cache(ck, results.clone());
171                            return Ok(results);
172                        }
173
174                        self.store_in_cache(ck, results.clone());
175                        return Ok(results);
176                    }
177                }
178                other => {
179                    ::zeroclaw_log::record!(
180                        WARN,
181                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
182                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
183                            .with_attrs(::serde_json::json!({"other": other})),
184                        "retrieval pipeline: unknown stage '', skipping"
185                    );
186                }
187            }
188        }
189
190        // No results from any stage
191        Ok(Vec::new())
192    }
193
194    /// Invalidate the hot cache (e.g. after a store operation).
195    pub fn invalidate_cache(&self) {
196        self.hot_cache.lock().clear();
197    }
198
199    /// Get the number of entries in the hot cache.
200    pub fn cache_size(&self) -> usize {
201        self.hot_cache.lock().len()
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::none::NoneMemory;
209
210    #[tokio::test]
211    async fn pipeline_returns_empty_from_none_backend() {
212        let memory = Arc::new(NoneMemory::new("none"));
213        let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
214
215        let results = pipeline
216            .recall("test", 10, None, None, None, None)
217            .await
218            .unwrap();
219        assert!(results.is_empty());
220    }
221
222    #[tokio::test]
223    async fn pipeline_cache_invalidation() {
224        let memory = Arc::new(NoneMemory::new("none"));
225        let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
226
227        // Force a cache entry
228        let ck = RetrievalPipeline::cache_key("test", 10, None, None);
229        pipeline.store_in_cache(ck, vec![]);
230
231        assert_eq!(pipeline.cache_size(), 1);
232        pipeline.invalidate_cache();
233        assert_eq!(pipeline.cache_size(), 0);
234    }
235
236    #[test]
237    fn cache_key_includes_all_params() {
238        let k1 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns1"));
239        let k2 = RetrievalPipeline::cache_key("hello", 10, Some("sess-b"), Some("ns1"));
240        let k3 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns2"));
241
242        assert_ne!(k1, k2);
243        assert_ne!(k1, k3);
244    }
245
246    #[tokio::test]
247    async fn pipeline_caches_results() {
248        let memory = Arc::new(NoneMemory::new("none"));
249        let config = RetrievalConfig {
250            stages: vec!["cache".into()],
251            ..Default::default()
252        };
253        let pipeline = RetrievalPipeline::new(memory, config);
254
255        // First call: cache miss, no results
256        let results = pipeline
257            .recall("test", 10, None, None, None, None)
258            .await
259            .unwrap();
260        assert!(results.is_empty());
261
262        // Manually insert a cache entry
263        let ck = RetrievalPipeline::cache_key("cached_query", 5, None, None);
264        let fake_entry = MemoryEntry {
265            id: "1".into(),
266            key: "k".into(),
267            content: "cached content".into(),
268            category: crate::traits::MemoryCategory::Core,
269            timestamp: "now".into(),
270            session_id: None,
271            score: Some(0.9),
272            namespace: "default".into(),
273            importance: None,
274            superseded_by: None,
275            agent_alias: None,
276            agent_id: None,
277        };
278        pipeline.store_in_cache(ck, vec![fake_entry]);
279
280        // Cache hit
281        let results = pipeline
282            .recall("cached_query", 5, None, None, None, None)
283            .await
284            .unwrap();
285        assert_eq!(results.len(), 1);
286        assert_eq!(results[0].content, "cached content");
287    }
288}