Skip to main content

zeroclaw_memory/
embeddings.rs

1use async_trait::async_trait;
2
3/// Trait for embedding model_providers — convert text to vectors
4#[async_trait]
5pub trait EmbeddingProvider: Send + Sync {
6    /// ModelProvider name
7    fn name(&self) -> &str;
8
9    /// Embedding dimensions
10    fn dimensions(&self) -> usize;
11
12    /// Embed a batch of texts into vectors
13    async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
14
15    /// Embed a single text
16    async fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
17        let mut results = self.embed(&[text]).await?;
18        results.pop().ok_or_else(|| {
19            ::zeroclaw_log::record!(
20                ERROR,
21                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
22                    .with_outcome(::zeroclaw_log::EventOutcome::Failure),
23                "embed_one: provider returned no embedding"
24            );
25            anyhow::Error::msg("Empty embedding result")
26        })
27    }
28}
29
30// ── Noop model_provider (keyword-only fallback) ────────────────────
31
32pub struct NoopEmbedding;
33
34#[async_trait]
35impl EmbeddingProvider for NoopEmbedding {
36    fn name(&self) -> &str {
37        "none"
38    }
39
40    fn dimensions(&self) -> usize {
41        0
42    }
43
44    async fn embed(&self, _texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
45        Ok(Vec::new())
46    }
47}
48
49// ── OpenAI-compatible embedding model_provider ─────────────────────
50
51pub struct OpenAiEmbedding {
52    base_url: String,
53    api_key: String,
54    model: String,
55    dims: usize,
56}
57
58impl OpenAiEmbedding {
59    pub fn new(base_url: &str, api_key: &str, model: &str, dims: usize) -> Self {
60        Self {
61            base_url: base_url.trim_end_matches('/').to_string(),
62            api_key: api_key.to_string(),
63            model: model.to_string(),
64            dims,
65        }
66    }
67
68    fn http_client(&self) -> reqwest::Client {
69        zeroclaw_config::schema::build_runtime_proxy_client("memory.embeddings")
70    }
71
72    fn has_explicit_api_path(&self) -> bool {
73        let Ok(url) = reqwest::Url::parse(&self.base_url) else {
74            return false;
75        };
76
77        let path = url.path().trim_end_matches('/');
78        !path.is_empty() && path != "/"
79    }
80
81    fn has_embeddings_endpoint(&self) -> bool {
82        let Ok(url) = reqwest::Url::parse(&self.base_url) else {
83            return false;
84        };
85
86        url.path().trim_end_matches('/').ends_with("/embeddings")
87    }
88
89    fn embeddings_url(&self) -> String {
90        if self.has_embeddings_endpoint() {
91            return self.base_url.clone();
92        }
93
94        if self.has_explicit_api_path() {
95            format!("{}/embeddings", self.base_url)
96        } else {
97            format!("{}/v1/embeddings", self.base_url)
98        }
99    }
100}
101
102#[async_trait]
103impl EmbeddingProvider for OpenAiEmbedding {
104    fn name(&self) -> &str {
105        "openai"
106    }
107
108    fn dimensions(&self) -> usize {
109        self.dims
110    }
111
112    async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
113        if texts.is_empty() {
114            return Ok(Vec::new());
115        }
116
117        let body = serde_json::json!({
118            "model": self.model,
119            "input": texts,
120        });
121
122        let resp = self
123            .http_client()
124            .post(self.embeddings_url())
125            .header("Authorization", format!("Bearer {}", self.api_key))
126            .header("Content-Type", "application/json")
127            .json(&body)
128            .send()
129            .await?;
130
131        if !resp.status().is_success() {
132            let status = resp.status();
133            let text = resp.text().await.unwrap_or_default();
134            anyhow::bail!("Embedding API error {status}: {text}");
135        }
136
137        let json: serde_json::Value = resp.json().await?;
138        let data = json.get("data").and_then(|d| d.as_array()).ok_or_else(|| {
139            ::zeroclaw_log::record!(
140                ERROR,
141                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
142                    .with_outcome(::zeroclaw_log::EventOutcome::Failure),
143                "embedding response missing 'data' field"
144            );
145            anyhow::Error::msg("Invalid embedding response: missing 'data'")
146        })?;
147
148        let mut embeddings = Vec::with_capacity(data.len());
149        for item in data {
150            let embedding = item
151                .get("embedding")
152                .and_then(|e| e.as_array())
153                .ok_or_else(|| {
154                    ::zeroclaw_log::record!(
155                        ERROR,
156                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
157                            .with_outcome(::zeroclaw_log::EventOutcome::Failure),
158                        "embedding response item missing 'embedding' array"
159                    );
160                    anyhow::Error::msg("Invalid embedding item")
161                })?;
162
163            #[allow(clippy::cast_possible_truncation)]
164            let vec: Vec<f32> = embedding
165                .iter()
166                .filter_map(|v| v.as_f64().map(|f| f as f32))
167                .collect();
168
169            embeddings.push(vec);
170        }
171
172        Ok(embeddings)
173    }
174}
175
176// ── Factory ──────────────────────────────────────────────────
177
178pub fn create_embedding_provider(
179    model_provider: &str,
180    api_key: Option<&str>,
181    model: &str,
182    dims: usize,
183) -> Box<dyn EmbeddingProvider> {
184    match model_provider {
185        "openai" => {
186            let key = api_key.unwrap_or("");
187            Box::new(OpenAiEmbedding::new(
188                "https://api.openai.com",
189                key,
190                model,
191                dims,
192            ))
193        }
194        "openrouter" => {
195            let key = api_key.unwrap_or("");
196            Box::new(OpenAiEmbedding::new(
197                "https://openrouter.ai/api/v1",
198                key,
199                model,
200                dims,
201            ))
202        }
203        name if name.starts_with("custom:") => {
204            let base_url = name.strip_prefix("custom:").unwrap_or("");
205            let key = api_key.unwrap_or("");
206            Box::new(OpenAiEmbedding::new(base_url, key, model, dims))
207        }
208        _ => Box::new(NoopEmbedding),
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn noop_name() {
218        let p = NoopEmbedding;
219        assert_eq!(p.name(), "none");
220        assert_eq!(p.dimensions(), 0);
221    }
222
223    #[tokio::test]
224    async fn noop_embed_returns_empty() {
225        let p = NoopEmbedding;
226        let result = p.embed(&["hello"]).await.unwrap();
227        assert!(result.is_empty());
228    }
229
230    #[test]
231    fn factory_none() {
232        let p = create_embedding_provider("none", None, "model", 1536);
233        assert_eq!(p.name(), "none");
234    }
235
236    #[test]
237    fn factory_openai() {
238        let p = create_embedding_provider("openai", Some("key"), "text-embedding-3-small", 1536);
239        assert_eq!(p.name(), "openai");
240        assert_eq!(p.dimensions(), 1536);
241    }
242
243    #[test]
244    fn factory_openrouter() {
245        let p = create_embedding_provider(
246            "openrouter",
247            Some("sk-or-test"),
248            "openai/text-embedding-3-small",
249            1536,
250        );
251        assert_eq!(p.name(), "openai"); // uses OpenAiEmbedding internally
252        assert_eq!(p.dimensions(), 1536);
253    }
254
255    #[test]
256    fn factory_custom_url() {
257        let p = create_embedding_provider("custom:http://localhost:1234", None, "model", 768);
258        assert_eq!(p.name(), "openai"); // uses OpenAiEmbedding internally
259        assert_eq!(p.dimensions(), 768);
260    }
261
262    // ── Edge cases ───────────────────────────────────────────────
263
264    #[tokio::test]
265    async fn noop_embed_one_returns_error() {
266        let p = NoopEmbedding;
267        // embed returns empty vec → pop() returns None → error
268        let result = p.embed_one("hello").await;
269        assert!(result.is_err());
270    }
271
272    #[tokio::test]
273    async fn noop_embed_empty_batch() {
274        let p = NoopEmbedding;
275        let result = p.embed(&[]).await.unwrap();
276        assert!(result.is_empty());
277    }
278
279    #[tokio::test]
280    async fn noop_embed_multiple_texts() {
281        let p = NoopEmbedding;
282        let result = p.embed(&["a", "b", "c"]).await.unwrap();
283        assert!(result.is_empty());
284    }
285
286    #[test]
287    fn factory_empty_string_returns_noop() {
288        let p = create_embedding_provider("", None, "model", 1536);
289        assert_eq!(p.name(), "none");
290    }
291
292    #[test]
293    fn factory_unknown_provider_returns_noop() {
294        let p = create_embedding_provider("cohere", None, "model", 1536);
295        assert_eq!(p.name(), "none");
296    }
297
298    #[test]
299    fn factory_custom_empty_url() {
300        // "custom:" with no URL — should still construct without panic
301        let p = create_embedding_provider("custom:", None, "model", 768);
302        assert_eq!(p.name(), "openai");
303    }
304
305    #[test]
306    fn factory_openai_no_api_key() {
307        let p = create_embedding_provider("openai", None, "text-embedding-3-small", 1536);
308        assert_eq!(p.name(), "openai");
309        assert_eq!(p.dimensions(), 1536);
310    }
311
312    #[test]
313    fn openai_trailing_slash_stripped() {
314        let p = OpenAiEmbedding::new("https://api.openai.com/", "key", "model", 1536);
315        assert_eq!(p.base_url, "https://api.openai.com");
316    }
317
318    #[test]
319    fn openai_dimensions_custom() {
320        let p = OpenAiEmbedding::new("http://localhost", "k", "m", 384);
321        assert_eq!(p.dimensions(), 384);
322    }
323
324    #[test]
325    fn embeddings_url_openrouter() {
326        let p = OpenAiEmbedding::new(
327            "https://openrouter.ai/api/v1",
328            "key",
329            "openai/text-embedding-3-small",
330            1536,
331        );
332        assert_eq!(
333            p.embeddings_url(),
334            "https://openrouter.ai/api/v1/embeddings"
335        );
336    }
337
338    #[test]
339    fn embeddings_url_standard_openai() {
340        let p = OpenAiEmbedding::new("https://api.openai.com", "key", "model", 1536);
341        assert_eq!(p.embeddings_url(), "https://api.openai.com/v1/embeddings");
342    }
343
344    #[test]
345    fn embeddings_url_base_with_v1_no_duplicate() {
346        let p = OpenAiEmbedding::new("https://api.example.com/v1", "key", "model", 1536);
347        assert_eq!(p.embeddings_url(), "https://api.example.com/v1/embeddings");
348    }
349
350    #[test]
351    fn embeddings_url_non_v1_api_path_uses_raw_suffix() {
352        let p = OpenAiEmbedding::new(
353            "https://api.example.com/api/coding/v3",
354            "key",
355            "model",
356            1536,
357        );
358        assert_eq!(
359            p.embeddings_url(),
360            "https://api.example.com/api/coding/v3/embeddings"
361        );
362    }
363
364    #[test]
365    fn embeddings_url_custom_full_endpoint() {
366        let p = OpenAiEmbedding::new(
367            "https://my-api.example.com/api/v2/embeddings",
368            "key",
369            "model",
370            1536,
371        );
372        assert_eq!(
373            p.embeddings_url(),
374            "https://my-api.example.com/api/v2/embeddings"
375        );
376    }
377}