1use super::traits::{Memory, MemoryEntry};
11use parking_lot::Mutex;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16struct CachedResult {
18 entries: Vec<MemoryEntry>,
19 created_at: Instant,
20}
21
22#[derive(Debug, Clone)]
24pub struct RetrievalConfig {
25 pub stages: Vec<String>,
27 pub fts_early_return_score: f64,
29 pub cache_max_entries: usize,
31 pub cache_ttl: Duration,
33}
34
35impl Default for RetrievalConfig {
36 fn default() -> Self {
37 Self {
38 stages: vec!["cache".into(), "fts".into(), "vector".into()],
39 fts_early_return_score: 0.85,
40 cache_max_entries: 256,
41 cache_ttl: Duration::from_secs(300),
42 }
43 }
44}
45
46pub struct RetrievalPipeline {
48 memory: Arc<dyn Memory>,
49 config: RetrievalConfig,
50 hot_cache: Mutex<HashMap<String, CachedResult>>,
51}
52
53impl RetrievalPipeline {
54 pub fn new(memory: Arc<dyn Memory>, config: RetrievalConfig) -> Self {
55 Self {
56 memory,
57 config,
58 hot_cache: Mutex::new(HashMap::new()),
59 }
60 }
61
62 fn cache_key(
64 query: &str,
65 limit: usize,
66 session_id: Option<&str>,
67 namespace: Option<&str>,
68 ) -> String {
69 format!(
70 "{}:{}:{}:{}",
71 query,
72 limit,
73 session_id.unwrap_or(""),
74 namespace.unwrap_or("")
75 )
76 }
77
78 fn check_cache(&self, key: &str) -> Option<Vec<MemoryEntry>> {
80 let cache = self.hot_cache.lock();
81 if let Some(cached) = cache.get(key)
82 && cached.created_at.elapsed() < self.config.cache_ttl
83 {
84 return Some(cached.entries.clone());
85 }
86 None
87 }
88
89 fn store_in_cache(&self, key: String, entries: Vec<MemoryEntry>) {
91 let mut cache = self.hot_cache.lock();
92
93 if cache.len() >= self.config.cache_max_entries {
95 let oldest_key = cache
96 .iter()
97 .min_by_key(|(_, v)| v.created_at)
98 .map(|(k, _)| k.clone());
99 if let Some(k) = oldest_key {
100 cache.remove(&k);
101 }
102 }
103
104 cache.insert(
105 key,
106 CachedResult {
107 entries,
108 created_at: Instant::now(),
109 },
110 );
111 }
112
113 pub async fn recall(
115 &self,
116 query: &str,
117 limit: usize,
118 session_id: Option<&str>,
119 namespace: Option<&str>,
120 since: Option<&str>,
121 until: Option<&str>,
122 ) -> anyhow::Result<Vec<MemoryEntry>> {
123 let ck = Self::cache_key(query, limit, session_id, namespace);
124
125 for stage in &self.config.stages {
126 match stage.as_str() {
127 "cache" => {
128 if let Some(cached) = self.check_cache(&ck) {
129 ::zeroclaw_log::record!(
130 DEBUG,
131 ::zeroclaw_log::Event::new(
132 module_path!(),
133 ::zeroclaw_log::Action::Note
134 )
135 .with_attrs(::serde_json::json!({"query": query})),
136 "retrieval pipeline: cache hit for ''"
137 );
138 return Ok(cached);
139 }
140 }
141 "fts" | "vector" => {
142 let results = if let Some(ns) = namespace {
145 self.memory
146 .recall_namespaced(ns, query, limit, session_id, since, until)
147 .await?
148 } else {
149 self.memory
150 .recall(query, limit, session_id, since, until)
151 .await?
152 };
153
154 if !results.is_empty() {
155 if stage == "fts"
158 && let Some(top_score) = results.first().and_then(|e| e.score)
159 && top_score >= self.config.fts_early_return_score
160 {
161 ::zeroclaw_log::record!(
162 DEBUG,
163 ::zeroclaw_log::Event::new(
164 module_path!(),
165 ::zeroclaw_log::Action::Note
166 )
167 .with_attrs(::serde_json::json!({"top_score": top_score})),
168 "retrieval pipeline: FTS early return (score=)"
169 );
170 self.store_in_cache(ck, results.clone());
171 return Ok(results);
172 }
173
174 self.store_in_cache(ck, results.clone());
175 return Ok(results);
176 }
177 }
178 other => {
179 ::zeroclaw_log::record!(
180 WARN,
181 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
182 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
183 .with_attrs(::serde_json::json!({"other": other})),
184 "retrieval pipeline: unknown stage '', skipping"
185 );
186 }
187 }
188 }
189
190 Ok(Vec::new())
192 }
193
194 pub fn invalidate_cache(&self) {
196 self.hot_cache.lock().clear();
197 }
198
199 pub fn cache_size(&self) -> usize {
201 self.hot_cache.lock().len()
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::none::NoneMemory;
209
210 #[tokio::test]
211 async fn pipeline_returns_empty_from_none_backend() {
212 let memory = Arc::new(NoneMemory::new("none"));
213 let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
214
215 let results = pipeline
216 .recall("test", 10, None, None, None, None)
217 .await
218 .unwrap();
219 assert!(results.is_empty());
220 }
221
222 #[tokio::test]
223 async fn pipeline_cache_invalidation() {
224 let memory = Arc::new(NoneMemory::new("none"));
225 let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
226
227 let ck = RetrievalPipeline::cache_key("test", 10, None, None);
229 pipeline.store_in_cache(ck, vec![]);
230
231 assert_eq!(pipeline.cache_size(), 1);
232 pipeline.invalidate_cache();
233 assert_eq!(pipeline.cache_size(), 0);
234 }
235
236 #[test]
237 fn cache_key_includes_all_params() {
238 let k1 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns1"));
239 let k2 = RetrievalPipeline::cache_key("hello", 10, Some("sess-b"), Some("ns1"));
240 let k3 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns2"));
241
242 assert_ne!(k1, k2);
243 assert_ne!(k1, k3);
244 }
245
246 #[tokio::test]
247 async fn pipeline_caches_results() {
248 let memory = Arc::new(NoneMemory::new("none"));
249 let config = RetrievalConfig {
250 stages: vec!["cache".into()],
251 ..Default::default()
252 };
253 let pipeline = RetrievalPipeline::new(memory, config);
254
255 let results = pipeline
257 .recall("test", 10, None, None, None, None)
258 .await
259 .unwrap();
260 assert!(results.is_empty());
261
262 let ck = RetrievalPipeline::cache_key("cached_query", 5, None, None);
264 let fake_entry = MemoryEntry {
265 id: "1".into(),
266 key: "k".into(),
267 content: "cached content".into(),
268 category: crate::traits::MemoryCategory::Core,
269 timestamp: "now".into(),
270 session_id: None,
271 score: Some(0.9),
272 namespace: "default".into(),
273 importance: None,
274 superseded_by: None,
275 agent_alias: None,
276 agent_id: None,
277 };
278 pipeline.store_in_cache(ck, vec![fake_entry]);
279
280 let results = pipeline
282 .recall("cached_query", 5, None, None, None, None)
283 .await
284 .unwrap();
285 assert_eq!(results.len(), 1);
286 assert_eq!(results[0].content, "cached content");
287 }
288}