1use async_trait::async_trait;
2
3#[async_trait]
5pub trait EmbeddingProvider: Send + Sync {
6 fn name(&self) -> &str;
8
9 fn dimensions(&self) -> usize;
11
12 async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
14
15 async fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
17 let mut results = self.embed(&[text]).await?;
18 results.pop().ok_or_else(|| {
19 ::zeroclaw_log::record!(
20 ERROR,
21 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
22 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
23 "embed_one: provider returned no embedding"
24 );
25 anyhow::Error::msg("Empty embedding result")
26 })
27 }
28}
29
30pub struct NoopEmbedding;
33
34#[async_trait]
35impl EmbeddingProvider for NoopEmbedding {
36 fn name(&self) -> &str {
37 "none"
38 }
39
40 fn dimensions(&self) -> usize {
41 0
42 }
43
44 async fn embed(&self, _texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
45 Ok(Vec::new())
46 }
47}
48
49pub struct OpenAiEmbedding {
52 base_url: String,
53 api_key: String,
54 model: String,
55 dims: usize,
56}
57
58impl OpenAiEmbedding {
59 pub fn new(base_url: &str, api_key: &str, model: &str, dims: usize) -> Self {
60 Self {
61 base_url: base_url.trim_end_matches('/').to_string(),
62 api_key: api_key.to_string(),
63 model: model.to_string(),
64 dims,
65 }
66 }
67
68 fn http_client(&self) -> reqwest::Client {
69 zeroclaw_config::schema::build_runtime_proxy_client("memory.embeddings")
70 }
71
72 fn has_explicit_api_path(&self) -> bool {
73 let Ok(url) = reqwest::Url::parse(&self.base_url) else {
74 return false;
75 };
76
77 let path = url.path().trim_end_matches('/');
78 !path.is_empty() && path != "/"
79 }
80
81 fn has_embeddings_endpoint(&self) -> bool {
82 let Ok(url) = reqwest::Url::parse(&self.base_url) else {
83 return false;
84 };
85
86 url.path().trim_end_matches('/').ends_with("/embeddings")
87 }
88
89 fn embeddings_url(&self) -> String {
90 if self.has_embeddings_endpoint() {
91 return self.base_url.clone();
92 }
93
94 if self.has_explicit_api_path() {
95 format!("{}/embeddings", self.base_url)
96 } else {
97 format!("{}/v1/embeddings", self.base_url)
98 }
99 }
100}
101
102#[async_trait]
103impl EmbeddingProvider for OpenAiEmbedding {
104 fn name(&self) -> &str {
105 "openai"
106 }
107
108 fn dimensions(&self) -> usize {
109 self.dims
110 }
111
112 async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
113 if texts.is_empty() {
114 return Ok(Vec::new());
115 }
116
117 let body = serde_json::json!({
118 "model": self.model,
119 "input": texts,
120 });
121
122 let resp = self
123 .http_client()
124 .post(self.embeddings_url())
125 .header("Authorization", format!("Bearer {}", self.api_key))
126 .header("Content-Type", "application/json")
127 .json(&body)
128 .send()
129 .await?;
130
131 if !resp.status().is_success() {
132 let status = resp.status();
133 let text = resp.text().await.unwrap_or_default();
134 anyhow::bail!("Embedding API error {status}: {text}");
135 }
136
137 let json: serde_json::Value = resp.json().await?;
138 let data = json.get("data").and_then(|d| d.as_array()).ok_or_else(|| {
139 ::zeroclaw_log::record!(
140 ERROR,
141 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
142 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
143 "embedding response missing 'data' field"
144 );
145 anyhow::Error::msg("Invalid embedding response: missing 'data'")
146 })?;
147
148 let mut embeddings = Vec::with_capacity(data.len());
149 for item in data {
150 let embedding = item
151 .get("embedding")
152 .and_then(|e| e.as_array())
153 .ok_or_else(|| {
154 ::zeroclaw_log::record!(
155 ERROR,
156 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
157 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
158 "embedding response item missing 'embedding' array"
159 );
160 anyhow::Error::msg("Invalid embedding item")
161 })?;
162
163 #[allow(clippy::cast_possible_truncation)]
164 let vec: Vec<f32> = embedding
165 .iter()
166 .filter_map(|v| v.as_f64().map(|f| f as f32))
167 .collect();
168
169 embeddings.push(vec);
170 }
171
172 Ok(embeddings)
173 }
174}
175
176pub fn create_embedding_provider(
179 model_provider: &str,
180 api_key: Option<&str>,
181 model: &str,
182 dims: usize,
183) -> Box<dyn EmbeddingProvider> {
184 match model_provider {
185 "openai" => {
186 let key = api_key.unwrap_or("");
187 Box::new(OpenAiEmbedding::new(
188 "https://api.openai.com",
189 key,
190 model,
191 dims,
192 ))
193 }
194 "openrouter" => {
195 let key = api_key.unwrap_or("");
196 Box::new(OpenAiEmbedding::new(
197 "https://openrouter.ai/api/v1",
198 key,
199 model,
200 dims,
201 ))
202 }
203 name if name.starts_with("custom:") => {
204 let base_url = name.strip_prefix("custom:").unwrap_or("");
205 let key = api_key.unwrap_or("");
206 Box::new(OpenAiEmbedding::new(base_url, key, model, dims))
207 }
208 _ => Box::new(NoopEmbedding),
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn noop_name() {
218 let p = NoopEmbedding;
219 assert_eq!(p.name(), "none");
220 assert_eq!(p.dimensions(), 0);
221 }
222
223 #[tokio::test]
224 async fn noop_embed_returns_empty() {
225 let p = NoopEmbedding;
226 let result = p.embed(&["hello"]).await.unwrap();
227 assert!(result.is_empty());
228 }
229
230 #[test]
231 fn factory_none() {
232 let p = create_embedding_provider("none", None, "model", 1536);
233 assert_eq!(p.name(), "none");
234 }
235
236 #[test]
237 fn factory_openai() {
238 let p = create_embedding_provider("openai", Some("key"), "text-embedding-3-small", 1536);
239 assert_eq!(p.name(), "openai");
240 assert_eq!(p.dimensions(), 1536);
241 }
242
243 #[test]
244 fn factory_openrouter() {
245 let p = create_embedding_provider(
246 "openrouter",
247 Some("sk-or-test"),
248 "openai/text-embedding-3-small",
249 1536,
250 );
251 assert_eq!(p.name(), "openai"); assert_eq!(p.dimensions(), 1536);
253 }
254
255 #[test]
256 fn factory_custom_url() {
257 let p = create_embedding_provider("custom:http://localhost:1234", None, "model", 768);
258 assert_eq!(p.name(), "openai"); assert_eq!(p.dimensions(), 768);
260 }
261
262 #[tokio::test]
265 async fn noop_embed_one_returns_error() {
266 let p = NoopEmbedding;
267 let result = p.embed_one("hello").await;
269 assert!(result.is_err());
270 }
271
272 #[tokio::test]
273 async fn noop_embed_empty_batch() {
274 let p = NoopEmbedding;
275 let result = p.embed(&[]).await.unwrap();
276 assert!(result.is_empty());
277 }
278
279 #[tokio::test]
280 async fn noop_embed_multiple_texts() {
281 let p = NoopEmbedding;
282 let result = p.embed(&["a", "b", "c"]).await.unwrap();
283 assert!(result.is_empty());
284 }
285
286 #[test]
287 fn factory_empty_string_returns_noop() {
288 let p = create_embedding_provider("", None, "model", 1536);
289 assert_eq!(p.name(), "none");
290 }
291
292 #[test]
293 fn factory_unknown_provider_returns_noop() {
294 let p = create_embedding_provider("cohere", None, "model", 1536);
295 assert_eq!(p.name(), "none");
296 }
297
298 #[test]
299 fn factory_custom_empty_url() {
300 let p = create_embedding_provider("custom:", None, "model", 768);
302 assert_eq!(p.name(), "openai");
303 }
304
305 #[test]
306 fn factory_openai_no_api_key() {
307 let p = create_embedding_provider("openai", None, "text-embedding-3-small", 1536);
308 assert_eq!(p.name(), "openai");
309 assert_eq!(p.dimensions(), 1536);
310 }
311
312 #[test]
313 fn openai_trailing_slash_stripped() {
314 let p = OpenAiEmbedding::new("https://api.openai.com/", "key", "model", 1536);
315 assert_eq!(p.base_url, "https://api.openai.com");
316 }
317
318 #[test]
319 fn openai_dimensions_custom() {
320 let p = OpenAiEmbedding::new("http://localhost", "k", "m", 384);
321 assert_eq!(p.dimensions(), 384);
322 }
323
324 #[test]
325 fn embeddings_url_openrouter() {
326 let p = OpenAiEmbedding::new(
327 "https://openrouter.ai/api/v1",
328 "key",
329 "openai/text-embedding-3-small",
330 1536,
331 );
332 assert_eq!(
333 p.embeddings_url(),
334 "https://openrouter.ai/api/v1/embeddings"
335 );
336 }
337
338 #[test]
339 fn embeddings_url_standard_openai() {
340 let p = OpenAiEmbedding::new("https://api.openai.com", "key", "model", 1536);
341 assert_eq!(p.embeddings_url(), "https://api.openai.com/v1/embeddings");
342 }
343
344 #[test]
345 fn embeddings_url_base_with_v1_no_duplicate() {
346 let p = OpenAiEmbedding::new("https://api.example.com/v1", "key", "model", 1536);
347 assert_eq!(p.embeddings_url(), "https://api.example.com/v1/embeddings");
348 }
349
350 #[test]
351 fn embeddings_url_non_v1_api_path_uses_raw_suffix() {
352 let p = OpenAiEmbedding::new(
353 "https://api.example.com/api/coding/v3",
354 "key",
355 "model",
356 1536,
357 );
358 assert_eq!(
359 p.embeddings_url(),
360 "https://api.example.com/api/coding/v3/embeddings"
361 );
362 }
363
364 #[test]
365 fn embeddings_url_custom_full_endpoint() {
366 let p = OpenAiEmbedding::new(
367 "https://my-api.example.com/api/v2/embeddings",
368 "key",
369 "model",
370 1536,
371 );
372 assert_eq!(
373 p.embeddings_url(),
374 "https://my-api.example.com/api/v2/embeddings"
375 );
376 }
377}