1use super::AppState;
4use axum::{
5 extract::State,
6 http::{HeaderMap, StatusCode, header},
7 response::{IntoResponse, Json},
8};
9use chrono::{DateTime, Utc};
10use parking_lot::Mutex;
11use rusqlite::Connection;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct DeviceInfo {
19 pub id: String,
20 pub name: Option<String>,
21 pub device_type: Option<String>,
22 pub paired_at: DateTime<Utc>,
23 pub last_seen: DateTime<Utc>,
24 pub ip_address: Option<String>,
25 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub capabilities: Option<Vec<String>>,
29}
30
31#[derive(Debug)]
33pub struct DeviceRegistry {
34 cache: Mutex<HashMap<String, DeviceInfo>>,
35 db_path: PathBuf,
36}
37
38impl DeviceRegistry {
39 pub fn new(workspace_dir: &Path) -> Self {
40 let db_path = workspace_dir.join("devices.db");
41 let conn = Connection::open(&db_path).expect("Failed to open device registry database");
42 conn.execute_batch(
43 "CREATE TABLE IF NOT EXISTS devices (
44 token_hash TEXT PRIMARY KEY,
45 id TEXT NOT NULL,
46 name TEXT,
47 device_type TEXT,
48 paired_at TEXT NOT NULL,
49 last_seen TEXT NOT NULL,
50 ip_address TEXT,
51 capabilities TEXT
52 )",
53 )
54 .expect("Failed to create devices table");
55
56 let _ = conn.execute("ALTER TABLE devices ADD COLUMN capabilities TEXT", []);
59
60 let mut cache = HashMap::new();
62 let mut stmt = conn
63 .prepare("SELECT token_hash, id, name, device_type, paired_at, last_seen, ip_address, capabilities FROM devices")
64 .expect("Failed to prepare device select");
65 let rows = stmt
66 .query_map([], |row| {
67 let token_hash: String = row.get(0)?;
68 let id: String = row.get(1)?;
69 let name: Option<String> = row.get(2)?;
70 let device_type: Option<String> = row.get(3)?;
71 let paired_at_str: String = row.get(4)?;
72 let last_seen_str: String = row.get(5)?;
73 let ip_address: Option<String> = row.get(6)?;
74 let capabilities_json: Option<String> = row.get(7)?;
75 let paired_at = DateTime::parse_from_rfc3339(&paired_at_str)
76 .map(|dt| dt.with_timezone(&Utc))
77 .unwrap_or_else(|_| Utc::now());
78 let last_seen = DateTime::parse_from_rfc3339(&last_seen_str)
79 .map(|dt| dt.with_timezone(&Utc))
80 .unwrap_or_else(|_| Utc::now());
81 let capabilities = capabilities_json
82 .as_deref()
83 .and_then(|s| serde_json::from_str::<Vec<String>>(s).ok());
84 Ok((
85 token_hash,
86 DeviceInfo {
87 id,
88 name,
89 device_type,
90 paired_at,
91 last_seen,
92 ip_address,
93 capabilities,
94 },
95 ))
96 })
97 .expect("Failed to query devices");
98 for (hash, info) in rows.flatten() {
99 cache.insert(hash, info);
100 }
101
102 Self {
103 cache: Mutex::new(cache),
104 db_path,
105 }
106 }
107
108 fn open_db(&self) -> Connection {
109 Connection::open(&self.db_path).expect("Failed to open device registry database")
110 }
111
112 pub fn register(&self, token_hash: String, info: DeviceInfo) {
113 let capabilities_json = info
114 .capabilities
115 .as_ref()
116 .and_then(|c| serde_json::to_string(c).ok());
117 let conn = self.open_db();
118 conn.execute(
119 "INSERT OR REPLACE INTO devices (token_hash, id, name, device_type, paired_at, last_seen, ip_address, capabilities) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
120 rusqlite::params![
121 token_hash,
122 info.id,
123 info.name,
124 info.device_type,
125 info.paired_at.to_rfc3339(),
126 info.last_seen.to_rfc3339(),
127 info.ip_address,
128 capabilities_json,
129 ],
130 )
131 .expect("Failed to insert device");
132 self.cache.lock().insert(token_hash, info);
133 }
134
135 pub fn list(&self) -> Vec<DeviceInfo> {
136 let conn = self.open_db();
137 let mut stmt = conn
138 .prepare("SELECT token_hash, id, name, device_type, paired_at, last_seen, ip_address, capabilities FROM devices")
139 .expect("Failed to prepare device select");
140 let rows = stmt
141 .query_map([], |row| {
142 let id: String = row.get(1)?;
143 let name: Option<String> = row.get(2)?;
144 let device_type: Option<String> = row.get(3)?;
145 let paired_at_str: String = row.get(4)?;
146 let last_seen_str: String = row.get(5)?;
147 let ip_address: Option<String> = row.get(6)?;
148 let capabilities_json: Option<String> = row.get(7)?;
149 let paired_at = DateTime::parse_from_rfc3339(&paired_at_str)
150 .map(|dt| dt.with_timezone(&Utc))
151 .unwrap_or_else(|_| Utc::now());
152 let last_seen = DateTime::parse_from_rfc3339(&last_seen_str)
153 .map(|dt| dt.with_timezone(&Utc))
154 .unwrap_or_else(|_| Utc::now());
155 let capabilities = capabilities_json
156 .as_deref()
157 .and_then(|s| serde_json::from_str::<Vec<String>>(s).ok());
158 Ok(DeviceInfo {
159 id,
160 name,
161 device_type,
162 paired_at,
163 last_seen,
164 ip_address,
165 capabilities,
166 })
167 })
168 .expect("Failed to query devices");
169 rows.filter_map(|r| r.ok()).collect()
170 }
171
172 pub fn revoke(&self, device_id: &str) -> bool {
173 let conn = self.open_db();
174 let deleted = conn
175 .execute(
176 "DELETE FROM devices WHERE id = ?1",
177 rusqlite::params![device_id],
178 )
179 .unwrap_or(0);
180 if deleted > 0 {
181 let mut cache = self.cache.lock();
182 let key = cache
183 .iter()
184 .find(|(_, v)| v.id == device_id)
185 .map(|(k, _)| k.clone());
186 if let Some(key) = key {
187 cache.remove(&key);
188 }
189 true
190 } else {
191 false
192 }
193 }
194
195 pub fn update_last_seen(&self, token_hash: &str) {
196 let now = Utc::now();
197 let conn = self.open_db();
198 conn.execute(
199 "UPDATE devices SET last_seen = ?1 WHERE token_hash = ?2",
200 rusqlite::params![now.to_rfc3339(), token_hash],
201 )
202 .ok();
203 if let Some(device) = self.cache.lock().get_mut(token_hash) {
204 device.last_seen = now;
205 }
206 }
207
208 pub fn update_capabilities(&self, token_hash: &str, capabilities: Vec<String>) -> bool {
211 let json = serde_json::to_string(&capabilities).unwrap_or_else(|_| "[]".into());
212 let conn = self.open_db();
213 let updated = conn
214 .execute(
215 "UPDATE devices SET capabilities = ?1, last_seen = ?2 WHERE token_hash = ?3",
216 rusqlite::params![json, Utc::now().to_rfc3339(), token_hash],
217 )
218 .unwrap_or(0);
219 if updated > 0
220 && let Some(device) = self.cache.lock().get_mut(token_hash)
221 {
222 device.capabilities = Some(capabilities);
223 device.last_seen = Utc::now();
224 }
225 updated > 0
226 }
227
228 pub fn device_count(&self) -> usize {
229 self.cache.lock().len()
230 }
231}
232
233#[derive(Debug, Default)]
235pub struct PairingStore {
236 pending: Mutex<Vec<PendingPairing>>,
237}
238
239#[derive(Debug, Clone, Serialize)]
240struct PendingPairing {
241 code: String,
242 created_at: DateTime<Utc>,
243 expires_at: DateTime<Utc>,
244 client_ip: Option<String>,
245 attempts: u32,
246}
247
248impl PairingStore {
249 pub fn new() -> Self {
250 Self::default()
251 }
252
253 pub fn pending_count(&self) -> usize {
254 let mut pending = self.pending.lock();
255 pending.retain(|p| p.expires_at > Utc::now());
256 pending.len()
257 }
258}
259
260fn extract_bearer(headers: &HeaderMap) -> Option<&str> {
261 headers
262 .get(header::AUTHORIZATION)
263 .and_then(|v| v.to_str().ok())
264 .and_then(|auth| auth.strip_prefix("Bearer "))
265}
266
267fn require_auth(state: &AppState, headers: &HeaderMap) -> Result<(), (StatusCode, &'static str)> {
268 if state.pairing.require_pairing() {
269 let token = extract_bearer(headers).unwrap_or("");
270 if !state.pairing.is_authenticated(token) {
271 return Err((StatusCode::UNAUTHORIZED, "Unauthorized"));
272 }
273 }
274 Ok(())
275}
276
277pub async fn initiate_pairing(
279 State(state): State<AppState>,
280 headers: HeaderMap,
281) -> impl IntoResponse {
282 if let Err(e) = require_auth(&state, &headers) {
283 return e.into_response();
284 }
285
286 match state.pairing.generate_new_pairing_code() {
287 Some(code) => Json(serde_json::json!({
288 "pairing_code": code,
289 "message": "New pairing code generated"
290 }))
291 .into_response(),
292 None => (
293 StatusCode::SERVICE_UNAVAILABLE,
294 "Pairing is disabled or not available",
295 )
296 .into_response(),
297 }
298}
299
300pub async fn submit_pairing_enhanced(
302 State(state): State<AppState>,
303 headers: HeaderMap,
304 Json(body): Json<serde_json::Value>,
305) -> impl IntoResponse {
306 let code = body["code"].as_str().unwrap_or("");
307 let device_name = body["device_name"].as_str().map(String::from);
308 let device_type = body["device_type"].as_str().map(String::from);
309
310 let client_id = headers
311 .get("X-Forwarded-For")
312 .and_then(|v| v.to_str().ok())
313 .unwrap_or("unknown")
314 .to_string();
315
316 match state.pairing.try_pair(code, &client_id).await {
317 Ok(Some(token)) => {
318 let token_hash = {
320 use sha2::{Digest, Sha256};
321 let hash = Sha256::digest(token.as_bytes());
322 hex::encode(hash)
323 };
324 if let Some(ref registry) = state.device_registry {
325 registry.register(
326 token_hash,
327 DeviceInfo {
328 id: uuid::Uuid::new_v4().to_string(),
329 name: device_name,
330 device_type,
331 paired_at: Utc::now(),
332 last_seen: Utc::now(),
333 ip_address: Some(client_id),
334 capabilities: None,
335 },
336 );
337 }
338 Json(serde_json::json!({
339 "token": token,
340 "message": "Pairing successful"
341 }))
342 .into_response()
343 }
344 Ok(None) => (StatusCode::BAD_REQUEST, "Invalid or expired pairing code").into_response(),
345 Err(lockout_secs) => (
346 StatusCode::TOO_MANY_REQUESTS,
347 format!("Too many attempts. Locked out for {lockout_secs}s"),
348 )
349 .into_response(),
350 }
351}
352
353pub async fn list_devices(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
355 if let Err(e) = require_auth(&state, &headers) {
356 return e.into_response();
357 }
358
359 let devices = state
360 .device_registry
361 .as_ref()
362 .map(|r| r.list())
363 .unwrap_or_default();
364
365 let count = devices.len();
366 Json(serde_json::json!({
367 "devices": devices,
368 "count": count
369 }))
370 .into_response()
371}
372
373pub async fn revoke_device(
375 State(state): State<AppState>,
376 headers: HeaderMap,
377 axum::extract::Path(device_id): axum::extract::Path<String>,
378) -> impl IntoResponse {
379 if let Err(e) = require_auth(&state, &headers) {
380 return e.into_response();
381 }
382
383 let revoked = state
384 .device_registry
385 .as_ref()
386 .map(|r| r.revoke(&device_id))
387 .unwrap_or(false);
388
389 if revoked {
390 Json(serde_json::json!({
391 "message": "Device revoked",
392 "device_id": device_id
393 }))
394 .into_response()
395 } else {
396 (StatusCode::NOT_FOUND, "Device not found").into_response()
397 }
398}
399
400pub async fn update_my_capabilities(
405 State(state): State<AppState>,
406 headers: HeaderMap,
407 Json(body): Json<serde_json::Value>,
408) -> impl IntoResponse {
409 if let Err(e) = require_auth(&state, &headers) {
410 return e.into_response();
411 }
412
413 let token = match extract_bearer(&headers) {
414 Some(t) => t,
415 None => return (StatusCode::UNAUTHORIZED, "Missing bearer token").into_response(),
416 };
417 let token_hash = {
418 use sha2::{Digest, Sha256};
419 let hash = Sha256::digest(token.as_bytes());
420 hex::encode(hash)
421 };
422
423 let capabilities: Vec<String> = body
424 .get("capabilities")
425 .and_then(|v| v.as_array())
426 .map(|arr| {
427 arr.iter()
428 .filter_map(|v| v.as_str().map(String::from))
429 .collect()
430 })
431 .unwrap_or_default();
432
433 let registry = match state.device_registry.as_ref() {
434 Some(r) => r,
435 None => {
436 return (
437 StatusCode::SERVICE_UNAVAILABLE,
438 "Device registry is disabled",
439 )
440 .into_response();
441 }
442 };
443
444 if registry.update_capabilities(&token_hash, capabilities.clone()) {
445 Json(serde_json::json!({
446 "message": "Capabilities updated",
447 "capabilities": capabilities,
448 }))
449 .into_response()
450 } else {
451 (StatusCode::NOT_FOUND, "Device not found for this token").into_response()
452 }
453}
454
455pub async fn rotate_token(
457 State(state): State<AppState>,
458 headers: HeaderMap,
459 axum::extract::Path(device_id): axum::extract::Path<String>,
460) -> impl IntoResponse {
461 if let Err(e) = require_auth(&state, &headers) {
462 return e.into_response();
463 }
464
465 match state.pairing.generate_new_pairing_code() {
467 Some(code) => Json(serde_json::json!({
468 "device_id": device_id,
469 "pairing_code": code,
470 "message": "Use this code to re-pair the device"
471 }))
472 .into_response(),
473 None => (
474 StatusCode::SERVICE_UNAVAILABLE,
475 "Cannot generate new pairing code",
476 )
477 .into_response(),
478 }
479}