Skip to main content

zeroclaw_runtime/security/
otp.rs

1use crate::security::secrets::SecretStore;
2use anyhow::{Context, Result};
3use parking_lot::Mutex;
4use ring::hmac;
5use std::collections::HashMap;
6use std::fs;
7use std::path::{Path, PathBuf};
8use std::time::{SystemTime, UNIX_EPOCH};
9use zeroclaw_config::schema::OtpConfig;
10
11const OTP_SECRET_FILE: &str = "otp-secret";
12const OTP_DIGITS: u32 = 6;
13const OTP_ISSUER: &str = "ZeroClaw";
14
15#[derive(Debug)]
16pub struct OtpValidator {
17    config: OtpConfig,
18    secret: Vec<u8>,
19    cached_codes: Mutex<HashMap<String, u64>>,
20}
21
22impl OtpValidator {
23    pub fn from_config(
24        config: &OtpConfig,
25        zeroclaw_dir: &Path,
26        store: &SecretStore,
27    ) -> Result<(Self, Option<String>)> {
28        let secret_path = secret_file_path(zeroclaw_dir);
29        let (secret, generated) = if secret_path.exists() {
30            let encoded = fs::read_to_string(&secret_path).with_context(|| {
31                format!(
32                    "Failed to read OTP secret file {}",
33                    secret_path.display().to_string()
34                )
35            })?;
36            let decrypted = store
37                .decrypt(encoded.trim())
38                .context("Failed to decrypt OTP secret file")?;
39            (decode_base32_secret(&decrypted)?, false)
40        } else {
41            let raw: [u8; 20] = rand::random();
42            let encoded_secret = encode_base32_secret(&raw);
43            let encrypted = store
44                .encrypt(&encoded_secret)
45                .context("Failed to encrypt OTP secret")?;
46            write_secret_file(&secret_path, &encrypted)?;
47            (raw.to_vec(), true)
48        };
49
50        let validator = Self {
51            config: config.clone(),
52            secret,
53            cached_codes: Mutex::new(HashMap::new()),
54        };
55        let uri = if generated {
56            Some(validator.otpauth_uri())
57        } else {
58            None
59        };
60        Ok((validator, uri))
61    }
62
63    pub fn validate(&self, code: &str) -> Result<bool> {
64        self.validate_at(code, unix_timestamp_now())
65    }
66
67    fn validate_at(&self, code: &str, now_secs: u64) -> Result<bool> {
68        let normalized = code.trim();
69        if normalized.len() != OTP_DIGITS as usize
70            || !normalized.chars().all(|ch| ch.is_ascii_digit())
71        {
72            return Ok(false);
73        }
74
75        {
76            let mut cache = self.cached_codes.lock();
77            cache.retain(|_, expiry| *expiry >= now_secs);
78            if cache
79                .get(normalized)
80                .is_some_and(|expiry| *expiry >= now_secs)
81            {
82                return Ok(true);
83            }
84        }
85
86        let step = self.config.token_ttl_secs.max(1);
87        let counter = now_secs / step;
88        let counters = [
89            counter.saturating_sub(1),
90            counter,
91            counter.saturating_add(1),
92        ];
93
94        let is_valid = counters
95            .iter()
96            .map(|c| compute_totp_code(&self.secret, *c))
97            .any(|candidate| candidate == normalized);
98
99        if is_valid {
100            let mut cache = self.cached_codes.lock();
101            cache.insert(
102                normalized.to_string(),
103                now_secs.saturating_add(self.config.cache_valid_secs),
104            );
105        }
106
107        Ok(is_valid)
108    }
109
110    pub fn otpauth_uri(&self) -> String {
111        let secret = encode_base32_secret(&self.secret);
112        let account = "zeroclaw";
113        format!(
114            "otpauth://totp/{issuer}:{account}?secret={secret}&issuer={issuer}&period={period}",
115            issuer = OTP_ISSUER,
116            period = self.config.token_ttl_secs.max(1)
117        )
118    }
119
120    #[cfg(test)]
121    pub fn code_for_timestamp(&self, timestamp: u64) -> String {
122        let counter = timestamp / self.config.token_ttl_secs.max(1);
123        compute_totp_code(&self.secret, counter)
124    }
125}
126
127pub fn secret_file_path(zeroclaw_dir: &Path) -> PathBuf {
128    zeroclaw_dir.join(OTP_SECRET_FILE)
129}
130
131fn write_secret_file(path: &Path, value: &str) -> Result<()> {
132    if let Some(parent) = path.parent() {
133        fs::create_dir_all(parent).with_context(|| {
134            format!(
135                "Failed to create directory {}",
136                parent.display().to_string()
137            )
138        })?;
139    }
140
141    let temp_path = path.with_extension(format!("tmp-{}", uuid::Uuid::new_v4()));
142    fs::write(&temp_path, value).with_context(|| {
143        format!(
144            "Failed to write temporary OTP secret {}",
145            temp_path.display()
146        )
147    })?;
148
149    #[cfg(unix)]
150    {
151        use std::os::unix::fs::PermissionsExt;
152        let _ = fs::set_permissions(&temp_path, fs::Permissions::from_mode(0o600));
153    }
154
155    fs::rename(&temp_path, path).with_context(|| {
156        format!(
157            "Failed to atomically replace OTP secret file {}",
158            path.display()
159        )
160    })?;
161    Ok(())
162}
163
164fn unix_timestamp_now() -> u64 {
165    SystemTime::now()
166        .duration_since(UNIX_EPOCH)
167        .map(|duration| duration.as_secs())
168        .unwrap_or(0)
169}
170
171fn compute_totp_code(secret: &[u8], counter: u64) -> String {
172    let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, secret);
173    let counter_bytes = counter.to_be_bytes();
174    let digest = hmac::sign(&key, &counter_bytes);
175    let hash = digest.as_ref();
176
177    let offset = (hash[19] & 0x0f) as usize;
178    let binary = ((u32::from(hash[offset]) & 0x7f) << 24)
179        | (u32::from(hash[offset + 1]) << 16)
180        | (u32::from(hash[offset + 2]) << 8)
181        | u32::from(hash[offset + 3]);
182
183    let code = binary % 10_u32.pow(OTP_DIGITS);
184    format!("{code:0>6}")
185}
186
187fn encode_base32_secret(input: &[u8]) -> String {
188    const ALPHABET: &[u8; 32] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
189    if input.is_empty() {
190        return String::new();
191    }
192
193    let mut result = String::new();
194    let mut buffer = 0u16;
195    let mut bits_left = 0u8;
196
197    for byte in input {
198        buffer = (buffer << 8) | u16::from(*byte);
199        bits_left += 8;
200
201        while bits_left >= 5 {
202            let index = ((buffer >> (bits_left - 5)) & 0x1f) as usize;
203            result.push(ALPHABET[index] as char);
204            bits_left -= 5;
205        }
206    }
207
208    if bits_left > 0 {
209        let index = ((buffer << (5 - bits_left)) & 0x1f) as usize;
210        result.push(ALPHABET[index] as char);
211    }
212
213    result
214}
215
216fn decode_base32_secret(raw: &str) -> Result<Vec<u8>> {
217    fn decode_char(ch: char) -> Option<u8> {
218        match ch {
219            'A'..='Z' => Some((ch as u8) - b'A'),
220            '2'..='7' => Some((ch as u8) - b'2' + 26),
221            _ => None,
222        }
223    }
224
225    let mut cleaned = raw
226        .chars()
227        .filter(|ch| !matches!(ch, ' ' | '\t' | '\n' | '\r' | '-'))
228        .collect::<String>()
229        .to_ascii_uppercase();
230    while cleaned.ends_with('=') {
231        cleaned.pop();
232    }
233    if cleaned.is_empty() {
234        anyhow::bail!("OTP secret is empty");
235    }
236
237    let mut output = Vec::new();
238    let mut buffer = 0u32;
239    let mut bits_left = 0u8;
240
241    for ch in cleaned.chars() {
242        let value = decode_char(ch)
243            .with_context(|| format!("OTP secret contains invalid base32 character '{ch}'"))?;
244        buffer = (buffer << 5) | u32::from(value);
245        bits_left += 5;
246
247        if bits_left >= 8 {
248            let byte = ((buffer >> (bits_left - 8)) & 0xff) as u8;
249            output.push(byte);
250            bits_left -= 8;
251        }
252    }
253
254    if output.is_empty() {
255        anyhow::bail!("OTP secret did not decode to any bytes");
256    }
257    Ok(output)
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use tempfile::tempdir;
264
265    fn test_config() -> OtpConfig {
266        OtpConfig {
267            enabled: true,
268            token_ttl_secs: 30,
269            cache_valid_secs: 120,
270            ..OtpConfig::default()
271        }
272    }
273
274    #[test]
275    fn valid_totp_code_is_accepted() {
276        let dir = tempdir().unwrap();
277        let store = SecretStore::new(dir.path(), true);
278        let (validator, _) = OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
279
280        let now = 1_700_000_000u64;
281        let code = validator.code_for_timestamp(now);
282        assert!(validator.validate_at(&code, now).unwrap());
283    }
284
285    #[test]
286    fn expired_totp_code_is_rejected() {
287        let dir = tempdir().unwrap();
288        let store = SecretStore::new(dir.path(), true);
289        let (validator, _) = OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
290
291        let stale = 1_700_000_000u64;
292        let now = stale + 300;
293        let code = validator.code_for_timestamp(stale);
294        assert!(!validator.validate_at(&code, now).unwrap());
295    }
296
297    #[test]
298    fn wrong_totp_code_is_rejected() {
299        let dir = tempdir().unwrap();
300        let store = SecretStore::new(dir.path(), true);
301        let (validator, _) = OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
302        assert!(!validator.validate_at("123456", 1_700_000_000).unwrap());
303    }
304
305    #[test]
306    fn secret_is_generated_and_reused() {
307        let dir = tempdir().unwrap();
308        let store = SecretStore::new(dir.path(), true);
309
310        let (first, first_uri) =
311            OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
312        assert!(first_uri.is_some());
313
314        let secret_path = secret_file_path(dir.path());
315        let stored = fs::read_to_string(&secret_path).unwrap();
316        assert!(SecretStore::is_encrypted(stored.trim()));
317
318        let (second, second_uri) =
319            OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
320        assert!(second_uri.is_none());
321
322        let ts = 1_700_000_000u64;
323        assert_eq!(first.code_for_timestamp(ts), second.code_for_timestamp(ts));
324    }
325}