1use super::traits::{ExportFilter, Memory, MemoryCategory, MemoryEntry, ProceduralMessage};
24use anyhow::Result;
25use async_trait::async_trait;
26use std::collections::HashSet;
27use std::sync::Arc;
28
29pub struct AgentScopedMemory {
37 inner: Arc<dyn Memory>,
41 agent_id: String,
44 allowed_agent_ids: HashSet<String>,
49}
50
51impl AgentScopedMemory {
52 #[must_use]
61 pub fn new(
62 inner: Arc<dyn Memory>,
63 agent_id: impl Into<String>,
64 allowed_sibling_agent_ids: impl IntoIterator<Item = String>,
65 ) -> Self {
66 let agent_id = agent_id.into();
67 let mut allowed_agent_ids: HashSet<String> =
68 allowed_sibling_agent_ids.into_iter().collect();
69 allowed_agent_ids.insert(agent_id.clone());
70 Self {
71 inner,
72 agent_id,
73 allowed_agent_ids,
74 }
75 }
76
77 fn allowed_slice(&self) -> Vec<&str> {
81 self.allowed_agent_ids.iter().map(String::as_str).collect()
82 }
83}
84
85#[async_trait]
86impl Memory for AgentScopedMemory {
87 fn name(&self) -> &str {
88 self.inner.name()
93 }
94
95 async fn health_check(&self) -> bool {
96 self.inner.health_check().await
97 }
98
99 async fn store(
100 &self,
101 key: &str,
102 content: &str,
103 category: MemoryCategory,
104 session_id: Option<&str>,
105 ) -> Result<()> {
106 self.inner
113 .store_with_agent(
114 key,
115 content,
116 category,
117 session_id,
118 None,
119 None,
120 Some(&self.agent_id),
121 )
122 .await
123 }
124
125 async fn store_with_metadata(
126 &self,
127 key: &str,
128 content: &str,
129 category: MemoryCategory,
130 session_id: Option<&str>,
131 namespace: Option<&str>,
132 importance: Option<f64>,
133 ) -> Result<()> {
134 self.inner
135 .store_with_agent(
136 key,
137 content,
138 category,
139 session_id,
140 namespace,
141 importance,
142 Some(&self.agent_id),
143 )
144 .await
145 }
146
147 async fn store_with_agent(
148 &self,
149 key: &str,
150 content: &str,
151 category: MemoryCategory,
152 session_id: Option<&str>,
153 namespace: Option<&str>,
154 importance: Option<f64>,
155 agent_id: Option<&str>,
156 ) -> Result<()> {
157 if let Some(requested) = agent_id
163 && requested != self.agent_id
164 {
165 ::zeroclaw_log::record!(
166 WARN,
167 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
168 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
169 .with_attrs(::serde_json::json!({
170 "bound_agent": self.agent_id,
171 "requested_agent": requested,
172 "key": key,
173 })),
174 "store_with_agent refused: foreign agent_id"
175 );
176 anyhow::bail!(
177 "AgentScopedMemory refuses store_with_agent for foreign agent_id; use a wrapper bound to the target agent"
178 );
179 }
180 self.inner
181 .store_with_agent(
182 key,
183 content,
184 category,
185 session_id,
186 namespace,
187 importance,
188 Some(&self.agent_id),
189 )
190 .await
191 }
192
193 async fn recall(
194 &self,
195 query: &str,
196 limit: usize,
197 session_id: Option<&str>,
198 since: Option<&str>,
199 until: Option<&str>,
200 ) -> Result<Vec<MemoryEntry>> {
201 let allowed = self.allowed_slice();
202 self.inner
203 .recall_for_agents(&allowed, query, limit, session_id, since, until)
204 .await
205 }
206
207 async fn recall_for_agents(
208 &self,
209 caller_allowed: &[&str],
210 query: &str,
211 limit: usize,
212 session_id: Option<&str>,
213 since: Option<&str>,
214 until: Option<&str>,
215 ) -> Result<Vec<MemoryEntry>> {
216 if caller_allowed.is_empty() {
225 let bound: Vec<&str> = self.allowed_agent_ids.iter().map(String::as_str).collect();
226 return self
227 .inner
228 .recall_for_agents(&bound, query, limit, session_id, since, until)
229 .await;
230 }
231
232 let intersected: Vec<&str> = caller_allowed
233 .iter()
234 .copied()
235 .filter(|id| self.allowed_agent_ids.contains(*id))
236 .collect();
237 if intersected.is_empty() {
238 return Ok(Vec::new());
239 }
240 self.inner
241 .recall_for_agents(&intersected, query, limit, session_id, since, until)
242 .await
243 }
244
245 async fn get(&self, key: &str) -> Result<Option<MemoryEntry>> {
246 if let Some(own) = self.inner.get_for_agent(key, &self.agent_id).await? {
252 return Ok(Some(own));
253 }
254 for sibling in &self.allowed_agent_ids {
255 if sibling == &self.agent_id {
256 continue;
257 }
258 if let Some(hit) = self.inner.get_for_agent(key, sibling).await? {
259 return Ok(Some(hit));
260 }
261 }
262 Ok(None)
263 }
264
265 async fn get_for_agent(&self, key: &str, agent_id: &str) -> Result<Option<MemoryEntry>> {
266 if agent_id != self.agent_id && !self.allowed_agent_ids.iter().any(|a| a == agent_id) {
267 return Ok(None);
268 }
269 self.inner.get_for_agent(key, agent_id).await
270 }
271
272 async fn list(
273 &self,
274 category: Option<&MemoryCategory>,
275 session_id: Option<&str>,
276 ) -> Result<Vec<MemoryEntry>> {
277 let entries = self.inner.list(category, session_id).await?;
282 Ok(entries
283 .into_iter()
284 .filter(|e| {
285 e.agent_id
286 .as_deref()
287 .is_some_and(|aid| self.allowed_agent_ids.contains(aid))
288 })
289 .collect())
290 }
291
292 async fn forget(&self, key: &str) -> Result<bool> {
293 if self.inner.forget_for_agent(key, &self.agent_id).await? {
304 return Ok(true);
305 }
306 match self.inner.get(key).await? {
307 None => Ok(false),
308 Some(entry) => match entry.agent_id.as_deref() {
309 Some(other) => {
310 ::zeroclaw_log::record!(
311 WARN,
312 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
313 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
314 .with_attrs(::serde_json::json!({
315 "key": key,
316 "row_agent": other,
317 "bound_agent": self.agent_id,
318 })),
319 "forget refused: row attributed to a different agent"
320 );
321 anyhow::bail!(
322 "AgentScopedMemory refuses to forget cross-agent row: key attributed to agent other than the bound agent"
323 );
324 }
325 None => {
326 ::zeroclaw_log::record!(
327 WARN,
328 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
329 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
330 .with_attrs(::serde_json::json!({
331 "key": key,
332 "bound_agent": self.agent_id,
333 })),
334 "forget refused: row has no agent attribution"
335 );
336 anyhow::bail!(
337 "AgentScopedMemory refuses to forget unattributed row: legacy or backend without per-agent tracking; resolve via an admin Memory handle"
338 );
339 }
340 },
341 }
342 }
343
344 async fn forget_for_agent(&self, key: &str, agent_id: &str) -> Result<bool> {
345 if agent_id != self.agent_id {
348 ::zeroclaw_log::record!(
349 WARN,
350 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
351 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
352 .with_attrs(::serde_json::json!({
353 "key": key,
354 "row_agent": agent_id,
355 "bound_agent": self.agent_id,
356 })),
357 "forget_for_agent refused: cross-agent delete through wrapper"
358 );
359 anyhow::bail!(
360 "AgentScopedMemory refuses cross-agent forget_for_agent: bound agent and target agent differ"
361 );
362 }
363 self.inner.forget_for_agent(key, agent_id).await
364 }
365
366 async fn count(&self) -> Result<usize> {
367 let entries = self.inner.list(None, None).await?;
370 Ok(entries
371 .into_iter()
372 .filter(|e| {
373 e.agent_id
374 .as_deref()
375 .is_some_and(|aid| self.allowed_agent_ids.contains(aid))
376 })
377 .count())
378 }
379
380 async fn purge_namespace(&self, namespace: &str) -> Result<usize> {
381 ::zeroclaw_log::record!(
385 WARN,
386 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
387 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
388 .with_attrs(::serde_json::json!({
389 "namespace": namespace,
390 "bound_agent": self.agent_id,
391 })),
392 "purge_namespace refused: cross-agent bulk delete requires an admin Memory handle"
393 );
394 anyhow::bail!(
395 "AgentScopedMemory refuses purge_namespace: cross-agent bulk delete must run through an admin Memory handle"
396 );
397 }
398
399 async fn purge_session(&self, session_id: &str) -> Result<usize> {
400 self.inner
405 .purge_session_for_agent(session_id, &self.agent_id)
406 .await
407 }
408
409 async fn reindex(&self) -> Result<usize> {
410 self.inner.reindex().await
415 }
416
417 async fn store_procedural(
418 &self,
419 messages: &[ProceduralMessage],
420 session_id: Option<&str>,
421 ) -> Result<()> {
422 self.inner.store_procedural(messages, session_id).await
423 }
424
425 async fn recall_namespaced(
426 &self,
427 namespace: &str,
428 query: &str,
429 limit: usize,
430 session_id: Option<&str>,
431 since: Option<&str>,
432 until: Option<&str>,
433 ) -> Result<Vec<MemoryEntry>> {
434 let entries = self
441 .recall(query, limit * 2, session_id, since, until)
442 .await?;
443 Ok(entries
444 .into_iter()
445 .filter(|e| e.namespace == namespace)
446 .take(limit)
447 .collect())
448 }
449
450 async fn export(&self, filter: &ExportFilter) -> Result<Vec<MemoryEntry>> {
451 let entries = self
457 .list(filter.category.as_ref(), filter.session_id.as_deref())
458 .await?;
459 Ok(entries
460 .into_iter()
461 .filter(|e| {
462 if let Some(ref ns) = filter.namespace
463 && e.namespace != *ns
464 {
465 return false;
466 }
467 if let Some(ref since) = filter.since
468 && e.timestamp.as_str() < since.as_str()
469 {
470 return false;
471 }
472 if let Some(ref until) = filter.until
473 && e.timestamp.as_str() > until.as_str()
474 {
475 return false;
476 }
477 true
478 })
479 .collect())
480 }
481
482 async fn ensure_agent_uuid(&self, alias: &str) -> Result<String> {
483 self.inner.ensure_agent_uuid(alias).await
484 }
485}
486
487impl ::zeroclaw_api::attribution::Attributable for AgentScopedMemory {
488 fn role(&self) -> ::zeroclaw_api::attribution::Role {
489 ::zeroclaw_api::attribution::Role::Memory(
490 ::zeroclaw_api::attribution::MemoryKind::AgentScoped,
491 )
492 }
493 fn alias(&self) -> &str {
494 &self.agent_id
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use crate::sqlite::SqliteMemory;
502 use tempfile::TempDir;
503
504 fn fresh_sqlite() -> (TempDir, Arc<SqliteMemory>) {
505 let tmp = TempDir::new().unwrap();
506 let mem = SqliteMemory::new("test", tmp.path()).unwrap();
507 (tmp, Arc::new(mem))
508 }
509
510 fn as_dyn(inner: Arc<SqliteMemory>) -> Arc<dyn Memory> {
511 inner
512 }
513
514 async fn provision_agents(inner: &Arc<SqliteMemory>, aliases: &[&str]) -> Vec<String> {
519 let mut uuids = Vec::with_capacity(aliases.len());
520 for alias in aliases {
521 uuids.push(inner.ensure_agent_uuid(alias).await.unwrap());
522 }
523 uuids
524 }
525
526 #[tokio::test]
527 async fn store_routes_through_store_with_agent_and_persists_attribution() {
528 let (_tmp, inner) = fresh_sqlite();
529 let alpha = inner.ensure_agent_uuid("alpha").await.unwrap();
530 let wrapper = AgentScopedMemory::new(as_dyn(inner.clone()), &alpha, Vec::<String>::new());
531
532 wrapper
533 .store("k1", "v1", MemoryCategory::Core, None)
534 .await
535 .unwrap();
536
537 let hits = wrapper.recall("k1", 10, None, None, None).await.unwrap();
539 assert!(
540 hits.iter().any(|e| e.key == "k1"),
541 "wrapper recall must find rows it just stored"
542 );
543 }
544
545 #[tokio::test]
546 async fn recall_excludes_other_agent_rows_when_allowlist_omits_them() {
547 let (_tmp, inner) = fresh_sqlite();
548 let uuids = provision_agents(&inner, &["alpha", "other"]).await;
549 let alpha_uuid = &uuids[0];
550 let other_uuid = &uuids[1];
551
552 inner
554 .store_with_agent(
555 "other-key",
556 "other-val",
557 MemoryCategory::Core,
558 None,
559 None,
560 None,
561 Some(other_uuid),
562 )
563 .await
564 .unwrap();
565
566 let wrapper = AgentScopedMemory::new(as_dyn(inner), alpha_uuid, Vec::<String>::new());
567
568 let hits = wrapper
569 .recall("other-key", 10, None, None, None)
570 .await
571 .unwrap();
572 assert!(
573 !hits.iter().any(|e| e.key == "other-key"),
574 "rows attributed to a non-allowlisted agent must not surface"
575 );
576 }
577
578 #[tokio::test]
579 async fn recall_includes_allowlisted_sibling_rows() {
580 let (_tmp, inner) = fresh_sqlite();
581 let uuids = provision_agents(&inner, &["alpha", "beta"]).await;
582 let alpha_uuid = &uuids[0];
583 let beta_uuid = &uuids[1];
584
585 inner
586 .store_with_agent(
587 "sibling-key",
588 "sibling-val",
589 MemoryCategory::Core,
590 None,
591 None,
592 None,
593 Some(beta_uuid),
594 )
595 .await
596 .unwrap();
597
598 let wrapper = AgentScopedMemory::new(as_dyn(inner), alpha_uuid, vec![beta_uuid.clone()]);
599
600 let hits = wrapper
601 .recall("sibling-key", 10, None, None, None)
602 .await
603 .unwrap();
604 assert!(
605 hits.iter().any(|e| e.key == "sibling-key"),
606 "rows attributed to an allowlisted sibling must surface"
607 );
608 }
609
610 #[tokio::test]
611 async fn get_filters_cross_agent_rows_by_attribution() {
612 let (_tmp, inner) = fresh_sqlite();
613 let uuids = provision_agents(&inner, &["alpha", "beta"]).await;
614 let alpha_uuid = &uuids[0];
615 let beta_uuid = &uuids[1];
616
617 inner
619 .store_with_agent(
620 "beta-only",
621 "secret",
622 MemoryCategory::Core,
623 None,
624 None,
625 None,
626 Some(beta_uuid),
627 )
628 .await
629 .unwrap();
630
631 let wrapper = AgentScopedMemory::new(as_dyn(inner), alpha_uuid, Vec::<String>::new());
632
633 let hit = wrapper.get("beta-only").await.unwrap();
634 assert!(
635 hit.is_none(),
636 "get must filter out rows attributed to non-allowlisted agents"
637 );
638 }
639
640 #[tokio::test]
641 async fn forget_refuses_to_delete_sibling_rows() {
642 let (_tmp, inner) = fresh_sqlite();
643 let uuids = provision_agents(&inner, &["alpha", "beta"]).await;
644 let alpha_uuid = &uuids[0];
645 let beta_uuid = &uuids[1];
646
647 inner
650 .store_with_agent(
651 "beta-row",
652 "v",
653 MemoryCategory::Core,
654 None,
655 None,
656 None,
657 Some(beta_uuid),
658 )
659 .await
660 .unwrap();
661
662 let wrapper = AgentScopedMemory::new(as_dyn(inner), alpha_uuid, vec![beta_uuid.clone()]);
663
664 let err = wrapper
665 .forget("beta-row")
666 .await
667 .expect_err("forget must refuse cross-agent delete even with read allowlist");
668 assert!(
669 err.to_string().contains("attributed to agent"),
670 "expected sibling-attribution refusal, got: {err}"
671 );
672 }
673
674 #[tokio::test]
675 async fn list_filters_to_bound_and_allowlisted_agents() {
676 let (_tmp, inner) = fresh_sqlite();
677 let uuids = provision_agents(&inner, &["alpha", "beta", "rogue"]).await;
678 let alpha_uuid = &uuids[0];
679 let beta_uuid = &uuids[1];
680 let rogue_uuid = &uuids[2];
681
682 for (key, owner) in [("alpha-row", alpha_uuid), ("rogue-row", rogue_uuid)] {
683 inner
684 .store_with_agent(
685 key,
686 "v",
687 MemoryCategory::Core,
688 None,
689 None,
690 None,
691 Some(owner),
692 )
693 .await
694 .unwrap();
695 }
696
697 let wrapper = AgentScopedMemory::new(as_dyn(inner), alpha_uuid, vec![beta_uuid.clone()]);
698
699 let entries = wrapper.list(None, None).await.unwrap();
700 assert!(entries.iter().any(|e| e.key == "alpha-row"));
701 assert!(
702 !entries.iter().any(|e| e.key == "rogue-row"),
703 "list must drop rows attributed to non-allowlisted agents"
704 );
705 }
706
707 #[tokio::test]
708 async fn store_with_agent_refuses_foreign_agent_id() {
709 let (_tmp, inner) = fresh_sqlite();
710 let uuids = provision_agents(&inner, &["alpha", "rogue"]).await;
711 let alpha_uuid = &uuids[0];
712 let rogue_uuid = &uuids[1];
713
714 let wrapper = AgentScopedMemory::new(as_dyn(inner), alpha_uuid, Vec::<String>::new());
715
716 let err = wrapper
717 .store_with_agent(
718 "k",
719 "v",
720 MemoryCategory::Core,
721 None,
722 None,
723 None,
724 Some(rogue_uuid),
725 )
726 .await
727 .expect_err(
728 "store_with_agent must refuse a foreign agent_id rather than silently override",
729 );
730 assert!(
731 err.to_string().contains("foreign agent_id"),
732 "expected foreign-agent refusal, got: {err}"
733 );
734 }
735
736 #[tokio::test]
737 async fn purge_namespace_is_refused() {
738 let (_tmp, inner) = fresh_sqlite();
739 let alpha = inner.ensure_agent_uuid("alpha").await.unwrap();
740 let wrapper = AgentScopedMemory::new(as_dyn(inner), &alpha, Vec::<String>::new());
741
742 let err = wrapper
743 .purge_namespace("default")
744 .await
745 .expect_err("purge_namespace must be refused on a wrapper");
746 assert!(
747 err.to_string().contains("admin Memory handle"),
748 "expected admin-only refusal, got: {err}"
749 );
750 }
751
752 #[tokio::test]
753 async fn purge_session_deletes_only_bound_agent_rows_in_that_session() {
754 let (_tmp, inner) = fresh_sqlite();
755 let uuids = provision_agents(&inner, &["alpha", "beta"]).await;
756 let alpha_uuid = &uuids[0];
757 let beta_uuid = &uuids[1];
758
759 inner
760 .store_with_agent(
761 "shared-key",
762 "alpha other session",
763 MemoryCategory::Core,
764 Some("other-session"),
765 None,
766 None,
767 Some(alpha_uuid),
768 )
769 .await
770 .unwrap();
771 inner
772 .store_with_agent(
773 "shared-key",
774 "beta target session",
775 MemoryCategory::Core,
776 Some("target-session"),
777 None,
778 None,
779 Some(beta_uuid),
780 )
781 .await
782 .unwrap();
783 inner
784 .store_with_agent(
785 "alpha-target",
786 "alpha target session",
787 MemoryCategory::Core,
788 Some("target-session"),
789 None,
790 None,
791 Some(alpha_uuid),
792 )
793 .await
794 .unwrap();
795
796 let wrapper =
797 AgentScopedMemory::new(as_dyn(inner.clone()), alpha_uuid, vec![beta_uuid.clone()]);
798
799 let purged = wrapper.purge_session("target-session").await.unwrap();
800 assert_eq!(purged, 1, "only alpha's row in target-session is deleted");
801 assert!(
802 inner
803 .get_for_agent("shared-key", alpha_uuid)
804 .await
805 .unwrap()
806 .is_some(),
807 "same-key alpha row in another session must survive"
808 );
809 assert!(
810 inner
811 .get_for_agent("shared-key", beta_uuid)
812 .await
813 .unwrap()
814 .is_some(),
815 "sibling row in target-session must survive"
816 );
817 assert!(
818 inner
819 .get_for_agent("alpha-target", alpha_uuid)
820 .await
821 .unwrap()
822 .is_none(),
823 "bound agent row in target-session must be deleted"
824 );
825 }
826
827 #[tokio::test]
828 async fn recall_for_agents_intersects_caller_allowlist_with_bound_allowlist() {
829 let (_tmp, inner) = fresh_sqlite();
830 let uuids = provision_agents(&inner, &["alpha", "beta", "rogue"]).await;
831 let alpha_uuid = &uuids[0];
832 let beta_uuid = &uuids[1];
833 let rogue_uuid = &uuids[2];
834
835 inner
836 .store_with_agent(
837 "rogue-key",
838 "rogue-val",
839 MemoryCategory::Core,
840 None,
841 None,
842 None,
843 Some(rogue_uuid),
844 )
845 .await
846 .unwrap();
847
848 let wrapper = AgentScopedMemory::new(as_dyn(inner), alpha_uuid, vec![beta_uuid.clone()]);
849
850 let hits = wrapper
854 .recall_for_agents(&[rogue_uuid.as_str()], "rogue-key", 10, None, None, None)
855 .await
856 .unwrap();
857 assert!(
858 !hits.iter().any(|e| e.key == "rogue-key"),
859 "caller allowlist must be intersected, not unioned"
860 );
861 }
862}