1use super::AppState;
59use crate::ws_approval::{PendingApprovals, WsApprovalChannel, new_pending_approvals};
60use axum::{
61 extract::{
62 Query, State, WebSocketUpgrade,
63 ws::{Message, WebSocket},
64 },
65 http::{HeaderMap, header},
66 response::IntoResponse,
67};
68use futures_util::{SinkExt, StreamExt};
69use serde::Deserialize;
70use std::path::{Path, PathBuf};
71use std::sync::Arc;
72use std::time::Duration;
73use zeroclaw_api::channel::ChannelApprovalResponse;
74
75const WS_APPROVAL_TIMEOUT_SECS: u64 = 120;
79
80#[derive(Debug, Deserialize)]
87struct ConnectParams {
88 #[serde(rename = "type")]
89 msg_type: String,
90 #[serde(default)]
92 session_id: Option<String>,
93 #[serde(default)]
95 device_name: Option<String>,
96 #[serde(default)]
98 capabilities: Vec<String>,
99 #[serde(default, alias = "workspaceDir", alias = "workspace_dir")]
101 cwd: Option<String>,
102}
103
104const WS_PROTOCOL: &str = "zeroclaw.v1";
106
107const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
109
110#[derive(Deserialize)]
111pub struct WsQuery {
112 pub token: Option<String>,
113 pub session_id: Option<String>,
114 pub name: Option<String>,
116 #[serde(default, alias = "agentAlias", alias = "agent")]
119 pub agent_alias: Option<String>,
120 #[serde(default)]
122 pub cwd: Option<String>,
123 #[serde(default, alias = "workspaceDir", alias = "workspace_dir")]
124 pub workspace_dir: Option<String>,
125}
126
127fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
137 if let Some(t) = headers
139 .get(header::AUTHORIZATION)
140 .and_then(|v| v.to_str().ok())
141 .and_then(|auth| auth.strip_prefix("Bearer "))
142 && !t.is_empty()
143 {
144 return Some(t);
145 }
146
147 if let Some(t) = headers
149 .get("sec-websocket-protocol")
150 .and_then(|v| v.to_str().ok())
151 .and_then(|protos| {
152 protos
153 .split(',')
154 .map(|p| p.trim())
155 .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
156 })
157 && !t.is_empty()
158 {
159 return Some(t);
160 }
161
162 if let Some(t) = query_token
164 && !t.is_empty()
165 {
166 return Some(t);
167 }
168
169 None
170}
171
172pub async fn handle_ws_chat(
174 State(state): State<AppState>,
175 Query(params): Query<WsQuery>,
176 headers: HeaderMap,
177 ws: WebSocketUpgrade,
178) -> impl IntoResponse {
179 if state.pairing.require_pairing() {
181 let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
182 if !state.pairing.is_authenticated(token) {
183 return (
184 axum::http::StatusCode::UNAUTHORIZED,
185 "Unauthorized — provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param",
186 )
187 .into_response();
188 }
189 }
190
191 let ws = if headers
193 .get("sec-websocket-protocol")
194 .and_then(|v| v.to_str().ok())
195 .is_some_and(|protos| protos.split(',').any(|p| p.trim() == WS_PROTOCOL))
196 {
197 ws.protocols([WS_PROTOCOL])
198 } else {
199 ws
200 };
201
202 let Some(agent_alias) = params.agent_alias.filter(|s| !s.trim().is_empty()) else {
205 return (
206 axum::http::StatusCode::BAD_REQUEST,
207 "Missing required `agent` query parameter — pass `?agent=<alias>` matching a configured [agents.<alias>] entry.",
208 )
209 .into_response();
210 };
211 {
212 let cfg = state.config.read();
213 if cfg.agent(&agent_alias).is_none() {
214 return (
215 axum::http::StatusCode::BAD_REQUEST,
216 format!(
217 "Unknown agent `{agent_alias}` — no [agents.{agent_alias}] entry configured."
218 ),
219 )
220 .into_response();
221 }
222 }
223
224 let session_id = params.session_id;
225 let session_name = params.name;
226 let session_cwd = params.cwd.or(params.workspace_dir);
227 ws.on_upgrade(move |socket| {
228 handle_socket(
229 socket,
230 state,
231 agent_alias,
232 session_id,
233 session_name,
234 session_cwd,
235 )
236 })
237 .into_response()
238}
239
240const GW_SESSION_PREFIX: &str = "gw_";
242
243async fn handle_socket(
244 socket: WebSocket,
245 state: AppState,
246 agent_alias: String,
247 session_id: Option<String>,
248 session_name: Option<String>,
249 session_cwd: Option<String>,
250) {
251 let (mut sender, mut receiver) = socket.split();
252
253 let session_id = session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
255 let session_key = format!("{GW_SESSION_PREFIX}{session_id}");
256 let mut memory_session_id = zeroclaw_api::session_keys::sanitize_session_key(&session_id);
258
259 let config = state.config.read().clone();
263 let mut resumed = false;
264 let mut message_count: usize = 0;
265 let mut effective_name: Option<String> = None;
266 let mut stored_messages = Vec::new();
267 if let Some(ref backend) = state.session_backend {
268 let messages = backend.load(&session_key);
269 if !messages.is_empty() {
270 message_count = messages.len();
271 stored_messages = messages;
272 resumed = true;
273 }
274 if let Some(ref name) = session_name
276 && !name.is_empty()
277 {
278 let _ = backend.set_session_name(&session_key, name);
279 effective_name = Some(name.clone());
280 }
281 if effective_name.is_none() {
283 effective_name = backend.get_session_name(&session_key).unwrap_or(None);
284 }
285 let _ = backend.set_session_agent_alias(&session_key, &agent_alias);
288 }
289
290 let mut session_start = serde_json::json!({
292 "type": "session_start",
293 "session_id": session_id,
294 "resumed": resumed,
295 "message_count": message_count,
296 });
297 if let Some(ref name) = effective_name {
298 session_start["name"] = serde_json::Value::String(name.clone());
299 }
300 let _ = sender
301 .send(Message::Text(session_start.to_string().into()))
302 .await;
303
304 let mut first_msg_fallback: Option<String> = None;
311 let mut requested_cwd = session_cwd;
312
313 if let Some(first) = receiver.next().await {
314 match first {
315 Ok(Message::Text(text)) => {
316 if let Ok(cp) = serde_json::from_str::<ConnectParams>(&text) {
317 if cp.msg_type == "connect" {
318 ::zeroclaw_log::record!(DEBUG, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"session_id": cp.session_id, "device_name": cp.device_name, "capabilities": cp.capabilities, "cwd": cp.cwd})), "WebSocket connect params received");
319 if let Some(sid) = &cp.session_id {
320 memory_session_id =
321 zeroclaw_api::session_keys::sanitize_session_key(sid);
322 ::zeroclaw_log::record!(
323 DEBUG,
324 ::zeroclaw_log::Event::new(
325 module_path!(),
326 ::zeroclaw_log::Action::Note
327 )
328 .with_attrs(::serde_json::json!({"session_id": sid})),
329 "WebSocket connect session override received"
330 );
331 }
332 if cp.cwd.is_some() {
333 requested_cwd = cp.cwd;
334 }
335 let ack = serde_json::json!({
336 "type": "connected",
337 "message": "Connection established"
338 });
339 let _ = sender.send(Message::Text(ack.to_string().into())).await;
340 } else {
341 first_msg_fallback = Some(text.to_string());
343 }
344 } else {
345 first_msg_fallback = Some(text.to_string());
347 }
348 }
349 Ok(Message::Close(_)) | Err(_) => return,
350 _ => {}
351 }
352 }
353
354 let session_cwd = match resolve_session_cwd(requested_cwd.as_deref(), &config.data_dir) {
355 Ok(cwd) => cwd,
356 Err(e) => {
357 let err = serde_json::json!({
358 "type": "error",
359 "message": e.to_string(),
360 "code": "INVALID_CWD"
361 });
362 let _ = sender.send(Message::Text(err.to_string().into())).await;
363 return;
364 }
365 };
366
367 if let Some(err) = needs_onboarding_ws_error(&config) {
368 let _ = sender.send(Message::Text(err.to_string().into())).await;
369 return;
370 }
371
372 let mut agent =
379 match zeroclaw_runtime::agent::Agent::from_config_with_session_cwd_and_mcp_backchannel(
380 &config,
381 &agent_alias,
382 Some(&session_cwd),
383 true,
384 Some(state.canvas_store.clone()),
385 )
386 .await
387 {
388 Ok(a) => a,
389 Err(e) => {
390 ::zeroclaw_log::record!(
391 ERROR,
392 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
393 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
394 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
395 "Agent initialization failed"
396 );
397 let err = serde_json::json!({
398 "type": "error",
399 "message": format!("Failed to initialise agent: {e}"),
400 "code": "AGENT_INIT_FAILED"
401 });
402 let _ = sender.send(Message::Text(err.to_string().into())).await;
403 let _ = sender
404 .send(Message::Close(Some(axum::extract::ws::CloseFrame {
405 code: 1011,
406 reason: axum::extract::ws::Utf8Bytes::from_static(
407 "Agent initialization failed",
408 ),
409 })))
410 .await;
411 return;
412 }
413 };
414 agent.set_memory_session_id(Some(memory_session_id));
415 if !stored_messages.is_empty() {
416 agent.seed_history(&stored_messages);
417 }
418
419 let (approval_event_tx, mut approval_event_rx) =
428 tokio::sync::mpsc::channel::<zeroclaw_api::agent::TurnEvent>(8);
429 let pending_approvals: PendingApprovals = new_pending_approvals();
430 let approval_channel = Arc::new(WsApprovalChannel::new(
431 approval_event_tx.clone(),
432 pending_approvals.clone(),
433 Duration::from_secs(WS_APPROVAL_TIMEOUT_SECS),
434 ));
435 agent
436 .channel_handles()
437 .register_channel("ws", approval_channel.clone());
438
439 let ch = agent.channel_handles();
445 let channel_names = zeroclaw_channels::orchestrator::register_channels_for_tools(
446 &config,
447 &ch.ask_user,
448 &Some(ch.reaction.clone()),
449 &ch.poll,
450 &ch.escalate,
451 &ch.channel_send,
452 );
453 if !channel_names.is_empty() {
454 ::zeroclaw_log::record!(
455 INFO,
456 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(
457 ::serde_json::json!({"channels": channel_names, "session": session_key})
458 ),
459 "Seeded {} channel(s) into dashboard agent session",
460 );
461 }
462
463 if let Some(ref text) = first_msg_fallback {
465 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
466 if parsed["type"].as_str() == Some("message") {
467 let content = parsed["content"].as_str().unwrap_or("").to_string();
468 if !content.is_empty() {
469 let _session_guard = match state.session_queue.acquire(&session_key).await {
470 Ok(guard) => guard,
471 Err(e) => {
472 let err = serde_json::json!({
473 "type": "error",
474 "message": e.to_string(),
475 "code": session_queue_ws_error_code(&e)
476 });
477 let _ = sender.send(Message::Text(err.to_string().into())).await;
478 return;
479 }
480 };
481 process_chat_message(
482 &state,
483 &mut agent,
484 &mut sender,
485 &mut receiver,
486 &mut approval_event_rx,
487 &pending_approvals,
488 &content,
489 &session_key,
490 )
491 .await;
492 }
493 } else {
494 let unknown_type = parsed["type"].as_str().unwrap_or("unknown");
495 let err = serde_json::json!({
496 "type": "error",
497 "message": format!(
498 "Unsupported message type \"{unknown_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
499 )
500 });
501 let _ = sender.send(Message::Text(err.to_string().into())).await;
502 }
503 } else {
504 let err = serde_json::json!({
505 "type": "error",
506 "message": "Invalid JSON. Send {\"type\":\"message\",\"content\":\"your text\"}"
507 });
508 let _ = sender.send(Message::Text(err.to_string().into())).await;
509 }
510 }
511
512 let mut broadcast_rx = state.event_tx.subscribe();
515
516 loop {
517 tokio::select! {
518 client_msg = receiver.next() => {
520 let Some(msg) = client_msg else { break };
521 let msg = match msg {
522 Ok(Message::Text(text)) => text,
523 Ok(Message::Close(_)) | Err(_) => break,
524 _ => continue,
525 };
526
527 let parsed: serde_json::Value = match serde_json::from_str(&msg) {
529 Ok(v) => v,
530 Err(e) => {
531 let err = serde_json::json!({
532 "type": "error",
533 "message": format!("Invalid JSON: {}", e),
534 "code": "INVALID_JSON"
535 });
536 let _ = sender.send(Message::Text(err.to_string().into())).await;
537 continue;
538 }
539 };
540
541 let msg_type = parsed["type"].as_str().unwrap_or("");
542
543 #[cfg(feature = "gateway-voice-duplex")]
545 {
546 let duplex_enabled = !state.config.read().channels.voice_duplex.is_empty();
548 if duplex_enabled {
549 if let Some(voice_event) = crate::voice_duplex::try_parse_voice_event(&msg) {
550 if let Some(error_frame) = crate::voice_duplex::handle_voice_event(voice_event) {
551 let _ = sender.send(Message::Text(error_frame.to_string().into())).await;
552 }
553 continue;
554 }
555 }
556 }
557
558 if msg_type == "approval_response" {
560 let request_id = parsed["request_id"].as_str().unwrap_or("");
561 let decision_str = parsed["decision"].as_str().unwrap_or("");
562 let decision = match decision_str {
563 "approve" => Some(ChannelApprovalResponse::Approve),
564 "always" => Some(ChannelApprovalResponse::AlwaysApprove),
565 "deny" => Some(ChannelApprovalResponse::Deny),
566 _ => None,
567 };
568 if request_id.is_empty() || decision.is_none() {
569 let err = serde_json::json!({
570 "type": "error",
571 "message": "approval_response requires request_id and decision in {approve,deny,always}",
572 "code": "INVALID_APPROVAL_RESPONSE"
573 });
574 let _ = sender.send(Message::Text(err.to_string().into())).await;
575 continue;
576 }
577 if let Some(tx) = pending_approvals.lock().remove(request_id) {
578 let _ = tx.send(decision.expect("checked above"));
579 } else {
580 ::zeroclaw_log::record!(DEBUG, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"request_id": request_id})), "approval_response with no matching pending request");
581 }
582 continue;
583 }
584
585 if msg_type != "message" {
586 let err = serde_json::json!({
587 "type": "error",
588 "message": format!(
589 "Unsupported message type \"{msg_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
590 ),
591 "code": "UNKNOWN_MESSAGE_TYPE"
592 });
593 let _ = sender.send(Message::Text(err.to_string().into())).await;
594 continue;
595 }
596
597 let content = parsed["content"].as_str().unwrap_or("").to_string();
598 if content.is_empty() {
599 let err = serde_json::json!({
600 "type": "error",
601 "message": "Message content cannot be empty",
602 "code": "EMPTY_CONTENT"
603 });
604 let _ = sender.send(Message::Text(err.to_string().into())).await;
605 continue;
606 }
607
608 let _session_guard = match state.session_queue.acquire(&session_key).await {
610 Ok(guard) => guard,
611 Err(e) => {
612 let err = serde_json::json!({
613 "type": "error",
614 "message": e.to_string(),
615 "code": session_queue_ws_error_code(&e)
616 });
617 let _ = sender.send(Message::Text(err.to_string().into())).await;
618 continue;
619 }
620 };
621
622 process_chat_message(
623 &state,
624 &mut agent,
625 &mut sender,
626 &mut receiver,
627 &mut approval_event_rx,
628 &pending_approvals,
629 &content,
630 &session_key,
631 )
632 .await;
633 }
634
635 event = broadcast_rx.recv() => {
637 if let Ok(event) = event
638 && event_matches_session(&event, &session_id)
639 {
640 let _ = sender.send(Message::Text(event.to_string().into())).await;
641 }
642 }
643
644 approval_event = approval_event_rx.recv() => {
651 let Some(event) = approval_event else { break };
652 let frame = match event {
653 zeroclaw_api::agent::TurnEvent::ApprovalRequest {
654 request_id,
655 tool_name,
656 arguments_summary,
657 timeout_secs,
658 } => serde_json::json!({
659 "type": "approval_request",
660 "request_id": request_id,
661 "tool": tool_name,
662 "arguments_summary": arguments_summary,
663 "timeout_secs": timeout_secs,
664 }),
665 other => {
666 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"kind": format!("{:?}", other)})), "non-ApprovalRequest event leaked into approval channel");
667 continue;
668 }
669 };
670 let _ = sender.send(Message::Text(frame.to_string().into())).await;
671 }
672 }
673 }
674}
675
676fn resolve_session_cwd(
677 requested_cwd: Option<&str>,
678 default_workspace: &Path,
679) -> anyhow::Result<PathBuf> {
680 let cwd = requested_cwd
681 .map(PathBuf::from)
682 .unwrap_or_else(|| default_workspace.to_path_buf());
683 std::fs::canonicalize(&cwd).map_err(|e| {
684 ::zeroclaw_log::record!(
685 WARN,
686 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
687 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
688 .with_attrs(::serde_json::json!({
689 "cwd": cwd.display().to_string(),
690 "error": format!("{}", e),
691 })),
692 "ws session cwd rejected"
693 );
694 anyhow::Error::msg(format!(
695 "cwd is not a usable directory ({}): {e}",
696 cwd.display()
697 ))
698 })
699}
700
701fn session_queue_ws_error_code(error: &crate::session_queue::SessionQueueError) -> &'static str {
702 match error {
703 crate::session_queue::SessionQueueError::QueueFull { .. } => "SESSION_QUEUE_FULL",
704 crate::session_queue::SessionQueueError::Timeout { .. } => "SESSION_QUEUE_TIMEOUT",
705 }
706}
707
708fn persist_conversation_messages(
709 backend: &dyn zeroclaw_infra::session_backend::SessionBackend,
710 session_key: &str,
711 messages: &[zeroclaw_providers::ConversationMessage],
712) {
713 for message in messages {
714 let zeroclaw_providers::ConversationMessage::Chat(message) = message else {
715 continue;
716 };
717 if message.role == "system" {
718 continue;
719 }
720 let _ = backend.append(session_key, message);
721 }
722}
723
724fn has_assistant_chat_message(messages: &[zeroclaw_providers::ConversationMessage]) -> bool {
725 messages.iter().any(|message| {
726 matches!(
727 message,
728 zeroclaw_providers::ConversationMessage::Chat(message)
729 if message.role == "assistant"
730 )
731 })
732}
733
734fn needs_onboarding_ws_error(
735 config: &zeroclaw_config::schema::Config,
736) -> Option<serde_json::Value> {
737 let model = config.resolve_default_model().unwrap_or_default();
738 crate::needs_onboarding_for(&model)?;
739 Some(serde_json::json!({
740 "type": "error",
741 "error": "needs_onboarding",
742 "code": "NEEDS_ONBOARDING",
743 "message": crate::needs_onboarding_channel_reply(),
744 "url": "/onboard",
745 }))
746}
747
748fn event_matches_session(event: &serde_json::Value, session_id: &str) -> bool {
749 match event.get("session_id").and_then(|value| value.as_str()) {
750 Some(event_session_id) => event_session_id == session_id,
751 None => true,
752 }
753}
754
755async fn process_chat_message(
760 state: &AppState,
761 agent: &mut zeroclaw_runtime::agent::Agent,
762 sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
763 receiver: &mut futures_util::stream::SplitStream<WebSocket>,
764 approval_event_rx: &mut tokio::sync::mpsc::Receiver<zeroclaw_api::agent::TurnEvent>,
765 pending_approvals: &PendingApprovals,
766 content: &str,
767 session_key: &str,
768) {
769 use futures_util::StreamExt as _;
770 use zeroclaw_runtime::agent::TurnEvent;
771
772 let provider_label = state
773 .config
774 .read()
775 .first_model_provider_type()
776 .unwrap_or("unknown")
777 .to_string();
778
779 let _ = state.event_tx.send(serde_json::json!({
781 "type": "agent_start",
782 "model_provider": provider_label,
783 "model": state.model,
784 }));
785
786 let turn_id = uuid::Uuid::new_v4().to_string();
788 if let Some(ref backend) = state.session_backend {
789 let _ = backend.set_session_state(session_key, "running", Some(&turn_id));
790 }
791
792 let cancel_token = tokio_util::sync::CancellationToken::new();
797 {
798 state
799 .cancel_tokens
800 .lock()
801 .expect("cancel_tokens lock poisoned")
802 .insert(session_key.to_string(), cancel_token.clone());
803 }
804
805 let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<TurnEvent>(64);
807 let (steering_tx, mut steering_rx) = tokio::sync::mpsc::channel::<String>(32);
808
809 let content_owned = content.to_string();
815 let session_key_owned = session_key.to_string();
816 let turn_fut = async {
817 zeroclaw_runtime::agent::loop_::scope_session_key(
818 Some(session_key_owned),
819 agent.turn_streamed_with_steering_state(
820 &content_owned,
821 event_tx,
822 Some(cancel_token.clone()),
823 Some(&mut steering_rx),
824 ),
825 )
826 .await
827 };
828
829 let mut accumulated_text = String::new();
834
835 let mut total_input_tokens: Option<u64> = None;
839 let mut total_output_tokens: Option<u64> = None;
840
841 let forward_fut = async {
850 let mut cancel_drained = false;
851 loop {
852 tokio::select! {
853 biased;
854 _ = cancel_token.cancelled(), if !cancel_drained => {
864 let drained: Vec<_> = pending_approvals.lock().drain().collect();
865 drop(drained);
866 cancel_drained = true;
867 }
872 client_msg = receiver.next() => {
873 let text = match client_msg {
879 Some(Ok(Message::Text(text))) => text,
880 Some(Ok(Message::Close(_))) | Some(Err(_)) | None => {
881 cancel_token.cancel();
882 break;
883 }
884 _ => continue,
885 };
886 let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) else {
887 let err = serde_json::json!({
888 "type": "error",
889 "message": "Invalid JSON. Send {\"type\":\"message\",\"content\":\"your text\"}",
890 "code": "INVALID_JSON"
891 });
892 let _ = sender.send(Message::Text(err.to_string().into())).await;
893 continue;
894 };
895 match parsed["type"].as_str() {
896 Some("approval_response") => {
897 let request_id = parsed["request_id"].as_str().unwrap_or("");
898 let decision = match parsed["decision"].as_str().unwrap_or("") {
899 "approve" => Some(ChannelApprovalResponse::Approve),
900 "always" => Some(ChannelApprovalResponse::AlwaysApprove),
901 "deny" => Some(ChannelApprovalResponse::Deny),
902 _ => None,
903 };
904 if request_id.is_empty() || decision.is_none() {
905 continue;
906 }
907 if let Some(tx) = pending_approvals.lock().remove(request_id) {
908 let _ = tx.send(decision.expect("checked above"));
909 } else {
910 ::zeroclaw_log::record!(DEBUG, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"request_id": request_id})), "approval_response with no matching pending request (mid-turn)");
911 }
912 }
913 Some("message") => {
914 let content = parsed["content"].as_str().unwrap_or("").to_string();
915 if content.is_empty() {
916 let err = serde_json::json!({
917 "type": "error",
918 "message": "Message content cannot be empty",
919 "code": "EMPTY_CONTENT"
920 });
921 let _ = sender.send(Message::Text(err.to_string().into())).await;
922 continue;
923 }
924 match steering_tx.try_send(content) {
925 Ok(()) => {}
926 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
927 let err = serde_json::json!({
928 "type": "error",
929 "message": "Steering queue is full for the running turn",
930 "code": "STEERING_QUEUE_FULL"
931 });
932 let _ = sender.send(Message::Text(err.to_string().into())).await;
933 }
934 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
935 let err = serde_json::json!({
936 "type": "error",
937 "message": "Running turn is no longer accepting steering messages",
938 "code": "STEERING_CLOSED"
939 });
940 let _ = sender.send(Message::Text(err.to_string().into())).await;
941 }
942 }
943 }
944 _ => {}
945 }
946 }
947 approval = approval_event_rx.recv() => {
948 let Some(event) = approval else { continue };
949 if let TurnEvent::ApprovalRequest {
950 request_id,
951 tool_name,
952 arguments_summary,
953 timeout_secs,
954 } = event {
955 let frame = serde_json::json!({
956 "type": "approval_request",
957 "request_id": request_id,
958 "tool": tool_name,
959 "arguments_summary": arguments_summary,
960 "timeout_secs": timeout_secs,
961 });
962 let _ = sender.send(Message::Text(frame.to_string().into())).await;
963 }
964 }
965 event_opt = event_rx.recv() => {
966 let Some(event) = event_opt else { break };
967 let ws_msg = match event {
968 TurnEvent::Usage {
969 input_tokens,
970 output_tokens,
971 cost_usd: _,
972 } => {
973 if let Some(it) = input_tokens {
974 total_input_tokens = Some(total_input_tokens.unwrap_or(0) + it);
975 }
976 if let Some(ot) = output_tokens {
977 total_output_tokens = Some(total_output_tokens.unwrap_or(0) + ot);
978 }
979 continue;
980 }
981 TurnEvent::Chunk { ref delta } => {
982 accumulated_text.push_str(delta);
983 serde_json::json!({ "type": "chunk", "content": delta })
984 }
985 TurnEvent::Thinking { delta } => {
986 serde_json::json!({ "type": "thinking", "content": delta })
987 }
988 TurnEvent::ToolCall { id, name, args } => {
989 serde_json::json!({ "type": "tool_call", "id": id, "name": name, "args": args })
990 }
991 TurnEvent::ToolResult { id, name, output } => {
992 serde_json::json!({ "type": "tool_result", "id": id, "name": name, "output": output })
993 }
994 TurnEvent::ApprovalRequest {
995 request_id,
996 tool_name,
997 arguments_summary,
998 timeout_secs,
999 } => serde_json::json!({
1000 "type": "approval_request",
1001 "request_id": request_id,
1002 "tool": tool_name,
1003 "arguments_summary": arguments_summary,
1004 "timeout_secs": timeout_secs,
1005 }),
1006 };
1007 let _ = sender.send(Message::Text(ws_msg.to_string().into())).await;
1008 }
1009 }
1010 }
1011 };
1012
1013 let (result, ()) = tokio::join!(turn_fut, forward_fut);
1014
1015 {
1017 state
1018 .cancel_tokens
1019 .lock()
1020 .expect("cancel_tokens lock poisoned")
1021 .remove(session_key);
1022 }
1023
1024 let was_cancelled = match &result {
1027 Err(e) => zeroclaw_runtime::agent::loop_::is_tool_loop_cancelled(&e.error),
1028 Ok(_) => false,
1029 };
1030
1031 if was_cancelled {
1032 if let Some(ref backend) = state.session_backend {
1033 match &result {
1034 Err(error) if !error.new_messages.is_empty() => {
1035 persist_conversation_messages(
1036 backend.as_ref(),
1037 session_key,
1038 &error.new_messages,
1039 );
1040 if !has_assistant_chat_message(&error.new_messages) {
1041 let truncated = if accumulated_text.is_empty() {
1042 "[interrupted by user]".to_string()
1043 } else {
1044 format!("{accumulated_text}\n\n[interrupted by user]")
1045 };
1046 let assistant_msg = zeroclaw_providers::ChatMessage::assistant(&truncated);
1047 let _ = backend.append(session_key, &assistant_msg);
1048 }
1049 }
1050 _ => {
1051 let truncated = if accumulated_text.is_empty() {
1052 "[interrupted by user]".to_string()
1053 } else {
1054 format!("{accumulated_text}\n\n[interrupted by user]")
1055 };
1056 let assistant_msg = zeroclaw_providers::ChatMessage::assistant(&truncated);
1057 let _ = backend.append(session_key, &assistant_msg);
1058 }
1059 }
1060 }
1061
1062 let aborted = serde_json::json!({ "type": "aborted" });
1064 let _ = sender.send(Message::Text(aborted.to_string().into())).await;
1065
1066 if let Some(ref backend) = state.session_backend {
1068 let _ = backend.set_session_state(session_key, "idle", None);
1069 }
1070
1071 let _ = state.event_tx.send(serde_json::json!({
1073 "type": "agent_end",
1074 "model_provider": provider_label,
1075 "model": state.model,
1076 }));
1077
1078 ::zeroclaw_log::record!(
1081 INFO,
1082 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Cancel)
1083 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1084 .with_attrs(::serde_json::json!({
1085 "model_provider": provider_label,
1086 "model": state.model,
1087 "session_key": session_key,
1088 "reason": "interrupted by user",
1089 "cancelled": true,
1090 "trace_id": turn_id,
1091 })),
1092 "gateway_ws_turn"
1093 );
1094
1095 return;
1096 }
1097
1098 match result {
1099 Ok(outcome) => {
1100 if let Some(ref backend) = state.session_backend {
1101 persist_conversation_messages(backend.as_ref(), session_key, &outcome.new_messages);
1102 }
1103
1104 if state.auto_save {
1107 let mem = state.mem.clone();
1108 let model_provider = state.model_provider.clone();
1109 let model = state.model.clone();
1110 let temperature = state.temperature;
1111 let user_msg = content.to_string();
1112 let assistant_resp = outcome.response.clone();
1113 tokio::spawn(async move {
1114 if let Err(e) = zeroclaw_memory::consolidation::consolidate_turn(
1115 model_provider.as_ref(),
1116 &model,
1117 temperature,
1118 mem.as_ref(),
1119 &user_msg,
1120 &assistant_resp,
1121 )
1122 .await
1123 {
1124 ::zeroclaw_log::record!(
1125 DEBUG,
1126 ::zeroclaw_log::Event::new(
1127 module_path!(),
1128 ::zeroclaw_log::Action::Note
1129 )
1130 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
1131 "WS memory consolidation skipped"
1132 );
1133 }
1134 });
1135 }
1136
1137 let total_tokens = match (total_input_tokens, total_output_tokens) {
1141 (Some(i), Some(o)) => Some(i.saturating_add(o)),
1142 (Some(i), None) => Some(i),
1143 (None, Some(o)) => Some(o),
1144 (None, None) => None,
1145 };
1146 let cost_usd = record_turn_cost(
1147 state,
1148 &provider_label,
1149 &state.model,
1150 total_input_tokens,
1151 total_output_tokens,
1152 None,
1153 );
1154
1155 let done = serde_json::json!({
1156 "type": "done",
1157 "full_response": outcome.response,
1158 "input_tokens": total_input_tokens,
1159 "output_tokens": total_output_tokens,
1160 "tokens_used": total_tokens,
1161 "cost_usd": cost_usd,
1162 "model": state.model,
1163 "provider": provider_label,
1164 });
1165 let _ = sender.send(Message::Text(done.to_string().into())).await;
1166
1167 if let Some(ref backend) = state.session_backend {
1169 let _ = backend.set_session_state(session_key, "idle", None);
1170 }
1171
1172 let _ = state.event_tx.send(serde_json::json!({
1174 "type": "agent_end",
1175 "model_provider": provider_label,
1176 "model": state.model,
1177 }));
1178
1179 ::zeroclaw_log::record!(
1183 INFO,
1184 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Complete)
1185 .with_outcome(::zeroclaw_log::EventOutcome::Success)
1186 .with_attrs(::serde_json::json!({
1187 "model_provider": provider_label,
1188 "model": state.model,
1189 "session_key": session_key,
1190 "input_tokens": total_input_tokens,
1191 "output_tokens": total_output_tokens,
1192 "tokens_used": total_tokens,
1193 "cost_usd": cost_usd,
1194 "trace_id": turn_id,
1195 })),
1196 "gateway_ws_turn"
1197 );
1198 }
1199 Err(e) => {
1200 if let Some(ref backend) = state.session_backend
1201 && !e.new_messages.is_empty()
1202 {
1203 persist_conversation_messages(backend.as_ref(), session_key, &e.new_messages);
1204 }
1205
1206 if let Some(ref backend) = state.session_backend {
1208 let _ = backend.set_session_state(session_key, "error", Some(&turn_id));
1209 }
1210
1211 ::zeroclaw_log::record!(
1212 ERROR,
1213 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1214 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1215 .with_attrs(::serde_json::json!({"error": format!("{}", e.error)})),
1216 "Agent turn failed"
1217 );
1218 let sanitized = zeroclaw_providers::sanitize_api_error(&e.error.to_string());
1219 let error_code = if sanitized.to_lowercase().contains("api key")
1220 || sanitized.to_lowercase().contains("authentication")
1221 || sanitized.to_lowercase().contains("unauthorized")
1222 {
1223 "AUTH_ERROR"
1224 } else if sanitized.to_lowercase().contains("model_provider")
1225 || sanitized.to_lowercase().contains("model")
1226 {
1227 "PROVIDER_ERROR"
1228 } else {
1229 "AGENT_ERROR"
1230 };
1231 let err = serde_json::json!({
1232 "type": "error",
1233 "message": sanitized,
1234 "code": error_code,
1235 });
1236 let _ = sender.send(Message::Text(err.to_string().into())).await;
1237
1238 let _ = state.event_tx.send(serde_json::json!({
1240 "type": "error",
1241 "component": "ws_chat",
1242 "message": sanitized,
1243 }));
1244
1245 ::zeroclaw_log::record!(
1249 WARN,
1250 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1251 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1252 .with_attrs(::serde_json::json!({
1253 "model_provider": provider_label,
1254 "model": state.model,
1255 "session_key": session_key,
1256 "error": sanitized,
1257 "error_code": error_code,
1258 "trace_id": turn_id,
1259 })),
1260 "gateway_ws_turn"
1261 );
1262 }
1263 }
1264}
1265
1266fn record_turn_cost(
1270 state: &AppState,
1271 provider_name: &str,
1272 model: &str,
1273 input_tokens: Option<u64>,
1274 output_tokens: Option<u64>,
1275 cached_input_tokens: Option<u64>,
1276) -> Option<f64> {
1277 let tracker = state.cost_tracker.as_ref()?;
1278 if input_tokens.is_none() && output_tokens.is_none() {
1279 return None;
1280 }
1281 let input = input_tokens.unwrap_or(0);
1282 let output = output_tokens.unwrap_or(0);
1283 let cached_input = cached_input_tokens.unwrap_or(0);
1284 if input == 0 && output == 0 {
1285 return None;
1286 }
1287 let config = state.config.read();
1295 let pricing_map = config
1296 .providers
1297 .models
1298 .iter_entries()
1299 .filter(|(_, _, base)| !base.pricing.is_empty())
1300 .map(|(type_k, alias_k, base)| (format!("{type_k}.{alias_k}"), base.pricing.clone()))
1301 .collect::<std::collections::HashMap<String, std::collections::HashMap<String, f64>>>();
1302 drop(config);
1303 let model_pricing = pricing_map.get(provider_name);
1304 let try_lookup = |key: &str| -> (f64, f64, f64) {
1305 let Some(map) = model_pricing else {
1306 return (0.0, 0.0, 0.0);
1307 };
1308 let in_rate = map
1309 .get(&format!("{key}.input"))
1310 .copied()
1311 .or_else(|| map.get(key).copied())
1312 .unwrap_or(0.0);
1313 let out_rate = map
1314 .get(&format!("{key}.output"))
1315 .copied()
1316 .or_else(|| map.get(key).copied())
1317 .unwrap_or(0.0);
1318 let cached_rate = map
1319 .get(&format!("{key}.cached_input"))
1320 .copied()
1321 .unwrap_or(0.0);
1322 (in_rate, out_rate, cached_rate)
1323 };
1324 let (input_rate, output_rate, cached_rate) = match try_lookup(model) {
1325 (0.0, 0.0, 0.0) => model
1326 .rsplit_once('/')
1327 .map(|(_, suffix)| try_lookup(suffix))
1328 .unwrap_or((0.0, 0.0, 0.0)),
1329 rates => rates,
1330 };
1331 let usage = zeroclaw_runtime::cost::types::TokenUsage::new(
1332 model,
1333 input,
1334 output,
1335 cached_input,
1336 input_rate,
1337 output_rate,
1338 cached_rate,
1339 );
1340 let cost_usd = usage.cost_usd;
1341 if let Err(error) = tracker.record_usage(usage) {
1342 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"provider": provider_name, "model": model, "error": format!("{}", error)})), "Failed to record gateway turn cost");
1343 }
1344 Some(cost_usd)
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349 use super::*;
1350 use axum::http::HeaderMap;
1351
1352 #[test]
1353 fn extract_ws_token_from_authorization_header() {
1354 let mut headers = HeaderMap::new();
1355 headers.insert("authorization", "Bearer zc_test123".parse().unwrap());
1356 assert_eq!(extract_ws_token(&headers, None), Some("zc_test123"));
1357 }
1358
1359 #[test]
1360 fn extract_ws_token_from_subprotocol() {
1361 let mut headers = HeaderMap::new();
1362 headers.insert(
1363 "sec-websocket-protocol",
1364 "zeroclaw.v1, bearer.zc_sub456".parse().unwrap(),
1365 );
1366 assert_eq!(extract_ws_token(&headers, None), Some("zc_sub456"));
1367 }
1368
1369 #[test]
1370 fn extract_ws_token_from_query_param() {
1371 let headers = HeaderMap::new();
1372 assert_eq!(
1373 extract_ws_token(&headers, Some("zc_query789")),
1374 Some("zc_query789")
1375 );
1376 }
1377
1378 #[test]
1379 fn extract_ws_token_precedence_header_over_subprotocol() {
1380 let mut headers = HeaderMap::new();
1381 headers.insert("authorization", "Bearer zc_header".parse().unwrap());
1382 headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
1383 assert_eq!(
1384 extract_ws_token(&headers, Some("zc_query")),
1385 Some("zc_header")
1386 );
1387 }
1388
1389 #[test]
1390 fn extract_ws_token_precedence_subprotocol_over_query() {
1391 let mut headers = HeaderMap::new();
1392 headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
1393 assert_eq!(extract_ws_token(&headers, Some("zc_query")), Some("zc_sub"));
1394 }
1395
1396 #[test]
1397 fn extract_ws_token_returns_none_when_empty() {
1398 let headers = HeaderMap::new();
1399 assert_eq!(extract_ws_token(&headers, None), None);
1400 }
1401
1402 #[test]
1403 fn extract_ws_token_skips_empty_header_value() {
1404 let mut headers = HeaderMap::new();
1405 headers.insert("authorization", "Bearer ".parse().unwrap());
1406 assert_eq!(
1407 extract_ws_token(&headers, Some("zc_fallback")),
1408 Some("zc_fallback")
1409 );
1410 }
1411
1412 #[test]
1413 fn extract_ws_token_skips_empty_query_param() {
1414 let headers = HeaderMap::new();
1415 assert_eq!(extract_ws_token(&headers, Some("")), None);
1416 }
1417
1418 #[test]
1419 fn extract_ws_token_subprotocol_with_multiple_entries() {
1420 let mut headers = HeaderMap::new();
1421 headers.insert(
1422 "sec-websocket-protocol",
1423 "zeroclaw.v1, bearer.zc_tok, other".parse().unwrap(),
1424 );
1425 assert_eq!(extract_ws_token(&headers, None), Some("zc_tok"));
1426 }
1427
1428 #[test]
1429 fn session_scoped_events_only_match_their_session() {
1430 let target_event = serde_json::json!({
1431 "type": "message",
1432 "session_id": "operator-1",
1433 "content": "deploy finished"
1434 });
1435 let other_event = serde_json::json!({
1436 "type": "message",
1437 "session_id": "operator-2",
1438 "content": "different session"
1439 });
1440 let global_event = serde_json::json!({
1441 "type": "cron_result",
1442 "content": "global notification"
1443 });
1444
1445 assert!(event_matches_session(&target_event, "operator-1"));
1446 assert!(!event_matches_session(&other_event, "operator-1"));
1447 assert!(event_matches_session(&global_event, "operator-1"));
1448 }
1449
1450 #[test]
1451 fn resolve_session_cwd_uses_requested_cwd() {
1452 let requested = tempfile::tempdir().unwrap();
1453 let fallback = tempfile::tempdir().unwrap();
1454
1455 let resolved =
1456 resolve_session_cwd(Some(requested.path().to_str().unwrap()), fallback.path()).unwrap();
1457
1458 assert_eq!(resolved, requested.path().canonicalize().unwrap());
1459 }
1460
1461 #[test]
1462 fn resolve_session_cwd_uses_default_workspace_without_request() {
1463 let fallback = tempfile::tempdir().unwrap();
1464
1465 let resolved = resolve_session_cwd(None, fallback.path()).unwrap();
1466
1467 assert_eq!(resolved, fallback.path().canonicalize().unwrap());
1468 }
1469
1470 #[test]
1471 fn resolve_session_cwd_rejects_missing_directory() {
1472 let fallback = tempfile::tempdir().unwrap();
1473 let missing = fallback.path().join("missing");
1474
1475 let err = resolve_session_cwd(Some(missing.to_str().unwrap()), fallback.path())
1476 .expect_err("missing cwd should be rejected");
1477
1478 assert!(err.to_string().contains("cwd is not a usable directory"));
1479 }
1480
1481 #[test]
1482 fn needs_onboarding_ws_error_points_to_onboard() {
1483 let config = zeroclaw_config::schema::Config::default();
1484 let frame = needs_onboarding_ws_error(&config)
1485 .expect("empty model must produce a WS onboarding error");
1486
1487 assert_eq!(frame["type"], "error");
1488 assert_eq!(frame["error"], "needs_onboarding");
1489 assert_eq!(frame["code"], "NEEDS_ONBOARDING");
1490 assert_eq!(frame["url"], "/onboard");
1491 let message = frame["message"]
1492 .as_str()
1493 .expect("onboarding WS error must include a message");
1494 assert!(
1495 !message.starts_with('{') && !message.ends_with('}'),
1496 "missing Fluent key fallback leaked into WS error message: {message:?}"
1497 );
1498 assert!(
1499 message.to_lowercase().contains("onboarding"),
1500 "WS onboarding message must explain the setup gap: {message:?}"
1501 );
1502 }
1503
1504 #[test]
1505 fn needs_onboarding_ws_error_uses_current_configured_model() {
1506 let mut config = zeroclaw_config::schema::Config::default();
1507 config.providers.models.openai.insert(
1508 "default".to_string(),
1509 zeroclaw_config::schema::OpenAIModelProviderConfig {
1510 base: zeroclaw_config::schema::ModelProviderConfig {
1511 model: Some("openai/gpt-4o-mini".to_string()),
1512 api_key: Some("sk-test".to_string()),
1513 ..Default::default()
1514 },
1515 },
1516 );
1517
1518 assert!(
1519 needs_onboarding_ws_error(&config).is_none(),
1520 "current configured model must allow WebSocket agent construction to continue"
1521 );
1522 }
1523
1524 #[derive(Debug, PartialEq, Eq)]
1529 enum DisconnectAction {
1530 Break,
1531 Continue,
1532 ProcessText,
1533 }
1534
1535 fn classify_client_msg(
1536 msg: Option<Result<axum::extract::ws::Message, &'static str>>,
1537 ) -> DisconnectAction {
1538 use axum::extract::ws::Message;
1539 match msg {
1540 Some(Ok(Message::Text(_))) => DisconnectAction::ProcessText,
1541 Some(Ok(Message::Close(_))) | Some(Err(_)) | None => DisconnectAction::Break,
1542 _ => DisconnectAction::Continue,
1543 }
1544 }
1545
1546 #[test]
1547 fn mid_turn_client_msg_breaks_on_stream_end_close_or_err() {
1548 use axum::extract::ws::Message;
1549 assert_eq!(classify_client_msg(None), DisconnectAction::Break);
1550 assert_eq!(
1551 classify_client_msg(Some(Ok(Message::Close(None)))),
1552 DisconnectAction::Break,
1553 );
1554 assert_eq!(
1555 classify_client_msg(Some(Err("io"))),
1556 DisconnectAction::Break,
1557 );
1558 assert_eq!(
1559 classify_client_msg(Some(Ok(Message::Ping(Default::default())))),
1560 DisconnectAction::Continue,
1561 );
1562 assert_eq!(
1563 classify_client_msg(Some(Ok(Message::Text("{}".into())))),
1564 DisconnectAction::ProcessText,
1565 );
1566 }
1567
1568 #[test]
1569 fn mid_turn_disconnect_cancel_unblocks_joined_turn() {
1570 let token = tokio_util::sync::CancellationToken::new();
1571 let clone_for_turn = token.clone();
1572 assert!(!clone_for_turn.is_cancelled());
1573 token.cancel();
1574 assert!(
1575 clone_for_turn.is_cancelled(),
1576 "cloned token (held by turn_fut via agent.turn_streamed) must observe cancellation"
1577 );
1578 }
1579
1580 #[test]
1581 fn session_queue_errors_map_to_explicit_websocket_codes() {
1582 use crate::session_queue::SessionQueueError;
1583
1584 assert_eq!(
1585 session_queue_ws_error_code(&SessionQueueError::QueueFull {
1586 session_id: "gw_test".into(),
1587 depth: 2,
1588 }),
1589 "SESSION_QUEUE_FULL"
1590 );
1591 assert_eq!(
1592 session_queue_ws_error_code(&SessionQueueError::Timeout {
1593 session_id: "gw_test".into(),
1594 }),
1595 "SESSION_QUEUE_TIMEOUT"
1596 );
1597 }
1598}