Skip to main content

zeroclaw_runtime/rpc/
wss.rs

1//! WebSocket Secure (WSS) transport for the RPC layer.
2//!
3//! Mirrors the Unix socket transport (`unix.rs`) but uses TLS-encrypted
4//! WebSocket connections, enabling remote TUI-to-daemon connectivity.
5
6use super::context::RpcContext;
7use super::dispatch::RpcDispatcher;
8use super::transport::RpcTransport;
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use futures_util::{SinkExt, StreamExt};
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::time::Duration;
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::mpsc;
18use tokio_rustls::TlsAcceptor;
19use tokio_tungstenite::WebSocketStream;
20use tokio_tungstenite::tungstenite::Message;
21use tokio_util::sync::CancellationToken;
22
23type TlsStream = tokio_rustls::server::TlsStream<TcpStream>;
24
25/// How long the read side waits for any frame before sending a liveness Ping.
26const HEARTBEAT_IDLE: Duration = Duration::from_secs(20);
27
28/// How long to wait after a Ping for any frame (a Pong, or anything else)
29/// before declaring the peer dead and tearing the connection down.
30const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(10);
31
32// ── Transport ────────────────────────────────────────────────────
33
34/// Control frames the read side asks the writer task to emit out-of-band
35/// from the JSON-RPC text stream.
36enum Control {
37    Ping,
38}
39
40pub struct WssTransport {
41    reader: futures_util::stream::SplitStream<WebSocketStream<TlsStream>>,
42    writer_tx: mpsc::Sender<String>,
43    control_tx: mpsc::Sender<Control>,
44    peer_label: String,
45    /// Set once a Ping has been sent and we are awaiting any reply. Detects a
46    /// peer that went silent on a half-open TCP connection (no FIN/RST).
47    awaiting_pong: bool,
48}
49
50impl WssTransport {
51    pub fn new(ws: WebSocketStream<TlsStream>, remote_addr: SocketAddr) -> Self {
52        let peer_label = format!("wss:{remote_addr}");
53        let (sink, stream) = ws.split();
54
55        let (writer_tx, mut writer_rx) = mpsc::channel::<String>(64);
56        let (control_tx, mut control_rx) = mpsc::channel::<Control>(8);
57        zeroclaw_spawn::spawn!(async move {
58            let mut sink = sink;
59            loop {
60                let msg = tokio::select! {
61                    line = writer_rx.recv() => match line {
62                        Some(line) => Message::Text(line.into()),
63                        None => break,
64                    },
65                    ctrl = control_rx.recv() => match ctrl {
66                        Some(Control::Ping) => Message::Ping(Vec::new().into()),
67                        None => break,
68                    },
69                };
70                if sink.send(msg).await.is_err() {
71                    break;
72                }
73            }
74        });
75
76        Self {
77            reader: stream,
78            writer_tx,
79            control_tx,
80            peer_label,
81            awaiting_pong: false,
82        }
83    }
84}
85
86#[async_trait]
87impl RpcTransport for WssTransport {
88    fn writer(&self) -> mpsc::Sender<String> {
89        self.writer_tx.clone()
90    }
91
92    async fn next_frame(&mut self) -> Option<String> {
93        loop {
94            let idle = if self.awaiting_pong {
95                HEARTBEAT_TIMEOUT
96            } else {
97                HEARTBEAT_IDLE
98            };
99
100            match tokio::time::timeout(idle, self.reader.next()).await {
101                Err(_) if self.awaiting_pong => return None,
102                Err(_) => {
103                    if self.control_tx.send(Control::Ping).await.is_err() {
104                        return None;
105                    }
106                    self.awaiting_pong = true;
107                }
108                Ok(frame) => {
109                    self.awaiting_pong = false;
110                    match frame {
111                        Some(Ok(Message::Text(text))) => return Some(text.to_string()),
112                        Some(Ok(Message::Close(_))) | None => return None,
113                        Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {
114                            continue;
115                        }
116                        Some(Ok(Message::Binary(_))) => continue,
117                        Some(Err(_)) => return None,
118                    }
119                }
120            }
121        }
122    }
123
124    fn peer_label(&self) -> String {
125        self.peer_label.clone()
126    }
127}
128
129// ── TLS acceptor ─────────────────────────────────────────────────
130
131/// Build a `TlsAcceptor` from PEM-encoded cert and key files.
132pub fn build_tls_acceptor(cert_path: &str, key_path: &str) -> Result<TlsAcceptor> {
133    use rustls::ServerConfig;
134    use rustls_pemfile::{certs, private_key};
135    use std::fs::File;
136    use std::io::BufReader;
137
138    let cert_file =
139        File::open(cert_path).with_context(|| format!("opening TLS cert: {cert_path}"))?;
140    let key_file = File::open(key_path).with_context(|| format!("opening TLS key: {key_path}"))?;
141
142    let certs: Vec<_> = certs(&mut BufReader::new(cert_file))
143        .collect::<Result<Vec<_>, _>>()
144        .context("parsing TLS certificates")?;
145
146    let key = private_key(&mut BufReader::new(key_file))
147        .context("parsing TLS private key")?
148        .context("no private key found in key file")?;
149
150    let config = ServerConfig::builder()
151        .with_no_client_auth()
152        .with_single_cert(certs, key)
153        .context("building TLS server config")?;
154
155    Ok(TlsAcceptor::from(Arc::new(config)))
156}
157
158// ── Listener ─────────────────────────────────────────────────────
159
160/// Run the WSS RPC listener as a daemon subsystem.
161///
162/// `client_count` is incremented on connect, decremented on disconnect —
163/// shared with the Unix socket listener for `--ephemeral` shutdown logic.
164pub async fn run_wss_listener(
165    ctx: Arc<RpcContext>,
166    cancel: CancellationToken,
167    client_count: Arc<AtomicUsize>,
168    tls_acceptor: TlsAcceptor,
169    bind_addr: SocketAddr,
170) -> Result<()> {
171    let listener = TcpListener::bind(bind_addr)
172        .await
173        .with_context(|| format!("binding WSS listener on {bind_addr}"))?;
174
175    ::zeroclaw_log::record!(
176        INFO,
177        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
178            .with_attrs(::serde_json::json!({"addr": bind_addr.to_string()})),
179        "RPC WSS listener started"
180    );
181
182    loop {
183        tokio::select! {
184            _ = cancel.cancelled() => {
185                ::zeroclaw_log::record!(
186                    INFO,
187                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
188                    "RPC WSS listener shutting down"
189                );
190                break;
191            }
192            accept = listener.accept() => {
193                let (tcp_stream, remote_addr) = match accept {
194                    Ok(v) => v,
195                    Err(e) => {
196                        ::zeroclaw_log::record!(
197                            WARN,
198                            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
199                                .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
200                            &format!("WSS accept error: {e}")
201                        );
202                        continue;
203                    }
204                };
205
206                let ctx = ctx.clone();
207                let count = client_count.clone();
208                let acceptor = tls_acceptor.clone();
209
210                count.fetch_add(1, Ordering::Relaxed);
211
212                zeroclaw_spawn::spawn!(async move {
213                    // TLS handshake.
214                    let tls_stream = match acceptor.accept(tcp_stream).await {
215                        Ok(s) => s,
216                        Err(e) => {
217                            ::zeroclaw_log::record!(
218                                WARN,
219                                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
220                                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
221                                &format!("WSS TLS handshake failed from {remote_addr}: {e}")
222                            );
223                            count.fetch_sub(1, Ordering::Relaxed);
224                            return;
225                        }
226                    };
227
228                    // WebSocket upgrade.
229                    let ws_stream = match tokio_tungstenite::accept_async(tls_stream).await {
230                        Ok(ws) => ws,
231                        Err(e) => {
232                            ::zeroclaw_log::record!(
233                                WARN,
234                                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
235                                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
236                                &format!("WSS WebSocket upgrade failed from {remote_addr}: {e}")
237                            );
238                            count.fetch_sub(1, Ordering::Relaxed);
239                            return;
240                        }
241                    };
242
243                    let mut transport = WssTransport::new(ws_stream, remote_addr);
244                    let peer = transport.peer_label();
245                    let writer_tx = transport.writer();
246                    let mut dispatcher = RpcDispatcher::new(ctx.clone(), writer_tx, peer);
247                    dispatcher.run(&mut transport).await;
248
249                    if let Some(tui_id) = dispatcher.tui_id() {
250                        ctx.tui_registry.unregister(tui_id);
251                        use ::zeroclaw_log::Instrument as _;
252                        let span = ::zeroclaw_log::info_span!(
253                            target: "zeroclaw_log_internal_scope",
254                            "zeroclaw_scope",
255                            owner_tui_id = %tui_id,
256                            channel = "wss",
257                        );
258                        async {
259                            ::zeroclaw_log::record!(
260                                INFO,
261                                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
262                                    .with_category(::zeroclaw_log::EventCategory::Agent),
263                                "WSS TUI disconnected; sessions retained (persistent)"
264                            );
265                        }
266                        .instrument(span)
267                        .await;
268                    }
269
270                    count.fetch_sub(1, Ordering::Relaxed);
271                });
272            }
273        }
274    }
275
276    Ok(())
277}