1use crate::auth::AuthService;
8use crate::traits::{
9 ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
10 ModelProvider, TokenUsage, ToolsPayload,
11};
12use async_trait::async_trait;
13use base64::Engine;
14use directories::UserDirs;
15use reqwest::Client;
16use serde::{Deserialize, Serialize};
17use std::path::PathBuf;
18use std::sync::Arc;
19
20pub struct GeminiModelProvider {
22 alias: String,
24 auth: Option<GeminiAuth>,
25 oauth_project: Arc<tokio::sync::Mutex<Option<String>>>,
26 oauth_project_seed: Option<String>,
33 oauth_cred_paths: Vec<PathBuf>,
34 oauth_index: Arc<tokio::sync::Mutex<usize>>,
35 auth_service: Option<AuthService>,
37 auth_profile_override: Option<String>,
39 oauth_client_id: Option<String>,
44 oauth_client_secret: Option<String>,
45}
46
47struct OAuthTokenState {
49 access_token: String,
50 refresh_token: Option<String>,
51 client_id: Option<String>,
52 client_secret: Option<String>,
53 expiry_millis: Option<i64>,
55}
56
57enum GeminiAuth {
60 ExplicitKey(String),
62 OAuthToken(Arc<tokio::sync::Mutex<OAuthTokenState>>),
65 ManagedOAuth,
68}
69
70impl GeminiAuth {
71 fn is_api_key(&self) -> bool {
73 matches!(self, GeminiAuth::ExplicitKey(_))
74 }
75
76 fn is_oauth(&self) -> bool {
78 matches!(self, GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth)
79 }
80
81 fn api_key_credential(&self) -> &str {
83 match self {
84 GeminiAuth::ExplicitKey(s) => s,
85 GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => "",
86 }
87 }
88}
89
90#[derive(Debug, Serialize, Clone)]
95struct GenerateContentRequest {
96 contents: Vec<Content>,
97 #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
98 system_instruction: Option<Content>,
99 #[serde(rename = "generationConfig")]
100 generation_config: GenerationConfig,
101}
102
103#[derive(Debug, Serialize)]
120struct InternalGenerateContentEnvelope {
121 model: String,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 project: Option<String>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 user_prompt_id: Option<String>,
126 request: InternalGenerateContentRequest,
127}
128
129#[derive(Debug, Serialize)]
131struct InternalGenerateContentRequest {
132 contents: Vec<Content>,
133 #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
134 system_instruction: Option<Content>,
135 #[serde(rename = "generationConfig", skip_serializing_if = "Option::is_none")]
136 generation_config: Option<GenerationConfig>,
137}
138
139#[derive(Debug, Serialize, Clone)]
140struct Content {
141 #[serde(skip_serializing_if = "Option::is_none")]
142 role: Option<String>,
143 parts: Vec<Part>,
144}
145
146#[derive(Debug, Serialize, Clone)]
147#[serde(untagged)]
148enum Part {
149 Text { text: String },
150 Inline { inline_data: InlineData },
151}
152
153impl Part {
154 fn text(s: impl Into<String>) -> Self {
155 Part::Text { text: s.into() }
156 }
157}
158
159#[derive(Debug, Serialize, Clone)]
160struct InlineData {
161 mime_type: String,
162 data: String,
163}
164
165fn build_parts(content: &str) -> Vec<Part> {
170 let (text, image_refs) = crate::multimodal::parse_image_markers(content);
171 let mut parts = Vec::new();
172 let trimmed = text.trim();
173 if !trimmed.is_empty() {
174 parts.push(Part::text(trimmed));
175 }
176 for uri in &image_refs {
177 if let Some(rest) = uri.strip_prefix("data:")
178 && let Some(semi_pos) = rest.find(';')
179 {
180 let mime = &rest[..semi_pos];
181 if let Some(b64) = rest[semi_pos + 1..].strip_prefix("base64,") {
182 parts.push(Part::Inline {
183 inline_data: InlineData {
184 mime_type: mime.to_string(),
185 data: b64.to_string(),
186 },
187 });
188 }
189 }
190 }
191 if parts.is_empty() {
192 parts.push(Part::text(content));
193 }
194 parts
195}
196
197#[derive(Debug, Serialize, Clone)]
198struct GenerationConfig {
199 temperature: f64,
200 #[serde(rename = "maxOutputTokens")]
201 max_output_tokens: u32,
202}
203
204#[derive(Debug, Deserialize)]
205struct GenerateContentResponse {
206 candidates: Option<Vec<Candidate>>,
207 error: Option<ApiError>,
208 #[serde(default)]
209 response: Option<Box<GenerateContentResponse>>,
210 #[serde(default, rename = "usageMetadata")]
211 usage_metadata: Option<GeminiUsageMetadata>,
212}
213
214#[derive(Debug, Deserialize)]
215struct GeminiUsageMetadata {
216 #[serde(default, rename = "promptTokenCount")]
217 prompt_token_count: Option<u64>,
218 #[serde(default, rename = "candidatesTokenCount")]
219 candidates_token_count: Option<u64>,
220}
221
222#[derive(Debug, Deserialize)]
223struct Candidate {
224 #[serde(default)]
225 content: Option<CandidateContent>,
226}
227
228#[derive(Debug, Deserialize)]
229struct CandidateContent {
230 parts: Vec<ResponsePart>,
231}
232
233#[derive(Debug, Deserialize)]
234struct ResponsePart {
235 #[serde(default)]
236 text: Option<String>,
237 #[serde(default)]
239 thought: bool,
240}
241
242impl CandidateContent {
243 fn effective_text(self) -> Option<String> {
253 let mut answer_parts: Vec<String> = Vec::new();
254 let mut first_thinking: Option<String> = None;
255
256 for part in self.parts {
257 if let Some(text) = part.text {
258 if text.is_empty() {
259 continue;
260 }
261 if !part.thought {
262 answer_parts.push(text);
263 } else if first_thinking.is_none() {
264 first_thinking = Some(text);
265 }
266 }
267 }
268
269 if answer_parts.is_empty() {
270 first_thinking
271 } else {
272 Some(answer_parts.join(""))
273 }
274 }
275}
276
277#[derive(Debug, Deserialize)]
278struct ApiError {
279 message: String,
280}
281
282impl GenerateContentResponse {
283 fn into_effective_response(self) -> Self {
285 match self {
286 Self {
287 response: Some(mut inner),
288 usage_metadata,
289 ..
290 } => {
291 if inner.usage_metadata.is_none() {
292 inner.usage_metadata = usage_metadata;
293 }
294 *inner
295 }
296 other => other,
297 }
298 }
299}
300
301#[derive(Debug, Deserialize)]
307struct GeminiCliOAuthCreds {
308 access_token: Option<String>,
309 #[serde(alias = "idToken")]
310 id_token: Option<String>,
311 refresh_token: Option<String>,
312 #[serde(alias = "clientId")]
313 client_id: Option<String>,
314 #[serde(alias = "clientSecret")]
315 client_secret: Option<String>,
316 #[serde(alias = "expiryDate")]
318 expiry_date: Option<i64>,
319 expiry: Option<String>,
321}
322
323const GOOGLE_TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
329
330const CLOUDCODE_PA_ENDPOINT: &str = "https://cloudcode-pa.googleapis.com/v1internal";
333
334const LOAD_CODE_ASSIST_ENDPOINT: &str =
336 "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist";
337
338pub(crate) const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
340
341struct RefreshedToken {
347 access_token: String,
348 expiry_millis: Option<i64>,
350}
351
352fn refresh_gemini_cli_token(
358 refresh_token: &str,
359 client_id: Option<&str>,
360 client_secret: Option<&str>,
361) -> anyhow::Result<RefreshedToken> {
362 let client = reqwest::blocking::Client::builder()
363 .timeout(std::time::Duration::from_secs(15))
364 .connect_timeout(std::time::Duration::from_secs(5))
365 .build()
366 .unwrap_or_else(|_| reqwest::blocking::Client::new());
367
368 let form = build_oauth_refresh_form(refresh_token, client_id, client_secret);
369
370 let response = client
371 .post(GOOGLE_TOKEN_ENDPOINT)
372 .header("Content-Type", "application/x-www-form-urlencoded")
373 .header("Accept", "application/json")
374 .form(&form)
375 .send()
376 .map_err(|error| {
377 ::zeroclaw_log::record!(
378 ERROR,
379 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
380 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
381 .with_attrs(::serde_json::json!({
382 "oauth_provider": "gemini_cli",
383 "phase": "refresh_request",
384 "error": format!("{}", error),
385 })),
386 "gemini: CLI OAuth refresh request failed"
387 );
388 anyhow::Error::msg(format!("Gemini CLI OAuth refresh request failed: {error}"))
389 })?;
390
391 let status = response.status();
392 let body = response
393 .text()
394 .unwrap_or_else(|_| "<failed to read response body>".to_string());
395
396 if !status.is_success() {
397 anyhow::bail!("Gemini CLI OAuth refresh failed (HTTP {status}): {body}");
398 }
399
400 #[derive(Deserialize)]
401 struct TokenResponse {
402 access_token: Option<String>,
403 expires_in: Option<i64>,
404 }
405
406 let parsed: TokenResponse = serde_json::from_str(&body).map_err(|_| {
407 ::zeroclaw_log::record!(
408 ERROR,
409 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
410 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
411 .with_attrs(::serde_json::json!({"oauth_provider": "gemini_cli"})),
412 "gemini: CLI OAuth refresh response is not valid JSON"
413 );
414 anyhow::Error::msg("Gemini CLI OAuth refresh response is not valid JSON")
415 })?;
416
417 let access_token = parsed
418 .access_token
419 .filter(|t| !t.trim().is_empty())
420 .ok_or_else(|| {
421 ::zeroclaw_log::record!(
422 ERROR,
423 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
424 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
425 .with_attrs(::serde_json::json!({
426 "oauth_provider": "gemini_cli",
427 "missing": "access_token",
428 })),
429 "gemini: CLI OAuth refresh missing access_token"
430 );
431 anyhow::Error::msg("Gemini CLI OAuth refresh response missing access_token")
432 })?;
433
434 let expiry_millis = parsed.expires_in.and_then(|secs| {
435 let now_millis = std::time::SystemTime::now()
436 .duration_since(std::time::UNIX_EPOCH)
437 .ok()
438 .and_then(|d| i64::try_from(d.as_millis()).ok())?;
439 now_millis.checked_add(secs.checked_mul(1000)?)
440 });
441
442 Ok(RefreshedToken {
443 access_token,
444 expiry_millis,
445 })
446}
447
448fn build_oauth_refresh_form(
449 refresh_token: &str,
450 client_id: Option<&str>,
451 client_secret: Option<&str>,
452) -> Vec<(&'static str, String)> {
453 let mut form = vec![
454 ("grant_type", "refresh_token".to_string()),
455 ("refresh_token", refresh_token.to_string()),
456 ];
457 if let Some(id) = client_id.and_then(GeminiModelProvider::normalize_non_empty) {
458 form.push(("client_id", id));
459 }
460 if let Some(secret) = client_secret.and_then(GeminiModelProvider::normalize_non_empty) {
461 form.push(("client_secret", secret));
462 }
463 form
464}
465
466fn extract_client_id_from_id_token(id_token: &str) -> Option<String> {
467 let payload = id_token.split('.').nth(1)?;
468 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
469 .decode(payload)
470 .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(payload))
471 .ok()?;
472
473 #[derive(Deserialize)]
474 struct IdTokenClaims {
475 aud: Option<String>,
476 azp: Option<String>,
477 }
478
479 let claims: IdTokenClaims = serde_json::from_slice(&decoded).ok()?;
480 claims
481 .aud
482 .as_deref()
483 .and_then(GeminiModelProvider::normalize_non_empty)
484 .or_else(|| {
485 claims
486 .azp
487 .as_deref()
488 .and_then(GeminiModelProvider::normalize_non_empty)
489 })
490}
491
492async fn refresh_gemini_cli_token_async(
494 refresh_token: &str,
495 client_id: Option<&str>,
496 client_secret: Option<&str>,
497) -> anyhow::Result<RefreshedToken> {
498 let refresh_token = refresh_token.to_string();
499 let client_id = client_id.map(str::to_string);
500 let client_secret = client_secret.map(str::to_string);
501 tokio::task::spawn_blocking(move || {
502 refresh_gemini_cli_token(
503 &refresh_token,
504 client_id.as_deref(),
505 client_secret.as_deref(),
506 )
507 })
508 .await
509 .map_err(|e| {
510 ::zeroclaw_log::record!(
511 ERROR,
512 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
513 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
514 .with_attrs(::serde_json::json!({
515 "oauth_provider": "gemini_cli",
516 "phase": "task_join",
517 "error": format!("{}", e),
518 })),
519 "gemini: token refresh task panicked"
520 );
521 anyhow::Error::msg(format!("Token refresh task panicked: {e}"))
522 })?
523}
524
525impl GeminiModelProvider {
526 pub fn new(alias: &str, api_key: Option<&str>) -> Self {
533 let oauth_cred_paths = Self::discover_oauth_cred_paths();
534 let resolved_auth = api_key
535 .and_then(Self::normalize_non_empty)
536 .map(GeminiAuth::ExplicitKey)
537 .or_else(|| {
538 Self::try_load_gemini_cli_token(oauth_cred_paths.first())
539 .map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state))))
540 });
541
542 Self {
543 alias: alias.to_string(),
544 auth: resolved_auth,
545 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
546 oauth_project_seed: None,
547 oauth_cred_paths,
548 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
549 auth_service: None,
550 auth_profile_override: None,
551 oauth_client_id: None,
552 oauth_client_secret: None,
553 }
554 }
555 pub fn new_with_auth(
562 alias: &str,
563 api_key: Option<&str>,
564 auth_service: AuthService,
565 profile_override: Option<String>,
566 oauth_project_seed: Option<String>,
567 oauth_client_id: Option<String>,
568 oauth_client_secret: Option<String>,
569 ) -> Self {
570 let oauth_cred_paths = Self::discover_oauth_cred_paths();
571
572 let resolved_auth = api_key
574 .and_then(Self::normalize_non_empty)
575 .map(GeminiAuth::ExplicitKey);
576
577 let (auth, use_managed) = if resolved_auth.is_some() {
580 (resolved_auth, false)
581 } else {
582 let has_managed = std::thread::scope(|s| {
585 s.spawn(|| {
586 let rt = tokio::runtime::Builder::new_current_thread()
587 .enable_all()
588 .build()
589 .ok()?;
590 rt.block_on(async {
591 auth_service
592 .get_gemini_profile(profile_override.as_deref())
593 .await
594 .ok()
595 .flatten()
596 })
597 })
598 .join()
599 .ok()
600 .flatten()
601 .is_some()
602 });
603
604 if has_managed {
605 (Some(GeminiAuth::ManagedOAuth), true)
606 } else {
607 let cli_auth = Self::try_load_gemini_cli_token(oauth_cred_paths.first())
609 .map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state))));
610 (cli_auth, false)
611 }
612 };
613
614 Self {
615 alias: alias.to_string(),
616 auth,
617 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
618 oauth_project_seed,
619 oauth_cred_paths,
620 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
621 auth_service: if use_managed {
622 Some(auth_service)
623 } else {
624 None
625 },
626 auth_profile_override: profile_override,
627 oauth_client_id,
628 oauth_client_secret,
629 }
630 }
631
632 fn normalize_non_empty(value: &str) -> Option<String> {
633 let trimmed = value.trim();
634 if trimmed.is_empty() {
635 None
636 } else {
637 Some(trimmed.to_string())
638 }
639 }
640
641 fn load_gemini_cli_creds(creds_path: &PathBuf) -> Option<GeminiCliOAuthCreds> {
642 if !creds_path.exists() {
643 return None;
644 }
645 let content = std::fs::read_to_string(creds_path).ok()?;
646 serde_json::from_str(&content).ok()
647 }
648
649 fn discover_oauth_cred_paths() -> Vec<PathBuf> {
654 let home = match UserDirs::new() {
655 Some(u) => u.home_dir().to_path_buf(),
656 None => return Vec::new(),
657 };
658
659 let mut paths = Vec::new();
660
661 let primary = home.join(".gemini").join("oauth_creds.json");
662 if primary.exists() {
663 paths.push(primary);
664 }
665
666 if let Ok(entries) = std::fs::read_dir(&home) {
667 let mut extras: Vec<PathBuf> = entries
668 .filter_map(|e| e.ok())
669 .filter_map(|e| {
670 let name = e.file_name().to_string_lossy().to_string();
671 if name.starts_with(".gemini-") && name.ends_with("-home") {
672 let path = e.path().join(".gemini").join("oauth_creds.json");
673 if path.exists() {
674 return Some(path);
675 }
676 }
677 None
678 })
679 .collect();
680 extras.sort();
681 paths.extend(extras);
682 }
683
684 paths
685 }
686
687 fn try_load_gemini_cli_token(path: Option<&PathBuf>) -> Option<OAuthTokenState> {
692 let creds = Self::load_gemini_cli_creds(path?)?;
693
694 let expiry_millis = creds.expiry_date.or_else(|| {
696 creds.expiry.as_deref().and_then(|expiry| {
697 chrono::DateTime::parse_from_rfc3339(expiry)
698 .ok()
699 .map(|dt| dt.timestamp_millis())
700 })
701 });
702
703 let access_token = creds
704 .access_token
705 .and_then(|token| Self::normalize_non_empty(&token))?;
706
707 let id_token_client_id = creds
708 .id_token
709 .as_deref()
710 .and_then(extract_client_id_from_id_token);
711
712 let client_id = creds
713 .client_id
714 .as_deref()
715 .and_then(Self::normalize_non_empty)
716 .or(id_token_client_id);
717 let client_secret = creds
718 .client_secret
719 .as_deref()
720 .and_then(Self::normalize_non_empty);
721
722 Some(OAuthTokenState {
723 access_token,
724 refresh_token: creds.refresh_token,
725 client_id,
726 client_secret,
727 expiry_millis,
728 })
729 }
730
731 pub fn has_cli_credentials() -> bool {
733 Self::discover_oauth_cred_paths().iter().any(|path| {
734 Self::load_gemini_cli_creds(path)
735 .and_then(|creds| {
736 creds
737 .access_token
738 .as_deref()
739 .and_then(Self::normalize_non_empty)
740 })
741 .is_some()
742 })
743 }
744
745 pub fn has_any_auth() -> bool {
750 Self::has_cli_credentials()
751 }
752
753 pub fn auth_source(&self) -> &'static str {
756 match self.auth.as_ref() {
757 Some(GeminiAuth::ExplicitKey(_)) => "config",
758 Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth",
759 Some(GeminiAuth::ManagedOAuth) => "auth-profiles",
760 None => "none",
761 }
762 }
763
764 async fn get_valid_oauth_token(
767 state: &Arc<tokio::sync::Mutex<OAuthTokenState>>,
768 ) -> anyhow::Result<String> {
769 let mut guard = state.lock().await;
770
771 let now_millis = std::time::SystemTime::now()
772 .duration_since(std::time::UNIX_EPOCH)
773 .ok()
774 .and_then(|d| i64::try_from(d.as_millis()).ok())
775 .unwrap_or(i64::MAX);
776
777 let needs_refresh = guard
779 .expiry_millis
780 .is_none_or(|exp| exp <= now_millis.saturating_add(60_000));
781
782 if needs_refresh {
783 if let Some(ref refresh_token) = guard.refresh_token {
784 let refreshed = refresh_gemini_cli_token_async(
785 refresh_token,
786 guard.client_id.as_deref(),
787 guard.client_secret.as_deref(),
788 )
789 .await?;
790 ::zeroclaw_log::record!(
791 INFO,
792 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
793 "Gemini CLI OAuth token refreshed successfully (runtime)"
794 );
795 guard.access_token = refreshed.access_token;
796 guard.expiry_millis = refreshed.expiry_millis;
797 } else {
798 anyhow::bail!(
799 "Gemini CLI OAuth token expired and no refresh_token available — re-run `gemini` to authenticate"
800 );
801 }
802 }
803
804 Ok(guard.access_token.clone())
805 }
806
807 async fn rotate_oauth_credential(
810 &self,
811 state: &Arc<tokio::sync::Mutex<OAuthTokenState>>,
812 ) -> bool {
813 if self.oauth_cred_paths.len() <= 1 {
814 return false;
815 }
816
817 let mut idx = self.oauth_index.lock().await;
818 let start = *idx;
819
820 loop {
821 let next = (*idx + 1) % self.oauth_cred_paths.len();
822 *idx = next;
823
824 if next == start {
825 return false;
826 }
827
828 if let Some(next_state) =
829 Self::try_load_gemini_cli_token(self.oauth_cred_paths.get(next))
830 {
831 {
832 let mut guard = state.lock().await;
833 *guard = next_state;
834 }
835 {
836 let mut cached_project = self.oauth_project.lock().await;
837 *cached_project = None;
838 }
839 ::zeroclaw_log::record!(
840 WARN,
841 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
842 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
843 &format!(
844 "Gemini OAuth: rotated credential to {}",
845 self.oauth_cred_paths[next].display().to_string()
846 )
847 );
848 return true;
849 }
850 }
851 }
852
853 fn format_model_name(model: &str) -> String {
854 if model.starts_with("models/") {
855 model.to_string()
856 } else {
857 format!("models/{model}")
858 }
859 }
860
861 fn format_internal_model_name(model: &str) -> String {
862 model.strip_prefix("models/").unwrap_or(model).to_string()
863 }
864
865 fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String {
875 match auth {
876 GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => {
877 format!("{CLOUDCODE_PA_ENDPOINT}:generateContent")
880 }
881 _ => {
882 let model_name = Self::format_model_name(model);
883 let base_url = format!("{BASE_URL}/{model_name}:generateContent");
884
885 if auth.is_api_key() {
886 format!("{base_url}?key={}", auth.api_key_credential())
887 } else {
888 base_url
889 }
890 }
891 }
892 }
893
894 fn http_client(&self) -> Client {
895 zeroclaw_config::schema::build_runtime_proxy_client_with_timeouts(
896 "model_provider.gemini",
897 120,
898 10,
899 )
900 }
901
902 async fn resolve_oauth_project(&self, token: &str) -> anyhow::Result<String> {
905 let project_seed = self.oauth_project_seed.clone();
906 let project_seed_for_request = project_seed.clone();
907 let duet_project_for_request = project_seed.clone();
908
909 {
911 let cached = self.oauth_project.lock().await;
912 if let Some(ref project) = *cached {
913 return Ok(project.clone());
914 }
915 }
916
917 let client = self.http_client();
919 let response = client
920 .post(LOAD_CODE_ASSIST_ENDPOINT)
921 .bearer_auth(token)
922 .json(&serde_json::json!({
923 "cloudaicompanionProject": project_seed_for_request,
924 "metadata": {
925 "ideType": "GEMINI_CLI",
926 "platform": "PLATFORM_UNSPECIFIED",
927 "pluginType": "GEMINI",
928 "duetProject": duet_project_for_request,
929 }
930 }))
931 .send()
932 .await?;
933
934 if !response.status().is_success() {
935 let status = response.status();
936 let body = response.text().await.unwrap_or_default();
937 if let Some(seed) = project_seed {
938 ::zeroclaw_log::record!(
939 WARN,
940 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
941 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
942 .with_attrs(::serde_json::json!({"status": status.to_string()})),
943 "loadCodeAssist failed (HTTP ); using oauth_project seed fallback"
944 );
945 return Ok(seed);
946 }
947 anyhow::bail!("loadCodeAssist failed (HTTP {status}): {body}");
948 }
949
950 #[derive(Deserialize)]
951 struct LoadCodeAssistResponse {
952 #[serde(rename = "cloudaicompanionProject")]
953 cloudaicompanion_project: Option<String>,
954 }
955
956 let result: LoadCodeAssistResponse = response.json().await?;
957 let project = result
958 .cloudaicompanion_project
959 .filter(|p| !p.trim().is_empty())
960 .or(project_seed)
961 .ok_or_else(|| {
962 ::zeroclaw_log::record!(
963 ERROR,
964 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
965 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
966 .with_attrs(::serde_json::json!({
967 "missing": "cloudaicompanionProject",
968 })),
969 "gemini: loadCodeAssist missing project context"
970 );
971 anyhow::Error::msg("loadCodeAssist response missing project context")
972 })?;
973
974 {
976 let mut cached = self.oauth_project.lock().await;
977 *cached = Some(project.clone());
978 }
979
980 Ok(project)
981 }
982
983 fn build_generate_content_request(
988 &self,
989 auth: &GeminiAuth,
990 url: &str,
991 request: &GenerateContentRequest,
992 model: &str,
993 include_generation_config: bool,
994 project: Option<&str>,
995 oauth_token: Option<&str>,
996 ) -> reqwest::RequestBuilder {
997 let req = self.http_client().post(url).json(request);
998 match auth {
999 GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => {
1000 let token = oauth_token.unwrap_or_default();
1001 let internal_request = InternalGenerateContentEnvelope {
1004 model: Self::format_internal_model_name(model),
1005 project: project.map(|value| value.to_string()),
1006 user_prompt_id: Some(uuid::Uuid::new_v4().to_string()),
1007 request: InternalGenerateContentRequest {
1008 contents: request.contents.clone(),
1009 system_instruction: request.system_instruction.clone(),
1010 generation_config: if include_generation_config {
1011 Some(request.generation_config.clone())
1012 } else {
1013 None
1014 },
1015 },
1016 };
1017 self.http_client()
1018 .post(url)
1019 .json(&internal_request)
1020 .bearer_auth(token)
1021 }
1022 _ => req,
1023 }
1024 }
1025
1026 fn should_retry_oauth_without_generation_config(
1027 status: reqwest::StatusCode,
1028 error_text: &str,
1029 ) -> bool {
1030 if status != reqwest::StatusCode::BAD_REQUEST {
1031 return false;
1032 }
1033
1034 error_text.contains("Unknown name \"generationConfig\"")
1035 || error_text.contains("Unknown name 'generationConfig'")
1036 || error_text.contains(r#"Unknown name \"generationConfig\""#)
1037 }
1038
1039 fn should_rotate_oauth_on_error(status: reqwest::StatusCode, error_text: &str) -> bool {
1040 status == reqwest::StatusCode::TOO_MANY_REQUESTS
1041 || status == reqwest::StatusCode::SERVICE_UNAVAILABLE
1042 || status.is_server_error()
1043 || error_text.contains("RESOURCE_EXHAUSTED")
1044 }
1045}
1046
1047impl GeminiModelProvider {
1048 fn build_chat_contents(
1049 messages: &[ChatMessage],
1050 tool_instructions: Option<&str>,
1051 ) -> (Vec<Content>, Option<Content>) {
1052 let mut system_parts: Vec<&str> = Vec::new();
1053 let mut contents: Vec<Content> = Vec::new();
1054 for msg in messages {
1055 match msg.role.as_str() {
1056 "system" => system_parts.push(&msg.content),
1057 "user" => contents.push(Content {
1058 role: Some("user".to_string()),
1059 parts: build_parts(&msg.content),
1060 }),
1061 "assistant" => contents.push(Content {
1062 role: Some("model".to_string()),
1063 parts: vec![Part::text(&msg.content)],
1064 }),
1065 _ => {}
1066 }
1067 }
1068 if let Some(instructions) = tool_instructions {
1069 system_parts.push(instructions);
1070 }
1071 let system_instruction = if system_parts.is_empty() {
1072 None
1073 } else {
1074 Some(Content {
1075 role: None,
1076 parts: vec![Part::text(system_parts.join("\n\n"))],
1077 })
1078 };
1079 (contents, system_instruction)
1080 }
1081
1082 async fn chat_with_history_full(
1083 &self,
1084 messages: &[ChatMessage],
1085 model: &str,
1086 temperature: Option<f64>,
1087 ) -> anyhow::Result<(String, Option<TokenUsage>)> {
1088 let temperature = temperature.unwrap_or(self.default_temperature());
1089 let (contents, system_instruction) = Self::build_chat_contents(messages, None);
1090 self.send_generate_content(contents, system_instruction, model, temperature)
1091 .await
1092 }
1093
1094 fn token_usage_from_metadata(usage: GeminiUsageMetadata) -> Option<TokenUsage> {
1095 if usage.prompt_token_count.is_none() && usage.candidates_token_count.is_none() {
1096 return None;
1097 }
1098 Some(TokenUsage {
1099 input_tokens: usage.prompt_token_count,
1100 output_tokens: usage.candidates_token_count,
1101 cached_input_tokens: None,
1102 })
1103 }
1104
1105 async fn send_generate_content(
1106 &self,
1107 contents: Vec<Content>,
1108 system_instruction: Option<Content>,
1109 model: &str,
1110 temperature: f64,
1111 ) -> anyhow::Result<(String, Option<TokenUsage>)> {
1112 let auth = self.auth.as_ref().ok_or_else(|| {
1113 ::zeroclaw_log::record!(
1114 ERROR,
1115 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
1116 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1117 .with_attrs(::serde_json::json!({"missing": "auth"})),
1118 "gemini: no auth configured"
1119 );
1120 anyhow::Error::msg(
1121 "Gemini API key not found. Options:\n\
1122 1. Set GEMINI_API_KEY env var\n\
1123 2. Run `gemini` CLI to authenticate (tokens will be reused)\n\
1124 3. Run `zeroclaw auth login --model-provider gemini`\n\
1125 4. Get an API key from https://aistudio.google.com/app/apikey\n\
1126 5. Run `zeroclaw onboard` to configure",
1127 )
1128 })?;
1129
1130 let oauth_state = match auth {
1131 GeminiAuth::OAuthToken(state) => Some(state.clone()),
1132 _ => None,
1133 };
1134
1135 let (mut oauth_token, mut project) = match auth {
1137 GeminiAuth::OAuthToken(state) => {
1138 let token = Self::get_valid_oauth_token(state).await?;
1139 let proj = self.resolve_oauth_project(&token).await?;
1140 (Some(token), Some(proj))
1141 }
1142 GeminiAuth::ManagedOAuth => {
1143 let auth_service = self.auth_service.as_ref().ok_or_else(|| {
1144 ::zeroclaw_log::record!(
1145 ERROR,
1146 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
1147 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1148 .with_attrs(::serde_json::json!({"missing": "auth_service"})),
1149 "gemini: ManagedOAuth requires auth_service"
1150 );
1151 anyhow::Error::msg("ManagedOAuth requires auth_service")
1152 })?;
1153 let token = auth_service
1154 .get_valid_gemini_access_token(
1155 self.auth_profile_override.as_deref(),
1156 self.oauth_client_id.as_deref().unwrap_or(""),
1157 self.oauth_client_secret.as_deref().unwrap_or(""),
1158 )
1159 .await?
1160 .ok_or_else(|| {
1161 ::zeroclaw_log::record!(
1162 ERROR,
1163 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
1164 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1165 .with_attrs(::serde_json::json!({"oauth_provider": "gemini"})),
1166 "gemini: auth profile not found"
1167 );
1168 anyhow::Error::msg(
1169 "Gemini auth profile not found. Run `zeroclaw auth login --model-provider gemini`.",
1170 )
1171 })?;
1172 let proj = self.resolve_oauth_project(&token).await?;
1173 (Some(token), Some(proj))
1174 }
1175 _ => (None, None),
1176 };
1177
1178 let request = GenerateContentRequest {
1179 contents,
1180 system_instruction,
1181 generation_config: GenerationConfig {
1182 temperature,
1183 max_output_tokens: 8192,
1184 },
1185 };
1186
1187 let url = Self::build_generate_content_url(model, auth);
1188
1189 let mut response = self
1190 .build_generate_content_request(
1191 auth,
1192 &url,
1193 &request,
1194 model,
1195 true,
1196 project.as_deref(),
1197 oauth_token.as_deref(),
1198 )
1199 .send()
1200 .await?;
1201
1202 if !response.status().is_success() {
1203 let status = response.status();
1204 let error_text = response.text().await.unwrap_or_default();
1205
1206 if auth.is_oauth() && Self::should_rotate_oauth_on_error(status, &error_text) {
1207 let can_retry = match auth {
1210 GeminiAuth::OAuthToken(_) => {
1211 if let Some(state) = oauth_state.as_ref() {
1212 self.rotate_oauth_credential(state).await
1213 } else {
1214 false
1215 }
1216 }
1217 GeminiAuth::ManagedOAuth => true, _ => false,
1219 };
1220
1221 if can_retry {
1222 let (new_token, new_project) = match auth {
1224 GeminiAuth::OAuthToken(state) => {
1225 let token = Self::get_valid_oauth_token(state).await?;
1226 let proj = self.resolve_oauth_project(&token).await?;
1227 (token, proj)
1228 }
1229 GeminiAuth::ManagedOAuth => {
1230 let auth_service = self.auth_service.as_ref().unwrap();
1231 let token = auth_service
1232 .get_valid_gemini_access_token(
1233 self.auth_profile_override.as_deref(),
1234 self.oauth_client_id.as_deref().unwrap_or(""),
1235 self.oauth_client_secret.as_deref().unwrap_or(""),
1236 )
1237 .await?
1238 .ok_or_else(|| {
1239 ::zeroclaw_log::record!(
1240 ERROR,
1241 ::zeroclaw_log::Event::new(
1242 module_path!(),
1243 ::zeroclaw_log::Action::Reject
1244 )
1245 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1246 .with_attrs(
1247 ::serde_json::json!({"oauth_provider": "gemini"})
1248 ),
1249 "gemini: auth profile not found"
1250 );
1251 anyhow::Error::msg("Gemini auth profile not found")
1252 })?;
1253 let proj = self.resolve_oauth_project(&token).await?;
1254 (token, proj)
1255 }
1256 _ => unreachable!(),
1257 };
1258 oauth_token = Some(new_token);
1259 project = Some(new_project);
1260 response = self
1261 .build_generate_content_request(
1262 auth,
1263 &url,
1264 &request,
1265 model,
1266 true,
1267 project.as_deref(),
1268 oauth_token.as_deref(),
1269 )
1270 .send()
1271 .await?;
1272 } else {
1273 anyhow::bail!("Gemini API error ({status}): {error_text}");
1274 }
1275 } else if auth.is_oauth()
1276 && Self::should_retry_oauth_without_generation_config(status, &error_text)
1277 {
1278 ::zeroclaw_log::record!(
1279 WARN,
1280 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
1281 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
1282 "Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig"
1283 );
1284 response = self
1285 .build_generate_content_request(
1286 auth,
1287 &url,
1288 &request,
1289 model,
1290 false,
1291 project.as_deref(),
1292 oauth_token.as_deref(),
1293 )
1294 .send()
1295 .await?;
1296 } else {
1297 anyhow::bail!("Gemini API error ({status}): {error_text}");
1298 }
1299 }
1300
1301 if !response.status().is_success() {
1302 let status = response.status();
1303 let error_text = response.text().await.unwrap_or_default();
1304 if auth.is_oauth()
1305 && Self::should_retry_oauth_without_generation_config(status, &error_text)
1306 {
1307 ::zeroclaw_log::record!(
1308 WARN,
1309 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
1310 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
1311 "Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig"
1312 );
1313 response = self
1314 .build_generate_content_request(
1315 auth,
1316 &url,
1317 &request,
1318 model,
1319 false,
1320 project.as_deref(),
1321 oauth_token.as_deref(),
1322 )
1323 .send()
1324 .await?;
1325 } else {
1326 anyhow::bail!("Gemini API error ({status}): {error_text}");
1327 }
1328 }
1329
1330 if !response.status().is_success() {
1331 let status = response.status();
1332 let error_text = response.text().await.unwrap_or_default();
1333 anyhow::bail!("Gemini API error ({status}): {error_text}");
1334 }
1335
1336 let result: GenerateContentResponse = response.json().await?;
1337 if let Some(err) = &result.error {
1338 anyhow::bail!("Gemini API error: {}", err.message);
1339 }
1340 let result = result.into_effective_response();
1341 if let Some(err) = result.error {
1342 anyhow::bail!("Gemini API error: {}", err.message);
1343 }
1344
1345 let usage = result
1346 .usage_metadata
1347 .and_then(Self::token_usage_from_metadata);
1348
1349 let text = result
1350 .candidates
1351 .and_then(|c| c.into_iter().next())
1352 .and_then(|c| c.content)
1353 .and_then(|c| c.effective_text())
1354 .ok_or_else(|| {
1355 ::zeroclaw_log::record!(
1356 ERROR,
1357 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1358 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
1359 "gemini: empty response text"
1360 );
1361 anyhow::Error::msg("No response from Gemini")
1362 })?;
1363
1364 Ok((text, usage))
1365 }
1366}
1367
1368#[async_trait]
1369impl ModelProvider for GeminiModelProvider {
1370 fn default_base_url(&self) -> Option<&str> {
1372 Some(BASE_URL)
1373 }
1374
1375 fn capabilities(&self) -> zeroclaw_api::model_provider::ProviderCapabilities {
1376 zeroclaw_api::model_provider::ProviderCapabilities {
1377 vision: true,
1378 native_tool_calling: false,
1379 prompt_caching: false,
1380 extended_thinking: false,
1381 }
1382 }
1383
1384 async fn chat_with_system(
1385 &self,
1386 system_prompt: Option<&str>,
1387 message: &str,
1388 model: &str,
1389 temperature: Option<f64>,
1390 ) -> anyhow::Result<String> {
1391 let temperature = temperature.unwrap_or(self.default_temperature());
1392 let system_instruction = system_prompt.map(|sys| Content {
1393 role: None,
1394 parts: vec![Part::text(sys)],
1395 });
1396
1397 let contents = vec![Content {
1398 role: Some("user".to_string()),
1399 parts: build_parts(message),
1400 }];
1401
1402 let (text, _usage) = self
1403 .send_generate_content(contents, system_instruction, model, temperature)
1404 .await?;
1405 Ok(text)
1406 }
1407
1408 async fn chat_with_history(
1409 &self,
1410 messages: &[ChatMessage],
1411 model: &str,
1412 temperature: Option<f64>,
1413 ) -> anyhow::Result<String> {
1414 let (text, _usage) = self
1415 .chat_with_history_full(messages, model, temperature)
1416 .await?;
1417 Ok(text)
1418 }
1419
1420 async fn chat(
1421 &self,
1422 request: ProviderChatRequest<'_>,
1423 model: &str,
1424 temperature: Option<f64>,
1425 ) -> anyhow::Result<ProviderChatResponse> {
1426 let tool_instructions = if let Some(tools) = request.tools
1427 && !tools.is_empty()
1428 && !self.supports_native_tools()
1429 {
1430 Some(match self.convert_tools(tools) {
1431 ToolsPayload::PromptGuided { instructions } => instructions,
1432 payload => {
1433 anyhow::bail!(
1434 "Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
1435 )
1436 }
1437 })
1438 } else {
1439 None
1440 };
1441
1442 let temperature = temperature.unwrap_or(self.default_temperature());
1443 let (contents, system_instruction) =
1444 Self::build_chat_contents(request.messages, tool_instructions.as_deref());
1445 let (text, usage) = self
1446 .send_generate_content(contents, system_instruction, model, temperature)
1447 .await?;
1448 Ok(ProviderChatResponse {
1449 text: Some(text),
1450 tool_calls: Vec::new(),
1451 usage,
1452 reasoning_content: None,
1453 })
1454 }
1455
1456 async fn warmup(&self) -> anyhow::Result<()> {
1457 if let Some(auth) = self.auth.as_ref() {
1458 match auth {
1459 GeminiAuth::ManagedOAuth => {
1460 let auth_service = self.auth_service.as_ref().ok_or_else(|| {
1463 ::zeroclaw_log::record!(
1464 ERROR,
1465 ::zeroclaw_log::Event::new(
1466 module_path!(),
1467 ::zeroclaw_log::Action::Reject
1468 )
1469 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1470 .with_attrs(::serde_json::json!({"missing": "auth_service"})),
1471 "gemini: ManagedOAuth requires auth_service"
1472 );
1473 anyhow::Error::msg("ManagedOAuth requires auth_service")
1474 })?;
1475
1476 let _token = auth_service
1477 .get_valid_gemini_access_token(
1478 self.auth_profile_override.as_deref(),
1479 self.oauth_client_id.as_deref().unwrap_or(""),
1480 self.oauth_client_secret.as_deref().unwrap_or(""),
1481 )
1482 .await?
1483 .ok_or_else(|| {
1484 ::zeroclaw_log::record!(
1485 ERROR,
1486 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
1487 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1488 .with_attrs(::serde_json::json!({"oauth_provider": "gemini"})),
1489 "gemini: auth profile not found or expired"
1490 );
1491 anyhow::Error::msg(
1492 "Gemini auth profile not found or expired. Run: zeroclaw auth login --model-provider gemini",
1493 )
1494 })?;
1495
1496 }
1500 GeminiAuth::OAuthToken(_) => {
1501 }
1504 _ => {
1505 let url = if auth.is_api_key() {
1507 format!(
1508 "https://generativelanguage.googleapis.com/v1beta/models?key={}",
1509 auth.api_key_credential()
1510 )
1511 } else {
1512 "https://generativelanguage.googleapis.com/v1beta/models".to_string()
1513 };
1514
1515 self.http_client()
1516 .get(&url)
1517 .send()
1518 .await?
1519 .error_for_status()?;
1520 }
1521 }
1522 }
1523 Ok(())
1524 }
1525
1526 async fn list_models(&self) -> anyhow::Result<Vec<String>> {
1527 crate::models_dev::list_models_for("google").await
1530 }
1531}
1532
1533impl ::zeroclaw_api::attribution::Attributable for GeminiModelProvider {
1534 fn role(&self) -> ::zeroclaw_api::attribution::Role {
1535 ::zeroclaw_api::attribution::Role::Provider(
1536 ::zeroclaw_api::attribution::ProviderKind::Model(
1537 ::zeroclaw_api::attribution::ModelProviderKind::Gemini,
1538 ),
1539 )
1540 }
1541 fn alias(&self) -> &str {
1542 &self.alias
1543 }
1544}
1545
1546#[cfg(test)]
1547mod tests {
1548 use super::*;
1549 use reqwest::{StatusCode, header::AUTHORIZATION};
1550
1551 fn test_oauth_auth(token: &str) -> GeminiAuth {
1553 GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(OAuthTokenState {
1554 access_token: token.to_string(),
1555 refresh_token: None,
1556 client_id: None,
1557 client_secret: None,
1558 expiry_millis: None,
1559 })))
1560 }
1561
1562 fn test_model_provider(auth: Option<GeminiAuth>) -> GeminiModelProvider {
1563 GeminiModelProvider {
1564 alias: "test".to_string(),
1565 auth,
1566 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
1567 oauth_project_seed: None,
1568 oauth_cred_paths: Vec::new(),
1569 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
1570 auth_service: None,
1571 auth_profile_override: None,
1572 oauth_client_id: None,
1573 oauth_client_secret: None,
1574 }
1575 }
1576
1577 #[test]
1578 fn normalize_non_empty_trims_and_filters() {
1579 assert_eq!(
1580 GeminiModelProvider::normalize_non_empty(" value "),
1581 Some("value".into())
1582 );
1583 assert_eq!(GeminiModelProvider::normalize_non_empty(""), None);
1584 assert_eq!(GeminiModelProvider::normalize_non_empty(" \t\n"), None);
1585 }
1586
1587 #[test]
1588 fn oauth_refresh_form_uses_provided_client_credentials() {
1589 let form = build_oauth_refresh_form("refresh-token", Some("client-id"), Some("secret"));
1590 let map: std::collections::HashMap<_, _> = form.into_iter().collect();
1591 assert_eq!(map.get("grant_type"), Some(&"refresh_token".to_string()));
1592 assert_eq!(map.get("refresh_token"), Some(&"refresh-token".to_string()));
1593 assert_eq!(map.get("client_id"), Some(&"client-id".to_string()));
1594 assert_eq!(map.get("client_secret"), Some(&"secret".to_string()));
1595 }
1596
1597 #[test]
1598 fn oauth_refresh_form_omits_client_credentials_when_missing() {
1599 let form = build_oauth_refresh_form("refresh-token", None, None);
1600 let map: std::collections::HashMap<_, _> = form.into_iter().collect();
1601 assert!(!map.contains_key("client_id"));
1602 assert!(!map.contains_key("client_secret"));
1603 }
1604
1605 #[test]
1606 fn extract_client_id_from_id_token_prefers_aud_claim() {
1607 let payload = serde_json::json!({
1608 "aud": "aud-client-id",
1609 "azp": "azp-client-id"
1610 });
1611 let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
1612 .encode(serde_json::to_vec(&payload).unwrap());
1613 let token = format!("header.{payload_b64}.sig");
1614
1615 assert_eq!(
1616 extract_client_id_from_id_token(&token),
1617 Some("aud-client-id".to_string())
1618 );
1619 }
1620
1621 #[test]
1622 fn extract_client_id_from_id_token_uses_azp_when_aud_missing() {
1623 let payload = serde_json::json!({
1624 "azp": "azp-client-id"
1625 });
1626 let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
1627 .encode(serde_json::to_vec(&payload).unwrap());
1628 let token = format!("header.{payload_b64}.sig");
1629
1630 assert_eq!(
1631 extract_client_id_from_id_token(&token),
1632 Some("azp-client-id".to_string())
1633 );
1634 }
1635
1636 #[test]
1637 fn extract_client_id_from_id_token_returns_none_for_invalid_tokens() {
1638 assert_eq!(extract_client_id_from_id_token("invalid"), None);
1639 assert_eq!(extract_client_id_from_id_token("a.b.c"), None);
1640 }
1641
1642 #[test]
1643 fn try_load_cli_token_derives_client_id_from_id_token_when_missing() {
1644 let payload = serde_json::json!({ "aud": "derived-client-id" });
1645 let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
1646 .encode(serde_json::to_vec(&payload).unwrap());
1647 let id_token = format!("header.{payload_b64}.sig");
1648
1649 let file = tempfile::NamedTempFile::new().unwrap();
1650 let json = format!(
1651 r#"{{
1652 "access_token": "ya29.test-access",
1653 "refresh_token": "1//test-refresh",
1654 "id_token": "{id_token}"
1655 }}"#
1656 );
1657 std::fs::write(file.path(), json).unwrap();
1658
1659 let path = file.path().to_path_buf();
1660 let state = GeminiModelProvider::try_load_gemini_cli_token(Some(&path)).unwrap();
1661 assert_eq!(state.client_id.as_deref(), Some("derived-client-id"));
1662 assert_eq!(state.client_secret, None);
1663 }
1664
1665 #[test]
1666 fn provider_creates_without_key() {
1667 let model_provider = GeminiModelProvider::new("test", None);
1668 let _ = model_provider.auth_source();
1670 }
1671
1672 #[test]
1673 fn provider_creates_with_key() {
1674 let model_provider = GeminiModelProvider::new("test", Some("test-api-key"));
1675 assert!(matches!(
1676 model_provider.auth,
1677 Some(GeminiAuth::ExplicitKey(ref key)) if key == "test-api-key"
1678 ));
1679 }
1680
1681 #[test]
1682 fn provider_rejects_empty_key() {
1683 let model_provider = GeminiModelProvider::new("test", Some(""));
1684 assert!(!matches!(
1685 model_provider.auth,
1686 Some(GeminiAuth::ExplicitKey(_))
1687 ));
1688 }
1689
1690 #[test]
1691 fn auth_source_explicit_key() {
1692 let model_provider = test_model_provider(Some(GeminiAuth::ExplicitKey("key".into())));
1693 assert_eq!(model_provider.auth_source(), "config");
1694 }
1695
1696 #[test]
1697 fn auth_source_none_without_credentials() {
1698 let model_provider = test_model_provider(None);
1699 assert_eq!(model_provider.auth_source(), "none");
1700 }
1701
1702 #[test]
1703 fn auth_source_oauth() {
1704 let model_provider = test_model_provider(Some(test_oauth_auth("ya29.mock")));
1705 assert_eq!(model_provider.auth_source(), "Gemini CLI OAuth");
1706 }
1707
1708 #[test]
1709 fn model_name_formatting() {
1710 assert_eq!(
1711 GeminiModelProvider::format_model_name("gemini-2.0-flash"),
1712 "models/gemini-2.0-flash"
1713 );
1714 assert_eq!(
1715 GeminiModelProvider::format_model_name("models/gemini-1.5-pro"),
1716 "models/gemini-1.5-pro"
1717 );
1718 assert_eq!(
1719 GeminiModelProvider::format_internal_model_name("models/gemini-2.5-flash"),
1720 "gemini-2.5-flash"
1721 );
1722 assert_eq!(
1723 GeminiModelProvider::format_internal_model_name("gemini-2.5-flash"),
1724 "gemini-2.5-flash"
1725 );
1726 }
1727
1728 #[test]
1729 fn api_key_url_includes_key_query_param() {
1730 let auth = GeminiAuth::ExplicitKey("api-key-123".into());
1731 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1732 assert!(url.contains(":generateContent?key=api-key-123"));
1733 }
1734
1735 #[test]
1736 fn oauth_url_uses_internal_endpoint() {
1737 let auth = test_oauth_auth("ya29.test-token");
1738 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1739 assert!(url.starts_with("https://cloudcode-pa.googleapis.com/v1internal"));
1740 assert!(url.ends_with(":generateContent"));
1741 assert!(!url.contains("generativelanguage.googleapis.com"));
1742 assert!(!url.contains("?key="));
1743 }
1744
1745 #[test]
1746 fn api_key_url_uses_public_endpoint() {
1747 let auth = GeminiAuth::ExplicitKey("api-key-123".into());
1748 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1749 assert!(url.contains("generativelanguage.googleapis.com/v1beta"));
1750 assert!(url.contains("models/gemini-2.0-flash"));
1751 }
1752
1753 #[test]
1754 fn oauth_request_uses_bearer_auth_header() {
1755 let model_provider = test_model_provider(Some(test_oauth_auth("ya29.mock-token")));
1756 let auth = test_oauth_auth("ya29.mock-token");
1757 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1758 let body = GenerateContentRequest {
1759 contents: vec![Content {
1760 role: Some("user".into()),
1761 parts: vec![Part::text("hello")],
1762 }],
1763 system_instruction: None,
1764 generation_config: GenerationConfig {
1765 temperature: 0.7,
1766 max_output_tokens: 8192,
1767 },
1768 };
1769
1770 let request = model_provider
1771 .build_generate_content_request(
1772 &auth,
1773 &url,
1774 &body,
1775 "gemini-2.0-flash",
1776 true,
1777 Some("test-project"),
1778 Some("ya29.mock-token"),
1779 )
1780 .build()
1781 .unwrap();
1782
1783 assert_eq!(
1784 request
1785 .headers()
1786 .get(AUTHORIZATION)
1787 .and_then(|h| h.to_str().ok()),
1788 Some("Bearer ya29.mock-token")
1789 );
1790 }
1791
1792 #[test]
1793 fn oauth_request_wraps_payload_in_request_envelope() {
1794 let model_provider = test_model_provider(Some(test_oauth_auth("ya29.mock-token")));
1795 let auth = test_oauth_auth("ya29.mock-token");
1796 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1797 let body = GenerateContentRequest {
1798 contents: vec![Content {
1799 role: Some("user".into()),
1800 parts: vec![Part::text("hello")],
1801 }],
1802 system_instruction: None,
1803 generation_config: GenerationConfig {
1804 temperature: 0.7,
1805 max_output_tokens: 8192,
1806 },
1807 };
1808
1809 let request = model_provider
1810 .build_generate_content_request(
1811 &auth,
1812 &url,
1813 &body,
1814 "models/gemini-2.0-flash",
1815 true,
1816 Some("test-project"),
1817 Some("ya29.mock-token"),
1818 )
1819 .build()
1820 .unwrap();
1821
1822 let payload = request
1823 .body()
1824 .and_then(|b| b.as_bytes())
1825 .expect("json request body should be bytes");
1826 let json: serde_json::Value = serde_json::from_slice(payload).unwrap();
1827
1828 assert_eq!(json["model"], "gemini-2.0-flash");
1829 assert!(json.get("generationConfig").is_none());
1830 assert!(json.get("request").is_some());
1831 assert!(json["request"].get("generationConfig").is_some());
1832 }
1833
1834 #[test]
1835 fn api_key_request_does_not_set_bearer_header() {
1836 let model_provider =
1837 test_model_provider(Some(GeminiAuth::ExplicitKey("api-key-123".into())));
1838 let auth = GeminiAuth::ExplicitKey("api-key-123".into());
1839 let url = GeminiModelProvider::build_generate_content_url("gemini-2.0-flash", &auth);
1840 let body = GenerateContentRequest {
1841 contents: vec![Content {
1842 role: Some("user".into()),
1843 parts: vec![Part::text("hello")],
1844 }],
1845 system_instruction: None,
1846 generation_config: GenerationConfig {
1847 temperature: 0.7,
1848 max_output_tokens: 8192,
1849 },
1850 };
1851
1852 let request = model_provider
1853 .build_generate_content_request(
1854 &auth,
1855 &url,
1856 &body,
1857 "gemini-2.0-flash",
1858 true,
1859 None,
1860 None,
1861 )
1862 .build()
1863 .unwrap();
1864
1865 assert!(request.headers().get(AUTHORIZATION).is_none());
1866 }
1867
1868 #[test]
1869 fn request_serialization() {
1870 let request = GenerateContentRequest {
1871 contents: vec![Content {
1872 role: Some("user".to_string()),
1873 parts: vec![Part::text("Hello")],
1874 }],
1875 system_instruction: Some(Content {
1876 role: None,
1877 parts: vec![Part::text("You are helpful")],
1878 }),
1879 generation_config: GenerationConfig {
1880 temperature: 0.7,
1881 max_output_tokens: 8192,
1882 },
1883 };
1884
1885 let json = serde_json::to_string(&request).unwrap();
1886 assert!(json.contains("\"role\":\"user\""));
1887 assert!(json.contains("\"text\":\"Hello\""));
1888 assert!(json.contains("\"systemInstruction\""));
1889 assert!(!json.contains("\"system_instruction\""));
1890 assert!(json.contains("\"temperature\":0.7"));
1891 assert!(json.contains("\"maxOutputTokens\":8192"));
1892 }
1893
1894 #[test]
1895 fn internal_request_includes_model() {
1896 let request = InternalGenerateContentEnvelope {
1897 model: "gemini-3-pro-preview".to_string(),
1898 project: Some("test-project".to_string()),
1899 user_prompt_id: Some("prompt-123".to_string()),
1900 request: InternalGenerateContentRequest {
1901 contents: vec![Content {
1902 role: Some("user".to_string()),
1903 parts: vec![Part::text("Hello")],
1904 }],
1905 system_instruction: None,
1906 generation_config: Some(GenerationConfig {
1907 temperature: 0.7,
1908 max_output_tokens: 8192,
1909 }),
1910 },
1911 };
1912
1913 let json = serde_json::to_string(&request).unwrap();
1914 assert!(json.contains("\"model\":\"gemini-3-pro-preview\""));
1915 assert!(json.contains("\"request\""));
1916 assert!(json.contains("\"generationConfig\""));
1917 assert!(json.contains("\"maxOutputTokens\":8192"));
1918 assert!(json.contains("\"user_prompt_id\":\"prompt-123\""));
1919 assert!(json.contains("\"project\":\"test-project\""));
1920 assert!(json.contains("\"role\":\"user\""));
1921 assert!(json.contains("\"temperature\":0.7"));
1922 }
1923
1924 #[test]
1925 fn internal_request_omits_generation_config_when_none() {
1926 let request = InternalGenerateContentEnvelope {
1927 model: "gemini-3-pro-preview".to_string(),
1928 project: Some("test-project".to_string()),
1929 user_prompt_id: None,
1930 request: InternalGenerateContentRequest {
1931 contents: vec![Content {
1932 role: Some("user".to_string()),
1933 parts: vec![Part::text("Hello")],
1934 }],
1935 system_instruction: None,
1936 generation_config: None,
1937 },
1938 };
1939
1940 let json = serde_json::to_string(&request).unwrap();
1941 assert!(!json.contains("generationConfig"));
1942 assert!(json.contains("\"model\":\"gemini-3-pro-preview\""));
1943 }
1944
1945 #[test]
1946 fn internal_request_includes_project() {
1947 let request = InternalGenerateContentEnvelope {
1948 model: "gemini-2.5-flash".to_string(),
1949 project: Some("my-gcp-project-id".to_string()),
1950 user_prompt_id: None,
1951 request: InternalGenerateContentRequest {
1952 contents: vec![Content {
1953 role: Some("user".to_string()),
1954 parts: vec![Part::text("Hello")],
1955 }],
1956 system_instruction: None,
1957 generation_config: None,
1958 },
1959 };
1960
1961 let json = serde_json::to_string(&request).unwrap();
1962 assert!(json.contains("\"project\":\"my-gcp-project-id\""));
1963 }
1964
1965 #[test]
1966 fn creds_deserialize_with_expiry_date() {
1967 let json = r#"{
1968 "access_token": "ya29.test-token",
1969 "refresh_token": "1//test-refresh",
1970 "expiry_date": 4102444800000
1971 }"#;
1972
1973 let creds: GeminiCliOAuthCreds = serde_json::from_str(json).unwrap();
1974 assert_eq!(creds.access_token.as_deref(), Some("ya29.test-token"));
1975 assert_eq!(creds.refresh_token.as_deref(), Some("1//test-refresh"));
1976 assert_eq!(creds.expiry_date, Some(4_102_444_800_000));
1977 assert!(creds.expiry.is_none());
1978 }
1979
1980 #[test]
1981 fn creds_deserialize_accepts_camel_case_fields() {
1982 let json = r#"{
1983 "access_token": "ya29.test-token",
1984 "idToken": "header.payload.sig",
1985 "refresh_token": "1//test-refresh",
1986 "clientId": "test-client-id",
1987 "clientSecret": "test-client-secret",
1988 "expiryDate": 4102444800000
1989 }"#;
1990
1991 let creds: GeminiCliOAuthCreds = serde_json::from_str(json).unwrap();
1992 assert_eq!(creds.id_token.as_deref(), Some("header.payload.sig"));
1993 assert_eq!(creds.client_id.as_deref(), Some("test-client-id"));
1994 assert_eq!(creds.client_secret.as_deref(), Some("test-client-secret"));
1995 assert_eq!(creds.expiry_date, Some(4_102_444_800_000));
1996 }
1997
1998 #[test]
1999 fn oauth_retry_detection_for_generation_config_rejection() {
2000 let err =
2002 "Invalid JSON payload received. Unknown name \"generationConfig\": Cannot find field.";
2003 assert!(
2004 GeminiModelProvider::should_retry_oauth_without_generation_config(
2005 StatusCode::BAD_REQUEST,
2006 err
2007 )
2008 );
2009 let err_json = r#"Invalid JSON payload received. Unknown name \"generationConfig\": Cannot find field."#;
2011 assert!(
2012 GeminiModelProvider::should_retry_oauth_without_generation_config(
2013 StatusCode::BAD_REQUEST,
2014 err_json
2015 )
2016 );
2017 assert!(
2018 !GeminiModelProvider::should_retry_oauth_without_generation_config(
2019 StatusCode::UNAUTHORIZED,
2020 err
2021 )
2022 );
2023 assert!(
2024 !GeminiModelProvider::should_retry_oauth_without_generation_config(
2025 StatusCode::BAD_REQUEST,
2026 "something else"
2027 )
2028 );
2029 }
2030
2031 #[test]
2032 fn response_deserialization() {
2033 let json = r#"{
2034 "candidates": [{
2035 "content": {
2036 "parts": [{"text": "Hello there!"}]
2037 }
2038 }]
2039 }"#;
2040
2041 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2042 assert!(response.candidates.is_some());
2043 let text = response
2044 .candidates
2045 .unwrap()
2046 .into_iter()
2047 .next()
2048 .unwrap()
2049 .content
2050 .unwrap()
2051 .parts
2052 .into_iter()
2053 .next()
2054 .unwrap()
2055 .text;
2056 assert_eq!(text, Some("Hello there!".to_string()));
2057 }
2058
2059 #[test]
2060 fn error_response_deserialization() {
2061 let json = r#"{
2062 "error": {
2063 "message": "Invalid API key"
2064 }
2065 }"#;
2066
2067 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2068 assert!(response.error.is_some());
2069 assert_eq!(response.error.unwrap().message, "Invalid API key");
2070 }
2071
2072 #[test]
2073 fn internal_response_deserialization() {
2074 let json = r#"{
2075 "response": {
2076 "candidates": [{
2077 "content": {
2078 "parts": [{"text": "Hello from internal"}]
2079 }
2080 }]
2081 }
2082 }"#;
2083
2084 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2085 let text = response
2086 .into_effective_response()
2087 .candidates
2088 .unwrap()
2089 .into_iter()
2090 .next()
2091 .unwrap()
2092 .content
2093 .unwrap()
2094 .parts
2095 .into_iter()
2096 .next()
2097 .unwrap()
2098 .text;
2099 assert_eq!(text, Some("Hello from internal".to_string()));
2100 }
2101
2102 #[test]
2105 fn thinking_response_extracts_non_thinking_text() {
2106 let json = r#"{
2107 "candidates": [{
2108 "content": {
2109 "parts": [
2110 {"thought": true, "text": "Let me think about this..."},
2111 {"text": "The answer is 42."},
2112 {"thoughtSignature": "c2lnbmF0dXJl"}
2113 ]
2114 }
2115 }]
2116 }"#;
2117
2118 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2119 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2120 let text = candidate.content.unwrap().effective_text();
2121 assert_eq!(text, Some("The answer is 42.".to_string()));
2122 }
2123
2124 #[test]
2125 fn non_thinking_response_unaffected() {
2126 let json = r#"{
2127 "candidates": [{
2128 "content": {
2129 "parts": [{"text": "Hello there!"}]
2130 }
2131 }]
2132 }"#;
2133
2134 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2135 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2136 let text = candidate.content.unwrap().effective_text();
2137 assert_eq!(text, Some("Hello there!".to_string()));
2138 }
2139
2140 #[test]
2141 fn thinking_only_response_falls_back_to_thinking_text() {
2142 let json = r#"{
2143 "candidates": [{
2144 "content": {
2145 "parts": [
2146 {"thought": true, "text": "I need more context..."},
2147 {"thoughtSignature": "c2lnbmF0dXJl"}
2148 ]
2149 }
2150 }]
2151 }"#;
2152
2153 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2154 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2155 let text = candidate.content.unwrap().effective_text();
2156 assert_eq!(text, Some("I need more context...".to_string()));
2157 }
2158
2159 #[test]
2160 fn empty_parts_returns_none() {
2161 let json = r#"{
2162 "candidates": [{
2163 "content": {
2164 "parts": []
2165 }
2166 }]
2167 }"#;
2168
2169 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2170 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2171 let text = candidate.content.unwrap().effective_text();
2172 assert_eq!(text, None);
2173 }
2174
2175 #[test]
2176 fn multiple_text_parts_concatenated() {
2177 let json = r#"{
2178 "candidates": [{
2179 "content": {
2180 "parts": [
2181 {"text": "Part one. "},
2182 {"text": "Part two."}
2183 ]
2184 }
2185 }]
2186 }"#;
2187
2188 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2189 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2190 let text = candidate.content.unwrap().effective_text();
2191 assert_eq!(text, Some("Part one. Part two.".to_string()));
2192 }
2193
2194 #[test]
2195 fn thought_signature_only_parts_skipped() {
2196 let json = r#"{
2197 "candidates": [{
2198 "content": {
2199 "parts": [
2200 {"thoughtSignature": "c2lnbmF0dXJl"}
2201 ]
2202 }
2203 }]
2204 }"#;
2205
2206 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2207 let candidate = response.candidates.unwrap().into_iter().next().unwrap();
2208 let text = candidate.content.unwrap().effective_text();
2209 assert_eq!(text, None);
2210 }
2211
2212 #[test]
2213 fn internal_response_thinking_model() {
2214 let json = r#"{
2215 "response": {
2216 "candidates": [{
2217 "content": {
2218 "parts": [
2219 {"thought": true, "text": "reasoning..."},
2220 {"text": "final answer"}
2221 ]
2222 }
2223 }]
2224 }
2225 }"#;
2226
2227 let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
2228 let effective = response.into_effective_response();
2229 let candidate = effective.candidates.unwrap().into_iter().next().unwrap();
2230 let text = candidate.content.unwrap().effective_text();
2231 assert_eq!(text, Some("final answer".to_string()));
2232 }
2233
2234 #[tokio::test]
2235 async fn warmup_without_key_is_noop() {
2236 let model_provider = test_model_provider(None);
2237 let result = model_provider.warmup().await;
2238 assert!(result.is_ok());
2239 }
2240
2241 #[tokio::test]
2242 async fn warmup_oauth_is_noop() {
2243 let model_provider = test_model_provider(Some(test_oauth_auth("ya29.mock-token")));
2244 let result = model_provider.warmup().await;
2245 assert!(result.is_ok());
2246 }
2247
2248 #[test]
2249 fn discover_oauth_cred_paths_does_not_panic() {
2250 let _paths = GeminiModelProvider::discover_oauth_cred_paths();
2251 }
2252
2253 #[tokio::test]
2254 async fn rotate_oauth_without_alternatives_returns_false() {
2255 let state = Arc::new(tokio::sync::Mutex::new(OAuthTokenState {
2256 access_token: "ya29.mock".to_string(),
2257 refresh_token: None,
2258 client_id: None,
2259 client_secret: None,
2260 expiry_millis: None,
2261 }));
2262 let model_provider = test_model_provider(Some(GeminiAuth::OAuthToken(state.clone())));
2263 assert!(!model_provider.rotate_oauth_credential(&state).await);
2264 }
2265
2266 #[test]
2267 fn response_parses_usage_metadata() {
2268 let json = r#"{
2269 "candidates": [{"content": {"parts": [{"text": "Hello"}]}}],
2270 "usageMetadata": {"promptTokenCount": 120, "candidatesTokenCount": 40}
2271 }"#;
2272 let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
2273 let usage = resp.usage_metadata.unwrap();
2274 assert_eq!(usage.prompt_token_count, Some(120));
2275 assert_eq!(usage.candidates_token_count, Some(40));
2276 }
2277
2278 #[test]
2279 fn response_usage_metadata_maps_to_token_usage() {
2280 let usage = GeminiUsageMetadata {
2281 prompt_token_count: Some(120),
2282 candidates_token_count: Some(40),
2283 };
2284
2285 let token_usage =
2286 GeminiModelProvider::token_usage_from_metadata(usage).expect("usage counts should map");
2287
2288 assert_eq!(token_usage.input_tokens, Some(120));
2289 assert_eq!(token_usage.output_tokens, Some(40));
2290 assert_eq!(token_usage.cached_input_tokens, None);
2291 }
2292
2293 #[test]
2294 fn empty_usage_metadata_maps_to_none() {
2295 let usage = GeminiUsageMetadata {
2296 prompt_token_count: None,
2297 candidates_token_count: None,
2298 };
2299
2300 assert!(GeminiModelProvider::token_usage_from_metadata(usage).is_none());
2301 }
2302
2303 #[test]
2304 fn wrapped_response_preserves_outer_usage_metadata() {
2305 let json = r#"{
2306 "usageMetadata": {"promptTokenCount": 120, "candidatesTokenCount": 40},
2307 "response": {
2308 "candidates": [{"content": {"parts": [{"text": "Hello"}]}}]
2309 }
2310 }"#;
2311
2312 let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
2313 let effective = resp.into_effective_response();
2314 let usage = effective.usage_metadata.unwrap();
2315
2316 assert_eq!(usage.prompt_token_count, Some(120));
2317 assert_eq!(usage.candidates_token_count, Some(40));
2318 }
2319
2320 #[test]
2321 fn wrapped_response_prefers_inner_usage_metadata() {
2322 let json = r#"{
2323 "usageMetadata": {"promptTokenCount": 120, "candidatesTokenCount": 40},
2324 "response": {
2325 "candidates": [{"content": {"parts": [{"text": "Hello"}]}}],
2326 "usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 2}
2327 }
2328 }"#;
2329
2330 let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
2331 let effective = resp.into_effective_response();
2332 let usage = effective.usage_metadata.unwrap();
2333
2334 assert_eq!(usage.prompt_token_count, Some(5));
2335 assert_eq!(usage.candidates_token_count, Some(2));
2336 }
2337
2338 #[test]
2339 fn response_parses_without_usage_metadata() {
2340 let json = r#"{"candidates": [{"content": {"parts": [{"text": "Hello"}]}}]}"#;
2341 let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
2342 assert!(resp.usage_metadata.is_none());
2343 }
2344
2345 #[tokio::test]
2347 async fn warmup_managed_oauth_requires_auth_service() {
2348 let model_provider = GeminiModelProvider {
2349 alias: "test".to_string(),
2350 auth: Some(GeminiAuth::ManagedOAuth),
2351 oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
2352 oauth_project_seed: None,
2353 oauth_cred_paths: Vec::new(),
2354 oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
2355 auth_service: None, auth_profile_override: None,
2357 oauth_client_id: None,
2358 oauth_client_secret: None,
2359 };
2360
2361 let result = model_provider.warmup().await;
2362 assert!(result.is_err());
2363 assert!(
2364 result
2365 .unwrap_err()
2366 .to_string()
2367 .contains("ManagedOAuth requires auth_service")
2368 );
2369 }
2370
2371 #[tokio::test]
2373 async fn warmup_cli_oauth_skips_validation() {
2374 let model_provider = test_model_provider(Some(test_oauth_auth("fake_token")));
2375 let result = model_provider.warmup().await;
2376 assert!(result.is_ok());
2378 }
2379
2380 #[test]
2383 fn part_text_serializes_as_text_object() {
2384 let part = Part::text("hello");
2385 let json = serde_json::to_value(&part).unwrap();
2386 assert_eq!(json, serde_json::json!({"text": "hello"}));
2387 }
2388
2389 #[test]
2390 fn part_inline_serializes_as_inline_data_object() {
2391 let part = Part::Inline {
2392 inline_data: InlineData {
2393 mime_type: "image/png".to_string(),
2394 data: "iVBOR...".to_string(),
2395 },
2396 };
2397 let json = serde_json::to_value(&part).unwrap();
2398 assert_eq!(
2399 json,
2400 serde_json::json!({"inline_data": {"mime_type": "image/png", "data": "iVBOR..."}})
2401 );
2402 }
2403
2404 #[test]
2405 fn part_text_constructor_accepts_string_and_str() {
2406 let from_str = Part::text("hello");
2407 let from_string = Part::text(String::from("hello"));
2408 assert_eq!(
2410 serde_json::to_value(&from_str).unwrap(),
2411 serde_json::to_value(&from_string).unwrap(),
2412 );
2413 }
2414
2415 #[test]
2416 fn content_with_mixed_parts_serializes_correctly() {
2417 let content = Content {
2418 role: Some("user".to_string()),
2419 parts: vec![
2420 Part::text("Describe this image:"),
2421 Part::Inline {
2422 inline_data: InlineData {
2423 mime_type: "image/jpeg".to_string(),
2424 data: "/9j/4AAQ...".to_string(),
2425 },
2426 },
2427 ],
2428 };
2429 let json = serde_json::to_value(&content).unwrap();
2430 let parts = json["parts"].as_array().unwrap();
2431 assert_eq!(parts.len(), 2);
2432 assert!(parts[0].get("text").is_some());
2433 assert!(parts[1].get("inline_data").is_some());
2434 }
2435
2436 #[test]
2439 fn build_parts_plain_text_returns_single_text_part() {
2440 let parts = build_parts("Hello, world!");
2441 assert_eq!(parts.len(), 1);
2442 assert_eq!(
2443 serde_json::to_value(&parts[0]).unwrap(),
2444 serde_json::json!({"text": "Hello, world!"})
2445 );
2446 }
2447
2448 #[test]
2449 fn build_parts_empty_string_returns_single_text_part() {
2450 let parts = build_parts("");
2451 assert_eq!(parts.len(), 1);
2452 assert_eq!(
2454 serde_json::to_value(&parts[0]).unwrap(),
2455 serde_json::json!({"text": ""})
2456 );
2457 }
2458
2459 #[test]
2460 fn build_parts_extracts_data_uri_as_inline_part() {
2461 let content = "Check this [IMAGE:data:image/png;base64,iVBORw0KGgo=]";
2462 let parts = build_parts(content);
2463 assert_eq!(parts.len(), 2);
2464 assert_eq!(
2466 serde_json::to_value(&parts[0]).unwrap(),
2467 serde_json::json!({"text": "Check this"})
2468 );
2469 assert_eq!(
2471 serde_json::to_value(&parts[1]).unwrap(),
2472 serde_json::json!({"inline_data": {"mime_type": "image/png", "data": "iVBORw0KGgo="}})
2473 );
2474 }
2475
2476 #[test]
2477 fn build_parts_multiple_images() {
2478 let content = "Image A: [IMAGE:data:image/png;base64,AAAA] Image B: [IMAGE:data:image/jpeg;base64,BBBB]";
2479 let parts = build_parts(content);
2480 assert_eq!(parts.len(), 3); let inline_parts: Vec<_> = parts
2483 .iter()
2484 .filter(|p| matches!(p, Part::Inline { .. }))
2485 .collect();
2486 assert_eq!(inline_parts.len(), 2);
2487 }
2488
2489 #[test]
2490 fn build_parts_ignores_non_data_uri_markers() {
2491 let content = "Look [IMAGE:/tmp/photo.png]";
2494 let parts = build_parts(content);
2495 for part in &parts {
2498 assert!(matches!(part, Part::Text { .. }));
2499 }
2500 }
2501
2502 #[test]
2503 fn build_parts_image_only_still_produces_inline_part() {
2504 let content = "[IMAGE:data:image/gif;base64,R0lGODlh]";
2505 let parts = build_parts(content);
2506 assert_eq!(parts.len(), 1);
2508 assert!(matches!(&parts[0], Part::Inline { .. }));
2509 }
2510
2511 #[test]
2514 fn chat_with_history_maps_roles_correctly() {
2515 let messages = vec![
2516 ChatMessage::system("You are helpful"),
2517 ChatMessage::user("Hello [IMAGE:data:image/png;base64,AA==]"),
2518 ChatMessage::assistant("I see the image"),
2519 ];
2520
2521 let (contents, system_instruction) =
2522 GeminiModelProvider::build_chat_contents(&messages, None);
2523
2524 let system_instruction = system_instruction.expect("system prompt should be separated");
2525 assert_eq!(system_instruction.role, None);
2526 assert!(
2527 matches!(&system_instruction.parts[0], Part::Text { text } if text == "You are helpful")
2528 );
2529
2530 assert_eq!(contents.len(), 2);
2531 assert_eq!(contents[0].role.as_deref(), Some("user"));
2532 assert!(
2533 contents[0]
2534 .parts
2535 .iter()
2536 .any(|p| matches!(p, Part::Inline { .. }))
2537 );
2538 assert_eq!(contents[1].role.as_deref(), Some("model"));
2539 assert!(matches!(&contents[1].parts[0], Part::Text { text } if text == "I see the image"));
2540 }
2541
2542 #[test]
2543 fn chat_contents_append_tool_instructions_to_system_prompt() {
2544 let messages = vec![
2545 ChatMessage::system("You are helpful"),
2546 ChatMessage::user("Hello"),
2547 ];
2548
2549 let (_contents, system_instruction) =
2550 GeminiModelProvider::build_chat_contents(&messages, Some("Use tools carefully"));
2551
2552 let system_instruction = system_instruction.expect("system prompt should include tools");
2553 assert!(
2554 matches!(&system_instruction.parts[0], Part::Text { text } if text == "You are helpful\n\nUse tools carefully")
2555 );
2556 }
2557
2558 #[test]
2559 fn chat_contents_create_system_prompt_from_tool_instructions() {
2560 let messages = vec![ChatMessage::user("Hello")];
2561
2562 let (_contents, system_instruction) =
2563 GeminiModelProvider::build_chat_contents(&messages, Some("Use tools carefully"));
2564
2565 let system_instruction =
2566 system_instruction.expect("tool instructions should be system prompt");
2567 assert!(
2568 matches!(&system_instruction.parts[0], Part::Text { text } if text == "Use tools carefully")
2569 );
2570 }
2571}