zeroclaw_runtime/rpc/
wss.rs1use super::context::RpcContext;
7use super::dispatch::RpcDispatcher;
8use super::transport::RpcTransport;
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use futures_util::{SinkExt, StreamExt};
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::time::Duration;
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::mpsc;
18use tokio_rustls::TlsAcceptor;
19use tokio_tungstenite::WebSocketStream;
20use tokio_tungstenite::tungstenite::Message;
21use tokio_util::sync::CancellationToken;
22
23type TlsStream = tokio_rustls::server::TlsStream<TcpStream>;
24
25const HEARTBEAT_IDLE: Duration = Duration::from_secs(20);
27
28const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(10);
31
32enum Control {
37 Ping,
38}
39
40pub struct WssTransport {
41 reader: futures_util::stream::SplitStream<WebSocketStream<TlsStream>>,
42 writer_tx: mpsc::Sender<String>,
43 control_tx: mpsc::Sender<Control>,
44 peer_label: String,
45 awaiting_pong: bool,
48}
49
50impl WssTransport {
51 pub fn new(ws: WebSocketStream<TlsStream>, remote_addr: SocketAddr) -> Self {
52 let peer_label = format!("wss:{remote_addr}");
53 let (sink, stream) = ws.split();
54
55 let (writer_tx, mut writer_rx) = mpsc::channel::<String>(64);
56 let (control_tx, mut control_rx) = mpsc::channel::<Control>(8);
57 zeroclaw_spawn::spawn!(async move {
58 let mut sink = sink;
59 loop {
60 let msg = tokio::select! {
61 line = writer_rx.recv() => match line {
62 Some(line) => Message::Text(line.into()),
63 None => break,
64 },
65 ctrl = control_rx.recv() => match ctrl {
66 Some(Control::Ping) => Message::Ping(Vec::new().into()),
67 None => break,
68 },
69 };
70 if sink.send(msg).await.is_err() {
71 break;
72 }
73 }
74 });
75
76 Self {
77 reader: stream,
78 writer_tx,
79 control_tx,
80 peer_label,
81 awaiting_pong: false,
82 }
83 }
84}
85
86#[async_trait]
87impl RpcTransport for WssTransport {
88 fn writer(&self) -> mpsc::Sender<String> {
89 self.writer_tx.clone()
90 }
91
92 async fn next_frame(&mut self) -> Option<String> {
93 loop {
94 let idle = if self.awaiting_pong {
95 HEARTBEAT_TIMEOUT
96 } else {
97 HEARTBEAT_IDLE
98 };
99
100 match tokio::time::timeout(idle, self.reader.next()).await {
101 Err(_) if self.awaiting_pong => return None,
102 Err(_) => {
103 if self.control_tx.send(Control::Ping).await.is_err() {
104 return None;
105 }
106 self.awaiting_pong = true;
107 }
108 Ok(frame) => {
109 self.awaiting_pong = false;
110 match frame {
111 Some(Ok(Message::Text(text))) => return Some(text.to_string()),
112 Some(Ok(Message::Close(_))) | None => return None,
113 Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {
114 continue;
115 }
116 Some(Ok(Message::Binary(_))) => continue,
117 Some(Err(_)) => return None,
118 }
119 }
120 }
121 }
122 }
123
124 fn peer_label(&self) -> String {
125 self.peer_label.clone()
126 }
127}
128
129pub fn build_tls_acceptor(cert_path: &str, key_path: &str) -> Result<TlsAcceptor> {
133 use rustls::ServerConfig;
134 use rustls_pemfile::{certs, private_key};
135 use std::fs::File;
136 use std::io::BufReader;
137
138 let cert_file =
139 File::open(cert_path).with_context(|| format!("opening TLS cert: {cert_path}"))?;
140 let key_file = File::open(key_path).with_context(|| format!("opening TLS key: {key_path}"))?;
141
142 let certs: Vec<_> = certs(&mut BufReader::new(cert_file))
143 .collect::<Result<Vec<_>, _>>()
144 .context("parsing TLS certificates")?;
145
146 let key = private_key(&mut BufReader::new(key_file))
147 .context("parsing TLS private key")?
148 .context("no private key found in key file")?;
149
150 let config = ServerConfig::builder()
151 .with_no_client_auth()
152 .with_single_cert(certs, key)
153 .context("building TLS server config")?;
154
155 Ok(TlsAcceptor::from(Arc::new(config)))
156}
157
158pub async fn run_wss_listener(
165 ctx: Arc<RpcContext>,
166 cancel: CancellationToken,
167 client_count: Arc<AtomicUsize>,
168 tls_acceptor: TlsAcceptor,
169 bind_addr: SocketAddr,
170) -> Result<()> {
171 let listener = TcpListener::bind(bind_addr)
172 .await
173 .with_context(|| format!("binding WSS listener on {bind_addr}"))?;
174
175 ::zeroclaw_log::record!(
176 INFO,
177 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
178 .with_attrs(::serde_json::json!({"addr": bind_addr.to_string()})),
179 "RPC WSS listener started"
180 );
181
182 loop {
183 tokio::select! {
184 _ = cancel.cancelled() => {
185 ::zeroclaw_log::record!(
186 INFO,
187 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
188 "RPC WSS listener shutting down"
189 );
190 break;
191 }
192 accept = listener.accept() => {
193 let (tcp_stream, remote_addr) = match accept {
194 Ok(v) => v,
195 Err(e) => {
196 ::zeroclaw_log::record!(
197 WARN,
198 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
199 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
200 &format!("WSS accept error: {e}")
201 );
202 continue;
203 }
204 };
205
206 let ctx = ctx.clone();
207 let count = client_count.clone();
208 let acceptor = tls_acceptor.clone();
209
210 count.fetch_add(1, Ordering::Relaxed);
211
212 zeroclaw_spawn::spawn!(async move {
213 let tls_stream = match acceptor.accept(tcp_stream).await {
215 Ok(s) => s,
216 Err(e) => {
217 ::zeroclaw_log::record!(
218 WARN,
219 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
220 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
221 &format!("WSS TLS handshake failed from {remote_addr}: {e}")
222 );
223 count.fetch_sub(1, Ordering::Relaxed);
224 return;
225 }
226 };
227
228 let ws_stream = match tokio_tungstenite::accept_async(tls_stream).await {
230 Ok(ws) => ws,
231 Err(e) => {
232 ::zeroclaw_log::record!(
233 WARN,
234 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
235 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
236 &format!("WSS WebSocket upgrade failed from {remote_addr}: {e}")
237 );
238 count.fetch_sub(1, Ordering::Relaxed);
239 return;
240 }
241 };
242
243 let mut transport = WssTransport::new(ws_stream, remote_addr);
244 let peer = transport.peer_label();
245 let writer_tx = transport.writer();
246 let mut dispatcher = RpcDispatcher::new(ctx.clone(), writer_tx, peer);
247 dispatcher.run(&mut transport).await;
248
249 if let Some(tui_id) = dispatcher.tui_id() {
250 ctx.tui_registry.unregister(tui_id);
251 use ::zeroclaw_log::Instrument as _;
252 let span = ::zeroclaw_log::info_span!(
253 target: "zeroclaw_log_internal_scope",
254 "zeroclaw_scope",
255 owner_tui_id = %tui_id,
256 channel = "wss",
257 );
258 async {
259 ::zeroclaw_log::record!(
260 INFO,
261 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
262 .with_category(::zeroclaw_log::EventCategory::Agent),
263 "WSS TUI disconnected; sessions retained (persistent)"
264 );
265 }
266 .instrument(span)
267 .await;
268 }
269
270 count.fetch_sub(1, Ordering::Relaxed);
271 });
272 }
273 }
274 }
275
276 Ok(())
277}