1use std::collections::HashMap;
2
3use anyhow::{Context, Result, bail};
4use async_trait::async_trait;
5use reqwest::multipart::{Form, Part};
6
7use zeroclaw_config::schema::TranscriptionConfig;
8
9const MAX_AUDIO_BYTES: usize = 25 * 1024 * 1024;
11
12const TRANSCRIPTION_TIMEOUT_SECS: u64 = 120;
14
15fn mime_for_audio(extension: &str) -> Option<&'static str> {
19 match extension.to_ascii_lowercase().as_str() {
20 "flac" => Some("audio/flac"),
21 "mp3" | "mpeg" | "mpga" => Some("audio/mpeg"),
22 "mp4" | "m4a" => Some("audio/mp4"),
23 "ogg" | "oga" => Some("audio/ogg"),
24 "opus" => Some("audio/opus"),
25 "wav" => Some("audio/wav"),
26 "webm" => Some("audio/webm"),
27 _ => None,
28 }
29}
30
31fn normalize_audio_filename(file_name: &str) -> String {
36 match file_name.rsplit_once('.') {
37 Some((stem, ext)) if ext.eq_ignore_ascii_case("oga") => format!("{stem}.ogg"),
38 _ => file_name.to_string(),
39 }
40}
41
42fn resolve_audio_format(file_name: &str) -> Result<(String, &'static str)> {
46 let normalized_name = normalize_audio_filename(file_name);
47 let extension = normalized_name
48 .rsplit_once('.')
49 .map(|(_, e)| e)
50 .unwrap_or("");
51 let mime = mime_for_audio(extension).ok_or_else(|| {
52 ::zeroclaw_log::record!(
53 WARN,
54 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
55 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
56 .with_attrs(::serde_json::json!({"extension": extension})),
57 "transcription: unsupported audio format"
58 );
59 anyhow::Error::msg(format!(
60 "Unsupported audio format '.{extension}'. \
61 accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm"
62 ))
63 })?;
64 Ok((normalized_name, mime))
65}
66
67fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
71 if audio_data.len() > MAX_AUDIO_BYTES {
72 bail!(
73 "Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
74 audio_data.len()
75 );
76 }
77 resolve_audio_format(file_name)
78}
79
80#[async_trait]
84pub trait TranscriptionProvider: Send + Sync + ::zeroclaw_api::attribution::Attributable {
85 fn name(&self) -> &str;
87
88 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String>;
91
92 fn supported_formats(&self) -> Vec<String> {
94 vec![
95 "flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
96 ]
97 .into_iter()
98 .map(String::from)
99 .collect()
100 }
101}
102
103pub struct GroqProvider {
107 alias: String,
108 api_url: String,
109 model: String,
110 api_key: String,
111 language: Option<String>,
112}
113
114impl GroqProvider {
115 pub fn from_config(alias: &str, config: &TranscriptionConfig) -> Result<Self> {
122 let api_key = config
123 .api_key
124 .as_deref()
125 .map(str::trim)
126 .filter(|v| !v.is_empty())
127 .map(ToOwned::to_owned)
128 .context(
129 "Missing transcription API key: set `[transcription].api_key` (or via the \
130 schema-mirror grammar `ZEROCLAW_transcription__api_key=...`).",
131 )?;
132
133 Ok(Self {
134 alias: alias.to_string(),
135 api_url: config.api_url.clone(),
136 model: config.model.clone(),
137 api_key,
138 language: config.language.clone(),
139 })
140 }
141}
142
143#[async_trait]
144impl TranscriptionProvider for GroqProvider {
145 fn name(&self) -> &str {
146 "groq"
147 }
148
149 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
150 let (normalized_name, mime) = validate_audio(audio_data, file_name)?;
151
152 let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.groq");
153
154 let file_part = Part::bytes(audio_data.to_vec())
155 .file_name(normalized_name)
156 .mime_str(mime)?;
157
158 let mut form = Form::new()
159 .part("file", file_part)
160 .text("model", self.model.clone())
161 .text("response_format", "json");
162
163 if let Some(ref lang) = self.language {
164 form = form.text("language", lang.clone());
165 }
166
167 let resp = client
168 .post(&self.api_url)
169 .bearer_auth(&self.api_key)
170 .multipart(form)
171 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
172 .send()
173 .await
174 .context("Failed to send transcription request to Groq")?;
175
176 parse_whisper_response(resp).await
177 }
178}
179
180pub struct OpenAiWhisperProvider {
184 alias: String,
185 api_key: String,
186 model: String,
187}
188
189impl OpenAiWhisperProvider {
190 pub fn from_config(
191 alias: &str,
192 config: &zeroclaw_config::schema::OpenAiSttConfig,
193 ) -> Result<Self> {
194 let api_key = config
195 .api_key
196 .as_deref()
197 .map(str::trim)
198 .filter(|v| !v.is_empty())
199 .map(ToOwned::to_owned)
200 .context("Missing OpenAI STT API key: set [transcription.openai].api_key")?;
201
202 Ok(Self {
203 alias: alias.to_string(),
204 api_key,
205 model: config.model.clone(),
206 })
207 }
208}
209
210#[async_trait]
211impl TranscriptionProvider for OpenAiWhisperProvider {
212 fn name(&self) -> &str {
213 "openai"
214 }
215
216 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
217 let (normalized_name, mime) = validate_audio(audio_data, file_name)?;
218
219 let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.openai");
220
221 let file_part = Part::bytes(audio_data.to_vec())
222 .file_name(normalized_name)
223 .mime_str(mime)?;
224
225 let form = Form::new()
226 .part("file", file_part)
227 .text("model", self.model.clone())
228 .text("response_format", "json");
229
230 let resp = client
231 .post("https://api.openai.com/v1/audio/transcriptions")
232 .bearer_auth(&self.api_key)
233 .multipart(form)
234 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
235 .send()
236 .await
237 .context("Failed to send transcription request to OpenAI")?;
238
239 parse_whisper_response(resp).await
240 }
241}
242
243pub struct DeepgramProvider {
247 alias: String,
248 api_key: String,
249 model: String,
250}
251
252impl DeepgramProvider {
253 pub fn from_config(
254 alias: &str,
255 config: &zeroclaw_config::schema::DeepgramSttConfig,
256 ) -> Result<Self> {
257 let api_key = config
258 .api_key
259 .as_deref()
260 .map(str::trim)
261 .filter(|v| !v.is_empty())
262 .map(ToOwned::to_owned)
263 .context("Missing Deepgram API key: set [transcription.deepgram].api_key")?;
264
265 Ok(Self {
266 alias: alias.to_string(),
267 api_key,
268 model: config.model.clone(),
269 })
270 }
271}
272
273#[async_trait]
274impl TranscriptionProvider for DeepgramProvider {
275 fn name(&self) -> &str {
276 "deepgram"
277 }
278
279 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
280 let (_, mime) = validate_audio(audio_data, file_name)?;
281
282 let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.deepgram");
283
284 let url = format!(
285 "https://api.deepgram.com/v1/listen?model={}&punctuate=true",
286 self.model
287 );
288
289 let resp = client
290 .post(&url)
291 .header("Authorization", format!("Token {}", self.api_key))
292 .header("Content-Type", mime)
293 .body(audio_data.to_vec())
294 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
295 .send()
296 .await
297 .context("Failed to send transcription request to Deepgram")?;
298
299 let status = resp.status();
300 let body: serde_json::Value = resp
301 .json()
302 .await
303 .context("Failed to parse Deepgram response")?;
304
305 if !status.is_success() {
306 let error_msg = body["err_msg"]
307 .as_str()
308 .or_else(|| body["error"].as_str())
309 .unwrap_or("unknown error");
310 bail!("Deepgram API error ({}): {}", status, error_msg);
311 }
312
313 let text = body["results"]["channels"][0]["alternatives"][0]["transcript"]
314 .as_str()
315 .context("Deepgram response missing transcript field")?
316 .to_string();
317
318 Ok(text)
319 }
320}
321
322pub struct AssemblyAiProvider {
326 alias: String,
327 api_key: String,
328}
329
330impl AssemblyAiProvider {
331 pub fn from_config(
332 alias: &str,
333 config: &zeroclaw_config::schema::AssemblyAiSttConfig,
334 ) -> Result<Self> {
335 let api_key = config
336 .api_key
337 .as_deref()
338 .map(str::trim)
339 .filter(|v| !v.is_empty())
340 .map(ToOwned::to_owned)
341 .context("Missing AssemblyAI API key: set [transcription.assemblyai].api_key")?;
342
343 Ok(Self {
344 alias: alias.to_string(),
345 api_key,
346 })
347 }
348}
349
350#[async_trait]
351impl TranscriptionProvider for AssemblyAiProvider {
352 fn name(&self) -> &str {
353 "assemblyai"
354 }
355
356 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
357 let (_, _) = validate_audio(audio_data, file_name)?;
358
359 let client =
360 zeroclaw_config::schema::build_runtime_proxy_client("transcription.assemblyai");
361
362 let upload_resp = client
364 .post("https://api.assemblyai.com/v2/upload")
365 .header("Authorization", &self.api_key)
366 .header("Content-Type", "application/octet-stream")
367 .body(audio_data.to_vec())
368 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
369 .send()
370 .await
371 .context("Failed to upload audio to AssemblyAI")?;
372
373 let upload_status = upload_resp.status();
374 let upload_body: serde_json::Value = upload_resp
375 .json()
376 .await
377 .context("Failed to parse AssemblyAI upload response")?;
378
379 if !upload_status.is_success() {
380 let error_msg = upload_body["error"].as_str().unwrap_or("unknown error");
381 bail!("AssemblyAI upload error ({}): {}", upload_status, error_msg);
382 }
383
384 let upload_url = upload_body["upload_url"]
385 .as_str()
386 .context("AssemblyAI upload response missing 'upload_url'")?;
387
388 let transcript_req = serde_json::json!({
390 "audio_url": upload_url,
391 });
392
393 let create_resp = client
394 .post("https://api.assemblyai.com/v2/transcript")
395 .header("Authorization", &self.api_key)
396 .json(&transcript_req)
397 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
398 .send()
399 .await
400 .context("Failed to create AssemblyAI transcription")?;
401
402 let create_status = create_resp.status();
403 let create_body: serde_json::Value = create_resp
404 .json()
405 .await
406 .context("Failed to parse AssemblyAI create response")?;
407
408 if !create_status.is_success() {
409 let error_msg = create_body["error"].as_str().unwrap_or("unknown error");
410 bail!(
411 "AssemblyAI transcription error ({}): {}",
412 create_status,
413 error_msg
414 );
415 }
416
417 let transcript_id = create_body["id"]
418 .as_str()
419 .context("AssemblyAI response missing 'id'")?;
420
421 let poll_url = format!("https://api.assemblyai.com/v2/transcript/{transcript_id}");
423 let poll_interval = std::time::Duration::from_secs(3);
424 let poll_deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(180);
425
426 while tokio::time::Instant::now() < poll_deadline {
427 tokio::time::sleep(poll_interval).await;
428
429 let poll_resp = client
430 .get(&poll_url)
431 .header("Authorization", &self.api_key)
432 .timeout(std::time::Duration::from_secs(30))
433 .send()
434 .await
435 .context("Failed to poll AssemblyAI transcription")?;
436
437 let poll_status = poll_resp.status();
438 let poll_body: serde_json::Value = poll_resp
439 .json()
440 .await
441 .context("Failed to parse AssemblyAI poll response")?;
442
443 if !poll_status.is_success() {
444 let error_msg = poll_body["error"].as_str().unwrap_or("unknown poll error");
445 bail!("AssemblyAI poll error ({}): {}", poll_status, error_msg);
446 }
447
448 let status_str = poll_body["status"].as_str().unwrap_or("unknown");
449
450 match status_str {
451 "completed" => {
452 let text = poll_body["text"]
453 .as_str()
454 .context("AssemblyAI response missing 'text'")?
455 .to_string();
456 return Ok(text);
457 }
458 "error" => {
459 let error_msg = poll_body["error"]
460 .as_str()
461 .unwrap_or("unknown transcription error");
462 bail!("AssemblyAI transcription failed: {}", error_msg);
463 }
464 _ => {}
465 }
466 }
467
468 bail!("AssemblyAI transcription timed out after 180s")
469 }
470}
471
472pub struct GoogleSttProvider {
476 alias: String,
477 api_key: String,
478 language_code: String,
479}
480
481impl GoogleSttProvider {
482 pub fn from_config(
483 alias: &str,
484 config: &zeroclaw_config::schema::GoogleSttConfig,
485 ) -> Result<Self> {
486 let api_key = config
487 .api_key
488 .as_deref()
489 .map(str::trim)
490 .filter(|v| !v.is_empty())
491 .map(ToOwned::to_owned)
492 .context("Missing Google STT API key: set [transcription.google].api_key")?;
493
494 Ok(Self {
495 alias: alias.to_string(),
496 api_key,
497 language_code: config.language_code.clone(),
498 })
499 }
500}
501
502#[async_trait]
503impl TranscriptionProvider for GoogleSttProvider {
504 fn name(&self) -> &str {
505 "google"
506 }
507
508 fn supported_formats(&self) -> Vec<String> {
509 vec!["flac", "wav", "ogg", "opus", "mp3", "webm"]
511 .into_iter()
512 .map(String::from)
513 .collect()
514 }
515
516 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
517 let (normalized_name, _) = validate_audio(audio_data, file_name)?;
518
519 let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.google");
520
521 let encoding = match normalized_name
522 .rsplit_once('.')
523 .map(|(_, e)| e.to_ascii_lowercase())
524 .as_deref()
525 {
526 Some("flac") => "FLAC",
527 Some("wav") => "LINEAR16",
528 Some("ogg" | "opus") => "OGG_OPUS",
529 Some("mp3") => "MP3",
530 Some("webm") => "WEBM_OPUS",
531 Some(ext) => bail!("Google STT does not support '.{ext}' input"),
532 None => bail!("Google STT requires a file extension"),
533 };
534
535 let audio_content =
536 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, audio_data);
537
538 let request_body = serde_json::json!({
539 "config": {
540 "encoding": encoding,
541 "languageCode": &self.language_code,
542 "enableAutomaticPunctuation": true,
543 },
544 "audio": {
545 "content": audio_content,
546 }
547 });
548
549 let url = format!(
550 "https://speech.googleapis.com/v1/speech:recognize?key={}",
551 self.api_key
552 );
553
554 let resp = client
555 .post(&url)
556 .json(&request_body)
557 .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
558 .send()
559 .await
560 .context("Failed to send transcription request to Google STT")?;
561
562 let status = resp.status();
563 let body: serde_json::Value = resp
564 .json()
565 .await
566 .context("Failed to parse Google STT response")?;
567
568 if !status.is_success() {
569 let error_msg = body["error"]["message"].as_str().unwrap_or("unknown error");
570 bail!("Google STT API error ({}): {}", status, error_msg);
571 }
572
573 let text = body["results"][0]["alternatives"][0]["transcript"]
574 .as_str()
575 .unwrap_or("")
576 .to_string();
577
578 Ok(text)
579 }
580}
581
582pub struct LocalWhisperProvider {
591 alias: String,
592 url: String,
593 bearer_token: String,
594 max_audio_bytes: usize,
595 timeout_secs: u64,
596}
597
598impl LocalWhisperProvider {
599 pub fn from_config(
603 alias: &str,
604 config: &zeroclaw_config::schema::LocalWhisperConfig,
605 ) -> Result<Self> {
606 let url = config.url.trim().to_string();
607 anyhow::ensure!(!url.is_empty(), "local_whisper: `url` must not be empty");
608 let parsed = url
609 .parse::<reqwest::Url>()
610 .with_context(|| format!("local_whisper: invalid `url`: {url:?}"))?;
611 anyhow::ensure!(
612 matches!(parsed.scheme(), "http" | "https"),
613 "local_whisper: `url` must use http or https scheme, got {:?}",
614 parsed.scheme()
615 );
616
617 let bearer_token = match config.bearer_token.as_deref().map(str::trim) {
618 None => anyhow::bail!("local_whisper: `bearer_token` must be set"),
619 Some("") => anyhow::bail!("local_whisper: `bearer_token` must not be empty"),
620 Some(t) => t.to_string(),
621 };
622
623 anyhow::ensure!(
624 config.max_audio_bytes > 0,
625 "local_whisper: `max_audio_bytes` must be greater than zero"
626 );
627
628 anyhow::ensure!(
629 config.timeout_secs > 0,
630 "local_whisper: `timeout_secs` must be greater than zero"
631 );
632
633 Ok(Self {
634 alias: alias.to_string(),
635 url,
636 bearer_token,
637 max_audio_bytes: config.max_audio_bytes,
638 timeout_secs: config.timeout_secs,
639 })
640 }
641}
642
643#[async_trait]
644impl TranscriptionProvider for LocalWhisperProvider {
645 fn name(&self) -> &str {
646 "local_whisper"
647 }
648
649 async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
650 if audio_data.len() > self.max_audio_bytes {
651 bail!(
652 "Audio file too large ({} bytes, local_whisper max {})",
653 audio_data.len(),
654 self.max_audio_bytes
655 );
656 }
657
658 let (normalized_name, mime) = resolve_audio_format(file_name)?;
659
660 let client =
661 zeroclaw_config::schema::build_runtime_proxy_client("transcription.local_whisper");
662
663 let file_part = Part::bytes(audio_data.to_vec())
667 .file_name(normalized_name)
668 .mime_str(mime)?;
669
670 let resp = client
671 .post(&self.url)
672 .bearer_auth(&self.bearer_token)
673 .multipart(Form::new().part("file", file_part))
674 .timeout(std::time::Duration::from_secs(self.timeout_secs))
675 .send()
676 .await
677 .context("Failed to send audio to local Whisper endpoint")?;
678
679 parse_whisper_response(resp).await
680 }
681}
682
683async fn parse_whisper_response(resp: reqwest::Response) -> Result<String> {
691 let status = resp.status();
692 if !status.is_success() {
693 let body = resp.text().await.unwrap_or_default();
694 bail!("Transcription API error ({}): {}", status, body.trim());
695 }
696
697 let body: serde_json::Value = resp
698 .json()
699 .await
700 .context("Failed to parse transcription response")?;
701
702 let text = body["text"]
703 .as_str()
704 .context("Transcription response missing 'text' field")?
705 .to_string();
706
707 Ok(text)
708}
709
710pub struct TranscriptionManager {
717 transcription_providers: HashMap<String, Box<dyn TranscriptionProvider>>,
718 max_audio_bytes: Option<usize>,
719 agent_transcription_provider: String,
722}
723
724impl TranscriptionManager {
725 pub fn new(config: &TranscriptionConfig) -> Result<Self> {
730 if matches!(config.max_audio_bytes, Some(0)) {
731 bail!("transcription.max_audio_bytes must be greater than zero");
732 }
733
734 let mut transcription_providers: HashMap<String, Box<dyn TranscriptionProvider>> =
735 HashMap::new();
736
737 if let Ok(groq) = GroqProvider::from_config("groq", config) {
738 transcription_providers.insert("groq".to_string(), Box::new(groq));
739 }
740
741 if let Some(ref openai_cfg) = config.openai
742 && let Ok(p) = OpenAiWhisperProvider::from_config("openai", openai_cfg)
743 {
744 transcription_providers.insert("openai".to_string(), Box::new(p));
745 }
746
747 if let Some(ref deepgram_cfg) = config.deepgram
748 && let Ok(p) = DeepgramProvider::from_config("deepgram", deepgram_cfg)
749 {
750 transcription_providers.insert("deepgram".to_string(), Box::new(p));
751 }
752
753 if let Some(ref assemblyai_cfg) = config.assemblyai
754 && let Ok(p) = AssemblyAiProvider::from_config("assemblyai", assemblyai_cfg)
755 {
756 transcription_providers.insert("assemblyai".to_string(), Box::new(p));
757 }
758
759 if let Some(ref google_cfg) = config.google
760 && let Ok(p) = GoogleSttProvider::from_config("google", google_cfg)
761 {
762 transcription_providers.insert("google".to_string(), Box::new(p));
763 }
764
765 if let Some(ref local_cfg) = config.local_whisper {
766 match LocalWhisperProvider::from_config("local_whisper", local_cfg) {
767 Ok(p) => {
768 transcription_providers.insert("local_whisper".to_string(), Box::new(p));
769 }
770 Err(e) => {
771 ::zeroclaw_log::record!(
772 WARN,
773 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
774 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
775 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
776 "local_whisper config invalid, provider skipped"
777 );
778 }
779 }
780 }
781
782 if config.enabled && transcription_providers.is_empty() {
783 bail!(
784 "Transcription is enabled but no transcription provider registered \
785 successfully. Configure at least one of: [transcription] (Groq) \
786 with api_key + api_url; [transcription.openai]; [transcription.deepgram]; \
787 [transcription.assemblyai]; [transcription.google]; [transcription.local_whisper]."
788 );
789 }
790
791 Ok(Self {
792 transcription_providers,
793 max_audio_bytes: config.max_audio_bytes,
794 agent_transcription_provider: String::new(),
795 })
796 }
797
798 #[must_use]
802 pub fn with_agent_transcription_provider(mut self, alias: impl Into<String>) -> Self {
803 self.agent_transcription_provider = alias.into();
804 self
805 }
806
807 pub async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
811 let provider_alias = self.agent_transcription_provider.as_str();
812 if provider_alias.is_empty() {
813 bail!(
814 "Agent has no transcription_provider configured. Set \
815 `agent.<alias>.transcription_provider = \"<type>.<alias>\"` \
816 referencing a configured transcription provider."
817 );
818 }
819 self.transcribe_with_provider(audio_data, file_name, provider_alias)
820 .await
821 }
822
823 pub async fn transcribe_with_provider(
825 &self,
826 audio_data: &[u8],
827 file_name: &str,
828 transcription_provider: &str,
829 ) -> Result<String> {
830 let p = self.transcription_providers.get(transcription_provider).ok_or_else(|| {
831 let available: Vec<&str> = self.transcription_providers.keys().map(|k| k.as_str()).collect();
832 ::zeroclaw_log::record!(
833 ERROR,
834 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
835 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
836 .with_attrs(::serde_json::json!({
837 "transcription_provider": transcription_provider,
838 "available": available,
839 })),
840 "transcription: provider not configured"
841 );
842 anyhow::Error::msg(format!(
843 "Transcription transcription_provider '{transcription_provider}' not configured. Available: {available:?}"
844 ))
845 })?;
846
847 self.enforce_global_audio_limit(audio_data)?;
848
849 use ::zeroclaw_log::Instrument;
850 let span = ::zeroclaw_log::attribution_span!(p.as_ref());
851 p.transcribe(audio_data, file_name).instrument(span).await
852 }
853
854 fn enforce_global_audio_limit(&self, audio_data: &[u8]) -> Result<()> {
855 if let Some(max_audio_bytes) = self.max_audio_bytes
856 && audio_data.len() > max_audio_bytes
857 {
858 bail!(
859 "Audio file too large ({} bytes, global max {})",
860 audio_data.len(),
861 max_audio_bytes
862 );
863 }
864 Ok(())
865 }
866
867 pub fn available_providers(&self) -> Vec<&str> {
869 self.transcription_providers
870 .keys()
871 .map(|k| k.as_str())
872 .collect()
873 }
874}
875
876impl ::zeroclaw_api::attribution::Attributable for GroqProvider {
883 fn role(&self) -> ::zeroclaw_api::attribution::Role {
884 ::zeroclaw_api::attribution::Role::Provider(
885 ::zeroclaw_api::attribution::ProviderKind::Transcription(
886 ::zeroclaw_api::attribution::TranscriptionProviderKind::Groq,
887 ),
888 )
889 }
890 fn alias(&self) -> &str {
891 &self.alias
892 }
893}
894
895impl ::zeroclaw_api::attribution::Attributable for OpenAiWhisperProvider {
896 fn role(&self) -> ::zeroclaw_api::attribution::Role {
897 ::zeroclaw_api::attribution::Role::Provider(
898 ::zeroclaw_api::attribution::ProviderKind::Transcription(
899 ::zeroclaw_api::attribution::TranscriptionProviderKind::OpenAi,
900 ),
901 )
902 }
903 fn alias(&self) -> &str {
904 &self.alias
905 }
906}
907
908impl ::zeroclaw_api::attribution::Attributable for DeepgramProvider {
909 fn role(&self) -> ::zeroclaw_api::attribution::Role {
910 ::zeroclaw_api::attribution::Role::Provider(
911 ::zeroclaw_api::attribution::ProviderKind::Transcription(
912 ::zeroclaw_api::attribution::TranscriptionProviderKind::Deepgram,
913 ),
914 )
915 }
916 fn alias(&self) -> &str {
917 &self.alias
918 }
919}
920
921impl ::zeroclaw_api::attribution::Attributable for AssemblyAiProvider {
922 fn role(&self) -> ::zeroclaw_api::attribution::Role {
923 ::zeroclaw_api::attribution::Role::Provider(
924 ::zeroclaw_api::attribution::ProviderKind::Transcription(
925 ::zeroclaw_api::attribution::TranscriptionProviderKind::AssemblyAi,
926 ),
927 )
928 }
929 fn alias(&self) -> &str {
930 &self.alias
931 }
932}
933
934impl ::zeroclaw_api::attribution::Attributable for GoogleSttProvider {
935 fn role(&self) -> ::zeroclaw_api::attribution::Role {
936 ::zeroclaw_api::attribution::Role::Provider(
937 ::zeroclaw_api::attribution::ProviderKind::Transcription(
938 ::zeroclaw_api::attribution::TranscriptionProviderKind::Google,
939 ),
940 )
941 }
942 fn alias(&self) -> &str {
943 &self.alias
944 }
945}
946
947impl ::zeroclaw_api::attribution::Attributable for LocalWhisperProvider {
948 fn role(&self) -> ::zeroclaw_api::attribution::Role {
949 ::zeroclaw_api::attribution::Role::Provider(
950 ::zeroclaw_api::attribution::ProviderKind::Transcription(
951 ::zeroclaw_api::attribution::TranscriptionProviderKind::Whisper,
952 ),
953 )
954 }
955 fn alias(&self) -> &str {
956 &self.alias
957 }
958}
959
960#[cfg(test)]
961mod tests {
962 use super::*;
963 use std::sync::{
964 Arc,
965 atomic::{AtomicUsize, Ordering},
966 };
967
968 struct StaticTranscriptionProvider {
969 calls: Arc<AtomicUsize>,
970 }
971
972 #[async_trait]
973 impl TranscriptionProvider for StaticTranscriptionProvider {
974 fn name(&self) -> &str {
975 "static"
976 }
977
978 async fn transcribe(&self, _audio_data: &[u8], _file_name: &str) -> Result<String> {
979 self.calls.fetch_add(1, Ordering::SeqCst);
980 Ok("under cap".to_string())
981 }
982 }
983
984 impl ::zeroclaw_api::attribution::Attributable for StaticTranscriptionProvider {
985 fn role(&self) -> ::zeroclaw_api::attribution::Role {
986 ::zeroclaw_api::attribution::Role::Provider(
987 ::zeroclaw_api::attribution::ProviderKind::Transcription(
988 ::zeroclaw_api::attribution::TranscriptionProviderKind::Groq,
989 ),
990 )
991 }
992
993 fn alias(&self) -> &str {
994 "static"
995 }
996 }
997
998 fn manager_with_static_provider(
999 max_audio_bytes: Option<usize>,
1000 ) -> (TranscriptionManager, Arc<AtomicUsize>) {
1001 let calls = Arc::new(AtomicUsize::new(0));
1002 let mut transcription_providers: HashMap<String, Box<dyn TranscriptionProvider>> =
1003 HashMap::new();
1004 transcription_providers.insert(
1005 "static".to_string(),
1006 Box::new(StaticTranscriptionProvider {
1007 calls: Arc::clone(&calls),
1008 }),
1009 );
1010 (
1011 TranscriptionManager {
1012 transcription_providers,
1013 max_audio_bytes,
1014 agent_transcription_provider: String::new(),
1015 },
1016 calls,
1017 )
1018 }
1019
1020 #[test]
1026 fn mime_for_audio_maps_accepted_formats() {
1027 let cases = [
1028 ("flac", "audio/flac"),
1029 ("mp3", "audio/mpeg"),
1030 ("mpeg", "audio/mpeg"),
1031 ("mpga", "audio/mpeg"),
1032 ("mp4", "audio/mp4"),
1033 ("m4a", "audio/mp4"),
1034 ("ogg", "audio/ogg"),
1035 ("oga", "audio/ogg"),
1036 ("opus", "audio/opus"),
1037 ("wav", "audio/wav"),
1038 ("webm", "audio/webm"),
1039 ];
1040 for (ext, expected) in cases {
1041 assert_eq!(
1042 mime_for_audio(ext),
1043 Some(expected),
1044 "failed for extension: {ext}"
1045 );
1046 }
1047 }
1048
1049 #[test]
1050 fn mime_for_audio_case_insensitive() {
1051 assert_eq!(mime_for_audio("OGG"), Some("audio/ogg"));
1052 assert_eq!(mime_for_audio("MP3"), Some("audio/mpeg"));
1053 assert_eq!(mime_for_audio("Opus"), Some("audio/opus"));
1054 }
1055
1056 #[test]
1057 fn mime_for_audio_rejects_unknown() {
1058 assert_eq!(mime_for_audio("txt"), None);
1059 assert_eq!(mime_for_audio("pdf"), None);
1060 assert_eq!(mime_for_audio("aac"), None);
1061 assert_eq!(mime_for_audio(""), None);
1062 }
1063
1064 #[test]
1065 fn normalize_audio_filename_rewrites_oga() {
1066 assert_eq!(normalize_audio_filename("voice.oga"), "voice.ogg");
1067 assert_eq!(normalize_audio_filename("file.OGA"), "file.ogg");
1068 }
1069
1070 #[test]
1071 fn normalize_audio_filename_preserves_accepted() {
1072 assert_eq!(normalize_audio_filename("voice.ogg"), "voice.ogg");
1073 assert_eq!(normalize_audio_filename("track.mp3"), "track.mp3");
1074 assert_eq!(normalize_audio_filename("clip.opus"), "clip.opus");
1075 }
1076
1077 #[test]
1078 fn normalize_audio_filename_no_extension() {
1079 assert_eq!(normalize_audio_filename("voice"), "voice");
1080 }
1081
1082 #[test]
1083 fn rejects_unsupported_audio_format() {
1084 let data = vec![0u8; 100];
1087 let err = validate_audio(&data, "recording.aac").unwrap_err();
1088 let msg = err.to_string();
1089 assert!(
1090 msg.contains("Unsupported audio format"),
1091 "expected unsupported-format error, got: {msg}"
1092 );
1093 assert!(
1094 msg.contains(".aac"),
1095 "error should mention the rejected extension, got: {msg}"
1096 );
1097 }
1098
1099 #[test]
1102 fn manager_creation_with_default_config() {
1103 unsafe { std::env::remove_var("GROQ_API_KEY") };
1105
1106 let config = TranscriptionConfig::default();
1107 let manager = TranscriptionManager::new(&config).unwrap();
1108 assert!(manager.agent_transcription_provider.is_empty());
1112 assert!(manager.transcription_providers.is_empty());
1114 }
1115
1116 #[test]
1117 fn manager_registers_groq_with_key() {
1118 unsafe { std::env::remove_var("GROQ_API_KEY") };
1120
1121 let config = TranscriptionConfig {
1122 api_key: Some("test-groq-key".to_string()),
1123 ..TranscriptionConfig::default()
1124 };
1125
1126 let manager = TranscriptionManager::new(&config).unwrap();
1127 assert!(manager.transcription_providers.contains_key("groq"));
1128 assert_eq!(manager.transcription_providers["groq"].name(), "groq");
1129 }
1130
1131 #[test]
1132 fn manager_registers_multiple_providers() {
1133 unsafe { std::env::remove_var("GROQ_API_KEY") };
1135
1136 let config = TranscriptionConfig {
1137 api_key: Some("test-groq-key".to_string()),
1138 openai: Some(zeroclaw_config::schema::OpenAiSttConfig {
1139 api_key: Some("test-openai-key".to_string()),
1140 model: "whisper-1".to_string(),
1141 }),
1142 deepgram: Some(zeroclaw_config::schema::DeepgramSttConfig {
1143 api_key: Some("test-deepgram-key".to_string()),
1144 model: "nova-2".to_string(),
1145 }),
1146 ..TranscriptionConfig::default()
1147 };
1148
1149 let manager = TranscriptionManager::new(&config).unwrap();
1150 assert!(manager.transcription_providers.contains_key("groq"));
1151 assert!(manager.transcription_providers.contains_key("openai"));
1152 assert!(manager.transcription_providers.contains_key("deepgram"));
1153 assert_eq!(manager.available_providers().len(), 3);
1154 }
1155
1156 #[tokio::test]
1157 async fn manager_rejects_unconfigured_provider() {
1158 unsafe { std::env::remove_var("GROQ_API_KEY") };
1160
1161 let config = TranscriptionConfig {
1162 api_key: Some("test-groq-key".to_string()),
1163 ..TranscriptionConfig::default()
1164 };
1165
1166 let manager = TranscriptionManager::new(&config).unwrap();
1167 let err = manager
1168 .transcribe_with_provider(&[0u8; 100], "test.ogg", "nonexistent")
1169 .await
1170 .unwrap_err();
1171 assert!(
1172 err.to_string().contains("not configured"),
1173 "expected not-configured error, got: {err}"
1174 );
1175 }
1176
1177 #[test]
1178 fn manager_agent_transcription_provider_via_setter() {
1179 unsafe { std::env::remove_var("GROQ_API_KEY") };
1181
1182 let config = TranscriptionConfig {
1183 openai: Some(zeroclaw_config::schema::OpenAiSttConfig {
1184 api_key: Some("test-openai-key".to_string()),
1185 model: "whisper-1".to_string(),
1186 }),
1187 ..TranscriptionConfig::default()
1188 };
1189
1190 let manager = TranscriptionManager::new(&config)
1191 .unwrap()
1192 .with_agent_transcription_provider("openai");
1193 assert_eq!(manager.agent_transcription_provider, "openai");
1194 }
1195
1196 #[test]
1197 fn manager_rejects_zero_global_max_audio_bytes() {
1198 let config = TranscriptionConfig {
1199 max_audio_bytes: Some(0),
1200 ..TranscriptionConfig::default()
1201 };
1202
1203 let err = match TranscriptionManager::new(&config) {
1204 Ok(_) => panic!("expected zero max_audio_bytes to fail manager construction"),
1205 Err(err) => err,
1206 };
1207 assert!(
1208 err.to_string()
1209 .contains("transcription.max_audio_bytes must be greater than zero"),
1210 "got: {err}"
1211 );
1212 }
1213
1214 #[tokio::test]
1215 async fn manager_global_max_audio_bytes_rejects_over_limit_before_provider_dispatch() {
1216 let (manager, calls) = manager_with_static_provider(Some(3));
1217 let err = manager
1218 .transcribe_with_provider(&[0u8; 4], "voice.ogg", "static")
1219 .await
1220 .unwrap_err();
1221 assert!(
1222 err.to_string().contains("Audio file too large"),
1223 "got: {err}"
1224 );
1225 assert!(err.to_string().contains("global max 3"), "got: {err}");
1226 assert_eq!(calls.load(Ordering::SeqCst), 0);
1227 }
1228
1229 #[tokio::test]
1230 async fn manager_global_max_audio_bytes_allows_exact_limit() {
1231 let (manager, calls) = manager_with_static_provider(Some(4));
1232 let result = manager
1233 .transcribe_with_provider(&[0u8; 4], "voice.ogg", "static")
1234 .await
1235 .unwrap();
1236 assert_eq!(result, "under cap");
1237 assert_eq!(calls.load(Ordering::SeqCst), 1);
1238 }
1239
1240 #[tokio::test]
1241 async fn manager_transcribe_enforces_global_max_audio_bytes() {
1242 let (manager, calls) = manager_with_static_provider(Some(2));
1243 let manager = manager.with_agent_transcription_provider("static");
1244 let err = manager
1245 .transcribe(&[0u8; 3], "voice.ogg")
1246 .await
1247 .unwrap_err();
1248 assert!(
1249 err.to_string().contains("Audio file too large"),
1250 "got: {err}"
1251 );
1252 assert!(err.to_string().contains("global max 2"), "got: {err}");
1253 assert_eq!(calls.load(Ordering::SeqCst), 0);
1254 }
1255
1256 #[test]
1257 fn validate_audio_rejects_oversized() {
1258 let big = vec![0u8; MAX_AUDIO_BYTES + 1];
1259 let err = validate_audio(&big, "test.ogg").unwrap_err();
1260 assert!(err.to_string().contains("too large"));
1261 }
1262
1263 #[test]
1264 fn validate_audio_rejects_unsupported_format() {
1265 let data = vec![0u8; 100];
1266 let err = validate_audio(&data, "test.aac").unwrap_err();
1267 assert!(err.to_string().contains("Unsupported audio format"));
1268 }
1269
1270 #[test]
1271 fn validate_audio_accepts_supported_format() {
1272 let data = vec![0u8; 100];
1273 let (name, mime) = validate_audio(&data, "test.ogg").unwrap();
1274 assert_eq!(name, "test.ogg");
1275 assert_eq!(mime, "audio/ogg");
1276 }
1277
1278 #[test]
1279 fn validate_audio_normalizes_oga() {
1280 let data = vec![0u8; 100];
1281 let (name, mime) = validate_audio(&data, "voice.oga").unwrap();
1282 assert_eq!(name, "voice.ogg");
1283 assert_eq!(mime, "audio/ogg");
1284 }
1285
1286 #[test]
1287 fn backward_compat_config_defaults_unchanged() {
1288 let config = TranscriptionConfig::default();
1289 assert!(!config.enabled);
1290 assert!(config.api_key.is_none());
1291 assert!(config.api_url.contains("groq.com"));
1292 assert_eq!(config.model, "whisper-large-v3-turbo");
1293 assert!(config.openai.is_none());
1296 assert!(config.deepgram.is_none());
1297 assert!(config.assemblyai.is_none());
1298 assert!(config.google.is_none());
1299 assert!(config.local_whisper.is_none());
1300 assert!(!config.transcribe_non_ptt_audio);
1301 }
1302
1303 fn local_whisper_config(url: &str) -> zeroclaw_config::schema::LocalWhisperConfig {
1306 zeroclaw_config::schema::LocalWhisperConfig {
1307 url: url.to_string(),
1308 bearer_token: Some("test-token".to_string()),
1309 max_audio_bytes: 10 * 1024 * 1024,
1310 timeout_secs: 30,
1311 }
1312 }
1313
1314 #[test]
1315 fn local_whisper_rejects_empty_url() {
1316 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1317 cfg.url = String::new();
1318 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1319 .err()
1320 .unwrap();
1321 assert!(
1322 err.to_string().contains("`url` must not be empty"),
1323 "got: {err}"
1324 );
1325 }
1326
1327 #[test]
1328 fn local_whisper_rejects_invalid_url() {
1329 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1330 cfg.url = "not-a-url".to_string();
1331 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1332 .err()
1333 .unwrap();
1334 assert!(err.to_string().contains("invalid `url`"), "got: {err}");
1335 }
1336
1337 #[test]
1338 fn local_whisper_rejects_non_http_url() {
1339 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1340 cfg.url = "ftp://10.10.0.1:8001/v1/transcribe".to_string();
1341 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1342 .err()
1343 .unwrap();
1344 assert!(err.to_string().contains("http or https"), "got: {err}");
1345 }
1346
1347 #[test]
1348 fn local_whisper_rejects_empty_bearer_token() {
1349 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1350 cfg.bearer_token = Some(String::new());
1351 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1352 .err()
1353 .unwrap();
1354 assert!(
1355 err.to_string().contains("`bearer_token` must not be empty"),
1356 "got: {err}"
1357 );
1358 }
1359
1360 #[test]
1361 fn local_whisper_rejects_missing_bearer_token() {
1362 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1363 cfg.bearer_token = None;
1364 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1365 .err()
1366 .unwrap();
1367 assert!(
1368 err.to_string().contains("`bearer_token` must be set"),
1369 "got: {err}"
1370 );
1371 }
1372
1373 #[test]
1374 fn local_whisper_rejects_zero_max_audio_bytes() {
1375 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1376 cfg.max_audio_bytes = 0;
1377 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1378 .err()
1379 .unwrap();
1380 assert!(
1381 err.to_string()
1382 .contains("`max_audio_bytes` must be greater than zero"),
1383 "got: {err}"
1384 );
1385 }
1386
1387 #[test]
1388 fn local_whisper_rejects_zero_timeout() {
1389 let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1390 cfg.timeout_secs = 0;
1391 let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1392 .err()
1393 .unwrap();
1394 assert!(
1395 err.to_string()
1396 .contains("`timeout_secs` must be greater than zero"),
1397 "got: {err}"
1398 );
1399 }
1400
1401 #[test]
1402 fn local_whisper_registered_when_config_present() {
1403 let config = TranscriptionConfig {
1404 local_whisper: Some(local_whisper_config("http://127.0.0.1:9999/v1/transcribe")),
1405 ..TranscriptionConfig::default()
1406 };
1407
1408 let manager = TranscriptionManager::new(&config).unwrap();
1409 assert!(
1410 manager.available_providers().contains(&"local_whisper"),
1411 "expected local_whisper in {:?}",
1412 manager.available_providers()
1413 );
1414 }
1415
1416 #[test]
1417 fn local_whisper_misconfigured_section_fails_manager_construction() {
1418 let mut bad_cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1423 bad_cfg.bearer_token = Some(String::new());
1424 let config = TranscriptionConfig {
1425 local_whisper: Some(bad_cfg),
1426 enabled: true,
1427 ..TranscriptionConfig::default()
1428 };
1429
1430 let err = TranscriptionManager::new(&config).err().unwrap();
1431 assert!(
1432 err.to_string()
1433 .contains("no transcription provider registered"),
1434 "expected 'no transcription provider registered' from manager safety net, got: {err}"
1435 );
1436 }
1437
1438 #[test]
1439 fn validate_audio_still_enforces_25mb_cap() {
1440 let at_limit = vec![0u8; MAX_AUDIO_BYTES];
1442 assert!(validate_audio(&at_limit, "test.ogg").is_ok());
1443 let over_limit = vec![0u8; MAX_AUDIO_BYTES + 1];
1444 let err = validate_audio(&over_limit, "test.ogg").unwrap_err();
1445 assert!(err.to_string().contains("too large"));
1446 }
1447
1448 #[tokio::test]
1449 async fn local_whisper_rejects_oversized_audio() {
1450 let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1451 let transcription_provider =
1452 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1453 let big = vec![0u8; cfg.max_audio_bytes + 1];
1454 let err = transcription_provider
1455 .transcribe(&big, "voice.ogg")
1456 .await
1457 .unwrap_err();
1458 assert!(err.to_string().contains("too large"), "got: {err}");
1459 }
1460
1461 #[tokio::test]
1462 async fn local_whisper_rejects_unsupported_format() {
1463 let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1464 let transcription_provider =
1465 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1466 let data = vec![0u8; 100];
1467 let err = transcription_provider
1468 .transcribe(&data, "voice.aiff")
1469 .await
1470 .unwrap_err();
1471 assert!(
1472 err.to_string().contains("Unsupported audio format"),
1473 "got: {err}"
1474 );
1475 }
1476
1477 #[tokio::test]
1480 async fn local_whisper_returns_text_from_response() {
1481 use wiremock::matchers::{header_exists, method, path};
1482 use wiremock::{Mock, MockServer, ResponseTemplate};
1483
1484 let server = MockServer::start().await;
1485
1486 Mock::given(method("POST"))
1487 .and(path("/v1/transcribe"))
1488 .and(header_exists("authorization"))
1489 .respond_with(
1490 ResponseTemplate::new(200)
1491 .set_body_json(serde_json::json!({"text": "hello world"})),
1492 )
1493 .mount(&server)
1494 .await;
1495
1496 let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1497 let transcription_provider =
1498 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1499
1500 let result = transcription_provider
1501 .transcribe(b"fake-audio", "voice.ogg")
1502 .await
1503 .unwrap();
1504 assert_eq!(result, "hello world");
1505 }
1506
1507 #[tokio::test]
1508 async fn local_whisper_sends_bearer_auth_header() {
1509 use wiremock::matchers::{header, method, path};
1510 use wiremock::{Mock, MockServer, ResponseTemplate};
1511
1512 let server = MockServer::start().await;
1513
1514 Mock::given(method("POST"))
1515 .and(path("/v1/transcribe"))
1516 .and(header("authorization", "Bearer test-token"))
1517 .respond_with(
1518 ResponseTemplate::new(200).set_body_json(serde_json::json!({"text": "auth ok"})),
1519 )
1520 .mount(&server)
1521 .await;
1522
1523 let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1524 let transcription_provider =
1525 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1526
1527 let result = transcription_provider
1528 .transcribe(b"fake-audio", "voice.ogg")
1529 .await
1530 .unwrap();
1531 assert_eq!(result, "auth ok");
1532 }
1533
1534 #[tokio::test]
1535 async fn local_whisper_propagates_http_error() {
1536 use wiremock::matchers::{method, path};
1537 use wiremock::{Mock, MockServer, ResponseTemplate};
1538
1539 let server = MockServer::start().await;
1540
1541 Mock::given(method("POST"))
1542 .and(path("/v1/transcribe"))
1543 .respond_with(
1544 ResponseTemplate::new(503).set_body_json(
1545 serde_json::json!({"error": {"message": "service unavailable"}}),
1546 ),
1547 )
1548 .mount(&server)
1549 .await;
1550
1551 let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1552 let transcription_provider =
1553 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1554
1555 let err = transcription_provider
1556 .transcribe(b"fake-audio", "voice.ogg")
1557 .await
1558 .unwrap_err();
1559 assert!(
1560 err.to_string().contains("503") || err.to_string().contains("service unavailable"),
1561 "expected HTTP error, got: {err}"
1562 );
1563 }
1564
1565 #[tokio::test]
1566 async fn local_whisper_propagates_non_json_http_error() {
1567 use wiremock::matchers::{method, path};
1568 use wiremock::{Mock, MockServer, ResponseTemplate};
1569
1570 let server = MockServer::start().await;
1571
1572 Mock::given(method("POST"))
1573 .and(path("/v1/transcribe"))
1574 .respond_with(
1575 ResponseTemplate::new(502)
1576 .set_body_string("Bad Gateway")
1577 .insert_header("content-type", "text/plain"),
1578 )
1579 .mount(&server)
1580 .await;
1581
1582 let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1583 let transcription_provider =
1584 LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1585
1586 let err = transcription_provider
1587 .transcribe(b"fake-audio", "voice.ogg")
1588 .await
1589 .unwrap_err();
1590 assert!(err.to_string().contains("502"), "got: {err}");
1591 assert!(
1592 err.to_string().contains("Bad Gateway"),
1593 "expected plain-text body in error, got: {err}"
1594 );
1595 }
1596}