Skip to main content

zeroclaw_providers/auth/
gemini_oauth.rs

1//! Google/Gemini OAuth2 authentication flow.
2//!
3//! Supports:
4//! - Authorization code flow with PKCE (loopback redirect)
5//! - Device code flow for headless environments
6//!
7//! Uses the same client credentials as Gemini CLI for compatibility.
8
9use crate::auth::oauth_common::{parse_query_params, url_decode, url_encode};
10use crate::auth::profiles::TokenSet;
11use anyhow::{Context, Result};
12use base64::Engine;
13use chrono::Utc;
14use reqwest::Client;
15use serde::Deserialize;
16use std::collections::BTreeMap;
17use std::time::Duration;
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19use tokio::net::TcpListener;
20
21// Re-export for external use (used by main.rs)
22#[allow(unused_imports)]
23pub use crate::auth::oauth_common::{PkceState, generate_pkce_state};
24
25pub const GOOGLE_OAUTH_AUTHORIZE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
26pub const GOOGLE_OAUTH_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
27pub const GOOGLE_OAUTH_DEVICE_CODE_URL: &str = "https://oauth2.googleapis.com/device/code";
28pub const GEMINI_OAUTH_REDIRECT_URI: &str = "http://localhost:1456/auth/callback";
29
30/// Scopes required for Gemini API access.
31pub const GEMINI_OAUTH_SCOPES: &str =
32    "openid profile email https://www.googleapis.com/auth/cloud-platform";
33
34#[derive(Debug, Clone)]
35pub struct DeviceCodeStart {
36    pub device_code: String,
37    pub user_code: String,
38    pub verification_uri: String,
39    pub verification_uri_complete: Option<String>,
40    pub expires_in: u64,
41    pub interval: u64,
42}
43
44#[derive(Debug, Deserialize)]
45struct TokenResponse {
46    access_token: String,
47    #[serde(default)]
48    refresh_token: Option<String>,
49    #[serde(default)]
50    id_token: Option<String>,
51    #[serde(default)]
52    expires_in: Option<i64>,
53    #[serde(default)]
54    token_type: Option<String>,
55    #[serde(default)]
56    scope: Option<String>,
57}
58
59#[derive(Debug, Deserialize)]
60struct DeviceCodeResponse {
61    device_code: String,
62    user_code: String,
63    verification_url: String,
64    #[serde(default)]
65    expires_in: Option<u64>,
66    #[serde(default)]
67    interval: Option<u64>,
68}
69
70#[derive(Debug, Deserialize)]
71struct OAuthErrorResponse {
72    error: String,
73    #[serde(default)]
74    error_description: Option<String>,
75}
76
77pub fn build_authorize_url(client_id: &str, pkce: &PkceState) -> Result<String> {
78    let mut params = BTreeMap::new();
79    params.insert("response_type", "code");
80    params.insert("client_id", client_id);
81    params.insert("redirect_uri", GEMINI_OAUTH_REDIRECT_URI);
82    params.insert("scope", GEMINI_OAUTH_SCOPES);
83    params.insert("code_challenge", pkce.code_challenge.as_str());
84    params.insert("code_challenge_method", "S256");
85    params.insert("state", pkce.state.as_str());
86    params.insert("access_type", "offline");
87    params.insert("prompt", "consent");
88
89    let mut encoded: Vec<String> = Vec::with_capacity(params.len());
90    for (k, v) in params {
91        encoded.push(format!("{}={}", url_encode(k), url_encode(v)));
92    }
93
94    Ok(format!(
95        "{}?{}",
96        GOOGLE_OAUTH_AUTHORIZE_URL,
97        encoded.join("&")
98    ))
99}
100
101pub async fn exchange_code_for_tokens(
102    client: &Client,
103    client_id: &str,
104    client_secret: &str,
105    code: &str,
106    pkce: &PkceState,
107) -> Result<TokenSet> {
108    let form = [
109        ("grant_type", "authorization_code"),
110        ("code", code),
111        ("redirect_uri", GEMINI_OAUTH_REDIRECT_URI),
112        ("client_id", client_id),
113        ("client_secret", client_secret),
114        ("code_verifier", &pkce.code_verifier),
115    ];
116
117    let response = client
118        .post(GOOGLE_OAUTH_TOKEN_URL)
119        .form(&form)
120        .send()
121        .await
122        .context("Failed to send token exchange request")?;
123
124    let status = response.status();
125    let body = response
126        .text()
127        .await
128        .context("Failed to read token response body")?;
129
130    if !status.is_success() {
131        if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
132            anyhow::bail!(
133                "Google OAuth error: {} - {}",
134                err.error,
135                err.error_description.unwrap_or_default()
136            );
137        }
138        anyhow::bail!("Google OAuth token exchange failed ({}): {}", status, body);
139    }
140
141    let token_response: TokenResponse =
142        serde_json::from_str(&body).context("Failed to parse token response")?;
143
144    let expires_at = token_response
145        .expires_in
146        .map(|secs| Utc::now() + chrono::Duration::seconds(secs));
147
148    Ok(TokenSet {
149        access_token: token_response.access_token,
150        refresh_token: token_response.refresh_token,
151        id_token: token_response.id_token,
152        expires_at,
153        token_type: token_response.token_type.or_else(|| Some("Bearer".into())),
154        scope: token_response.scope,
155    })
156}
157
158pub async fn refresh_access_token(
159    client: &Client,
160    client_id: &str,
161    client_secret: &str,
162    refresh_token: &str,
163) -> Result<TokenSet> {
164    let form = [
165        ("grant_type", "refresh_token"),
166        ("refresh_token", refresh_token),
167        ("client_id", client_id),
168        ("client_secret", client_secret),
169    ];
170
171    let response = client
172        .post(GOOGLE_OAUTH_TOKEN_URL)
173        .form(&form)
174        .send()
175        .await
176        .context("Failed to send refresh token request")?;
177
178    let status = response.status();
179    let body = response
180        .text()
181        .await
182        .context("Failed to read refresh response body")?;
183
184    if !status.is_success() {
185        if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
186            anyhow::bail!(
187                "Google OAuth refresh error: {} - {}",
188                err.error,
189                err.error_description.unwrap_or_default()
190            );
191        }
192        anyhow::bail!("Google OAuth refresh failed ({}): {}", status, body);
193    }
194
195    let token_response: TokenResponse =
196        serde_json::from_str(&body).context("Failed to parse refresh response")?;
197
198    let expires_at = token_response
199        .expires_in
200        .map(|secs| Utc::now() + chrono::Duration::seconds(secs));
201
202    Ok(TokenSet {
203        access_token: token_response.access_token,
204        refresh_token: token_response.refresh_token,
205        id_token: token_response.id_token,
206        expires_at,
207        token_type: token_response.token_type.or_else(|| Some("Bearer".into())),
208        scope: token_response.scope,
209    })
210}
211
212pub async fn start_device_code_flow(client: &Client, client_id: &str) -> Result<DeviceCodeStart> {
213    let form = [("client_id", client_id), ("scope", GEMINI_OAUTH_SCOPES)];
214
215    let response = client
216        .post(GOOGLE_OAUTH_DEVICE_CODE_URL)
217        .form(&form)
218        .send()
219        .await
220        .context("Failed to start device code flow")?;
221
222    let status = response.status();
223    let body = response
224        .text()
225        .await
226        .context("Failed to read device code response")?;
227
228    if !status.is_success() {
229        if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
230            anyhow::bail!(
231                "Google device code error: {} - {}",
232                err.error,
233                err.error_description.unwrap_or_default()
234            );
235        }
236        anyhow::bail!("Google device code request failed ({}): {}", status, body);
237    }
238
239    let device_response: DeviceCodeResponse =
240        serde_json::from_str(&body).context("Failed to parse device code response")?;
241
242    let user_code = device_response.user_code;
243    let verification_url = device_response.verification_url;
244
245    Ok(DeviceCodeStart {
246        device_code: device_response.device_code,
247        verification_uri_complete: Some(format!("{}?user_code={}", &verification_url, &user_code)),
248        user_code,
249        verification_uri: verification_url,
250        expires_in: device_response.expires_in.unwrap_or(1800),
251        interval: device_response.interval.unwrap_or(5),
252    })
253}
254
255pub async fn poll_device_code_tokens(
256    client: &Client,
257    client_id: &str,
258    client_secret: &str,
259    device: &DeviceCodeStart,
260) -> Result<TokenSet> {
261    let deadline = std::time::Instant::now() + Duration::from_secs(device.expires_in);
262    let interval = Duration::from_secs(device.interval.max(5));
263
264    loop {
265        if std::time::Instant::now() > deadline {
266            anyhow::bail!("Device code expired before authorization was completed");
267        }
268
269        tokio::time::sleep(interval).await;
270
271        let form = [
272            ("client_id", client_id),
273            ("client_secret", client_secret),
274            ("device_code", device.device_code.as_str()),
275            ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
276        ];
277
278        let response = client
279            .post(GOOGLE_OAUTH_TOKEN_URL)
280            .form(&form)
281            .send()
282            .await
283            .context("Failed to poll device code")?;
284
285        let status = response.status();
286        let body = response.text().await.unwrap_or_default();
287
288        if status.is_success() {
289            let token_response: TokenResponse =
290                serde_json::from_str(&body).context("Failed to parse token response")?;
291
292            let expires_at = token_response
293                .expires_in
294                .map(|secs| Utc::now() + chrono::Duration::seconds(secs));
295
296            return Ok(TokenSet {
297                access_token: token_response.access_token,
298                refresh_token: token_response.refresh_token,
299                id_token: token_response.id_token,
300                expires_at,
301                token_type: token_response.token_type.or_else(|| Some("Bearer".into())),
302                scope: token_response.scope,
303            });
304        }
305
306        if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
307            match err.error.as_str() {
308                "authorization_pending" => {}
309                "slow_down" => {
310                    tokio::time::sleep(Duration::from_secs(5)).await;
311                }
312                "access_denied" => {
313                    anyhow::bail!("User denied authorization");
314                }
315                "expired_token" => {
316                    anyhow::bail!("Device code expired");
317                }
318                _ => {
319                    anyhow::bail!(
320                        "Google OAuth error: {} - {}",
321                        err.error,
322                        err.error_description.unwrap_or_default()
323                    );
324                }
325            }
326        }
327    }
328}
329
330/// Receive OAuth code via loopback callback OR manual stdin input.
331///
332/// If the callback server can't receive the redirect (e.g., remote/headless environment),
333/// the user can paste the full callback URL or just the code.
334pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> Result<String> {
335    // Try to bind to the callback port
336    let listener = match TcpListener::bind("127.0.0.1:1456").await {
337        Ok(l) => l,
338        Err(e) => {
339            eprintln!("Could not bind to localhost:1456: {e}");
340            eprintln!("Falling back to manual input.");
341            return receive_code_from_stdin(expected_state).await;
342        }
343    };
344
345    println!("Waiting for callback at http://localhost:1456/auth/callback ...");
346    println!("(Or paste the full callback URL / authorization code here if running remotely)");
347
348    // Race between: callback arriving OR stdin input
349    tokio::select! {
350        accept_result = async {
351            tokio::time::timeout(timeout, listener.accept()).await
352        } => {
353            match accept_result {
354                Ok(Ok((mut stream, _))) => {
355                    let mut buffer = vec![0u8; 4096];
356                    let n = stream
357                        .read(&mut buffer)
358                        .await
359                        .context("Failed to read from callback connection")?;
360
361                    let request = String::from_utf8_lossy(&buffer[..n]);
362                    let (code, state) = parse_callback_request(&request)?;
363
364                    if state != expected_state {
365                        let response = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n\
366                             <html><body><h1>State mismatch</h1><p>Please try again.</p></body></html>";
367                        let _ = stream.write_all(response.as_bytes()).await;
368                        anyhow::bail!("OAuth state mismatch");
369                    }
370
371                    let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
372                         <html><body><h1>Success!</h1><p>You can close this window and return to the terminal.</p></body></html>";
373                    let _ = stream.write_all(response.as_bytes()).await;
374
375                    Ok(code)
376                }
377                Ok(Err(e)) => {
378                    ::zeroclaw_log::record!(
379                        ERROR,
380                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
381                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
382                            .with_attrs(::serde_json::json!({
383                                "oauth_provider": "gemini",
384                                "phase": "callback_accept",
385                                "error": format!("{}", e),
386                            })),
387                        "gemini_oauth: failed to accept callback connection"
388                    );
389                    Err(anyhow::Error::msg(format!("Failed to accept connection: {e}")))
390                }
391                Err(_) => {
392                    eprintln!("\nCallback timeout. Falling back to manual input.");
393                    receive_code_from_stdin(expected_state).await
394                }
395            }
396        }
397        stdin_result = receive_code_from_stdin(expected_state) => {
398            stdin_result
399        }
400    }
401}
402
403/// Read authorization code from stdin (supports full URL or raw code).
404async fn receive_code_from_stdin(expected_state: &str) -> Result<String> {
405    use std::io::{self, BufRead};
406
407    let expected = expected_state.to_string();
408    let input = tokio::task::spawn_blocking(move || {
409        let stdin = io::stdin();
410        let mut line = String::new();
411        stdin.lock().read_line(&mut line).ok();
412        let trimmed = line.trim().to_string();
413        if trimmed.is_empty() {
414            ::zeroclaw_log::record!(
415                WARN,
416                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
417                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
418                    .with_attrs(::serde_json::json!({"oauth_provider": "gemini"})),
419                "gemini_oauth: empty stdin input for OAuth code"
420            );
421            return Err(anyhow::Error::msg("No input received"));
422        }
423        parse_code_from_redirect(&trimmed, Some(&expected))
424    })
425    .await
426    .context("Failed to read from stdin")??;
427
428    Ok(input)
429}
430
431fn parse_callback_request(request: &str) -> Result<(String, String)> {
432    let first_line = request.lines().next().unwrap_or("");
433    let path = first_line
434        .split_whitespace()
435        .nth(1)
436        .unwrap_or("")
437        .to_string();
438
439    let query_start = path.find('?').map(|i| i + 1).unwrap_or(path.len());
440    let query = &path[query_start..];
441
442    let mut code = None;
443    let mut state = None;
444
445    for pair in query.split('&') {
446        if let Some((key, value)) = pair.split_once('=') {
447            match key {
448                "code" => code = Some(url_decode(value)),
449                "state" => state = Some(url_decode(value)),
450                _ => {}
451            }
452        }
453    }
454
455    let code = code.ok_or_else(|| {
456        ::zeroclaw_log::record!(
457            WARN,
458            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
459                .with_outcome(::zeroclaw_log::EventOutcome::Failure)
460                .with_attrs(::serde_json::json!({
461                    "oauth_provider": "gemini",
462                    "missing": "code",
463                })),
464            "gemini_oauth: callback missing code parameter"
465        );
466        anyhow::Error::msg("No 'code' parameter in callback")
467    })?;
468    let state = state.ok_or_else(|| {
469        ::zeroclaw_log::record!(
470            WARN,
471            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
472                .with_outcome(::zeroclaw_log::EventOutcome::Failure)
473                .with_attrs(::serde_json::json!({
474                    "oauth_provider": "gemini",
475                    "missing": "state",
476                })),
477            "gemini_oauth: callback missing state parameter"
478        );
479        anyhow::Error::msg("No 'state' parameter in callback")
480    })?;
481
482    Ok((code, state))
483}
484
485pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Result<String> {
486    let trimmed = input.trim();
487    if trimmed.is_empty() {
488        anyhow::bail!("No OAuth code provided");
489    }
490
491    // Extract query string
492    let query = if let Some((_, right)) = trimmed.split_once('?') {
493        right
494    } else {
495        trimmed
496    };
497
498    let params = parse_query_params(query);
499
500    // If we have code param, extract it
501    if let Some(code) = params.get("code") {
502        // Validate state if expected
503        if let Some(expected) = expected_state
504            && let Some(actual) = params.get("state")
505            && actual != expected
506        {
507            anyhow::bail!("OAuth state mismatch: expected {expected}, got {actual}");
508        }
509        return Ok(code.clone());
510    }
511
512    // Otherwise, assume it's the raw code (if long enough and no spaces)
513    if trimmed.len() > 10 && !trimmed.contains(' ') && !trimmed.contains('&') {
514        return Ok(trimmed.to_string());
515    }
516
517    anyhow::bail!("Could not parse OAuth code from input")
518}
519
520/// Extract account email from Google ID token.
521pub fn extract_account_email_from_id_token(id_token: &str) -> Option<String> {
522    let parts: Vec<&str> = id_token.split('.').collect();
523    if parts.len() != 3 {
524        return None;
525    }
526
527    let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
528        .decode(parts[1])
529        .ok()?;
530
531    #[derive(Deserialize)]
532    struct IdTokenPayload {
533        email: Option<String>,
534    }
535
536    let payload: IdTokenPayload = serde_json::from_slice(&payload).ok()?;
537    payload.email
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    struct EnvVarRestore {
545        key: &'static str,
546        original: Option<String>,
547    }
548
549    impl EnvVarRestore {
550        fn set(key: &'static str, value: &str) -> Self {
551            let original = std::env::var(key).ok();
552            // SAFETY: test-only, single-threaded test runner.
553            unsafe { std::env::set_var(key, value) };
554            Self { key, original }
555        }
556    }
557
558    impl Drop for EnvVarRestore {
559        fn drop(&mut self) {
560            if let Some(ref original) = self.original {
561                // SAFETY: test-only, single-threaded test runner.
562                unsafe { std::env::set_var(self.key, original) };
563            } else {
564                // SAFETY: test-only, single-threaded test runner.
565                unsafe { std::env::remove_var(self.key) };
566            }
567        }
568    }
569
570    #[test]
571    fn pkce_generates_valid_state() {
572        let pkce = generate_pkce_state();
573        assert!(!pkce.code_verifier.is_empty());
574        assert!(!pkce.code_challenge.is_empty());
575        assert!(!pkce.state.is_empty());
576    }
577
578    #[test]
579    fn authorize_url_contains_required_params() {
580        // Isolate environment changes so this test cannot leak into other test modules.
581        let _client_id_guard = EnvVarRestore::set("GEMINI_OAUTH_CLIENT_ID", "test-client-id");
582        let _client_secret_guard =
583            EnvVarRestore::set("GEMINI_OAUTH_CLIENT_SECRET", "test-client-secret");
584
585        let pkce = generate_pkce_state();
586        let url =
587            build_authorize_url("test-client-id", &pkce).expect("Failed to build authorize URL");
588        assert!(url.contains("accounts.google.com"));
589        assert!(url.contains("client_id="));
590        assert!(url.contains("redirect_uri="));
591        assert!(url.contains("code_challenge="));
592        assert!(url.contains("access_type=offline"));
593    }
594
595    #[test]
596    fn parse_code_from_url() {
597        let url = "http://localhost:1456/auth/callback?code=4/0test&state=xyz";
598        let code = parse_code_from_redirect(url, Some("xyz")).unwrap();
599        assert_eq!(code, "4/0test");
600    }
601
602    #[test]
603    fn parse_code_from_raw() {
604        let raw = "4/0AcvDMrC1234567890abcdef";
605        let code = parse_code_from_redirect(raw, None).unwrap();
606        assert_eq!(code, raw);
607    }
608
609    #[test]
610    fn extract_email_from_id_token() {
611        // Minimal test JWT with email claim
612        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256"}"#);
613        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
614            .encode(r#"{"email":"test@example.com"}"#);
615        let token = format!("{}.{}.signature", header, payload);
616
617        let email = extract_account_email_from_id_token(&token);
618        assert_eq!(email, Some("test@example.com".to_string()));
619    }
620}