Skip to main content

zeroclaw_providers/auth/
openai_oauth.rs

1use crate::auth::oauth_common::{parse_query_params, url_encode};
2
3use crate::auth::profiles::TokenSet;
4use anyhow::{Context, Result};
5use base64::Engine;
6use chrono::Utc;
7use reqwest::Client;
8use serde::Deserialize;
9use std::collections::BTreeMap;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpListener;
13
14// Re-export for external use (used by main.rs)
15#[allow(unused_imports)]
16pub use crate::auth::oauth_common::{PkceState, generate_pkce_state};
17
18pub const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
19pub const OPENAI_OAUTH_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
20pub const OPENAI_OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
21pub const OPENAI_OAUTH_DEVICE_CODE_URL: &str = "https://auth.openai.com/oauth/device/code";
22pub const OPENAI_OAUTH_REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
23
24#[derive(Debug, Clone)]
25pub struct DeviceCodeStart {
26    pub device_code: String,
27    pub user_code: String,
28    pub verification_uri: String,
29    pub verification_uri_complete: Option<String>,
30    pub expires_in: u64,
31    pub interval: u64,
32    pub message: Option<String>,
33}
34
35#[derive(Debug, Deserialize)]
36struct TokenResponse {
37    access_token: String,
38    #[serde(default)]
39    refresh_token: Option<String>,
40    #[serde(default)]
41    id_token: Option<String>,
42    #[serde(default)]
43    expires_in: Option<i64>,
44    #[serde(default)]
45    token_type: Option<String>,
46    #[serde(default)]
47    scope: Option<String>,
48}
49
50#[derive(Debug, Deserialize)]
51struct DeviceCodeResponse {
52    device_code: String,
53    user_code: String,
54    verification_uri: String,
55    #[serde(default)]
56    verification_uri_complete: Option<String>,
57    expires_in: u64,
58    #[serde(default)]
59    interval: Option<u64>,
60    #[serde(default)]
61    message: Option<String>,
62}
63
64#[derive(Debug, Deserialize)]
65struct OAuthErrorResponse {
66    error: String,
67    #[serde(default)]
68    error_description: Option<String>,
69}
70
71pub fn build_authorize_url(pkce: &PkceState) -> String {
72    let mut params = BTreeMap::new();
73    params.insert("response_type", "code");
74    params.insert("client_id", OPENAI_OAUTH_CLIENT_ID);
75    params.insert("redirect_uri", OPENAI_OAUTH_REDIRECT_URI);
76    params.insert("scope", "openid profile email offline_access");
77    params.insert("code_challenge", pkce.code_challenge.as_str());
78    params.insert("code_challenge_method", "S256");
79    params.insert("state", pkce.state.as_str());
80    params.insert("codex_cli_simplified_flow", "true");
81    params.insert("id_token_add_organizations", "true");
82
83    let mut encoded: Vec<String> = Vec::with_capacity(params.len());
84    for (k, v) in params {
85        encoded.push(format!("{}={}", url_encode(k), url_encode(v)));
86    }
87
88    format!("{OPENAI_OAUTH_AUTHORIZE_URL}?{}", encoded.join("&"))
89}
90
91pub async fn exchange_code_for_tokens(
92    client: &Client,
93    code: &str,
94    pkce: &PkceState,
95) -> Result<TokenSet> {
96    let form = [
97        ("grant_type", "authorization_code"),
98        ("code", code),
99        ("client_id", OPENAI_OAUTH_CLIENT_ID),
100        ("redirect_uri", OPENAI_OAUTH_REDIRECT_URI),
101        ("code_verifier", pkce.code_verifier.as_str()),
102    ];
103
104    let response = client
105        .post(OPENAI_OAUTH_TOKEN_URL)
106        .form(&form)
107        .send()
108        .await
109        .context("Failed to exchange OpenAI OAuth authorization code")?;
110
111    parse_token_response(response).await
112}
113
114pub async fn refresh_access_token(client: &Client, refresh_token: &str) -> Result<TokenSet> {
115    let form = [
116        ("grant_type", "refresh_token"),
117        ("refresh_token", refresh_token),
118        ("client_id", OPENAI_OAUTH_CLIENT_ID),
119    ];
120
121    let response = client
122        .post(OPENAI_OAUTH_TOKEN_URL)
123        .form(&form)
124        .send()
125        .await
126        .context("Failed to refresh OpenAI OAuth token")?;
127
128    parse_token_response(response).await
129}
130
131pub async fn start_device_code_flow(client: &Client) -> Result<DeviceCodeStart> {
132    let form = [
133        ("client_id", OPENAI_OAUTH_CLIENT_ID),
134        ("scope", "openid profile email offline_access"),
135    ];
136
137    let response = client
138        .post(OPENAI_OAUTH_DEVICE_CODE_URL)
139        .form(&form)
140        .send()
141        .await
142        .context("Failed to start OpenAI OAuth device-code flow")?;
143
144    if !response.status().is_success() {
145        let status = response.status();
146        let body = response.text().await.unwrap_or_default();
147        anyhow::bail!("OpenAI device-code start failed ({status}): {body}");
148    }
149
150    let parsed: DeviceCodeResponse = response
151        .json()
152        .await
153        .context("Failed to parse OpenAI device-code response")?;
154
155    Ok(DeviceCodeStart {
156        device_code: parsed.device_code,
157        user_code: parsed.user_code,
158        verification_uri: parsed.verification_uri,
159        verification_uri_complete: parsed.verification_uri_complete,
160        expires_in: parsed.expires_in,
161        interval: parsed.interval.unwrap_or(5).max(1),
162        message: parsed.message,
163    })
164}
165
166pub async fn poll_device_code_tokens(
167    client: &Client,
168    device: &DeviceCodeStart,
169) -> Result<TokenSet> {
170    let started = Instant::now();
171    let mut interval_secs = device.interval.max(1);
172
173    loop {
174        if started.elapsed() > Duration::from_secs(device.expires_in) {
175            anyhow::bail!("Device-code flow timed out before authorization completed");
176        }
177
178        tokio::time::sleep(Duration::from_secs(interval_secs)).await;
179
180        let form = [
181            ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
182            ("device_code", device.device_code.as_str()),
183            ("client_id", OPENAI_OAUTH_CLIENT_ID),
184        ];
185
186        let response = client
187            .post(OPENAI_OAUTH_TOKEN_URL)
188            .form(&form)
189            .send()
190            .await
191            .context("Failed polling OpenAI device-code token endpoint")?;
192
193        if response.status().is_success() {
194            return parse_token_response(response).await;
195        }
196
197        let status = response.status();
198        let text = response.text().await.unwrap_or_default();
199
200        if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&text) {
201            match err.error.as_str() {
202                "authorization_pending" => {
203                    continue;
204                }
205                "slow_down" => {
206                    interval_secs = interval_secs.saturating_add(5);
207                    continue;
208                }
209                "access_denied" => {
210                    anyhow::bail!("OpenAI device-code authorization was denied")
211                }
212                "expired_token" => {
213                    anyhow::bail!("OpenAI device-code expired")
214                }
215                _ => {
216                    anyhow::bail!(
217                        "OpenAI device-code polling failed ({status}): {}",
218                        err.error_description.unwrap_or(err.error)
219                    )
220                }
221            }
222        }
223
224        anyhow::bail!("OpenAI device-code polling failed ({status}): {text}");
225    }
226}
227
228pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> Result<String> {
229    let listener = TcpListener::bind("127.0.0.1:1455")
230        .await
231        .context("Failed to bind callback listener at 127.0.0.1:1455")?;
232
233    let accepted = tokio::time::timeout(timeout, listener.accept())
234        .await
235        .context("Timed out waiting for browser callback")?
236        .context("Failed to accept callback connection")?;
237
238    let (mut stream, _) = accepted;
239    let mut buffer = vec![0_u8; 8192];
240    let bytes_read = stream
241        .read(&mut buffer)
242        .await
243        .context("Failed to read callback request")?;
244
245    let request = String::from_utf8_lossy(&buffer[..bytes_read]);
246    let first_line = request.lines().next().ok_or_else(|| {
247        ::zeroclaw_log::record!(
248            WARN,
249            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
250                .with_outcome(::zeroclaw_log::EventOutcome::Failure)
251                .with_attrs(::serde_json::json!({"oauth_provider": "openai"})),
252            "openai_oauth: malformed callback request"
253        );
254        anyhow::Error::msg("Malformed callback request")
255    })?;
256
257    let path = first_line.split_whitespace().nth(1).ok_or_else(|| {
258        ::zeroclaw_log::record!(
259            WARN,
260            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
261                .with_outcome(::zeroclaw_log::EventOutcome::Failure)
262                .with_attrs(::serde_json::json!({"oauth_provider": "openai"})),
263            "openai_oauth: callback request missing path"
264        );
265        anyhow::Error::msg("Callback request missing path")
266    })?;
267
268    let code = parse_code_from_redirect(path, Some(expected_state))?;
269
270    let body =
271        "<html><body><h2>ZeroClaw login complete</h2><p>You can close this tab.</p></body></html>";
272    let response = format!(
273        "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
274        body.len(),
275        body
276    );
277    let _ = stream.write_all(response.as_bytes()).await;
278
279    Ok(code)
280}
281
282pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Result<String> {
283    let trimmed = input.trim();
284    if trimmed.is_empty() {
285        anyhow::bail!("No OAuth code provided");
286    }
287
288    let query = if let Some((_, right)) = trimmed.split_once('?') {
289        right
290    } else {
291        trimmed
292    };
293
294    let params = parse_query_params(query);
295    let is_callback_payload = trimmed.contains('?')
296        || params.contains_key("code")
297        || params.contains_key("state")
298        || params.contains_key("error");
299
300    if let Some(err) = params.get("error") {
301        let desc = params
302            .get("error_description")
303            .cloned()
304            .unwrap_or_else(|| "OAuth authorization failed".to_string());
305        anyhow::bail!("OpenAI OAuth error: {err} ({desc})");
306    }
307
308    if let Some(expected_state) = expected_state {
309        if let Some(got) = params.get("state") {
310            if got != expected_state {
311                anyhow::bail!("OAuth state mismatch");
312            }
313        } else if is_callback_payload {
314            anyhow::bail!("Missing OAuth state in callback");
315        }
316    }
317
318    if let Some(code) = params.get("code").cloned() {
319        return Ok(code);
320    }
321
322    if !is_callback_payload {
323        return Ok(trimmed.to_string());
324    }
325
326    anyhow::bail!("Missing OAuth code in callback")
327}
328
329pub fn extract_account_id_from_jwt(token: &str) -> Option<String> {
330    let payload = token.split('.').nth(1)?;
331    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
332        .decode(payload)
333        .ok()?;
334    let claims: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
335
336    for key in [
337        "account_id",
338        "accountId",
339        "acct",
340        "sub",
341        "https://api.openai.com/account_id",
342    ] {
343        if let Some(value) = claims.get(key).and_then(|v| v.as_str())
344            && !value.trim().is_empty()
345        {
346            return Some(value.to_string());
347        }
348    }
349
350    None
351}
352
353pub fn extract_expiry_from_jwt(token: &str) -> Option<chrono::DateTime<Utc>> {
354    let payload = token.split('.').nth(1)?;
355    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
356        .decode(payload)
357        .ok()?;
358    let claims: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
359    let exp = claims.get("exp").and_then(|v| v.as_i64())?;
360    chrono::DateTime::<Utc>::from_timestamp(exp, 0)
361}
362
363async fn parse_token_response(response: reqwest::Response) -> Result<TokenSet> {
364    if !response.status().is_success() {
365        let status = response.status();
366        let body = response.text().await.unwrap_or_default();
367        anyhow::bail!("OpenAI OAuth token request failed ({status}): {body}");
368    }
369
370    let token: TokenResponse = response
371        .json()
372        .await
373        .context("Failed to parse OpenAI token response")?;
374
375    let expires_at = token.expires_in.and_then(|seconds| {
376        if seconds <= 0 {
377            None
378        } else {
379            Some(Utc::now() + chrono::Duration::seconds(seconds))
380        }
381    });
382
383    Ok(TokenSet {
384        access_token: token.access_token,
385        refresh_token: token.refresh_token,
386        id_token: token.id_token,
387        expires_at,
388        token_type: token.token_type,
389        scope: token.scope,
390    })
391}
392
393/// Import an existing OpenAI Codex auth-profile JSON (the file
394/// `~/.codex/auth.json` produced by the upstream Codex CLI) into
395/// ZeroClaw's auth store. Replaces the `import_openai_codex_auth_profile`
396/// helper formerly in `src/main.rs`.
397pub async fn import_codex_auth_profile(
398    auth_service: &super::AuthService,
399    profile: &str,
400    import_path: &std::path::Path,
401) -> anyhow::Result<()> {
402    use anyhow::Context;
403
404    #[derive(serde::Deserialize)]
405    struct CodexAuthTokens {
406        access_token: String,
407        #[serde(default)]
408        refresh_token: Option<String>,
409        #[serde(default)]
410        id_token: Option<String>,
411        #[serde(default)]
412        account_id: Option<String>,
413    }
414
415    #[derive(serde::Deserialize)]
416    struct CodexAuthFile {
417        tokens: CodexAuthTokens,
418    }
419
420    let raw = std::fs::read_to_string(import_path).with_context(|| {
421        format!(
422            "Failed to read import file {}",
423            import_path.display().to_string()
424        )
425    })?;
426    let imported: CodexAuthFile = serde_json::from_str(&raw).with_context(|| {
427        format!(
428            "Failed to parse import file {}",
429            import_path.display().to_string()
430        )
431    })?;
432    let expires_at = extract_expiry_from_jwt(&imported.tokens.access_token);
433
434    let token_set = crate::auth::profiles::TokenSet {
435        access_token: imported.tokens.access_token,
436        refresh_token: imported.tokens.refresh_token,
437        id_token: imported.tokens.id_token,
438        expires_at,
439        token_type: Some("Bearer".to_string()),
440        scope: None,
441    };
442
443    let account_id = imported
444        .tokens
445        .account_id
446        .or_else(|| extract_account_id_from_jwt(&token_set.access_token));
447    if account_id.is_none() {
448        ::zeroclaw_log::record!(
449            WARN,
450            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
451                .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
452            "Could not extract OpenAI account id from imported access token; \
453             requests may fail until re-authentication."
454        );
455    }
456
457    auth_service
458        .store_openai_tokens(profile, token_set, account_id, true)
459        .await?;
460
461    Ok(())
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    #[test]
469    fn pkce_generation_is_valid() {
470        let pkce = generate_pkce_state();
471        assert!(pkce.code_verifier.len() >= 43);
472        assert!(!pkce.code_challenge.is_empty());
473        assert!(!pkce.state.is_empty());
474    }
475
476    #[test]
477    fn parse_redirect_url_extracts_code() {
478        let code = parse_code_from_redirect(
479            "http://127.0.0.1:1455/auth/callback?code=abc123&state=xyz",
480            Some("xyz"),
481        )
482        .unwrap();
483        assert_eq!(code, "abc123");
484    }
485
486    #[test]
487    fn parse_redirect_accepts_raw_code() {
488        let code = parse_code_from_redirect("raw-code", None).unwrap();
489        assert_eq!(code, "raw-code");
490    }
491
492    #[test]
493    fn parse_redirect_rejects_state_mismatch() {
494        let err = parse_code_from_redirect("/auth/callback?code=x&state=a", Some("b")).unwrap_err();
495        assert!(err.to_string().contains("state mismatch"));
496    }
497
498    #[test]
499    fn parse_redirect_rejects_error_without_code() {
500        let err = parse_code_from_redirect(
501            "/auth/callback?error=access_denied&error_description=user+cancelled",
502            Some("xyz"),
503        )
504        .unwrap_err();
505        assert!(
506            err.to_string()
507                .contains("OpenAI OAuth error: access_denied")
508        );
509    }
510
511    #[test]
512    fn extract_account_id_from_jwt_payload() {
513        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("{}");
514        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
515            .encode("{\"account_id\":\"acct_123\"}");
516        let token = format!("{header}.{payload}.sig");
517
518        let account = extract_account_id_from_jwt(&token);
519        assert_eq!(account.as_deref(), Some("acct_123"));
520    }
521}