Skip to main content

zeroclaw_gateway/
ws.rs

1//! WebSocket agent chat handler.
2//!
3//! Connect: `ws://host:port/ws/chat?session_id=ID&name=My+Session`
4//!
5//! Protocol:
6//! ```text
7//! Server -> Client: {"type":"session_start","session_id":"...","name":"...","resumed":true,"message_count":42}
8//! Client -> Server: {"type":"message","content":"Hello"}
9//! Server -> Client: {"type":"chunk","content":"Hi! "}
10//! Server -> Client: {"type":"tool_call","name":"shell","args":{...}}
11//! Server -> Client: {"type":"tool_result","name":"shell","output":"..."}
12//! Server -> Client: {"type":"done","full_response":"..."}
13//! ```
14//!
15//! ## Tool approvals
16//!
17//! When supervised-mode tool calls hit the `ApprovalManager`, the server
18//! emits an `approval_request` and pauses the tool loop until the client
19//! responds. Mirrors the Telegram inline-keyboard / CLI Y/N/A pattern,
20//! over the WS frame transport.
21//!
22//! ```text
23//! Server -> Client: {
24//!     "type": "approval_request",
25//!     "request_id": "<uuid>",
26//!     "tool": "shell",
27//!     "arguments_summary": "command: git status",
28//!     "timeout_secs": 120
29//! }
30//! Client -> Server: {
31//!     "type": "approval_response",
32//!     "request_id": "<uuid>",
33//!     "decision": "approve" | "deny" | "always"
34//! }
35//! ```
36//!
37//! `approve` runs the tool once, `always` adds the tool to the session
38//! allowlist for the rest of the conversation, `deny` returns a structured
39//! error to the model. When no client is connected, or the client
40//! disconnects mid-prompt, the tool call is auto-denied after `timeout_secs`.
41//!
42//! ### `arguments_summary` security boundary
43//!
44//! `arguments_summary` is a human-readable string the runtime synthesises
45//! for the operator (e.g. `"command: git status"`, `"path: /etc/hosts"`).
46//! It is render-only; the operator's approve/deny choice attaches to the
47//! `request_id`, never to the summary string. The runtime must not echo
48//! any `#[secret]` or `#[derived_from_secret]` field (auth tokens, API
49//! keys, OAuth secrets) into the summary. The agent's tool loop runs
50//! tool args through `zeroclaw_runtime::approval::summarize_args` before
51//! the request reaches this transport; do not stringify raw args here.
52//!
53//! Query params:
54//! - `session_id` — resume or create a session (default: new UUID)
55//! - `name` — optional human-readable label for the session
56//! - `token` — bearer auth token (alternative to Authorization header)
57
58use super::AppState;
59use crate::ws_approval::{PendingApprovals, WsApprovalChannel, new_pending_approvals};
60use axum::{
61    extract::{
62        Query, State, WebSocketUpgrade,
63        ws::{Message, WebSocket},
64    },
65    http::{HeaderMap, header},
66    response::IntoResponse,
67};
68use futures_util::{SinkExt, StreamExt};
69use serde::Deserialize;
70use std::path::{Path, PathBuf};
71use std::sync::Arc;
72use std::time::Duration;
73use zeroclaw_api::channel::ChannelApprovalResponse;
74
75/// Default wall-clock budget for the operator to answer an
76/// `approval_request` frame before the channel auto-denies. Mirrors the
77/// channel-side default on `TelegramConfig::approval_timeout_secs`.
78const WS_APPROVAL_TIMEOUT_SECS: u64 = 120;
79
80/// Optional connection parameters sent as the first WebSocket message.
81///
82/// If the first message after upgrade is `{"type":"connect",...}`, these
83/// parameters are extracted and an acknowledgement is sent back. Old clients
84/// that send `{"type":"message",...}` as the first frame still work — the
85/// message is processed normally (backward-compatible).
86#[derive(Debug, Deserialize)]
87struct ConnectParams {
88    #[serde(rename = "type")]
89    msg_type: String,
90    /// Client-chosen session ID for memory persistence
91    #[serde(default)]
92    session_id: Option<String>,
93    /// Device name for device registry tracking
94    #[serde(default)]
95    device_name: Option<String>,
96    /// Client capabilities
97    #[serde(default)]
98    capabilities: Vec<String>,
99    /// Project root / working directory for this session.
100    #[serde(default, alias = "workspaceDir", alias = "workspace_dir")]
101    cwd: Option<String>,
102}
103
104/// The sub-protocol we support for the chat WebSocket.
105const WS_PROTOCOL: &str = "zeroclaw.v1";
106
107/// Prefix used in `Sec-WebSocket-Protocol` to carry a bearer token.
108const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
109
110#[derive(Deserialize)]
111pub struct WsQuery {
112    pub token: Option<String>,
113    pub session_id: Option<String>,
114    /// Optional human-readable name for the session.
115    pub name: Option<String>,
116    /// Configured agent alias to run as. Required — every WebSocket
117    /// session is bound to an explicit agent (no default agent exists).
118    #[serde(default, alias = "agentAlias", alias = "agent")]
119    pub agent_alias: Option<String>,
120    /// Project root / working directory for this session.
121    #[serde(default)]
122    pub cwd: Option<String>,
123    #[serde(default, alias = "workspaceDir", alias = "workspace_dir")]
124    pub workspace_dir: Option<String>,
125}
126
127/// Extract a bearer token from WebSocket-compatible sources.
128///
129/// Precedence (first non-empty wins):
130/// 1. `Authorization: Bearer <token>` header
131/// 2. `Sec-WebSocket-Protocol: bearer.<token>` subprotocol
132/// 3. `?token=<token>` query parameter
133///
134/// Browsers cannot set custom headers on `new WebSocket(url)`, so the query
135/// parameter and subprotocol paths are required for browser-based clients.
136fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
137    // 1. Authorization header
138    if let Some(t) = headers
139        .get(header::AUTHORIZATION)
140        .and_then(|v| v.to_str().ok())
141        .and_then(|auth| auth.strip_prefix("Bearer "))
142        && !t.is_empty()
143    {
144        return Some(t);
145    }
146
147    // 2. Sec-WebSocket-Protocol: bearer.<token>
148    if let Some(t) = headers
149        .get("sec-websocket-protocol")
150        .and_then(|v| v.to_str().ok())
151        .and_then(|protos| {
152            protos
153                .split(',')
154                .map(|p| p.trim())
155                .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
156        })
157        && !t.is_empty()
158    {
159        return Some(t);
160    }
161
162    // 3. ?token= query parameter
163    if let Some(t) = query_token
164        && !t.is_empty()
165    {
166        return Some(t);
167    }
168
169    None
170}
171
172/// GET /ws/chat — WebSocket upgrade for agent chat
173pub async fn handle_ws_chat(
174    State(state): State<AppState>,
175    Query(params): Query<WsQuery>,
176    headers: HeaderMap,
177    ws: WebSocketUpgrade,
178) -> impl IntoResponse {
179    // Auth: check header, subprotocol, then query param (precedence order)
180    if state.pairing.require_pairing() {
181        let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
182        if !state.pairing.is_authenticated(token) {
183            return (
184                axum::http::StatusCode::UNAUTHORIZED,
185                "Unauthorized — provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param",
186            )
187                .into_response();
188        }
189    }
190
191    // Echo Sec-WebSocket-Protocol if the client requests our sub-protocol.
192    let ws = if headers
193        .get("sec-websocket-protocol")
194        .and_then(|v| v.to_str().ok())
195        .is_some_and(|protos| protos.split(',').any(|p| p.trim() == WS_PROTOCOL))
196    {
197        ws.protocols([WS_PROTOCOL])
198    } else {
199        ws
200    };
201
202    // Reject the upgrade up-front when the client didn't pick an agent.
203    // No default — every WS session is bound to an explicit agent.
204    let Some(agent_alias) = params.agent_alias.filter(|s| !s.trim().is_empty()) else {
205        return (
206            axum::http::StatusCode::BAD_REQUEST,
207            "Missing required `agent` query parameter — pass `?agent=<alias>` matching a configured [agents.<alias>] entry.",
208        )
209            .into_response();
210    };
211    {
212        let cfg = state.config.read();
213        if cfg.agent(&agent_alias).is_none() {
214            return (
215                axum::http::StatusCode::BAD_REQUEST,
216                format!(
217                    "Unknown agent `{agent_alias}` — no [agents.{agent_alias}] entry configured."
218                ),
219            )
220                .into_response();
221        }
222    }
223
224    let session_id = params.session_id;
225    let session_name = params.name;
226    let session_cwd = params.cwd.or(params.workspace_dir);
227    ws.on_upgrade(move |socket| {
228        handle_socket(
229            socket,
230            state,
231            agent_alias,
232            session_id,
233            session_name,
234            session_cwd,
235        )
236    })
237    .into_response()
238}
239
240/// Gateway session key prefix to avoid collisions with channel sessions.
241const GW_SESSION_PREFIX: &str = "gw_";
242
243async fn handle_socket(
244    socket: WebSocket,
245    state: AppState,
246    agent_alias: String,
247    session_id: Option<String>,
248    session_name: Option<String>,
249    session_cwd: Option<String>,
250) {
251    let (mut sender, mut receiver) = socket.split();
252
253    // Resolve session ID: use provided or generate a new UUID
254    let session_id = session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
255    let session_key = format!("{GW_SESSION_PREFIX}{session_id}");
256    // Match the sanitized form persisted by memory backend migrations.
257    let mut memory_session_id = zeroclaw_api::session_keys::sanitize_session_key(&session_id);
258
259    // Hydrate session metadata from persistence (if available). Agent
260    // construction is deferred until after the optional `connect` frame so the
261    // client can provide a per-session cwd for the security sandbox root.
262    let config = state.config.read().clone();
263    let mut resumed = false;
264    let mut message_count: usize = 0;
265    let mut effective_name: Option<String> = None;
266    let mut stored_messages = Vec::new();
267    if let Some(ref backend) = state.session_backend {
268        let messages = backend.load(&session_key);
269        if !messages.is_empty() {
270            message_count = messages.len();
271            stored_messages = messages;
272            resumed = true;
273        }
274        // Set session name if provided (non-empty) on connect
275        if let Some(ref name) = session_name
276            && !name.is_empty()
277        {
278            let _ = backend.set_session_name(&session_key, name);
279            effective_name = Some(name.clone());
280        }
281        // If no name was provided via query param, load the stored name
282        if effective_name.is_none() {
283            effective_name = backend.get_session_name(&session_key).unwrap_or(None);
284        }
285        // Stamp the agent alias so future /api/sessions queries and
286        // per-agent filters can attribute this session to its agent.
287        let _ = backend.set_session_agent_alias(&session_key, &agent_alias);
288    }
289
290    // Send session_start message to client
291    let mut session_start = serde_json::json!({
292        "type": "session_start",
293        "session_id": session_id,
294        "resumed": resumed,
295        "message_count": message_count,
296    });
297    if let Some(ref name) = effective_name {
298        session_start["name"] = serde_json::Value::String(name.clone());
299    }
300    let _ = sender
301        .send(Message::Text(session_start.to_string().into()))
302        .await;
303
304    // ── Optional connect handshake ──────────────────────────────────
305    // The first message may be a `{"type":"connect",...}` frame carrying
306    // connection parameters.  If it is, we extract the params, send an
307    // ack, and proceed to the normal message loop.  If the first message
308    // is a regular `{"type":"message",...}` frame, we fall through and
309    // process it immediately (backward-compatible).
310    let mut first_msg_fallback: Option<String> = None;
311    let mut requested_cwd = session_cwd;
312
313    if let Some(first) = receiver.next().await {
314        match first {
315            Ok(Message::Text(text)) => {
316                if let Ok(cp) = serde_json::from_str::<ConnectParams>(&text) {
317                    if cp.msg_type == "connect" {
318                        ::zeroclaw_log::record!(DEBUG, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"session_id": cp.session_id, "device_name": cp.device_name, "capabilities": cp.capabilities, "cwd": cp.cwd})), "WebSocket connect params received");
319                        if let Some(sid) = &cp.session_id {
320                            memory_session_id =
321                                zeroclaw_api::session_keys::sanitize_session_key(sid);
322                            ::zeroclaw_log::record!(
323                                DEBUG,
324                                ::zeroclaw_log::Event::new(
325                                    module_path!(),
326                                    ::zeroclaw_log::Action::Note
327                                )
328                                .with_attrs(::serde_json::json!({"session_id": sid})),
329                                "WebSocket connect session override received"
330                            );
331                        }
332                        if cp.cwd.is_some() {
333                            requested_cwd = cp.cwd;
334                        }
335                        let ack = serde_json::json!({
336                            "type": "connected",
337                            "message": "Connection established"
338                        });
339                        let _ = sender.send(Message::Text(ack.to_string().into())).await;
340                    } else {
341                        // Not a connect message — fall through to normal processing
342                        first_msg_fallback = Some(text.to_string());
343                    }
344                } else {
345                    // Not parseable as ConnectParams — fall through
346                    first_msg_fallback = Some(text.to_string());
347                }
348            }
349            Ok(Message::Close(_)) | Err(_) => return,
350            _ => {}
351        }
352    }
353
354    let session_cwd = match resolve_session_cwd(requested_cwd.as_deref(), &config.data_dir) {
355        Ok(cwd) => cwd,
356        Err(e) => {
357            let err = serde_json::json!({
358                "type": "error",
359                "message": e.to_string(),
360                "code": "INVALID_CWD"
361            });
362            let _ = sender.send(Message::Text(err.to_string().into())).await;
363            return;
364        }
365    };
366
367    if let Some(err) = needs_onboarding_ws_error(&config) {
368        let _ = sender.send(Message::Text(err.to_string().into())).await;
369        return;
370    }
371
372    // Build a persistent Agent for this connection so history is maintained
373    // across turns. The session cwd becomes the security sandbox root; config
374    // workspace remains the daemon data directory. Routes through the
375    // backchannel constructor so this WS session shares its tool-approval
376    // path with the operator-driven dashboard. The agent_alias was
377    // validated up-front in handle_ws_chat against the configured agents.
378    let mut agent =
379        match zeroclaw_runtime::agent::Agent::from_config_with_session_cwd_and_mcp_backchannel(
380            &config,
381            &agent_alias,
382            Some(&session_cwd),
383            true,
384            Some(state.canvas_store.clone()),
385        )
386        .await
387        {
388            Ok(a) => a,
389            Err(e) => {
390                ::zeroclaw_log::record!(
391                    ERROR,
392                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
393                        .with_outcome(::zeroclaw_log::EventOutcome::Failure)
394                        .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
395                    "Agent initialization failed"
396                );
397                let err = serde_json::json!({
398                    "type": "error",
399                    "message": format!("Failed to initialise agent: {e}"),
400                    "code": "AGENT_INIT_FAILED"
401                });
402                let _ = sender.send(Message::Text(err.to_string().into())).await;
403                let _ = sender
404                    .send(Message::Close(Some(axum::extract::ws::CloseFrame {
405                        code: 1011,
406                        reason: axum::extract::ws::Utf8Bytes::from_static(
407                            "Agent initialization failed",
408                        ),
409                    })))
410                    .await;
411                return;
412            }
413        };
414    agent.set_memory_session_id(Some(memory_session_id));
415    if !stored_messages.is_empty() {
416        agent.seed_history(&stored_messages);
417    }
418
419    // ── Tool-approval back-channel ─────────────────────────────────
420    // Connection-level event channel that the WsApprovalChannel shares
421    // with the per-turn forward task: it pushes ApprovalRequest frames
422    // here when the agent's tool loop pauses for consent, and the
423    // forward task drains them out the same WebSocket as the regular
424    // streaming events. The pending map is shared with the receive loop
425    // so inbound `approval_response` frames can resolve the matching
426    // oneshot waiter.
427    let (approval_event_tx, mut approval_event_rx) =
428        tokio::sync::mpsc::channel::<zeroclaw_api::agent::TurnEvent>(8);
429    let pending_approvals: PendingApprovals = new_pending_approvals();
430    let approval_channel = Arc::new(WsApprovalChannel::new(
431        approval_event_tx.clone(),
432        pending_approvals.clone(),
433        Duration::from_secs(WS_APPROVAL_TIMEOUT_SECS),
434    ));
435    agent
436        .channel_handles()
437        .register_channel("ws", approval_channel.clone());
438
439    // Seed agent's channel handles with configured channels (telegram,
440    // etc.) so the dashboard agent can deliver to external channels.
441    // The agent creates its own fresh handles in
442    // from_config_with_session_cwd_and_mcp_backchannel, so they need
443    // to be populated here — separate from the gateway boot-time seeding.
444    let ch = agent.channel_handles();
445    let channel_names = zeroclaw_channels::orchestrator::register_channels_for_tools(
446        &config,
447        &ch.ask_user,
448        &Some(ch.reaction.clone()),
449        &ch.poll,
450        &ch.escalate,
451        &ch.channel_send,
452    );
453    if !channel_names.is_empty() {
454        ::zeroclaw_log::record!(
455            INFO,
456            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(
457                ::serde_json::json!({"channels": channel_names, "session": session_key})
458            ),
459            "Seeded {} channel(s) into dashboard agent session",
460        );
461    }
462
463    // Process the first message if it was not a connect frame
464    if let Some(ref text) = first_msg_fallback {
465        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
466            if parsed["type"].as_str() == Some("message") {
467                let content = parsed["content"].as_str().unwrap_or("").to_string();
468                if !content.is_empty() {
469                    let _session_guard = match state.session_queue.acquire(&session_key).await {
470                        Ok(guard) => guard,
471                        Err(e) => {
472                            let err = serde_json::json!({
473                                "type": "error",
474                                "message": e.to_string(),
475                                "code": session_queue_ws_error_code(&e)
476                            });
477                            let _ = sender.send(Message::Text(err.to_string().into())).await;
478                            return;
479                        }
480                    };
481                    process_chat_message(
482                        &state,
483                        &mut agent,
484                        &mut sender,
485                        &mut receiver,
486                        &mut approval_event_rx,
487                        &pending_approvals,
488                        &content,
489                        &session_key,
490                    )
491                    .await;
492                }
493            } else {
494                let unknown_type = parsed["type"].as_str().unwrap_or("unknown");
495                let err = serde_json::json!({
496                    "type": "error",
497                    "message": format!(
498                        "Unsupported message type \"{unknown_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
499                    )
500                });
501                let _ = sender.send(Message::Text(err.to_string().into())).await;
502            }
503        } else {
504            let err = serde_json::json!({
505                "type": "error",
506                "message": "Invalid JSON. Send {\"type\":\"message\",\"content\":\"your text\"}"
507            });
508            let _ = sender.send(Message::Text(err.to_string().into())).await;
509        }
510    }
511
512    // Subscribe to the shared broadcast channel so cron/heartbeat events
513    // are forwarded to this WebSocket client.
514    let mut broadcast_rx = state.event_tx.subscribe();
515
516    loop {
517        tokio::select! {
518            // ── Client message ────────────────────────────────────────
519            client_msg = receiver.next() => {
520                let Some(msg) = client_msg else { break };
521                let msg = match msg {
522                    Ok(Message::Text(text)) => text,
523                    Ok(Message::Close(_)) | Err(_) => break,
524                    _ => continue,
525                };
526
527                // Parse incoming message
528                let parsed: serde_json::Value = match serde_json::from_str(&msg) {
529                    Ok(v) => v,
530                    Err(e) => {
531                        let err = serde_json::json!({
532                            "type": "error",
533                            "message": format!("Invalid JSON: {}", e),
534                            "code": "INVALID_JSON"
535                        });
536                        let _ = sender.send(Message::Text(err.to_string().into())).await;
537                        continue;
538                    }
539                };
540
541                let msg_type = parsed["type"].as_str().unwrap_or("");
542
543                // ── Voice duplex event dispatch (gated by feature flag + runtime config) ──
544                #[cfg(feature = "gateway-voice-duplex")]
545                {
546                    // Multi-instance shape: presence in the map = enabled.
547                    let duplex_enabled = !state.config.read().channels.voice_duplex.is_empty();
548                    if duplex_enabled {
549                        if let Some(voice_event) = crate::voice_duplex::try_parse_voice_event(&msg) {
550                            if let Some(error_frame) = crate::voice_duplex::handle_voice_event(voice_event) {
551                                let _ = sender.send(Message::Text(error_frame.to_string().into())).await;
552                            }
553                            continue;
554                        }
555                    }
556                }
557
558                // ── approval_response (operator answered a tool prompt) ──
559                if msg_type == "approval_response" {
560                    let request_id = parsed["request_id"].as_str().unwrap_or("");
561                    let decision_str = parsed["decision"].as_str().unwrap_or("");
562                    let decision = match decision_str {
563                        "approve" => Some(ChannelApprovalResponse::Approve),
564                        "always" => Some(ChannelApprovalResponse::AlwaysApprove),
565                        "deny" => Some(ChannelApprovalResponse::Deny),
566                        _ => None,
567                    };
568                    if request_id.is_empty() || decision.is_none() {
569                        let err = serde_json::json!({
570                            "type": "error",
571                            "message": "approval_response requires request_id and decision in {approve,deny,always}",
572                            "code": "INVALID_APPROVAL_RESPONSE"
573                        });
574                        let _ = sender.send(Message::Text(err.to_string().into())).await;
575                        continue;
576                    }
577                    if let Some(tx) = pending_approvals.lock().remove(request_id) {
578                        let _ = tx.send(decision.expect("checked above"));
579                    } else {
580                        ::zeroclaw_log::record!(DEBUG, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"request_id": request_id})), "approval_response with no matching pending request");
581                    }
582                    continue;
583                }
584
585                if msg_type != "message" {
586                    let err = serde_json::json!({
587                        "type": "error",
588                        "message": format!(
589                            "Unsupported message type \"{msg_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
590                        ),
591                        "code": "UNKNOWN_MESSAGE_TYPE"
592                    });
593                    let _ = sender.send(Message::Text(err.to_string().into())).await;
594                    continue;
595                }
596
597                let content = parsed["content"].as_str().unwrap_or("").to_string();
598                if content.is_empty() {
599                    let err = serde_json::json!({
600                        "type": "error",
601                        "message": "Message content cannot be empty",
602                        "code": "EMPTY_CONTENT"
603                    });
604                    let _ = sender.send(Message::Text(err.to_string().into())).await;
605                    continue;
606                }
607
608                // Acquire session lock to serialize concurrent turns
609                let _session_guard = match state.session_queue.acquire(&session_key).await {
610                    Ok(guard) => guard,
611                    Err(e) => {
612                        let err = serde_json::json!({
613                            "type": "error",
614                            "message": e.to_string(),
615                            "code": session_queue_ws_error_code(&e)
616                        });
617                        let _ = sender.send(Message::Text(err.to_string().into())).await;
618                        continue;
619                    }
620                };
621
622                process_chat_message(
623                    &state,
624                    &mut agent,
625                    &mut sender,
626                    &mut receiver,
627                    &mut approval_event_rx,
628                    &pending_approvals,
629                    &content,
630                    &session_key,
631                )
632                .await;
633            }
634
635            // ── Broadcast event (cron/heartbeat results) ──────────────
636            event = broadcast_rx.recv() => {
637                if let Ok(event) = event
638                    && event_matches_session(&event, &session_id)
639                {
640                    let _ = sender.send(Message::Text(event.to_string().into())).await;
641                }
642            }
643
644            // ── Approval request from the agent's tool loop ────────────
645            // The WsApprovalChannel emits these whenever a supervised tool
646            // call needs operator consent. Forwarded out the same socket
647            // as the regular streaming events; the matching response
648            // arrives via the `approval_response` arm above and resolves
649            // the channel's pending oneshot.
650            approval_event = approval_event_rx.recv() => {
651                let Some(event) = approval_event else { break };
652                let frame = match event {
653                    zeroclaw_api::agent::TurnEvent::ApprovalRequest {
654                        request_id,
655                        tool_name,
656                        arguments_summary,
657                        timeout_secs,
658                    } => serde_json::json!({
659                        "type": "approval_request",
660                        "request_id": request_id,
661                        "tool": tool_name,
662                        "arguments_summary": arguments_summary,
663                        "timeout_secs": timeout_secs,
664                    }),
665                    other => {
666                        ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"kind": format!("{:?}", other)})), "non-ApprovalRequest event leaked into approval channel");
667                        continue;
668                    }
669                };
670                let _ = sender.send(Message::Text(frame.to_string().into())).await;
671            }
672        }
673    }
674}
675
676fn resolve_session_cwd(
677    requested_cwd: Option<&str>,
678    default_workspace: &Path,
679) -> anyhow::Result<PathBuf> {
680    let cwd = requested_cwd
681        .map(PathBuf::from)
682        .unwrap_or_else(|| default_workspace.to_path_buf());
683    std::fs::canonicalize(&cwd).map_err(|e| {
684        ::zeroclaw_log::record!(
685            WARN,
686            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
687                .with_outcome(::zeroclaw_log::EventOutcome::Failure)
688                .with_attrs(::serde_json::json!({
689                    "cwd": cwd.display().to_string(),
690                    "error": format!("{}", e),
691                })),
692            "ws session cwd rejected"
693        );
694        anyhow::Error::msg(format!(
695            "cwd is not a usable directory ({}): {e}",
696            cwd.display()
697        ))
698    })
699}
700
701fn session_queue_ws_error_code(error: &crate::session_queue::SessionQueueError) -> &'static str {
702    match error {
703        crate::session_queue::SessionQueueError::QueueFull { .. } => "SESSION_QUEUE_FULL",
704        crate::session_queue::SessionQueueError::Timeout { .. } => "SESSION_QUEUE_TIMEOUT",
705    }
706}
707
708fn persist_conversation_messages(
709    backend: &dyn zeroclaw_infra::session_backend::SessionBackend,
710    session_key: &str,
711    messages: &[zeroclaw_providers::ConversationMessage],
712) {
713    for message in messages {
714        let zeroclaw_providers::ConversationMessage::Chat(message) = message else {
715            continue;
716        };
717        if message.role == "system" {
718            continue;
719        }
720        let _ = backend.append(session_key, message);
721    }
722}
723
724fn has_assistant_chat_message(messages: &[zeroclaw_providers::ConversationMessage]) -> bool {
725    messages.iter().any(|message| {
726        matches!(
727            message,
728            zeroclaw_providers::ConversationMessage::Chat(message)
729                if message.role == "assistant"
730        )
731    })
732}
733
734fn needs_onboarding_ws_error(
735    config: &zeroclaw_config::schema::Config,
736) -> Option<serde_json::Value> {
737    let model = config.resolve_default_model().unwrap_or_default();
738    crate::needs_onboarding_for(&model)?;
739    Some(serde_json::json!({
740        "type": "error",
741        "error": "needs_onboarding",
742        "code": "NEEDS_ONBOARDING",
743        "message": crate::needs_onboarding_channel_reply(),
744        "url": "/onboard",
745    }))
746}
747
748fn event_matches_session(event: &serde_json::Value, session_id: &str) -> bool {
749    match event.get("session_id").and_then(|value| value.as_str()) {
750        Some(event_session_id) => event_session_id == session_id,
751        None => true,
752    }
753}
754
755/// Process a single chat message through the agent and send the response.
756///
757/// Uses [`Agent::turn_streamed`] so that intermediate text chunks, tool calls,
758/// and tool results are forwarded to the WebSocket client in real time.
759async fn process_chat_message(
760    state: &AppState,
761    agent: &mut zeroclaw_runtime::agent::Agent,
762    sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
763    receiver: &mut futures_util::stream::SplitStream<WebSocket>,
764    approval_event_rx: &mut tokio::sync::mpsc::Receiver<zeroclaw_api::agent::TurnEvent>,
765    pending_approvals: &PendingApprovals,
766    content: &str,
767    session_key: &str,
768) {
769    use futures_util::StreamExt as _;
770    use zeroclaw_runtime::agent::TurnEvent;
771
772    let provider_label = state
773        .config
774        .read()
775        .first_model_provider_type()
776        .unwrap_or("unknown")
777        .to_string();
778
779    // Broadcast agent_start event
780    let _ = state.event_tx.send(serde_json::json!({
781        "type": "agent_start",
782        "model_provider": provider_label,
783        "model": state.model,
784    }));
785
786    // Set session state to running
787    let turn_id = uuid::Uuid::new_v4().to_string();
788    if let Some(ref backend) = state.session_backend {
789        let _ = backend.set_session_state(session_key, "running", Some(&turn_id));
790    }
791
792    // ── Cancellation token lifecycle ─────────────────────────────
793    // Create a token before the turn starts so the abort endpoint
794    // can cancel it. Remove it after the turn completes regardless
795    // of outcome (normal, error, or cancelled).
796    let cancel_token = tokio_util::sync::CancellationToken::new();
797    {
798        state
799            .cancel_tokens
800            .lock()
801            .expect("cancel_tokens lock poisoned")
802            .insert(session_key.to_string(), cancel_token.clone());
803    }
804
805    // Channel for streaming turn events from the agent.
806    let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<TurnEvent>(64);
807    let (steering_tx, mut steering_rx) = tokio::sync::mpsc::channel::<String>(32);
808
809    // Run the streamed turn concurrently: the agent produces events
810    // while we forward them to the WebSocket below.  We cannot move
811    // `agent` into a spawned task (it is `&mut`), so we use a join
812    // instead — `turn_streamed` writes to the channel and we drain it
813    // from the other branch.
814    let content_owned = content.to_string();
815    let session_key_owned = session_key.to_string();
816    let turn_fut = async {
817        zeroclaw_runtime::agent::loop_::scope_session_key(
818            Some(session_key_owned),
819            agent.turn_streamed_with_steering_state(
820                &content_owned,
821                event_tx,
822                Some(cancel_token.clone()),
823                Some(&mut steering_rx),
824            ),
825        )
826        .await
827    };
828
829    // Drive both futures concurrently: the agent turn produces events
830    // and we relay them over WebSocket. Track streamed chunks so we
831    // can reconstruct partial content on cancellation.
832    //
833    let mut accumulated_text = String::new();
834
835    // Aggregate token usage across all LLM calls in this turn.
836    // The agent emits TurnEvent::Usage once per LLM call when the provider
837    // surfaces usage; we sum to produce a single done-frame total.
838    let mut total_input_tokens: Option<u64> = None;
839    let mut total_output_tokens: Option<u64> = None;
840
841    // Routes the three concurrent streams that the running turn cares about:
842    //   1. inbound `approval_response` frames from the WebSocket client,
843    //   2. `TurnEvent::ApprovalRequest` events from `WsApprovalChannel`,
844    //   3. ordinary `TurnEvent`s from the agent loop.
845    // Without the multiplexed select, the loop draining only `event_rx`
846    // would block the approval back-channel for the whole turn, so a pending
847    // tool approval could neither be sent to the client nor answered before
848    // the timeout fired.
849    let forward_fut = async {
850        let mut cancel_drained = false;
851        loop {
852            tokio::select! {
853                biased;
854                // ── Cancellation arm ─────────────────────────────
855                // When `/abort` cancels the token, immediately drop every
856                // parked oneshot sender so any in-flight `request_approval`
857                // unblocks via the "sender dropped → deny" path in
858                // `WsApprovalChannel`. Without this, the approval future
859                // races only its own `timeout_secs` (default 120s) and
860                // ignores the cancel token, so the abort sits idle for up
861                // to two minutes before the tool loop even gets a chance
862                // to observe the cancellation.
863                _ = cancel_token.cancelled(), if !cancel_drained => {
864                    let drained: Vec<_> = pending_approvals.lock().drain().collect();
865                    drop(drained);
866                    cancel_drained = true;
867                    // Fall through; the agent loop will now wake from the
868                    // approval await, see the cancel token, and propagate
869                    // a ToolLoopCancelled error which closes event_rx and
870                    // breaks this loop on the `event_rx.recv()` arm below.
871                }
872                client_msg = receiver.next() => {
873                    // On client disconnect, `receiver.next()` returns `None`
874                    // (stream end) or `Err(_)` repeatedly. A bare `continue`
875                    // hot-loops the select; cancel the turn so `turn_fut`
876                    // resolves with `ToolLoopCancelled` and `tokio::join!`
877                    // below can return. See #6514.
878                    let text = match client_msg {
879                        Some(Ok(Message::Text(text))) => text,
880                        Some(Ok(Message::Close(_))) | Some(Err(_)) | None => {
881                            cancel_token.cancel();
882                            break;
883                        }
884                        _ => continue,
885                    };
886                    let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) else {
887                        let err = serde_json::json!({
888                            "type": "error",
889                            "message": "Invalid JSON. Send {\"type\":\"message\",\"content\":\"your text\"}",
890                            "code": "INVALID_JSON"
891                        });
892                        let _ = sender.send(Message::Text(err.to_string().into())).await;
893                        continue;
894                    };
895                    match parsed["type"].as_str() {
896                        Some("approval_response") => {
897                            let request_id = parsed["request_id"].as_str().unwrap_or("");
898                            let decision = match parsed["decision"].as_str().unwrap_or("") {
899                                "approve" => Some(ChannelApprovalResponse::Approve),
900                                "always" => Some(ChannelApprovalResponse::AlwaysApprove),
901                                "deny" => Some(ChannelApprovalResponse::Deny),
902                                _ => None,
903                            };
904                            if request_id.is_empty() || decision.is_none() {
905                                continue;
906                            }
907                            if let Some(tx) = pending_approvals.lock().remove(request_id) {
908                                let _ = tx.send(decision.expect("checked above"));
909                            } else {
910                                ::zeroclaw_log::record!(DEBUG, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"request_id": request_id})), "approval_response with no matching pending request (mid-turn)");
911                            }
912                        }
913                        Some("message") => {
914                            let content = parsed["content"].as_str().unwrap_or("").to_string();
915                            if content.is_empty() {
916                                let err = serde_json::json!({
917                                    "type": "error",
918                                    "message": "Message content cannot be empty",
919                                    "code": "EMPTY_CONTENT"
920                                });
921                                let _ = sender.send(Message::Text(err.to_string().into())).await;
922                                continue;
923                            }
924                            match steering_tx.try_send(content) {
925                                Ok(()) => {}
926                                Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
927                                    let err = serde_json::json!({
928                                        "type": "error",
929                                        "message": "Steering queue is full for the running turn",
930                                        "code": "STEERING_QUEUE_FULL"
931                                    });
932                                    let _ = sender.send(Message::Text(err.to_string().into())).await;
933                                }
934                                Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
935                                    let err = serde_json::json!({
936                                        "type": "error",
937                                        "message": "Running turn is no longer accepting steering messages",
938                                        "code": "STEERING_CLOSED"
939                                    });
940                                    let _ = sender.send(Message::Text(err.to_string().into())).await;
941                                }
942                            }
943                        }
944                        _ => {}
945                    }
946                }
947                approval = approval_event_rx.recv() => {
948                    let Some(event) = approval else { continue };
949                    if let TurnEvent::ApprovalRequest {
950                        request_id,
951                        tool_name,
952                        arguments_summary,
953                        timeout_secs,
954                    } = event {
955                        let frame = serde_json::json!({
956                            "type": "approval_request",
957                            "request_id": request_id,
958                            "tool": tool_name,
959                            "arguments_summary": arguments_summary,
960                            "timeout_secs": timeout_secs,
961                        });
962                        let _ = sender.send(Message::Text(frame.to_string().into())).await;
963                    }
964                }
965                event_opt = event_rx.recv() => {
966                    let Some(event) = event_opt else { break };
967                    let ws_msg = match event {
968                        TurnEvent::Usage {
969                            input_tokens,
970                            output_tokens,
971                            cost_usd: _,
972                        } => {
973                            if let Some(it) = input_tokens {
974                                total_input_tokens = Some(total_input_tokens.unwrap_or(0) + it);
975                            }
976                            if let Some(ot) = output_tokens {
977                                total_output_tokens = Some(total_output_tokens.unwrap_or(0) + ot);
978                            }
979                            continue;
980                        }
981                        TurnEvent::Chunk { ref delta } => {
982                            accumulated_text.push_str(delta);
983                            serde_json::json!({ "type": "chunk", "content": delta })
984                        }
985                        TurnEvent::Thinking { delta } => {
986                            serde_json::json!({ "type": "thinking", "content": delta })
987                        }
988                        TurnEvent::ToolCall { id, name, args } => {
989                            serde_json::json!({ "type": "tool_call", "id": id, "name": name, "args": args })
990                        }
991                        TurnEvent::ToolResult { id, name, output } => {
992                            serde_json::json!({ "type": "tool_result", "id": id, "name": name, "output": output })
993                        }
994                        TurnEvent::ApprovalRequest {
995                            request_id,
996                            tool_name,
997                            arguments_summary,
998                            timeout_secs,
999                        } => serde_json::json!({
1000                            "type": "approval_request",
1001                            "request_id": request_id,
1002                            "tool": tool_name,
1003                            "arguments_summary": arguments_summary,
1004                            "timeout_secs": timeout_secs,
1005                        }),
1006                    };
1007                    let _ = sender.send(Message::Text(ws_msg.to_string().into())).await;
1008                }
1009            }
1010        }
1011    };
1012
1013    let (result, ()) = tokio::join!(turn_fut, forward_fut);
1014
1015    // ── Remove cancel token (turn finished) ──────────────────────
1016    {
1017        state
1018            .cancel_tokens
1019            .lock()
1020            .expect("cancel_tokens lock poisoned")
1021            .remove(session_key);
1022    }
1023
1024    // Check if this turn was cancelled. `turn_streamed` propagates
1025    // `ToolLoopCancelled` through anyhow, so we detect it here.
1026    let was_cancelled = match &result {
1027        Err(e) => zeroclaw_runtime::agent::loop_::is_tool_loop_cancelled(&e.error),
1028        Ok(_) => false,
1029    };
1030
1031    if was_cancelled {
1032        if let Some(ref backend) = state.session_backend {
1033            match &result {
1034                Err(error) if !error.new_messages.is_empty() => {
1035                    persist_conversation_messages(
1036                        backend.as_ref(),
1037                        session_key,
1038                        &error.new_messages,
1039                    );
1040                    if !has_assistant_chat_message(&error.new_messages) {
1041                        let truncated = if accumulated_text.is_empty() {
1042                            "[interrupted by user]".to_string()
1043                        } else {
1044                            format!("{accumulated_text}\n\n[interrupted by user]")
1045                        };
1046                        let assistant_msg = zeroclaw_providers::ChatMessage::assistant(&truncated);
1047                        let _ = backend.append(session_key, &assistant_msg);
1048                    }
1049                }
1050                _ => {
1051                    let truncated = if accumulated_text.is_empty() {
1052                        "[interrupted by user]".to_string()
1053                    } else {
1054                        format!("{accumulated_text}\n\n[interrupted by user]")
1055                    };
1056                    let assistant_msg = zeroclaw_providers::ChatMessage::assistant(&truncated);
1057                    let _ = backend.append(session_key, &assistant_msg);
1058                }
1059            }
1060        }
1061
1062        // Inform the client the turn was aborted
1063        let aborted = serde_json::json!({ "type": "aborted" });
1064        let _ = sender.send(Message::Text(aborted.to_string().into())).await;
1065
1066        // Set session state to idle
1067        if let Some(ref backend) = state.session_backend {
1068            let _ = backend.set_session_state(session_key, "idle", None);
1069        }
1070
1071        // Broadcast agent_end event
1072        let _ = state.event_tx.send(serde_json::json!({
1073            "type": "agent_end",
1074            "model_provider": provider_label,
1075            "model": state.model,
1076        }));
1077
1078        // Trace the cancelled turn so the doctor / replay tool sees it
1079        // alongside successful turns. #6001 follow-through.
1080        ::zeroclaw_log::record!(
1081            INFO,
1082            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Cancel)
1083                .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1084                .with_attrs(::serde_json::json!({
1085                    "model_provider": provider_label,
1086                    "model": state.model,
1087                    "session_key": session_key,
1088                    "reason": "interrupted by user",
1089                    "cancelled": true,
1090                    "trace_id": turn_id,
1091                })),
1092            "gateway_ws_turn"
1093        );
1094
1095        return;
1096    }
1097
1098    match result {
1099        Ok(outcome) => {
1100            if let Some(ref backend) = state.session_backend {
1101                persist_conversation_messages(backend.as_ref(), session_key, &outcome.new_messages);
1102            }
1103
1104            // Fire-and-forget memory consolidation so facts from WS sessions
1105            // are extracted to long-term memory (Daily + Core categories).
1106            if state.auto_save {
1107                let mem = state.mem.clone();
1108                let model_provider = state.model_provider.clone();
1109                let model = state.model.clone();
1110                let temperature = state.temperature;
1111                let user_msg = content.to_string();
1112                let assistant_resp = outcome.response.clone();
1113                tokio::spawn(async move {
1114                    if let Err(e) = zeroclaw_memory::consolidation::consolidate_turn(
1115                        model_provider.as_ref(),
1116                        &model,
1117                        temperature,
1118                        mem.as_ref(),
1119                        &user_msg,
1120                        &assistant_resp,
1121                    )
1122                    .await
1123                    {
1124                        ::zeroclaw_log::record!(
1125                            DEBUG,
1126                            ::zeroclaw_log::Event::new(
1127                                module_path!(),
1128                                ::zeroclaw_log::Action::Note
1129                            )
1130                            .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
1131                            "WS memory consolidation skipped"
1132                        );
1133                    }
1134                });
1135            }
1136
1137            // Compute cost from accumulated tokens + configured pricing,
1138            // then write the cost record so /api/cost and costs.jsonl reflect
1139            // this turn. Done before the done frame so cost_usd can ride along.
1140            let total_tokens = match (total_input_tokens, total_output_tokens) {
1141                (Some(i), Some(o)) => Some(i.saturating_add(o)),
1142                (Some(i), None) => Some(i),
1143                (None, Some(o)) => Some(o),
1144                (None, None) => None,
1145            };
1146            let cost_usd = record_turn_cost(
1147                state,
1148                &provider_label,
1149                &state.model,
1150                total_input_tokens,
1151                total_output_tokens,
1152                None,
1153            );
1154
1155            let done = serde_json::json!({
1156                "type": "done",
1157                "full_response": outcome.response,
1158                "input_tokens": total_input_tokens,
1159                "output_tokens": total_output_tokens,
1160                "tokens_used": total_tokens,
1161                "cost_usd": cost_usd,
1162                "model": state.model,
1163                "provider": provider_label,
1164            });
1165            let _ = sender.send(Message::Text(done.to_string().into())).await;
1166
1167            // Set session state to idle
1168            if let Some(ref backend) = state.session_backend {
1169                let _ = backend.set_session_state(session_key, "idle", None);
1170            }
1171
1172            // Broadcast agent_end event
1173            let _ = state.event_tx.send(serde_json::json!({
1174                "type": "agent_end",
1175                "model_provider": provider_label,
1176                "model": state.model,
1177            }));
1178
1179            // Append a runtime-trace.jsonl record so a `zeroclaw doctor`
1180            // sweep sees gateway WS turns alongside channel and CLI turns.
1181            // Closes the gateway-side trace gap from #6001.
1182            ::zeroclaw_log::record!(
1183                INFO,
1184                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Complete)
1185                    .with_outcome(::zeroclaw_log::EventOutcome::Success)
1186                    .with_attrs(::serde_json::json!({
1187                        "model_provider": provider_label,
1188                        "model": state.model,
1189                        "session_key": session_key,
1190                        "input_tokens": total_input_tokens,
1191                        "output_tokens": total_output_tokens,
1192                        "tokens_used": total_tokens,
1193                        "cost_usd": cost_usd,
1194                        "trace_id": turn_id,
1195                    })),
1196                "gateway_ws_turn"
1197            );
1198        }
1199        Err(e) => {
1200            if let Some(ref backend) = state.session_backend
1201                && !e.new_messages.is_empty()
1202            {
1203                persist_conversation_messages(backend.as_ref(), session_key, &e.new_messages);
1204            }
1205
1206            // Set session state to error
1207            if let Some(ref backend) = state.session_backend {
1208                let _ = backend.set_session_state(session_key, "error", Some(&turn_id));
1209            }
1210
1211            ::zeroclaw_log::record!(
1212                ERROR,
1213                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1214                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1215                    .with_attrs(::serde_json::json!({"error": format!("{}", e.error)})),
1216                "Agent turn failed"
1217            );
1218            let sanitized = zeroclaw_providers::sanitize_api_error(&e.error.to_string());
1219            let error_code = if sanitized.to_lowercase().contains("api key")
1220                || sanitized.to_lowercase().contains("authentication")
1221                || sanitized.to_lowercase().contains("unauthorized")
1222            {
1223                "AUTH_ERROR"
1224            } else if sanitized.to_lowercase().contains("model_provider")
1225                || sanitized.to_lowercase().contains("model")
1226            {
1227                "PROVIDER_ERROR"
1228            } else {
1229                "AGENT_ERROR"
1230            };
1231            let err = serde_json::json!({
1232                "type": "error",
1233                "message": sanitized,
1234                "code": error_code,
1235            });
1236            let _ = sender.send(Message::Text(err.to_string().into())).await;
1237
1238            // Broadcast error event
1239            let _ = state.event_tx.send(serde_json::json!({
1240                "type": "error",
1241                "component": "ws_chat",
1242                "message": sanitized,
1243            }));
1244
1245            // Trace the failed turn so the doctor / replay tool sees the
1246            // failure mode and the turn_id can be cross-referenced with
1247            // costs.jsonl. #6001 follow-through.
1248            ::zeroclaw_log::record!(
1249                WARN,
1250                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1251                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1252                    .with_attrs(::serde_json::json!({
1253                        "model_provider": provider_label,
1254                        "model": state.model,
1255                        "session_key": session_key,
1256                        "error": sanitized,
1257                        "error_code": error_code,
1258                        "trace_id": turn_id,
1259                    })),
1260                "gateway_ws_turn"
1261            );
1262        }
1263    }
1264}
1265
1266/// Record token usage for the just-completed turn against the gateway's
1267/// cost tracker, returning the computed cost in USD (or `None` when no
1268/// tracker is configured or no usage was reported).
1269fn record_turn_cost(
1270    state: &AppState,
1271    provider_name: &str,
1272    model: &str,
1273    input_tokens: Option<u64>,
1274    output_tokens: Option<u64>,
1275    cached_input_tokens: Option<u64>,
1276) -> Option<f64> {
1277    let tracker = state.cost_tracker.as_ref()?;
1278    if input_tokens.is_none() && output_tokens.is_none() {
1279        return None;
1280    }
1281    let input = input_tokens.unwrap_or(0);
1282    let output = output_tokens.unwrap_or(0);
1283    let cached_input = cached_input_tokens.unwrap_or(0);
1284    if input == 0 && output == 0 {
1285        return None;
1286    }
1287    // V3 per-provider pricing lookup. Mirrors how the channels
1288    // orchestrator and the gateway lib.rs cost-tracking scope build
1289    // their `ModelProviderPricing`: walk every
1290    // `[model_providers.<type>.<alias>]` and key the per-profile
1291    // pricing map by `<type>.<alias>`. The streaming and non-streaming
1292    // paths derive identical costs because both bottom out in the same
1293    // `<type>.<alias>` key shape.
1294    let config = state.config.read();
1295    let pricing_map = config
1296        .providers
1297        .models
1298        .iter_entries()
1299        .filter(|(_, _, base)| !base.pricing.is_empty())
1300        .map(|(type_k, alias_k, base)| (format!("{type_k}.{alias_k}"), base.pricing.clone()))
1301        .collect::<std::collections::HashMap<String, std::collections::HashMap<String, f64>>>();
1302    drop(config);
1303    let model_pricing = pricing_map.get(provider_name);
1304    let try_lookup = |key: &str| -> (f64, f64, f64) {
1305        let Some(map) = model_pricing else {
1306            return (0.0, 0.0, 0.0);
1307        };
1308        let in_rate = map
1309            .get(&format!("{key}.input"))
1310            .copied()
1311            .or_else(|| map.get(key).copied())
1312            .unwrap_or(0.0);
1313        let out_rate = map
1314            .get(&format!("{key}.output"))
1315            .copied()
1316            .or_else(|| map.get(key).copied())
1317            .unwrap_or(0.0);
1318        let cached_rate = map
1319            .get(&format!("{key}.cached_input"))
1320            .copied()
1321            .unwrap_or(0.0);
1322        (in_rate, out_rate, cached_rate)
1323    };
1324    let (input_rate, output_rate, cached_rate) = match try_lookup(model) {
1325        (0.0, 0.0, 0.0) => model
1326            .rsplit_once('/')
1327            .map(|(_, suffix)| try_lookup(suffix))
1328            .unwrap_or((0.0, 0.0, 0.0)),
1329        rates => rates,
1330    };
1331    let usage = zeroclaw_runtime::cost::types::TokenUsage::new(
1332        model,
1333        input,
1334        output,
1335        cached_input,
1336        input_rate,
1337        output_rate,
1338        cached_rate,
1339    );
1340    let cost_usd = usage.cost_usd;
1341    if let Err(error) = tracker.record_usage(usage) {
1342        ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"provider": provider_name, "model": model, "error": format!("{}", error)})), "Failed to record gateway turn cost");
1343    }
1344    Some(cost_usd)
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349    use super::*;
1350    use axum::http::HeaderMap;
1351
1352    #[test]
1353    fn extract_ws_token_from_authorization_header() {
1354        let mut headers = HeaderMap::new();
1355        headers.insert("authorization", "Bearer zc_test123".parse().unwrap());
1356        assert_eq!(extract_ws_token(&headers, None), Some("zc_test123"));
1357    }
1358
1359    #[test]
1360    fn extract_ws_token_from_subprotocol() {
1361        let mut headers = HeaderMap::new();
1362        headers.insert(
1363            "sec-websocket-protocol",
1364            "zeroclaw.v1, bearer.zc_sub456".parse().unwrap(),
1365        );
1366        assert_eq!(extract_ws_token(&headers, None), Some("zc_sub456"));
1367    }
1368
1369    #[test]
1370    fn extract_ws_token_from_query_param() {
1371        let headers = HeaderMap::new();
1372        assert_eq!(
1373            extract_ws_token(&headers, Some("zc_query789")),
1374            Some("zc_query789")
1375        );
1376    }
1377
1378    #[test]
1379    fn extract_ws_token_precedence_header_over_subprotocol() {
1380        let mut headers = HeaderMap::new();
1381        headers.insert("authorization", "Bearer zc_header".parse().unwrap());
1382        headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
1383        assert_eq!(
1384            extract_ws_token(&headers, Some("zc_query")),
1385            Some("zc_header")
1386        );
1387    }
1388
1389    #[test]
1390    fn extract_ws_token_precedence_subprotocol_over_query() {
1391        let mut headers = HeaderMap::new();
1392        headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
1393        assert_eq!(extract_ws_token(&headers, Some("zc_query")), Some("zc_sub"));
1394    }
1395
1396    #[test]
1397    fn extract_ws_token_returns_none_when_empty() {
1398        let headers = HeaderMap::new();
1399        assert_eq!(extract_ws_token(&headers, None), None);
1400    }
1401
1402    #[test]
1403    fn extract_ws_token_skips_empty_header_value() {
1404        let mut headers = HeaderMap::new();
1405        headers.insert("authorization", "Bearer ".parse().unwrap());
1406        assert_eq!(
1407            extract_ws_token(&headers, Some("zc_fallback")),
1408            Some("zc_fallback")
1409        );
1410    }
1411
1412    #[test]
1413    fn extract_ws_token_skips_empty_query_param() {
1414        let headers = HeaderMap::new();
1415        assert_eq!(extract_ws_token(&headers, Some("")), None);
1416    }
1417
1418    #[test]
1419    fn extract_ws_token_subprotocol_with_multiple_entries() {
1420        let mut headers = HeaderMap::new();
1421        headers.insert(
1422            "sec-websocket-protocol",
1423            "zeroclaw.v1, bearer.zc_tok, other".parse().unwrap(),
1424        );
1425        assert_eq!(extract_ws_token(&headers, None), Some("zc_tok"));
1426    }
1427
1428    #[test]
1429    fn session_scoped_events_only_match_their_session() {
1430        let target_event = serde_json::json!({
1431            "type": "message",
1432            "session_id": "operator-1",
1433            "content": "deploy finished"
1434        });
1435        let other_event = serde_json::json!({
1436            "type": "message",
1437            "session_id": "operator-2",
1438            "content": "different session"
1439        });
1440        let global_event = serde_json::json!({
1441            "type": "cron_result",
1442            "content": "global notification"
1443        });
1444
1445        assert!(event_matches_session(&target_event, "operator-1"));
1446        assert!(!event_matches_session(&other_event, "operator-1"));
1447        assert!(event_matches_session(&global_event, "operator-1"));
1448    }
1449
1450    #[test]
1451    fn resolve_session_cwd_uses_requested_cwd() {
1452        let requested = tempfile::tempdir().unwrap();
1453        let fallback = tempfile::tempdir().unwrap();
1454
1455        let resolved =
1456            resolve_session_cwd(Some(requested.path().to_str().unwrap()), fallback.path()).unwrap();
1457
1458        assert_eq!(resolved, requested.path().canonicalize().unwrap());
1459    }
1460
1461    #[test]
1462    fn resolve_session_cwd_uses_default_workspace_without_request() {
1463        let fallback = tempfile::tempdir().unwrap();
1464
1465        let resolved = resolve_session_cwd(None, fallback.path()).unwrap();
1466
1467        assert_eq!(resolved, fallback.path().canonicalize().unwrap());
1468    }
1469
1470    #[test]
1471    fn resolve_session_cwd_rejects_missing_directory() {
1472        let fallback = tempfile::tempdir().unwrap();
1473        let missing = fallback.path().join("missing");
1474
1475        let err = resolve_session_cwd(Some(missing.to_str().unwrap()), fallback.path())
1476            .expect_err("missing cwd should be rejected");
1477
1478        assert!(err.to_string().contains("cwd is not a usable directory"));
1479    }
1480
1481    #[test]
1482    fn needs_onboarding_ws_error_points_to_onboard() {
1483        let config = zeroclaw_config::schema::Config::default();
1484        let frame = needs_onboarding_ws_error(&config)
1485            .expect("empty model must produce a WS onboarding error");
1486
1487        assert_eq!(frame["type"], "error");
1488        assert_eq!(frame["error"], "needs_onboarding");
1489        assert_eq!(frame["code"], "NEEDS_ONBOARDING");
1490        assert_eq!(frame["url"], "/onboard");
1491        let message = frame["message"]
1492            .as_str()
1493            .expect("onboarding WS error must include a message");
1494        assert!(
1495            !message.starts_with('{') && !message.ends_with('}'),
1496            "missing Fluent key fallback leaked into WS error message: {message:?}"
1497        );
1498        assert!(
1499            message.to_lowercase().contains("onboarding"),
1500            "WS onboarding message must explain the setup gap: {message:?}"
1501        );
1502    }
1503
1504    #[test]
1505    fn needs_onboarding_ws_error_uses_current_configured_model() {
1506        let mut config = zeroclaw_config::schema::Config::default();
1507        config.providers.models.openai.insert(
1508            "default".to_string(),
1509            zeroclaw_config::schema::OpenAIModelProviderConfig {
1510                base: zeroclaw_config::schema::ModelProviderConfig {
1511                    model: Some("openai/gpt-4o-mini".to_string()),
1512                    api_key: Some("sk-test".to_string()),
1513                    ..Default::default()
1514                },
1515            },
1516        );
1517
1518        assert!(
1519            needs_onboarding_ws_error(&config).is_none(),
1520            "current configured model must allow WebSocket agent construction to continue"
1521        );
1522    }
1523
1524    // Regression for #6514. The mid-turn `client_msg` arm in `forward_fut`
1525    // must (a) classify stream-end / close / error frames as "client gone"
1526    // and (b) cancel the turn token so `tokio::join!(turn_fut, forward_fut)`
1527    // can return — a bare `continue` hot-loops the select forever.
1528    #[derive(Debug, PartialEq, Eq)]
1529    enum DisconnectAction {
1530        Break,
1531        Continue,
1532        ProcessText,
1533    }
1534
1535    fn classify_client_msg(
1536        msg: Option<Result<axum::extract::ws::Message, &'static str>>,
1537    ) -> DisconnectAction {
1538        use axum::extract::ws::Message;
1539        match msg {
1540            Some(Ok(Message::Text(_))) => DisconnectAction::ProcessText,
1541            Some(Ok(Message::Close(_))) | Some(Err(_)) | None => DisconnectAction::Break,
1542            _ => DisconnectAction::Continue,
1543        }
1544    }
1545
1546    #[test]
1547    fn mid_turn_client_msg_breaks_on_stream_end_close_or_err() {
1548        use axum::extract::ws::Message;
1549        assert_eq!(classify_client_msg(None), DisconnectAction::Break);
1550        assert_eq!(
1551            classify_client_msg(Some(Ok(Message::Close(None)))),
1552            DisconnectAction::Break,
1553        );
1554        assert_eq!(
1555            classify_client_msg(Some(Err("io"))),
1556            DisconnectAction::Break,
1557        );
1558        assert_eq!(
1559            classify_client_msg(Some(Ok(Message::Ping(Default::default())))),
1560            DisconnectAction::Continue,
1561        );
1562        assert_eq!(
1563            classify_client_msg(Some(Ok(Message::Text("{}".into())))),
1564            DisconnectAction::ProcessText,
1565        );
1566    }
1567
1568    #[test]
1569    fn mid_turn_disconnect_cancel_unblocks_joined_turn() {
1570        let token = tokio_util::sync::CancellationToken::new();
1571        let clone_for_turn = token.clone();
1572        assert!(!clone_for_turn.is_cancelled());
1573        token.cancel();
1574        assert!(
1575            clone_for_turn.is_cancelled(),
1576            "cloned token (held by turn_fut via agent.turn_streamed) must observe cancellation"
1577        );
1578    }
1579
1580    #[test]
1581    fn session_queue_errors_map_to_explicit_websocket_codes() {
1582        use crate::session_queue::SessionQueueError;
1583
1584        assert_eq!(
1585            session_queue_ws_error_code(&SessionQueueError::QueueFull {
1586                session_id: "gw_test".into(),
1587                depth: 2,
1588            }),
1589            "SESSION_QUEUE_FULL"
1590        );
1591        assert_eq!(
1592            session_queue_ws_error_code(&SessionQueueError::Timeout {
1593                session_id: "gw_test".into(),
1594            }),
1595            "SESSION_QUEUE_TIMEOUT"
1596        );
1597    }
1598}