1use crate::auth::oauth_common::{parse_query_params, url_decode, url_encode};
10use crate::auth::profiles::TokenSet;
11use anyhow::{Context, Result};
12use base64::Engine;
13use chrono::Utc;
14use reqwest::Client;
15use serde::Deserialize;
16use std::collections::BTreeMap;
17use std::time::Duration;
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19use tokio::net::TcpListener;
20
21#[allow(unused_imports)]
23pub use crate::auth::oauth_common::{PkceState, generate_pkce_state};
24
25pub const GOOGLE_OAUTH_AUTHORIZE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
26pub const GOOGLE_OAUTH_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
27pub const GOOGLE_OAUTH_DEVICE_CODE_URL: &str = "https://oauth2.googleapis.com/device/code";
28pub const GEMINI_OAUTH_REDIRECT_URI: &str = "http://localhost:1456/auth/callback";
29
30pub const GEMINI_OAUTH_SCOPES: &str =
32 "openid profile email https://www.googleapis.com/auth/cloud-platform";
33
34#[derive(Debug, Clone)]
35pub struct DeviceCodeStart {
36 pub device_code: String,
37 pub user_code: String,
38 pub verification_uri: String,
39 pub verification_uri_complete: Option<String>,
40 pub expires_in: u64,
41 pub interval: u64,
42}
43
44#[derive(Debug, Deserialize)]
45struct TokenResponse {
46 access_token: String,
47 #[serde(default)]
48 refresh_token: Option<String>,
49 #[serde(default)]
50 id_token: Option<String>,
51 #[serde(default)]
52 expires_in: Option<i64>,
53 #[serde(default)]
54 token_type: Option<String>,
55 #[serde(default)]
56 scope: Option<String>,
57}
58
59#[derive(Debug, Deserialize)]
60struct DeviceCodeResponse {
61 device_code: String,
62 user_code: String,
63 verification_url: String,
64 #[serde(default)]
65 expires_in: Option<u64>,
66 #[serde(default)]
67 interval: Option<u64>,
68}
69
70#[derive(Debug, Deserialize)]
71struct OAuthErrorResponse {
72 error: String,
73 #[serde(default)]
74 error_description: Option<String>,
75}
76
77pub fn build_authorize_url(client_id: &str, pkce: &PkceState) -> Result<String> {
78 let mut params = BTreeMap::new();
79 params.insert("response_type", "code");
80 params.insert("client_id", client_id);
81 params.insert("redirect_uri", GEMINI_OAUTH_REDIRECT_URI);
82 params.insert("scope", GEMINI_OAUTH_SCOPES);
83 params.insert("code_challenge", pkce.code_challenge.as_str());
84 params.insert("code_challenge_method", "S256");
85 params.insert("state", pkce.state.as_str());
86 params.insert("access_type", "offline");
87 params.insert("prompt", "consent");
88
89 let mut encoded: Vec<String> = Vec::with_capacity(params.len());
90 for (k, v) in params {
91 encoded.push(format!("{}={}", url_encode(k), url_encode(v)));
92 }
93
94 Ok(format!(
95 "{}?{}",
96 GOOGLE_OAUTH_AUTHORIZE_URL,
97 encoded.join("&")
98 ))
99}
100
101pub async fn exchange_code_for_tokens(
102 client: &Client,
103 client_id: &str,
104 client_secret: &str,
105 code: &str,
106 pkce: &PkceState,
107) -> Result<TokenSet> {
108 let form = [
109 ("grant_type", "authorization_code"),
110 ("code", code),
111 ("redirect_uri", GEMINI_OAUTH_REDIRECT_URI),
112 ("client_id", client_id),
113 ("client_secret", client_secret),
114 ("code_verifier", &pkce.code_verifier),
115 ];
116
117 let response = client
118 .post(GOOGLE_OAUTH_TOKEN_URL)
119 .form(&form)
120 .send()
121 .await
122 .context("Failed to send token exchange request")?;
123
124 let status = response.status();
125 let body = response
126 .text()
127 .await
128 .context("Failed to read token response body")?;
129
130 if !status.is_success() {
131 if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
132 anyhow::bail!(
133 "Google OAuth error: {} - {}",
134 err.error,
135 err.error_description.unwrap_or_default()
136 );
137 }
138 anyhow::bail!("Google OAuth token exchange failed ({}): {}", status, body);
139 }
140
141 let token_response: TokenResponse =
142 serde_json::from_str(&body).context("Failed to parse token response")?;
143
144 let expires_at = token_response
145 .expires_in
146 .map(|secs| Utc::now() + chrono::Duration::seconds(secs));
147
148 Ok(TokenSet {
149 access_token: token_response.access_token,
150 refresh_token: token_response.refresh_token,
151 id_token: token_response.id_token,
152 expires_at,
153 token_type: token_response.token_type.or_else(|| Some("Bearer".into())),
154 scope: token_response.scope,
155 })
156}
157
158pub async fn refresh_access_token(
159 client: &Client,
160 client_id: &str,
161 client_secret: &str,
162 refresh_token: &str,
163) -> Result<TokenSet> {
164 let form = [
165 ("grant_type", "refresh_token"),
166 ("refresh_token", refresh_token),
167 ("client_id", client_id),
168 ("client_secret", client_secret),
169 ];
170
171 let response = client
172 .post(GOOGLE_OAUTH_TOKEN_URL)
173 .form(&form)
174 .send()
175 .await
176 .context("Failed to send refresh token request")?;
177
178 let status = response.status();
179 let body = response
180 .text()
181 .await
182 .context("Failed to read refresh response body")?;
183
184 if !status.is_success() {
185 if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
186 anyhow::bail!(
187 "Google OAuth refresh error: {} - {}",
188 err.error,
189 err.error_description.unwrap_or_default()
190 );
191 }
192 anyhow::bail!("Google OAuth refresh failed ({}): {}", status, body);
193 }
194
195 let token_response: TokenResponse =
196 serde_json::from_str(&body).context("Failed to parse refresh response")?;
197
198 let expires_at = token_response
199 .expires_in
200 .map(|secs| Utc::now() + chrono::Duration::seconds(secs));
201
202 Ok(TokenSet {
203 access_token: token_response.access_token,
204 refresh_token: token_response.refresh_token,
205 id_token: token_response.id_token,
206 expires_at,
207 token_type: token_response.token_type.or_else(|| Some("Bearer".into())),
208 scope: token_response.scope,
209 })
210}
211
212pub async fn start_device_code_flow(client: &Client, client_id: &str) -> Result<DeviceCodeStart> {
213 let form = [("client_id", client_id), ("scope", GEMINI_OAUTH_SCOPES)];
214
215 let response = client
216 .post(GOOGLE_OAUTH_DEVICE_CODE_URL)
217 .form(&form)
218 .send()
219 .await
220 .context("Failed to start device code flow")?;
221
222 let status = response.status();
223 let body = response
224 .text()
225 .await
226 .context("Failed to read device code response")?;
227
228 if !status.is_success() {
229 if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
230 anyhow::bail!(
231 "Google device code error: {} - {}",
232 err.error,
233 err.error_description.unwrap_or_default()
234 );
235 }
236 anyhow::bail!("Google device code request failed ({}): {}", status, body);
237 }
238
239 let device_response: DeviceCodeResponse =
240 serde_json::from_str(&body).context("Failed to parse device code response")?;
241
242 let user_code = device_response.user_code;
243 let verification_url = device_response.verification_url;
244
245 Ok(DeviceCodeStart {
246 device_code: device_response.device_code,
247 verification_uri_complete: Some(format!("{}?user_code={}", &verification_url, &user_code)),
248 user_code,
249 verification_uri: verification_url,
250 expires_in: device_response.expires_in.unwrap_or(1800),
251 interval: device_response.interval.unwrap_or(5),
252 })
253}
254
255pub async fn poll_device_code_tokens(
256 client: &Client,
257 client_id: &str,
258 client_secret: &str,
259 device: &DeviceCodeStart,
260) -> Result<TokenSet> {
261 let deadline = std::time::Instant::now() + Duration::from_secs(device.expires_in);
262 let interval = Duration::from_secs(device.interval.max(5));
263
264 loop {
265 if std::time::Instant::now() > deadline {
266 anyhow::bail!("Device code expired before authorization was completed");
267 }
268
269 tokio::time::sleep(interval).await;
270
271 let form = [
272 ("client_id", client_id),
273 ("client_secret", client_secret),
274 ("device_code", device.device_code.as_str()),
275 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
276 ];
277
278 let response = client
279 .post(GOOGLE_OAUTH_TOKEN_URL)
280 .form(&form)
281 .send()
282 .await
283 .context("Failed to poll device code")?;
284
285 let status = response.status();
286 let body = response.text().await.unwrap_or_default();
287
288 if status.is_success() {
289 let token_response: TokenResponse =
290 serde_json::from_str(&body).context("Failed to parse token response")?;
291
292 let expires_at = token_response
293 .expires_in
294 .map(|secs| Utc::now() + chrono::Duration::seconds(secs));
295
296 return Ok(TokenSet {
297 access_token: token_response.access_token,
298 refresh_token: token_response.refresh_token,
299 id_token: token_response.id_token,
300 expires_at,
301 token_type: token_response.token_type.or_else(|| Some("Bearer".into())),
302 scope: token_response.scope,
303 });
304 }
305
306 if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
307 match err.error.as_str() {
308 "authorization_pending" => {}
309 "slow_down" => {
310 tokio::time::sleep(Duration::from_secs(5)).await;
311 }
312 "access_denied" => {
313 anyhow::bail!("User denied authorization");
314 }
315 "expired_token" => {
316 anyhow::bail!("Device code expired");
317 }
318 _ => {
319 anyhow::bail!(
320 "Google OAuth error: {} - {}",
321 err.error,
322 err.error_description.unwrap_or_default()
323 );
324 }
325 }
326 }
327 }
328}
329
330pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> Result<String> {
335 let listener = match TcpListener::bind("127.0.0.1:1456").await {
337 Ok(l) => l,
338 Err(e) => {
339 eprintln!("Could not bind to localhost:1456: {e}");
340 eprintln!("Falling back to manual input.");
341 return receive_code_from_stdin(expected_state).await;
342 }
343 };
344
345 println!("Waiting for callback at http://localhost:1456/auth/callback ...");
346 println!("(Or paste the full callback URL / authorization code here if running remotely)");
347
348 tokio::select! {
350 accept_result = async {
351 tokio::time::timeout(timeout, listener.accept()).await
352 } => {
353 match accept_result {
354 Ok(Ok((mut stream, _))) => {
355 let mut buffer = vec![0u8; 4096];
356 let n = stream
357 .read(&mut buffer)
358 .await
359 .context("Failed to read from callback connection")?;
360
361 let request = String::from_utf8_lossy(&buffer[..n]);
362 let (code, state) = parse_callback_request(&request)?;
363
364 if state != expected_state {
365 let response = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n\
366 <html><body><h1>State mismatch</h1><p>Please try again.</p></body></html>";
367 let _ = stream.write_all(response.as_bytes()).await;
368 anyhow::bail!("OAuth state mismatch");
369 }
370
371 let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
372 <html><body><h1>Success!</h1><p>You can close this window and return to the terminal.</p></body></html>";
373 let _ = stream.write_all(response.as_bytes()).await;
374
375 Ok(code)
376 }
377 Ok(Err(e)) => {
378 ::zeroclaw_log::record!(
379 ERROR,
380 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
381 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
382 .with_attrs(::serde_json::json!({
383 "oauth_provider": "gemini",
384 "phase": "callback_accept",
385 "error": format!("{}", e),
386 })),
387 "gemini_oauth: failed to accept callback connection"
388 );
389 Err(anyhow::Error::msg(format!("Failed to accept connection: {e}")))
390 }
391 Err(_) => {
392 eprintln!("\nCallback timeout. Falling back to manual input.");
393 receive_code_from_stdin(expected_state).await
394 }
395 }
396 }
397 stdin_result = receive_code_from_stdin(expected_state) => {
398 stdin_result
399 }
400 }
401}
402
403async fn receive_code_from_stdin(expected_state: &str) -> Result<String> {
405 use std::io::{self, BufRead};
406
407 let expected = expected_state.to_string();
408 let input = tokio::task::spawn_blocking(move || {
409 let stdin = io::stdin();
410 let mut line = String::new();
411 stdin.lock().read_line(&mut line).ok();
412 let trimmed = line.trim().to_string();
413 if trimmed.is_empty() {
414 ::zeroclaw_log::record!(
415 WARN,
416 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
417 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
418 .with_attrs(::serde_json::json!({"oauth_provider": "gemini"})),
419 "gemini_oauth: empty stdin input for OAuth code"
420 );
421 return Err(anyhow::Error::msg("No input received"));
422 }
423 parse_code_from_redirect(&trimmed, Some(&expected))
424 })
425 .await
426 .context("Failed to read from stdin")??;
427
428 Ok(input)
429}
430
431fn parse_callback_request(request: &str) -> Result<(String, String)> {
432 let first_line = request.lines().next().unwrap_or("");
433 let path = first_line
434 .split_whitespace()
435 .nth(1)
436 .unwrap_or("")
437 .to_string();
438
439 let query_start = path.find('?').map(|i| i + 1).unwrap_or(path.len());
440 let query = &path[query_start..];
441
442 let mut code = None;
443 let mut state = None;
444
445 for pair in query.split('&') {
446 if let Some((key, value)) = pair.split_once('=') {
447 match key {
448 "code" => code = Some(url_decode(value)),
449 "state" => state = Some(url_decode(value)),
450 _ => {}
451 }
452 }
453 }
454
455 let code = code.ok_or_else(|| {
456 ::zeroclaw_log::record!(
457 WARN,
458 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
459 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
460 .with_attrs(::serde_json::json!({
461 "oauth_provider": "gemini",
462 "missing": "code",
463 })),
464 "gemini_oauth: callback missing code parameter"
465 );
466 anyhow::Error::msg("No 'code' parameter in callback")
467 })?;
468 let state = state.ok_or_else(|| {
469 ::zeroclaw_log::record!(
470 WARN,
471 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
472 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
473 .with_attrs(::serde_json::json!({
474 "oauth_provider": "gemini",
475 "missing": "state",
476 })),
477 "gemini_oauth: callback missing state parameter"
478 );
479 anyhow::Error::msg("No 'state' parameter in callback")
480 })?;
481
482 Ok((code, state))
483}
484
485pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Result<String> {
486 let trimmed = input.trim();
487 if trimmed.is_empty() {
488 anyhow::bail!("No OAuth code provided");
489 }
490
491 let query = if let Some((_, right)) = trimmed.split_once('?') {
493 right
494 } else {
495 trimmed
496 };
497
498 let params = parse_query_params(query);
499
500 if let Some(code) = params.get("code") {
502 if let Some(expected) = expected_state
504 && let Some(actual) = params.get("state")
505 && actual != expected
506 {
507 anyhow::bail!("OAuth state mismatch: expected {expected}, got {actual}");
508 }
509 return Ok(code.clone());
510 }
511
512 if trimmed.len() > 10 && !trimmed.contains(' ') && !trimmed.contains('&') {
514 return Ok(trimmed.to_string());
515 }
516
517 anyhow::bail!("Could not parse OAuth code from input")
518}
519
520pub fn extract_account_email_from_id_token(id_token: &str) -> Option<String> {
522 let parts: Vec<&str> = id_token.split('.').collect();
523 if parts.len() != 3 {
524 return None;
525 }
526
527 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
528 .decode(parts[1])
529 .ok()?;
530
531 #[derive(Deserialize)]
532 struct IdTokenPayload {
533 email: Option<String>,
534 }
535
536 let payload: IdTokenPayload = serde_json::from_slice(&payload).ok()?;
537 payload.email
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 struct EnvVarRestore {
545 key: &'static str,
546 original: Option<String>,
547 }
548
549 impl EnvVarRestore {
550 fn set(key: &'static str, value: &str) -> Self {
551 let original = std::env::var(key).ok();
552 unsafe { std::env::set_var(key, value) };
554 Self { key, original }
555 }
556 }
557
558 impl Drop for EnvVarRestore {
559 fn drop(&mut self) {
560 if let Some(ref original) = self.original {
561 unsafe { std::env::set_var(self.key, original) };
563 } else {
564 unsafe { std::env::remove_var(self.key) };
566 }
567 }
568 }
569
570 #[test]
571 fn pkce_generates_valid_state() {
572 let pkce = generate_pkce_state();
573 assert!(!pkce.code_verifier.is_empty());
574 assert!(!pkce.code_challenge.is_empty());
575 assert!(!pkce.state.is_empty());
576 }
577
578 #[test]
579 fn authorize_url_contains_required_params() {
580 let _client_id_guard = EnvVarRestore::set("GEMINI_OAUTH_CLIENT_ID", "test-client-id");
582 let _client_secret_guard =
583 EnvVarRestore::set("GEMINI_OAUTH_CLIENT_SECRET", "test-client-secret");
584
585 let pkce = generate_pkce_state();
586 let url =
587 build_authorize_url("test-client-id", &pkce).expect("Failed to build authorize URL");
588 assert!(url.contains("accounts.google.com"));
589 assert!(url.contains("client_id="));
590 assert!(url.contains("redirect_uri="));
591 assert!(url.contains("code_challenge="));
592 assert!(url.contains("access_type=offline"));
593 }
594
595 #[test]
596 fn parse_code_from_url() {
597 let url = "http://localhost:1456/auth/callback?code=4/0test&state=xyz";
598 let code = parse_code_from_redirect(url, Some("xyz")).unwrap();
599 assert_eq!(code, "4/0test");
600 }
601
602 #[test]
603 fn parse_code_from_raw() {
604 let raw = "4/0AcvDMrC1234567890abcdef";
605 let code = parse_code_from_redirect(raw, None).unwrap();
606 assert_eq!(code, raw);
607 }
608
609 #[test]
610 fn extract_email_from_id_token() {
611 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256"}"#);
613 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
614 .encode(r#"{"email":"test@example.com"}"#);
615 let token = format!("{}.{}.signature", header, payload);
616
617 let email = extract_account_email_from_id_token(&token);
618 assert_eq!(email, Some("test@example.com".to_string()));
619 }
620}