1use crate::auth::oauth_common::{parse_query_params, url_encode};
2
3use crate::auth::profiles::TokenSet;
4use anyhow::{Context, Result};
5use base64::Engine;
6use chrono::Utc;
7use reqwest::Client;
8use serde::Deserialize;
9use std::collections::BTreeMap;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpListener;
13
14#[allow(unused_imports)]
16pub use crate::auth::oauth_common::{PkceState, generate_pkce_state};
17
18pub const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
19pub const OPENAI_OAUTH_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
20pub const OPENAI_OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
21pub const OPENAI_OAUTH_DEVICE_CODE_URL: &str = "https://auth.openai.com/oauth/device/code";
22pub const OPENAI_OAUTH_REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
23
24#[derive(Debug, Clone)]
25pub struct DeviceCodeStart {
26 pub device_code: String,
27 pub user_code: String,
28 pub verification_uri: String,
29 pub verification_uri_complete: Option<String>,
30 pub expires_in: u64,
31 pub interval: u64,
32 pub message: Option<String>,
33}
34
35#[derive(Debug, Deserialize)]
36struct TokenResponse {
37 access_token: String,
38 #[serde(default)]
39 refresh_token: Option<String>,
40 #[serde(default)]
41 id_token: Option<String>,
42 #[serde(default)]
43 expires_in: Option<i64>,
44 #[serde(default)]
45 token_type: Option<String>,
46 #[serde(default)]
47 scope: Option<String>,
48}
49
50#[derive(Debug, Deserialize)]
51struct DeviceCodeResponse {
52 device_code: String,
53 user_code: String,
54 verification_uri: String,
55 #[serde(default)]
56 verification_uri_complete: Option<String>,
57 expires_in: u64,
58 #[serde(default)]
59 interval: Option<u64>,
60 #[serde(default)]
61 message: Option<String>,
62}
63
64#[derive(Debug, Deserialize)]
65struct OAuthErrorResponse {
66 error: String,
67 #[serde(default)]
68 error_description: Option<String>,
69}
70
71pub fn build_authorize_url(pkce: &PkceState) -> String {
72 let mut params = BTreeMap::new();
73 params.insert("response_type", "code");
74 params.insert("client_id", OPENAI_OAUTH_CLIENT_ID);
75 params.insert("redirect_uri", OPENAI_OAUTH_REDIRECT_URI);
76 params.insert("scope", "openid profile email offline_access");
77 params.insert("code_challenge", pkce.code_challenge.as_str());
78 params.insert("code_challenge_method", "S256");
79 params.insert("state", pkce.state.as_str());
80 params.insert("codex_cli_simplified_flow", "true");
81 params.insert("id_token_add_organizations", "true");
82
83 let mut encoded: Vec<String> = Vec::with_capacity(params.len());
84 for (k, v) in params {
85 encoded.push(format!("{}={}", url_encode(k), url_encode(v)));
86 }
87
88 format!("{OPENAI_OAUTH_AUTHORIZE_URL}?{}", encoded.join("&"))
89}
90
91pub async fn exchange_code_for_tokens(
92 client: &Client,
93 code: &str,
94 pkce: &PkceState,
95) -> Result<TokenSet> {
96 let form = [
97 ("grant_type", "authorization_code"),
98 ("code", code),
99 ("client_id", OPENAI_OAUTH_CLIENT_ID),
100 ("redirect_uri", OPENAI_OAUTH_REDIRECT_URI),
101 ("code_verifier", pkce.code_verifier.as_str()),
102 ];
103
104 let response = client
105 .post(OPENAI_OAUTH_TOKEN_URL)
106 .form(&form)
107 .send()
108 .await
109 .context("Failed to exchange OpenAI OAuth authorization code")?;
110
111 parse_token_response(response).await
112}
113
114pub async fn refresh_access_token(client: &Client, refresh_token: &str) -> Result<TokenSet> {
115 let form = [
116 ("grant_type", "refresh_token"),
117 ("refresh_token", refresh_token),
118 ("client_id", OPENAI_OAUTH_CLIENT_ID),
119 ];
120
121 let response = client
122 .post(OPENAI_OAUTH_TOKEN_URL)
123 .form(&form)
124 .send()
125 .await
126 .context("Failed to refresh OpenAI OAuth token")?;
127
128 parse_token_response(response).await
129}
130
131pub async fn start_device_code_flow(client: &Client) -> Result<DeviceCodeStart> {
132 let form = [
133 ("client_id", OPENAI_OAUTH_CLIENT_ID),
134 ("scope", "openid profile email offline_access"),
135 ];
136
137 let response = client
138 .post(OPENAI_OAUTH_DEVICE_CODE_URL)
139 .form(&form)
140 .send()
141 .await
142 .context("Failed to start OpenAI OAuth device-code flow")?;
143
144 if !response.status().is_success() {
145 let status = response.status();
146 let body = response.text().await.unwrap_or_default();
147 anyhow::bail!("OpenAI device-code start failed ({status}): {body}");
148 }
149
150 let parsed: DeviceCodeResponse = response
151 .json()
152 .await
153 .context("Failed to parse OpenAI device-code response")?;
154
155 Ok(DeviceCodeStart {
156 device_code: parsed.device_code,
157 user_code: parsed.user_code,
158 verification_uri: parsed.verification_uri,
159 verification_uri_complete: parsed.verification_uri_complete,
160 expires_in: parsed.expires_in,
161 interval: parsed.interval.unwrap_or(5).max(1),
162 message: parsed.message,
163 })
164}
165
166pub async fn poll_device_code_tokens(
167 client: &Client,
168 device: &DeviceCodeStart,
169) -> Result<TokenSet> {
170 let started = Instant::now();
171 let mut interval_secs = device.interval.max(1);
172
173 loop {
174 if started.elapsed() > Duration::from_secs(device.expires_in) {
175 anyhow::bail!("Device-code flow timed out before authorization completed");
176 }
177
178 tokio::time::sleep(Duration::from_secs(interval_secs)).await;
179
180 let form = [
181 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
182 ("device_code", device.device_code.as_str()),
183 ("client_id", OPENAI_OAUTH_CLIENT_ID),
184 ];
185
186 let response = client
187 .post(OPENAI_OAUTH_TOKEN_URL)
188 .form(&form)
189 .send()
190 .await
191 .context("Failed polling OpenAI device-code token endpoint")?;
192
193 if response.status().is_success() {
194 return parse_token_response(response).await;
195 }
196
197 let status = response.status();
198 let text = response.text().await.unwrap_or_default();
199
200 if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&text) {
201 match err.error.as_str() {
202 "authorization_pending" => {
203 continue;
204 }
205 "slow_down" => {
206 interval_secs = interval_secs.saturating_add(5);
207 continue;
208 }
209 "access_denied" => {
210 anyhow::bail!("OpenAI device-code authorization was denied")
211 }
212 "expired_token" => {
213 anyhow::bail!("OpenAI device-code expired")
214 }
215 _ => {
216 anyhow::bail!(
217 "OpenAI device-code polling failed ({status}): {}",
218 err.error_description.unwrap_or(err.error)
219 )
220 }
221 }
222 }
223
224 anyhow::bail!("OpenAI device-code polling failed ({status}): {text}");
225 }
226}
227
228pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> Result<String> {
229 let listener = TcpListener::bind("127.0.0.1:1455")
230 .await
231 .context("Failed to bind callback listener at 127.0.0.1:1455")?;
232
233 let accepted = tokio::time::timeout(timeout, listener.accept())
234 .await
235 .context("Timed out waiting for browser callback")?
236 .context("Failed to accept callback connection")?;
237
238 let (mut stream, _) = accepted;
239 let mut buffer = vec![0_u8; 8192];
240 let bytes_read = stream
241 .read(&mut buffer)
242 .await
243 .context("Failed to read callback request")?;
244
245 let request = String::from_utf8_lossy(&buffer[..bytes_read]);
246 let first_line = request.lines().next().ok_or_else(|| {
247 ::zeroclaw_log::record!(
248 WARN,
249 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
250 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
251 .with_attrs(::serde_json::json!({"oauth_provider": "openai"})),
252 "openai_oauth: malformed callback request"
253 );
254 anyhow::Error::msg("Malformed callback request")
255 })?;
256
257 let path = first_line.split_whitespace().nth(1).ok_or_else(|| {
258 ::zeroclaw_log::record!(
259 WARN,
260 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
261 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
262 .with_attrs(::serde_json::json!({"oauth_provider": "openai"})),
263 "openai_oauth: callback request missing path"
264 );
265 anyhow::Error::msg("Callback request missing path")
266 })?;
267
268 let code = parse_code_from_redirect(path, Some(expected_state))?;
269
270 let body =
271 "<html><body><h2>ZeroClaw login complete</h2><p>You can close this tab.</p></body></html>";
272 let response = format!(
273 "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
274 body.len(),
275 body
276 );
277 let _ = stream.write_all(response.as_bytes()).await;
278
279 Ok(code)
280}
281
282pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Result<String> {
283 let trimmed = input.trim();
284 if trimmed.is_empty() {
285 anyhow::bail!("No OAuth code provided");
286 }
287
288 let query = if let Some((_, right)) = trimmed.split_once('?') {
289 right
290 } else {
291 trimmed
292 };
293
294 let params = parse_query_params(query);
295 let is_callback_payload = trimmed.contains('?')
296 || params.contains_key("code")
297 || params.contains_key("state")
298 || params.contains_key("error");
299
300 if let Some(err) = params.get("error") {
301 let desc = params
302 .get("error_description")
303 .cloned()
304 .unwrap_or_else(|| "OAuth authorization failed".to_string());
305 anyhow::bail!("OpenAI OAuth error: {err} ({desc})");
306 }
307
308 if let Some(expected_state) = expected_state {
309 if let Some(got) = params.get("state") {
310 if got != expected_state {
311 anyhow::bail!("OAuth state mismatch");
312 }
313 } else if is_callback_payload {
314 anyhow::bail!("Missing OAuth state in callback");
315 }
316 }
317
318 if let Some(code) = params.get("code").cloned() {
319 return Ok(code);
320 }
321
322 if !is_callback_payload {
323 return Ok(trimmed.to_string());
324 }
325
326 anyhow::bail!("Missing OAuth code in callback")
327}
328
329pub fn extract_account_id_from_jwt(token: &str) -> Option<String> {
330 let payload = token.split('.').nth(1)?;
331 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
332 .decode(payload)
333 .ok()?;
334 let claims: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
335
336 for key in [
337 "account_id",
338 "accountId",
339 "acct",
340 "sub",
341 "https://api.openai.com/account_id",
342 ] {
343 if let Some(value) = claims.get(key).and_then(|v| v.as_str())
344 && !value.trim().is_empty()
345 {
346 return Some(value.to_string());
347 }
348 }
349
350 None
351}
352
353pub fn extract_expiry_from_jwt(token: &str) -> Option<chrono::DateTime<Utc>> {
354 let payload = token.split('.').nth(1)?;
355 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
356 .decode(payload)
357 .ok()?;
358 let claims: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
359 let exp = claims.get("exp").and_then(|v| v.as_i64())?;
360 chrono::DateTime::<Utc>::from_timestamp(exp, 0)
361}
362
363async fn parse_token_response(response: reqwest::Response) -> Result<TokenSet> {
364 if !response.status().is_success() {
365 let status = response.status();
366 let body = response.text().await.unwrap_or_default();
367 anyhow::bail!("OpenAI OAuth token request failed ({status}): {body}");
368 }
369
370 let token: TokenResponse = response
371 .json()
372 .await
373 .context("Failed to parse OpenAI token response")?;
374
375 let expires_at = token.expires_in.and_then(|seconds| {
376 if seconds <= 0 {
377 None
378 } else {
379 Some(Utc::now() + chrono::Duration::seconds(seconds))
380 }
381 });
382
383 Ok(TokenSet {
384 access_token: token.access_token,
385 refresh_token: token.refresh_token,
386 id_token: token.id_token,
387 expires_at,
388 token_type: token.token_type,
389 scope: token.scope,
390 })
391}
392
393pub async fn import_codex_auth_profile(
398 auth_service: &super::AuthService,
399 profile: &str,
400 import_path: &std::path::Path,
401) -> anyhow::Result<()> {
402 use anyhow::Context;
403
404 #[derive(serde::Deserialize)]
405 struct CodexAuthTokens {
406 access_token: String,
407 #[serde(default)]
408 refresh_token: Option<String>,
409 #[serde(default)]
410 id_token: Option<String>,
411 #[serde(default)]
412 account_id: Option<String>,
413 }
414
415 #[derive(serde::Deserialize)]
416 struct CodexAuthFile {
417 tokens: CodexAuthTokens,
418 }
419
420 let raw = std::fs::read_to_string(import_path).with_context(|| {
421 format!(
422 "Failed to read import file {}",
423 import_path.display().to_string()
424 )
425 })?;
426 let imported: CodexAuthFile = serde_json::from_str(&raw).with_context(|| {
427 format!(
428 "Failed to parse import file {}",
429 import_path.display().to_string()
430 )
431 })?;
432 let expires_at = extract_expiry_from_jwt(&imported.tokens.access_token);
433
434 let token_set = crate::auth::profiles::TokenSet {
435 access_token: imported.tokens.access_token,
436 refresh_token: imported.tokens.refresh_token,
437 id_token: imported.tokens.id_token,
438 expires_at,
439 token_type: Some("Bearer".to_string()),
440 scope: None,
441 };
442
443 let account_id = imported
444 .tokens
445 .account_id
446 .or_else(|| extract_account_id_from_jwt(&token_set.access_token));
447 if account_id.is_none() {
448 ::zeroclaw_log::record!(
449 WARN,
450 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
451 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
452 "Could not extract OpenAI account id from imported access token; \
453 requests may fail until re-authentication."
454 );
455 }
456
457 auth_service
458 .store_openai_tokens(profile, token_set, account_id, true)
459 .await?;
460
461 Ok(())
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 #[test]
469 fn pkce_generation_is_valid() {
470 let pkce = generate_pkce_state();
471 assert!(pkce.code_verifier.len() >= 43);
472 assert!(!pkce.code_challenge.is_empty());
473 assert!(!pkce.state.is_empty());
474 }
475
476 #[test]
477 fn parse_redirect_url_extracts_code() {
478 let code = parse_code_from_redirect(
479 "http://127.0.0.1:1455/auth/callback?code=abc123&state=xyz",
480 Some("xyz"),
481 )
482 .unwrap();
483 assert_eq!(code, "abc123");
484 }
485
486 #[test]
487 fn parse_redirect_accepts_raw_code() {
488 let code = parse_code_from_redirect("raw-code", None).unwrap();
489 assert_eq!(code, "raw-code");
490 }
491
492 #[test]
493 fn parse_redirect_rejects_state_mismatch() {
494 let err = parse_code_from_redirect("/auth/callback?code=x&state=a", Some("b")).unwrap_err();
495 assert!(err.to_string().contains("state mismatch"));
496 }
497
498 #[test]
499 fn parse_redirect_rejects_error_without_code() {
500 let err = parse_code_from_redirect(
501 "/auth/callback?error=access_denied&error_description=user+cancelled",
502 Some("xyz"),
503 )
504 .unwrap_err();
505 assert!(
506 err.to_string()
507 .contains("OpenAI OAuth error: access_denied")
508 );
509 }
510
511 #[test]
512 fn extract_account_id_from_jwt_payload() {
513 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("{}");
514 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
515 .encode("{\"account_id\":\"acct_123\"}");
516 let token = format!("{header}.{payload}.sig");
517
518 let account = extract_account_id_from_jwt(&token);
519 assert_eq!(account.as_deref(), Some("acct_123"));
520 }
521}