Skip to main content

zeroclaw_memory/
qdrant.rs

1use super::embeddings::EmbeddingProvider;
2use super::traits::{Memory, MemoryCategory, MemoryEntry, is_recent_recall_query};
3use anyhow::{Context, Result};
4use async_trait::async_trait;
5use chrono::Utc;
6use serde::{Deserialize, Serialize};
7use std::collections::HashSet;
8use std::sync::Arc;
9use tokio::sync::OnceCell;
10use uuid::Uuid;
11use zeroclaw_api::session_keys::sanitize_session_key;
12
13/// Qdrant vector database memory backend.
14///
15/// Uses Qdrant's REST API for vector storage and semantic search.
16/// Requires an embedding model_provider for converting text to vectors.
17pub struct QdrantMemory {
18    alias: String,
19    client: reqwest::Client,
20    base_url: String,
21    collection: String,
22    api_key: Option<String>,
23    embedder: Arc<dyn EmbeddingProvider>,
24    /// Tracks whether collection has been initialized (lazy init for sync factory).
25    initialized: OnceCell<()>,
26}
27
28impl QdrantMemory {
29    /// Create a new Qdrant memory backend.
30    ///
31    /// # Arguments
32    /// * `url` - Qdrant server URL (e.g., `"http://localhost:6333"`)
33    /// * `collection` - Collection name for storing memories
34    /// * `api_key` - Optional API key for Qdrant Cloud
35    /// * `embedder` - Embedding model_provider for vector conversion
36    pub async fn new(
37        alias: &str,
38        url: &str,
39        collection: &str,
40        api_key: Option<String>,
41        embedder: Arc<dyn EmbeddingProvider>,
42    ) -> Result<Self> {
43        let mem = Self::new_lazy(alias, url, collection, api_key, embedder);
44
45        // Ensure collection exists with correct schema
46        mem.ensure_collection().await?;
47        if mem.embedder.dimensions() > 0 {
48            mem.migrate_session_ids_to_sanitized().await?;
49            zeroclaw_config::schema::v2::migrate_qdrant_collection_to_v3(
50                &mem.client,
51                &mem.base_url,
52                &mem.collection,
53                mem.api_key.as_deref(),
54            )
55            .await?;
56        }
57        mem.initialized.set(()).ok();
58
59        Ok(mem)
60    }
61
62    /// Create a Qdrant memory backend with lazy initialization.
63    ///
64    /// Collection will be created on first operation. Use this when calling
65    /// from a synchronous context (e.g., the memory factory).
66    pub fn new_lazy(
67        alias: &str,
68        url: &str,
69        collection: &str,
70        api_key: Option<String>,
71        embedder: Arc<dyn EmbeddingProvider>,
72    ) -> Self {
73        let base_url = url.trim_end_matches('/').to_string();
74        let client = zeroclaw_config::schema::build_runtime_proxy_client("memory.qdrant");
75
76        Self {
77            alias: alias.to_string(),
78            client,
79            base_url,
80            collection: collection.to_string(),
81            api_key,
82            embedder,
83            initialized: OnceCell::new(),
84        }
85    }
86
87    /// Ensure the collection is initialized (called lazily on first operation).
88    async fn ensure_initialized(&self) -> Result<()> {
89        self.initialized
90            .get_or_try_init(|| async {
91                self.ensure_collection().await?;
92                if self.embedder.dimensions() > 0 {
93                    self.migrate_session_ids_to_sanitized().await?;
94                    zeroclaw_config::schema::v2::migrate_qdrant_collection_to_v3(
95                        &self.client,
96                        &self.base_url,
97                        &self.collection,
98                        self.api_key.as_deref(),
99                    )
100                    .await?;
101                }
102                Ok::<(), anyhow::Error>(())
103            })
104            .await?;
105        Ok(())
106    }
107
108    fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
109        let url = format!("{}{}", self.base_url, path);
110        let mut req = self.client.request(method, &url);
111
112        if let Some(ref key) = self.api_key {
113            req = req.header("api-key", key);
114        }
115
116        req.header("Content-Type", "application/json")
117    }
118
119    /// Scroll all points whose payload `agent_id` is on the supplied
120    /// allowlist, optionally filtered by category and session_id.
121    /// Used by `recall_for_agents`'s recent/time-only branch and the
122    /// embedding-empty fallback so the agent_id check happens at the
123    /// query boundary, not after a broader fetch.
124    async fn list_for_agents(
125        &self,
126        allowed_agent_ids: &[&str],
127        category: Option<&MemoryCategory>,
128        session_id: Option<&str>,
129    ) -> Result<Vec<MemoryEntry>> {
130        self.ensure_initialized().await?;
131
132        let mut must_conditions: Vec<serde_json::Value> = Vec::new();
133        if let Some(cat) = category {
134            must_conditions.push(serde_json::json!({
135                "key": "category",
136                "match": { "value": Self::category_to_str(cat) }
137            }));
138        }
139        if let Some(sid) = session_id {
140            must_conditions.push(serde_json::json!({
141                "key": "session_id",
142                "match": { "value": sid }
143            }));
144        }
145        must_conditions.push(serde_json::json!({
146            "key": "agent_id",
147            "match": { "any": allowed_agent_ids }
148        }));
149
150        let scroll_body = serde_json::json!({
151            "limit": 1000,
152            "with_payload": true,
153            "filter": { "must": must_conditions }
154        });
155
156        let resp = self
157            .request(
158                reqwest::Method::POST,
159                &format!("/collections/{}/points/scroll", self.collection),
160            )
161            .json(&scroll_body)
162            .send()
163            .await
164            .context("failed to scroll Qdrant for allowed agent set")?;
165
166        if !resp.status().is_success() {
167            let status = resp.status();
168            let text = resp.text().await.unwrap_or_default();
169            anyhow::bail!("Qdrant scroll failed ({status}): {text}");
170        }
171
172        let result: QdrantScrollResult = resp.json().await?;
173
174        let entries = result
175            .result
176            .points
177            .into_iter()
178            .filter_map(|point| {
179                let payload = point.payload?;
180                let id = match &point.id {
181                    serde_json::Value::String(s) => s.clone(),
182                    serde_json::Value::Number(n) => n.to_string(),
183                    _ => return None,
184                };
185
186                Some(MemoryEntry {
187                    id,
188                    key: payload.key,
189                    content: payload.content,
190                    category: Self::parse_category(&payload.category),
191                    timestamp: payload.timestamp,
192                    session_id: payload.session_id,
193                    score: None,
194                    namespace: "default".into(),
195                    importance: None,
196                    superseded_by: None,
197                    agent_alias: payload.agent_id.clone(),
198                    agent_id: payload.agent_id,
199                })
200            })
201            .collect();
202
203        Ok(entries)
204    }
205
206    async fn ensure_collection(&self) -> Result<()> {
207        let dims = self.embedder.dimensions();
208        if dims == 0 {
209            // Noop embedder — skip vector collection setup
210            ::zeroclaw_log::record!(
211                WARN,
212                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
213                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
214                "Qdrant memory using noop embedder (0 dimensions); vector search disabled"
215            );
216            return Ok(());
217        }
218
219        // Check if collection exists
220        let resp = self
221            .request(
222                reqwest::Method::GET,
223                &format!("/collections/{}", self.collection),
224            )
225            .send()
226            .await;
227
228        match resp {
229            Ok(r) if r.status().is_success() => {
230                // Collection exists
231                return Ok(());
232            }
233            Ok(r) if r.status().as_u16() == 404 => {
234                // Collection doesn't exist, create it
235            }
236            Ok(r) => {
237                let status = r.status();
238                let text = r.text().await.unwrap_or_default();
239                anyhow::bail!("Qdrant collection check failed ({status}): {text}");
240            }
241            Err(e) => {
242                anyhow::bail!("Qdrant connection failed: {e}");
243            }
244        }
245
246        // Create collection with vector config
247        let create_body = serde_json::json!({
248            "vectors": {
249                "size": dims,
250                "distance": "Cosine"
251            }
252        });
253
254        let resp = self
255            .request(
256                reqwest::Method::PUT,
257                &format!("/collections/{}", self.collection),
258            )
259            .json(&create_body)
260            .send()
261            .await
262            .context("failed to create Qdrant collection")?;
263
264        if !resp.status().is_success() {
265            let status = resp.status();
266            let text = resp.text().await.unwrap_or_default();
267            anyhow::bail!("Qdrant collection creation failed ({status}): {text}");
268        }
269
270        ::zeroclaw_log::record!(
271            INFO,
272            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
273            &format!(
274                "Created Qdrant collection '{}' with {} dimensions",
275                self.collection, dims
276            )
277        );
278
279        Ok(())
280    }
281
282    /// One-shot, idempotent normalization of `payload.session_id`.
283    ///
284    /// Mirrors the SQLite-backed migration: rewrite rows that were persisted
285    /// before the orchestrator sanitized session keys at the source so the
286    /// new sanitized recall filter still matches them. Iterates the
287    /// collection with a paginated scroll, gathers distinct `session_id`
288    /// values, and issues one `set payload` per (old → new) pair where the
289    /// sanitized form differs from the stored one.
290    async fn migrate_session_ids_to_sanitized(&self) -> Result<()> {
291        let mut seen: HashSet<String> = HashSet::new();
292        let mut next_offset: Option<serde_json::Value> = None;
293
294        loop {
295            let mut scroll_body = serde_json::json!({
296                "limit": 1000,
297                "with_payload": true,
298                "with_vector": false,
299            });
300            if let Some(ref offset) = next_offset {
301                scroll_body["offset"] = offset.clone();
302            }
303
304            let resp = self
305                .request(
306                    reqwest::Method::POST,
307                    &format!("/collections/{}/points/scroll", self.collection),
308                )
309                .json(&scroll_body)
310                .send()
311                .await
312                .context("failed to scroll Qdrant for session_id migration")?;
313
314            if !resp.status().is_success() {
315                let status = resp.status();
316                let text = resp.text().await.unwrap_or_default();
317                anyhow::bail!("Qdrant scroll failed during migration ({status}): {text}");
318            }
319
320            let page: QdrantScrollResult = resp.json().await?;
321            for point in &page.result.points {
322                if let Some(ref payload) = point.payload
323                    && let Some(ref sid) = payload.session_id
324                {
325                    seen.insert(sid.clone());
326                }
327            }
328
329            match page.result.next_page_offset {
330                Some(offset) if !offset.is_null() => next_offset = Some(offset),
331                _ => break,
332            }
333        }
334
335        let mut rewritten = 0usize;
336        for old in &seen {
337            let new = sanitize_session_key(old);
338            if new == *old {
339                continue;
340            }
341
342            let body = serde_json::json!({
343                "payload": { "session_id": new },
344                "filter": {
345                    "must": [{
346                        "key": "session_id",
347                        "match": { "value": old }
348                    }]
349                }
350            });
351
352            let resp = self
353                .request(
354                    reqwest::Method::POST,
355                    &format!("/collections/{}/points/payload", self.collection),
356                )
357                .query(&[("wait", "true")])
358                .json(&body)
359                .send()
360                .await
361                .context("failed to set payload during Qdrant session_id migration")?;
362
363            if !resp.status().is_success() {
364                let status = resp.status();
365                let text = resp.text().await.unwrap_or_default();
366                anyhow::bail!("Qdrant set payload failed during migration ({status}): {text}");
367            }
368
369            rewritten += 1;
370        }
371
372        if rewritten > 0 {
373            ::zeroclaw_log::record!(
374                INFO,
375                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
376                    .with_attrs(
377                        ::serde_json::json!({"rewritten": rewritten, "collection": self.collection})
378                    ),
379                "Normalized session_id payload values in Qdrant collection to sanitized form"
380            );
381        }
382
383        Ok(())
384    }
385
386    fn category_to_str(category: &MemoryCategory) -> String {
387        match category {
388            MemoryCategory::Core => "core".to_string(),
389            MemoryCategory::Daily => "daily".to_string(),
390            MemoryCategory::Conversation => "conversation".to_string(),
391            MemoryCategory::Custom(name) => name.clone(),
392        }
393    }
394
395    fn parse_category(value: &str) -> MemoryCategory {
396        match value {
397            "core" => MemoryCategory::Core,
398            "daily" => MemoryCategory::Daily,
399            "conversation" => MemoryCategory::Conversation,
400            other => MemoryCategory::Custom(other.to_string()),
401        }
402    }
403
404    /// Build a Qdrant `must` payload filter from `(field, value)` pairs.
405    fn must_filter(fields: &[(&str, &str)]) -> serde_json::Value {
406        let must: Vec<serde_json::Value> = fields
407            .iter()
408            .map(|(field, value)| serde_json::json!({"key": field, "match": {"value": value}}))
409            .collect();
410        serde_json::json!({"must": must})
411    }
412
413    /// Scroll for the first point matching every `(field, value)` filter
414    /// pair, decoded into a `MemoryEntry`. Returns `None` when nothing
415    /// matches.
416    async fn scroll_first_matching(&self, fields: &[(&str, &str)]) -> Result<Option<MemoryEntry>> {
417        self.ensure_initialized().await?;
418
419        let scroll_body = serde_json::json!({
420            "filter": Self::must_filter(fields),
421            "limit": 1,
422            "with_payload": true,
423        });
424
425        let resp = self
426            .request(
427                reqwest::Method::POST,
428                &format!("/collections/{}/points/scroll", self.collection),
429            )
430            .json(&scroll_body)
431            .send()
432            .await
433            .context("failed to scroll Qdrant")?;
434
435        if !resp.status().is_success() {
436            let status = resp.status();
437            let text = resp.text().await.unwrap_or_default();
438            anyhow::bail!("Qdrant scroll failed ({status}): {text}");
439        }
440
441        let result: QdrantScrollResult = resp.json().await?;
442        let entry = result.result.points.into_iter().next().and_then(|point| {
443            let payload = point.payload?;
444            let id = match &point.id {
445                serde_json::Value::String(s) => s.clone(),
446                serde_json::Value::Number(n) => n.to_string(),
447                _ => return None,
448            };
449            Some(MemoryEntry {
450                id,
451                key: payload.key,
452                content: payload.content,
453                category: Self::parse_category(&payload.category),
454                timestamp: payload.timestamp,
455                session_id: payload.session_id,
456                score: None,
457                namespace: "default".into(),
458                importance: None,
459                superseded_by: None,
460                agent_alias: payload.agent_id.clone(),
461                agent_id: payload.agent_id,
462            })
463        });
464        Ok(entry)
465    }
466
467    /// Delete every point matching every `(field, value)` filter pair.
468    /// Qdrant's delete response does not expose a per-call match count,
469    /// so this returns `true` on success regardless of how many points
470    /// were touched.
471    async fn delete_points_matching(&self, fields: &[(&str, &str)]) -> Result<bool> {
472        self.ensure_initialized().await?;
473
474        let delete_body = serde_json::json!({"filter": Self::must_filter(fields)});
475        let resp = self
476            .request(
477                reqwest::Method::POST,
478                &format!("/collections/{}/points/delete", self.collection),
479            )
480            .query(&[("wait", "true")])
481            .json(&delete_body)
482            .send()
483            .await
484            .context("failed to delete from Qdrant")?;
485
486        if !resp.status().is_success() {
487            let status = resp.status();
488            let text = resp.text().await.unwrap_or_default();
489            anyhow::bail!("Qdrant delete failed ({status}): {text}");
490        }
491
492        Ok(true)
493    }
494}
495
496/// Qdrant point payload structure
497#[derive(Debug, Clone, Serialize, Deserialize)]
498struct MemoryPayload {
499    key: String,
500    content: String,
501    category: String,
502    timestamp: String,
503    #[serde(skip_serializing_if = "Option::is_none")]
504    session_id: Option<String>,
505    #[serde(skip_serializing_if = "Option::is_none")]
506    agent_id: Option<String>,
507}
508
509/// Qdrant search result
510#[derive(Debug, Deserialize)]
511struct QdrantSearchResult {
512    result: Vec<QdrantScoredPoint>,
513}
514
515#[derive(Debug, Deserialize)]
516struct QdrantScoredPoint {
517    id: serde_json::Value,
518    score: f64,
519    payload: Option<MemoryPayload>,
520}
521
522/// Qdrant scroll result
523#[derive(Debug, Deserialize)]
524struct QdrantScrollResult {
525    result: QdrantScrollPoints,
526}
527
528#[derive(Debug, Deserialize)]
529struct QdrantScrollPoints {
530    points: Vec<QdrantPoint>,
531    #[serde(default)]
532    next_page_offset: Option<serde_json::Value>,
533}
534
535#[derive(Debug, Deserialize)]
536struct QdrantPoint {
537    id: serde_json::Value,
538    payload: Option<MemoryPayload>,
539}
540
541#[async_trait]
542impl Memory for QdrantMemory {
543    fn name(&self) -> &str {
544        "qdrant"
545    }
546
547    async fn store(
548        &self,
549        key: &str,
550        content: &str,
551        category: MemoryCategory,
552        session_id: Option<&str>,
553    ) -> Result<()> {
554        self.store_with_agent(key, content, category, session_id, None, None, None)
555            .await
556    }
557
558    async fn recall(
559        &self,
560        query: &str,
561        limit: usize,
562        session_id: Option<&str>,
563        since: Option<&str>,
564        until: Option<&str>,
565    ) -> Result<Vec<MemoryEntry>> {
566        if is_recent_recall_query(query) {
567            let mut entries = self.list(None, session_id).await?;
568            if let Some(s) = since {
569                entries.retain(|e| e.timestamp.as_str() >= s);
570            }
571            if let Some(u) = until {
572                entries.retain(|e| e.timestamp.as_str() <= u);
573            }
574            entries.truncate(limit);
575            return Ok(entries);
576        }
577
578        self.ensure_initialized().await?;
579
580        // Generate embedding for the query
581        let embedding = self.embedder.embed_one(query).await?;
582
583        if embedding.is_empty() {
584            // Fallback to listing if embeddings aren't available
585            return self.list(None, session_id).await;
586        }
587
588        // Build filter for session_id if provided
589        let filter = session_id.map(|sid| {
590            serde_json::json!({
591                "must": [{
592                    "key": "session_id",
593                    "match": { "value": sid }
594                }]
595            })
596        });
597
598        let mut search_body = serde_json::json!({
599            "vector": embedding,
600            "limit": limit,
601            "with_payload": true
602        });
603
604        if let Some(f) = filter {
605            search_body["filter"] = f;
606        }
607
608        let resp = self
609            .request(
610                reqwest::Method::POST,
611                &format!("/collections/{}/points/search", self.collection),
612            )
613            .json(&search_body)
614            .send()
615            .await
616            .context("failed to search Qdrant")?;
617
618        if !resp.status().is_success() {
619            let status = resp.status();
620            let text = resp.text().await.unwrap_or_default();
621            anyhow::bail!("Qdrant search failed ({status}): {text}");
622        }
623
624        let result: QdrantSearchResult = resp.json().await?;
625
626        let mut entries: Vec<MemoryEntry> = result
627            .result
628            .into_iter()
629            .filter_map(|point| {
630                let payload = point.payload?;
631                let id = match &point.id {
632                    serde_json::Value::String(s) => s.clone(),
633                    serde_json::Value::Number(n) => n.to_string(),
634                    _ => return None,
635                };
636
637                Some(MemoryEntry {
638                    id,
639                    key: payload.key,
640                    content: payload.content,
641                    category: Self::parse_category(&payload.category),
642                    timestamp: payload.timestamp,
643                    session_id: payload.session_id,
644                    score: Some(point.score),
645                    namespace: "default".into(),
646                    importance: None,
647                    superseded_by: None,
648                    agent_alias: payload.agent_id.clone(),
649                    agent_id: payload.agent_id,
650                })
651            })
652            .collect();
653
654        // Filter by time range if specified
655        if let Some(s) = since {
656            entries.retain(|e| e.timestamp.as_str() >= s);
657        }
658        if let Some(u) = until {
659            entries.retain(|e| e.timestamp.as_str() <= u);
660        }
661
662        Ok(entries)
663    }
664
665    async fn get(&self, key: &str) -> Result<Option<MemoryEntry>> {
666        self.scroll_first_matching(&[("key", key)]).await
667    }
668
669    async fn get_for_agent(&self, key: &str, agent_id: &str) -> Result<Option<MemoryEntry>> {
670        self.scroll_first_matching(&[("key", key), ("agent_id", agent_id)])
671            .await
672    }
673
674    async fn list(
675        &self,
676        category: Option<&MemoryCategory>,
677        session_id: Option<&str>,
678    ) -> Result<Vec<MemoryEntry>> {
679        self.ensure_initialized().await?;
680
681        // Build filter conditions
682        let mut must_conditions = Vec::new();
683
684        if let Some(cat) = category {
685            must_conditions.push(serde_json::json!({
686                "key": "category",
687                "match": { "value": Self::category_to_str(cat) }
688            }));
689        }
690
691        if let Some(sid) = session_id {
692            must_conditions.push(serde_json::json!({
693                "key": "session_id",
694                "match": { "value": sid }
695            }));
696        }
697
698        let mut scroll_body = serde_json::json!({
699            "limit": 1000,
700            "with_payload": true
701        });
702
703        if !must_conditions.is_empty() {
704            scroll_body["filter"] = serde_json::json!({ "must": must_conditions });
705        }
706
707        let resp = self
708            .request(
709                reqwest::Method::POST,
710                &format!("/collections/{}/points/scroll", self.collection),
711            )
712            .json(&scroll_body)
713            .send()
714            .await
715            .context("failed to scroll Qdrant")?;
716
717        if !resp.status().is_success() {
718            let status = resp.status();
719            let text = resp.text().await.unwrap_or_default();
720            anyhow::bail!("Qdrant scroll failed ({status}): {text}");
721        }
722
723        let result: QdrantScrollResult = resp.json().await?;
724
725        let entries = result
726            .result
727            .points
728            .into_iter()
729            .filter_map(|point| {
730                let payload = point.payload?;
731                let id = match &point.id {
732                    serde_json::Value::String(s) => s.clone(),
733                    serde_json::Value::Number(n) => n.to_string(),
734                    _ => return None,
735                };
736
737                Some(MemoryEntry {
738                    id,
739                    key: payload.key,
740                    content: payload.content,
741                    category: Self::parse_category(&payload.category),
742                    timestamp: payload.timestamp,
743                    session_id: payload.session_id,
744                    score: None,
745                    namespace: "default".into(),
746                    importance: None,
747                    superseded_by: None,
748                    agent_alias: payload.agent_id.clone(),
749                    agent_id: payload.agent_id,
750                })
751            })
752            .collect();
753
754        Ok(entries)
755    }
756
757    async fn forget(&self, key: &str) -> Result<bool> {
758        self.delete_points_matching(&[("key", key)]).await
759    }
760
761    async fn forget_for_agent(&self, key: &str, agent_id: &str) -> Result<bool> {
762        // Qdrant's delete response does not expose a match count, so
763        // probe for a matching point first. Returning `false` when
764        // nothing exists keeps the bool meaningful for callers (absent
765        // and deleted are distinguishable).
766        if self
767            .scroll_first_matching(&[("key", key), ("agent_id", agent_id)])
768            .await?
769            .is_none()
770        {
771            return Ok(false);
772        }
773        self.delete_points_matching(&[("key", key), ("agent_id", agent_id)])
774            .await
775    }
776
777    async fn purge_session_for_agent(&self, session_id: &str, agent_id: &str) -> Result<usize> {
778        let matches = self
779            .list(None, Some(session_id))
780            .await?
781            .into_iter()
782            .filter(|entry| entry.agent_id.as_deref() == Some(agent_id))
783            .count();
784        if matches == 0 {
785            return Ok(0);
786        }
787        self.delete_points_matching(&[("session_id", session_id), ("agent_id", agent_id)])
788            .await?;
789        Ok(matches)
790    }
791
792    async fn count(&self) -> Result<usize> {
793        self.ensure_initialized().await?;
794
795        let resp = self
796            .request(
797                reqwest::Method::GET,
798                &format!("/collections/{}", self.collection),
799            )
800            .send()
801            .await
802            .context("failed to get Qdrant collection info")?;
803
804        if !resp.status().is_success() {
805            let status = resp.status();
806            let text = resp.text().await.unwrap_or_default();
807            anyhow::bail!("Qdrant collection info failed ({status}): {text}");
808        }
809
810        let json: serde_json::Value = resp.json().await?;
811
812        let count = json
813            .get("result")
814            .and_then(|r| r.get("points_count"))
815            .and_then(|c| c.as_u64())
816            .unwrap_or(0);
817
818        let count =
819            usize::try_from(count).context("Qdrant returned a points count that exceeds usize")?;
820        Ok(count)
821    }
822
823    async fn health_check(&self) -> bool {
824        let resp = self.request(reqwest::Method::GET, "/").send().await;
825
826        matches!(resp, Ok(r) if r.status().is_success())
827    }
828
829    async fn store_with_agent(
830        &self,
831        key: &str,
832        content: &str,
833        category: MemoryCategory,
834        session_id: Option<&str>,
835        _namespace: Option<&str>,
836        _importance: Option<f64>,
837        agent_id: Option<&str>,
838    ) -> Result<()> {
839        self.ensure_initialized().await?;
840
841        let combined_text = format!("{}\n{}", key, content);
842        let embedding = self.embedder.embed_one(&combined_text).await?;
843        if embedding.is_empty() {
844            anyhow::bail!("Qdrant requires non-zero dimensional embeddings");
845        }
846
847        let id = Uuid::new_v4().to_string();
848        let timestamp = Utc::now().to_rfc3339();
849
850        // Attribute un-scoped writes to the synthesized `default`
851        // agent so cross-agent recall's `must agent_id IN (...)` filter
852        // never sees a payload-less point as globally visible. Qdrant
853        // uses alias verbatim as agent_id (no UUID indirection at the
854        // storage layer; see `Memory::ensure_agent_uuid` default impl).
855        let resolved_agent_id = agent_id.unwrap_or("default").to_string();
856        let payload = MemoryPayload {
857            key: key.to_string(),
858            content: content.to_string(),
859            category: Self::category_to_str(&category),
860            timestamp,
861            session_id: session_id.map(str::to_string),
862            agent_id: Some(resolved_agent_id.clone()),
863        };
864
865        // Pre-upsert cleanup must scope to the writing agent so sibling
866        // points under the same key for other agents survive.
867        // Propagate failures so a cleanup error doesn't leave duplicate
868        // (agent_id, key) points after the upsert lands.
869        self.delete_points_matching(&[("key", key), ("agent_id", resolved_agent_id.as_str())])
870            .await
871            .context("qdrant pre-upsert cleanup failed")?;
872
873        let upsert_body = serde_json::json!({
874            "points": [{
875                "id": id,
876                "vector": embedding,
877                "payload": payload
878            }]
879        });
880
881        let resp = self
882            .request(
883                reqwest::Method::PUT,
884                &format!("/collections/{}/points", self.collection),
885            )
886            .query(&[("wait", "true")])
887            .json(&upsert_body)
888            .send()
889            .await
890            .context("failed to upsert point to Qdrant")?;
891
892        if !resp.status().is_success() {
893            let status = resp.status();
894            let text = resp.text().await.unwrap_or_default();
895            anyhow::bail!("Qdrant upsert failed ({status}): {text}");
896        }
897
898        Ok(())
899    }
900
901    async fn recall_for_agents(
902        &self,
903        allowed_agent_ids: &[&str],
904        query: &str,
905        limit: usize,
906        session_id: Option<&str>,
907        since: Option<&str>,
908        until: Option<&str>,
909    ) -> Result<Vec<MemoryEntry>> {
910        // Empty allowlist = no agent filter (matches the wrapper's
911        // semantics; see the SQL backends).
912        if allowed_agent_ids.is_empty() {
913            return self.recall(query, limit, session_id, since, until).await;
914        }
915
916        // Recent/time-only branch: scroll with a payload `must` filter
917        // on `agent_id` so unattributed points never reach the caller.
918        if is_recent_recall_query(query) {
919            let mut entries = self
920                .list_for_agents(allowed_agent_ids, None, session_id)
921                .await?;
922            if let Some(s) = since {
923                entries.retain(|e| e.timestamp.as_str() >= s);
924            }
925            if let Some(u) = until {
926                entries.retain(|e| e.timestamp.as_str() <= u);
927            }
928            entries.truncate(limit);
929            return Ok(entries);
930        }
931
932        self.ensure_initialized().await?;
933
934        let embedding = self.embedder.embed_one(query).await?;
935        if embedding.is_empty() {
936            // No embedding available: fall back to listing under the
937            // allowlist. Same surface as `recall`'s fallback.
938            return self
939                .list_for_agents(allowed_agent_ids, None, session_id)
940                .await;
941        }
942
943        // Build a `must` filter that combines the optional session_id
944        // with the agent_id allowlist. The agent_id filter lives in
945        // the search call, not in a post-fetch scroll: legacy points
946        // whose payload lacks `agent_id` are simply not returned (the
947        // V3 store path attributes everything to `default` if no agent
948        // is in scope, so no payload should be agent_id-less after
949        // upgrade).
950        let mut must: Vec<serde_json::Value> = Vec::new();
951        if let Some(sid) = session_id {
952            must.push(serde_json::json!({
953                "key": "session_id",
954                "match": { "value": sid }
955            }));
956        }
957        must.push(serde_json::json!({
958            "key": "agent_id",
959            "match": { "any": allowed_agent_ids }
960        }));
961
962        let search_body = serde_json::json!({
963            "vector": embedding,
964            "limit": limit,
965            "with_payload": true,
966            "filter": { "must": must }
967        });
968
969        let resp = self
970            .request(
971                reqwest::Method::POST,
972                &format!("/collections/{}/points/search", self.collection),
973            )
974            .json(&search_body)
975            .send()
976            .await
977            .context("failed to search Qdrant for allowed agent set")?;
978
979        if !resp.status().is_success() {
980            let status = resp.status();
981            let text = resp.text().await.unwrap_or_default();
982            anyhow::bail!("Qdrant search failed ({status}): {text}");
983        }
984
985        let result: QdrantSearchResult = resp.json().await?;
986
987        let mut entries: Vec<MemoryEntry> = result
988            .result
989            .into_iter()
990            .filter_map(|point| {
991                let payload = point.payload?;
992                let id = match &point.id {
993                    serde_json::Value::String(s) => s.clone(),
994                    serde_json::Value::Number(n) => n.to_string(),
995                    _ => return None,
996                };
997
998                Some(MemoryEntry {
999                    id,
1000                    key: payload.key,
1001                    content: payload.content,
1002                    category: Self::parse_category(&payload.category),
1003                    timestamp: payload.timestamp,
1004                    session_id: payload.session_id,
1005                    score: Some(point.score),
1006                    namespace: "default".into(),
1007                    importance: None,
1008                    superseded_by: None,
1009                    agent_alias: payload.agent_id.clone(),
1010                    agent_id: payload.agent_id,
1011                })
1012            })
1013            .collect();
1014
1015        if let Some(s) = since {
1016            entries.retain(|e| e.timestamp.as_str() >= s);
1017        }
1018        if let Some(u) = until {
1019            entries.retain(|e| e.timestamp.as_str() <= u);
1020        }
1021        Ok(entries)
1022    }
1023}
1024
1025impl ::zeroclaw_api::attribution::Attributable for QdrantMemory {
1026    fn role(&self) -> ::zeroclaw_api::attribution::Role {
1027        ::zeroclaw_api::attribution::Role::Memory(::zeroclaw_api::attribution::MemoryKind::Qdrant)
1028    }
1029    fn alias(&self) -> &str {
1030        &self.alias
1031    }
1032}
1033
1034#[cfg(test)]
1035mod tests {
1036    use super::*;
1037
1038    #[test]
1039    fn category_to_str_maps_known_categories() {
1040        assert_eq!(QdrantMemory::category_to_str(&MemoryCategory::Core), "core");
1041        assert_eq!(
1042            QdrantMemory::category_to_str(&MemoryCategory::Daily),
1043            "daily"
1044        );
1045        assert_eq!(
1046            QdrantMemory::category_to_str(&MemoryCategory::Conversation),
1047            "conversation"
1048        );
1049        assert_eq!(
1050            QdrantMemory::category_to_str(&MemoryCategory::Custom("notes".into())),
1051            "notes"
1052        );
1053    }
1054
1055    #[test]
1056    fn parse_category_maps_known_and_custom_values() {
1057        assert_eq!(QdrantMemory::parse_category("core"), MemoryCategory::Core);
1058        assert_eq!(QdrantMemory::parse_category("daily"), MemoryCategory::Daily);
1059        assert_eq!(
1060            QdrantMemory::parse_category("conversation"),
1061            MemoryCategory::Conversation
1062        );
1063        assert_eq!(
1064            QdrantMemory::parse_category("custom_notes"),
1065            MemoryCategory::Custom("custom_notes".into())
1066        );
1067    }
1068
1069    #[test]
1070    fn memory_payload_serializes_correctly() {
1071        let payload = MemoryPayload {
1072            key: "test_key".into(),
1073            content: "test content".into(),
1074            category: "core".into(),
1075            timestamp: "2026-02-20T00:00:00Z".into(),
1076            session_id: Some("session-1".into()),
1077            agent_id: None,
1078        };
1079
1080        let json = serde_json::to_string(&payload).unwrap();
1081        assert!(json.contains("test_key"));
1082        assert!(json.contains("test content"));
1083        assert!(json.contains("session-1"));
1084    }
1085
1086    #[test]
1087    fn memory_payload_skips_none_session_id() {
1088        let payload = MemoryPayload {
1089            key: "test_key".into(),
1090            content: "test content".into(),
1091            category: "core".into(),
1092            timestamp: "2026-02-20T00:00:00Z".into(),
1093            session_id: None,
1094            agent_id: None,
1095        };
1096
1097        let json = serde_json::to_string(&payload).unwrap();
1098        assert!(!json.contains("session_id"));
1099        assert!(!json.contains("agent_id"));
1100    }
1101}