Skip to main content

zeroclaw_tools/microsoft365/
auth.rs

1use anyhow::Context;
2use parking_lot::RwLock;
3use serde::{Deserialize, Serialize};
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use tokio::sync::Mutex;
8
9/// Cached OAuth2 token state persisted to disk between runs.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct CachedTokenState {
12    pub access_token: String,
13    pub refresh_token: Option<String>,
14    /// Unix timestamp (seconds) when the access token expires.
15    pub expires_at: i64,
16}
17
18impl CachedTokenState {
19    /// Returns `true` when the token is expired or will expire within 60 seconds.
20    pub fn is_expired(&self) -> bool {
21        let now = chrono::Utc::now().timestamp();
22        self.expires_at <= now + 60
23    }
24}
25
26/// Thread-safe token cache with disk persistence.
27pub struct TokenCache {
28    inner: RwLock<Option<CachedTokenState>>,
29    /// Serialises the slow acquire/refresh path so only one caller performs the
30    /// network round-trip while others wait and then read the updated cache.
31    acquire_lock: Mutex<()>,
32    config: super::types::Microsoft365ResolvedConfig,
33    cache_path: PathBuf,
34}
35
36impl TokenCache {
37    pub fn new(
38        config: super::types::Microsoft365ResolvedConfig,
39        zeroclaw_dir: &std::path::Path,
40    ) -> anyhow::Result<Self> {
41        if config.token_cache_encrypted {
42            anyhow::bail!(
43                "microsoft365: token_cache_encrypted is enabled but encryption is not yet \
44                 implemented; refusing to store tokens in plaintext. Set token_cache_encrypted \
45                 to false or wait for encryption support."
46            );
47        }
48
49        // Scope cache file to (tenant_id, client_id, auth_flow) so config
50        // changes never reuse tokens from a different account/flow.
51        let mut hasher = DefaultHasher::new();
52        config.tenant_id.hash(&mut hasher);
53        config.client_id.hash(&mut hasher);
54        config.auth_flow.hash(&mut hasher);
55        let fingerprint = format!("{:016x}", hasher.finish());
56
57        let cache_path = zeroclaw_dir.join(format!("ms365_token_cache_{fingerprint}.json"));
58        let cached = Self::load_from_disk(&cache_path);
59        Ok(Self {
60            inner: RwLock::new(cached),
61            acquire_lock: Mutex::new(()),
62            config,
63            cache_path,
64        })
65    }
66
67    /// Get a valid access token, refreshing or re-authenticating as needed.
68    pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result<String> {
69        // Fast path: cached and not expired.
70        {
71            let guard = self.inner.read();
72            if let Some(ref state) = *guard
73                && !state.is_expired()
74            {
75                return Ok(state.access_token.clone());
76            }
77        }
78
79        // Slow path: serialise through a mutex so only one caller performs the
80        // network round-trip while concurrent callers wait and re-check.
81        let _lock = self.acquire_lock.lock().await;
82
83        // Re-check after acquiring the lock — another caller may have refreshed
84        // while we were waiting.
85        {
86            let guard = self.inner.read();
87            if let Some(ref state) = *guard
88                && !state.is_expired()
89            {
90                return Ok(state.access_token.clone());
91            }
92        }
93
94        let new_state = self.acquire_token(client).await?;
95        let token = new_state.access_token.clone();
96        self.persist_to_disk(&new_state);
97        *self.inner.write() = Some(new_state);
98        Ok(token)
99    }
100
101    async fn acquire_token(&self, client: &reqwest::Client) -> anyhow::Result<CachedTokenState> {
102        // Try refresh first if we have a refresh token and the flow supports it.
103        // Client credentials flow does not issue refresh tokens, so skip the
104        // attempt entirely to avoid a wasted round-trip.
105        if self.config.auth_flow.as_str() != "client_credentials" {
106            // Clone the token out so the RwLock guard is dropped before the await.
107            let refresh_token_copy = {
108                let guard = self.inner.read();
109                guard.as_ref().and_then(|state| state.refresh_token.clone())
110            };
111            if let Some(refresh_tok) = refresh_token_copy {
112                match self.refresh_token(client, &refresh_tok).await {
113                    Ok(new_state) => return Ok(new_state),
114                    Err(e) => {
115                        ::zeroclaw_log::record!(
116                            DEBUG,
117                            ::zeroclaw_log::Event::new(
118                                module_path!(),
119                                ::zeroclaw_log::Action::Note
120                            )
121                            .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
122                            "ms365: refresh token failed, re-authenticating"
123                        );
124                    }
125                }
126            }
127        }
128
129        match self.config.auth_flow.as_str() {
130            "client_credentials" => self.client_credentials_flow(client).await,
131            "device_code" => self.device_code_flow(client).await,
132            other => anyhow::bail!("Unsupported auth flow: {other}"),
133        }
134    }
135
136    async fn client_credentials_flow(
137        &self,
138        client: &reqwest::Client,
139    ) -> anyhow::Result<CachedTokenState> {
140        let client_secret = self
141            .config
142            .client_secret
143            .as_deref()
144            .context("client_credentials flow requires client_secret")?;
145
146        let token_url = format!(
147            "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
148            self.config.tenant_id
149        );
150
151        let scope = self.config.scopes.join(" ");
152
153        let resp = client
154            .post(&token_url)
155            .form(&[
156                ("grant_type", "client_credentials"),
157                ("client_id", &self.config.client_id),
158                ("client_secret", client_secret),
159                ("scope", &scope),
160            ])
161            .send()
162            .await
163            .context("ms365: failed to request client_credentials token")?;
164
165        if !resp.status().is_success() {
166            let status = resp.status();
167            let body = resp.text().await.unwrap_or_default();
168            ::zeroclaw_log::record!(
169                DEBUG,
170                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
171                    .with_attrs(::serde_json::json!({"body": body})),
172                "ms365: client_credentials raw OAuth error"
173            );
174            anyhow::bail!("ms365: client_credentials token request failed ({status})");
175        }
176
177        let token_resp: TokenResponse = resp
178            .json()
179            .await
180            .context("ms365: failed to parse token response")?;
181
182        Ok(CachedTokenState {
183            access_token: token_resp.access_token,
184            refresh_token: token_resp.refresh_token,
185            expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
186        })
187    }
188
189    async fn device_code_flow(&self, client: &reqwest::Client) -> anyhow::Result<CachedTokenState> {
190        let device_code_url = format!(
191            "https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
192            self.config.tenant_id
193        );
194        let scope = self.config.scopes.join(" ");
195
196        let resp = client
197            .post(&device_code_url)
198            .form(&[
199                ("client_id", self.config.client_id.as_str()),
200                ("scope", &scope),
201            ])
202            .send()
203            .await
204            .context("ms365: failed to request device code")?;
205
206        if !resp.status().is_success() {
207            let status = resp.status();
208            let body = resp.text().await.unwrap_or_default();
209            ::zeroclaw_log::record!(
210                DEBUG,
211                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
212                    .with_attrs(::serde_json::json!({"body": body})),
213                "ms365: device_code initiation raw error"
214            );
215            anyhow::bail!("ms365: device code request failed ({status})");
216        }
217
218        let device_resp: DeviceCodeResponse = resp
219            .json()
220            .await
221            .context("ms365: failed to parse device code response")?;
222
223        // Log only a generic prompt; the full device_resp.message may contain
224        // sensitive verification URIs or codes that should not appear in logs.
225        ::zeroclaw_log::record!(
226            INFO,
227            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
228            "ms365: device code auth required — follow the instructions shown to the user"
229        );
230        // Print the user-facing message to stderr so the operator can act on it
231        // without it being captured in structured log sinks.
232        eprintln!("ms365: {}", device_resp.message);
233
234        let token_url = format!(
235            "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
236            self.config.tenant_id
237        );
238
239        let interval = device_resp.interval.max(5);
240        let max_polls = u32::try_from(
241            (device_resp.expires_in / i64::try_from(interval).unwrap_or(i64::MAX)).max(1),
242        )
243        .unwrap_or(u32::MAX);
244
245        for _ in 0..max_polls {
246            tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
247
248            let poll_resp = client
249                .post(&token_url)
250                .form(&[
251                    ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
252                    ("client_id", self.config.client_id.as_str()),
253                    ("device_code", &device_resp.device_code),
254                ])
255                .send()
256                .await
257                .context("ms365: failed to poll device code token")?;
258
259            if poll_resp.status().is_success() {
260                let token_resp: TokenResponse = poll_resp
261                    .json()
262                    .await
263                    .context("ms365: failed to parse token response")?;
264                return Ok(CachedTokenState {
265                    access_token: token_resp.access_token,
266                    refresh_token: token_resp.refresh_token,
267                    expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
268                });
269            }
270
271            let body = poll_resp.text().await.unwrap_or_default();
272            if body.contains("authorization_pending") {
273                continue;
274            }
275            if body.contains("slow_down") {
276                tokio::time::sleep(std::time::Duration::from_secs(5)).await;
277                continue;
278            }
279            ::zeroclaw_log::record!(
280                DEBUG,
281                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
282                    .with_attrs(::serde_json::json!({"body": body})),
283                "ms365: device code polling raw error"
284            );
285            anyhow::bail!("ms365: device code polling failed");
286        }
287
288        anyhow::bail!("ms365: device code flow timed out waiting for user authorization")
289    }
290
291    async fn refresh_token(
292        &self,
293        client: &reqwest::Client,
294        refresh_token: &str,
295    ) -> anyhow::Result<CachedTokenState> {
296        let token_url = format!(
297            "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
298            self.config.tenant_id
299        );
300
301        let mut params = vec![
302            ("grant_type", "refresh_token"),
303            ("client_id", self.config.client_id.as_str()),
304            ("refresh_token", refresh_token),
305        ];
306
307        let secret_ref;
308        if let Some(ref secret) = self.config.client_secret {
309            secret_ref = secret.as_str();
310            params.push(("client_secret", secret_ref));
311        }
312
313        let resp = client
314            .post(&token_url)
315            .form(&params)
316            .send()
317            .await
318            .context("ms365: failed to refresh token")?;
319
320        if !resp.status().is_success() {
321            let status = resp.status();
322            let body = resp.text().await.unwrap_or_default();
323            ::zeroclaw_log::record!(
324                DEBUG,
325                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
326                    .with_attrs(::serde_json::json!({"body": body})),
327                "ms365: token refresh raw error"
328            );
329            anyhow::bail!("ms365: token refresh failed ({status})");
330        }
331
332        let token_resp: TokenResponse = resp
333            .json()
334            .await
335            .context("ms365: failed to parse refresh token response")?;
336
337        Ok(CachedTokenState {
338            access_token: token_resp.access_token,
339            refresh_token: token_resp
340                .refresh_token
341                .or_else(|| Some(refresh_token.to_string())),
342            expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
343        })
344    }
345
346    fn load_from_disk(path: &std::path::Path) -> Option<CachedTokenState> {
347        let data = std::fs::read_to_string(path).ok()?;
348        serde_json::from_str(&data).ok()
349    }
350
351    fn persist_to_disk(&self, state: &CachedTokenState) {
352        if let Ok(json) = serde_json::to_string_pretty(state)
353            && let Err(e) = std::fs::write(&self.cache_path, json)
354        {
355            ::zeroclaw_log::record!(
356                WARN,
357                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
358                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
359                    .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
360                "ms365: failed to persist token cache"
361            );
362        }
363    }
364}
365
366#[derive(Deserialize)]
367struct TokenResponse {
368    access_token: String,
369    #[serde(default)]
370    refresh_token: Option<String>,
371    #[serde(default = "default_expires_in")]
372    expires_in: i64,
373}
374
375fn default_expires_in() -> i64 {
376    3600
377}
378
379#[derive(Deserialize)]
380struct DeviceCodeResponse {
381    device_code: String,
382    message: String,
383    #[serde(default = "default_device_interval")]
384    interval: u64,
385    #[serde(default = "default_device_expires_in")]
386    expires_in: i64,
387}
388
389fn default_device_interval() -> u64 {
390    5
391}
392
393fn default_device_expires_in() -> i64 {
394    900
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn token_is_expired_when_past_deadline() {
403        let state = CachedTokenState {
404            access_token: "test".into(),
405            refresh_token: None,
406            expires_at: chrono::Utc::now().timestamp() - 10,
407        };
408        assert!(state.is_expired());
409    }
410
411    #[test]
412    fn token_is_expired_within_buffer() {
413        let state = CachedTokenState {
414            access_token: "test".into(),
415            refresh_token: None,
416            expires_at: chrono::Utc::now().timestamp() + 30,
417        };
418        assert!(state.is_expired());
419    }
420
421    #[test]
422    fn token_is_valid_when_far_from_expiry() {
423        let state = CachedTokenState {
424            access_token: "test".into(),
425            refresh_token: None,
426            expires_at: chrono::Utc::now().timestamp() + 3600,
427        };
428        assert!(!state.is_expired());
429    }
430
431    #[test]
432    fn load_from_disk_returns_none_for_missing_file() {
433        let path = std::path::Path::new("/nonexistent/ms365_token_cache.json");
434        assert!(TokenCache::load_from_disk(path).is_none());
435    }
436}