Skip to main content

zeroclaw_channels/
transcription.rs

1use std::collections::HashMap;
2
3use anyhow::{Context, Result, bail};
4use async_trait::async_trait;
5use reqwest::multipart::{Form, Part};
6
7use zeroclaw_config::schema::TranscriptionConfig;
8
9/// Maximum upload size accepted by most Whisper-compatible APIs (25 MB).
10const MAX_AUDIO_BYTES: usize = 25 * 1024 * 1024;
11
12/// Request timeout for transcription API calls (seconds).
13const TRANSCRIPTION_TIMEOUT_SECS: u64 = 120;
14
15// ── Audio utilities ─────────────────────────────────────────────
16
17/// Map file extension to MIME type for Whisper-compatible transcription APIs.
18fn mime_for_audio(extension: &str) -> Option<&'static str> {
19    match extension.to_ascii_lowercase().as_str() {
20        "flac" => Some("audio/flac"),
21        "mp3" | "mpeg" | "mpga" => Some("audio/mpeg"),
22        "mp4" | "m4a" => Some("audio/mp4"),
23        "ogg" | "oga" => Some("audio/ogg"),
24        "opus" => Some("audio/opus"),
25        "wav" => Some("audio/wav"),
26        "webm" => Some("audio/webm"),
27        _ => None,
28    }
29}
30
31/// Normalize audio filename for Whisper-compatible APIs.
32///
33/// Groq validates the filename extension — `.oga` (Opus-in-Ogg) is not in
34/// its accepted list, so we rewrite it to `.ogg`.
35fn normalize_audio_filename(file_name: &str) -> String {
36    match file_name.rsplit_once('.') {
37        Some((stem, ext)) if ext.eq_ignore_ascii_case("oga") => format!("{stem}.ogg"),
38        _ => file_name.to_string(),
39    }
40}
41
42/// Resolve MIME type and normalize filename from extension.
43///
44/// No size check — callers enforce their own limits.
45fn resolve_audio_format(file_name: &str) -> Result<(String, &'static str)> {
46    let normalized_name = normalize_audio_filename(file_name);
47    let extension = normalized_name
48        .rsplit_once('.')
49        .map(|(_, e)| e)
50        .unwrap_or("");
51    let mime = mime_for_audio(extension).ok_or_else(|| {
52        ::zeroclaw_log::record!(
53            WARN,
54            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
55                .with_outcome(::zeroclaw_log::EventOutcome::Failure)
56                .with_attrs(::serde_json::json!({"extension": extension})),
57            "transcription: unsupported audio format"
58        );
59        anyhow::Error::msg(format!(
60            "Unsupported audio format '.{extension}'. \
61             accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm"
62        ))
63    })?;
64    Ok((normalized_name, mime))
65}
66
67/// Validate audio data and resolve MIME type from file name.
68///
69/// Enforces the 25 MB cloud API cap. Returns `(normalized_filename, mime_type)` on success.
70fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
71    if audio_data.len() > MAX_AUDIO_BYTES {
72        bail!(
73            "Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
74            audio_data.len()
75        );
76    }
77    resolve_audio_format(file_name)
78}
79
80// ── TranscriptionProvider trait ─────────────────────────────────
81
82/// Trait for speech-to-text transcription_provider implementations.
83#[async_trait]
84pub trait TranscriptionProvider: Send + Sync + ::zeroclaw_api::attribution::Attributable {
85    /// Human-readable transcription_provider name (e.g. "groq", "openai").
86    fn name(&self) -> &str;
87
88    /// Transcribe raw audio bytes. `file_name` includes the extension for
89    /// format detection (e.g. "voice.ogg").
90    async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String>;
91
92    /// List of supported audio file extensions.
93    fn supported_formats(&self) -> Vec<String> {
94        vec![
95            "flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
96        ]
97        .into_iter()
98        .map(String::from)
99        .collect()
100    }
101}
102
103// ── GroqProvider ────────────────────────────────────────────────
104
105/// Groq Whisper API transcription_provider (default, backward-compatible with existing config).
106pub struct GroqProvider {
107    alias: String,
108    api_url: String,
109    model: String,
110    api_key: String,
111    language: Option<String>,
112}
113
114impl GroqProvider {
115    /// Build from the existing `TranscriptionConfig` fields.
116    ///
117    /// Credential resolution order:
118    /// Reads `config.api_key` (set via `[transcription].api_key` or the
119    /// schema-mirror env grammar `ZEROCLAW_transcription__api_key=...`).
120    /// The legacy `GROQ_API_KEY` env-var fallback was eradicated in V0.8.0.
121    pub fn from_config(alias: &str, config: &TranscriptionConfig) -> Result<Self> {
122        let api_key = config
123            .api_key
124            .as_deref()
125            .map(str::trim)
126            .filter(|v| !v.is_empty())
127            .map(ToOwned::to_owned)
128            .context(
129                "Missing transcription API key: set `[transcription].api_key` (or via the \
130                 schema-mirror grammar `ZEROCLAW_transcription__api_key=...`).",
131            )?;
132
133        Ok(Self {
134            alias: alias.to_string(),
135            api_url: config.api_url.clone(),
136            model: config.model.clone(),
137            api_key,
138            language: config.language.clone(),
139        })
140    }
141}
142
143#[async_trait]
144impl TranscriptionProvider for GroqProvider {
145    fn name(&self) -> &str {
146        "groq"
147    }
148
149    async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
150        let (normalized_name, mime) = validate_audio(audio_data, file_name)?;
151
152        let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.groq");
153
154        let file_part = Part::bytes(audio_data.to_vec())
155            .file_name(normalized_name)
156            .mime_str(mime)?;
157
158        let mut form = Form::new()
159            .part("file", file_part)
160            .text("model", self.model.clone())
161            .text("response_format", "json");
162
163        if let Some(ref lang) = self.language {
164            form = form.text("language", lang.clone());
165        }
166
167        let resp = client
168            .post(&self.api_url)
169            .bearer_auth(&self.api_key)
170            .multipart(form)
171            .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
172            .send()
173            .await
174            .context("Failed to send transcription request to Groq")?;
175
176        parse_whisper_response(resp).await
177    }
178}
179
180// ── OpenAiWhisperProvider ───────────────────────────────────────
181
182/// OpenAI Whisper API transcription_provider.
183pub struct OpenAiWhisperProvider {
184    alias: String,
185    api_key: String,
186    model: String,
187}
188
189impl OpenAiWhisperProvider {
190    pub fn from_config(
191        alias: &str,
192        config: &zeroclaw_config::schema::OpenAiSttConfig,
193    ) -> Result<Self> {
194        let api_key = config
195            .api_key
196            .as_deref()
197            .map(str::trim)
198            .filter(|v| !v.is_empty())
199            .map(ToOwned::to_owned)
200            .context("Missing OpenAI STT API key: set [transcription.openai].api_key")?;
201
202        Ok(Self {
203            alias: alias.to_string(),
204            api_key,
205            model: config.model.clone(),
206        })
207    }
208}
209
210#[async_trait]
211impl TranscriptionProvider for OpenAiWhisperProvider {
212    fn name(&self) -> &str {
213        "openai"
214    }
215
216    async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
217        let (normalized_name, mime) = validate_audio(audio_data, file_name)?;
218
219        let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.openai");
220
221        let file_part = Part::bytes(audio_data.to_vec())
222            .file_name(normalized_name)
223            .mime_str(mime)?;
224
225        let form = Form::new()
226            .part("file", file_part)
227            .text("model", self.model.clone())
228            .text("response_format", "json");
229
230        let resp = client
231            .post("https://api.openai.com/v1/audio/transcriptions")
232            .bearer_auth(&self.api_key)
233            .multipart(form)
234            .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
235            .send()
236            .await
237            .context("Failed to send transcription request to OpenAI")?;
238
239        parse_whisper_response(resp).await
240    }
241}
242
243// ── DeepgramProvider ────────────────────────────────────────────
244
245/// Deepgram STT API transcription_provider.
246pub struct DeepgramProvider {
247    alias: String,
248    api_key: String,
249    model: String,
250}
251
252impl DeepgramProvider {
253    pub fn from_config(
254        alias: &str,
255        config: &zeroclaw_config::schema::DeepgramSttConfig,
256    ) -> Result<Self> {
257        let api_key = config
258            .api_key
259            .as_deref()
260            .map(str::trim)
261            .filter(|v| !v.is_empty())
262            .map(ToOwned::to_owned)
263            .context("Missing Deepgram API key: set [transcription.deepgram].api_key")?;
264
265        Ok(Self {
266            alias: alias.to_string(),
267            api_key,
268            model: config.model.clone(),
269        })
270    }
271}
272
273#[async_trait]
274impl TranscriptionProvider for DeepgramProvider {
275    fn name(&self) -> &str {
276        "deepgram"
277    }
278
279    async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
280        let (_, mime) = validate_audio(audio_data, file_name)?;
281
282        let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.deepgram");
283
284        let url = format!(
285            "https://api.deepgram.com/v1/listen?model={}&punctuate=true",
286            self.model
287        );
288
289        let resp = client
290            .post(&url)
291            .header("Authorization", format!("Token {}", self.api_key))
292            .header("Content-Type", mime)
293            .body(audio_data.to_vec())
294            .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
295            .send()
296            .await
297            .context("Failed to send transcription request to Deepgram")?;
298
299        let status = resp.status();
300        let body: serde_json::Value = resp
301            .json()
302            .await
303            .context("Failed to parse Deepgram response")?;
304
305        if !status.is_success() {
306            let error_msg = body["err_msg"]
307                .as_str()
308                .or_else(|| body["error"].as_str())
309                .unwrap_or("unknown error");
310            bail!("Deepgram API error ({}): {}", status, error_msg);
311        }
312
313        let text = body["results"]["channels"][0]["alternatives"][0]["transcript"]
314            .as_str()
315            .context("Deepgram response missing transcript field")?
316            .to_string();
317
318        Ok(text)
319    }
320}
321
322// ── AssemblyAiProvider ──────────────────────────────────────────
323
324/// AssemblyAI STT API transcription_provider.
325pub struct AssemblyAiProvider {
326    alias: String,
327    api_key: String,
328}
329
330impl AssemblyAiProvider {
331    pub fn from_config(
332        alias: &str,
333        config: &zeroclaw_config::schema::AssemblyAiSttConfig,
334    ) -> Result<Self> {
335        let api_key = config
336            .api_key
337            .as_deref()
338            .map(str::trim)
339            .filter(|v| !v.is_empty())
340            .map(ToOwned::to_owned)
341            .context("Missing AssemblyAI API key: set [transcription.assemblyai].api_key")?;
342
343        Ok(Self {
344            alias: alias.to_string(),
345            api_key,
346        })
347    }
348}
349
350#[async_trait]
351impl TranscriptionProvider for AssemblyAiProvider {
352    fn name(&self) -> &str {
353        "assemblyai"
354    }
355
356    async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
357        let (_, _) = validate_audio(audio_data, file_name)?;
358
359        let client =
360            zeroclaw_config::schema::build_runtime_proxy_client("transcription.assemblyai");
361
362        // Step 1: Upload the audio file.
363        let upload_resp = client
364            .post("https://api.assemblyai.com/v2/upload")
365            .header("Authorization", &self.api_key)
366            .header("Content-Type", "application/octet-stream")
367            .body(audio_data.to_vec())
368            .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
369            .send()
370            .await
371            .context("Failed to upload audio to AssemblyAI")?;
372
373        let upload_status = upload_resp.status();
374        let upload_body: serde_json::Value = upload_resp
375            .json()
376            .await
377            .context("Failed to parse AssemblyAI upload response")?;
378
379        if !upload_status.is_success() {
380            let error_msg = upload_body["error"].as_str().unwrap_or("unknown error");
381            bail!("AssemblyAI upload error ({}): {}", upload_status, error_msg);
382        }
383
384        let upload_url = upload_body["upload_url"]
385            .as_str()
386            .context("AssemblyAI upload response missing 'upload_url'")?;
387
388        // Step 2: Create transcription job.
389        let transcript_req = serde_json::json!({
390            "audio_url": upload_url,
391        });
392
393        let create_resp = client
394            .post("https://api.assemblyai.com/v2/transcript")
395            .header("Authorization", &self.api_key)
396            .json(&transcript_req)
397            .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
398            .send()
399            .await
400            .context("Failed to create AssemblyAI transcription")?;
401
402        let create_status = create_resp.status();
403        let create_body: serde_json::Value = create_resp
404            .json()
405            .await
406            .context("Failed to parse AssemblyAI create response")?;
407
408        if !create_status.is_success() {
409            let error_msg = create_body["error"].as_str().unwrap_or("unknown error");
410            bail!(
411                "AssemblyAI transcription error ({}): {}",
412                create_status,
413                error_msg
414            );
415        }
416
417        let transcript_id = create_body["id"]
418            .as_str()
419            .context("AssemblyAI response missing 'id'")?;
420
421        // Step 3: Poll for completion.
422        let poll_url = format!("https://api.assemblyai.com/v2/transcript/{transcript_id}");
423        let poll_interval = std::time::Duration::from_secs(3);
424        let poll_deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(180);
425
426        while tokio::time::Instant::now() < poll_deadline {
427            tokio::time::sleep(poll_interval).await;
428
429            let poll_resp = client
430                .get(&poll_url)
431                .header("Authorization", &self.api_key)
432                .timeout(std::time::Duration::from_secs(30))
433                .send()
434                .await
435                .context("Failed to poll AssemblyAI transcription")?;
436
437            let poll_status = poll_resp.status();
438            let poll_body: serde_json::Value = poll_resp
439                .json()
440                .await
441                .context("Failed to parse AssemblyAI poll response")?;
442
443            if !poll_status.is_success() {
444                let error_msg = poll_body["error"].as_str().unwrap_or("unknown poll error");
445                bail!("AssemblyAI poll error ({}): {}", poll_status, error_msg);
446            }
447
448            let status_str = poll_body["status"].as_str().unwrap_or("unknown");
449
450            match status_str {
451                "completed" => {
452                    let text = poll_body["text"]
453                        .as_str()
454                        .context("AssemblyAI response missing 'text'")?
455                        .to_string();
456                    return Ok(text);
457                }
458                "error" => {
459                    let error_msg = poll_body["error"]
460                        .as_str()
461                        .unwrap_or("unknown transcription error");
462                    bail!("AssemblyAI transcription failed: {}", error_msg);
463                }
464                _ => {}
465            }
466        }
467
468        bail!("AssemblyAI transcription timed out after 180s")
469    }
470}
471
472// ── GoogleSttProvider ───────────────────────────────────────────
473
474/// Google Cloud Speech-to-Text API transcription_provider.
475pub struct GoogleSttProvider {
476    alias: String,
477    api_key: String,
478    language_code: String,
479}
480
481impl GoogleSttProvider {
482    pub fn from_config(
483        alias: &str,
484        config: &zeroclaw_config::schema::GoogleSttConfig,
485    ) -> Result<Self> {
486        let api_key = config
487            .api_key
488            .as_deref()
489            .map(str::trim)
490            .filter(|v| !v.is_empty())
491            .map(ToOwned::to_owned)
492            .context("Missing Google STT API key: set [transcription.google].api_key")?;
493
494        Ok(Self {
495            alias: alias.to_string(),
496            api_key,
497            language_code: config.language_code.clone(),
498        })
499    }
500}
501
502#[async_trait]
503impl TranscriptionProvider for GoogleSttProvider {
504    fn name(&self) -> &str {
505        "google"
506    }
507
508    fn supported_formats(&self) -> Vec<String> {
509        // Google Cloud STT supports a subset of formats.
510        vec!["flac", "wav", "ogg", "opus", "mp3", "webm"]
511            .into_iter()
512            .map(String::from)
513            .collect()
514    }
515
516    async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
517        let (normalized_name, _) = validate_audio(audio_data, file_name)?;
518
519        let client = zeroclaw_config::schema::build_runtime_proxy_client("transcription.google");
520
521        let encoding = match normalized_name
522            .rsplit_once('.')
523            .map(|(_, e)| e.to_ascii_lowercase())
524            .as_deref()
525        {
526            Some("flac") => "FLAC",
527            Some("wav") => "LINEAR16",
528            Some("ogg" | "opus") => "OGG_OPUS",
529            Some("mp3") => "MP3",
530            Some("webm") => "WEBM_OPUS",
531            Some(ext) => bail!("Google STT does not support '.{ext}' input"),
532            None => bail!("Google STT requires a file extension"),
533        };
534
535        let audio_content =
536            base64::Engine::encode(&base64::engine::general_purpose::STANDARD, audio_data);
537
538        let request_body = serde_json::json!({
539            "config": {
540                "encoding": encoding,
541                "languageCode": &self.language_code,
542                "enableAutomaticPunctuation": true,
543            },
544            "audio": {
545                "content": audio_content,
546            }
547        });
548
549        let url = format!(
550            "https://speech.googleapis.com/v1/speech:recognize?key={}",
551            self.api_key
552        );
553
554        let resp = client
555            .post(&url)
556            .json(&request_body)
557            .timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
558            .send()
559            .await
560            .context("Failed to send transcription request to Google STT")?;
561
562        let status = resp.status();
563        let body: serde_json::Value = resp
564            .json()
565            .await
566            .context("Failed to parse Google STT response")?;
567
568        if !status.is_success() {
569            let error_msg = body["error"]["message"].as_str().unwrap_or("unknown error");
570            bail!("Google STT API error ({}): {}", status, error_msg);
571        }
572
573        let text = body["results"][0]["alternatives"][0]["transcript"]
574            .as_str()
575            .unwrap_or("")
576            .to_string();
577
578        Ok(text)
579    }
580}
581
582// ── LocalWhisperProvider ────────────────────────────────────────
583
584/// Self-hosted faster-whisper-compatible STT transcription_provider.
585///
586/// POSTs audio as `multipart/form-data` (field name `file`) to a configurable
587/// HTTP endpoint (e.g. `http://localhost:8000` or a private network host). The endpoint
588/// must return `{"text": "..."}`. No cloud API key required. Size limit is
589/// configurable — not constrained by the 25 MB cloud API cap.
590pub struct LocalWhisperProvider {
591    alias: String,
592    url: String,
593    bearer_token: String,
594    max_audio_bytes: usize,
595    timeout_secs: u64,
596}
597
598impl LocalWhisperProvider {
599    /// Build from config. Fails if `url` or `bearer_token` is empty, if `url`
600    /// is not a valid HTTP/HTTPS URL (scheme must be `http` or `https`), if
601    /// `max_audio_bytes` is zero, or if `timeout_secs` is zero.
602    pub fn from_config(
603        alias: &str,
604        config: &zeroclaw_config::schema::LocalWhisperConfig,
605    ) -> Result<Self> {
606        let url = config.url.trim().to_string();
607        anyhow::ensure!(!url.is_empty(), "local_whisper: `url` must not be empty");
608        let parsed = url
609            .parse::<reqwest::Url>()
610            .with_context(|| format!("local_whisper: invalid `url`: {url:?}"))?;
611        anyhow::ensure!(
612            matches!(parsed.scheme(), "http" | "https"),
613            "local_whisper: `url` must use http or https scheme, got {:?}",
614            parsed.scheme()
615        );
616
617        let bearer_token = match config.bearer_token.as_deref().map(str::trim) {
618            None => anyhow::bail!("local_whisper: `bearer_token` must be set"),
619            Some("") => anyhow::bail!("local_whisper: `bearer_token` must not be empty"),
620            Some(t) => t.to_string(),
621        };
622
623        anyhow::ensure!(
624            config.max_audio_bytes > 0,
625            "local_whisper: `max_audio_bytes` must be greater than zero"
626        );
627
628        anyhow::ensure!(
629            config.timeout_secs > 0,
630            "local_whisper: `timeout_secs` must be greater than zero"
631        );
632
633        Ok(Self {
634            alias: alias.to_string(),
635            url,
636            bearer_token,
637            max_audio_bytes: config.max_audio_bytes,
638            timeout_secs: config.timeout_secs,
639        })
640    }
641}
642
643#[async_trait]
644impl TranscriptionProvider for LocalWhisperProvider {
645    fn name(&self) -> &str {
646        "local_whisper"
647    }
648
649    async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
650        if audio_data.len() > self.max_audio_bytes {
651            bail!(
652                "Audio file too large ({} bytes, local_whisper max {})",
653                audio_data.len(),
654                self.max_audio_bytes
655            );
656        }
657
658        let (normalized_name, mime) = resolve_audio_format(file_name)?;
659
660        let client =
661            zeroclaw_config::schema::build_runtime_proxy_client("transcription.local_whisper");
662
663        // to_vec() clones the buffer for the multipart payload; peak memory per
664        // call is ~2× max_audio_bytes. TODO: replace with streaming upload once
665        // reqwest supports body streaming in multipart parts.
666        let file_part = Part::bytes(audio_data.to_vec())
667            .file_name(normalized_name)
668            .mime_str(mime)?;
669
670        let resp = client
671            .post(&self.url)
672            .bearer_auth(&self.bearer_token)
673            .multipart(Form::new().part("file", file_part))
674            .timeout(std::time::Duration::from_secs(self.timeout_secs))
675            .send()
676            .await
677            .context("Failed to send audio to local Whisper endpoint")?;
678
679        parse_whisper_response(resp).await
680    }
681}
682
683// ── Shared response parsing ─────────────────────────────────────
684
685/// Parse a faster-whisper-compatible JSON response (`{ "text": "..." }`).
686///
687/// Checks HTTP status before attempting JSON parsing so that non-JSON error
688/// bodies (plain text, HTML, empty 5xx) produce a readable status error
689/// rather than a confusing "Failed to parse transcription response".
690async fn parse_whisper_response(resp: reqwest::Response) -> Result<String> {
691    let status = resp.status();
692    if !status.is_success() {
693        let body = resp.text().await.unwrap_or_default();
694        bail!("Transcription API error ({}): {}", status, body.trim());
695    }
696
697    let body: serde_json::Value = resp
698        .json()
699        .await
700        .context("Failed to parse transcription response")?;
701
702    let text = body["text"]
703        .as_str()
704        .context("Transcription response missing 'text' field")?
705        .to_string();
706
707    Ok(text)
708}
709
710// ── TranscriptionManager ────────────────────────────────────────
711
712/// Manages multiple transcription / STT providers and routes transcription
713/// requests. The manager is implicitly per-agent: the runtime-active
714/// agent's `transcription_provider` reference is the resolved alias for
715/// `transcribe()` calls. there is no global default-provider concept.
716pub struct TranscriptionManager {
717    transcription_providers: HashMap<String, Box<dyn TranscriptionProvider>>,
718    max_audio_bytes: Option<usize>,
719    /// Resolved alias for the agent that owns this manager. Empty when
720    /// the agent has no transcription preference (opt-out).
721    agent_transcription_provider: String,
722}
723
724impl TranscriptionManager {
725    /// Build a `TranscriptionManager` from a `TranscriptionConfig`. The
726    /// resolved agent alias starts empty; orchestrators that wire the
727    /// manager to a specific agent should call
728    /// `with_agent_transcription_provider` to set it.
729    pub fn new(config: &TranscriptionConfig) -> Result<Self> {
730        if matches!(config.max_audio_bytes, Some(0)) {
731            bail!("transcription.max_audio_bytes must be greater than zero");
732        }
733
734        let mut transcription_providers: HashMap<String, Box<dyn TranscriptionProvider>> =
735            HashMap::new();
736
737        if let Ok(groq) = GroqProvider::from_config("groq", config) {
738            transcription_providers.insert("groq".to_string(), Box::new(groq));
739        }
740
741        if let Some(ref openai_cfg) = config.openai
742            && let Ok(p) = OpenAiWhisperProvider::from_config("openai", openai_cfg)
743        {
744            transcription_providers.insert("openai".to_string(), Box::new(p));
745        }
746
747        if let Some(ref deepgram_cfg) = config.deepgram
748            && let Ok(p) = DeepgramProvider::from_config("deepgram", deepgram_cfg)
749        {
750            transcription_providers.insert("deepgram".to_string(), Box::new(p));
751        }
752
753        if let Some(ref assemblyai_cfg) = config.assemblyai
754            && let Ok(p) = AssemblyAiProvider::from_config("assemblyai", assemblyai_cfg)
755        {
756            transcription_providers.insert("assemblyai".to_string(), Box::new(p));
757        }
758
759        if let Some(ref google_cfg) = config.google
760            && let Ok(p) = GoogleSttProvider::from_config("google", google_cfg)
761        {
762            transcription_providers.insert("google".to_string(), Box::new(p));
763        }
764
765        if let Some(ref local_cfg) = config.local_whisper {
766            match LocalWhisperProvider::from_config("local_whisper", local_cfg) {
767                Ok(p) => {
768                    transcription_providers.insert("local_whisper".to_string(), Box::new(p));
769                }
770                Err(e) => {
771                    ::zeroclaw_log::record!(
772                        WARN,
773                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
774                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
775                            .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
776                        "local_whisper config invalid, provider skipped"
777                    );
778                }
779            }
780        }
781
782        if config.enabled && transcription_providers.is_empty() {
783            bail!(
784                "Transcription is enabled but no transcription provider registered \
785                 successfully. Configure at least one of: [transcription] (Groq) \
786                 with api_key + api_url; [transcription.openai]; [transcription.deepgram]; \
787                 [transcription.assemblyai]; [transcription.google]; [transcription.local_whisper]."
788            );
789        }
790
791        Ok(Self {
792            transcription_providers,
793            max_audio_bytes: config.max_audio_bytes,
794            agent_transcription_provider: String::new(),
795        })
796    }
797
798    /// Set the resolved agent `transcription_provider` alias. Called by
799    /// orchestrators that bind this manager to a specific agent at startup.
800    /// Subsequent `transcribe` calls dispatch to this alias.
801    #[must_use]
802    pub fn with_agent_transcription_provider(mut self, alias: impl Into<String>) -> Self {
803        self.agent_transcription_provider = alias.into();
804        self
805    }
806
807    /// Transcribe audio using the runtime-active agent's resolved
808    /// `transcription_provider`. Fails loud when the agent has no
809    /// transcription_provider configured — there is no global default.
810    pub async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
811        let provider_alias = self.agent_transcription_provider.as_str();
812        if provider_alias.is_empty() {
813            bail!(
814                "Agent has no transcription_provider configured. Set \
815                 `agent.<alias>.transcription_provider = \"<type>.<alias>\"` \
816                 referencing a configured transcription provider."
817            );
818        }
819        self.transcribe_with_provider(audio_data, file_name, provider_alias)
820            .await
821    }
822
823    /// Transcribe audio using a specific named transcription_provider.
824    pub async fn transcribe_with_provider(
825        &self,
826        audio_data: &[u8],
827        file_name: &str,
828        transcription_provider: &str,
829    ) -> Result<String> {
830        let p = self.transcription_providers.get(transcription_provider).ok_or_else(|| {
831            let available: Vec<&str> = self.transcription_providers.keys().map(|k| k.as_str()).collect();
832            ::zeroclaw_log::record!(
833                ERROR,
834                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
835                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
836                    .with_attrs(::serde_json::json!({
837                        "transcription_provider": transcription_provider,
838                        "available": available,
839                    })),
840                "transcription: provider not configured"
841            );
842            anyhow::Error::msg(format!(
843                "Transcription transcription_provider '{transcription_provider}' not configured. Available: {available:?}"
844            ))
845        })?;
846
847        self.enforce_global_audio_limit(audio_data)?;
848
849        use ::zeroclaw_log::Instrument;
850        let span = ::zeroclaw_log::attribution_span!(p.as_ref());
851        p.transcribe(audio_data, file_name).instrument(span).await
852    }
853
854    fn enforce_global_audio_limit(&self, audio_data: &[u8]) -> Result<()> {
855        if let Some(max_audio_bytes) = self.max_audio_bytes
856            && audio_data.len() > max_audio_bytes
857        {
858            bail!(
859                "Audio file too large ({} bytes, global max {})",
860                audio_data.len(),
861                max_audio_bytes
862            );
863        }
864        Ok(())
865    }
866
867    /// List registered transcription_provider names.
868    pub fn available_providers(&self) -> Vec<&str> {
869        self.transcription_providers
870            .keys()
871            .map(|k| k.as_str())
872            .collect()
873    }
874}
875
876// `transcribe_audio` (the legacy free function that dispatched against
877// `config.default_transcription_provider`) was deleted in #6273. There is
878// no global default-provider concept anymore; transcription routes through
879// `TranscriptionManager` whose resolved alias comes from the per-agent
880// `transcription_provider` field (`agent.<X>.transcription_provider`).
881
882impl ::zeroclaw_api::attribution::Attributable for GroqProvider {
883    fn role(&self) -> ::zeroclaw_api::attribution::Role {
884        ::zeroclaw_api::attribution::Role::Provider(
885            ::zeroclaw_api::attribution::ProviderKind::Transcription(
886                ::zeroclaw_api::attribution::TranscriptionProviderKind::Groq,
887            ),
888        )
889    }
890    fn alias(&self) -> &str {
891        &self.alias
892    }
893}
894
895impl ::zeroclaw_api::attribution::Attributable for OpenAiWhisperProvider {
896    fn role(&self) -> ::zeroclaw_api::attribution::Role {
897        ::zeroclaw_api::attribution::Role::Provider(
898            ::zeroclaw_api::attribution::ProviderKind::Transcription(
899                ::zeroclaw_api::attribution::TranscriptionProviderKind::OpenAi,
900            ),
901        )
902    }
903    fn alias(&self) -> &str {
904        &self.alias
905    }
906}
907
908impl ::zeroclaw_api::attribution::Attributable for DeepgramProvider {
909    fn role(&self) -> ::zeroclaw_api::attribution::Role {
910        ::zeroclaw_api::attribution::Role::Provider(
911            ::zeroclaw_api::attribution::ProviderKind::Transcription(
912                ::zeroclaw_api::attribution::TranscriptionProviderKind::Deepgram,
913            ),
914        )
915    }
916    fn alias(&self) -> &str {
917        &self.alias
918    }
919}
920
921impl ::zeroclaw_api::attribution::Attributable for AssemblyAiProvider {
922    fn role(&self) -> ::zeroclaw_api::attribution::Role {
923        ::zeroclaw_api::attribution::Role::Provider(
924            ::zeroclaw_api::attribution::ProviderKind::Transcription(
925                ::zeroclaw_api::attribution::TranscriptionProviderKind::AssemblyAi,
926            ),
927        )
928    }
929    fn alias(&self) -> &str {
930        &self.alias
931    }
932}
933
934impl ::zeroclaw_api::attribution::Attributable for GoogleSttProvider {
935    fn role(&self) -> ::zeroclaw_api::attribution::Role {
936        ::zeroclaw_api::attribution::Role::Provider(
937            ::zeroclaw_api::attribution::ProviderKind::Transcription(
938                ::zeroclaw_api::attribution::TranscriptionProviderKind::Google,
939            ),
940        )
941    }
942    fn alias(&self) -> &str {
943        &self.alias
944    }
945}
946
947impl ::zeroclaw_api::attribution::Attributable for LocalWhisperProvider {
948    fn role(&self) -> ::zeroclaw_api::attribution::Role {
949        ::zeroclaw_api::attribution::Role::Provider(
950            ::zeroclaw_api::attribution::ProviderKind::Transcription(
951                ::zeroclaw_api::attribution::TranscriptionProviderKind::Whisper,
952            ),
953        )
954    }
955    fn alias(&self) -> &str {
956        &self.alias
957    }
958}
959
960#[cfg(test)]
961mod tests {
962    use super::*;
963    use std::sync::{
964        Arc,
965        atomic::{AtomicUsize, Ordering},
966    };
967
968    struct StaticTranscriptionProvider {
969        calls: Arc<AtomicUsize>,
970    }
971
972    #[async_trait]
973    impl TranscriptionProvider for StaticTranscriptionProvider {
974        fn name(&self) -> &str {
975            "static"
976        }
977
978        async fn transcribe(&self, _audio_data: &[u8], _file_name: &str) -> Result<String> {
979            self.calls.fetch_add(1, Ordering::SeqCst);
980            Ok("under cap".to_string())
981        }
982    }
983
984    impl ::zeroclaw_api::attribution::Attributable for StaticTranscriptionProvider {
985        fn role(&self) -> ::zeroclaw_api::attribution::Role {
986            ::zeroclaw_api::attribution::Role::Provider(
987                ::zeroclaw_api::attribution::ProviderKind::Transcription(
988                    ::zeroclaw_api::attribution::TranscriptionProviderKind::Groq,
989                ),
990            )
991        }
992
993        fn alias(&self) -> &str {
994            "static"
995        }
996    }
997
998    fn manager_with_static_provider(
999        max_audio_bytes: Option<usize>,
1000    ) -> (TranscriptionManager, Arc<AtomicUsize>) {
1001        let calls = Arc::new(AtomicUsize::new(0));
1002        let mut transcription_providers: HashMap<String, Box<dyn TranscriptionProvider>> =
1003            HashMap::new();
1004        transcription_providers.insert(
1005            "static".to_string(),
1006            Box::new(StaticTranscriptionProvider {
1007                calls: Arc::clone(&calls),
1008            }),
1009        );
1010        (
1011            TranscriptionManager {
1012                transcription_providers,
1013                max_audio_bytes,
1014                agent_transcription_provider: String::new(),
1015            },
1016            calls,
1017        )
1018    }
1019
1020    // Tests for the deleted `transcribe_audio` free function were removed
1021    // alongside the function in #6273. Equivalent coverage lives on
1022    // `TranscriptionManager` (`manager_creation_with_default_config`,
1023    // `manager_registers_groq_with_key`, `manager_rejects_unconfigured_provider`).
1024
1025    #[test]
1026    fn mime_for_audio_maps_accepted_formats() {
1027        let cases = [
1028            ("flac", "audio/flac"),
1029            ("mp3", "audio/mpeg"),
1030            ("mpeg", "audio/mpeg"),
1031            ("mpga", "audio/mpeg"),
1032            ("mp4", "audio/mp4"),
1033            ("m4a", "audio/mp4"),
1034            ("ogg", "audio/ogg"),
1035            ("oga", "audio/ogg"),
1036            ("opus", "audio/opus"),
1037            ("wav", "audio/wav"),
1038            ("webm", "audio/webm"),
1039        ];
1040        for (ext, expected) in cases {
1041            assert_eq!(
1042                mime_for_audio(ext),
1043                Some(expected),
1044                "failed for extension: {ext}"
1045            );
1046        }
1047    }
1048
1049    #[test]
1050    fn mime_for_audio_case_insensitive() {
1051        assert_eq!(mime_for_audio("OGG"), Some("audio/ogg"));
1052        assert_eq!(mime_for_audio("MP3"), Some("audio/mpeg"));
1053        assert_eq!(mime_for_audio("Opus"), Some("audio/opus"));
1054    }
1055
1056    #[test]
1057    fn mime_for_audio_rejects_unknown() {
1058        assert_eq!(mime_for_audio("txt"), None);
1059        assert_eq!(mime_for_audio("pdf"), None);
1060        assert_eq!(mime_for_audio("aac"), None);
1061        assert_eq!(mime_for_audio(""), None);
1062    }
1063
1064    #[test]
1065    fn normalize_audio_filename_rewrites_oga() {
1066        assert_eq!(normalize_audio_filename("voice.oga"), "voice.ogg");
1067        assert_eq!(normalize_audio_filename("file.OGA"), "file.ogg");
1068    }
1069
1070    #[test]
1071    fn normalize_audio_filename_preserves_accepted() {
1072        assert_eq!(normalize_audio_filename("voice.ogg"), "voice.ogg");
1073        assert_eq!(normalize_audio_filename("track.mp3"), "track.mp3");
1074        assert_eq!(normalize_audio_filename("clip.opus"), "clip.opus");
1075    }
1076
1077    #[test]
1078    fn normalize_audio_filename_no_extension() {
1079        assert_eq!(normalize_audio_filename("voice"), "voice");
1080    }
1081
1082    #[test]
1083    fn rejects_unsupported_audio_format() {
1084        // Without the legacy `transcribe_audio` free function, exercise the
1085        // format-rejection path directly via `validate_audio`.
1086        let data = vec![0u8; 100];
1087        let err = validate_audio(&data, "recording.aac").unwrap_err();
1088        let msg = err.to_string();
1089        assert!(
1090            msg.contains("Unsupported audio format"),
1091            "expected unsupported-format error, got: {msg}"
1092        );
1093        assert!(
1094            msg.contains(".aac"),
1095            "error should mention the rejected extension, got: {msg}"
1096        );
1097    }
1098
1099    // ── TranscriptionManager tests ──────────────────────────────
1100
1101    #[test]
1102    fn manager_creation_with_default_config() {
1103        // SAFETY: test-only, single-threaded test runner.
1104        unsafe { std::env::remove_var("GROQ_API_KEY") };
1105
1106        let config = TranscriptionConfig::default();
1107        let manager = TranscriptionManager::new(&config).unwrap();
1108        // the manager's agent_transcription_provider starts empty
1109        // until an orchestrator wires it via `with_agent_transcription_provider`.
1110        // No global default-provider concept.
1111        assert!(manager.agent_transcription_provider.is_empty());
1112        // Groq won't be registered without a key.
1113        assert!(manager.transcription_providers.is_empty());
1114    }
1115
1116    #[test]
1117    fn manager_registers_groq_with_key() {
1118        // SAFETY: test-only, single-threaded test runner.
1119        unsafe { std::env::remove_var("GROQ_API_KEY") };
1120
1121        let config = TranscriptionConfig {
1122            api_key: Some("test-groq-key".to_string()),
1123            ..TranscriptionConfig::default()
1124        };
1125
1126        let manager = TranscriptionManager::new(&config).unwrap();
1127        assert!(manager.transcription_providers.contains_key("groq"));
1128        assert_eq!(manager.transcription_providers["groq"].name(), "groq");
1129    }
1130
1131    #[test]
1132    fn manager_registers_multiple_providers() {
1133        // SAFETY: test-only, single-threaded test runner.
1134        unsafe { std::env::remove_var("GROQ_API_KEY") };
1135
1136        let config = TranscriptionConfig {
1137            api_key: Some("test-groq-key".to_string()),
1138            openai: Some(zeroclaw_config::schema::OpenAiSttConfig {
1139                api_key: Some("test-openai-key".to_string()),
1140                model: "whisper-1".to_string(),
1141            }),
1142            deepgram: Some(zeroclaw_config::schema::DeepgramSttConfig {
1143                api_key: Some("test-deepgram-key".to_string()),
1144                model: "nova-2".to_string(),
1145            }),
1146            ..TranscriptionConfig::default()
1147        };
1148
1149        let manager = TranscriptionManager::new(&config).unwrap();
1150        assert!(manager.transcription_providers.contains_key("groq"));
1151        assert!(manager.transcription_providers.contains_key("openai"));
1152        assert!(manager.transcription_providers.contains_key("deepgram"));
1153        assert_eq!(manager.available_providers().len(), 3);
1154    }
1155
1156    #[tokio::test]
1157    async fn manager_rejects_unconfigured_provider() {
1158        // SAFETY: test-only, single-threaded test runner.
1159        unsafe { std::env::remove_var("GROQ_API_KEY") };
1160
1161        let config = TranscriptionConfig {
1162            api_key: Some("test-groq-key".to_string()),
1163            ..TranscriptionConfig::default()
1164        };
1165
1166        let manager = TranscriptionManager::new(&config).unwrap();
1167        let err = manager
1168            .transcribe_with_provider(&[0u8; 100], "test.ogg", "nonexistent")
1169            .await
1170            .unwrap_err();
1171        assert!(
1172            err.to_string().contains("not configured"),
1173            "expected not-configured error, got: {err}"
1174        );
1175    }
1176
1177    #[test]
1178    fn manager_agent_transcription_provider_via_setter() {
1179        // SAFETY: test-only, single-threaded test runner.
1180        unsafe { std::env::remove_var("GROQ_API_KEY") };
1181
1182        let config = TranscriptionConfig {
1183            openai: Some(zeroclaw_config::schema::OpenAiSttConfig {
1184                api_key: Some("test-openai-key".to_string()),
1185                model: "whisper-1".to_string(),
1186            }),
1187            ..TranscriptionConfig::default()
1188        };
1189
1190        let manager = TranscriptionManager::new(&config)
1191            .unwrap()
1192            .with_agent_transcription_provider("openai");
1193        assert_eq!(manager.agent_transcription_provider, "openai");
1194    }
1195
1196    #[test]
1197    fn manager_rejects_zero_global_max_audio_bytes() {
1198        let config = TranscriptionConfig {
1199            max_audio_bytes: Some(0),
1200            ..TranscriptionConfig::default()
1201        };
1202
1203        let err = match TranscriptionManager::new(&config) {
1204            Ok(_) => panic!("expected zero max_audio_bytes to fail manager construction"),
1205            Err(err) => err,
1206        };
1207        assert!(
1208            err.to_string()
1209                .contains("transcription.max_audio_bytes must be greater than zero"),
1210            "got: {err}"
1211        );
1212    }
1213
1214    #[tokio::test]
1215    async fn manager_global_max_audio_bytes_rejects_over_limit_before_provider_dispatch() {
1216        let (manager, calls) = manager_with_static_provider(Some(3));
1217        let err = manager
1218            .transcribe_with_provider(&[0u8; 4], "voice.ogg", "static")
1219            .await
1220            .unwrap_err();
1221        assert!(
1222            err.to_string().contains("Audio file too large"),
1223            "got: {err}"
1224        );
1225        assert!(err.to_string().contains("global max 3"), "got: {err}");
1226        assert_eq!(calls.load(Ordering::SeqCst), 0);
1227    }
1228
1229    #[tokio::test]
1230    async fn manager_global_max_audio_bytes_allows_exact_limit() {
1231        let (manager, calls) = manager_with_static_provider(Some(4));
1232        let result = manager
1233            .transcribe_with_provider(&[0u8; 4], "voice.ogg", "static")
1234            .await
1235            .unwrap();
1236        assert_eq!(result, "under cap");
1237        assert_eq!(calls.load(Ordering::SeqCst), 1);
1238    }
1239
1240    #[tokio::test]
1241    async fn manager_transcribe_enforces_global_max_audio_bytes() {
1242        let (manager, calls) = manager_with_static_provider(Some(2));
1243        let manager = manager.with_agent_transcription_provider("static");
1244        let err = manager
1245            .transcribe(&[0u8; 3], "voice.ogg")
1246            .await
1247            .unwrap_err();
1248        assert!(
1249            err.to_string().contains("Audio file too large"),
1250            "got: {err}"
1251        );
1252        assert!(err.to_string().contains("global max 2"), "got: {err}");
1253        assert_eq!(calls.load(Ordering::SeqCst), 0);
1254    }
1255
1256    #[test]
1257    fn validate_audio_rejects_oversized() {
1258        let big = vec![0u8; MAX_AUDIO_BYTES + 1];
1259        let err = validate_audio(&big, "test.ogg").unwrap_err();
1260        assert!(err.to_string().contains("too large"));
1261    }
1262
1263    #[test]
1264    fn validate_audio_rejects_unsupported_format() {
1265        let data = vec![0u8; 100];
1266        let err = validate_audio(&data, "test.aac").unwrap_err();
1267        assert!(err.to_string().contains("Unsupported audio format"));
1268    }
1269
1270    #[test]
1271    fn validate_audio_accepts_supported_format() {
1272        let data = vec![0u8; 100];
1273        let (name, mime) = validate_audio(&data, "test.ogg").unwrap();
1274        assert_eq!(name, "test.ogg");
1275        assert_eq!(mime, "audio/ogg");
1276    }
1277
1278    #[test]
1279    fn validate_audio_normalizes_oga() {
1280        let data = vec![0u8; 100];
1281        let (name, mime) = validate_audio(&data, "voice.oga").unwrap();
1282        assert_eq!(name, "voice.ogg");
1283        assert_eq!(mime, "audio/ogg");
1284    }
1285
1286    #[test]
1287    fn backward_compat_config_defaults_unchanged() {
1288        let config = TranscriptionConfig::default();
1289        assert!(!config.enabled);
1290        assert!(config.api_key.is_none());
1291        assert!(config.api_url.contains("groq.com"));
1292        assert_eq!(config.model, "whisper-large-v3-turbo");
1293        // TranscriptionConfig has no global default-provider field;
1294        // per-agent `transcription_provider` is the only selector.
1295        assert!(config.openai.is_none());
1296        assert!(config.deepgram.is_none());
1297        assert!(config.assemblyai.is_none());
1298        assert!(config.google.is_none());
1299        assert!(config.local_whisper.is_none());
1300        assert!(!config.transcribe_non_ptt_audio);
1301    }
1302
1303    // ── LocalWhisperProvider tests (TDD — added below as red/green cycles) ──
1304
1305    fn local_whisper_config(url: &str) -> zeroclaw_config::schema::LocalWhisperConfig {
1306        zeroclaw_config::schema::LocalWhisperConfig {
1307            url: url.to_string(),
1308            bearer_token: Some("test-token".to_string()),
1309            max_audio_bytes: 10 * 1024 * 1024,
1310            timeout_secs: 30,
1311        }
1312    }
1313
1314    #[test]
1315    fn local_whisper_rejects_empty_url() {
1316        let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1317        cfg.url = String::new();
1318        let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1319            .err()
1320            .unwrap();
1321        assert!(
1322            err.to_string().contains("`url` must not be empty"),
1323            "got: {err}"
1324        );
1325    }
1326
1327    #[test]
1328    fn local_whisper_rejects_invalid_url() {
1329        let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1330        cfg.url = "not-a-url".to_string();
1331        let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1332            .err()
1333            .unwrap();
1334        assert!(err.to_string().contains("invalid `url`"), "got: {err}");
1335    }
1336
1337    #[test]
1338    fn local_whisper_rejects_non_http_url() {
1339        let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1340        cfg.url = "ftp://10.10.0.1:8001/v1/transcribe".to_string();
1341        let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1342            .err()
1343            .unwrap();
1344        assert!(err.to_string().contains("http or https"), "got: {err}");
1345    }
1346
1347    #[test]
1348    fn local_whisper_rejects_empty_bearer_token() {
1349        let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1350        cfg.bearer_token = Some(String::new());
1351        let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1352            .err()
1353            .unwrap();
1354        assert!(
1355            err.to_string().contains("`bearer_token` must not be empty"),
1356            "got: {err}"
1357        );
1358    }
1359
1360    #[test]
1361    fn local_whisper_rejects_missing_bearer_token() {
1362        let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1363        cfg.bearer_token = None;
1364        let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1365            .err()
1366            .unwrap();
1367        assert!(
1368            err.to_string().contains("`bearer_token` must be set"),
1369            "got: {err}"
1370        );
1371    }
1372
1373    #[test]
1374    fn local_whisper_rejects_zero_max_audio_bytes() {
1375        let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1376        cfg.max_audio_bytes = 0;
1377        let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1378            .err()
1379            .unwrap();
1380        assert!(
1381            err.to_string()
1382                .contains("`max_audio_bytes` must be greater than zero"),
1383            "got: {err}"
1384        );
1385    }
1386
1387    #[test]
1388    fn local_whisper_rejects_zero_timeout() {
1389        let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1390        cfg.timeout_secs = 0;
1391        let err = LocalWhisperProvider::from_config("local_whisper", &cfg)
1392            .err()
1393            .unwrap();
1394        assert!(
1395            err.to_string()
1396                .contains("`timeout_secs` must be greater than zero"),
1397            "got: {err}"
1398        );
1399    }
1400
1401    #[test]
1402    fn local_whisper_registered_when_config_present() {
1403        let config = TranscriptionConfig {
1404            local_whisper: Some(local_whisper_config("http://127.0.0.1:9999/v1/transcribe")),
1405            ..TranscriptionConfig::default()
1406        };
1407
1408        let manager = TranscriptionManager::new(&config).unwrap();
1409        assert!(
1410            manager.available_providers().contains(&"local_whisper"),
1411            "expected local_whisper in {:?}",
1412            manager.available_providers()
1413        );
1414    }
1415
1416    #[test]
1417    fn local_whisper_misconfigured_section_fails_manager_construction() {
1418        // A misconfigured local_whisper section logs a warning and skips
1419        // registration. When transcription is enabled and no other provider
1420        // section is set, the safety net in TranscriptionManager surfaces
1421        // the error rather than returning a useless empty manager.
1422        let mut bad_cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1423        bad_cfg.bearer_token = Some(String::new());
1424        let config = TranscriptionConfig {
1425            local_whisper: Some(bad_cfg),
1426            enabled: true,
1427            ..TranscriptionConfig::default()
1428        };
1429
1430        let err = TranscriptionManager::new(&config).err().unwrap();
1431        assert!(
1432            err.to_string()
1433                .contains("no transcription provider registered"),
1434            "expected 'no transcription provider registered' from manager safety net, got: {err}"
1435        );
1436    }
1437
1438    #[test]
1439    fn validate_audio_still_enforces_25mb_cap() {
1440        // Regression: extracting resolve_audio_format() must not weaken validate_audio().
1441        let at_limit = vec![0u8; MAX_AUDIO_BYTES];
1442        assert!(validate_audio(&at_limit, "test.ogg").is_ok());
1443        let over_limit = vec![0u8; MAX_AUDIO_BYTES + 1];
1444        let err = validate_audio(&over_limit, "test.ogg").unwrap_err();
1445        assert!(err.to_string().contains("too large"));
1446    }
1447
1448    #[tokio::test]
1449    async fn local_whisper_rejects_oversized_audio() {
1450        let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1451        let transcription_provider =
1452            LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1453        let big = vec![0u8; cfg.max_audio_bytes + 1];
1454        let err = transcription_provider
1455            .transcribe(&big, "voice.ogg")
1456            .await
1457            .unwrap_err();
1458        assert!(err.to_string().contains("too large"), "got: {err}");
1459    }
1460
1461    #[tokio::test]
1462    async fn local_whisper_rejects_unsupported_format() {
1463        let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe");
1464        let transcription_provider =
1465            LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1466        let data = vec![0u8; 100];
1467        let err = transcription_provider
1468            .transcribe(&data, "voice.aiff")
1469            .await
1470            .unwrap_err();
1471        assert!(
1472            err.to_string().contains("Unsupported audio format"),
1473            "got: {err}"
1474        );
1475    }
1476
1477    // ── LocalWhisperProvider HTTP mock tests ────────────────────
1478
1479    #[tokio::test]
1480    async fn local_whisper_returns_text_from_response() {
1481        use wiremock::matchers::{header_exists, method, path};
1482        use wiremock::{Mock, MockServer, ResponseTemplate};
1483
1484        let server = MockServer::start().await;
1485
1486        Mock::given(method("POST"))
1487            .and(path("/v1/transcribe"))
1488            .and(header_exists("authorization"))
1489            .respond_with(
1490                ResponseTemplate::new(200)
1491                    .set_body_json(serde_json::json!({"text": "hello world"})),
1492            )
1493            .mount(&server)
1494            .await;
1495
1496        let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1497        let transcription_provider =
1498            LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1499
1500        let result = transcription_provider
1501            .transcribe(b"fake-audio", "voice.ogg")
1502            .await
1503            .unwrap();
1504        assert_eq!(result, "hello world");
1505    }
1506
1507    #[tokio::test]
1508    async fn local_whisper_sends_bearer_auth_header() {
1509        use wiremock::matchers::{header, method, path};
1510        use wiremock::{Mock, MockServer, ResponseTemplate};
1511
1512        let server = MockServer::start().await;
1513
1514        Mock::given(method("POST"))
1515            .and(path("/v1/transcribe"))
1516            .and(header("authorization", "Bearer test-token"))
1517            .respond_with(
1518                ResponseTemplate::new(200).set_body_json(serde_json::json!({"text": "auth ok"})),
1519            )
1520            .mount(&server)
1521            .await;
1522
1523        let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1524        let transcription_provider =
1525            LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1526
1527        let result = transcription_provider
1528            .transcribe(b"fake-audio", "voice.ogg")
1529            .await
1530            .unwrap();
1531        assert_eq!(result, "auth ok");
1532    }
1533
1534    #[tokio::test]
1535    async fn local_whisper_propagates_http_error() {
1536        use wiremock::matchers::{method, path};
1537        use wiremock::{Mock, MockServer, ResponseTemplate};
1538
1539        let server = MockServer::start().await;
1540
1541        Mock::given(method("POST"))
1542            .and(path("/v1/transcribe"))
1543            .respond_with(
1544                ResponseTemplate::new(503).set_body_json(
1545                    serde_json::json!({"error": {"message": "service unavailable"}}),
1546                ),
1547            )
1548            .mount(&server)
1549            .await;
1550
1551        let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1552        let transcription_provider =
1553            LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1554
1555        let err = transcription_provider
1556            .transcribe(b"fake-audio", "voice.ogg")
1557            .await
1558            .unwrap_err();
1559        assert!(
1560            err.to_string().contains("503") || err.to_string().contains("service unavailable"),
1561            "expected HTTP error, got: {err}"
1562        );
1563    }
1564
1565    #[tokio::test]
1566    async fn local_whisper_propagates_non_json_http_error() {
1567        use wiremock::matchers::{method, path};
1568        use wiremock::{Mock, MockServer, ResponseTemplate};
1569
1570        let server = MockServer::start().await;
1571
1572        Mock::given(method("POST"))
1573            .and(path("/v1/transcribe"))
1574            .respond_with(
1575                ResponseTemplate::new(502)
1576                    .set_body_string("Bad Gateway")
1577                    .insert_header("content-type", "text/plain"),
1578            )
1579            .mount(&server)
1580            .await;
1581
1582        let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri()));
1583        let transcription_provider =
1584            LocalWhisperProvider::from_config("local_whisper", &cfg).unwrap();
1585
1586        let err = transcription_provider
1587            .transcribe(b"fake-audio", "voice.ogg")
1588            .await
1589            .unwrap_err();
1590        assert!(err.to_string().contains("502"), "got: {err}");
1591        assert!(
1592            err.to_string().contains("Bad Gateway"),
1593            "expected plain-text body in error, got: {err}"
1594        );
1595    }
1596}