Skip to main content

zeroclaw_memory/
consolidation.rs

1//! LLM-driven memory consolidation.
2//!
3//! After each conversation turn, extracts structured information:
4//! - `history_entry`: A timestamped summary for the daily conversation log.
5//! - `memory_update`: New facts, preferences, or decisions worth remembering
6//!   long-term (or `null` if nothing new was learned).
7//!
8//! This two-phase approach replaces the naive raw-message auto-save with
9//! semantic extraction, similar to Nanobot's `save_memory` tool call pattern.
10
11use crate::conflict;
12use crate::importance;
13use crate::traits::{Memory, MemoryCategory};
14use zeroclaw_api::model_provider::ModelProvider;
15
16/// Output of consolidation extraction.
17#[derive(Debug, serde::Deserialize)]
18pub struct ConsolidationResult {
19    /// Brief timestamped summary for the conversation history log.
20    pub history_entry: String,
21    /// New facts/preferences/decisions to store long-term, or None.
22    pub memory_update: Option<String>,
23    /// Atomic facts extracted from the turn (when consolidation_extract_facts is enabled).
24    #[serde(default)]
25    pub facts: Vec<String>,
26    /// Observed trend or pattern (when consolidation_extract_facts is enabled).
27    #[serde(default)]
28    pub trend: Option<String>,
29}
30
31const CONSOLIDATION_SYSTEM_PROMPT: &str = r#"You are a memory consolidation engine. Given a conversation turn, extract:
321. "history_entry": A brief summary of what happened in this turn (1-2 sentences). Include the key topic or action.
332. "memory_update": Any NEW facts, preferences, decisions, or commitments worth remembering long-term. Return null if nothing new was learned.
34
35Respond ONLY with valid JSON: {"history_entry": "...", "memory_update": "..." or null}
36Do not include any text outside the JSON object."#;
37
38/// Run two-phase LLM-driven consolidation on a conversation turn.
39///
40/// Phase 1: Write a history entry to the Daily memory category.
41/// Phase 2: Write a memory update to the Core category (if the LLM identified new facts).
42///
43/// This function is designed to be called fire-and-forget via `tokio::spawn`.
44/// Strip channel media markers (e.g. `[IMAGE:/local/path]`, `[DOCUMENT:...]`)
45/// that contain local filesystem paths.  These must never be forwarded to
46/// upstream model_provider APIs — they would leak local paths and cause API errors.
47fn strip_media_markers(text: &str) -> String {
48    // Matches [IMAGE:...], [DOCUMENT:...], [FILE:...], [VIDEO:...], [VOICE:...], [AUDIO:...]
49    static RE: std::sync::LazyLock<regex::Regex> = std::sync::LazyLock::new(|| {
50        regex::Regex::new(r"\[(?:IMAGE|DOCUMENT|FILE|VIDEO|VOICE|AUDIO):[^\]]*\]").unwrap()
51    });
52    RE.replace_all(text, "[media attachment]").into_owned()
53}
54
55pub async fn consolidate_turn(
56    model_provider: &dyn ModelProvider,
57    model: &str,
58    temperature: Option<f64>,
59    memory: &dyn Memory,
60    user_message: &str,
61    assistant_response: &str,
62) -> anyhow::Result<()> {
63    let turn_text = format!(
64        "User: {}\nAssistant: {}",
65        strip_media_markers(user_message),
66        strip_media_markers(assistant_response),
67    );
68
69    // Truncate very long turns to avoid wasting tokens on consolidation.
70    // Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8 (e.g. CJK text).
71    let truncated = if turn_text.len() > 4000 {
72        let end = turn_text
73            .char_indices()
74            .map(|(i, _)| i)
75            .take_while(|&i| i <= 4000)
76            .last()
77            .unwrap_or(0);
78        format!("{}…", &turn_text[..end])
79    } else {
80        turn_text.clone()
81    };
82
83    let raw = model_provider
84        .chat_with_system(
85            Some(CONSOLIDATION_SYSTEM_PROMPT),
86            &truncated,
87            model,
88            temperature,
89        )
90        .await?;
91
92    let result: ConsolidationResult = parse_consolidation_response(&raw, &turn_text);
93
94    // Phase 1: Write history entry to Daily category.
95    let date = chrono::Local::now().format("%Y-%m-%d").to_string();
96    let history_key = format!("daily_{date}_{}", uuid::Uuid::new_v4());
97    memory
98        .store(
99            &history_key,
100            &result.history_entry,
101            MemoryCategory::Daily,
102            None,
103        )
104        .await?;
105
106    // Phase 2: Write memory update to Core category (if present).
107    if let Some(ref update) = result.memory_update
108        && !update.trim().is_empty()
109    {
110        let mem_key = format!("core_{}", uuid::Uuid::new_v4());
111
112        // Compute importance score heuristically.
113        let imp = importance::compute_importance(update, &MemoryCategory::Core);
114
115        // Check for conflicts with existing Core memories.
116        if let Err(e) = conflict::check_and_resolve_conflicts(
117            memory,
118            &mem_key,
119            update,
120            &MemoryCategory::Core,
121            0.85,
122        )
123        .await
124        {
125            ::zeroclaw_log::record!(
126                DEBUG,
127                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
128                    .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
129                "conflict check skipped"
130            );
131        }
132
133        // Store with importance metadata.
134        memory
135            .store_with_metadata(
136                &mem_key,
137                update,
138                MemoryCategory::Core,
139                None,
140                None,
141                Some(imp),
142            )
143            .await?;
144    }
145
146    Ok(())
147}
148
149/// Parse the LLM's consolidation response, with fallback for malformed JSON.
150fn parse_consolidation_response(raw: &str, fallback_text: &str) -> ConsolidationResult {
151    // Try to extract JSON from the response (LLM may wrap in markdown code blocks).
152    let cleaned = raw
153        .trim()
154        .trim_start_matches("```json")
155        .trim_start_matches("```")
156        .trim_end_matches("```")
157        .trim();
158
159    serde_json::from_str(cleaned).unwrap_or_else(|_| {
160        // Fallback: use truncated turn text as history entry.
161        // Use char-boundary-safe slicing to prevent panic on multi-byte UTF-8.
162        let summary = if fallback_text.len() > 200 {
163            let end = fallback_text
164                .char_indices()
165                .map(|(i, _)| i)
166                .take_while(|&i| i <= 200)
167                .last()
168                .unwrap_or(0);
169            format!("{}…", &fallback_text[..end])
170        } else {
171            fallback_text.to_string()
172        };
173        ConsolidationResult {
174            history_entry: summary,
175            memory_update: None,
176            facts: Vec::new(),
177            trend: None,
178        }
179    })
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn parse_valid_json_response() {
188        let raw = r#"{"history_entry": "User asked about Rust.", "memory_update": "User prefers Rust over Go."}"#;
189        let result = parse_consolidation_response(raw, "fallback");
190        assert_eq!(result.history_entry, "User asked about Rust.");
191        assert_eq!(
192            result.memory_update.as_deref(),
193            Some("User prefers Rust over Go.")
194        );
195    }
196
197    #[test]
198    fn parse_json_with_null_memory() {
199        let raw = r#"{"history_entry": "Routine greeting.", "memory_update": null}"#;
200        let result = parse_consolidation_response(raw, "fallback");
201        assert_eq!(result.history_entry, "Routine greeting.");
202        assert!(result.memory_update.is_none());
203    }
204
205    #[test]
206    fn parse_json_wrapped_in_code_block() {
207        let raw =
208            "```json\n{\"history_entry\": \"Discussed deployment.\", \"memory_update\": null}\n```";
209        let result = parse_consolidation_response(raw, "fallback");
210        assert_eq!(result.history_entry, "Discussed deployment.");
211    }
212
213    #[test]
214    fn fallback_on_malformed_response() {
215        let raw = "I'm sorry, I can't do that.";
216        let result = parse_consolidation_response(raw, "User: hello\nAssistant: hi");
217        assert_eq!(result.history_entry, "User: hello\nAssistant: hi");
218        assert!(result.memory_update.is_none());
219    }
220
221    #[test]
222    fn fallback_truncates_long_text() {
223        let long_text = "x".repeat(500);
224        let result = parse_consolidation_response("invalid", &long_text);
225        // 200 bytes + "…" (3 bytes in UTF-8) = 203
226        assert!(result.history_entry.len() <= 203);
227    }
228
229    #[test]
230    fn fallback_truncates_cjk_text_without_panic() {
231        // Each CJK character is 3 bytes in UTF-8; byte index 200 may land
232        // inside a character. This must not panic.
233        let cjk_text = "二手书项目".repeat(50); // 250 chars = 750 bytes
234        let result = parse_consolidation_response("invalid", &cjk_text);
235        assert!(
236            result
237                .history_entry
238                .is_char_boundary(result.history_entry.len())
239        );
240        assert!(result.history_entry.ends_with('…'));
241    }
242}