Skip to main content

zeroclaw_tools/
backup_tool.rs

1use async_trait::async_trait;
2use serde_json::json;
3use sha2::{Digest, Sha256};
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use tokio::fs;
7use zeroclaw_api::tool::{Tool, ToolResult};
8
9/// Workspace backup tool: create, list, verify, and restore timestamped backups
10/// with SHA-256 manifest integrity checking.
11pub struct BackupTool {
12    workspace_dir: PathBuf,
13    include_dirs: Vec<String>,
14    max_keep: usize,
15}
16
17impl BackupTool {
18    pub fn new(workspace_dir: PathBuf, include_dirs: Vec<String>, max_keep: usize) -> Self {
19        Self {
20            workspace_dir,
21            include_dirs,
22            max_keep,
23        }
24    }
25
26    fn backups_dir(&self) -> PathBuf {
27        self.workspace_dir.join("backups")
28    }
29
30    async fn cmd_create(&self) -> anyhow::Result<ToolResult> {
31        let ts = chrono::Utc::now().format("%Y%m%dT%H%M%SZ");
32        let name = format!("backup-{ts}");
33        let backup_dir = self.backups_dir().join(&name);
34        fs::create_dir_all(&backup_dir).await?;
35
36        for sub in &self.include_dirs {
37            let src = self.workspace_dir.join(sub);
38            if src.is_dir() {
39                let dst = backup_dir.join(sub);
40                copy_dir_recursive(&src, &dst).await?;
41            }
42        }
43
44        let checksums = compute_checksums(&backup_dir).await?;
45        let file_count = checksums.len();
46        let manifest = serde_json::to_string_pretty(&checksums)?;
47        fs::write(backup_dir.join("manifest.json"), &manifest).await?;
48
49        // Enforce max_keep: remove oldest backups beyond the limit.
50        self.enforce_max_keep().await?;
51
52        Ok(ToolResult {
53            success: true,
54            output: json!({
55                "backup": name,
56                "file_count": file_count,
57            })
58            .to_string(),
59            error: None,
60        })
61    }
62
63    async fn enforce_max_keep(&self) -> anyhow::Result<()> {
64        let mut backups = self.list_backup_dirs().await?;
65        // Sorted newest-first; drop excess from the tail.
66        while backups.len() > self.max_keep {
67            if let Some(old) = backups.pop() {
68                fs::remove_dir_all(old).await?;
69            }
70        }
71        Ok(())
72    }
73
74    async fn list_backup_dirs(&self) -> anyhow::Result<Vec<PathBuf>> {
75        let dir = self.backups_dir();
76        if !dir.is_dir() {
77            return Ok(Vec::new());
78        }
79        let mut entries = Vec::new();
80        let mut rd = fs::read_dir(&dir).await?;
81        while let Some(e) = rd.next_entry().await? {
82            let p = e.path();
83            if p.is_dir() && e.file_name().to_string_lossy().starts_with("backup-") {
84                entries.push(p);
85            }
86        }
87        entries.sort();
88        entries.reverse(); // newest first
89        Ok(entries)
90    }
91
92    async fn cmd_list(&self) -> anyhow::Result<ToolResult> {
93        let dirs = self.list_backup_dirs().await?;
94        let mut items = Vec::new();
95        for d in &dirs {
96            let name = d
97                .file_name()
98                .map(|n| n.to_string_lossy().to_string())
99                .unwrap_or_default();
100            let manifest_path = d.join("manifest.json");
101            let file_count = if manifest_path.is_file() {
102                let data = fs::read_to_string(&manifest_path).await?;
103                let map: HashMap<String, String> = serde_json::from_str(&data).unwrap_or_default();
104                map.len()
105            } else {
106                0
107            };
108            let meta = fs::metadata(d).await?;
109            let created = meta
110                .created()
111                .or_else(|_| meta.modified())
112                .unwrap_or(std::time::SystemTime::UNIX_EPOCH);
113            let dt: chrono::DateTime<chrono::Utc> = created.into();
114            items.push(json!({
115                "name": name,
116                "file_count": file_count,
117                "created": dt.to_rfc3339(),
118            }));
119        }
120        Ok(ToolResult {
121            success: true,
122            output: serde_json::to_string_pretty(&items)?,
123            error: None,
124        })
125    }
126
127    async fn cmd_verify(&self, backup_name: &str) -> anyhow::Result<ToolResult> {
128        let backup_dir = self.backups_dir().join(backup_name);
129        if !backup_dir.is_dir() {
130            return Ok(ToolResult {
131                success: false,
132                output: String::new(),
133                error: Some(format!("Backup not found: {backup_name}")),
134            });
135        }
136        let manifest_path = backup_dir.join("manifest.json");
137        let data = fs::read_to_string(&manifest_path).await?;
138        let expected: HashMap<String, String> = serde_json::from_str(&data)?;
139        let actual = compute_checksums(&backup_dir).await?;
140
141        let mut mismatches = Vec::new();
142        for (path, expected_hash) in &expected {
143            match actual.get(path) {
144                Some(actual_hash) if actual_hash == expected_hash => {}
145                Some(actual_hash) => mismatches.push(json!({
146                    "file": path,
147                    "expected": expected_hash,
148                    "actual": actual_hash,
149                })),
150                None => mismatches.push(json!({
151                    "file": path,
152                    "error": "missing",
153                })),
154            }
155        }
156        let pass = mismatches.is_empty();
157        Ok(ToolResult {
158            success: pass,
159            output: json!({
160                "backup": backup_name,
161                "pass": pass,
162                "checked": expected.len(),
163                "mismatches": mismatches,
164            })
165            .to_string(),
166            error: if pass {
167                None
168            } else {
169                Some("Integrity check failed".into())
170            },
171        })
172    }
173
174    async fn cmd_restore(&self, backup_name: &str, confirm: bool) -> anyhow::Result<ToolResult> {
175        let backup_dir = self.backups_dir().join(backup_name);
176        if !backup_dir.is_dir() {
177            return Ok(ToolResult {
178                success: false,
179                output: String::new(),
180                error: Some(format!("Backup not found: {backup_name}")),
181            });
182        }
183
184        // Collect restorable subdirectories (skip manifest.json).
185        let mut restore_items: Vec<String> = Vec::new();
186        let mut rd = fs::read_dir(&backup_dir).await?;
187        while let Some(e) = rd.next_entry().await? {
188            let name = e.file_name().to_string_lossy().to_string();
189            if name == "manifest.json" {
190                continue;
191            }
192            if e.path().is_dir() {
193                restore_items.push(name);
194            }
195        }
196
197        if !confirm {
198            return Ok(ToolResult {
199                success: true,
200                output: json!({
201                    "dry_run": true,
202                    "backup": backup_name,
203                    "would_restore": restore_items,
204                })
205                .to_string(),
206                error: None,
207            });
208        }
209
210        for sub in &restore_items {
211            let src = backup_dir.join(sub);
212            let dst = self.workspace_dir.join(sub);
213            copy_dir_recursive(&src, &dst).await?;
214        }
215        Ok(ToolResult {
216            success: true,
217            output: json!({
218                "restored": backup_name,
219                "directories": restore_items,
220            })
221            .to_string(),
222            error: None,
223        })
224    }
225}
226
227#[async_trait]
228impl Tool for BackupTool {
229    fn name(&self) -> &str {
230        "backup"
231    }
232
233    fn description(&self) -> &str {
234        "Create, list, verify, and restore workspace backups"
235    }
236
237    fn parameters_schema(&self) -> serde_json::Value {
238        json!({
239            "type": "object",
240            "properties": {
241                "command": {
242                    "type": "string",
243                    "enum": ["create", "list", "verify", "restore"],
244                    "description": "Backup command to execute"
245                },
246                "backup_name": {
247                    "type": "string",
248                    "description": "Name of backup (for verify/restore)"
249                },
250                "confirm": {
251                    "type": "boolean",
252                    "description": "Confirm restore (required for actual restore, default false)"
253                }
254            },
255            "required": ["command"]
256        })
257    }
258
259    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
260        let command = match args.get("command").and_then(|v| v.as_str()) {
261            Some(c) => c,
262            None => {
263                return Ok(ToolResult {
264                    success: false,
265                    output: String::new(),
266                    error: Some("Missing 'command' parameter".into()),
267                });
268            }
269        };
270
271        match command {
272            "create" => self.cmd_create().await,
273            "list" => self.cmd_list().await,
274            "verify" => {
275                let name = args
276                    .get("backup_name")
277                    .and_then(|v| v.as_str())
278                    .ok_or_else(|| {
279                        ::zeroclaw_log::record!(
280                            WARN,
281                            ::zeroclaw_log::Event::new(
282                                module_path!(),
283                                ::zeroclaw_log::Action::Reject
284                            )
285                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
286                            .with_attrs(::serde_json::json!({
287                                "param": "backup_name",
288                                "command": "verify",
289                            })),
290                            "backup_tool: missing backup_name for verify"
291                        );
292                        anyhow::Error::msg("Missing 'backup_name' for verify")
293                    })?;
294                self.cmd_verify(name).await
295            }
296            "restore" => {
297                let name = args
298                    .get("backup_name")
299                    .and_then(|v| v.as_str())
300                    .ok_or_else(|| {
301                        ::zeroclaw_log::record!(
302                            WARN,
303                            ::zeroclaw_log::Event::new(
304                                module_path!(),
305                                ::zeroclaw_log::Action::Reject
306                            )
307                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
308                            .with_attrs(::serde_json::json!({
309                                "param": "backup_name",
310                                "command": "restore",
311                            })),
312                            "backup_tool: missing backup_name for restore"
313                        );
314                        anyhow::Error::msg("Missing 'backup_name' for restore")
315                    })?;
316                let confirm = args
317                    .get("confirm")
318                    .and_then(|v| v.as_bool())
319                    .unwrap_or(false);
320                self.cmd_restore(name, confirm).await
321            }
322            other => Ok(ToolResult {
323                success: false,
324                output: String::new(),
325                error: Some(format!("Unknown command: {other}")),
326            }),
327        }
328    }
329}
330
331// -- Helpers ------------------------------------------------------------------
332
333async fn copy_dir_recursive(src: &Path, dst: &Path) -> anyhow::Result<()> {
334    fs::create_dir_all(dst).await?;
335    let mut rd = fs::read_dir(src).await?;
336    while let Some(entry) = rd.next_entry().await? {
337        let src_path = entry.path();
338        let dst_path = dst.join(entry.file_name());
339        if src_path.is_dir() {
340            Box::pin(copy_dir_recursive(&src_path, &dst_path)).await?;
341        } else {
342            fs::copy(&src_path, &dst_path).await?;
343        }
344    }
345    Ok(())
346}
347
348async fn compute_checksums(dir: &Path) -> anyhow::Result<HashMap<String, String>> {
349    let mut map = HashMap::new();
350    let base = dir.to_path_buf();
351    walk_and_hash(&base, dir, &mut map).await?;
352    Ok(map)
353}
354
355async fn walk_and_hash(
356    base: &Path,
357    dir: &Path,
358    map: &mut HashMap<String, String>,
359) -> anyhow::Result<()> {
360    let mut rd = fs::read_dir(dir).await?;
361    while let Some(entry) = rd.next_entry().await? {
362        let path = entry.path();
363        if path.is_dir() {
364            Box::pin(walk_and_hash(base, &path, map)).await?;
365        } else {
366            let rel = path
367                .strip_prefix(base)
368                .unwrap_or(&path)
369                .to_string_lossy()
370                .replace('\\', "/");
371            if rel == "manifest.json" {
372                continue;
373            }
374            let bytes = fs::read(&path).await?;
375            let hash = hex::encode(Sha256::digest(&bytes));
376            map.insert(rel, hash);
377        }
378    }
379    Ok(())
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use tempfile::TempDir;
386
387    fn make_tool(tmp: &TempDir) -> BackupTool {
388        BackupTool::new(
389            tmp.path().to_path_buf(),
390            vec!["config".into(), "memory".into()],
391            10,
392        )
393    }
394
395    #[tokio::test]
396    async fn create_backup_produces_manifest() {
397        let tmp = TempDir::new().unwrap();
398        // Seed workspace subdirectories.
399        let cfg_dir = tmp.path().join("config");
400        std::fs::create_dir_all(&cfg_dir).unwrap();
401        std::fs::write(cfg_dir.join("a.toml"), "key = 1").unwrap();
402
403        let tool = make_tool(&tmp);
404        let res = tool.execute(json!({"command": "create"})).await.unwrap();
405        assert!(res.success, "create failed: {:?}", res.error);
406
407        let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap();
408        assert_eq!(parsed["file_count"], 1);
409
410        // Manifest should exist inside the backup directory.
411        let backup_name = parsed["backup"].as_str().unwrap();
412        let manifest = tmp
413            .path()
414            .join("backups")
415            .join(backup_name)
416            .join("manifest.json");
417        assert!(manifest.exists());
418    }
419
420    #[tokio::test]
421    async fn verify_backup_detects_corruption() {
422        let tmp = TempDir::new().unwrap();
423        let cfg_dir = tmp.path().join("config");
424        std::fs::create_dir_all(&cfg_dir).unwrap();
425        std::fs::write(cfg_dir.join("a.toml"), "original").unwrap();
426
427        let tool = make_tool(&tmp);
428        let res = tool.execute(json!({"command": "create"})).await.unwrap();
429        let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap();
430        let name = parsed["backup"].as_str().unwrap();
431
432        // Corrupt a file inside the backup.
433        let backed_up = tmp.path().join("backups").join(name).join("config/a.toml");
434        std::fs::write(&backed_up, "corrupted").unwrap();
435
436        let res = tool
437            .execute(json!({"command": "verify", "backup_name": name}))
438            .await
439            .unwrap();
440        assert!(!res.success);
441        let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
442        assert!(!v["mismatches"].as_array().unwrap().is_empty());
443    }
444
445    #[tokio::test]
446    async fn restore_requires_confirmation() {
447        let tmp = TempDir::new().unwrap();
448        let cfg_dir = tmp.path().join("config");
449        std::fs::create_dir_all(&cfg_dir).unwrap();
450        std::fs::write(cfg_dir.join("a.toml"), "v1").unwrap();
451
452        let tool = make_tool(&tmp);
453        let res = tool.execute(json!({"command": "create"})).await.unwrap();
454        let parsed: serde_json::Value = serde_json::from_str(&res.output).unwrap();
455        let name = parsed["backup"].as_str().unwrap();
456
457        // Without confirm: dry-run.
458        let res = tool
459            .execute(json!({"command": "restore", "backup_name": name}))
460            .await
461            .unwrap();
462        assert!(res.success);
463        let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
464        assert_eq!(v["dry_run"], true);
465
466        // With confirm: actual restore.
467        let res = tool
468            .execute(json!({"command": "restore", "backup_name": name, "confirm": true}))
469            .await
470            .unwrap();
471        assert!(res.success);
472        let v: serde_json::Value = serde_json::from_str(&res.output).unwrap();
473        assert!(v.get("restored").is_some());
474    }
475
476    #[tokio::test]
477    async fn list_backups_sorted_newest_first() {
478        let tmp = TempDir::new().unwrap();
479        let cfg_dir = tmp.path().join("config");
480        std::fs::create_dir_all(&cfg_dir).unwrap();
481        std::fs::write(cfg_dir.join("a.toml"), "v1").unwrap();
482
483        let tool = make_tool(&tmp);
484        tool.execute(json!({"command": "create"})).await.unwrap();
485        // Delay to ensure different second-resolution timestamps.
486        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
487        tool.execute(json!({"command": "create"})).await.unwrap();
488
489        let res = tool.execute(json!({"command": "list"})).await.unwrap();
490        assert!(res.success);
491        let items: Vec<serde_json::Value> = serde_json::from_str(&res.output).unwrap();
492        assert_eq!(items.len(), 2);
493        // Newest first by name (ISO8601 names sort lexicographically).
494        assert!(items[0]["name"].as_str().unwrap() >= items[1]["name"].as_str().unwrap());
495    }
496}