1use super::AppState;
4use axum::{
5 extract::{
6 Query, State, WebSocketUpgrade,
7 ws::{Message, WebSocket},
8 },
9 http::HeaderMap,
10 response::IntoResponse,
11};
12use futures_util::{SinkExt, StreamExt};
13use serde::Deserialize;
14use std::sync::Arc;
15use tokio::sync::mpsc;
16use zeroclaw_channels::orchestrator::acp_server::{AcpServer, AcpServerConfig};
17use zeroclaw_infra::acp_session_store::AcpSessionStore;
18
19const ACP_WS_PROTOCOL: &str = "zeroclaw.acp.v1";
20
21#[derive(Debug, Deserialize)]
22pub struct AcpQuery {
23 token: Option<String>,
24}
25
26pub async fn handle_ws_acp(
27 State(state): State<AppState>,
28 Query(params): Query<AcpQuery>,
29 headers: HeaderMap,
30 ws: WebSocketUpgrade,
31) -> impl IntoResponse {
32 if state.pairing.require_pairing() {
33 let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
34 if !state.pairing.is_authenticated(token) {
35 return (
36 axum::http::StatusCode::UNAUTHORIZED,
37 "Unauthorized - provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param",
38 )
39 .into_response();
40 }
41 }
42
43 let ws = if headers
44 .get("sec-websocket-protocol")
45 .and_then(|v| v.to_str().ok())
46 .is_some_and(|protos| protos.split(',').any(|p| p.trim() == ACP_WS_PROTOCOL))
47 {
48 ws.protocols([ACP_WS_PROTOCOL])
49 } else {
50 ws
51 };
52
53 ws.on_upgrade(move |socket| handle_socket(socket, state))
54 .into_response()
55}
56
57async fn handle_socket(socket: WebSocket, state: AppState) {
58 let (mut sender, mut receiver) = socket.split();
59 let (input_tx, input_rx) = mpsc::channel::<String>(256);
60 let (output_tx, mut output_rx) = mpsc::channel::<String>(256);
61
62 let config = state.config.read().clone();
63 let acp_config = AcpServerConfig {
64 max_sessions: config.acp.max_sessions,
65 session_timeout_secs: config.acp.session_timeout_secs,
66 };
67 let store = AcpSessionStore::new(&config.data_dir)
68 .map(Arc::new)
69 .inspect_err(|e| {
70 ::zeroclaw_log::record!(
71 WARN,
72 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
73 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
74 .with_attrs(::serde_json::json!({"error": e.to_string()})),
75 "Failed to open ACP session store"
76 );
77 })
78 .ok();
79 let canvas_store = state.canvas_store.clone();
80 let server = if let Some(store) = store {
81 Arc::new(
82 AcpServer::new_with_writer_and_store(config, acp_config, output_tx, store)
83 .with_canvas_store(canvas_store),
84 )
85 } else {
86 Arc::new(
87 AcpServer::new_with_writer(config, acp_config, output_tx)
88 .with_canvas_store(canvas_store),
89 )
90 };
91
92 let server_task = tokio::spawn(Arc::clone(&server).run_messages(input_rx));
93
94 let output_task = tokio::spawn(async move {
95 while let Some(line) = output_rx.recv().await {
96 if sender.send(Message::Text(line.into())).await.is_err() {
97 break;
98 }
99 }
100 });
101
102 while let Some(message) = receiver.next().await {
103 match message {
104 Ok(Message::Text(text)) => {
105 if input_tx.send(text.to_string()).await.is_err() {
106 break;
107 }
108 }
109 Ok(Message::Binary(bytes)) => match String::from_utf8(bytes.to_vec()) {
110 Ok(text) => {
111 if input_tx.send(text).await.is_err() {
112 break;
113 }
114 }
115 Err(e) => ::zeroclaw_log::record!(
116 WARN,
117 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
118 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
119 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
120 "ACP WebSocket received non-UTF-8 binary frame"
121 ),
122 },
123 Ok(Message::Close(_)) => break,
124 Ok(Message::Ping(_) | Message::Pong(_)) => {}
125 Err(e) => {
126 let msg = e.to_string();
127 if msg.contains("Connection reset without closing handshake")
128 || msg.contains("Connection closed normally")
129 {
130 ::zeroclaw_log::record!(
131 DEBUG,
132 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
133 "ACP WebSocket closed without handshake"
134 );
135 } else {
136 ::zeroclaw_log::record!(
137 WARN,
138 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
139 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
140 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
141 "ACP WebSocket receive error"
142 );
143 }
144 break;
145 }
146 }
147 }
148
149 drop(input_tx);
150
151 if let Err(e) = server_task.await {
152 ::zeroclaw_log::record!(
153 WARN,
154 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
155 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
156 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
157 "ACP WebSocket server task panicked"
158 );
159 }
160 output_task.abort();
161 ::zeroclaw_log::record!(
162 DEBUG,
163 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
164 "ACP WebSocket disconnected"
165 );
166}
167
168fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
169 headers
170 .get(axum::http::header::AUTHORIZATION)
171 .and_then(|v| v.to_str().ok())
172 .and_then(|auth| auth.strip_prefix("Bearer "))
173 .map(str::trim)
174 .filter(|token| !token.is_empty())
175 .or_else(|| {
176 headers
177 .get(axum::http::header::SEC_WEBSOCKET_PROTOCOL)
178 .and_then(|v| v.to_str().ok())
179 .and_then(|protocols| {
180 protocols
181 .split(',')
182 .map(str::trim)
183 .find_map(|p| p.strip_prefix("bearer."))
184 })
185 .filter(|token| !token.is_empty())
186 })
187 .or_else(|| query_token.filter(|token| !token.is_empty()))
188}