1use anyhow::{Context, Result};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::collections::BTreeMap;
5use std::io::Write;
6use std::path::{Path, PathBuf};
7use std::time::Duration;
8use tokio::fs::{self, OpenOptions};
9use tokio::io::AsyncWriteExt;
10use tokio::time::sleep;
11use zeroclaw_config::secrets::SecretStore;
12
13const CURRENT_SCHEMA_VERSION: u32 = 1;
14const PROFILES_FILENAME: &str = "auth-profiles.json";
15const LOCK_FILENAME: &str = "auth-profiles.lock";
16const LOCK_WAIT_MS: u64 = 50;
17const LOCK_TIMEOUT_MS: u64 = 10_000;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "kebab-case")]
21pub enum AuthProfileKind {
22 OAuth,
23 Token,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct TokenSet {
28 pub access_token: String,
29 #[serde(default)]
30 pub refresh_token: Option<String>,
31 #[serde(default)]
32 pub id_token: Option<String>,
33 #[serde(default)]
34 pub expires_at: Option<DateTime<Utc>>,
35 #[serde(default)]
36 pub token_type: Option<String>,
37 #[serde(default)]
38 pub scope: Option<String>,
39}
40
41impl TokenSet {
42 pub fn is_expiring_within(&self, skew: Duration) -> bool {
43 match self.expires_at {
44 Some(expires_at) => {
45 let now_plus_skew =
46 Utc::now() + chrono::Duration::from_std(skew).unwrap_or_default();
47 expires_at <= now_plus_skew
48 }
49 None => false,
50 }
51 }
52}
53
54#[derive(Clone, Serialize, Deserialize)]
55pub struct AuthProfile {
56 pub id: String,
57 pub model_provider: String,
58 pub profile_name: String,
59 pub kind: AuthProfileKind,
60 #[serde(default)]
61 pub account_id: Option<String>,
62 #[serde(default)]
63 pub workspace_id: Option<String>,
64 #[serde(default)]
65 pub token_set: Option<TokenSet>,
66 #[serde(default)]
67 pub token: Option<String>,
68 #[serde(default)]
69 pub metadata: BTreeMap<String, String>,
70 pub created_at: DateTime<Utc>,
71 pub updated_at: DateTime<Utc>,
72}
73
74impl std::fmt::Debug for AuthProfile {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("AuthProfile")
77 .field("id", &self.id)
78 .field("model_provider", &self.model_provider)
79 .field("profile_name", &self.profile_name)
80 .field("kind", &self.kind)
81 .field("workspace_id", &self.workspace_id)
82 .field("metadata", &self.metadata)
83 .field("created_at", &self.created_at)
84 .field("updated_at", &self.updated_at)
85 .finish_non_exhaustive()
86 }
87}
88
89impl AuthProfile {
90 pub fn new_oauth(model_provider: &str, profile_name: &str, token_set: TokenSet) -> Self {
91 let now = Utc::now();
92 let id = profile_id(model_provider, profile_name);
93 Self {
94 id,
95 model_provider: model_provider.to_string(),
96 profile_name: profile_name.to_string(),
97 kind: AuthProfileKind::OAuth,
98 account_id: None,
99 workspace_id: None,
100 token_set: Some(token_set),
101 token: None,
102 metadata: BTreeMap::new(),
103 created_at: now,
104 updated_at: now,
105 }
106 }
107
108 pub fn new_token(model_provider: &str, profile_name: &str, token: String) -> Self {
109 let now = Utc::now();
110 let id = profile_id(model_provider, profile_name);
111 Self {
112 id,
113 model_provider: model_provider.to_string(),
114 profile_name: profile_name.to_string(),
115 kind: AuthProfileKind::Token,
116 account_id: None,
117 workspace_id: None,
118 token_set: None,
119 token: Some(token),
120 metadata: BTreeMap::new(),
121 created_at: now,
122 updated_at: now,
123 }
124 }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct AuthProfilesData {
129 pub schema_version: u32,
130 pub updated_at: DateTime<Utc>,
131 pub active_profiles: BTreeMap<String, String>,
132 pub profiles: BTreeMap<String, AuthProfile>,
133}
134
135impl Default for AuthProfilesData {
136 fn default() -> Self {
137 Self {
138 schema_version: CURRENT_SCHEMA_VERSION,
139 updated_at: Utc::now(),
140 active_profiles: BTreeMap::new(),
141 profiles: BTreeMap::new(),
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
147pub struct AuthProfilesStore {
148 path: PathBuf,
149 lock_path: PathBuf,
150 secret_store: SecretStore,
151}
152
153impl AuthProfilesStore {
154 pub fn new(state_dir: &Path, encrypt_secrets: bool) -> Self {
155 Self {
156 path: state_dir.join(PROFILES_FILENAME),
157 lock_path: state_dir.join(LOCK_FILENAME),
158 secret_store: SecretStore::new(state_dir, encrypt_secrets),
159 }
160 }
161
162 pub fn path(&self) -> &Path {
163 &self.path
164 }
165
166 pub async fn load(&self) -> Result<AuthProfilesData> {
167 let _lock = self.acquire_lock().await?;
168 self.load_locked().await
169 }
170
171 pub async fn upsert_profile(&self, mut profile: AuthProfile, set_active: bool) -> Result<()> {
172 let _lock = self.acquire_lock().await?;
173 let mut data = self.load_locked().await?;
174
175 profile.updated_at = Utc::now();
176 if let Some(existing) = data.profiles.get(&profile.id) {
177 profile.created_at = existing.created_at;
178 }
179
180 if set_active {
181 data.active_profiles
182 .insert(profile.model_provider.clone(), profile.id.clone());
183 }
184
185 data.profiles.insert(profile.id.clone(), profile);
186 data.updated_at = Utc::now();
187
188 self.save_locked(&data).await
189 }
190
191 pub async fn remove_profile(&self, profile_id: &str) -> Result<bool> {
192 let _lock = self.acquire_lock().await?;
193 let mut data = self.load_locked().await?;
194
195 let removed = data.profiles.remove(profile_id).is_some();
196 if !removed {
197 return Ok(false);
198 }
199
200 data.active_profiles
201 .retain(|_, active| active != profile_id);
202 data.updated_at = Utc::now();
203 self.save_locked(&data).await?;
204 Ok(true)
205 }
206
207 pub async fn set_active_profile(&self, model_provider: &str, profile_id: &str) -> Result<()> {
208 let _lock = self.acquire_lock().await?;
209 let mut data = self.load_locked().await?;
210
211 if !data.profiles.contains_key(profile_id) {
212 anyhow::bail!("Auth profile not found: {profile_id}");
213 }
214
215 data.active_profiles
216 .insert(model_provider.to_string(), profile_id.to_string());
217 data.updated_at = Utc::now();
218 self.save_locked(&data).await
219 }
220
221 pub async fn clear_active_profile(&self, model_provider: &str) -> Result<()> {
222 let _lock = self.acquire_lock().await?;
223 let mut data = self.load_locked().await?;
224 data.active_profiles.remove(model_provider);
225 data.updated_at = Utc::now();
226 self.save_locked(&data).await
227 }
228
229 pub async fn update_profile<F>(&self, profile_id: &str, mut updater: F) -> Result<AuthProfile>
230 where
231 F: FnMut(&mut AuthProfile) -> Result<()>,
232 {
233 let _lock = self.acquire_lock().await?;
234 let mut data = self.load_locked().await?;
235
236 let profile = data.profiles.get_mut(profile_id).ok_or_else(|| {
237 ::zeroclaw_log::record!(
238 WARN,
239 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
240 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
241 .with_attrs(::serde_json::json!({"profile_id": profile_id})),
242 "auth_profiles: profile not found for update"
243 );
244 anyhow::Error::msg(format!("Auth profile not found: {profile_id}"))
245 })?;
246
247 updater(profile)?;
248 profile.updated_at = Utc::now();
249 let updated_profile = profile.clone();
250 data.updated_at = Utc::now();
251 self.save_locked(&data).await?;
252 Ok(updated_profile)
253 }
254
255 async fn load_locked(&self) -> Result<AuthProfilesData> {
256 let mut persisted = self.read_persisted_locked().await?;
257 let mut migrated = false;
258
259 let mut profiles = BTreeMap::new();
260 for (id, p) in &mut persisted.profiles {
261 let (access_token, access_migrated) =
262 self.decrypt_optional(p.access_token.as_deref())?;
263 let (refresh_token, refresh_migrated) =
264 self.decrypt_optional(p.refresh_token.as_deref())?;
265 let (id_token, id_migrated) = self.decrypt_optional(p.id_token.as_deref())?;
266 let (token, token_migrated) = self.decrypt_optional(p.token.as_deref())?;
267
268 if let Some(value) = access_migrated {
269 p.access_token = Some(value);
270 migrated = true;
271 }
272 if let Some(value) = refresh_migrated {
273 p.refresh_token = Some(value);
274 migrated = true;
275 }
276 if let Some(value) = id_migrated {
277 p.id_token = Some(value);
278 migrated = true;
279 }
280 if let Some(value) = token_migrated {
281 p.token = Some(value);
282 migrated = true;
283 }
284
285 let kind = parse_profile_kind(&p.kind)?;
286 let token_set = match kind {
287 AuthProfileKind::OAuth => {
288 let access = access_token.ok_or_else(|| {
289 ::zeroclaw_log::record!(
290 ERROR,
291 ::zeroclaw_log::Event::new(
292 module_path!(),
293 ::zeroclaw_log::Action::Reject
294 )
295 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
296 .with_attrs(::serde_json::json!({
297 "profile_id": id,
298 "missing": "access_token",
299 })),
300 "auth_profiles: OAuth profile missing access_token"
301 );
302 anyhow::Error::msg(format!("OAuth profile missing access_token: {id}"))
303 })?;
304 Some(TokenSet {
305 access_token: access,
306 refresh_token,
307 id_token,
308 expires_at: parse_optional_datetime(p.expires_at.as_deref())?,
309 token_type: p.token_type.clone(),
310 scope: p.scope.clone(),
311 })
312 }
313 AuthProfileKind::Token => None,
314 };
315
316 profiles.insert(
317 id.clone(),
318 AuthProfile {
319 id: id.clone(),
320 model_provider: p.model_provider.clone(),
321 profile_name: p.profile_name.clone(),
322 kind,
323 account_id: p.account_id.clone(),
324 workspace_id: p.workspace_id.clone(),
325 token_set,
326 token,
327 metadata: p.metadata.clone(),
328 created_at: parse_datetime_with_fallback(&p.created_at),
329 updated_at: parse_datetime_with_fallback(&p.updated_at),
330 },
331 );
332 }
333
334 if migrated {
335 self.write_persisted_locked(&persisted).await?;
336 }
337
338 Ok(AuthProfilesData {
339 schema_version: persisted.schema_version,
340 updated_at: parse_datetime_with_fallback(&persisted.updated_at),
341 active_profiles: persisted.active_profiles,
342 profiles,
343 })
344 }
345
346 async fn save_locked(&self, data: &AuthProfilesData) -> Result<()> {
347 let mut persisted = PersistedAuthProfiles {
348 schema_version: CURRENT_SCHEMA_VERSION,
349 updated_at: data.updated_at.to_rfc3339(),
350 active_profiles: data.active_profiles.clone(),
351 profiles: BTreeMap::new(),
352 };
353
354 for (id, profile) in &data.profiles {
355 let (access_token, refresh_token, id_token, expires_at, token_type, scope) =
356 match (&profile.kind, &profile.token_set) {
357 (AuthProfileKind::OAuth, Some(token_set)) => (
358 self.encrypt_optional(Some(&token_set.access_token))?,
359 self.encrypt_optional(token_set.refresh_token.as_deref())?,
360 self.encrypt_optional(token_set.id_token.as_deref())?,
361 token_set.expires_at.as_ref().map(DateTime::to_rfc3339),
362 token_set.token_type.clone(),
363 token_set.scope.clone(),
364 ),
365 _ => (None, None, None, None, None, None),
366 };
367
368 let token = self.encrypt_optional(profile.token.as_deref())?;
369
370 persisted.profiles.insert(
371 id.clone(),
372 PersistedAuthProfile {
373 model_provider: profile.model_provider.clone(),
374 profile_name: profile.profile_name.clone(),
375 kind: profile_kind_to_string(profile.kind).to_string(),
376 account_id: profile.account_id.clone(),
377 workspace_id: profile.workspace_id.clone(),
378 access_token,
379 refresh_token,
380 id_token,
381 token,
382 expires_at,
383 token_type,
384 scope,
385 metadata: profile.metadata.clone(),
386 created_at: profile.created_at.to_rfc3339(),
387 updated_at: profile.updated_at.to_rfc3339(),
388 },
389 );
390 }
391
392 self.write_persisted_locked(&persisted).await
393 }
394
395 async fn read_persisted_locked(&self) -> Result<PersistedAuthProfiles> {
396 if !self.path.exists() {
397 return Ok(PersistedAuthProfiles::default());
398 }
399
400 let bytes = fs::read(&self.path).await.with_context(|| {
401 format!(
402 "Failed to read auth profile store at {}",
403 self.path.display()
404 )
405 })?;
406
407 if bytes.is_empty() {
408 return Ok(PersistedAuthProfiles::default());
409 }
410
411 let mut persisted: PersistedAuthProfiles =
412 serde_json::from_slice(&bytes).with_context(|| {
413 format!(
414 "Failed to parse auth profile store at {}",
415 self.path.display()
416 )
417 })?;
418
419 if persisted.schema_version == 0 {
420 persisted.schema_version = CURRENT_SCHEMA_VERSION;
421 }
422
423 if persisted.schema_version > CURRENT_SCHEMA_VERSION {
424 anyhow::bail!(
425 "Unsupported auth profile schema version {} (max supported: {})",
426 persisted.schema_version,
427 CURRENT_SCHEMA_VERSION
428 );
429 }
430
431 Ok(persisted)
432 }
433
434 async fn write_persisted_locked(&self, persisted: &PersistedAuthProfiles) -> Result<()> {
435 if let Some(parent) = self.path.parent() {
436 fs::create_dir_all(parent).await.with_context(|| {
437 format!(
438 "Failed to create auth profile directory at {}",
439 parent.display()
440 )
441 })?;
442 }
443
444 let json =
445 serde_json::to_vec_pretty(persisted).context("Failed to serialize auth profiles")?;
446 let tmp_name = format!(
447 "{}.tmp.{}.{}",
448 PROFILES_FILENAME,
449 std::process::id(),
450 Utc::now().timestamp_nanos_opt().unwrap_or_default()
451 );
452 let tmp_path = self.path.with_file_name(tmp_name);
453
454 fs::write(&tmp_path, &json).await.with_context(|| {
455 format!(
456 "Failed to write temporary auth profile file at {}",
457 tmp_path.display()
458 )
459 })?;
460
461 fs::rename(&tmp_path, &self.path).await.with_context(|| {
462 format!(
463 "Failed to replace auth profile store at {}",
464 self.path.display()
465 )
466 })?;
467
468 Ok(())
469 }
470
471 fn encrypt_optional(&self, value: Option<&str>) -> Result<Option<String>> {
472 match value {
473 Some(value) if !value.is_empty() => self.secret_store.encrypt(value).map(Some),
474 Some(_) | None => Ok(None),
475 }
476 }
477
478 fn decrypt_optional(&self, value: Option<&str>) -> Result<(Option<String>, Option<String>)> {
479 match value {
480 Some(value) if !value.is_empty() => {
481 let (plaintext, migrated) = self.secret_store.decrypt_and_migrate(value)?;
482 Ok((Some(plaintext), migrated))
483 }
484 Some(_) | None => Ok((None, None)),
485 }
486 }
487
488 async fn acquire_lock(&self) -> Result<AuthProfileLockGuard> {
489 if let Some(parent) = self.lock_path.parent() {
490 fs::create_dir_all(parent).await.with_context(|| {
491 format!(
492 "Failed to create lock directory at {}",
493 parent.display().to_string()
494 )
495 })?;
496 }
497
498 let mut waited = 0_u64;
499 loop {
500 match OpenOptions::new()
501 .create_new(true)
502 .write(true)
503 .open(&self.lock_path)
504 .await
505 {
506 Ok(mut file) => {
507 let mut buffer = Vec::new();
508 writeln!(&mut buffer, "pid={}", std::process::id())?;
509 if let Err(e) = file.write_all(&buffer).await {
510 fs::remove_file(&self.lock_path)
511 .await
512 .inspect(|e| {
513 ::zeroclaw_log::record!(
514 ERROR,
515 ::zeroclaw_log::Event::new(
516 module_path!(),
517 ::zeroclaw_log::Action::Fail
518 )
519 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
520 .with_attrs(::serde_json::json!({"e": format!("{:?}", e)})),
521 "Failed to remove auth profile lock file: "
522 );
523 })
524 .ok();
525 return Err(e).with_context(|| {
526 format!(
527 "Failed to write auth profile lock at {}",
528 self.lock_path.display()
529 )
530 });
531 }
532 return Ok(AuthProfileLockGuard {
533 lock_path: self.lock_path.clone(),
534 });
535 }
536 Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
537 if waited >= LOCK_TIMEOUT_MS {
538 anyhow::bail!(
539 "Timed out waiting for auth profile lock at {}",
540 self.lock_path.display()
541 );
542 }
543 sleep(Duration::from_millis(LOCK_WAIT_MS)).await;
544 waited = waited.saturating_add(LOCK_WAIT_MS);
545 }
546 Err(e) => {
547 return Err(e).with_context(|| {
548 format!(
549 "Failed to create auth profile lock at {}",
550 self.lock_path.display()
551 )
552 });
553 }
554 }
555 }
556 }
557}
558
559struct AuthProfileLockGuard {
560 lock_path: PathBuf,
561}
562
563impl Drop for AuthProfileLockGuard {
564 fn drop(&mut self) {
565 let _ = std::fs::remove_file(&self.lock_path);
566 }
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
570struct PersistedAuthProfiles {
571 #[serde(default = "default_schema_version")]
572 schema_version: u32,
573 #[serde(default = "default_now_rfc3339")]
574 updated_at: String,
575 #[serde(default)]
576 active_profiles: BTreeMap<String, String>,
577 #[serde(default)]
578 profiles: BTreeMap<String, PersistedAuthProfile>,
579}
580
581impl Default for PersistedAuthProfiles {
582 fn default() -> Self {
583 Self {
584 schema_version: CURRENT_SCHEMA_VERSION,
585 updated_at: default_now_rfc3339(),
586 active_profiles: BTreeMap::new(),
587 profiles: BTreeMap::new(),
588 }
589 }
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize, Default)]
593struct PersistedAuthProfile {
594 model_provider: String,
595 profile_name: String,
596 kind: String,
597 #[serde(default)]
598 account_id: Option<String>,
599 #[serde(default)]
600 workspace_id: Option<String>,
601 #[serde(default)]
602 access_token: Option<String>,
603 #[serde(default)]
604 refresh_token: Option<String>,
605 #[serde(default)]
606 id_token: Option<String>,
607 #[serde(default)]
608 token: Option<String>,
609 #[serde(default)]
610 expires_at: Option<String>,
611 #[serde(default)]
612 token_type: Option<String>,
613 #[serde(default)]
614 scope: Option<String>,
615 #[serde(default = "default_now_rfc3339")]
616 created_at: String,
617 #[serde(default = "default_now_rfc3339")]
618 updated_at: String,
619 #[serde(default)]
620 metadata: BTreeMap<String, String>,
621}
622
623fn default_schema_version() -> u32 {
624 CURRENT_SCHEMA_VERSION
625}
626
627fn default_now_rfc3339() -> String {
628 Utc::now().to_rfc3339()
629}
630
631fn parse_profile_kind(value: &str) -> Result<AuthProfileKind> {
632 match value {
633 "oauth" => Ok(AuthProfileKind::OAuth),
634 "token" => Ok(AuthProfileKind::Token),
635 other => anyhow::bail!("Unsupported auth profile kind: {other}"),
636 }
637}
638
639fn profile_kind_to_string(kind: AuthProfileKind) -> &'static str {
640 match kind {
641 AuthProfileKind::OAuth => "oauth",
642 AuthProfileKind::Token => "token",
643 }
644}
645
646fn parse_optional_datetime(value: Option<&str>) -> Result<Option<DateTime<Utc>>> {
647 value.map(parse_datetime).transpose()
648}
649
650fn parse_datetime(value: &str) -> Result<DateTime<Utc>> {
651 DateTime::parse_from_rfc3339(value)
652 .map(|dt| dt.with_timezone(&Utc))
653 .with_context(|| format!("Invalid RFC3339 timestamp: {value}"))
654}
655
656fn parse_datetime_with_fallback(value: &str) -> DateTime<Utc> {
657 parse_datetime(value).unwrap_or_else(|_| Utc::now())
658}
659
660pub fn profile_id(model_provider: &str, profile_name: &str) -> String {
661 format!("{}:{}", model_provider.trim(), profile_name.trim())
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667 use tempfile::TempDir;
668
669 #[test]
670 fn profile_id_format() {
671 assert_eq!(
672 profile_id("openai-codex", "default"),
673 "openai-codex:default"
674 );
675 }
676
677 #[test]
678 fn token_expiry_math() {
679 let token_set = TokenSet {
680 access_token: "token".into(),
681 refresh_token: Some("refresh".into()),
682 id_token: None,
683 expires_at: Some(Utc::now() + chrono::Duration::seconds(10)),
684 token_type: Some("Bearer".into()),
685 scope: None,
686 };
687
688 assert!(token_set.is_expiring_within(Duration::from_secs(15)));
689 assert!(!token_set.is_expiring_within(Duration::from_secs(1)));
690 }
691
692 #[tokio::test]
693 async fn store_roundtrip_with_encryption() {
694 let tmp = TempDir::new().unwrap();
695 let store = AuthProfilesStore::new(tmp.path(), true);
696
697 let mut profile = AuthProfile::new_oauth(
698 "openai-codex",
699 "default",
700 TokenSet {
701 access_token: "access-123".into(),
702 refresh_token: Some("refresh-123".into()),
703 id_token: None,
704 expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
705 token_type: Some("Bearer".into()),
706 scope: Some("openid offline_access".into()),
707 },
708 );
709 profile.account_id = Some("acct_123".into());
710
711 store.upsert_profile(profile.clone(), true).await.unwrap();
712
713 let data = store.load().await.unwrap();
714 let loaded = data.profiles.get(&profile.id).unwrap();
715
716 assert_eq!(loaded.model_provider, "openai-codex");
717 assert_eq!(loaded.profile_name, "default");
718 assert_eq!(loaded.account_id.as_deref(), Some("acct_123"));
719 assert_eq!(
720 loaded
721 .token_set
722 .as_ref()
723 .and_then(|t| t.refresh_token.as_deref()),
724 Some("refresh-123")
725 );
726
727 let raw = tokio::fs::read_to_string(store.path()).await.unwrap();
728 assert!(raw.contains("enc2:"));
729 assert!(!raw.contains("refresh-123"));
730 assert!(!raw.contains("access-123"));
731 }
732
733 #[tokio::test]
734 async fn atomic_write_replaces_file() {
735 let tmp = TempDir::new().unwrap();
736 let store = AuthProfilesStore::new(tmp.path(), false);
737
738 let profile = AuthProfile::new_token("anthropic", "default", "token-abc".into());
739 store.upsert_profile(profile, true).await.unwrap();
740
741 let path = store.path().to_path_buf();
742 assert!(path.exists());
743
744 let contents = tokio::fs::read_to_string(path).await.unwrap();
745 assert!(contents.contains("\"schema_version\": 1"));
746 }
747}