1use super::AppState;
17use axum::{
18 extract::{
19 Query, State, WebSocketUpgrade,
20 ws::{Message, WebSocket},
21 },
22 http::{HeaderMap, header},
23 response::IntoResponse,
24};
25use futures_util::{SinkExt, StreamExt};
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29use std::sync::Arc;
30use tokio::sync::{mpsc, oneshot};
31use zeroclaw_runtime::security::pairing::PairingGuard;
32
33const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
35
36const WS_NODE_PROTOCOL: &str = "zeroclaw.nodes.v1";
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct NodeCapability {
42 pub name: String,
43 pub description: String,
44 #[serde(default = "default_capability_parameters")]
45 pub parameters: serde_json::Value,
46}
47
48fn default_capability_parameters() -> serde_json::Value {
49 serde_json::json!({
50 "type": "object",
51 "properties": {}
52 })
53}
54
55#[derive(Debug, Clone)]
57pub struct NodeInfo {
58 pub node_id: String,
59 pub capabilities: Vec<NodeCapability>,
60 pub invoke_tx: mpsc::Sender<NodeInvocation>,
62}
63
64#[derive(Debug)]
66pub struct NodeInvocation {
67 pub call_id: String,
68 pub capability: String,
69 pub args: serde_json::Value,
70 pub response_tx: oneshot::Sender<NodeInvocationResult>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct NodeInvocationResult {
76 pub success: bool,
77 pub output: String,
78 pub error: Option<String>,
79}
80
81#[derive(Debug, Default, Clone)]
83pub struct NodeRegistry {
84 nodes: Arc<RwLock<HashMap<String, NodeInfo>>>,
85 max_nodes: usize,
86}
87
88impl NodeRegistry {
89 pub fn new(max_nodes: usize) -> Self {
91 Self {
92 nodes: Arc::new(RwLock::new(HashMap::new())),
93 max_nodes,
94 }
95 }
96
97 pub fn register(&self, info: NodeInfo) -> bool {
99 let mut nodes = self.nodes.write();
100 if nodes.len() >= self.max_nodes && !nodes.contains_key(&info.node_id) {
101 return false;
102 }
103 nodes.insert(info.node_id.clone(), info);
104 true
105 }
106
107 pub fn unregister(&self, node_id: &str) {
109 self.nodes.write().remove(node_id);
110 }
111
112 pub fn node_ids(&self) -> Vec<String> {
114 self.nodes.read().keys().cloned().collect()
115 }
116
117 pub fn all_capabilities(&self) -> Vec<(String, String, NodeCapability)> {
119 let nodes = self.nodes.read();
120 let mut caps = Vec::new();
121 for info in nodes.values() {
122 for cap in &info.capabilities {
123 caps.push((info.node_id.clone(), cap.name.clone(), cap.clone()));
124 }
125 }
126 caps
127 }
128
129 pub fn invoke_tx(&self, node_id: &str) -> Option<mpsc::Sender<NodeInvocation>> {
131 self.nodes.read().get(node_id).map(|n| n.invoke_tx.clone())
132 }
133
134 pub fn contains(&self, node_id: &str) -> bool {
136 self.nodes.read().contains_key(node_id)
137 }
138
139 pub fn len(&self) -> usize {
141 self.nodes.read().len()
142 }
143
144 pub fn is_empty(&self) -> bool {
146 self.nodes.read().is_empty()
147 }
148}
149
150#[derive(Debug, Deserialize)]
152#[serde(tag = "type", rename_all = "snake_case")]
153enum NodeMessage {
154 Register {
155 node_id: String,
156 capabilities: Vec<NodeCapability>,
157 },
158 Result {
159 call_id: String,
160 success: bool,
161 output: String,
162 #[serde(default)]
163 error: Option<String>,
164 },
165}
166
167#[derive(Debug, Serialize)]
169#[serde(tag = "type", rename_all = "snake_case")]
170enum GatewayMessage {
171 #[allow(dead_code)] Registered {
173 node_id: String,
174 capabilities_count: usize,
175 },
176 Invoke {
177 call_id: String,
178 capability: String,
179 args: serde_json::Value,
180 },
181}
182
183#[derive(Deserialize)]
185pub struct NodeWsQuery {
186 pub token: Option<String>,
187}
188
189fn extract_node_ws_token<'a>(
191 headers: &'a HeaderMap,
192 query_token: Option<&'a str>,
193) -> Option<&'a str> {
194 if let Some(t) = headers
196 .get(header::AUTHORIZATION)
197 .and_then(|v| v.to_str().ok())
198 .and_then(|auth| auth.strip_prefix("Bearer "))
199 && !t.is_empty()
200 {
201 return Some(t);
202 }
203
204 if let Some(t) = headers
206 .get("sec-websocket-protocol")
207 .and_then(|v| v.to_str().ok())
208 .and_then(|protos| {
209 protos
210 .split(',')
211 .map(|p| p.trim())
212 .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
213 })
214 && !t.is_empty()
215 {
216 return Some(t);
217 }
218
219 if let Some(t) = query_token
221 && !t.is_empty()
222 {
223 return Some(t);
224 }
225
226 None
227}
228
229pub(crate) fn check_node_auth(
237 nodes_config: &zeroclaw_config::schema::NodesConfig,
238 pairing: &PairingGuard,
239 headers: &HeaderMap,
240 query_token: Option<&str>,
241) -> Option<(axum::http::StatusCode, &'static str)> {
242 if !nodes_config.enabled {
243 return Some((
244 axum::http::StatusCode::NOT_FOUND,
245 "Not Found — node discovery is disabled (set nodes.enabled=true to enable)",
246 ));
247 }
248 if let Some(ref expected_token) = nodes_config.auth_token {
249 let token = extract_node_ws_token(headers, query_token).unwrap_or("");
250 if token != expected_token {
251 return Some((
252 axum::http::StatusCode::UNAUTHORIZED,
253 "Unauthorized — provide a valid node auth token",
254 ));
255 }
256 } else if pairing.require_pairing() {
257 let token = extract_node_ws_token(headers, query_token).unwrap_or("");
258 if !pairing.is_authenticated(token) {
259 return Some((
260 axum::http::StatusCode::UNAUTHORIZED,
261 "Unauthorized — provide Authorization header or ?token= query param",
262 ));
263 }
264 } else {
265 return Some((
266 axum::http::StatusCode::SERVICE_UNAVAILABLE,
267 "Service Unavailable — node registration is disabled because no auth method is configured. \
268 Set nodes.auth_token OR enable gateway.require_pairing.",
269 ));
270 }
271 None
272}
273
274pub async fn handle_ws_nodes(
275 State(state): State<AppState>,
276 Query(params): Query<NodeWsQuery>,
277 headers: HeaderMap,
278 ws: WebSocketUpgrade,
279) -> impl IntoResponse {
280 let nodes_config = state.config.read().nodes.clone();
281 if let Some((status, body)) = check_node_auth(
282 &nodes_config,
283 &state.pairing,
284 &headers,
285 params.token.as_deref(),
286 ) {
287 return (status, body).into_response();
288 }
289
290 let ws = if headers
292 .get("sec-websocket-protocol")
293 .and_then(|v| v.to_str().ok())
294 .is_some_and(|protos| protos.split(',').any(|p| p.trim() == WS_NODE_PROTOCOL))
295 {
296 ws.protocols([WS_NODE_PROTOCOL])
297 } else {
298 ws
299 };
300
301 let registry = state.node_registry.clone();
302 ws.on_upgrade(move |socket| handle_node_socket(socket, registry))
303 .into_response()
304}
305
306async fn handle_node_socket(socket: WebSocket, registry: Arc<NodeRegistry>) {
307 let (mut sender, mut receiver) = socket.split();
308 let mut registered_node_id: Option<String> = None;
309
310 let (invoke_tx, mut invoke_rx) = mpsc::channel::<NodeInvocation>(32);
312
313 let pending: Arc<RwLock<HashMap<String, oneshot::Sender<NodeInvocationResult>>>> =
315 Arc::new(RwLock::new(HashMap::new()));
316
317 let pending_clone = Arc::clone(&pending);
318
319 let send_task = tokio::spawn(async move {
321 while let Some(invocation) = invoke_rx.recv().await {
322 let msg = GatewayMessage::Invoke {
323 call_id: invocation.call_id.clone(),
324 capability: invocation.capability,
325 args: invocation.args,
326 };
327 if let Ok(json) = serde_json::to_string(&msg) {
328 if sender.send(Message::Text(json.into())).await.is_err() {
329 break;
330 }
331 pending_clone
332 .write()
333 .insert(invocation.call_id, invocation.response_tx);
334 }
335 }
336 });
337
338 while let Some(msg) = receiver.next().await {
340 let text = match msg {
341 Ok(Message::Text(text)) => text,
342 Ok(Message::Close(_)) | Err(_) => break,
343 _ => continue,
344 };
345
346 let parsed: serde_json::Value = match serde_json::from_str(&text) {
347 Ok(v) => v,
348 Err(_) => continue,
349 };
350
351 let node_msg: NodeMessage = match serde_json::from_value(parsed) {
353 Ok(m) => m,
354 Err(_) => continue,
355 };
356
357 match node_msg {
358 NodeMessage::Register {
359 node_id,
360 capabilities,
361 } => {
362 if node_id.is_empty() || node_id.len() > 128 {
364 ::zeroclaw_log::record!(
365 WARN,
366 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
367 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
368 "Node registration rejected: invalid node_id length"
369 );
370 continue;
371 }
372
373 let caps_count = capabilities.len();
374 let info = NodeInfo {
375 node_id: node_id.clone(),
376 capabilities,
377 invoke_tx: invoke_tx.clone(),
378 };
379
380 if registry.register(info) {
381 ::zeroclaw_log::record!(
382 INFO,
383 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
384 .with_attrs(
385 ::serde_json::json!({"node_id": node_id, "caps_count": caps_count})
386 ),
387 "Node registered: with capabilities"
388 );
389 registered_node_id = Some(node_id.clone());
390
391 } else {
398 ::zeroclaw_log::record!(
399 WARN,
400 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
401 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
402 .with_attrs(::serde_json::json!({"node_id": node_id})),
403 "Node registration rejected: registry at capacity for"
404 );
405 }
406 }
407 NodeMessage::Result {
408 call_id,
409 success,
410 output,
411 error,
412 } => {
413 if let Some(tx) = pending.write().remove(&call_id) {
414 let _ = tx.send(NodeInvocationResult {
415 success,
416 output,
417 error,
418 });
419 }
420 }
421 }
422 }
423
424 if let Some(node_id) = registered_node_id {
426 registry.unregister(&node_id);
427 ::zeroclaw_log::record!(
428 INFO,
429 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
430 .with_attrs(::serde_json::json!({"node_id": node_id})),
431 "Node disconnected and unregistered"
432 );
433 }
434
435 send_task.abort();
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use axum::http::{HeaderMap, StatusCode};
442 use zeroclaw_config::schema::NodesConfig;
443 use zeroclaw_runtime::security::pairing::PairingGuard;
444
445 fn empty_headers() -> HeaderMap {
448 HeaderMap::new()
449 }
450
451 fn bearer_headers(token: &str) -> HeaderMap {
452 let mut h = HeaderMap::new();
453 h.insert("authorization", format!("Bearer {token}").parse().unwrap());
454 h
455 }
456
457 fn make_pairing(require: bool) -> PairingGuard {
458 PairingGuard::new(require, &[])
459 }
460
461 #[test]
463 fn nodes_disabled_returns_404() {
464 let cfg = NodesConfig {
465 enabled: false,
466 ..NodesConfig::default()
467 };
468 let result = check_node_auth(&cfg, &make_pairing(false), &empty_headers(), None);
469 assert_eq!(result.map(|(s, _)| s), Some(StatusCode::NOT_FOUND));
470 }
471
472 #[test]
475 fn nodes_enabled_no_auth_no_pairing_returns_503() {
476 let cfg = NodesConfig {
477 enabled: true,
478 auth_token: None,
479 ..NodesConfig::default()
480 };
481 let result = check_node_auth(&cfg, &make_pairing(false), &empty_headers(), None);
482 assert_eq!(
483 result.map(|(s, _)| s),
484 Some(StatusCode::SERVICE_UNAVAILABLE)
485 );
486 }
487
488 #[test]
490 fn nodes_auth_token_wrong_token_returns_401() {
491 let cfg = NodesConfig {
492 enabled: true,
493 auth_token: Some("secret".into()),
494 ..NodesConfig::default()
495 };
496 let result = check_node_auth(&cfg, &make_pairing(false), &empty_headers(), None);
497 assert_eq!(result.map(|(s, _)| s), Some(StatusCode::UNAUTHORIZED));
498 }
499
500 #[test]
502 fn nodes_auth_token_correct_token_passes() {
503 let cfg = NodesConfig {
504 enabled: true,
505 auth_token: Some("secret".into()),
506 ..NodesConfig::default()
507 };
508 let headers = bearer_headers("secret");
509 let result = check_node_auth(&cfg, &make_pairing(false), &headers, None);
510 assert!(result.is_none(), "correct token must pass auth gate");
511 }
512
513 #[test]
515 fn nodes_pairing_required_wrong_token_returns_401() {
516 let cfg = NodesConfig {
517 enabled: true,
518 auth_token: None,
519 ..NodesConfig::default()
520 };
521 let result = check_node_auth(&cfg, &make_pairing(true), &empty_headers(), None);
522 assert_eq!(result.map(|(s, _)| s), Some(StatusCode::UNAUTHORIZED));
523 }
524
525 #[test]
526 fn node_registry_register_and_unregister() {
527 let registry = NodeRegistry::new(10);
528 let (tx, _rx) = mpsc::channel(1);
529
530 let info = NodeInfo {
531 node_id: "test-node".to_string(),
532 capabilities: vec![NodeCapability {
533 name: "ping".to_string(),
534 description: "Ping test".to_string(),
535 parameters: serde_json::json!({"type": "object", "properties": {}}),
536 }],
537 invoke_tx: tx,
538 };
539
540 assert!(registry.register(info));
541 assert!(registry.contains("test-node"));
542 assert_eq!(registry.len(), 1);
543
544 registry.unregister("test-node");
545 assert!(!registry.contains("test-node"));
546 assert_eq!(registry.len(), 0);
547 }
548
549 #[test]
550 fn node_registry_capacity_limit() {
551 let registry = NodeRegistry::new(2);
552
553 for i in 0..2 {
554 let (tx, _rx) = mpsc::channel(1);
555 let info = NodeInfo {
556 node_id: format!("node-{i}"),
557 capabilities: vec![],
558 invoke_tx: tx,
559 };
560 assert!(registry.register(info));
561 }
562
563 let (tx, _rx) = mpsc::channel(1);
564 let info = NodeInfo {
565 node_id: "node-overflow".to_string(),
566 capabilities: vec![],
567 invoke_tx: tx,
568 };
569 assert!(!registry.register(info));
570 assert_eq!(registry.len(), 2);
571 }
572
573 #[test]
574 fn node_registry_re_register_same_id() {
575 let registry = NodeRegistry::new(2);
576 let (tx1, _rx1) = mpsc::channel(1);
577 let (tx2, _rx2) = mpsc::channel(1);
578
579 let info1 = NodeInfo {
580 node_id: "node-1".to_string(),
581 capabilities: vec![NodeCapability {
582 name: "old".to_string(),
583 description: "Old cap".to_string(),
584 parameters: serde_json::json!({"type": "object", "properties": {}}),
585 }],
586 invoke_tx: tx1,
587 };
588 assert!(registry.register(info1));
589
590 let info2 = NodeInfo {
591 node_id: "node-1".to_string(),
592 capabilities: vec![NodeCapability {
593 name: "new".to_string(),
594 description: "New cap".to_string(),
595 parameters: serde_json::json!({"type": "object", "properties": {}}),
596 }],
597 invoke_tx: tx2,
598 };
599 assert!(registry.register(info2));
601 assert_eq!(registry.len(), 1);
602
603 let caps = registry.all_capabilities();
604 assert_eq!(caps.len(), 1);
605 assert_eq!(caps[0].2.name, "new");
606 }
607
608 #[test]
609 fn node_registry_all_capabilities() {
610 let registry = NodeRegistry::new(10);
611 let (tx1, _rx1) = mpsc::channel(1);
612 let (tx2, _rx2) = mpsc::channel(1);
613
614 registry.register(NodeInfo {
615 node_id: "phone-1".to_string(),
616 capabilities: vec![
617 NodeCapability {
618 name: "camera.snap".to_string(),
619 description: "Take a photo".to_string(),
620 parameters: serde_json::json!({"type": "object", "properties": {}}),
621 },
622 NodeCapability {
623 name: "gps.location".to_string(),
624 description: "Get GPS location".to_string(),
625 parameters: serde_json::json!({"type": "object", "properties": {}}),
626 },
627 ],
628 invoke_tx: tx1,
629 });
630
631 registry.register(NodeInfo {
632 node_id: "sensor-1".to_string(),
633 capabilities: vec![NodeCapability {
634 name: "temp.read".to_string(),
635 description: "Read temperature".to_string(),
636 parameters: serde_json::json!({"type": "object", "properties": {}}),
637 }],
638 invoke_tx: tx2,
639 });
640
641 let caps = registry.all_capabilities();
642 assert_eq!(caps.len(), 3);
643 }
644
645 #[test]
646 fn node_registry_is_empty() {
647 let registry = NodeRegistry::new(10);
648 assert!(registry.is_empty());
649
650 let (tx, _rx) = mpsc::channel(1);
651 registry.register(NodeInfo {
652 node_id: "n".to_string(),
653 capabilities: vec![],
654 invoke_tx: tx,
655 });
656 assert!(!registry.is_empty());
657 }
658
659 #[test]
660 fn node_capability_deserialize() {
661 let json = r#"{"name":"camera.snap","description":"Take a photo"}"#;
662 let cap: NodeCapability = serde_json::from_str(json).unwrap();
663 assert_eq!(cap.name, "camera.snap");
664 assert_eq!(cap.description, "Take a photo");
665 assert_eq!(cap.parameters["type"], "object");
667 }
668
669 #[test]
670 fn node_message_register_deserialize() {
671 let json = r#"{"type":"register","node_id":"phone-1","capabilities":[{"name":"camera.snap","description":"Take a photo","parameters":{"type":"object","properties":{"resolution":{"type":"string"}}}}]}"#;
672 let msg: NodeMessage = serde_json::from_str(json).unwrap();
673 match msg {
674 NodeMessage::Register {
675 node_id,
676 capabilities,
677 } => {
678 assert_eq!(node_id, "phone-1");
679 assert_eq!(capabilities.len(), 1);
680 assert_eq!(capabilities[0].name, "camera.snap");
681 }
682 NodeMessage::Result { .. } => panic!("Expected Register message"),
683 }
684 }
685
686 #[test]
687 fn node_message_result_deserialize() {
688 let json = r#"{"type":"result","call_id":"abc-123","success":true,"output":"photo taken"}"#;
689 let msg: NodeMessage = serde_json::from_str(json).unwrap();
690 match msg {
691 NodeMessage::Result {
692 call_id,
693 success,
694 output,
695 error,
696 } => {
697 assert_eq!(call_id, "abc-123");
698 assert!(success);
699 assert_eq!(output, "photo taken");
700 assert!(error.is_none());
701 }
702 NodeMessage::Register { .. } => panic!("Expected Result message"),
703 }
704 }
705
706 #[test]
707 fn gateway_message_serialize() {
708 let msg = GatewayMessage::Registered {
709 node_id: "phone-1".to_string(),
710 capabilities_count: 3,
711 };
712 let json = serde_json::to_string(&msg).unwrap();
713 assert!(json.contains("\"type\":\"registered\""));
714 assert!(json.contains("\"node_id\":\"phone-1\""));
715 assert!(json.contains("\"capabilities_count\":3"));
716 }
717
718 #[test]
719 fn gateway_invoke_message_serialize() {
720 let msg = GatewayMessage::Invoke {
721 call_id: "call-1".to_string(),
722 capability: "camera.snap".to_string(),
723 args: serde_json::json!({"resolution": "1080p"}),
724 };
725 let json = serde_json::to_string(&msg).unwrap();
726 assert!(json.contains("\"type\":\"invoke\""));
727 assert!(json.contains("\"capability\":\"camera.snap\""));
728 }
729
730 #[test]
731 fn extract_node_ws_token_from_header() {
732 let mut headers = HeaderMap::new();
733 headers.insert("authorization", "Bearer node_tok_123".parse().unwrap());
734 assert_eq!(extract_node_ws_token(&headers, None), Some("node_tok_123"));
735 }
736
737 #[test]
738 fn extract_node_ws_token_from_query() {
739 let headers = HeaderMap::new();
740 assert_eq!(
741 extract_node_ws_token(&headers, Some("node_tok_456")),
742 Some("node_tok_456")
743 );
744 }
745
746 #[test]
747 fn extract_node_ws_token_none_when_empty() {
748 let headers = HeaderMap::new();
749 assert_eq!(extract_node_ws_token(&headers, None), None);
750 }
751}