1use crate::auth::AuthService;
8use crate::traits::{
9 ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
10 ModelProvider, TokenUsage, ToolsPayload,
11};
12use async_trait::async_trait;
13use base64::Engine;
14use directories::UserDirs;
15use reqwest::Client;
16use serde::{Deserialize, Serialize};
17use std::path::PathBuf;
18use std::sync::Arc;
19
20pub struct GeminiModelProvider {
22 alias: String,
24 auth: Option<GeminiAuth>,
25 oauth_project: Arc<tokio::sync::Mutex<Option<String>>>,
26 oauth_project_seed: Option<String>,
33 oauth_cred_paths: Vec<PathBuf>,
34 oauth_index: Arc<tokio::sync::Mutex<usize>>,
35 auth_service: Option<AuthService>,
37 auth_profile_override: Option<String>,
39 oauth_client_id: Option<String>,
44 oauth_client_secret: Option<String>,
45}
46
47struct OAuthTokenState {
49 access_token: String,
50 refresh_token: Option<String>,
51 client_id: Option<String>,
52 client_secret: Option<String>,
53 expiry_millis: Option<i64>,
55}
56
57enum GeminiAuth {
60 ExplicitKey(String),
62 OAuthToken(Arc<tokio::sync::Mutex<OAuthTokenState>>),
65 ManagedOAuth,
68}
69
70impl GeminiAuth {
71 fn is_api_key(&self) -> bool {
73 matches!(self, GeminiAuth::ExplicitKey(_))
74 }
75
76 fn is_oauth(&self) -> bool {
78 matches!(self, GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth)
79 }
80
81 fn api_key_credential(&self) -> &str {
83 match self {
84 GeminiAuth::ExplicitKey(s) => s,
85 GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => "",
86 }
87 }
88}
89
90#[derive(Debug, Serialize, Clone)]
95struct GenerateContentRequest {
96 contents: Vec<Content>,
97 #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
98 system_instruction: Option<Content>,
99 #[serde(rename = "generationConfig")]
100 generation_config: GenerationConfig,
101}
102
103#[derive(Debug, Serialize)]
120struct InternalGenerateContentEnvelope {
121 model: String,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 project: Option<String>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 user_prompt_id: Option<String>,
126 request: InternalGenerateContentRequest,
127}
128
129#[derive(Debug, Serialize)]
131struct InternalGenerateContentRequest {
132 contents: Vec<Content>,
133 #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
134 system_instruction: Option<Content>,
135 #[serde(rename = "generationConfig", skip_serializing_if = "Option::is_none")]
136 generation_config: Option<GenerationConfig>,
137}
138
139#[derive(Debug, Serialize, Clone)]
140struct Content {
141 #[serde(skip_serializing_if = "Option::is_none")]
142 role: Option<String>,
143 parts: Vec<Part>,
144}
145
146#[derive(Debug, Serialize, Clone)]
147#[serde(untagged)]
148enum Part {
149 Text { text: String },
150 Inline { inline_data: InlineData },
151}
152
153impl Part {
154 fn text(s: impl Into<String>) -> Self {
155 Part::Text { text: s.into() }
156 }
157}
158
159#[derive(Debug, Serialize, Clone)]
160struct InlineData {
161 mime_type: String,
162 data: String,
163}
164
165fn build_parts(content: &str) -> Vec<Part> {
170 let (text, image_refs) = crate::multimodal::parse_image_markers(content);
171 let mut parts = Vec::new();
172 let trimmed = text.trim();
173 if !trimmed.is_empty() {
174 parts.push(Part::text(trimmed));
175 }
176 for uri in &image_refs {
177 if let Some(rest) = uri.strip_prefix("data:")
178 && let Some(semi_pos) = rest.find(';')
179 {
180 let mime = &rest[..semi_pos];
181 if let Some(b64) = rest[semi_pos + 1..].strip_prefix("base64,") {
182 parts.push(Part::Inline {
183 inline_data: InlineData {
184 mime_type: mime.to_string(),
185 data: b64.to_string(),
186 },
187 });
188 }
189 }
190 }
191 if parts.is_empty() {
192 parts.push(Part::text(content));
193 }
194 parts
195}
196
197#[derive(Debug, Serialize, Clone)]
198struct GenerationConfig {
199 #[serde(skip_serializing_if = "Option::is_none")]
200 temperature: Option<f64>,
201 #[serde(rename = "maxOutputTokens")]
202 max_output_tokens: u32,
203}
204
205#[derive(Debug, Deserialize)]
206struct GenerateContentResponse {
207 candidates: Option<Vec<Candidate>>,
208 error: Option<ApiError>,
209 #[serde(default)]
210 response: Option<Box<GenerateContentResponse>>,
211 #[serde(default, rename = "usageMetadata")]
212 usage_metadata: Option<GeminiUsageMetadata>,
213}
214
215#[derive(Debug, Deserialize)]
216struct GeminiUsageMetadata {
217 #[serde(default, rename = "promptTokenCount")]
218 prompt_token_count: Option<u64>,
219 #[serde(default, rename = "candidatesTokenCount")]
220 candidates_token_count: Option<u64>,
221}
222
223#[derive(Debug, Deserialize)]
224struct Candidate {
225 #[serde(default)]
226 content: Option<CandidateContent>,
227}
228
229#[derive(Debug, Deserialize)]
230struct CandidateContent {
231 parts: Vec<ResponsePart>,
232}
233
234#[derive(Debug, Deserialize)]
235struct ResponsePart {
236 #[serde(default)]
237 text: Option<String>,
238 #[serde(default)]
240 thought: bool,
241}
242
243impl CandidateContent {
244 fn effective_text(self) -> Option<String> {
254 let mut answer_parts: Vec<String> = Vec::new();
255 let mut first_thinking: Option<String> = None;
256
257 for part in self.parts {
258 if let Some(text) = part.text {
259 if text.is_empty() {
260 continue;
261 }
262 if !part.thought {
263 answer_parts.push(text);
264 } else if first_thinking.is_none() {
265 first_thinking = Some(text);
266 }
267 }
268 }
269
270 if answer_parts.is_empty() {
271 first_thinking
272 } else {
273 Some(answer_parts.join(""))
274 }
275 }
276}
277
278#[derive(Debug, Deserialize)]
279struct ApiError {
280 message: String,
281}
282
283impl GenerateContentResponse {
284 fn into_effective_response(self) -> Self {
286 match self {
287 Self {
288 response: Some(mut inner),
289 usage_metadata,
290 ..
291 } => {
292 if inner.usage_metadata.is_none() {
293 inner.usage_metadata = usage_metadata;
294 }
295 *inner
296 }
297 other => other,
298 }
299 }
300}
301
302#[derive(Debug, Deserialize)]
308struct GeminiCliOAuthCreds {
309 access_token: Option<String>,
310 #[serde(alias = "idToken")]
311 id_token: Option<String>,
312 refresh_token: Option<String>,
313 #[serde(alias = "clientId")]
314 client_id: Option<String>,
315 #[serde(alias = "clientSecret")]
316 client_secret: Option<String>,
317 #[serde(alias = "expiryDate")]
319 expiry_date: Option<i64>,
320 expiry: Option<String>,
322}
323
324const GOOGLE_TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
330
331const CLOUDCODE_PA_ENDPOINT: &str = "https://cloudcode-pa.googleapis.com/v1internal";
334
335const LOAD_CODE_ASSIST_ENDPOINT: &str =
337 "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist";
338
339pub(crate) const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
341
342struct RefreshedToken {
348 access_token: String,
349 expiry_millis: Option<i64>,
351}
352
353fn refresh_gemini_cli_token(
359 refresh_token: &str,
360 client_id: Option<&str>,
361 client_secret: Option<&str>,
362) -> anyhow::Result<RefreshedToken> {
363 let client = reqwest::blocking::Client::builder()
364 .timeout(std::time::Duration::from_secs(15))
365 .connect_timeout(std::time::Duration::from_secs(5))
366 .build()
367 .unwrap_or_else(|_| reqwest::blocking::Client::new());
368
369 let form = build_oauth_refresh_form(refresh_token, client_id, client_secret);
370
371 let response = client
372 .post(GOOGLE_TOKEN_ENDPOINT)
373 .header("Content-Type", "application/x-www-form-urlencoded")
374 .header("Accept", "application/json")
375 .form(&form)
376 .send()
377 .map_err(|error| {
378 ::zeroclaw_log::record!(
379 ERROR,
380 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
381 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
382 .with_attrs(::serde_json::json!({
383 "oauth_provider": "gemini_cli",
384 "phase": "refresh_request",
385 "error": format!("{}", error),
386 })),
387 "gemini: CLI OAuth refresh request failed"
388 );
389 anyhow::Error::msg(format!("Gemini CLI OAuth refresh request failed: {error}"))
390 })?;
391
392 let status = response.status();
393 let body = response
394 .text()
395 .unwrap_or_else(|_| "<failed to read response body>".to_string());
396
397 if !status.is_success() {
398 anyhow::bail!("Gemini CLI OAuth refresh failed (HTTP {status}): {body}");
399 }
400
401 #[derive(Deserialize)]
402 struct TokenResponse {
403 access_token: Option<String>,
404 expires_in: Option<i64>,
405 }
406
407 let parsed: TokenResponse = serde_json::from_str(&body).map_err(|_| {
408 ::zeroclaw_log::record!(
409 ERROR,
410 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
411 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
412 .with_attrs(::serde_json::json!({"oauth_provider": "gemini_cli"})),
413 "gemini: CLI OAuth refresh response is not valid JSON"
414 );
415 anyhow::Error::msg("Gemini CLI OAuth refresh response is not valid JSON")
416 })?;
417
418 let access_token = parsed
419 .access_token
420 .filter(|t| !t.trim().is_empty())
421 .ok_or_else(|| {
422 ::zeroclaw_log::record!(
423 ERROR,
424 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
425 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
426 .with_attrs(::serde_json::json!({
427 "oauth_provider": "gemini_cli",
428 "missing": "access_token",
429 })),
430 "gemini: CLI OAuth refresh missing access_token"
431 );
432 anyhow::Error::msg("Gemini CLI OAuth refresh response missing access_token")
433 })?;
434
435 let expiry_millis = parsed.expires_in.and_then(|secs| {
436 let now_millis = std::time::SystemTime::now()
437 .duration_since(std::time::UNIX_EPOCH)
438 .ok()
439 .and_then(|d| i64::try_from(d.as_millis()).ok())?;
440 now_millis.checked_add(secs.checked_mul(1000)?)
441 });
442
443 Ok(RefreshedToken {
444 access_token,
445 expiry_millis,
446 })
447}
448
449fn build_oauth_refresh_form(
450 refresh_token: &str,
451 client_id: Option<&str>,
452 client_secret: Option<&str>,
453) -> Vec<(&'static str, String)> {
454 let mut form = vec![
455 ("grant_type", "refresh_token".to_string()),
456 ("refresh_token", refresh_token.to_string()),
457 ];
458 if let Some(id) = client_id.and_then(GeminiModelProvider::normalize_non_empty) {
459 form.push(("client_id", id));
460 }
461 if let Some(secret) = client_secret.and_then(GeminiModelProvider::normalize_non_empty) {
462 form.push(("client_secret", secret));
463 }
464 form
465}
466
467fn extract_client_id_from_id_token(id_token: &str) -> Option<String> {
468 let payload = id_token.split('.').nth(1)?;
469 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
470 .decode(payload)
471 .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(payload))
472 .ok()?;
473
474 #[derive(Deserialize)]
475 struct IdTokenClaims {
476 aud: Option<String>,
477 azp: Option<String>,
478 }
479
480 let claims: IdTokenClaims = serde_json::from_slice(&decoded).ok()?;
481 claims
482 .aud
483 .as_deref()
484 .and_then(GeminiModelProvider::normalize_non_empty)
485 .or_else(|| {
486 claims
487 .azp
488 .as_deref()
489 .and_then(GeminiModelProvider::normalize_non_empty)
490 })
491}
492
493async fn refresh_gemini_cli_token_async(
495 refresh_token: &str,
496 client_id: Option<&str>,
497 client_secret: Option<&str>,
498) -> anyhow::Result<RefreshedToken> {
499 let refresh_token = refresh_token.to_string();
500 let client_id = client_id.map(str::to_string);
501 let client_secret = client_secret.map(str::to_string);
502 tokio::task::spawn_blocking(move || {
503 refresh_gemini_cli_token(
504 &refresh_token,
505 client_id.as_deref(),
506 client_secret.as_deref(),
507 )
508 })
509 .await
510 .map_err(|e| {
511 ::zeroclaw_log::record!(
512 ERROR,
513 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
514 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
515 .with_attrs(::serde_json::json!({
516 "oauth_provider": "gemini_cli",
517 "phase": "task_join",
518 "error": format!("{}", e),
519 })),
520 "gemini: token refresh task panicked"
521 );
522 anyhow::Error::msg(format!("Token refresh task panicked: {e}"))
523 })?
524}
525
526impl GeminiModelProvider {
527 pub fn new(alias: &str, api_key: Option<&str>) -> Self {
534 let oauth_cred_paths = Self::discover_oauth_cred_paths();
535 let resolved_auth = api_key
536 .and_then(Self::normalize_non_empty)
537 .map(GeminiAuth::ExplicitKey)
538 .or_else(|| {
539 Self::try_load_gemini_cli_token(oauth_cred_paths.first())
540 .map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state))))
541 });
542
543 Self {
544 alias: alias.to_string(),
545 auth: resolved_auth,
546 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
547 oauth_project_seed: None,
548 oauth_cred_paths,
549 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
550 auth_service: None,
551 auth_profile_override: None,
552 oauth_client_id: None,
553 oauth_client_secret: None,
554 }
555 }
556 pub fn new_with_auth(
563 alias: &str,
564 api_key: Option<&str>,
565 auth_service: AuthService,
566 profile_override: Option<String>,
567 oauth_project_seed: Option<String>,
568 oauth_client_id: Option<String>,
569 oauth_client_secret: Option<String>,
570 ) -> Self {
571 let oauth_cred_paths = Self::discover_oauth_cred_paths();
572
573 let resolved_auth = api_key
575 .and_then(Self::normalize_non_empty)
576 .map(GeminiAuth::ExplicitKey);
577
578 let (auth, use_managed) = if resolved_auth.is_some() {
581 (resolved_auth, false)
582 } else {
583 let has_managed = std::thread::scope(|s| {
586 s.spawn(|| {
587 let rt = tokio::runtime::Builder::new_current_thread()
588 .enable_all()
589 .build()
590 .ok()?;
591 rt.block_on(async {
592 auth_service
593 .get_gemini_profile(profile_override.as_deref())
594 .await
595 .ok()
596 .flatten()
597 })
598 })
599 .join()
600 .ok()
601 .flatten()
602 .is_some()
603 });
604
605 if has_managed {
606 (Some(GeminiAuth::ManagedOAuth), true)
607 } else {
608 let cli_auth = Self::try_load_gemini_cli_token(oauth_cred_paths.first())
610 .map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state))));
611 (cli_auth, false)
612 }
613 };
614
615 Self {
616 alias: alias.to_string(),
617 auth,
618 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
619 oauth_project_seed,
620 oauth_cred_paths,
621 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
622 auth_service: if use_managed {
623 Some(auth_service)
624 } else {
625 None
626 },
627 auth_profile_override: profile_override,
628 oauth_client_id,
629 oauth_client_secret,
630 }
631 }
632
633 fn normalize_non_empty(value: &str) -> Option<String> {
634 let trimmed = value.trim();
635 if trimmed.is_empty() {
636 None
637 } else {
638 Some(trimmed.to_string())
639 }
640 }
641
642 fn load_gemini_cli_creds(creds_path: &PathBuf) -> Option<GeminiCliOAuthCreds> {
643 if !creds_path.exists() {
644 return None;
645 }
646 let content = std::fs::read_to_string(creds_path).ok()?;
647 serde_json::from_str(&content).ok()
648 }
649
650 fn discover_oauth_cred_paths() -> Vec<PathBuf> {
655 let home = match UserDirs::new() {
656 Some(u) => u.home_dir().to_path_buf(),
657 None => return Vec::new(),
658 };
659
660 let mut paths = Vec::new();
661
662 let primary = home.join(".gemini").join("oauth_creds.json");
663 if primary.exists() {
664 paths.push(primary);
665 }
666
667 if let Ok(entries) = std::fs::read_dir(&home) {
668 let mut extras: Vec<PathBuf> = entries
669 .filter_map(|e| e.ok())
670 .filter_map(|e| {
671 let name = e.file_name().to_string_lossy().to_string();
672 if name.starts_with(".gemini-") && name.ends_with("-home") {
673 let path = e.path().join(".gemini").join("oauth_creds.json");
674 if path.exists() {
675 return Some(path);
676 }
677 }
678 None
679 })
680 .collect();
681 extras.sort();
682 paths.extend(extras);
683 }
684
685 paths
686 }
687
688 fn try_load_gemini_cli_token(path: Option<&PathBuf>) -> Option<OAuthTokenState> {
693 let creds = Self::load_gemini_cli_creds(path?)?;
694
695 let expiry_millis = creds.expiry_date.or_else(|| {
697 creds.expiry.as_deref().and_then(|expiry| {
698 chrono::DateTime::parse_from_rfc3339(expiry)
699 .ok()
700 .map(|dt| dt.timestamp_millis())
701 })
702 });
703
704 let access_token = creds
705 .access_token
706 .and_then(|token| Self::normalize_non_empty(&token))?;
707
708 let id_token_client_id = creds
709 .id_token
710 .as_deref()
711 .and_then(extract_client_id_from_id_token);
712
713 let client_id = creds
714 .client_id
715 .as_deref()
716 .and_then(Self::normalize_non_empty)
717 .or(id_token_client_id);
718 let client_secret = creds
719 .client_secret
720 .as_deref()
721 .and_then(Self::normalize_non_empty);
722
723 Some(OAuthTokenState {
724 access_token,
725 refresh_token: creds.refresh_token,
726 client_id,
727 client_secret,
728 expiry_millis,
729 })
730 }
731
732 pub fn has_cli_credentials() -> bool {
734 Self::discover_oauth_cred_paths().iter().any(|path| {
735 Self::load_gemini_cli_creds(path)
736 .and_then(|creds| {
737 creds
738 .access_token
739 .as_deref()
740 .and_then(Self::normalize_non_empty)
741 })
742 .is_some()
743 })
744 }
745
746 pub fn has_any_auth() -> bool {
751 Self::has_cli_credentials()
752 }
753
754 pub fn auth_source(&self) -> &'static str {
757 match self.auth.as_ref() {
758 Some(GeminiAuth::ExplicitKey(_)) => "config",
759 Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth",
760 Some(GeminiAuth::ManagedOAuth) => "auth-profiles",
761 None => "none",
762 }
763 }
764
765 async fn get_valid_oauth_token(
768 state: &Arc<tokio::sync::Mutex<OAuthTokenState>>,
769 ) -> anyhow::Result<String> {
770 let mut guard = state.lock().await;
771
772 let now_millis = std::time::SystemTime::now()
773 .duration_since(std::time::UNIX_EPOCH)
774 .ok()
775 .and_then(|d| i64::try_from(d.as_millis()).ok())
776 .unwrap_or(i64::MAX);
777
778 let needs_refresh = guard
780 .expiry_millis
781 .is_none_or(|exp| exp <= now_millis.saturating_add(60_000));
782
783 if needs_refresh {
784 if let Some(ref refresh_token) = guard.refresh_token {
785 let refreshed = refresh_gemini_cli_token_async(
786 refresh_token,
787 guard.client_id.as_deref(),
788 guard.client_secret.as_deref(),
789 )
790 .await?;
791 ::zeroclaw_log::record!(
792 INFO,
793 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
794 "Gemini CLI OAuth token refreshed successfully (runtime)"
795 );
796 guard.access_token = refreshed.access_token;
797 guard.expiry_millis = refreshed.expiry_millis;
798 } else {
799 anyhow::bail!(
800 "Gemini CLI OAuth token expired and no refresh_token available — re-run `gemini` to authenticate"
801 );
802 }
803 }
804
805 Ok(guard.access_token.clone())
806 }
807
808 async fn rotate_oauth_credential(
811 &self,
812 state: &Arc<tokio::sync::Mutex<OAuthTokenState>>,
813 ) -> bool {
814 if self.oauth_cred_paths.len() <= 1 {
815 return false;
816 }
817
818 let mut idx = self.oauth_index.lock().await;
819 let start = *idx;
820
821 loop {
822 let next = (*idx + 1) % self.oauth_cred_paths.len();
823 *idx = next;
824
825 if next == start {
826 return false;
827 }
828
829 if let Some(next_state) =
830 Self::try_load_gemini_cli_token(self.oauth_cred_paths.get(next))
831 {
832 {
833 let mut guard = state.lock().await;
834 *guard = next_state;
835 }
836 {
837 let mut cached_project = self.oauth_project.lock().await;
838 *cached_project = None;
839 }
840 ::zeroclaw_log::record!(
841 WARN,
842 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
843 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
844 &format!(
845 "Gemini OAuth: rotated credential to {}",
846 self.oauth_cred_paths[next].display().to_string()
847 )
848 );
849 return true;
850 }
851 }
852 }
853
854 fn format_model_name(model: &str) -> String {
855 if model.starts_with("models/") {
856 model.to_string()
857 } else {
858 format!("models/{model}")
859 }
860 }
861
862 fn format_internal_model_name(model: &str) -> String {
863 model.strip_prefix("models/").unwrap_or(model).to_string()
864 }
865
866 fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String {
876 match auth {
877 GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => {
878 format!("{CLOUDCODE_PA_ENDPOINT}:generateContent")
881 }
882 _ => {
883 let model_name = Self::format_model_name(model);
884 let base_url = format!("{BASE_URL}/{model_name}:generateContent");
885
886 if auth.is_api_key() {
887 format!("{base_url}?key={}", auth.api_key_credential())
888 } else {
889 base_url
890 }
891 }
892 }
893 }
894
895 fn http_client(&self) -> Client {
896 zeroclaw_config::schema::build_runtime_proxy_client_with_timeouts(
897 "model_provider.gemini",
898 120,
899 10,
900 )
901 }
902
903 async fn resolve_oauth_project(&self, token: &str) -> anyhow::Result<String> {
906 let project_seed = self.oauth_project_seed.clone();
907 let project_seed_for_request = project_seed.clone();
908 let duet_project_for_request = project_seed.clone();
909
910 {
912 let cached = self.oauth_project.lock().await;
913 if let Some(ref project) = *cached {
914 return Ok(project.clone());
915 }
916 }
917
918 let client = self.http_client();
920 let response = client
921 .post(LOAD_CODE_ASSIST_ENDPOINT)
922 .bearer_auth(token)
923 .json(&serde_json::json!({
924 "cloudaicompanionProject": project_seed_for_request,
925 "metadata": {
926 "ideType": "GEMINI_CLI",
927 "platform": "PLATFORM_UNSPECIFIED",
928 "pluginType": "GEMINI",
929 "duetProject": duet_project_for_request,
930 }
931 }))
932 .send()
933 .await?;
934
935 if !response.status().is_success() {
936 let status = response.status();
937 let body = response.text().await.unwrap_or_default();
938 if let Some(seed) = project_seed {
939 ::zeroclaw_log::record!(
940 WARN,
941 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
942 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
943 .with_attrs(::serde_json::json!({"status": status.to_string()})),
944 "loadCodeAssist failed (HTTP ); using oauth_project seed fallback"
945 );
946 return Ok(seed);
947 }
948 anyhow::bail!("loadCodeAssist failed (HTTP {status}): {body}");
949 }
950
951 #[derive(Deserialize)]
952 struct LoadCodeAssistResponse {
953 #[serde(rename = "cloudaicompanionProject")]
954 cloudaicompanion_project: Option<String>,
955 }
956
957 let result: LoadCodeAssistResponse = response.json().await?;
958 let project = result
959 .cloudaicompanion_project
960 .filter(|p| !p.trim().is_empty())
961 .or(project_seed)
962 .ok_or_else(|| {
963 ::zeroclaw_log::record!(
964 ERROR,
965 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
966 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
967 .with_attrs(::serde_json::json!({
968 "missing": "cloudaicompanionProject",
969 })),
970 "gemini: loadCodeAssist missing project context"
971 );
972 anyhow::Error::msg("loadCodeAssist response missing project context")
973 })?;
974
975 {
977 let mut cached = self.oauth_project.lock().await;
978 *cached = Some(project.clone());
979 }
980
981 Ok(project)
982 }
983
984 fn build_generate_content_request(
989 &self,
990 auth: &GeminiAuth,
991 url: &str,
992 request: &GenerateContentRequest,
993 model: &str,
994 include_generation_config: bool,
995 project: Option<&str>,
996 oauth_token: Option<&str>,
997 ) -> reqwest::RequestBuilder {
998 let req = self.http_client().post(url).json(request);
999 match auth {
1000 GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => {
1001 let token = oauth_token.unwrap_or_default();
1002 let internal_request = InternalGenerateContentEnvelope {
1005 model: Self::format_internal_model_name(model),
1006 project: project.map(|value| value.to_string()),
1007 user_prompt_id: Some(uuid::Uuid::new_v4().to_string()),
1008 request: InternalGenerateContentRequest {
1009 contents: request.contents.clone(),
1010 system_instruction: request.system_instruction.clone(),
1011 generation_config: if include_generation_config {
1012 Some(request.generation_config.clone())
1013 } else {
1014 None
1015 },
1016 },
1017 };
1018 self.http_client()
1019 .post(url)
1020 .json(&internal_request)
1021 .bearer_auth(token)
1022 }
1023 _ => req,
1024 }
1025 }
1026
1027 fn should_retry_oauth_without_generation_config(
1028 status: reqwest::StatusCode,
1029 error_text: &str,
1030 ) -> bool {
1031 if status != reqwest::StatusCode::BAD_REQUEST {
1032 return false;
1033 }
1034
1035 error_text.contains("Unknown name \"generationConfig\"")
1036 || error_text.contains("Unknown name 'generationConfig'")
1037 || error_text.contains(r#"Unknown name \"generationConfig\""#)
1038 }
1039
1040 fn should_rotate_oauth_on_error(status: reqwest::StatusCode, error_text: &str) -> bool {
1041 status == reqwest::StatusCode::TOO_MANY_REQUESTS
1042 || status == reqwest::StatusCode::SERVICE_UNAVAILABLE
1043 || status.is_server_error()
1044 || error_text.contains("RESOURCE_EXHAUSTED")
1045 }
1046}
1047
1048impl GeminiModelProvider {
1049 fn build_chat_contents(
1050 messages: &[ChatMessage],
1051 tool_instructions: Option<&str>,
1052 ) -> (Vec<Content>, Option<Content>) {
1053 let mut system_parts: Vec<&str> = Vec::new();
1054 let mut contents: Vec<Content> = Vec::new();
1055 for msg in messages {
1056 match msg.role.as_str() {
1057 "system" => system_parts.push(&msg.content),
1058 "user" => contents.push(Content {
1059 role: Some("user".to_string()),
1060 parts: build_parts(&msg.content),
1061 }),
1062 "assistant" => contents.push(Content {
1063 role: Some("model".to_string()),
1064 parts: vec![Part::text(&msg.content)],
1065 }),
1066 _ => {}
1067 }
1068 }
1069 if let Some(instructions) = tool_instructions {
1070 system_parts.push(instructions);
1071 }
1072 let system_instruction = if system_parts.is_empty() {
1073 None
1074 } else {
1075 Some(Content {
1076 role: None,
1077 parts: vec![Part::text(system_parts.join("\n\n"))],
1078 })
1079 };
1080 (contents, system_instruction)
1081 }
1082
1083 async fn chat_with_history_full(
1084 &self,
1085 messages: &[ChatMessage],
1086 model: &str,
1087 temperature: Option<f64>,
1088 ) -> anyhow::Result<(String, Option<TokenUsage>)> {
1089 let (contents, system_instruction) = Self::build_chat_contents(messages, None);
1090 self.send_generate_content(contents, system_instruction, model, temperature)
1091 .await
1092 }
1093
1094 fn token_usage_from_metadata(usage: GeminiUsageMetadata) -> Option<TokenUsage> {
1095 if usage.prompt_token_count.is_none() && usage.candidates_token_count.is_none() {
1096 return None;
1097 }
1098 Some(TokenUsage {
1099 input_tokens: usage.prompt_token_count,
1100 output_tokens: usage.candidates_token_count,
1101 cached_input_tokens: None,
1102 })
1103 }
1104
1105 async fn send_generate_content(
1106 &self,
1107 contents: Vec<Content>,
1108 system_instruction: Option<Content>,
1109 model: &str,
1110 temperature: Option<f64>,
1111 ) -> anyhow::Result<(String, Option<TokenUsage>)> {
1112 let auth = self.auth.as_ref().ok_or_else(|| {
1113 ::zeroclaw_log::record!(
1114 ERROR,
1115 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
1116 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1117 .with_attrs(::serde_json::json!({"missing": "auth"})),
1118 "gemini: no auth configured"
1119 );
1120 anyhow::Error::msg(
1121 "Gemini API key not found. Options:\n\
1122 1. Set GEMINI_API_KEY env var\n\
1123 2. Run `gemini` CLI to authenticate (tokens will be reused)\n\
1124 3. Run `zeroclaw auth login --model-provider gemini`\n\
1125 4. Get an API key from https://aistudio.google.com/app/apikey\n\
1126 5. Run `zeroclaw quickstart --model-provider gemini --api-key <key>` to configure",
1127 )
1128 })?;
1129
1130 let oauth_state = match auth {
1131 GeminiAuth::OAuthToken(state) => Some(state.clone()),
1132 _ => None,
1133 };
1134
1135 let (mut oauth_token, mut project) = match auth {
1137 GeminiAuth::OAuthToken(state) => {
1138 let token = Self::get_valid_oauth_token(state).await?;
1139 let proj = self.resolve_oauth_project(&token).await?;
1140 (Some(token), Some(proj))
1141 }
1142 GeminiAuth::ManagedOAuth => {
1143 let auth_service = self.auth_service.as_ref().ok_or_else(|| {
1144 ::zeroclaw_log::record!(
1145 ERROR,
1146 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
1147 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1148 .with_attrs(::serde_json::json!({"missing": "auth_service"})),
1149 "gemini: ManagedOAuth requires auth_service"
1150 );
1151 anyhow::Error::msg("ManagedOAuth requires auth_service")
1152 })?;
1153 let token = auth_service
1154 .get_valid_gemini_access_token(
1155 self.auth_profile_override.as_deref(),
1156 self.oauth_client_id.as_deref().unwrap_or(""),
1157 self.oauth_client_secret.as_deref().unwrap_or(""),
1158 )
1159 .await?
1160 .ok_or_else(|| {
1161 ::zeroclaw_log::record!(
1162 ERROR,
1163 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
1164 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1165 .with_attrs(::serde_json::json!({"oauth_provider": "gemini"})),
1166 "gemini: auth profile not found"
1167 );
1168 anyhow::Error::msg(
1169 "Gemini auth profile not found. Run `zeroclaw auth login --model-provider gemini`.",
1170 )
1171 })?;
1172 let proj = self.resolve_oauth_project(&token).await?;
1173 (Some(token), Some(proj))
1174 }
1175 _ => (None, None),
1176 };
1177
1178 let request = GenerateContentRequest {
1179 contents,
1180 system_instruction,
1181 generation_config: GenerationConfig {
1182 temperature,
1183 max_output_tokens: 8192,
1184 },
1185 };
1186
1187 let url = Self::build_generate_content_url(model, auth);
1188
1189 let mut response = self
1190 .build_generate_content_request(
1191 auth,
1192 &url,
1193 &request,
1194 model,
1195 true,
1196 project.as_deref(),
1197 oauth_token.as_deref(),
1198 )
1199 .send()
1200 .await?;
1201
1202 if !response.status().is_success() {
1203 let status = response.status();
1204 let error_text = response.text().await.unwrap_or_default();
1205
1206 if auth.is_oauth() && Self::should_rotate_oauth_on_error(status, &error_text) {
1207 let can_retry = match auth {
1210 GeminiAuth::OAuthToken(_) => {
1211 if let Some(state) = oauth_state.as_ref() {
1212 self.rotate_oauth_credential(state).await
1213 } else {
1214 false
1215 }
1216 }
1217 GeminiAuth::ManagedOAuth => true, _ => false,
1219 };
1220
1221 if can_retry {
1222 let (new_token, new_project) = match auth {
1224 GeminiAuth::OAuthToken(state) => {
1225 let token = Self::get_valid_oauth_token(state).await?;
1226 let proj = self.resolve_oauth_project(&token).await?;
1227 (token, proj)
1228 }
1229 GeminiAuth::ManagedOAuth => {
1230 let auth_service = self.auth_service.as_ref().unwrap();
1231 let token = auth_service
1232 .get_valid_gemini_access_token(
1233 self.auth_profile_override.as_deref(),
1234 self.oauth_client_id.as_deref().unwrap_or(""),
1235 self.oauth_client_secret.as_deref().unwrap_or(""),
1236 )
1237 .await?
1238 .ok_or_else(|| {
1239 ::zeroclaw_log::record!(
1240 ERROR,
1241 ::zeroclaw_log::Event::new(
1242 module_path!(),
1243 ::zeroclaw_log::Action::Reject
1244 )
1245 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1246 .with_attrs(
1247 ::serde_json::json!({"oauth_provider": "gemini"})
1248 ),
1249 "gemini: auth profile not found"
1250 );
1251 anyhow::Error::msg("Gemini auth profile not found")
1252 })?;
1253 let proj = self.resolve_oauth_project(&token).await?;
1254 (token, proj)
1255 }
1256 _ => unreachable!(),
1257 };
1258 oauth_token = Some(new_token);
1259 project = Some(new_project);
1260 response = self
1261 .build_generate_content_request(
1262 auth,
1263 &url,
1264 &request,
1265 model,
1266 true,
1267 project.as_deref(),
1268 oauth_token.as_deref(),
1269 )
1270 .send()
1271 .await?;
1272 } else {
1273 anyhow::bail!("Gemini API error ({status}): {error_text}");
1274 }
1275 } else if auth.is_oauth()
1276 && Self::should_retry_oauth_without_generation_config(status, &error_text)
1277 {
1278 ::zeroclaw_log::record!(
1279 WARN,
1280 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
1281 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
1282 "Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig"
1283 );
1284 response = self
1285 .build_generate_content_request(
1286 auth,
1287 &url,
1288 &request,
1289 model,
1290 false,
1291 project.as_deref(),
1292 oauth_token.as_deref(),
1293 )
1294 .send()
1295 .await?;
1296 } else {
1297 anyhow::bail!("Gemini API error ({status}): {error_text}");
1298 }
1299 }
1300
1301 if !response.status().is_success() {
1302 let status = response.status();
1303 let error_text = response.text().await.unwrap_or_default();
1304 if auth.is_oauth()
1305 && Self::should_retry_oauth_without_generation_config(status, &error_text)
1306 {
1307 ::zeroclaw_log::record!(
1308 WARN,
1309 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
1310 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
1311 "Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig"
1312 );
1313 response = self
1314 .build_generate_content_request(
1315 auth,
1316 &url,
1317 &request,
1318 model,
1319 false,
1320 project.as_deref(),
1321 oauth_token.as_deref(),
1322 )
1323 .send()
1324 .await?;
1325 } else {
1326 anyhow::bail!("Gemini API error ({status}): {error_text}");
1327 }
1328 }
1329
1330 if !response.status().is_success() {
1331 let status = response.status();
1332 let error_text = response.text().await.unwrap_or_default();
1333 anyhow::bail!("Gemini API error ({status}): {error_text}");
1334 }
1335
1336 let result: GenerateContentResponse = response.json().await?;
1337 if let Some(err) = &result.error {
1338 anyhow::bail!("Gemini API error: {}", err.message);
1339 }
1340 let result = result.into_effective_response();
1341 if let Some(err) = result.error {
1342 anyhow::bail!("Gemini API error: {}", err.message);
1343 }
1344
1345 let usage = result
1346 .usage_metadata
1347 .and_then(Self::token_usage_from_metadata);
1348
1349 let text = result
1350 .candidates
1351 .and_then(|c| c.into_iter().next())
1352 .and_then(|c| c.content)
1353 .and_then(|c| c.effective_text())
1354 .ok_or_else(|| {
1355 ::zeroclaw_log::record!(
1356 ERROR,
1357 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1358 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
1359 "gemini: empty response text"
1360 );
1361 anyhow::Error::msg("No response from Gemini")
1362 })?;
1363
1364 Ok((text, usage))
1365 }
1366}
1367
1368#[async_trait]
1369impl ModelProvider for GeminiModelProvider {
1370 fn default_base_url(&self) -> Option<&str> {
1372 Some(BASE_URL)
1373 }
1374
1375 fn capabilities(&self) -> zeroclaw_api::model_provider::ProviderCapabilities {
1376 zeroclaw_api::model_provider::ProviderCapabilities {
1377 vision: true,
1378 native_tool_calling: false,
1379 prompt_caching: false,
1380 extended_thinking: false,
1381 }
1382 }
1383
1384 async fn chat_with_system(
1385 &self,
1386 system_prompt: Option<&str>,
1387 message: &str,
1388 model: &str,
1389 temperature: Option<f64>,
1390 ) -> anyhow::Result<String> {
1391 let system_instruction = system_prompt.map(|sys| Content {
1392 role: None,
1393 parts: vec![Part::text(sys)],
1394 });
1395
1396 let contents = vec![Content {
1397 role: Some("user".to_string()),
1398 parts: build_parts(message),
1399 }];
1400
1401 let (text, _usage) = self
1402 .send_generate_content(contents, system_instruction, model, temperature)
1403 .await?;
1404 Ok(text)
1405 }
1406
1407 async fn chat_with_history(
1408 &self,
1409 messages: &[ChatMessage],
1410 model: &str,
1411 temperature: Option<f64>,
1412 ) -> anyhow::Result<String> {
1413 let (text, _usage) = self
1414 .chat_with_history_full(messages, model, temperature)
1415 .await?;
1416 Ok(text)
1417 }
1418
1419 async fn chat(
1420 &self,
1421 request: ProviderChatRequest<'_>,
1422 model: &str,
1423 temperature: Option<f64>,
1424 ) -> anyhow::Result<ProviderChatResponse> {
1425 let tool_instructions = if let Some(tools) = request.tools
1426 && !tools.is_empty()
1427 && !self.supports_native_tools()
1428 {
1429 Some(match self.convert_tools(tools) {
1430 ToolsPayload::PromptGuided { instructions } => instructions,
1431 payload => {
1432 anyhow::bail!(
1433 "Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
1434 )
1435 }
1436 })
1437 } else {
1438 None
1439 };
1440 let (contents, system_instruction) =
1441 Self::build_chat_contents(request.messages, tool_instructions.as_deref());
1442 let (text, usage) = self
1443 .send_generate_content(contents, system_instruction, model, temperature)
1444 .await?;
1445 Ok(ProviderChatResponse {
1446 text: Some(text),
1447 tool_calls: Vec::new(),
1448 usage,
1449 reasoning_content: None,
1450 })
1451 }
1452
1453 async fn warmup(&self) -> anyhow::Result<()> {
1454 if let Some(auth) = self.auth.as_ref() {
1455 match auth {
1456 GeminiAuth::ManagedOAuth => {
1457 let auth_service = self.auth_service.as_ref().ok_or_else(|| {
1460 ::zeroclaw_log::record!(
1461 ERROR,
1462 ::zeroclaw_log::Event::new(
1463 module_path!(),
1464 ::zeroclaw_log::Action::Reject
1465 )
1466 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1467 .with_attrs(::serde_json::json!({"missing": "auth_service"})),
1468 "gemini: ManagedOAuth requires auth_service"
1469 );
1470 anyhow::Error::msg("ManagedOAuth requires auth_service")
1471 })?;
1472
1473 let _token = auth_service
1474 .get_valid_gemini_access_token(
1475 self.auth_profile_override.as_deref(),
1476 self.oauth_client_id.as_deref().unwrap_or(""),
1477 self.oauth_client_secret.as_deref().unwrap_or(""),
1478 )
1479 .await?
1480 .ok_or_else(|| {
1481 ::zeroclaw_log::record!(
1482 ERROR,
1483 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
1484 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1485 .with_attrs(::serde_json::json!({"oauth_provider": "gemini"})),
1486 "gemini: auth profile not found or expired"
1487 );
1488 anyhow::Error::msg(
1489 "Gemini auth profile not found or expired. Run: zeroclaw auth login --model-provider gemini",
1490 )
1491 })?;
1492
1493 }
1497 GeminiAuth::OAuthToken(_) => {
1498 }
1501 _ => {
1502 let url = if auth.is_api_key() {
1504 format!(
1505 "https://generativelanguage.googleapis.com/v1beta/models?key={}",
1506 auth.api_key_credential()
1507 )
1508 } else {
1509 "https://generativelanguage.googleapis.com/v1beta/models".to_string()
1510 };
1511
1512 self.http_client()
1513 .get(&url)
1514 .send()
1515 .await?
1516 .error_for_status()?;
1517 }
1518 }
1519 }
1520 Ok(())
1521 }
1522
1523 async fn list_models(&self) -> anyhow::Result<Vec<String>> {
1524 crate::models_dev::list_models_for("google").await
1527 }
1528}
1529
1530impl ::zeroclaw_api::attribution::Attributable for GeminiModelProvider {
1531 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1532 ::zeroclaw_api::attribution::Role::Provider(
1533 ::zeroclaw_api::attribution::ProviderKind::Model(
1534 ::zeroclaw_api::attribution::ModelProviderKind::Gemini,
1535 ),
1536 )
1537 }
1538 fn alias(&self) -> &str {
1539 &self.alias
1540 }
1541}
1542
1543#[cfg(test)]
1544mod tests {
1545 use super::*;
1546 use reqwest::{StatusCode, header::AUTHORIZATION};
1547
1548 fn test_oauth_auth(token: &str) -> GeminiAuth {
1550 GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(OAuthTokenState {
1551 access_token: token.to_string(),
1552 refresh_token: None,
1553 client_id: None,
1554 client_secret: None,
1555 expiry_millis: None,
1556 })))
1557 }
1558
1559 fn test_model_provider(auth: Option<GeminiAuth>) -> GeminiModelProvider {
1560 GeminiModelProvider {
1561 alias: "test".to_string(),
1562 auth,
1563 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
1564 oauth_project_seed: None,
1565 oauth_cred_paths: Vec::new(),
1566 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
1567 auth_service: None,
1568 auth_profile_override: None,
1569 oauth_client_id: None,
1570 oauth_client_secret: None,
1571 }
1572 }
1573
1574 #[test]
1575 fn normalize_non_empty_trims_and_filters() {
1576 assert_eq!(
1577 GeminiModelProvider::normalize_non_empty(" value "),
1578 Some("value".into())
1579 );
1580 assert_eq!(GeminiModelProvider::normalize_non_empty(""), None);
1581 assert_eq!(GeminiModelProvider::normalize_non_empty(" \t\n"), None);
1582 }
1583
1584 #[test]
1585 fn oauth_refresh_form_uses_provided_client_credentials() {
1586 let form = build_oauth_refresh_form("refresh-token", Some("client-id"), Some("secret"));
1587 let map: std::collections::HashMap<_, _> = form.into_iter().collect();
1588 assert_eq!(map.get("grant_type"), Some(&"refresh_token".to_string()));
1589 assert_eq!(map.get("refresh_token"), Some(&"refresh-token".to_string()));
1590 assert_eq!(map.get("client_id"), Some(&"client-id".to_string()));
1591 assert_eq!(map.get("client_secret"), Some(&"secret".to_string()));
1592 }
1593
1594 #[test]
1595 fn oauth_refresh_form_omits_client_credentials_when_missing() {
1596 let form = build_oauth_refresh_form("refresh-token", None, None);
1597 let map: std::collections::HashMap<_, _> = form.into_iter().collect();
1598 assert!(!map.contains_key("client_id"));
1599 assert!(!map.contains_key("client_secret"));
1600 }
1601
1602 #[test]
1603 fn extract_client_id_from_id_token_prefers_aud_claim() {
1604 let payload = serde_json::json!({
1605 "aud": "aud-client-id",
1606 "azp": "azp-client-id"
1607 });
1608 let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
1609 .encode(serde_json::to_vec(&payload).unwrap());
1610 let token = format!("header.{payload_b64}.sig");
1611
1612 assert_eq!(
1613 extract_client_id_from_id_token(&token),
1614 Some("aud-client-id".to_string())
1615 );
1616 }
1617
1618 #[test]
1619 fn extract_client_id_from_id_token_uses_azp_when_aud_missing() {
1620 let payload = serde_json::json!({
1621 "azp": "azp-client-id"
1622 });
1623 let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
1624 .encode(serde_json::to_vec(&payload).unwrap());
1625 let token = format!("header.{payload_b64}.sig");
1626
1627 assert_eq!(
1628 extract_client_id_from_id_token(&token),
1629 Some("azp-client-id".to_string())
1630 );
1631 }
1632
1633 #[test]
1634 fn extract_client_id_from_id_token_returns_none_for_invalid_tokens() {
1635 assert_eq!(extract_client_id_from_id_token("invalid"), None);
1636 assert_eq!(extract_client_id_from_id_token("a.b.c"), None);
1637 }
1638
1639 #[test]
1640 fn try_load_cli_token_derives_client_id_from_id_token_when_missing() {
1641 let payload = serde_json::json!({ "aud": "derived-client-id" });
1642 let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
1643 .encode(serde_json::to_vec(&payload).unwrap());
1644 let id_token = format!("header.{payload_b64}.sig");
1645
1646 let file = tempfile::NamedTempFile::new().unwrap();
1647 let json = format!(
1648 r#"{{
1649 "access_token": "ya29.test-access",
1650 "refresh_token": "1//test-refresh",
1651 "id_token": "{id_token}"
1652 }}"#
1653 );
1654 std::fs::write(file.path(), json).unwrap();
1655
1656 let path = file.path().to_path_buf();
1657 let state = GeminiModelProvider::try_load_gemini_cli_token(Some(&path)).unwrap();
1658 assert_eq!(state.client_id.as_deref(), Some("derived-client-id"));
1659 assert_eq!(state.client_secret, None);
1660 }
1661
1662 #[test]
1663 fn provider_creates_without_key() {
1664 let model_provider = GeminiModelProvider::new("test", None);
1665 let _ = model_provider.auth_source();
1667 }
1668
1669 #[test]
1670 fn provider_creates_with_key() {
1671 let model_provider = GeminiModelProvider::new("test", Some("test-api-key"));
1672 assert!(matches!(
1673 model_provider.auth,
1674 Some(GeminiAuth::ExplicitKey(ref key)) if key == "test-api-key"
1675 ));
1676 }
1677
1678 #[test]
1679 fn provider_rejects_empty_key() {
1680 let model_provider = GeminiModelProvider::new("test", Some(""));
1681 assert!(!matches!(
1682 model_provider.auth,
1683 Some(GeminiAuth::ExplicitKey(_))
1684 ));
1685 }
1686
1687 #[test]
1688 fn auth_source_explicit_key() {
1689 let model_provider = test_model_provider(Some(GeminiAuth::ExplicitKey("key".into())));
1690 assert_eq!(model_provider.auth_source(), "config");
1691 }
1692
1693 #[test]
1694 fn auth_source_none_without_credentials() {
1695 let model_provider = test_model_provider(None);
1696 assert_eq!(model_provider.auth_source(), "none");
1697 }
1698
1699 #[test]
1700 fn auth_source_oauth() {
1701 let model_provider = test_model_provider(Some(test_oauth_auth("ya29.mock")));
1702 assert_eq!(model_provider.auth_source(), "Gemini CLI OAuth");
1703 }
1704
1705 #[test]
1706 fn model_name_formatting() {
1707 assert_eq!(
1708 GeminiModelProvider::format_model_name("gemini-2.0-flash"),
1709 "models/gemini-2.0-flash"
1710 );
1711 assert_eq!(
1712 GeminiModelProvider::format_model_name("models/gemini-1.5-pro"),
1713 "models/gemini-1.5-pro"
1714 );
1715 assert_eq!(
1716 GeminiModelProvider::format_internal_model_name("models/gemini-2.5-flash"),
1717 "gemini-2.5-flash"
1718 );
1719 assert_eq!(
1720 GeminiModelProvider::format_internal_model_name("gemini-2.5-flash"),
1721 "gemini-2.5-flash"
1722 );
1723 }
1724
1725 #[test]
1726 fn api_key_url_includes_key_query_param() {
1727 let auth = GeminiAuth::ExplicitKey("api-key-123".into());
1728 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1729 assert!(url.contains(":generateContent?key=api-key-123"));
1730 }
1731
1732 #[test]
1733 fn oauth_url_uses_internal_endpoint() {
1734 let auth = test_oauth_auth("ya29.test-token");
1735 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1736 assert!(url.starts_with("https://cloudcode-pa.googleapis.com/v1internal"));
1737 assert!(url.ends_with(":generateContent"));
1738 assert!(!url.contains("generativelanguage.googleapis.com"));
1739 assert!(!url.contains("?key="));
1740 }
1741
1742 #[test]
1743 fn api_key_url_uses_public_endpoint() {
1744 let auth = GeminiAuth::ExplicitKey("api-key-123".into());
1745 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1746 assert!(url.contains("generativelanguage.googleapis.com/v1beta"));
1747 assert!(url.contains("models/gemini-2.0-flash"));
1748 }
1749
1750 #[test]
1751 fn oauth_request_uses_bearer_auth_header() {
1752 let model_provider = test_model_provider(Some(test_oauth_auth("ya29.mock-token")));
1753 let auth = test_oauth_auth("ya29.mock-token");
1754 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1755 let body = GenerateContentRequest {
1756 contents: vec![Content {
1757 role: Some("user".into()),
1758 parts: vec![Part::text("hello")],
1759 }],
1760 system_instruction: None,
1761 generation_config: GenerationConfig {
1762 temperature: Some(0.7),
1763 max_output_tokens: 8192,
1764 },
1765 };
1766
1767 let request = model_provider
1768 .build_generate_content_request(
1769 &auth,
1770 &url,
1771 &body,
1772 "gemini-2.0-flash",
1773 true,
1774 Some("test-project"),
1775 Some("ya29.mock-token"),
1776 )
1777 .build()
1778 .unwrap();
1779
1780 assert_eq!(
1781 request
1782 .headers()
1783 .get(AUTHORIZATION)
1784 .and_then(|h| h.to_str().ok()),
1785 Some("Bearer ya29.mock-token")
1786 );
1787 }
1788
1789 #[test]
1790 fn oauth_request_wraps_payload_in_request_envelope() {
1791 let model_provider = test_model_provider(Some(test_oauth_auth("ya29.mock-token")));
1792 let auth = test_oauth_auth("ya29.mock-token");
1793 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1794 let body = GenerateContentRequest {
1795 contents: vec![Content {
1796 role: Some("user".into()),
1797 parts: vec![Part::text("hello")],
1798 }],
1799 system_instruction: None,
1800 generation_config: GenerationConfig {
1801 temperature: Some(0.7),
1802 max_output_tokens: 8192,
1803 },
1804 };
1805
1806 let request = model_provider
1807 .build_generate_content_request(
1808 &auth,
1809 &url,
1810 &body,
1811 "models/gemini-2.0-flash",
1812 true,
1813 Some("test-project"),
1814 Some("ya29.mock-token"),
1815 )
1816 .build()
1817 .unwrap();
1818
1819 let payload = request
1820 .body()
1821 .and_then(|b| b.as_bytes())
1822 .expect("json request body should be bytes");
1823 let json: serde_json::Value = serde_json::from_slice(payload).unwrap();
1824
1825 assert_eq!(json["model"], "gemini-2.0-flash");
1826 assert!(json.get("generationConfig").is_none());
1827 assert!(json.get("request").is_some());
1828 assert!(json["request"].get("generationConfig").is_some());
1829 }
1830
1831 #[test]
1832 fn api_key_request_does_not_set_bearer_header() {
1833 let model_provider =
1834 test_model_provider(Some(GeminiAuth::ExplicitKey("api-key-123".into())));
1835 let auth = GeminiAuth::ExplicitKey("api-key-123".into());
1836 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1837 let body = GenerateContentRequest {
1838 contents: vec![Content {
1839 role: Some("user".into()),
1840 parts: vec![Part::text("hello")],
1841 }],
1842 system_instruction: None,
1843 generation_config: GenerationConfig {
1844 temperature: Some(0.7),
1845 max_output_tokens: 8192,
1846 },
1847 };
1848
1849 let request = model_provider
1850 .build_generate_content_request(
1851 &auth,
1852 &url,
1853 &body,
1854 "gemini-2.0-flash",
1855 true,
1856 None,
1857 None,
1858 )
1859 .build()
1860 .unwrap();
1861
1862 assert!(request.headers().get(AUTHORIZATION).is_none());
1863 }
1864
1865 #[test]
1866 fn request_serialization() {
1867 let request = GenerateContentRequest {
1868 contents: vec![Content {
1869 role: Some("user".to_string()),
1870 parts: vec![Part::text("Hello")],
1871 }],
1872 system_instruction: Some(Content {
1873 role: None,
1874 parts: vec![Part::text("You are helpful")],
1875 }),
1876 generation_config: GenerationConfig {
1877 temperature: Some(0.7),
1878 max_output_tokens: 8192,
1879 },
1880 };
1881
1882 let json = serde_json::to_string(&request).unwrap();
1883 assert!(json.contains("\"role\":\"user\""));
1884 assert!(json.contains("\"text\":\"Hello\""));
1885 assert!(json.contains("\"systemInstruction\""));
1886 assert!(!json.contains("\"system_instruction\""));
1887 assert!(json.contains("\"temperature\":0.7"));
1888 assert!(json.contains("\"maxOutputTokens\":8192"));
1889 }
1890
1891 #[test]
1892 fn internal_request_includes_model() {
1893 let request = InternalGenerateContentEnvelope {
1894 model: "gemini-3-pro-preview".to_string(),
1895 project: Some("test-project".to_string()),
1896 user_prompt_id: Some("prompt-123".to_string()),
1897 request: InternalGenerateContentRequest {
1898 contents: vec![Content {
1899 role: Some("user".to_string()),
1900 parts: vec![Part::text("Hello")],
1901 }],
1902 system_instruction: None,
1903 generation_config: Some(GenerationConfig {
1904 temperature: Some(0.7),
1905 max_output_tokens: 8192,
1906 }),
1907 },
1908 };
1909
1910 let json = serde_json::to_string(&request).unwrap();
1911 assert!(json.contains("\"model\":\"gemini-3-pro-preview\""));
1912 assert!(json.contains("\"request\""));
1913 assert!(json.contains("\"generationConfig\""));
1914 assert!(json.contains("\"maxOutputTokens\":8192"));
1915 assert!(json.contains("\"user_prompt_id\":\"prompt-123\""));
1916 assert!(json.contains("\"project\":\"test-project\""));
1917 assert!(json.contains("\"role\":\"user\""));
1918 assert!(json.contains("\"temperature\":0.7"));
1919 }
1920
1921 #[test]
1922 fn internal_request_omits_generation_config_when_none() {
1923 let request = InternalGenerateContentEnvelope {
1924 model: "gemini-3-pro-preview".to_string(),
1925 project: Some("test-project".to_string()),
1926 user_prompt_id: None,
1927 request: InternalGenerateContentRequest {
1928 contents: vec![Content {
1929 role: Some("user".to_string()),
1930 parts: vec![Part::text("Hello")],
1931 }],
1932 system_instruction: None,
1933 generation_config: None,
1934 },
1935 };
1936
1937 let json = serde_json::to_string(&request).unwrap();
1938 assert!(!json.contains("generationConfig"));
1939 assert!(json.contains("\"model\":\"gemini-3-pro-preview\""));
1940 }
1941
1942 #[test]
1943 fn internal_request_includes_project() {
1944 let request = InternalGenerateContentEnvelope {
1945 model: "gemini-2.5-flash".to_string(),
1946 project: Some("my-gcp-project-id".to_string()),
1947 user_prompt_id: None,
1948 request: InternalGenerateContentRequest {
1949 contents: vec![Content {
1950 role: Some("user".to_string()),
1951 parts: vec![Part::text("Hello")],
1952 }],
1953 system_instruction: None,
1954 generation_config: None,
1955 },
1956 };
1957
1958 let json = serde_json::to_string(&request).unwrap();
1959 assert!(json.contains("\"project\":\"my-gcp-project-id\""));
1960 }
1961
1962 #[test]
1963 fn creds_deserialize_with_expiry_date() {
1964 let json = r#"{
1965 "access_token": "ya29.test-token",
1966 "refresh_token": "1//test-refresh",
1967 "expiry_date": 4102444800000
1968 }"#;
1969
1970 let creds: GeminiCliOAuthCreds = serde_json::from_str(json).unwrap();
1971 assert_eq!(creds.access_token.as_deref(), Some("ya29.test-token"));
1972 assert_eq!(creds.refresh_token.as_deref(), Some("1//test-refresh"));
1973 assert_eq!(creds.expiry_date, Some(4_102_444_800_000));
1974 assert!(creds.expiry.is_none());
1975 }
1976
1977 #[test]
1978 fn creds_deserialize_accepts_camel_case_fields() {
1979 let json = r#"{
1980 "access_token": "ya29.test-token",
1981 "idToken": "header.payload.sig",
1982 "refresh_token": "1//test-refresh",
1983 "clientId": "test-client-id",
1984 "clientSecret": "test-client-secret",
1985 "expiryDate": 4102444800000
1986 }"#;
1987
1988 let creds: GeminiCliOAuthCreds = serde_json::from_str(json).unwrap();
1989 assert_eq!(creds.id_token.as_deref(), Some("header.payload.sig"));
1990 assert_eq!(creds.client_id.as_deref(), Some("test-client-id"));
1991 assert_eq!(creds.client_secret.as_deref(), Some("test-client-secret"));
1992 assert_eq!(creds.expiry_date, Some(4_102_444_800_000));
1993 }
1994
1995 #[test]
1996 fn oauth_retry_detection_for_generation_config_rejection() {
1997 let err =
1999 "Invalid JSON payload received. Unknown name \"generationConfig\": Cannot find field.";
2000 assert!(
2001 GeminiModelProvider::should_retry_oauth_without_generation_config(
2002 StatusCode::BAD_REQUEST,
2003 err
2004 )
2005 );
2006 let err_json = r#"Invalid JSON payload received. Unknown name \"generationConfig\": Cannot find field."#;
2008 assert!(
2009 GeminiModelProvider::should_retry_oauth_without_generation_config(
2010 StatusCode::BAD_REQUEST,
2011 err_json
2012 )
2013 );
2014 assert!(
2015 !GeminiModelProvider::should_retry_oauth_without_generation_config(
2016 StatusCode::UNAUTHORIZED,
2017 err
2018 )
2019 );
2020 assert!(
2021 !GeminiModelProvider::should_retry_oauth_without_generation_config(
2022 StatusCode::BAD_REQUEST,
2023 "something else"
2024 )
2025 );
2026 }
2027
2028 #[test]
2029 fn response_deserialization() {
2030 let json = r#"{
2031 "candidates": [{
2032 "content": {
2033 "parts": [{"text": "Hello there!"}]
2034 }
2035 }]
2036 }"#;
2037
2038 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2039 assert!(response.candidates.is_some());
2040 let text = response
2041 .candidates
2042 .unwrap()
2043 .into_iter()
2044 .next()
2045 .unwrap()
2046 .content
2047 .unwrap()
2048 .parts
2049 .into_iter()
2050 .next()
2051 .unwrap()
2052 .text;
2053 assert_eq!(text, Some("Hello there!".to_string()));
2054 }
2055
2056 #[test]
2057 fn error_response_deserialization() {
2058 let json = r#"{
2059 "error": {
2060 "message": "Invalid API key"
2061 }
2062 }"#;
2063
2064 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2065 assert!(response.error.is_some());
2066 assert_eq!(response.error.unwrap().message, "Invalid API key");
2067 }
2068
2069 #[test]
2070 fn internal_response_deserialization() {
2071 let json = r#"{
2072 "response": {
2073 "candidates": [{
2074 "content": {
2075 "parts": [{"text": "Hello from internal"}]
2076 }
2077 }]
2078 }
2079 }"#;
2080
2081 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2082 let text = response
2083 .into_effective_response()
2084 .candidates
2085 .unwrap()
2086 .into_iter()
2087 .next()
2088 .unwrap()
2089 .content
2090 .unwrap()
2091 .parts
2092 .into_iter()
2093 .next()
2094 .unwrap()
2095 .text;
2096 assert_eq!(text, Some("Hello from internal".to_string()));
2097 }
2098
2099 #[test]
2102 fn thinking_response_extracts_non_thinking_text() {
2103 let json = r#"{
2104 "candidates": [{
2105 "content": {
2106 "parts": [
2107 {"thought": true, "text": "Let me think about this..."},
2108 {"text": "The answer is 42."},
2109 {"thoughtSignature": "c2lnbmF0dXJl"}
2110 ]
2111 }
2112 }]
2113 }"#;
2114
2115 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2116 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2117 let text = candidate.content.unwrap().effective_text();
2118 assert_eq!(text, Some("The answer is 42.".to_string()));
2119 }
2120
2121 #[test]
2122 fn non_thinking_response_unaffected() {
2123 let json = r#"{
2124 "candidates": [{
2125 "content": {
2126 "parts": [{"text": "Hello there!"}]
2127 }
2128 }]
2129 }"#;
2130
2131 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2132 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2133 let text = candidate.content.unwrap().effective_text();
2134 assert_eq!(text, Some("Hello there!".to_string()));
2135 }
2136
2137 #[test]
2138 fn thinking_only_response_falls_back_to_thinking_text() {
2139 let json = r#"{
2140 "candidates": [{
2141 "content": {
2142 "parts": [
2143 {"thought": true, "text": "I need more context..."},
2144 {"thoughtSignature": "c2lnbmF0dXJl"}
2145 ]
2146 }
2147 }]
2148 }"#;
2149
2150 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2151 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2152 let text = candidate.content.unwrap().effective_text();
2153 assert_eq!(text, Some("I need more context...".to_string()));
2154 }
2155
2156 #[test]
2157 fn empty_parts_returns_none() {
2158 let json = r#"{
2159 "candidates": [{
2160 "content": {
2161 "parts": []
2162 }
2163 }]
2164 }"#;
2165
2166 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2167 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2168 let text = candidate.content.unwrap().effective_text();
2169 assert_eq!(text, None);
2170 }
2171
2172 #[test]
2173 fn multiple_text_parts_concatenated() {
2174 let json = r#"{
2175 "candidates": [{
2176 "content": {
2177 "parts": [
2178 {"text": "Part one. "},
2179 {"text": "Part two."}
2180 ]
2181 }
2182 }]
2183 }"#;
2184
2185 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2186 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2187 let text = candidate.content.unwrap().effective_text();
2188 assert_eq!(text, Some("Part one. Part two.".to_string()));
2189 }
2190
2191 #[test]
2192 fn thought_signature_only_parts_skipped() {
2193 let json = r#"{
2194 "candidates": [{
2195 "content": {
2196 "parts": [
2197 {"thoughtSignature": "c2lnbmF0dXJl"}
2198 ]
2199 }
2200 }]
2201 }"#;
2202
2203 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2204 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2205 let text = candidate.content.unwrap().effective_text();
2206 assert_eq!(text, None);
2207 }
2208
2209 #[test]
2210 fn internal_response_thinking_model() {
2211 let json = r#"{
2212 "response": {
2213 "candidates": [{
2214 "content": {
2215 "parts": [
2216 {"thought": true, "text": "reasoning..."},
2217 {"text": "final answer"}
2218 ]
2219 }
2220 }]
2221 }
2222 }"#;
2223
2224 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2225 let effective = response.into_effective_response();
2226 let candidate = effective.candidates.unwrap().into_iter().next().unwrap();
2227 let text = candidate.content.unwrap().effective_text();
2228 assert_eq!(text, Some("final answer".to_string()));
2229 }
2230
2231 #[tokio::test]
2232 async fn warmup_without_key_is_noop() {
2233 let model_provider = test_model_provider(None);
2234 let result = model_provider.warmup().await;
2235 assert!(result.is_ok());
2236 }
2237
2238 #[tokio::test]
2239 async fn warmup_oauth_is_noop() {
2240 let model_provider = test_model_provider(Some(test_oauth_auth("ya29.mock-token")));
2241 let result = model_provider.warmup().await;
2242 assert!(result.is_ok());
2243 }
2244
2245 #[test]
2246 fn discover_oauth_cred_paths_does_not_panic() {
2247 let _paths = GeminiModelProvider::discover_oauth_cred_paths();
2248 }
2249
2250 #[tokio::test]
2251 async fn rotate_oauth_without_alternatives_returns_false() {
2252 let state = Arc::new(tokio::sync::Mutex::new(OAuthTokenState {
2253 access_token: "ya29.mock".to_string(),
2254 refresh_token: None,
2255 client_id: None,
2256 client_secret: None,
2257 expiry_millis: None,
2258 }));
2259 let model_provider = test_model_provider(Some(GeminiAuth::OAuthToken(state.clone())));
2260 assert!(!model_provider.rotate_oauth_credential(&state).await);
2261 }
2262
2263 #[test]
2264 fn response_parses_usage_metadata() {
2265 let json = r#"{
2266 "candidates": [{"content": {"parts": [{"text": "Hello"}]}}],
2267 "usageMetadata": {"promptTokenCount": 120, "candidatesTokenCount": 40}
2268 }"#;
2269 let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
2270 let usage = resp.usage_metadata.unwrap();
2271 assert_eq!(usage.prompt_token_count, Some(120));
2272 assert_eq!(usage.candidates_token_count, Some(40));
2273 }
2274
2275 #[test]
2276 fn response_usage_metadata_maps_to_token_usage() {
2277 let usage = GeminiUsageMetadata {
2278 prompt_token_count: Some(120),
2279 candidates_token_count: Some(40),
2280 };
2281
2282 let token_usage =
2283 GeminiModelProvider::token_usage_from_metadata(usage).expect("usage counts should map");
2284
2285 assert_eq!(token_usage.input_tokens, Some(120));
2286 assert_eq!(token_usage.output_tokens, Some(40));
2287 assert_eq!(token_usage.cached_input_tokens, None);
2288 }
2289
2290 #[test]
2291 fn empty_usage_metadata_maps_to_none() {
2292 let usage = GeminiUsageMetadata {
2293 prompt_token_count: None,
2294 candidates_token_count: None,
2295 };
2296
2297 assert!(GeminiModelProvider::token_usage_from_metadata(usage).is_none());
2298 }
2299
2300 #[test]
2301 fn wrapped_response_preserves_outer_usage_metadata() {
2302 let json = r#"{
2303 "usageMetadata": {"promptTokenCount": 120, "candidatesTokenCount": 40},
2304 "response": {
2305 "candidates": [{"content": {"parts": [{"text": "Hello"}]}}]
2306 }
2307 }"#;
2308
2309 let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
2310 let effective = resp.into_effective_response();
2311 let usage = effective.usage_metadata.unwrap();
2312
2313 assert_eq!(usage.prompt_token_count, Some(120));
2314 assert_eq!(usage.candidates_token_count, Some(40));
2315 }
2316
2317 #[test]
2318 fn wrapped_response_prefers_inner_usage_metadata() {
2319 let json = r#"{
2320 "usageMetadata": {"promptTokenCount": 120, "candidatesTokenCount": 40},
2321 "response": {
2322 "candidates": [{"content": {"parts": [{"text": "Hello"}]}}],
2323 "usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 2}
2324 }
2325 }"#;
2326
2327 let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
2328 let effective = resp.into_effective_response();
2329 let usage = effective.usage_metadata.unwrap();
2330
2331 assert_eq!(usage.prompt_token_count, Some(5));
2332 assert_eq!(usage.candidates_token_count, Some(2));
2333 }
2334
2335 #[test]
2336 fn response_parses_without_usage_metadata() {
2337 let json = r#"{"candidates": [{"content": {"parts": [{"text": "Hello"}]}}]}"#;
2338 let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
2339 assert!(resp.usage_metadata.is_none());
2340 }
2341
2342 #[tokio::test]
2344 async fn warmup_managed_oauth_requires_auth_service() {
2345 let model_provider = GeminiModelProvider {
2346 alias: "test".to_string(),
2347 auth: Some(GeminiAuth::ManagedOAuth),
2348 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
2349 oauth_project_seed: None,
2350 oauth_cred_paths: Vec::new(),
2351 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
2352 auth_service: None, auth_profile_override: None,
2354 oauth_client_id: None,
2355 oauth_client_secret: None,
2356 };
2357
2358 let result = model_provider.warmup().await;
2359 assert!(result.is_err());
2360 assert!(
2361 result
2362 .unwrap_err()
2363 .to_string()
2364 .contains("ManagedOAuth requires auth_service")
2365 );
2366 }
2367
2368 #[tokio::test]
2370 async fn warmup_cli_oauth_skips_validation() {
2371 let model_provider = test_model_provider(Some(test_oauth_auth("fake_token")));
2372 let result = model_provider.warmup().await;
2373 assert!(result.is_ok());
2375 }
2376
2377 #[test]
2380 fn part_text_serializes_as_text_object() {
2381 let part = Part::text("hello");
2382 let json = serde_json::to_value(&part).unwrap();
2383 assert_eq!(json, serde_json::json!({"text": "hello"}));
2384 }
2385
2386 #[test]
2387 fn part_inline_serializes_as_inline_data_object() {
2388 let part = Part::Inline {
2389 inline_data: InlineData {
2390 mime_type: "image/png".to_string(),
2391 data: "iVBOR...".to_string(),
2392 },
2393 };
2394 let json = serde_json::to_value(&part).unwrap();
2395 assert_eq!(
2396 json,
2397 serde_json::json!({"inline_data": {"mime_type": "image/png", "data": "iVBOR..."}})
2398 );
2399 }
2400
2401 #[test]
2402 fn part_text_constructor_accepts_string_and_str() {
2403 let from_str = Part::text("hello");
2404 let from_string = Part::text(String::from("hello"));
2405 assert_eq!(
2407 serde_json::to_value(&from_str).unwrap(),
2408 serde_json::to_value(&from_string).unwrap(),
2409 );
2410 }
2411
2412 #[test]
2413 fn content_with_mixed_parts_serializes_correctly() {
2414 let content = Content {
2415 role: Some("user".to_string()),
2416 parts: vec![
2417 Part::text("Describe this image:"),
2418 Part::Inline {
2419 inline_data: InlineData {
2420 mime_type: "image/jpeg".to_string(),
2421 data: "/9j/4AAQ...".to_string(),
2422 },
2423 },
2424 ],
2425 };
2426 let json = serde_json::to_value(&content).unwrap();
2427 let parts = json["parts"].as_array().unwrap();
2428 assert_eq!(parts.len(), 2);
2429 assert!(parts[0].get("text").is_some());
2430 assert!(parts[1].get("inline_data").is_some());
2431 }
2432
2433 #[test]
2436 fn build_parts_plain_text_returns_single_text_part() {
2437 let parts = build_parts("Hello, world!");
2438 assert_eq!(parts.len(), 1);
2439 assert_eq!(
2440 serde_json::to_value(&parts[0]).unwrap(),
2441 serde_json::json!({"text": "Hello, world!"})
2442 );
2443 }
2444
2445 #[test]
2446 fn build_parts_empty_string_returns_single_text_part() {
2447 let parts = build_parts("");
2448 assert_eq!(parts.len(), 1);
2449 assert_eq!(
2451 serde_json::to_value(&parts[0]).unwrap(),
2452 serde_json::json!({"text": ""})
2453 );
2454 }
2455
2456 #[test]
2457 fn build_parts_extracts_data_uri_as_inline_part() {
2458 let content = "Check this [IMAGE:data:image/png;base64,iVBORw0KGgo=]";
2459 let parts = build_parts(content);
2460 assert_eq!(parts.len(), 2);
2461 assert_eq!(
2463 serde_json::to_value(&parts[0]).unwrap(),
2464 serde_json::json!({"text": "Check this"})
2465 );
2466 assert_eq!(
2468 serde_json::to_value(&parts[1]).unwrap(),
2469 serde_json::json!({"inline_data": {"mime_type": "image/png", "data": "iVBORw0KGgo="}})
2470 );
2471 }
2472
2473 #[test]
2474 fn build_parts_multiple_images() {
2475 let content = "Image A: [IMAGE:data:image/png;base64,AAAA] Image B: [IMAGE:data:image/jpeg;base64,BBBB]";
2476 let parts = build_parts(content);
2477 assert_eq!(parts.len(), 3); let inline_parts: Vec<_> = parts
2480 .iter()
2481 .filter(|p| matches!(p, Part::Inline { .. }))
2482 .collect();
2483 assert_eq!(inline_parts.len(), 2);
2484 }
2485
2486 #[test]
2487 fn build_parts_ignores_non_data_uri_markers() {
2488 let content = "Look [IMAGE:/tmp/photo.png]";
2491 let parts = build_parts(content);
2492 for part in &parts {
2495 assert!(matches!(part, Part::Text { .. }));
2496 }
2497 }
2498
2499 #[test]
2500 fn build_parts_image_only_still_produces_inline_part() {
2501 let content = "[IMAGE:data:image/gif;base64,R0lGODlh]";
2502 let parts = build_parts(content);
2503 assert_eq!(parts.len(), 1);
2505 assert!(matches!(&parts[0], Part::Inline { .. }));
2506 }
2507
2508 #[test]
2511 fn chat_with_history_maps_roles_correctly() {
2512 let messages = vec![
2513 ChatMessage::system("You are helpful"),
2514 ChatMessage::user("Hello [IMAGE:data:image/png;base64,AA==]"),
2515 ChatMessage::assistant("I see the image"),
2516 ];
2517
2518 let (contents, system_instruction) =
2519 GeminiModelProvider::build_chat_contents(&messages, None);
2520
2521 let system_instruction = system_instruction.expect("system prompt should be separated");
2522 assert_eq!(system_instruction.role, None);
2523 assert!(
2524 matches!(&system_instruction.parts[0], Part::Text { text } if text == "You are helpful")
2525 );
2526
2527 assert_eq!(contents.len(), 2);
2528 assert_eq!(contents[0].role.as_deref(), Some("user"));
2529 assert!(
2530 contents[0]
2531 .parts
2532 .iter()
2533 .any(|p| matches!(p, Part::Inline { .. }))
2534 );
2535 assert_eq!(contents[1].role.as_deref(), Some("model"));
2536 assert!(matches!(&contents[1].parts[0], Part::Text { text } if text == "I see the image"));
2537 }
2538
2539 #[test]
2540 fn chat_contents_append_tool_instructions_to_system_prompt() {
2541 let messages = vec![
2542 ChatMessage::system("You are helpful"),
2543 ChatMessage::user("Hello"),
2544 ];
2545
2546 let (_contents, system_instruction) =
2547 GeminiModelProvider::build_chat_contents(&messages, Some("Use tools carefully"));
2548
2549 let system_instruction = system_instruction.expect("system prompt should include tools");
2550 assert!(
2551 matches!(&system_instruction.parts[0], Part::Text { text } if text == "You are helpful\n\nUse tools carefully")
2552 );
2553 }
2554
2555 #[test]
2556 fn chat_contents_create_system_prompt_from_tool_instructions() {
2557 let messages = vec![ChatMessage::user("Hello")];
2558
2559 let (_contents, system_instruction) =
2560 GeminiModelProvider::build_chat_contents(&messages, Some("Use tools carefully"));
2561
2562 let system_instruction =
2563 system_instruction.expect("tool instructions should be system prompt");
2564 assert!(
2565 matches!(&system_instruction.parts[0], Part::Text { text } if text == "Use tools carefully")
2566 );
2567 }
2568}