Skip to main content

zeroclaw_gateway/
acp.rs

1//! ACP-over-WebSocket gateway endpoint.
2
3use super::AppState;
4use axum::{
5    extract::{
6        Query, State, WebSocketUpgrade,
7        ws::{Message, WebSocket},
8    },
9    http::HeaderMap,
10    response::IntoResponse,
11};
12use futures_util::{SinkExt, StreamExt};
13use serde::Deserialize;
14use std::sync::Arc;
15use tokio::sync::mpsc;
16use zeroclaw_channels::orchestrator::acp_server::{AcpServer, AcpServerConfig};
17use zeroclaw_infra::acp_session_store::AcpSessionStore;
18
19const ACP_WS_PROTOCOL: &str = "zeroclaw.acp.v1";
20
21#[derive(Debug, Deserialize)]
22pub struct AcpQuery {
23    token: Option<String>,
24}
25
26pub async fn handle_ws_acp(
27    State(state): State<AppState>,
28    Query(params): Query<AcpQuery>,
29    headers: HeaderMap,
30    ws: WebSocketUpgrade,
31) -> impl IntoResponse {
32    if state.pairing.require_pairing() {
33        let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
34        if !state.pairing.is_authenticated(token) {
35            return (
36                axum::http::StatusCode::UNAUTHORIZED,
37                "Unauthorized - provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param",
38            )
39                .into_response();
40        }
41    }
42
43    let ws = if headers
44        .get("sec-websocket-protocol")
45        .and_then(|v| v.to_str().ok())
46        .is_some_and(|protos| protos.split(',').any(|p| p.trim() == ACP_WS_PROTOCOL))
47    {
48        ws.protocols([ACP_WS_PROTOCOL])
49    } else {
50        ws
51    };
52
53    ws.on_upgrade(move |socket| handle_socket(socket, state))
54        .into_response()
55}
56
57async fn handle_socket(socket: WebSocket, state: AppState) {
58    let (mut sender, mut receiver) = socket.split();
59    let (input_tx, input_rx) = mpsc::channel::<String>(256);
60    let (output_tx, mut output_rx) = mpsc::channel::<String>(256);
61
62    let config = state.config.read().clone();
63    let acp_config = AcpServerConfig {
64        max_sessions: config.acp.max_sessions,
65        session_timeout_secs: config.acp.session_timeout_secs,
66    };
67    let store = AcpSessionStore::new(&config.data_dir)
68        .map(Arc::new)
69        .inspect_err(|e| {
70            ::zeroclaw_log::record!(
71                WARN,
72                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
73                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
74                    .with_attrs(::serde_json::json!({"error": e.to_string()})),
75                "Failed to open ACP session store"
76            );
77        })
78        .ok();
79    let canvas_store = state.canvas_store.clone();
80    let server = if let Some(store) = store {
81        Arc::new(
82            AcpServer::new_with_writer_and_store(config, acp_config, output_tx, store)
83                .with_canvas_store(canvas_store),
84        )
85    } else {
86        Arc::new(
87            AcpServer::new_with_writer(config, acp_config, output_tx)
88                .with_canvas_store(canvas_store),
89        )
90    };
91
92    let server_task = tokio::spawn(Arc::clone(&server).run_messages(input_rx));
93
94    let output_task = tokio::spawn(async move {
95        while let Some(line) = output_rx.recv().await {
96            if sender.send(Message::Text(line.into())).await.is_err() {
97                break;
98            }
99        }
100    });
101
102    while let Some(message) = receiver.next().await {
103        match message {
104            Ok(Message::Text(text)) => {
105                if input_tx.send(text.to_string()).await.is_err() {
106                    break;
107                }
108            }
109            Ok(Message::Binary(bytes)) => match String::from_utf8(bytes.to_vec()) {
110                Ok(text) => {
111                    if input_tx.send(text).await.is_err() {
112                        break;
113                    }
114                }
115                Err(e) => ::zeroclaw_log::record!(
116                    WARN,
117                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
118                        .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
119                        .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
120                    "ACP WebSocket received non-UTF-8 binary frame"
121                ),
122            },
123            Ok(Message::Close(_)) => break,
124            Ok(Message::Ping(_) | Message::Pong(_)) => {}
125            Err(e) => {
126                let msg = e.to_string();
127                if msg.contains("Connection reset without closing handshake")
128                    || msg.contains("Connection closed normally")
129                {
130                    ::zeroclaw_log::record!(
131                        DEBUG,
132                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
133                        "ACP WebSocket closed without handshake"
134                    );
135                } else {
136                    ::zeroclaw_log::record!(
137                        WARN,
138                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
139                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
140                            .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
141                        "ACP WebSocket receive error"
142                    );
143                }
144                break;
145            }
146        }
147    }
148
149    drop(input_tx);
150
151    if let Err(e) = server_task.await {
152        ::zeroclaw_log::record!(
153            WARN,
154            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
155                .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
156                .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
157            "ACP WebSocket server task panicked"
158        );
159    }
160    output_task.abort();
161    ::zeroclaw_log::record!(
162        DEBUG,
163        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
164        "ACP WebSocket disconnected"
165    );
166}
167
168fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
169    headers
170        .get(axum::http::header::AUTHORIZATION)
171        .and_then(|v| v.to_str().ok())
172        .and_then(|auth| auth.strip_prefix("Bearer "))
173        .map(str::trim)
174        .filter(|token| !token.is_empty())
175        .or_else(|| {
176            headers
177                .get(axum::http::header::SEC_WEBSOCKET_PROTOCOL)
178                .and_then(|v| v.to_str().ok())
179                .and_then(|protocols| {
180                    protocols
181                        .split(',')
182                        .map(str::trim)
183                        .find_map(|p| p.strip_prefix("bearer."))
184                })
185                .filter(|token| !token.is_empty())
186        })
187        .or_else(|| query_token.filter(|token| !token.is_empty()))
188}