Skip to main content

zeroclaw_memory/
response_cache.rs

1//! Response cache — avoid burning tokens on repeated prompts.
2//!
3//! Stores LLM responses in a separate SQLite table keyed by a SHA-256 hash of
4//! `(model, system_prompt_hash, user_prompt)`. Entries expire after a
5//! configurable TTL (default: 1 hour). The cache is optional and disabled by
6//! default — users opt in via `[memory] response_cache_enabled = true`.
7
8use anyhow::Result;
9use chrono::{Duration, Local};
10use parking_lot::Mutex;
11use rusqlite::{Connection, params};
12use sha2::{Digest, Sha256};
13use std::collections::HashMap;
14use std::path::Path;
15
16/// An in-memory hot cache entry for the two-tier response cache.
17struct InMemoryEntry {
18    response: String,
19    token_count: u32,
20    created_at: std::time::Instant,
21    accessed_at: std::time::Instant,
22}
23
24/// Two-tier response cache: in-memory LRU (hot) + SQLite (warm).
25///
26/// The hot cache avoids SQLite round-trips for frequently repeated prompts.
27/// On miss from hot cache, falls through to SQLite. On hit from SQLite,
28/// the entry is promoted to the hot cache.
29pub struct ResponseCache {
30    conn: Mutex<Connection>,
31    ttl_minutes: i64,
32    max_entries: usize,
33    hot_cache: Mutex<HashMap<String, InMemoryEntry>>,
34    hot_max_entries: usize,
35}
36
37impl ResponseCache {
38    /// Open (or create) the response cache database.
39    pub fn new(workspace_dir: &Path, ttl_minutes: u32, max_entries: usize) -> Result<Self> {
40        Self::with_hot_cache(workspace_dir, ttl_minutes, max_entries, 256)
41    }
42
43    /// Open (or create) the response cache database with a custom hot cache size.
44    pub fn with_hot_cache(
45        workspace_dir: &Path,
46        ttl_minutes: u32,
47        max_entries: usize,
48        hot_max_entries: usize,
49    ) -> Result<Self> {
50        let db_dir = workspace_dir.join("memory");
51        std::fs::create_dir_all(&db_dir)?;
52        let db_path = db_dir.join("response_cache.db");
53
54        let conn = Connection::open(&db_path)?;
55
56        conn.execute_batch(
57            "PRAGMA journal_mode = WAL;
58             PRAGMA synchronous  = NORMAL;
59             PRAGMA temp_store   = MEMORY;",
60        )?;
61
62        conn.execute_batch(
63            "CREATE TABLE IF NOT EXISTS response_cache (
64                prompt_hash TEXT PRIMARY KEY,
65                model       TEXT NOT NULL,
66                response    TEXT NOT NULL,
67                token_count INTEGER NOT NULL DEFAULT 0,
68                created_at  TEXT NOT NULL,
69                accessed_at TEXT NOT NULL,
70                hit_count   INTEGER NOT NULL DEFAULT 0
71            );
72            CREATE INDEX IF NOT EXISTS idx_rc_accessed ON response_cache(accessed_at);
73            CREATE INDEX IF NOT EXISTS idx_rc_created ON response_cache(created_at);",
74        )?;
75
76        Ok(Self {
77            conn: Mutex::new(conn),
78            ttl_minutes: i64::from(ttl_minutes),
79            max_entries,
80            hot_cache: Mutex::new(HashMap::new()),
81            hot_max_entries,
82        })
83    }
84
85    /// Build a deterministic cache key from model + system prompt + user prompt.
86    pub fn cache_key(model: &str, system_prompt: Option<&str>, user_prompt: &str) -> String {
87        let mut hasher = Sha256::new();
88        hasher.update(model.as_bytes());
89        hasher.update(b"|");
90        if let Some(sys) = system_prompt {
91            hasher.update(sys.as_bytes());
92        }
93        hasher.update(b"|");
94        hasher.update(user_prompt.as_bytes());
95        let hash = hasher.finalize();
96        format!("{:064x}", hash)
97    }
98
99    /// Look up a cached response. Returns `None` on miss or expired entry.
100    ///
101    /// Two-tier lookup: checks the in-memory hot cache first, then falls
102    /// through to SQLite. On a SQLite hit the entry is promoted to hot cache.
103    #[allow(clippy::cast_sign_loss)]
104    pub fn get(&self, key: &str) -> Result<Option<String>> {
105        // Tier 1: hot cache (with TTL check)
106        {
107            let mut hot = self.hot_cache.lock();
108            if let Some(entry) = hot.get_mut(key) {
109                let ttl = std::time::Duration::from_secs(self.ttl_minutes as u64 * 60);
110                if entry.created_at.elapsed() > ttl {
111                    hot.remove(key);
112                } else {
113                    entry.accessed_at = std::time::Instant::now();
114                    let response = entry.response.clone();
115                    drop(hot);
116                    // Still bump SQLite hit count for accurate stats
117                    let conn = self.conn.lock();
118                    let now_str = Local::now().to_rfc3339();
119                    conn.execute(
120                        "UPDATE response_cache
121                         SET accessed_at = ?1, hit_count = hit_count + 1
122                         WHERE prompt_hash = ?2",
123                        params![now_str, key],
124                    )?;
125                    return Ok(Some(response));
126                }
127            }
128        }
129
130        // Tier 2: SQLite (warm)
131        let result: Option<(String, u32)> = {
132            let conn = self.conn.lock();
133            let now = Local::now();
134            let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
135
136            let mut stmt = conn.prepare(
137                "SELECT response, token_count FROM response_cache
138                 WHERE prompt_hash = ?1 AND created_at > ?2",
139            )?;
140
141            let result: Option<(String, u32)> = stmt
142                .query_row(params![key, cutoff], |row| Ok((row.get(0)?, row.get(1)?)))
143                .ok();
144
145            if result.is_some() {
146                let now_str = now.to_rfc3339();
147                conn.execute(
148                    "UPDATE response_cache
149                     SET accessed_at = ?1, hit_count = hit_count + 1
150                     WHERE prompt_hash = ?2",
151                    params![now_str, key],
152                )?;
153            }
154
155            result
156        };
157
158        if let Some((ref response, token_count)) = result {
159            self.promote_to_hot(key, response, token_count);
160        }
161
162        Ok(result.map(|(r, _)| r))
163    }
164
165    /// Store a response in the cache (both hot and warm tiers).
166    pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> {
167        // Write to hot cache
168        self.promote_to_hot(key, response, token_count);
169
170        // Write to SQLite (warm)
171        let conn = self.conn.lock();
172
173        let now = Local::now().to_rfc3339();
174
175        conn.execute(
176            "INSERT OR REPLACE INTO response_cache
177             (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count)
178             VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)",
179            params![key, model, response, token_count, now, now],
180        )?;
181
182        // Evict expired entries
183        let cutoff = (Local::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
184        conn.execute(
185            "DELETE FROM response_cache WHERE created_at <= ?1",
186            params![cutoff],
187        )?;
188
189        // LRU eviction if over max_entries
190        #[allow(clippy::cast_possible_wrap)]
191        let max = self.max_entries as i64;
192        conn.execute(
193            "DELETE FROM response_cache WHERE prompt_hash IN (
194                SELECT prompt_hash FROM response_cache
195                ORDER BY accessed_at ASC
196                LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
197            )",
198            params![max],
199        )?;
200
201        Ok(())
202    }
203
204    /// Promote an entry to the in-memory hot cache, evicting the oldest if full.
205    fn promote_to_hot(&self, key: &str, response: &str, token_count: u32) {
206        let mut hot = self.hot_cache.lock();
207
208        // If already present, just update (keep original created_at for TTL)
209        if let Some(entry) = hot.get_mut(key) {
210            entry.response = response.to_string();
211            entry.token_count = token_count;
212            entry.accessed_at = std::time::Instant::now();
213            return;
214        }
215
216        // Evict oldest entry if at capacity
217        if self.hot_max_entries > 0
218            && hot.len() >= self.hot_max_entries
219            && let Some(oldest_key) = hot
220                .iter()
221                .min_by_key(|(_, v)| v.accessed_at)
222                .map(|(k, _)| k.clone())
223        {
224            hot.remove(&oldest_key);
225        }
226
227        if self.hot_max_entries > 0 {
228            let now = std::time::Instant::now();
229            hot.insert(
230                key.to_string(),
231                InMemoryEntry {
232                    response: response.to_string(),
233                    token_count,
234                    created_at: now,
235                    accessed_at: now,
236                },
237            );
238        }
239    }
240
241    /// Return cache statistics: (total_entries, total_hits, total_tokens_saved).
242    pub fn stats(&self) -> Result<(usize, u64, u64)> {
243        let conn = self.conn.lock();
244
245        let count: i64 =
246            conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
247
248        let hits: i64 = conn.query_row(
249            "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache",
250            [],
251            |row| row.get(0),
252        )?;
253
254        let tokens_saved: i64 = conn.query_row(
255            "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache",
256            [],
257            |row| row.get(0),
258        )?;
259
260        #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
261        Ok((count as usize, hits as u64, tokens_saved as u64))
262    }
263
264    /// Wipe the entire response cache.
265    pub fn clear(&self) -> Result<usize> {
266        self.hot_cache.lock().clear();
267        let conn = self.conn.lock();
268        let affected = conn.execute("DELETE FROM response_cache", [])?;
269        Ok(affected)
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use tempfile::TempDir;
277
278    fn temp_cache(ttl_minutes: u32) -> (TempDir, ResponseCache) {
279        let tmp = TempDir::new().unwrap();
280        let cache = ResponseCache::new(tmp.path(), ttl_minutes, 1000).unwrap();
281        (tmp, cache)
282    }
283
284    #[test]
285    fn cache_key_deterministic() {
286        let k1 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
287        let k2 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
288        assert_eq!(k1, k2);
289        assert_eq!(k1.len(), 64); // SHA-256 hex
290    }
291
292    #[test]
293    fn cache_key_varies_by_model() {
294        let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
295        let k2 = ResponseCache::cache_key("claude-3", None, "hello");
296        assert_ne!(k1, k2);
297    }
298
299    #[test]
300    fn cache_key_varies_by_system_prompt() {
301        let k1 = ResponseCache::cache_key("gpt-4", Some("You are helpful"), "hello");
302        let k2 = ResponseCache::cache_key("gpt-4", Some("You are rude"), "hello");
303        assert_ne!(k1, k2);
304    }
305
306    #[test]
307    fn cache_key_varies_by_prompt() {
308        let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
309        let k2 = ResponseCache::cache_key("gpt-4", None, "goodbye");
310        assert_ne!(k1, k2);
311    }
312
313    #[test]
314    fn put_and_get() {
315        let (_tmp, cache) = temp_cache(60);
316        let key = ResponseCache::cache_key("gpt-4", None, "What is Rust?");
317
318        cache
319            .put(&key, "gpt-4", "Rust is a systems programming language.", 25)
320            .unwrap();
321
322        let result = cache.get(&key).unwrap();
323        assert_eq!(
324            result.as_deref(),
325            Some("Rust is a systems programming language.")
326        );
327    }
328
329    #[test]
330    fn miss_returns_none() {
331        let (_tmp, cache) = temp_cache(60);
332        let result = cache.get("nonexistent_key").unwrap();
333        assert!(result.is_none());
334    }
335
336    #[test]
337    fn expired_entry_returns_none() {
338        let (_tmp, cache) = temp_cache(0); // 0-minute TTL → everything is instantly expired
339        let key = ResponseCache::cache_key("gpt-4", None, "test");
340
341        cache.put(&key, "gpt-4", "response", 10).unwrap();
342
343        // The entry was created with created_at = now(), but TTL is 0 minutes,
344        // so cutoff = now() - 0 = now(). The entry's created_at is NOT > cutoff.
345        let result = cache.get(&key).unwrap();
346        assert!(result.is_none());
347    }
348
349    #[test]
350    fn hit_count_incremented() {
351        let (_tmp, cache) = temp_cache(60);
352        let key = ResponseCache::cache_key("gpt-4", None, "hello");
353
354        cache.put(&key, "gpt-4", "Hi!", 5).unwrap();
355
356        // 3 hits
357        for _ in 0..3 {
358            let _ = cache.get(&key).unwrap();
359        }
360
361        let (_, total_hits, _) = cache.stats().unwrap();
362        assert_eq!(total_hits, 3);
363    }
364
365    #[test]
366    fn tokens_saved_calculated() {
367        let (_tmp, cache) = temp_cache(60);
368        let key = ResponseCache::cache_key("gpt-4", None, "explain rust");
369
370        cache.put(&key, "gpt-4", "Rust is...", 100).unwrap();
371
372        // 5 cache hits × 100 tokens = 500 tokens saved
373        for _ in 0..5 {
374            let _ = cache.get(&key).unwrap();
375        }
376
377        let (_, _, tokens_saved) = cache.stats().unwrap();
378        assert_eq!(tokens_saved, 500);
379    }
380
381    #[test]
382    fn lru_eviction() {
383        let tmp = TempDir::new().unwrap();
384        let cache = ResponseCache::new(tmp.path(), 60, 3).unwrap(); // max 3 entries
385
386        for i in 0..5 {
387            let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
388            cache
389                .put(&key, "gpt-4", &format!("response {i}"), 10)
390                .unwrap();
391        }
392
393        let (count, _, _) = cache.stats().unwrap();
394        assert!(count <= 3, "Should have at most 3 entries after eviction");
395    }
396
397    #[test]
398    fn clear_wipes_all() {
399        let (_tmp, cache) = temp_cache(60);
400
401        for i in 0..10 {
402            let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
403            cache
404                .put(&key, "gpt-4", &format!("response {i}"), 10)
405                .unwrap();
406        }
407
408        let cleared = cache.clear().unwrap();
409        assert_eq!(cleared, 10);
410
411        let (count, _, _) = cache.stats().unwrap();
412        assert_eq!(count, 0);
413    }
414
415    #[test]
416    fn stats_empty_cache() {
417        let (_tmp, cache) = temp_cache(60);
418        let (count, hits, tokens) = cache.stats().unwrap();
419        assert_eq!(count, 0);
420        assert_eq!(hits, 0);
421        assert_eq!(tokens, 0);
422    }
423
424    #[test]
425    fn overwrite_same_key() {
426        let (_tmp, cache) = temp_cache(60);
427        let key = ResponseCache::cache_key("gpt-4", None, "question");
428
429        cache.put(&key, "gpt-4", "answer v1", 20).unwrap();
430        cache.put(&key, "gpt-4", "answer v2", 25).unwrap();
431
432        let result = cache.get(&key).unwrap();
433        assert_eq!(result.as_deref(), Some("answer v2"));
434
435        let (count, _, _) = cache.stats().unwrap();
436        assert_eq!(count, 1);
437    }
438
439    #[test]
440    fn unicode_prompt_handling() {
441        let (_tmp, cache) = temp_cache(60);
442        let key = ResponseCache::cache_key("gpt-4", None, "日本語のテスト 🦀");
443
444        cache
445            .put(&key, "gpt-4", "はい、Rustは素晴らしい", 30)
446            .unwrap();
447
448        let result = cache.get(&key).unwrap();
449        assert_eq!(result.as_deref(), Some("はい、Rustは素晴らしい"));
450    }
451
452    // ── §4.4 Cache eviction under pressure tests ─────────────
453
454    #[test]
455    fn lru_eviction_keeps_most_recent() {
456        let tmp = TempDir::new().unwrap();
457        let cache = ResponseCache::new(tmp.path(), 60, 3).unwrap();
458
459        // Insert 3 entries
460        for i in 0..3 {
461            let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
462            cache
463                .put(&key, "gpt-4", &format!("response {i}"), 10)
464                .unwrap();
465        }
466
467        // Access entry 0 to make it recently used
468        let key0 = ResponseCache::cache_key("gpt-4", None, "prompt 0");
469        let _ = cache.get(&key0).unwrap();
470
471        // Insert entry 3 (triggers eviction)
472        let key3 = ResponseCache::cache_key("gpt-4", None, "prompt 3");
473        cache.put(&key3, "gpt-4", "response 3", 10).unwrap();
474
475        let (count, _, _) = cache.stats().unwrap();
476        assert!(count <= 3, "cache must not exceed max_entries");
477
478        // Entry 0 was recently accessed and should survive
479        let entry0 = cache.get(&key0).unwrap();
480        assert!(
481            entry0.is_some(),
482            "recently accessed entry should survive LRU eviction"
483        );
484    }
485
486    #[test]
487    fn cache_handles_zero_max_entries() {
488        let tmp = TempDir::new().unwrap();
489        let cache = ResponseCache::new(tmp.path(), 60, 0).unwrap();
490
491        let key = ResponseCache::cache_key("gpt-4", None, "test");
492        // Should not panic even with max_entries=0
493        cache.put(&key, "gpt-4", "response", 10).unwrap();
494
495        let (count, _, _) = cache.stats().unwrap();
496        assert_eq!(count, 0, "cache with max_entries=0 should evict everything");
497    }
498
499    #[test]
500    fn cache_concurrent_reads_no_panic() {
501        let tmp = TempDir::new().unwrap();
502        let cache = std::sync::Arc::new(ResponseCache::new(tmp.path(), 60, 100).unwrap());
503
504        let key = ResponseCache::cache_key("gpt-4", None, "concurrent");
505        cache.put(&key, "gpt-4", "response", 10).unwrap();
506
507        let mut handles = Vec::new();
508        for _ in 0..10 {
509            let cache = std::sync::Arc::clone(&cache);
510            let key = key.clone();
511            handles.push(std::thread::spawn(move || {
512                let _ = cache.get(&key).unwrap();
513            }));
514        }
515
516        for handle in handles {
517            handle.join().unwrap();
518        }
519
520        let (_, hits, _) = cache.stats().unwrap();
521        assert_eq!(hits, 10, "all concurrent reads should register as hits");
522    }
523}