1use anyhow::Result;
9use chrono::{Duration, Local};
10use parking_lot::Mutex;
11use rusqlite::{Connection, params};
12use sha2::{Digest, Sha256};
13use std::collections::HashMap;
14use std::path::Path;
15
16struct InMemoryEntry {
18 response: String,
19 token_count: u32,
20 created_at: std::time::Instant,
21 accessed_at: std::time::Instant,
22}
23
24pub struct ResponseCache {
30 conn: Mutex<Connection>,
31 ttl_minutes: i64,
32 max_entries: usize,
33 hot_cache: Mutex<HashMap<String, InMemoryEntry>>,
34 hot_max_entries: usize,
35}
36
37impl ResponseCache {
38 pub fn new(workspace_dir: &Path, ttl_minutes: u32, max_entries: usize) -> Result<Self> {
40 Self::with_hot_cache(workspace_dir, ttl_minutes, max_entries, 256)
41 }
42
43 pub fn with_hot_cache(
45 workspace_dir: &Path,
46 ttl_minutes: u32,
47 max_entries: usize,
48 hot_max_entries: usize,
49 ) -> Result<Self> {
50 let db_dir = workspace_dir.join("memory");
51 std::fs::create_dir_all(&db_dir)?;
52 let db_path = db_dir.join("response_cache.db");
53
54 let conn = Connection::open(&db_path)?;
55
56 conn.execute_batch(
57 "PRAGMA journal_mode = WAL;
58 PRAGMA synchronous = NORMAL;
59 PRAGMA temp_store = MEMORY;",
60 )?;
61
62 conn.execute_batch(
63 "CREATE TABLE IF NOT EXISTS response_cache (
64 prompt_hash TEXT PRIMARY KEY,
65 model TEXT NOT NULL,
66 response TEXT NOT NULL,
67 token_count INTEGER NOT NULL DEFAULT 0,
68 created_at TEXT NOT NULL,
69 accessed_at TEXT NOT NULL,
70 hit_count INTEGER NOT NULL DEFAULT 0
71 );
72 CREATE INDEX IF NOT EXISTS idx_rc_accessed ON response_cache(accessed_at);
73 CREATE INDEX IF NOT EXISTS idx_rc_created ON response_cache(created_at);",
74 )?;
75
76 Ok(Self {
77 conn: Mutex::new(conn),
78 ttl_minutes: i64::from(ttl_minutes),
79 max_entries,
80 hot_cache: Mutex::new(HashMap::new()),
81 hot_max_entries,
82 })
83 }
84
85 pub fn cache_key(model: &str, system_prompt: Option<&str>, user_prompt: &str) -> String {
87 let mut hasher = Sha256::new();
88 hasher.update(model.as_bytes());
89 hasher.update(b"|");
90 if let Some(sys) = system_prompt {
91 hasher.update(sys.as_bytes());
92 }
93 hasher.update(b"|");
94 hasher.update(user_prompt.as_bytes());
95 let hash = hasher.finalize();
96 format!("{:064x}", hash)
97 }
98
99 #[allow(clippy::cast_sign_loss)]
104 pub fn get(&self, key: &str) -> Result<Option<String>> {
105 {
107 let mut hot = self.hot_cache.lock();
108 if let Some(entry) = hot.get_mut(key) {
109 let ttl = std::time::Duration::from_secs(self.ttl_minutes as u64 * 60);
110 if entry.created_at.elapsed() > ttl {
111 hot.remove(key);
112 } else {
113 entry.accessed_at = std::time::Instant::now();
114 let response = entry.response.clone();
115 drop(hot);
116 let conn = self.conn.lock();
118 let now_str = Local::now().to_rfc3339();
119 conn.execute(
120 "UPDATE response_cache
121 SET accessed_at = ?1, hit_count = hit_count + 1
122 WHERE prompt_hash = ?2",
123 params![now_str, key],
124 )?;
125 return Ok(Some(response));
126 }
127 }
128 }
129
130 let result: Option<(String, u32)> = {
132 let conn = self.conn.lock();
133 let now = Local::now();
134 let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
135
136 let mut stmt = conn.prepare(
137 "SELECT response, token_count FROM response_cache
138 WHERE prompt_hash = ?1 AND created_at > ?2",
139 )?;
140
141 let result: Option<(String, u32)> = stmt
142 .query_row(params![key, cutoff], |row| Ok((row.get(0)?, row.get(1)?)))
143 .ok();
144
145 if result.is_some() {
146 let now_str = now.to_rfc3339();
147 conn.execute(
148 "UPDATE response_cache
149 SET accessed_at = ?1, hit_count = hit_count + 1
150 WHERE prompt_hash = ?2",
151 params![now_str, key],
152 )?;
153 }
154
155 result
156 };
157
158 if let Some((ref response, token_count)) = result {
159 self.promote_to_hot(key, response, token_count);
160 }
161
162 Ok(result.map(|(r, _)| r))
163 }
164
165 pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> {
167 self.promote_to_hot(key, response, token_count);
169
170 let conn = self.conn.lock();
172
173 let now = Local::now().to_rfc3339();
174
175 conn.execute(
176 "INSERT OR REPLACE INTO response_cache
177 (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count)
178 VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)",
179 params![key, model, response, token_count, now, now],
180 )?;
181
182 let cutoff = (Local::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
184 conn.execute(
185 "DELETE FROM response_cache WHERE created_at <= ?1",
186 params![cutoff],
187 )?;
188
189 #[allow(clippy::cast_possible_wrap)]
191 let max = self.max_entries as i64;
192 conn.execute(
193 "DELETE FROM response_cache WHERE prompt_hash IN (
194 SELECT prompt_hash FROM response_cache
195 ORDER BY accessed_at ASC
196 LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
197 )",
198 params![max],
199 )?;
200
201 Ok(())
202 }
203
204 fn promote_to_hot(&self, key: &str, response: &str, token_count: u32) {
206 let mut hot = self.hot_cache.lock();
207
208 if let Some(entry) = hot.get_mut(key) {
210 entry.response = response.to_string();
211 entry.token_count = token_count;
212 entry.accessed_at = std::time::Instant::now();
213 return;
214 }
215
216 if self.hot_max_entries > 0
218 && hot.len() >= self.hot_max_entries
219 && let Some(oldest_key) = hot
220 .iter()
221 .min_by_key(|(_, v)| v.accessed_at)
222 .map(|(k, _)| k.clone())
223 {
224 hot.remove(&oldest_key);
225 }
226
227 if self.hot_max_entries > 0 {
228 let now = std::time::Instant::now();
229 hot.insert(
230 key.to_string(),
231 InMemoryEntry {
232 response: response.to_string(),
233 token_count,
234 created_at: now,
235 accessed_at: now,
236 },
237 );
238 }
239 }
240
241 pub fn stats(&self) -> Result<(usize, u64, u64)> {
243 let conn = self.conn.lock();
244
245 let count: i64 =
246 conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
247
248 let hits: i64 = conn.query_row(
249 "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache",
250 [],
251 |row| row.get(0),
252 )?;
253
254 let tokens_saved: i64 = conn.query_row(
255 "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache",
256 [],
257 |row| row.get(0),
258 )?;
259
260 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
261 Ok((count as usize, hits as u64, tokens_saved as u64))
262 }
263
264 pub fn clear(&self) -> Result<usize> {
266 self.hot_cache.lock().clear();
267 let conn = self.conn.lock();
268 let affected = conn.execute("DELETE FROM response_cache", [])?;
269 Ok(affected)
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use tempfile::TempDir;
277
278 fn temp_cache(ttl_minutes: u32) -> (TempDir, ResponseCache) {
279 let tmp = TempDir::new().unwrap();
280 let cache = ResponseCache::new(tmp.path(), ttl_minutes, 1000).unwrap();
281 (tmp, cache)
282 }
283
284 #[test]
285 fn cache_key_deterministic() {
286 let k1 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
287 let k2 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
288 assert_eq!(k1, k2);
289 assert_eq!(k1.len(), 64); }
291
292 #[test]
293 fn cache_key_varies_by_model() {
294 let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
295 let k2 = ResponseCache::cache_key("claude-3", None, "hello");
296 assert_ne!(k1, k2);
297 }
298
299 #[test]
300 fn cache_key_varies_by_system_prompt() {
301 let k1 = ResponseCache::cache_key("gpt-4", Some("You are helpful"), "hello");
302 let k2 = ResponseCache::cache_key("gpt-4", Some("You are rude"), "hello");
303 assert_ne!(k1, k2);
304 }
305
306 #[test]
307 fn cache_key_varies_by_prompt() {
308 let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
309 let k2 = ResponseCache::cache_key("gpt-4", None, "goodbye");
310 assert_ne!(k1, k2);
311 }
312
313 #[test]
314 fn put_and_get() {
315 let (_tmp, cache) = temp_cache(60);
316 let key = ResponseCache::cache_key("gpt-4", None, "What is Rust?");
317
318 cache
319 .put(&key, "gpt-4", "Rust is a systems programming language.", 25)
320 .unwrap();
321
322 let result = cache.get(&key).unwrap();
323 assert_eq!(
324 result.as_deref(),
325 Some("Rust is a systems programming language.")
326 );
327 }
328
329 #[test]
330 fn miss_returns_none() {
331 let (_tmp, cache) = temp_cache(60);
332 let result = cache.get("nonexistent_key").unwrap();
333 assert!(result.is_none());
334 }
335
336 #[test]
337 fn expired_entry_returns_none() {
338 let (_tmp, cache) = temp_cache(0); let key = ResponseCache::cache_key("gpt-4", None, "test");
340
341 cache.put(&key, "gpt-4", "response", 10).unwrap();
342
343 let result = cache.get(&key).unwrap();
346 assert!(result.is_none());
347 }
348
349 #[test]
350 fn hit_count_incremented() {
351 let (_tmp, cache) = temp_cache(60);
352 let key = ResponseCache::cache_key("gpt-4", None, "hello");
353
354 cache.put(&key, "gpt-4", "Hi!", 5).unwrap();
355
356 for _ in 0..3 {
358 let _ = cache.get(&key).unwrap();
359 }
360
361 let (_, total_hits, _) = cache.stats().unwrap();
362 assert_eq!(total_hits, 3);
363 }
364
365 #[test]
366 fn tokens_saved_calculated() {
367 let (_tmp, cache) = temp_cache(60);
368 let key = ResponseCache::cache_key("gpt-4", None, "explain rust");
369
370 cache.put(&key, "gpt-4", "Rust is...", 100).unwrap();
371
372 for _ in 0..5 {
374 let _ = cache.get(&key).unwrap();
375 }
376
377 let (_, _, tokens_saved) = cache.stats().unwrap();
378 assert_eq!(tokens_saved, 500);
379 }
380
381 #[test]
382 fn lru_eviction() {
383 let tmp = TempDir::new().unwrap();
384 let cache = ResponseCache::new(tmp.path(), 60, 3).unwrap(); for i in 0..5 {
387 let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
388 cache
389 .put(&key, "gpt-4", &format!("response {i}"), 10)
390 .unwrap();
391 }
392
393 let (count, _, _) = cache.stats().unwrap();
394 assert!(count <= 3, "Should have at most 3 entries after eviction");
395 }
396
397 #[test]
398 fn clear_wipes_all() {
399 let (_tmp, cache) = temp_cache(60);
400
401 for i in 0..10 {
402 let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
403 cache
404 .put(&key, "gpt-4", &format!("response {i}"), 10)
405 .unwrap();
406 }
407
408 let cleared = cache.clear().unwrap();
409 assert_eq!(cleared, 10);
410
411 let (count, _, _) = cache.stats().unwrap();
412 assert_eq!(count, 0);
413 }
414
415 #[test]
416 fn stats_empty_cache() {
417 let (_tmp, cache) = temp_cache(60);
418 let (count, hits, tokens) = cache.stats().unwrap();
419 assert_eq!(count, 0);
420 assert_eq!(hits, 0);
421 assert_eq!(tokens, 0);
422 }
423
424 #[test]
425 fn overwrite_same_key() {
426 let (_tmp, cache) = temp_cache(60);
427 let key = ResponseCache::cache_key("gpt-4", None, "question");
428
429 cache.put(&key, "gpt-4", "answer v1", 20).unwrap();
430 cache.put(&key, "gpt-4", "answer v2", 25).unwrap();
431
432 let result = cache.get(&key).unwrap();
433 assert_eq!(result.as_deref(), Some("answer v2"));
434
435 let (count, _, _) = cache.stats().unwrap();
436 assert_eq!(count, 1);
437 }
438
439 #[test]
440 fn unicode_prompt_handling() {
441 let (_tmp, cache) = temp_cache(60);
442 let key = ResponseCache::cache_key("gpt-4", None, "日本語のテスト 🦀");
443
444 cache
445 .put(&key, "gpt-4", "はい、Rustは素晴らしい", 30)
446 .unwrap();
447
448 let result = cache.get(&key).unwrap();
449 assert_eq!(result.as_deref(), Some("はい、Rustは素晴らしい"));
450 }
451
452 #[test]
455 fn lru_eviction_keeps_most_recent() {
456 let tmp = TempDir::new().unwrap();
457 let cache = ResponseCache::new(tmp.path(), 60, 3).unwrap();
458
459 for i in 0..3 {
461 let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
462 cache
463 .put(&key, "gpt-4", &format!("response {i}"), 10)
464 .unwrap();
465 }
466
467 let key0 = ResponseCache::cache_key("gpt-4", None, "prompt 0");
469 let _ = cache.get(&key0).unwrap();
470
471 let key3 = ResponseCache::cache_key("gpt-4", None, "prompt 3");
473 cache.put(&key3, "gpt-4", "response 3", 10).unwrap();
474
475 let (count, _, _) = cache.stats().unwrap();
476 assert!(count <= 3, "cache must not exceed max_entries");
477
478 let entry0 = cache.get(&key0).unwrap();
480 assert!(
481 entry0.is_some(),
482 "recently accessed entry should survive LRU eviction"
483 );
484 }
485
486 #[test]
487 fn cache_handles_zero_max_entries() {
488 let tmp = TempDir::new().unwrap();
489 let cache = ResponseCache::new(tmp.path(), 60, 0).unwrap();
490
491 let key = ResponseCache::cache_key("gpt-4", None, "test");
492 cache.put(&key, "gpt-4", "response", 10).unwrap();
494
495 let (count, _, _) = cache.stats().unwrap();
496 assert_eq!(count, 0, "cache with max_entries=0 should evict everything");
497 }
498
499 #[test]
500 fn cache_concurrent_reads_no_panic() {
501 let tmp = TempDir::new().unwrap();
502 let cache = std::sync::Arc::new(ResponseCache::new(tmp.path(), 60, 100).unwrap());
503
504 let key = ResponseCache::cache_key("gpt-4", None, "concurrent");
505 cache.put(&key, "gpt-4", "response", 10).unwrap();
506
507 let mut handles = Vec::new();
508 for _ in 0..10 {
509 let cache = std::sync::Arc::clone(&cache);
510 let key = key.clone();
511 handles.push(std::thread::spawn(move || {
512 let _ = cache.get(&key).unwrap();
513 }));
514 }
515
516 for handle in handles {
517 handle.join().unwrap();
518 }
519
520 let (_, hits, _) = cache.stats().unwrap();
521 assert_eq!(hits, 10, "all concurrent reads should register as hits");
522 }
523}