1use std::collections::HashMap;
2
3use anyhow::{Context, Result, bail};
4use async_trait::async_trait;
5use reqwest::multipart::{Form, Part};
6
7use zeroclaw_config::schema::TranscriptionConfig;
8
9const MAX_AUDIO_BYTES: usize = 25 * 1024 * 1024;
11
12const TRANSCRIPTION_TIMEOUT_SECS: u64 = 120;
14
15fn mime_for_audio(extension: &str) -> Option<&'static str> {
19 match extension.to_ascii_lowercase().as_str() {
20 "flac" => Some("audio/flac"),
21 "mp3" | "mpeg" | "mpga" => Some("audio/mpeg"),
22 "mp4" | "m4a" => Some("audio/mp4"),
23 "ogg" | "oga" => Some("audio/ogg"),
24 "opus" => Some("audio/opus"),
25 "wav" => Some("audio/wav"),
26 "webm" => Some("audio/webm"),
27 _ => None,
28 }
29}
30
31fn normalize_audio_filename(file_name: &str) -> String {
36 match file_name.rsplit_once('.') {
37 Some((stem, ext)) if ext.eq_ignore_ascii_case("oga") => format!("{stem}.ogg"),
38 _ => file_name.to_string(),
39 }
40}
41
42fn resolve_audio_format(file_name: &str) -> Result<(String, &'static str)> {
46 let normalized_name = normalize_audio_filename(file_name);
47 let extension = normalized_name
48 .rsplit_once('.')
49 .map(|(_, e)| e)
50 .unwrap_or("");
51 let mime = mime_for_audio(extension).ok_or_else(|| {
52 ::zeroclaw_log::record!(
53 WARN,
54 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
55 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
56 .with_attrs(::serde_json::json!({"extension": extension})),
57 "transcription: unsupported audio format"
58 );
59 anyhow::Error::msg(format!(
60 "Unsupported audio format '.{extension}'. \
61 accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm"
62 ))
63 })?;
64 Ok((normalized_name, mime))
65}
66
67fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
71 if audio_data.len() > MAX_AUDIO_BYTES {
72 bail!(
73 "Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
74 audio_data.len()
75 );
76 }
77 resolve_audio_format(file_name)
78}
79
80#[async_trait]
84pub trait TranscriptionProvider: Send + Sync + ::zeroclaw_api::attribution::Attributable {
85 fn name(&self) -> &str;
87
88 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String>;
91
92 fn supported_formats(&self) -> Vec<String> {
94 vec![
95 "flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
96 ]
97 .into_iter()
98 .map(String::from)
99 .collect()
100 }
101}
102
103pub struct GroqProvider {
107 alias: String,
108 api_url: String,
109 model: String,
110 api_key: String,
111 language: Option<String>,
112}
113
114impl GroqProvider {
115 pub fn from_config(alias: &str, config: &TranscriptionConfig) -> Result<Self> {
122 let api_key = config
123 .api_key
124 .as_deref()
125 .map(str::trim)
126 .filter(|v| !v.is_empty())
127 .map(ToOwned::to_owned)
128 .context(
129 "Missing transcription API key: set `[transcription].api_key` (or via the \
130 schema-mirror grammar `ZEROCLAW_transcription__api_key=...`).",
131 )?;
132
133 Ok(Self {
134 alias: alias.to_string(),
135 api_url: config.api_url.clone(),
136 model: config.model.clone(),
137 api_key,
138 language: config.language.clone(),
139 })
140 }
141}
142
143#[async_trait]
144impl TranscriptionProvider for GroqProvider {
145 fn name(&self) -> &str {
146 "groq"
147 }
148
149 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
150 let (normalized_name, mime) = validate_audio(audio_data, file_name)?;
151
152 let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.groq");
153
154 let file_part = Part::bytes(audio_data.to_vec())
155 .file_name(normalized_name)
156 .mime_str(mime)?;
157
158 let mut form = Form::new()
159 .part("file", file_part)
160 .text("model", self.model.clone())
161 .text("response_format", "json");
162
163 if let Some(ref lang) = self.language {
164 form = form.text("language", lang.clone());
165 }
166
167 let resp = client
168 .post(&self.api_url)
169 .bearer_auth(&self.api_key)
170 .multipart(form)
171 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
172 .send()
173 .await
174 .context("Failed to send transcription request to Groq")?;
175
176 parse_whisper_response(resp).await
177 }
178}
179
180pub struct OpenAiWhisperProvider {
184 alias: String,
185 api_key: String,
186 model: String,
187}
188
189impl OpenAiWhisperProvider {
190 pub fn from_config(
191 alias: &str,
192 config: &zeroclaw_config::schema::OpenAiSttConfig,
193 ) -> Result<Self> {
194 let api_key = config
195 .api_key
196 .as_deref()
197 .map(str::trim)
198 .filter(|v| !v.is_empty())
199 .map(ToOwned::to_owned)
200 .context("Missing OpenAI STT API key: set [transcription.openai].api_key")?;
201
202 Ok(Self {
203 alias: alias.to_string(),
204 api_key,
205 model: config.model.clone(),
206 })
207 }
208}
209
210#[async_trait]
211impl TranscriptionProvider for OpenAiWhisperProvider {
212 fn name(&self) -> &str {
213 "openai"
214 }
215
216 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
217 let (normalized_name, mime) = validate_audio(audio_data, file_name)?;
218
219 let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.openai");
220
221 let file_part = Part::bytes(audio_data.to_vec())
222 .file_name(normalized_name)
223 .mime_str(mime)?;
224
225 let form = Form::new()
226 .part("file", file_part)
227 .text("model", self.model.clone())
228 .text("response_format", "json");
229
230 let resp = client
231 .post("https://api.openai.com/v1/audio/transcriptions")
232 .bearer_auth(&self.api_key)
233 .multipart(form)
234 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
235 .send()
236 .await
237 .context("Failed to send transcription request to OpenAI")?;
238
239 parse_whisper_response(resp).await
240 }
241}
242
243pub struct DeepgramProvider {
247 alias: String,
248 api_key: String,
249 model: String,
250}
251
252impl DeepgramProvider {
253 pub fn from_config(
254 alias: &str,
255 config: &zeroclaw_config::schema::DeepgramSttConfig,
256 ) -> Result<Self> {
257 let api_key = config
258 .api_key
259 .as_deref()
260 .map(str::trim)
261 .filter(|v| !v.is_empty())
262 .map(ToOwned::to_owned)
263 .context("Missing Deepgram API key: set [transcription.deepgram].api_key")?;
264
265 Ok(Self {
266 alias: alias.to_string(),
267 api_key,
268 model: config.model.clone(),
269 })
270 }
271}
272
273#[async_trait]
274impl TranscriptionProvider for DeepgramProvider {
275 fn name(&self) -> &str {
276 "deepgram"
277 }
278
279 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
280 let (_, mime) = validate_audio(audio_data, file_name)?;
281
282 let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.deepgram");
283
284 let url = format!(
285 "https://api.deepgram.com/v1/listen?model={}&punctuate=true",
286 self.model
287 );
288
289 let resp = client
290 .post(&url)
291 .header("Authorization", format!("Token {}", self.api_key))
292 .header("Content-Type", mime)
293 .body(audio_data.to_vec())
294 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
295 .send()
296 .await
297 .context("Failed to send transcription request to Deepgram")?;
298
299 let status = resp.status();
300 let body: serde_json::Value = resp
301 .json()
302 .await
303 .context("Failed to parse Deepgram response")?;
304
305 if !status.is_success() {
306 let error_msg = body["err_msg"]
307 .as_str()
308 .or_else(|| body["error"].as_str())
309 .unwrap_or("unknown error");
310 bail!("Deepgram API error ({}): {}", status, error_msg);
311 }
312
313 let text = body["results"]["channels"][0]["alternatives"][0]["transcript"]
314 .as_str()
315 .context("Deepgram response missing transcript field")?
316 .to_string();
317
318 Ok(text)
319 }
320}
321
322pub struct AssemblyAiProvider {
326 alias: String,
327 api_key: String,
328}
329
330impl AssemblyAiProvider {
331 pub fn from_config(
332 alias: &str,
333 config: &zeroclaw_config::schema::AssemblyAiSttConfig,
334 ) -> Result<Self> {
335 let api_key = config
336 .api_key
337 .as_deref()
338 .map(str::trim)
339 .filter(|v| !v.is_empty())
340 .map(ToOwned::to_owned)
341 .context("Missing AssemblyAI API key: set [transcription.assemblyai].api_key")?;
342
343 Ok(Self {
344 alias: alias.to_string(),
345 api_key,
346 })
347 }
348}
349
350#[async_trait]
351impl TranscriptionProvider for AssemblyAiProvider {
352 fn name(&self) -> &str {
353 "assemblyai"
354 }
355
356 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
357 let (_, _) = validate_audio(audio_data, file_name)?;
358
359 let client =
360 zeroclaw_config::schema::build_runtime_proxy_client("transcription.assemblyai");
361
362 let upload_resp = client
364 .post("https://api.assemblyai.com/v2/upload")
365 .header("Authorization", &self.api_key)
366 .header("Content-Type", "application/octet-stream")
367 .body(audio_data.to_vec())
368 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
369 .send()
370 .await
371 .context("Failed to upload audio to AssemblyAI")?;
372
373 let upload_status = upload_resp.status();
374 let upload_body: serde_json::Value = upload_resp
375 .json()
376 .await
377 .context("Failed to parse AssemblyAI upload response")?;
378
379 if !upload_status.is_success() {
380 let error_msg = upload_body["error"].as_str().unwrap_or("unknown error");
381 bail!("AssemblyAI upload error ({}): {}", upload_status, error_msg);
382 }
383
384 let upload_url = upload_body["upload_url"]
385 .as_str()
386 .context("AssemblyAI upload response missing 'upload_url'")?;
387
388 let transcript_req = serde_json::json!({
390 "audio_url": upload_url,
391 });
392
393 let create_resp = client
394 .post("https://api.assemblyai.com/v2/transcript")
395 .header("Authorization", &self.api_key)
396 .json(&transcript_req)
397 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
398 .send()
399 .await
400 .context("Failed to create AssemblyAI transcription")?;
401
402 let create_status = create_resp.status();
403 let create_body: serde_json::Value = create_resp
404 .json()
405 .await
406 .context("Failed to parse AssemblyAI create response")?;
407
408 if !create_status.is_success() {
409 let error_msg = create_body["error"].as_str().unwrap_or("unknown error");
410 bail!(
411 "AssemblyAI transcription error ({}): {}",
412 create_status,
413 error_msg
414 );
415 }
416
417 let transcript_id = create_body["id"]
418 .as_str()
419 .context("AssemblyAI response missing 'id'")?;
420
421 let poll_url = format!("https://api.assemblyai.com/v2/transcript/{transcript_id}");
423 let poll_interval = std::time::Duration::from_secs(3);
424 let poll_deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(180);
425
426 while tokio::time::Instant::now() < poll_deadline {
427 tokio::time::sleep(poll_interval).await;
428
429 let poll_resp = client
430 .get(&poll_url)
431 .header("Authorization", &self.api_key)
432 .timeout(std::time::Duration::from_secs(30))
433 .send()
434 .await
435 .context("Failed to poll AssemblyAI transcription")?;
436
437 let poll_status = poll_resp.status();
438 let poll_body: serde_json::Value = poll_resp
439 .json()
440 .await
441 .context("Failed to parse AssemblyAI poll response")?;
442
443 if !poll_status.is_success() {
444 let error_msg = poll_body["error"].as_str().unwrap_or("unknown poll error");
445 bail!("AssemblyAI poll error ({}): {}", poll_status, error_msg);
446 }
447
448 let status_str = poll_body["status"].as_str().unwrap_or("unknown");
449
450 match status_str {
451 "completed" => {
452 let text = poll_body["text"]
453 .as_str()
454 .context("AssemblyAI response missing 'text'")?
455 .to_string();
456 return Ok(text);
457 }
458 "error" => {
459 let error_msg = poll_body["error"]
460 .as_str()
461 .unwrap_or("unknown transcription error");
462 bail!("AssemblyAI transcription failed: {}", error_msg);
463 }
464 _ => {}
465 }
466 }
467
468 bail!("AssemblyAI transcription timed out after 180s")
469 }
470}
471
472pub struct GoogleSttProvider {
476 alias: String,
477 api_key: String,
478 language_code: String,
479}
480
481impl GoogleSttProvider {
482 pub fn from_config(
483 alias: &str,
484 config: &zeroclaw_config::schema::GoogleSttConfig,
485 ) -> Result<Self> {
486 let api_key = config
487 .api_key
488 .as_deref()
489 .map(str::trim)
490 .filter(|v| !v.is_empty())
491 .map(ToOwned::to_owned)
492 .context("Missing Google STT API key: set [transcription.google].api_key")?;
493
494 Ok(Self {
495 alias: alias.to_string(),
496 api_key,
497 language_code: config.language_code.clone(),
498 })
499 }
500}
501
502#[async_trait]
503impl TranscriptionProvider for GoogleSttProvider {
504 fn name(&self) -> &str {
505 "google"
506 }
507
508 fn supported_formats(&self) -> Vec<String> {
509 vec!["flac", "wav", "ogg", "opus", "mp3", "webm"]
511 .into_iter()
512 .map(String::from)
513 .collect()
514 }
515
516 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
517 let (normalized_name, _) = validate_audio(audio_data, file_name)?;
518
519 let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.google");
520
521 let encoding = match normalized_name
522 .rsplit_once('.')
523 .map(|(_, e)| e.to_ascii_lowercase())
524 .as_deref()
525 {
526 Some("flac") => "FLAC",
527 Some("wav") => "LINEAR16",
528 Some("ogg" | "opus") => "OGG_OPUS",
529 Some("mp3") => "MP3",
530 Some("webm") => "WEBM_OPUS",
531 Some(ext) => bail!("Google STT does not support '.{ext}' input"),
532 None => bail!("Google STT requires a file extension"),
533 };
534
535 let audio_content =
536 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, audio_data);
537
538 let request_body = serde_json::json!({
539 "config": {
540 "encoding": encoding,
541 "languageCode": &self.language_code,
542 "enableAutomaticPunctuation": true,
543 },
544 "audio": {
545 "content": audio_content,
546 }
547 });
548
549 let url = format!(
550 "https://speech.googleapis.com/v1/speech:recognize?key={}",
551 self.api_key
552 );
553
554 let resp = client
555 .post(&url)
556 .json(&request_body)
557 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
558 .send()
559 .await
560 .context("Failed to send transcription request to Google STT")?;
561
562 let status = resp.status();
563 let body: serde_json::Value = resp
564 .json()
565 .await
566 .context("Failed to parse Google STT response")?;
567
568 if !status.is_success() {
569 let error_msg = body["error"]["message"].as_str().unwrap_or("unknown error");
570 bail!("Google STT API error ({}): {}", status, error_msg);
571 }
572
573 let text = body["results"][0]["alternatives"][0]["transcript"]
574 .as_str()
575 .unwrap_or("")
576 .to_string();
577
578 Ok(text)
579 }
580}
581
582pub struct LocalWhisperProvider {
591 alias: String,
592 url: String,
593 bearer_token: String,
594 max_audio_bytes: usize,
595 timeout_secs: u64,
596}
597
598impl LocalWhisperProvider {
599 pub fn from_config(
603 alias: &str,
604 config: &zeroclaw_config::schema::LocalWhisperConfig,
605 ) -> Result<Self> {
606 let url = config.url.trim().to_string();
607 anyhow::ensure!(!url.is_empty(), "local_whisper: `url` must not be empty");
608 let parsed = url
609 .parse::<reqwest::Url>()
610 .with_context(|| format!("local_whisper: invalid `url`: {url:?}"))?;
611 anyhow::ensure!(
612 matches!(parsed.scheme(), "http" | "https"),
613 "local_whisper: `url` must use http or https scheme, got {:?}",
614 parsed.scheme()
615 );
616
617 let bearer_token = match config.bearer_token.as_deref().map(str::trim) {
618 None => anyhow::bail!("local_whisper: `bearer_token` must be set"),
619 Some("") => anyhow::bail!("local_whisper: `bearer_token` must not be empty"),
620 Some(t) => t.to_string(),
621 };
622
623 anyhow::ensure!(
624 config.max_audio_bytes > 0,
625 "local_whisper: `max_audio_bytes` must be greater than zero"
626 );
627
628 anyhow::ensure!(
629 config.timeout_secs > 0,
630 "local_whisper: `timeout_secs` must be greater than zero"
631 );
632
633 Ok(Self {
634 alias: alias.to_string(),
635 url,
636 bearer_token,
637 max_audio_bytes: config.max_audio_bytes,
638 timeout_secs: config.timeout_secs,
639 })
640 }
641}
642
643#[async_trait]
644impl TranscriptionProvider for LocalWhisperProvider {
645 fn name(&self) -> &str {
646 "local_whisper"
647 }
648
649 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
650 if audio_data.len() > self.max_audio_bytes {
651 bail!(
652 "Audio file too large ({} bytes, local_whisper max {})",
653 audio_data.len(),
654 self.max_audio_bytes
655 );
656 }
657
658 let (normalized_name, mime) = resolve_audio_format(file_name)?;
659
660 let client =
661 zeroclaw_config::schema::build_runtime_proxy_client("transcription.local_whisper");
662
663 let file_part = Part::bytes(audio_data.to_vec())
667 .file_name(normalized_name)
668 .mime_str(mime)?;
669
670 let resp = client
671 .post(&self.url)
672 .bearer_auth(&self.bearer_token)
673 .multipart(Form::new().part("file", file_part))
674 .timeout(std::time::Duration::from_secs(self.timeout_secs))
675 .send()
676 .await
677 .context("Failed to send audio to local Whisper endpoint")?;
678
679 parse_whisper_response(resp).await
680 }
681}
682
683async fn parse_whisper_response(resp: reqwest::Response) -> Result<String> {
691 let status = resp.status();
692 if !status.is_success() {
693 let body = resp.text().await.unwrap_or_default();
694 bail!("Transcription API error ({}): {}", status, body.trim());
695 }
696
697 let body: serde_json::Value = resp
698 .json()
699 .await
700 .context("Failed to parse transcription response")?;
701
702 let text = body["text"]
703 .as_str()
704 .context("Transcription response missing 'text' field")?
705 .to_string();
706
707 Ok(text)
708}
709
710pub struct TranscriptionManager {
717 transcription_providers: HashMap<String, Box<dyn TranscriptionProvider>>,
718 agent_transcription_provider: String,
721}
722
723impl TranscriptionManager {
724 pub fn new(config: &TranscriptionConfig) -> Result<Self> {
729 let mut transcription_providers: HashMap<String, Box<dyn TranscriptionProvider>> =
730 HashMap::new();
731
732 if let Ok(groq) = GroqProvider::from_config("groq", config) {
733 transcription_providers.insert("groq".to_string(), Box::new(groq));
734 }
735
736 if let Some(ref openai_cfg) = config.openai
737 && let Ok(p) = OpenAiWhisperProvider::from_config("openai", openai_cfg)
738 {
739 transcription_providers.insert("openai".to_string(), Box::new(p));
740 }
741
742 if let Some(ref deepgram_cfg) = config.deepgram
743 && let Ok(p) = DeepgramProvider::from_config("deepgram", deepgram_cfg)
744 {
745 transcription_providers.insert("deepgram".to_string(), Box::new(p));
746 }
747
748 if let Some(ref assemblyai_cfg) = config.assemblyai
749 && let Ok(p) = AssemblyAiProvider::from_config("assemblyai", assemblyai_cfg)
750 {
751 transcription_providers.insert("assemblyai".to_string(), Box::new(p));
752 }
753
754 if let Some(ref google_cfg) = config.google
755 && let Ok(p) = GoogleSttProvider::from_config("google", google_cfg)
756 {
757 transcription_providers.insert("google".to_string(), Box::new(p));
758 }
759
760 if let Some(ref local_cfg) = config.local_whisper {
761 match LocalWhisperProvider::from_config("local_whisper", local_cfg) {
762 Ok(p) => {
763 transcription_providers.insert("local_whisper".to_string(), Box::new(p));
764 }
765 Err(e) => {
766 ::zeroclaw_log::record!(
767 WARN,
768 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
769 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
770 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
771 "local_whisper config invalid, provider skipped"
772 );
773 }
774 }
775 }
776
777 if config.enabled && transcription_providers.is_empty() {
778 bail!(
779 "Transcription is enabled but no transcription provider registered \
780 successfully. Configure at least one of: [transcription] (Groq) \
781 with api_key + api_url; [transcription.openai]; [transcription.deepgram]; \
782 [transcription.assemblyai]; [transcription.google]; [transcription.local_whisper]."
783 );
784 }
785
786 Ok(Self {
787 transcription_providers,
788 agent_transcription_provider: String::new(),
789 })
790 }
791
792 #[must_use]
796 pub fn with_agent_transcription_provider(mut self, alias: impl Into<String>) -> Self {
797 self.agent_transcription_provider = alias.into();
798 self
799 }
800
801 pub async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
805 let provider_alias = self.agent_transcription_provider.as_str();
806 if provider_alias.is_empty() {
807 bail!(
808 "Agent has no transcription_provider configured. Set \
809 `agent.<alias>.transcription_provider = \"<type>.<alias>\"` \
810 referencing a configured transcription provider."
811 );
812 }
813 self.transcribe_with_provider(audio_data, file_name, provider_alias)
814 .await
815 }
816
817 pub async fn transcribe_with_provider(
819 &self,
820 audio_data: &[u8],
821 file_name: &str,
822 transcription_provider: &str,
823 ) -> Result<String> {
824 let p = self.transcription_providers.get(transcription_provider).ok_or_else(|| {
825 let available: Vec<&str> = self.transcription_providers.keys().map(|k| k.as_str()).collect();
826 ::zeroclaw_log::record!(
827 ERROR,
828 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
829 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
830 .with_attrs(::serde_json::json!({
831 "transcription_provider": transcription_provider,
832 "available": available,
833 })),
834 "transcription: provider not configured"
835 );
836 anyhow::Error::msg(format!(
837 "Transcription transcription_provider '{transcription_provider}' not configured. Available: {available:?}"
838 ))
839 })?;
840
841 use ::zeroclaw_log::Instrument;
842 let span = ::zeroclaw_log::attribution_span!(p.as_ref());
843 p.transcribe(audio_data, file_name).instrument(span).await
844 }
845
846 pub fn available_providers(&self) -> Vec<&str> {
848 self.transcription_providers
849 .keys()
850 .map(|k| k.as_str())
851 .collect()
852 }
853}
854
855impl ::zeroclaw_api::attribution::Attributable for GroqProvider {
862 fn role(&self) -> ::zeroclaw_api::attribution::Role {
863 ::zeroclaw_api::attribution::Role::Provider(
864 ::zeroclaw_api::attribution::ProviderKind::Transcription(
865 ::zeroclaw_api::attribution::TranscriptionProviderKind::Groq,
866 ),
867 )
868 }
869 fn alias(&self) -> &str {
870 &self.alias
871 }
872}
873
874impl ::zeroclaw_api::attribution::Attributable for OpenAiWhisperProvider {
875 fn role(&self) -> ::zeroclaw_api::attribution::Role {
876 ::zeroclaw_api::attribution::Role::Provider(
877 ::zeroclaw_api::attribution::ProviderKind::Transcription(
878 ::zeroclaw_api::attribution::TranscriptionProviderKind::OpenAi,
879 ),
880 )
881 }
882 fn alias(&self) -> &str {
883 &self.alias
884 }
885}
886
887impl ::zeroclaw_api::attribution::Attributable for DeepgramProvider {
888 fn role(&self) -> ::zeroclaw_api::attribution::Role {
889 ::zeroclaw_api::attribution::Role::Provider(
890 ::zeroclaw_api::attribution::ProviderKind::Transcription(
891 ::zeroclaw_api::attribution::TranscriptionProviderKind::Deepgram,
892 ),
893 )
894 }
895 fn alias(&self) -> &str {
896 &self.alias
897 }
898}
899
900impl ::zeroclaw_api::attribution::Attributable for AssemblyAiProvider {
901 fn role(&self) -> ::zeroclaw_api::attribution::Role {
902 ::zeroclaw_api::attribution::Role::Provider(
903 ::zeroclaw_api::attribution::ProviderKind::Transcription(
904 ::zeroclaw_api::attribution::TranscriptionProviderKind::AssemblyAi,
905 ),
906 )
907 }
908 fn alias(&self) -> &str {
909 &self.alias
910 }
911}
912
913impl ::zeroclaw_api::attribution::Attributable for GoogleSttProvider {
914 fn role(&self) -> ::zeroclaw_api::attribution::Role {
915 ::zeroclaw_api::attribution::Role::Provider(
916 ::zeroclaw_api::attribution::ProviderKind::Transcription(
917 ::zeroclaw_api::attribution::TranscriptionProviderKind::Google,
918 ),
919 )
920 }
921 fn alias(&self) -> &str {
922 &self.alias
923 }
924}
925
926impl ::zeroclaw_api::attribution::Attributable for LocalWhisperProvider {
927 fn role(&self) -> ::zeroclaw_api::attribution::Role {
928 ::zeroclaw_api::attribution::Role::Provider(
929 ::zeroclaw_api::attribution::ProviderKind::Transcription(
930 ::zeroclaw_api::attribution::TranscriptionProviderKind::Whisper,
931 ),
932 )
933 }
934 fn alias(&self) -> &str {
935 &self.alias
936 }
937}
938
939#[cfg(test)]
940mod tests {
941 use super::*;
942
943 #[test]
949 fn mime_for_audio_maps_accepted_formats() {
950 let cases = [
951 ("flac", "audio/flac"),
952 ("mp3", "audio/mpeg"),
953 ("mpeg", "audio/mpeg"),
954 ("mpga", "audio/mpeg"),
955 ("mp4", "audio/mp4"),
956 ("m4a", "audio/mp4"),
957 ("ogg", "audio/ogg"),
958 ("oga", "audio/ogg"),
959 ("opus", "audio/opus"),
960 ("wav", "audio/wav"),
961 ("webm", "audio/webm"),
962 ];
963 for (ext, expected) in cases {
964 assert_eq!(
965 mime_for_audio(ext),
966 Some(expected),
967 "failed for extension: {ext}"
968 );
969 }
970 }
971
972 #[test]
973 fn mime_for_audio_case_insensitive() {
974 assert_eq!(mime_for_audio("OGG"), Some("audio/ogg"));
975 assert_eq!(mime_for_audio("MP3"), Some("audio/mpeg"));
976 assert_eq!(mime_for_audio("Opus"), Some("audio/opus"));
977 }
978
979 #[test]
980 fn mime_for_audio_rejects_unknown() {
981 assert_eq!(mime_for_audio("txt"), None);
982 assert_eq!(mime_for_audio("pdf"), None);
983 assert_eq!(mime_for_audio("aac"), None);
984 assert_eq!(mime_for_audio(""), None);
985 }
986
987 #[test]
988 fn normalize_audio_filename_rewrites_oga() {
989 assert_eq!(normalize_audio_filename("voice.oga"), "voice.ogg");
990 assert_eq!(normalize_audio_filename("file.OGA"), "file.ogg");
991 }
992
993 #[test]
994 fn normalize_audio_filename_preserves_accepted() {
995 assert_eq!(normalize_audio_filename("voice.ogg"), "voice.ogg");
996 assert_eq!(normalize_audio_filename("track.mp3"), "track.mp3");
997 assert_eq!(normalize_audio_filename("clip.opus"), "clip.opus");
998 }
999
1000 #[test]
1001 fn normalize_audio_filename_no_extension() {
1002 assert_eq!(normalize_audio_filename("voice"), "voice");
1003 }
1004
1005 #[test]
1006 fn rejects_unsupported_audio_format() {
1007 let data = vec![0u8; 100];
1010 let err = validate_audio(&data, "recording.aac").unwrap_err();
1011 let msg = err.to_string();
1012 assert!(
1013 msg.contains("Unsupported audio format"),
1014 "expected unsupported-format error, got: {msg}"
1015 );
1016 assert!(
1017 msg.contains(".aac"),
1018 "error should mention the rejected extension, got: {msg}"
1019 );
1020 }
1021
1022 #[test]
1025 fn manager_creation_with_default_config() {
1026 unsafe { std::env::remove_var("GROQ_API_KEY") };
1028
1029 let config = TranscriptionConfig::default();
1030 let manager = TranscriptionManager::new(&config).unwrap();
1031 assert!(manager.agent_transcription_provider.is_empty());
1035 assert!(manager.transcription_providers.is_empty());
1037 }
1038
1039 #[test]
1040 fn manager_registers_groq_with_key() {
1041 unsafe { std::env::remove_var("GROQ_API_KEY") };
1043
1044 let config = TranscriptionConfig {
1045 api_key: Some("test-groq-key".to_string()),
1046 ..TranscriptionConfig::default()
1047 };
1048
1049 let manager = TranscriptionManager::new(&config).unwrap();
1050 assert!(manager.transcription_providers.contains_key("groq"));
1051 assert_eq!(manager.transcription_providers["groq"].name(), "groq");
1052 }
1053
1054 #[test]
1055 fn manager_registers_multiple_providers() {
1056 unsafe { std::env::remove_var("GROQ_API_KEY") };
1058
1059 let config = TranscriptionConfig {
1060 api_key: Some("test-groq-key".to_string()),
1061 openai: Some(zeroclaw_config::schema::OpenAiSttConfig {
1062 api_key: Some("test-openai-key".to_string()),
1063 model: "whisper-1".to_string(),
1064 }),
1065 deepgram: Some(zeroclaw_config::schema::DeepgramSttConfig {
1066 api_key: Some("test-deepgram-key".to_string()),
1067 model: "nova-2".to_string(),
1068 }),
1069 ..TranscriptionConfig::default()
1070 };
1071
1072 let manager = TranscriptionManager::new(&config).unwrap();
1073 assert!(manager.transcription_providers.contains_key("groq"));
1074 assert!(manager.transcription_providers.contains_key("openai"));
1075 assert!(manager.transcription_providers.contains_key("deepgram"));
1076 assert_eq!(manager.available_providers().len(), 3);
1077 }
1078
1079 #[tokio::test]
1080 async fn manager_rejects_unconfigured_provider() {
1081 unsafe { std::env::remove_var("GROQ_API_KEY") };
1083
1084 let config = TranscriptionConfig {
1085 api_key: Some("test-groq-key".to_string()),
1086 ..TranscriptionConfig::default()
1087 };
1088
1089 let manager = TranscriptionManager::new(&config).unwrap();
1090 let err = manager
1091 .transcribe_with_provider(&[0u8; 100], "test.ogg", "nonexistent")
1092 .await
1093 .unwrap_err();
1094 assert!(
1095 err.to_string().contains("not configured"),
1096 "expected not-configured error, got: {err}"
1097 );
1098 }
1099
1100 #[test]
1101 fn manager_agent_transcription_provider_via_setter() {
1102 unsafe { std::env::remove_var("GROQ_API_KEY") };
1104
1105 let config = TranscriptionConfig {
1106 openai: Some(zeroclaw_config::schema::OpenAiSttConfig {
1107 api_key: Some("test-openai-key".to_string()),
1108 model: "whisper-1".to_string(),
1109 }),
1110 ..TranscriptionConfig::default()
1111 };
1112
1113 let manager = TranscriptionManager::new(&config)
1114 .unwrap()
1115 .with_agent_transcription_provider("openai");
1116 assert_eq!(manager.agent_transcription_provider, "openai");
1117 }
1118
1119 #[test]
1120 fn validate_audio_rejects_oversized() {
1121 let big = vec![0u8; MAX_AUDIO_BYTES + 1];
1122 let err = validate_audio(&big, "test.ogg").unwrap_err();
1123 assert!(err.to_string().contains("too large"));
1124 }
1125
1126 #[test]
1127 fn validate_audio_rejects_unsupported_format() {
1128 let data = vec![0u8; 100];
1129 let err = validate_audio(&data, "test.aac").unwrap_err();
1130 assert!(err.to_string().contains("Unsupported audio format"));
1131 }
1132
1133 #[test]
1134 fn validate_audio_accepts_supported_format() {
1135 let data = vec![0u8; 100];
1136 let (name, mime) = validate_audio(&data, "test.ogg").unwrap();
1137 assert_eq!(name, "test.ogg");
1138 assert_eq!(mime, "audio/ogg");
1139 }
1140
1141 #[test]
1142 fn validate_audio_normalizes_oga() {
1143 let data = vec![0u8; 100];
1144 let (name, mime) = validate_audio(&data, "voice.oga").unwrap();
1145 assert_eq!(name, "voice.ogg");
1146 assert_eq!(mime, "audio/ogg");
1147 }
1148
1149 #[test]
1150 fn backward_compat_config_defaults_unchanged() {
1151 let config = TranscriptionConfig::default();
1152 assert!(!config.enabled);
1153 assert!(config.api_key.is_none());
1154 assert!(config.api_url.contains("groq.com"));
1155 assert_eq!(config.model, "whisper-large-v3-turbo");
1156 assert!(config.openai.is_none());
1159 assert!(config.deepgram.is_none());
1160 assert!(config.assemblyai.is_none());
1161 assert!(config.google.is_none());
1162 assert!(config.local_whisper.is_none());
1163 assert!(!config.transcribe_non_ptt_audio);
1164 }
1165
1166 fn local_whisper_config(url: &str) -> zeroclaw_config::schema::LocalWhisperConfig {
1169 zeroclaw_config::schema::LocalWhisperConfig {
1170 url: url.to_string(),
1171 bearer_token: Some("test-token".to_string()),
1172 max_audio_bytes: 10 * 1024 * 1024,
1173 timeout_secs: 30,
1174 }
1175 }
1176
1177 #[test]
1178 fn local_whisper_rejects_empty_url() {
1179 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1180 cfg.url = String::new();
1181 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1182 .err()
1183 .unwrap();
1184 assert!(
1185 err.to_string().contains("`url` must not be empty"),
1186 "got: {err}"
1187 );
1188 }
1189
1190 #[test]
1191 fn local_whisper_rejects_invalid_url() {
1192 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1193 cfg.url = "not-a-url".to_string();
1194 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1195 .err()
1196 .unwrap();
1197 assert!(err.to_string().contains("invalid `url`"), "got: {err}");
1198 }
1199
1200 #[test]
1201 fn local_whisper_rejects_non_http_url() {
1202 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1203 cfg.url = "ftp://10.10.0.1:8001/v1/transcribe".to_string();
1204 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1205 .err()
1206 .unwrap();
1207 assert!(err.to_string().contains("http or https"), "got: {err}");
1208 }
1209
1210 #[test]
1211 fn local_whisper_rejects_empty_bearer_token() {
1212 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1213 cfg.bearer_token = Some(String::new());
1214 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1215 .err()
1216 .unwrap();
1217 assert!(
1218 err.to_string().contains("`bearer_token` must not be empty"),
1219 "got: {err}"
1220 );
1221 }
1222
1223 #[test]
1224 fn local_whisper_rejects_missing_bearer_token() {
1225 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1226 cfg.bearer_token = None;
1227 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1228 .err()
1229 .unwrap();
1230 assert!(
1231 err.to_string().contains("`bearer_token` must be set"),
1232 "got: {err}"
1233 );
1234 }
1235
1236 #[test]
1237 fn local_whisper_rejects_zero_max_audio_bytes() {
1238 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1239 cfg.max_audio_bytes = 0;
1240 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1241 .err()
1242 .unwrap();
1243 assert!(
1244 err.to_string()
1245 .contains("`max_audio_bytes` must be greater than zero"),
1246 "got: {err}"
1247 );
1248 }
1249
1250 #[test]
1251 fn local_whisper_rejects_zero_timeout() {
1252 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1253 cfg.timeout_secs = 0;
1254 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1255 .err()
1256 .unwrap();
1257 assert!(
1258 err.to_string()
1259 .contains("`timeout_secs` must be greater than zero"),
1260 "got: {err}"
1261 );
1262 }
1263
1264 #[test]
1265 fn local_whisper_registered_when_config_present() {
1266 let config = TranscriptionConfig {
1267 local_whisper: Some(local_whisper_config("http://127.0.0.1:9999/v1/transcribe")),
1268 ..TranscriptionConfig::default()
1269 };
1270
1271 let manager = TranscriptionManager::new(&config).unwrap();
1272 assert!(
1273 manager.available_providers().contains(&"local_whisper"),
1274 "expected local_whisper in {:?}",
1275 manager.available_providers()
1276 );
1277 }
1278
1279 #[test]
1280 fn local_whisper_misconfigured_section_fails_manager_construction() {
1281 let mut bad_cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1286 bad_cfg.bearer_token = Some(String::new());
1287 let config = TranscriptionConfig {
1288 local_whisper: Some(bad_cfg),
1289 enabled: true,
1290 ..TranscriptionConfig::default()
1291 };
1292
1293 let err = TranscriptionManager::new(&config).err().unwrap();
1294 assert!(
1295 err.to_string()
1296 .contains("no transcription provider registered"),
1297 "expected 'no transcription provider registered' from manager safety net, got: {err}"
1298 );
1299 }
1300
1301 #[test]
1302 fn validate_audio_still_enforces_25mb_cap() {
1303 let at_limit = vec![0u8; MAX_AUDIO_BYTES];
1305 assert!(validate_audio(&at_limit, "test.ogg").is_ok());
1306 let over_limit = vec![0u8; MAX_AUDIO_BYTES + 1];
1307 let err = validate_audio(&over_limit, "test.ogg").unwrap_err();
1308 assert!(err.to_string().contains("too large"));
1309 }
1310
1311 #[tokio::test]
1312 async fn local_whisper_rejects_oversized_audio() {
1313 let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1314 let transcription_provider =
1315 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1316 let big = vec![0u8; cfg.max_audio_bytes + 1];
1317 let err = transcription_provider
1318 .transcribe(&big, "voice.ogg")
1319 .await
1320 .unwrap_err();
1321 assert!(err.to_string().contains("too large"), "got: {err}");
1322 }
1323
1324 #[tokio::test]
1325 async fn local_whisper_rejects_unsupported_format() {
1326 let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1327 let transcription_provider =
1328 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1329 let data = vec![0u8; 100];
1330 let err = transcription_provider
1331 .transcribe(&data, "voice.aiff")
1332 .await
1333 .unwrap_err();
1334 assert!(
1335 err.to_string().contains("Unsupported audio format"),
1336 "got: {err}"
1337 );
1338 }
1339
1340 #[tokio::test]
1343 async fn local_whisper_returns_text_from_response() {
1344 use wiremock::matchers::{header_exists, method, path};
1345 use wiremock::{Mock, MockServer, ResponseTemplate};
1346
1347 let server = MockServer::start().await;
1348
1349 Mock::given(method("POST"))
1350 .and(path("/v1/transcribe"))
1351 .and(header_exists("authorization"))
1352 .respond_with(
1353 ResponseTemplate::new(200)
1354 .set_body_json(serde_json::json!({"text": "hello world"})),
1355 )
1356 .mount(&server)
1357 .await;
1358
1359 let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1360 let transcription_provider =
1361 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1362
1363 let result = transcription_provider
1364 .transcribe(b"fake-audio", "voice.ogg")
1365 .await
1366 .unwrap();
1367 assert_eq!(result, "hello world");
1368 }
1369
1370 #[tokio::test]
1371 async fn local_whisper_sends_bearer_auth_header() {
1372 use wiremock::matchers::{header, method, path};
1373 use wiremock::{Mock, MockServer, ResponseTemplate};
1374
1375 let server = MockServer::start().await;
1376
1377 Mock::given(method("POST"))
1378 .and(path("/v1/transcribe"))
1379 .and(header("authorization", "Bearer test-token"))
1380 .respond_with(
1381 ResponseTemplate::new(200).set_body_json(serde_json::json!({"text": "auth ok"})),
1382 )
1383 .mount(&server)
1384 .await;
1385
1386 let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1387 let transcription_provider =
1388 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1389
1390 let result = transcription_provider
1391 .transcribe(b"fake-audio", "voice.ogg")
1392 .await
1393 .unwrap();
1394 assert_eq!(result, "auth ok");
1395 }
1396
1397 #[tokio::test]
1398 async fn local_whisper_propagates_http_error() {
1399 use wiremock::matchers::{method, path};
1400 use wiremock::{Mock, MockServer, ResponseTemplate};
1401
1402 let server = MockServer::start().await;
1403
1404 Mock::given(method("POST"))
1405 .and(path("/v1/transcribe"))
1406 .respond_with(
1407 ResponseTemplate::new(503).set_body_json(
1408 serde_json::json!({"error": {"message": "service unavailable"}}),
1409 ),
1410 )
1411 .mount(&server)
1412 .await;
1413
1414 let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1415 let transcription_provider =
1416 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1417
1418 let err = transcription_provider
1419 .transcribe(b"fake-audio", "voice.ogg")
1420 .await
1421 .unwrap_err();
1422 assert!(
1423 err.to_string().contains("503") || err.to_string().contains("service unavailable"),
1424 "expected HTTP error, got: {err}"
1425 );
1426 }
1427
1428 #[tokio::test]
1429 async fn local_whisper_propagates_non_json_http_error() {
1430 use wiremock::matchers::{method, path};
1431 use wiremock::{Mock, MockServer, ResponseTemplate};
1432
1433 let server = MockServer::start().await;
1434
1435 Mock::given(method("POST"))
1436 .and(path("/v1/transcribe"))
1437 .respond_with(
1438 ResponseTemplate::new(502)
1439 .set_body_string("Bad Gateway")
1440 .insert_header("content-type", "text/plain"),
1441 )
1442 .mount(&server)
1443 .await;
1444
1445 let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1446 let transcription_provider =
1447 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1448
1449 let err = transcription_provider
1450 .transcribe(b"fake-audio", "voice.ogg")
1451 .await
1452 .unwrap_err();
1453 assert!(err.to_string().contains("502"), "got: {err}");
1454 assert!(
1455 err.to_string().contains("Bad Gateway"),
1456 "expected plain-text body in error, got: {err}"
1457 );
1458 }
1459}