1use super::AppState;
59use crate::ws_approval::{PendingApprovals, WsApprovalChannel, new_pending_approvals};
60use axum::{
61 extract::{
62 Query, State, WebSocketUpgrade,
63 ws::{Message, WebSocket},
64 },
65 http::{HeaderMap, header},
66 response::IntoResponse,
67};
68use futures_util::{SinkExt, StreamExt};
69use serde::Deserialize;
70use std::path::{Path, PathBuf};
71use std::sync::Arc;
72use std::time::Duration;
73use zeroclaw_api::channel::ChannelApprovalResponse;
74
75const WS_APPROVAL_TIMEOUT_SECS: u64 = 120;
79
80#[derive(Debug, Deserialize)]
87struct ConnectParams {
88 #[serde(rename = "type")]
89 msg_type: String,
90 #[serde(default)]
92 session_id: Option<String>,
93 #[serde(default)]
95 device_name: Option<String>,
96 #[serde(default)]
98 capabilities: Vec<String>,
99 #[serde(default, alias = "workspaceDir", alias = "workspace_dir")]
101 cwd: Option<String>,
102}
103
104const WS_PROTOCOL: &str = "zeroclaw.v1";
106
107const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
109
110#[derive(Deserialize)]
111pub struct WsQuery {
112 pub token: Option<String>,
113 pub session_id: Option<String>,
114 pub name: Option<String>,
116 #[serde(default, alias = "agentAlias", alias = "agent")]
119 pub agent_alias: Option<String>,
120 #[serde(default)]
122 pub cwd: Option<String>,
123 #[serde(default, alias = "workspaceDir", alias = "workspace_dir")]
124 pub workspace_dir: Option<String>,
125}
126
127fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
137 if let Some(t) = headers
139 .get(header::AUTHORIZATION)
140 .and_then(|v| v.to_str().ok())
141 .and_then(|auth| auth.strip_prefix("Bearer "))
142 && !t.is_empty()
143 {
144 return Some(t);
145 }
146
147 if let Some(t) = headers
149 .get("sec-websocket-protocol")
150 .and_then(|v| v.to_str().ok())
151 .and_then(|protos| {
152 protos
153 .split(',')
154 .map(|p| p.trim())
155 .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
156 })
157 && !t.is_empty()
158 {
159 return Some(t);
160 }
161
162 if let Some(t) = query_token
164 && !t.is_empty()
165 {
166 return Some(t);
167 }
168
169 None
170}
171
172pub async fn handle_ws_chat(
174 State(state): State<AppState>,
175 Query(params): Query<WsQuery>,
176 headers: HeaderMap,
177 ws: WebSocketUpgrade,
178) -> impl IntoResponse {
179 if state.pairing.require_pairing() {
181 let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
182 if !state.pairing.is_authenticated(token) {
183 return (
184 axum::http::StatusCode::UNAUTHORIZED,
185 "Unauthorized — provide Authorization header, Sec-WebSocket-Protocol bearer, or ?token= query param",
186 )
187 .into_response();
188 }
189 }
190
191 let ws = if headers
193 .get("sec-websocket-protocol")
194 .and_then(|v| v.to_str().ok())
195 .is_some_and(|protos| protos.split(',').any(|p| p.trim() == WS_PROTOCOL))
196 {
197 ws.protocols([WS_PROTOCOL])
198 } else {
199 ws
200 };
201
202 let Some(agent_alias) = params.agent_alias.filter(|s| !s.trim().is_empty()) else {
205 return (
206 axum::http::StatusCode::BAD_REQUEST,
207 "Missing required `agent` query parameter — pass `?agent=<alias>` matching a configured [agents.<alias>] entry.",
208 )
209 .into_response();
210 };
211 {
212 let cfg = state.config.read();
213 if cfg.agent(&agent_alias).is_none() {
214 return (
215 axum::http::StatusCode::BAD_REQUEST,
216 format!(
217 "Unknown agent `{agent_alias}` — no [agents.{agent_alias}] entry configured."
218 ),
219 )
220 .into_response();
221 }
222 }
223
224 let session_id = params.session_id;
225 let session_name = params.name;
226 let session_cwd = params.cwd.or(params.workspace_dir);
227 ws.on_upgrade(move |socket| {
228 handle_socket(
229 socket,
230 state,
231 agent_alias,
232 session_id,
233 session_name,
234 session_cwd,
235 )
236 })
237 .into_response()
238}
239
240const GW_SESSION_PREFIX: &str = "gw_";
242
243async fn handle_socket(
244 socket: WebSocket,
245 state: AppState,
246 agent_alias: String,
247 session_id: Option<String>,
248 session_name: Option<String>,
249 session_cwd: Option<String>,
250) {
251 let (mut sender, mut receiver) = socket.split();
252
253 let session_id = session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
255 let session_key = format!("{GW_SESSION_PREFIX}{session_id}");
256 let mut memory_session_id = zeroclaw_api::session_keys::sanitize_session_key(&session_id);
258
259 let config = state.config.read().clone();
263 let mut resumed = false;
264 let mut message_count: usize = 0;
265 let mut effective_name: Option<String> = None;
266 let mut stored_messages = Vec::new();
267 if let Some(ref backend) = state.session_backend {
268 let messages = backend.load(&session_key);
269 if !messages.is_empty() {
270 message_count = messages.len();
271 stored_messages = messages;
272 resumed = true;
273 }
274 if let Some(ref name) = session_name
276 && !name.is_empty()
277 {
278 let _ = backend.set_session_name(&session_key, name);
279 effective_name = Some(name.clone());
280 }
281 if effective_name.is_none() {
283 effective_name = backend.get_session_name(&session_key).unwrap_or(None);
284 }
285 let _ = backend.set_session_agent_alias(&session_key, &agent_alias);
288 }
289
290 let mut session_start = serde_json::json!({
292 "type": "session_start",
293 "session_id": session_id,
294 "resumed": resumed,
295 "message_count": message_count,
296 });
297 if let Some(ref name) = effective_name {
298 session_start["name"] = serde_json::Value::String(name.clone());
299 }
300 let _ = sender
301 .send(Message::Text(session_start.to_string().into()))
302 .await;
303
304 let mut first_msg_fallback: Option<String> = None;
311 let mut requested_cwd = session_cwd;
312
313 if let Some(first) = receiver.next().await {
314 match first {
315 Ok(Message::Text(text)) => {
316 if let Ok(cp) = serde_json::from_str::<ConnectParams>(&text) {
317 if cp.msg_type == "connect" {
318 ::zeroclaw_log::record!(DEBUG, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"session_id": cp.session_id, "device_name": cp.device_name, "capabilities": cp.capabilities, "cwd": cp.cwd})), "WebSocket connect params received");
319 if let Some(sid) = &cp.session_id {
320 memory_session_id =
321 zeroclaw_api::session_keys::sanitize_session_key(sid);
322 ::zeroclaw_log::record!(
323 DEBUG,
324 ::zeroclaw_log::Event::new(
325 module_path!(),
326 ::zeroclaw_log::Action::Note
327 )
328 .with_attrs(::serde_json::json!({"session_id": sid})),
329 "WebSocket connect session override received"
330 );
331 }
332 if cp.cwd.is_some() {
333 requested_cwd = cp.cwd;
334 }
335 let ack = serde_json::json!({
336 "type": "connected",
337 "message": "Connection established"
338 });
339 let _ = sender.send(Message::Text(ack.to_string().into())).await;
340 } else {
341 first_msg_fallback = Some(text.to_string());
343 }
344 } else {
345 first_msg_fallback = Some(text.to_string());
347 }
348 }
349 Ok(Message::Close(_)) | Err(_) => return,
350 _ => {}
351 }
352 }
353
354 let session_cwd = match resolve_session_cwd(requested_cwd.as_deref(), &config.data_dir) {
355 Ok(cwd) => cwd,
356 Err(e) => {
357 let err = serde_json::json!({
358 "type": "error",
359 "message": e.to_string(),
360 "code": "INVALID_CWD"
361 });
362 let _ = sender.send(Message::Text(err.to_string().into())).await;
363 return;
364 }
365 };
366
367 if let Some(err) = needs_onboarding_ws_error(&config) {
368 let _ = sender.send(Message::Text(err.to_string().into())).await;
369 return;
370 }
371
372 let mut agent =
379 match zeroclaw_runtime::agent::Agent::from_config_with_session_cwd_and_mcp_backchannel(
380 &config,
381 &agent_alias,
382 Some(&session_cwd),
383 true,
384 false,
385 )
386 .await
387 {
388 Ok(a) => a,
389 Err(e) => {
390 ::zeroclaw_log::record!(
391 ERROR,
392 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
393 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
394 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
395 "Agent initialization failed"
396 );
397 let err = serde_json::json!({
398 "type": "error",
399 "message": format!("Failed to initialise agent: {e}"),
400 "code": "AGENT_INIT_FAILED"
401 });
402 let _ = sender.send(Message::Text(err.to_string().into())).await;
403 let _ = sender
404 .send(Message::Close(Some(axum::extract::ws::CloseFrame {
405 code: 1011,
406 reason: axum::extract::ws::Utf8Bytes::from_static(
407 "Agent initialization failed",
408 ),
409 })))
410 .await;
411 return;
412 }
413 };
414 agent.set_memory_session_id(Some(memory_session_id));
415 if !stored_messages.is_empty() {
416 agent.seed_history(&stored_messages);
417 }
418
419 let (approval_event_tx, mut approval_event_rx) =
428 tokio::sync::mpsc::channel::<zeroclaw_api::agent::TurnEvent>(8);
429 let pending_approvals: PendingApprovals = new_pending_approvals();
430 let approval_channel = Arc::new(WsApprovalChannel::new(
431 approval_event_tx.clone(),
432 pending_approvals.clone(),
433 Duration::from_secs(WS_APPROVAL_TIMEOUT_SECS),
434 ));
435 agent
436 .channel_handles()
437 .register_channel("ws", approval_channel.clone());
438
439 let ch = agent.channel_handles();
445 let channel_names = zeroclaw_channels::orchestrator::register_channels_for_tools(
446 &config,
447 &ch.ask_user,
448 &Some(ch.reaction.clone()),
449 &ch.poll,
450 &ch.escalate,
451 );
452 if !channel_names.is_empty() {
453 ::zeroclaw_log::record!(
454 INFO,
455 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(
456 ::serde_json::json!({"channels": channel_names, "session": session_key})
457 ),
458 "Seeded {} channel(s) into dashboard agent session",
459 );
460 }
461
462 if let Some(ref text) = first_msg_fallback {
464 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
465 if parsed["type"].as_str() == Some("message") {
466 let content = parsed["content"].as_str().unwrap_or("").to_string();
467 if !content.is_empty() {
468 let _session_guard = match state.session_queue.acquire(&session_key).await {
469 Ok(guard) => guard,
470 Err(e) => {
471 let err = serde_json::json!({
472 "type": "error",
473 "message": e.to_string(),
474 "code": session_queue_ws_error_code(&e)
475 });
476 let _ = sender.send(Message::Text(err.to_string().into())).await;
477 return;
478 }
479 };
480 process_chat_message(
481 &state,
482 &mut agent,
483 &mut sender,
484 &mut receiver,
485 &mut approval_event_rx,
486 &pending_approvals,
487 &content,
488 &session_key,
489 )
490 .await;
491 }
492 } else {
493 let unknown_type = parsed["type"].as_str().unwrap_or("unknown");
494 let err = serde_json::json!({
495 "type": "error",
496 "message": format!(
497 "Unsupported message type \"{unknown_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
498 )
499 });
500 let _ = sender.send(Message::Text(err.to_string().into())).await;
501 }
502 } else {
503 let err = serde_json::json!({
504 "type": "error",
505 "message": "Invalid JSON. Send {\"type\":\"message\",\"content\":\"your text\"}"
506 });
507 let _ = sender.send(Message::Text(err.to_string().into())).await;
508 }
509 }
510
511 let mut broadcast_rx = state.event_tx.subscribe();
514
515 loop {
516 tokio::select! {
517 client_msg = receiver.next() => {
519 let Some(msg) = client_msg else { break };
520 let msg = match msg {
521 Ok(Message::Text(text)) => text,
522 Ok(Message::Close(_)) | Err(_) => break,
523 _ => continue,
524 };
525
526 let parsed: serde_json::Value = match serde_json::from_str(&msg) {
528 Ok(v) => v,
529 Err(e) => {
530 let err = serde_json::json!({
531 "type": "error",
532 "message": format!("Invalid JSON: {}", e),
533 "code": "INVALID_JSON"
534 });
535 let _ = sender.send(Message::Text(err.to_string().into())).await;
536 continue;
537 }
538 };
539
540 let msg_type = parsed["type"].as_str().unwrap_or("");
541
542 #[cfg(feature = "gateway-voice-duplex")]
544 {
545 let duplex_enabled = !state.config.read().channels.voice_duplex.is_empty();
547 if duplex_enabled {
548 if let Some(voice_event) = crate::voice_duplex::try_parse_voice_event(&msg) {
549 if let Some(error_frame) = crate::voice_duplex::handle_voice_event(voice_event) {
550 let _ = sender.send(Message::Text(error_frame.to_string().into())).await;
551 }
552 continue;
553 }
554 }
555 }
556
557 if msg_type == "approval_response" {
559 let request_id = parsed["request_id"].as_str().unwrap_or("");
560 let decision_str = parsed["decision"].as_str().unwrap_or("");
561 let decision = match decision_str {
562 "approve" => Some(ChannelApprovalResponse::Approve),
563 "always" => Some(ChannelApprovalResponse::AlwaysApprove),
564 "deny" => Some(ChannelApprovalResponse::Deny),
565 _ => None,
566 };
567 if request_id.is_empty() || decision.is_none() {
568 let err = serde_json::json!({
569 "type": "error",
570 "message": "approval_response requires request_id and decision in {approve,deny,always}",
571 "code": "INVALID_APPROVAL_RESPONSE"
572 });
573 let _ = sender.send(Message::Text(err.to_string().into())).await;
574 continue;
575 }
576 if let Some(tx) = pending_approvals.lock().remove(request_id) {
577 let _ = tx.send(decision.expect("checked above"));
578 } else {
579 ::zeroclaw_log::record!(DEBUG, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"request_id": request_id})), "approval_response with no matching pending request");
580 }
581 continue;
582 }
583
584 if msg_type != "message" {
585 let err = serde_json::json!({
586 "type": "error",
587 "message": format!(
588 "Unsupported message type \"{msg_type}\". Send {{\"type\":\"message\",\"content\":\"your text\"}}"
589 ),
590 "code": "UNKNOWN_MESSAGE_TYPE"
591 });
592 let _ = sender.send(Message::Text(err.to_string().into())).await;
593 continue;
594 }
595
596 let content = parsed["content"].as_str().unwrap_or("").to_string();
597 if content.is_empty() {
598 let err = serde_json::json!({
599 "type": "error",
600 "message": "Message content cannot be empty",
601 "code": "EMPTY_CONTENT"
602 });
603 let _ = sender.send(Message::Text(err.to_string().into())).await;
604 continue;
605 }
606
607 let _session_guard = match state.session_queue.acquire(&session_key).await {
609 Ok(guard) => guard,
610 Err(e) => {
611 let err = serde_json::json!({
612 "type": "error",
613 "message": e.to_string(),
614 "code": session_queue_ws_error_code(&e)
615 });
616 let _ = sender.send(Message::Text(err.to_string().into())).await;
617 continue;
618 }
619 };
620
621 process_chat_message(
622 &state,
623 &mut agent,
624 &mut sender,
625 &mut receiver,
626 &mut approval_event_rx,
627 &pending_approvals,
628 &content,
629 &session_key,
630 )
631 .await;
632 }
633
634 event = broadcast_rx.recv() => {
636 if let Ok(event) = event
637 && event_matches_session(&event, &session_id)
638 && !is_observability_telemetry(&event)
639 {
640 let _ = sender.send(Message::Text(event.to_string().into())).await;
641 }
642 }
643
644 approval_event = approval_event_rx.recv() => {
651 let Some(event) = approval_event else { break };
652 let frame = match event {
653 zeroclaw_api::agent::TurnEvent::ApprovalRequest {
654 request_id,
655 tool_name,
656 arguments_summary,
657 timeout_secs,
658 } => serde_json::json!({
659 "type": "approval_request",
660 "request_id": request_id,
661 "tool": tool_name,
662 "arguments_summary": arguments_summary,
663 "timeout_secs": timeout_secs,
664 }),
665 other => {
666 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"kind": format!("{:?}", other)})), "non-ApprovalRequest event leaked into approval channel");
667 continue;
668 }
669 };
670 let _ = sender.send(Message::Text(frame.to_string().into())).await;
671 }
672 }
673 }
674}
675
676fn resolve_session_cwd(
677 requested_cwd: Option<&str>,
678 default_workspace: &Path,
679) -> anyhow::Result<PathBuf> {
680 let cwd = requested_cwd
681 .map(PathBuf::from)
682 .unwrap_or_else(|| default_workspace.to_path_buf());
683 std::fs::canonicalize(&cwd).map_err(|e| {
684 ::zeroclaw_log::record!(
685 WARN,
686 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
687 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
688 .with_attrs(::serde_json::json!({
689 "cwd": cwd.display().to_string(),
690 "error": format!("{}", e),
691 })),
692 "ws session cwd rejected"
693 );
694 anyhow::Error::msg(format!(
695 "cwd is not a usable directory ({}): {e}",
696 cwd.display()
697 ))
698 })
699}
700
701fn session_queue_ws_error_code(error: &crate::session_queue::SessionQueueError) -> &'static str {
702 match error {
703 crate::session_queue::SessionQueueError::QueueFull { .. } => "SESSION_QUEUE_FULL",
704 crate::session_queue::SessionQueueError::Timeout { .. } => "SESSION_QUEUE_TIMEOUT",
705 }
706}
707
708fn persist_conversation_messages(
709 backend: &dyn zeroclaw_infra::session_backend::SessionBackend,
710 session_key: &str,
711 messages: &[zeroclaw_providers::ConversationMessage],
712) {
713 if !backend.session_exists(session_key) {
718 return;
719 }
720 for message in messages {
721 let zeroclaw_providers::ConversationMessage::Chat(message) = message else {
722 continue;
723 };
724 if message.role == "system" {
725 continue;
726 }
727 let _ = backend.append(session_key, message);
728 }
729}
730
731fn has_assistant_chat_message(messages: &[zeroclaw_providers::ConversationMessage]) -> bool {
732 messages.iter().any(|message| {
733 matches!(
734 message,
735 zeroclaw_providers::ConversationMessage::Chat(message)
736 if message.role == "assistant"
737 )
738 })
739}
740
741fn needs_onboarding_ws_error(
742 config: &zeroclaw_config::schema::Config,
743) -> Option<serde_json::Value> {
744 let model = config.resolve_default_model().unwrap_or_default();
745 crate::needs_quickstart_for(&model)?;
746 Some(serde_json::json!({
747 "type": "error",
748 "error": "needs_onboarding",
749 "code": "NEEDS_ONBOARDING",
750 "message": crate::needs_quickstart_channel_reply(),
751 "url": "/onboard",
752 }))
753}
754
755fn event_matches_session(event: &serde_json::Value, session_id: &str) -> bool {
774 match event.get("session_id").and_then(|value| value.as_str()) {
775 Some(event_session_id) => event_session_id == session_id,
776 None => is_global_chat_event(event),
777 }
778}
779
780fn is_global_chat_event(event: &serde_json::Value) -> bool {
789 matches!(
790 event.get("type").and_then(serde_json::Value::as_str),
791 Some("cron_result")
792 )
793}
794
795fn is_observability_telemetry(event: &serde_json::Value) -> bool {
806 event.get("source").and_then(serde_json::Value::as_str) == Some("observability")
807}
808
809async fn process_chat_message(
814 state: &AppState,
815 agent: &mut zeroclaw_runtime::agent::Agent,
816 sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
817 receiver: &mut futures_util::stream::SplitStream<WebSocket>,
818 approval_event_rx: &mut tokio::sync::mpsc::Receiver<zeroclaw_api::agent::TurnEvent>,
819 pending_approvals: &PendingApprovals,
820 content: &str,
821 session_key: &str,
822) {
823 use futures_util::StreamExt as _;
824 use zeroclaw_runtime::agent::TurnEvent;
825
826 let (turn_alias, turn_provider, turn_model) = agent.attribution_fields();
832 let provider_label = turn_provider.clone();
833 let model_label = turn_model.clone();
834
835 let _ = state.event_tx.send(serde_json::json!({
837 "type": "agent_start",
838 "model_provider": provider_label,
839 "model": model_label,
840 }));
841
842 let turn_id = uuid::Uuid::new_v4().to_string();
844 if let Some(ref backend) = state.session_backend {
845 let _ = backend.set_session_state(session_key, "running", Some(&turn_id));
846 }
847
848 let cancel_token = tokio_util::sync::CancellationToken::new();
853 {
854 state
855 .cancel_tokens
856 .lock()
857 .expect("cancel_tokens lock poisoned")
858 .insert(session_key.to_string(), cancel_token.clone());
859 }
860
861 let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<TurnEvent>(64);
863 let (steering_tx, mut steering_rx) = tokio::sync::mpsc::channel::<String>(32);
864
865 let content_owned = content.to_string();
871 let session_key_owned = session_key.to_string();
872 let turn_fut = async {
873 use ::zeroclaw_log::Instrument as _;
874 let span = ::zeroclaw_log::info_span!(
875 target: "zeroclaw_log_internal_scope",
876 "zeroclaw_scope",
877 session_key = %session_key_owned,
878 agent_alias = %turn_alias,
879 model_provider = %turn_provider,
880 model = %turn_model,
881 channel = "wss",
882 );
883 zeroclaw_runtime::agent::loop_::scope_session_key(
884 Some(session_key_owned.clone()),
885 agent
886 .turn_streamed_with_steering_state(
887 &content_owned,
888 event_tx,
889 Some(cancel_token.clone()),
890 Some(&mut steering_rx),
891 )
892 .instrument(span),
893 )
894 .await
895 };
896
897 let mut accumulated_text = String::new();
902
903 let mut total_input_tokens: Option<u64> = None;
907 let mut total_output_tokens: Option<u64> = None;
908
909 let forward_fut = async {
918 let mut cancel_drained = false;
919 loop {
920 tokio::select! {
921 biased;
922 _ = cancel_token.cancelled(), if !cancel_drained => {
932 let drained: Vec<_> = pending_approvals.lock().drain().collect();
933 drop(drained);
934 cancel_drained = true;
935 }
940 client_msg = receiver.next() => {
941 let text = match client_msg {
947 Some(Ok(Message::Text(text))) => text,
948 Some(Ok(Message::Close(_))) | Some(Err(_)) | None => {
949 cancel_token.cancel();
950 break;
951 }
952 _ => continue,
953 };
954 let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) else {
955 let err = serde_json::json!({
956 "type": "error",
957 "message": "Invalid JSON. Send {\"type\":\"message\",\"content\":\"your text\"}",
958 "code": "INVALID_JSON"
959 });
960 let _ = sender.send(Message::Text(err.to_string().into())).await;
961 continue;
962 };
963 match parsed["type"].as_str() {
964 Some("approval_response") => {
965 let request_id = parsed["request_id"].as_str().unwrap_or("");
966 let decision = match parsed["decision"].as_str().unwrap_or("") {
967 "approve" => Some(ChannelApprovalResponse::Approve),
968 "always" => Some(ChannelApprovalResponse::AlwaysApprove),
969 "deny" => Some(ChannelApprovalResponse::Deny),
970 _ => None,
971 };
972 if request_id.is_empty() || decision.is_none() {
973 continue;
974 }
975 if let Some(tx) = pending_approvals.lock().remove(request_id) {
976 let _ = tx.send(decision.expect("checked above"));
977 } else {
978 ::zeroclaw_log::record!(DEBUG, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"request_id": request_id})), "approval_response with no matching pending request (mid-turn)");
979 }
980 }
981 Some("message") => {
982 let content = parsed["content"].as_str().unwrap_or("").to_string();
983 if content.is_empty() {
984 let err = serde_json::json!({
985 "type": "error",
986 "message": "Message content cannot be empty",
987 "code": "EMPTY_CONTENT"
988 });
989 let _ = sender.send(Message::Text(err.to_string().into())).await;
990 continue;
991 }
992 match steering_tx.try_send(content) {
993 Ok(()) => {}
994 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
995 let err = serde_json::json!({
996 "type": "error",
997 "message": "Steering queue is full for the running turn",
998 "code": "STEERING_QUEUE_FULL"
999 });
1000 let _ = sender.send(Message::Text(err.to_string().into())).await;
1001 }
1002 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
1003 let err = serde_json::json!({
1004 "type": "error",
1005 "message": "Running turn is no longer accepting steering messages",
1006 "code": "STEERING_CLOSED"
1007 });
1008 let _ = sender.send(Message::Text(err.to_string().into())).await;
1009 }
1010 }
1011 }
1012 _ => {}
1013 }
1014 }
1015 approval = approval_event_rx.recv() => {
1016 let Some(event) = approval else { continue };
1017 if let TurnEvent::ApprovalRequest {
1018 request_id,
1019 tool_name,
1020 arguments_summary,
1021 timeout_secs,
1022 } = event {
1023 let frame = serde_json::json!({
1024 "type": "approval_request",
1025 "request_id": request_id,
1026 "tool": tool_name,
1027 "arguments_summary": arguments_summary,
1028 "timeout_secs": timeout_secs,
1029 });
1030 let _ = sender.send(Message::Text(frame.to_string().into())).await;
1031 }
1032 }
1033 event_opt = event_rx.recv() => {
1034 let Some(event) = event_opt else { break };
1035 let ws_msg = match event {
1036 TurnEvent::Usage {
1037 input_tokens,
1038 cached_input_tokens: _,
1039 output_tokens,
1040 cost_usd: _,
1041 } => {
1042 if let Some(it) = input_tokens {
1048 total_input_tokens = Some(total_input_tokens.unwrap_or(0) + it);
1049 }
1050 if let Some(ot) = output_tokens {
1051 total_output_tokens = Some(total_output_tokens.unwrap_or(0) + ot);
1052 }
1053 continue;
1054 }
1055 TurnEvent::Chunk { ref delta } => {
1056 accumulated_text.push_str(delta);
1057 serde_json::json!({ "type": "chunk", "content": delta })
1058 }
1059 TurnEvent::Thinking { delta } => {
1060 serde_json::json!({ "type": "thinking", "content": delta })
1061 }
1062 TurnEvent::ToolCall { id, name, args } => {
1063 serde_json::json!({ "type": "tool_call", "id": id, "name": name, "args": args })
1064 }
1065 TurnEvent::ToolResult { id, name, output } => {
1066 serde_json::json!({ "type": "tool_result", "id": id, "name": name, "output": output })
1067 }
1068 TurnEvent::ApprovalRequest {
1069 request_id,
1070 tool_name,
1071 arguments_summary,
1072 timeout_secs,
1073 } => serde_json::json!({
1074 "type": "approval_request",
1075 "request_id": request_id,
1076 "tool": tool_name,
1077 "arguments_summary": arguments_summary,
1078 "timeout_secs": timeout_secs,
1079 }),
1080 };
1081 let _ = sender.send(Message::Text(ws_msg.to_string().into())).await;
1082 }
1083 }
1084 }
1085 };
1086
1087 let (result, ()) = tokio::join!(turn_fut, forward_fut);
1088
1089 {
1091 state
1092 .cancel_tokens
1093 .lock()
1094 .expect("cancel_tokens lock poisoned")
1095 .remove(session_key);
1096 }
1097
1098 let was_cancelled = match &result {
1101 Err(e) => zeroclaw_runtime::agent::loop_::is_tool_loop_cancelled(&e.error),
1102 Ok(_) => false,
1103 };
1104
1105 if was_cancelled {
1106 if let Some(ref backend) = state.session_backend {
1107 let still_exists = backend.session_exists(session_key);
1117 if still_exists {
1118 match &result {
1119 Err(error) if !error.new_messages.is_empty() => {
1120 persist_conversation_messages(
1121 backend.as_ref(),
1122 session_key,
1123 &error.new_messages,
1124 );
1125 if !has_assistant_chat_message(&error.new_messages) {
1126 let truncated = if accumulated_text.is_empty() {
1127 "[interrupted by user]".to_string()
1128 } else {
1129 format!("{accumulated_text}\n\n[interrupted by user]")
1130 };
1131 let assistant_msg =
1132 zeroclaw_providers::ChatMessage::assistant(&truncated);
1133 if backend.session_exists(session_key) {
1138 let _ = backend.append(session_key, &assistant_msg);
1139 }
1140 }
1141 }
1142 _ => {
1143 let truncated = if accumulated_text.is_empty() {
1144 "[interrupted by user]".to_string()
1145 } else {
1146 format!("{accumulated_text}\n\n[interrupted by user]")
1147 };
1148 let assistant_msg = zeroclaw_providers::ChatMessage::assistant(&truncated);
1149 if backend.session_exists(session_key) {
1150 let _ = backend.append(session_key, &assistant_msg);
1151 }
1152 }
1153 }
1154 }
1155 }
1156
1157 let aborted = serde_json::json!({ "type": "aborted" });
1159 let _ = sender.send(Message::Text(aborted.to_string().into())).await;
1160
1161 if let Some(ref backend) = state.session_backend
1167 && backend.session_exists(session_key)
1168 {
1169 let _ = backend.set_session_state(session_key, "idle", None);
1170 }
1171
1172 let _ = state.event_tx.send(serde_json::json!({
1174 "type": "agent_end",
1175 "model_provider": provider_label,
1176 "model": model_label,
1177 }));
1178
1179 ::zeroclaw_log::record!(
1182 INFO,
1183 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Cancel)
1184 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1185 .with_attrs(::serde_json::json!({
1186 "model_provider": provider_label,
1187 "model": model_label,
1188 "session_key": session_key,
1189 "reason": "interrupted by user",
1190 "cancelled": true,
1191 "trace_id": turn_id,
1192 })),
1193 "gateway_ws_turn"
1194 );
1195
1196 return;
1197 }
1198
1199 match result {
1200 Ok(outcome) => {
1201 if let Some(ref backend) = state.session_backend {
1202 persist_conversation_messages(backend.as_ref(), session_key, &outcome.new_messages);
1203 }
1204
1205 if state.auto_save {
1208 let memory_strategy = state.memory_strategy.clone();
1209 let model_provider = state.model_provider.clone();
1210 let model = state.model.clone();
1211 let temperature = state.temperature;
1212 let user_msg = content.to_string();
1213 let assistant_resp = outcome.response.clone();
1214 zeroclaw_spawn::spawn!(async move {
1215 if let Err(e) = memory_strategy
1216 .consolidate_turn(
1217 &user_msg,
1218 &assistant_resp,
1219 model_provider.as_ref(),
1220 &model,
1221 temperature,
1222 )
1223 .await
1224 {
1225 ::zeroclaw_log::record!(
1226 DEBUG,
1227 ::zeroclaw_log::Event::new(
1228 module_path!(),
1229 ::zeroclaw_log::Action::Note
1230 )
1231 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
1232 "WS memory consolidation skipped"
1233 );
1234 }
1235 });
1236 }
1237
1238 let total_tokens = match (total_input_tokens, total_output_tokens) {
1242 (Some(i), Some(o)) => Some(i.saturating_add(o)),
1243 (Some(i), None) => Some(i),
1244 (None, Some(o)) => Some(o),
1245 (None, None) => None,
1246 };
1247 let cost_usd = record_turn_cost(
1248 state,
1249 &provider_label,
1250 &model_label,
1251 total_input_tokens,
1252 total_output_tokens,
1253 None,
1254 );
1255
1256 let done = serde_json::json!({
1257 "type": "done",
1258 "full_response": outcome.response,
1259 "input_tokens": total_input_tokens,
1260 "output_tokens": total_output_tokens,
1261 "tokens_used": total_tokens,
1262 "cost_usd": cost_usd,
1263 "model": model_label,
1264 "provider": provider_label,
1265 });
1266 let _ = sender.send(Message::Text(done.to_string().into())).await;
1267
1268 if let Some(ref backend) = state.session_backend {
1270 let _ = backend.set_session_state(session_key, "idle", None);
1271 }
1272
1273 let _ = state.event_tx.send(serde_json::json!({
1275 "type": "agent_end",
1276 "model_provider": provider_label,
1277 "model": model_label,
1278 }));
1279
1280 ::zeroclaw_log::record!(
1284 INFO,
1285 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Complete)
1286 .with_outcome(::zeroclaw_log::EventOutcome::Success)
1287 .with_attrs(::serde_json::json!({
1288 "model_provider": provider_label,
1289 "model": model_label,
1290 "session_key": session_key,
1291 "input_tokens": total_input_tokens,
1292 "output_tokens": total_output_tokens,
1293 "tokens_used": total_tokens,
1294 "cost_usd": cost_usd,
1295 "trace_id": turn_id,
1296 })),
1297 "gateway_ws_turn"
1298 );
1299 }
1300 Err(e) => {
1301 if let Some(ref backend) = state.session_backend
1302 && !e.new_messages.is_empty()
1303 {
1304 persist_conversation_messages(backend.as_ref(), session_key, &e.new_messages);
1305 }
1306
1307 if let Some(ref backend) = state.session_backend {
1309 let _ = backend.set_session_state(session_key, "error", Some(&turn_id));
1310 }
1311
1312 ::zeroclaw_log::record!(
1313 ERROR,
1314 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1315 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1316 .with_attrs(::serde_json::json!({"error": format!("{}", e.error)})),
1317 "Agent turn failed"
1318 );
1319 let sanitized = zeroclaw_providers::sanitize_api_error(&e.error.to_string());
1320 let error_code = if sanitized.to_lowercase().contains("api key")
1321 || sanitized.to_lowercase().contains("authentication")
1322 || sanitized.to_lowercase().contains("unauthorized")
1323 {
1324 "AUTH_ERROR"
1325 } else if sanitized.to_lowercase().contains("model_provider")
1326 || sanitized.to_lowercase().contains("model")
1327 {
1328 "PROVIDER_ERROR"
1329 } else {
1330 "AGENT_ERROR"
1331 };
1332 let err = serde_json::json!({
1333 "type": "error",
1334 "message": sanitized,
1335 "code": error_code,
1336 });
1337 let _ = sender.send(Message::Text(err.to_string().into())).await;
1338
1339 let _ = state.event_tx.send(serde_json::json!({
1341 "type": "error",
1342 "component": "ws_chat",
1343 "message": sanitized,
1344 }));
1345
1346 ::zeroclaw_log::record!(
1350 WARN,
1351 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
1352 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
1353 .with_attrs(::serde_json::json!({
1354 "model_provider": provider_label,
1355 "model": model_label,
1356 "session_key": session_key,
1357 "error": sanitized,
1358 "error_code": error_code,
1359 "trace_id": turn_id,
1360 })),
1361 "gateway_ws_turn"
1362 );
1363 }
1364 }
1365}
1366
1367fn record_turn_cost(
1371 state: &AppState,
1372 provider_name: &str,
1373 model: &str,
1374 input_tokens: Option<u64>,
1375 output_tokens: Option<u64>,
1376 cached_input_tokens: Option<u64>,
1377) -> Option<f64> {
1378 let tracker = state.cost_tracker.as_ref()?;
1379 if input_tokens.is_none() && output_tokens.is_none() {
1380 return None;
1381 }
1382 let input = input_tokens.unwrap_or(0);
1383 let output = output_tokens.unwrap_or(0);
1384 let cached_input = cached_input_tokens.unwrap_or(0);
1385 if input == 0 && output == 0 {
1386 return None;
1387 }
1388 let config = state.config.read();
1396 let pricing_map = config
1397 .providers
1398 .models
1399 .iter_entries()
1400 .filter(|(_, _, base)| !base.pricing.is_empty())
1401 .map(|(type_k, alias_k, base)| (format!("{type_k}.{alias_k}"), base.pricing.clone()))
1402 .collect::<std::collections::HashMap<String, std::collections::HashMap<String, f64>>>();
1403 drop(config);
1404 let model_pricing = pricing_map.get(provider_name);
1405 let try_lookup = |key: &str| -> (f64, f64, f64) {
1406 let Some(map) = model_pricing else {
1407 return (0.0, 0.0, 0.0);
1408 };
1409 let in_rate = map
1410 .get(&format!("{key}.input"))
1411 .copied()
1412 .or_else(|| map.get(key).copied())
1413 .unwrap_or(0.0);
1414 let out_rate = map
1415 .get(&format!("{key}.output"))
1416 .copied()
1417 .or_else(|| map.get(key).copied())
1418 .unwrap_or(0.0);
1419 let cached_rate = map
1420 .get(&format!("{key}.cached_input"))
1421 .copied()
1422 .unwrap_or(0.0);
1423 (in_rate, out_rate, cached_rate)
1424 };
1425 let (input_rate, output_rate, cached_rate) = match try_lookup(model) {
1426 (0.0, 0.0, 0.0) => model
1427 .rsplit_once('/')
1428 .map(|(_, suffix)| try_lookup(suffix))
1429 .unwrap_or((0.0, 0.0, 0.0)),
1430 rates => rates,
1431 };
1432 let usage = zeroclaw_runtime::cost::types::TokenUsage::new(
1433 model,
1434 input,
1435 output,
1436 cached_input,
1437 input_rate,
1438 output_rate,
1439 cached_rate,
1440 );
1441 let cost_usd = usage.cost_usd;
1442 if let Err(error) = tracker.record_usage(usage) {
1443 ::zeroclaw_log::record!(WARN, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_outcome(::zeroclaw_log::EventOutcome::Unknown).with_attrs(::serde_json::json!({"provider": provider_name, "model": model, "error": format!("{}", error)})), "Failed to record gateway turn cost");
1444 }
1445 Some(cost_usd)
1446}
1447
1448#[cfg(test)]
1449mod tests {
1450 use super::*;
1451 use axum::http::HeaderMap;
1452
1453 #[test]
1454 fn extract_ws_token_from_authorization_header() {
1455 let mut headers = HeaderMap::new();
1456 headers.insert("authorization", "Bearer zc_test123".parse().unwrap());
1457 assert_eq!(extract_ws_token(&headers, None), Some("zc_test123"));
1458 }
1459
1460 #[test]
1461 fn extract_ws_token_from_subprotocol() {
1462 let mut headers = HeaderMap::new();
1463 headers.insert(
1464 "sec-websocket-protocol",
1465 "zeroclaw.v1, bearer.zc_sub456".parse().unwrap(),
1466 );
1467 assert_eq!(extract_ws_token(&headers, None), Some("zc_sub456"));
1468 }
1469
1470 #[test]
1471 fn extract_ws_token_from_query_param() {
1472 let headers = HeaderMap::new();
1473 assert_eq!(
1474 extract_ws_token(&headers, Some("zc_query789")),
1475 Some("zc_query789")
1476 );
1477 }
1478
1479 #[test]
1480 fn extract_ws_token_precedence_header_over_subprotocol() {
1481 let mut headers = HeaderMap::new();
1482 headers.insert("authorization", "Bearer zc_header".parse().unwrap());
1483 headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
1484 assert_eq!(
1485 extract_ws_token(&headers, Some("zc_query")),
1486 Some("zc_header")
1487 );
1488 }
1489
1490 #[test]
1491 fn extract_ws_token_precedence_subprotocol_over_query() {
1492 let mut headers = HeaderMap::new();
1493 headers.insert("sec-websocket-protocol", "bearer.zc_sub".parse().unwrap());
1494 assert_eq!(extract_ws_token(&headers, Some("zc_query")), Some("zc_sub"));
1495 }
1496
1497 #[test]
1498 fn extract_ws_token_returns_none_when_empty() {
1499 let headers = HeaderMap::new();
1500 assert_eq!(extract_ws_token(&headers, None), None);
1501 }
1502
1503 #[test]
1504 fn extract_ws_token_skips_empty_header_value() {
1505 let mut headers = HeaderMap::new();
1506 headers.insert("authorization", "Bearer ".parse().unwrap());
1507 assert_eq!(
1508 extract_ws_token(&headers, Some("zc_fallback")),
1509 Some("zc_fallback")
1510 );
1511 }
1512
1513 #[test]
1514 fn extract_ws_token_skips_empty_query_param() {
1515 let headers = HeaderMap::new();
1516 assert_eq!(extract_ws_token(&headers, Some("")), None);
1517 }
1518
1519 #[test]
1520 fn extract_ws_token_subprotocol_with_multiple_entries() {
1521 let mut headers = HeaderMap::new();
1522 headers.insert(
1523 "sec-websocket-protocol",
1524 "zeroclaw.v1, bearer.zc_tok, other".parse().unwrap(),
1525 );
1526 assert_eq!(extract_ws_token(&headers, None), Some("zc_tok"));
1527 }
1528
1529 #[test]
1530 fn session_scoped_events_only_match_their_session() {
1531 let target_event = serde_json::json!({
1532 "type": "message",
1533 "session_id": "operator-1",
1534 "content": "deploy finished"
1535 });
1536 let other_event = serde_json::json!({
1537 "type": "message",
1538 "session_id": "operator-2",
1539 "content": "different session"
1540 });
1541 let nameless_observability = serde_json::json!({
1543 "type": "agent_start",
1544 "source": "observability",
1545 "model": "gpt-4o"
1546 });
1547 let cron = serde_json::json!({
1549 "type": "cron_result",
1550 "output": "global notification"
1551 });
1552
1553 assert!(event_matches_session(&target_event, "operator-1"));
1554 assert!(!event_matches_session(&other_event, "operator-1"));
1555 assert!(!event_matches_session(
1556 &nameless_observability,
1557 "operator-1"
1558 ));
1559 assert!(event_matches_session(&cron, "operator-1"));
1560 }
1561
1562 #[test]
1563 fn event_matches_session_defaults_drops_unwhitelisted_no_session_frames() {
1564 for ty in [
1569 "agent_start",
1570 "agent_end",
1571 "llm_request",
1572 "tool_call",
1573 "tool_call_start",
1574 "error",
1575 ] {
1576 let frame = serde_json::json!({
1577 "type": ty,
1578 "source": "observability",
1579 "timestamp": "2026-06-04T00:00:00Z",
1580 });
1581 assert!(
1582 !event_matches_session(&frame, "operator-1"),
1583 "{ty} observability frame must be dropped from chat WS"
1584 );
1585 }
1586 }
1587
1588 #[test]
1589 fn event_matches_session_passes_session_scoped_chat_messages() {
1590 let assistant_inject = serde_json::json!({
1593 "type": "message",
1594 "session_id": "operator-1",
1595 "role": "assistant",
1596 "content": "hello",
1597 });
1598 assert!(event_matches_session(&assistant_inject, "operator-1"));
1599 assert!(!event_matches_session(&assistant_inject, "operator-2"));
1600 }
1601
1602 #[test]
1603 fn observability_tagged_frames_are_filtered() {
1604 let obs = serde_json::json!({
1607 "type": "tool_call",
1608 "source": "observability",
1609 "tool": "shell",
1610 });
1611 assert!(is_observability_telemetry(&obs));
1612
1613 let chat = serde_json::json!({
1614 "type": "tool_call",
1615 "id": "call-1",
1616 "name": "file_write",
1617 "args": {"path": "/tmp/x"},
1618 });
1619 assert!(!is_observability_telemetry(&chat));
1620 }
1621
1622 #[test]
1623 fn observability_telemetry_filter_handles_malformed_source_field() {
1624 for source in [
1628 serde_json::Value::Null,
1629 serde_json::json!(""),
1630 serde_json::json!(42),
1631 serde_json::json!("api"),
1632 serde_json::json!({"nested": "x"}),
1633 ] {
1634 let frame = serde_json::json!({
1635 "type": "tool_call",
1636 "id": "call-1",
1637 "name": "file_write",
1638 "source": source,
1639 });
1640 assert!(
1641 !is_observability_telemetry(&frame),
1642 "frame with source={frame:?} must not be flagged as observability telemetry",
1643 );
1644 }
1645 }
1646
1647 #[test]
1648 fn chat_tool_frames_pass_through_when_session_scoped() {
1649 let chat_tool_call = serde_json::json!({
1653 "type": "tool_call",
1654 "session_id": "operator-1",
1655 "id": "call-1",
1656 "name": "file_write",
1657 "args": {"path": "/tmp/x"},
1658 });
1659 assert!(event_matches_session(&chat_tool_call, "operator-1"));
1660 assert!(!is_observability_telemetry(&chat_tool_call));
1661 }
1662
1663 #[test]
1664 fn resolve_session_cwd_uses_requested_cwd() {
1665 let requested = tempfile::tempdir().unwrap();
1666 let fallback = tempfile::tempdir().unwrap();
1667
1668 let resolved =
1669 resolve_session_cwd(Some(requested.path().to_str().unwrap()), fallback.path()).unwrap();
1670
1671 assert_eq!(resolved, requested.path().canonicalize().unwrap());
1672 }
1673
1674 #[test]
1675 fn resolve_session_cwd_uses_default_workspace_without_request() {
1676 let fallback = tempfile::tempdir().unwrap();
1677
1678 let resolved = resolve_session_cwd(None, fallback.path()).unwrap();
1679
1680 assert_eq!(resolved, fallback.path().canonicalize().unwrap());
1681 }
1682
1683 #[test]
1684 fn resolve_session_cwd_rejects_missing_directory() {
1685 let fallback = tempfile::tempdir().unwrap();
1686 let missing = fallback.path().join("missing");
1687
1688 let err = resolve_session_cwd(Some(missing.to_str().unwrap()), fallback.path())
1689 .expect_err("missing cwd should be rejected");
1690
1691 assert!(err.to_string().contains("cwd is not a usable directory"));
1692 }
1693
1694 #[test]
1695 fn needs_onboarding_ws_error_points_to_onboard() {
1696 let config = zeroclaw_config::schema::Config::default();
1697 let frame = needs_onboarding_ws_error(&config)
1698 .expect("empty model must produce a WS onboarding error");
1699
1700 assert_eq!(frame["type"], "error");
1701 assert_eq!(frame["error"], "needs_onboarding");
1702 assert_eq!(frame["code"], "NEEDS_ONBOARDING");
1703 assert_eq!(frame["url"], "/onboard");
1704 let message = frame["message"]
1705 .as_str()
1706 .expect("onboarding WS error must include a message");
1707 assert!(
1708 !message.starts_with('{') && !message.ends_with('}'),
1709 "missing Fluent key fallback leaked into WS error message: {message:?}"
1710 );
1711 assert!(
1712 message.to_lowercase().contains("quickstart"),
1713 "WS setup-gap message must explain the setup gap: {message:?}"
1714 );
1715 }
1716
1717 #[test]
1718 fn needs_onboarding_ws_error_uses_current_configured_model() {
1719 let mut config = zeroclaw_config::schema::Config::default();
1720 config.providers.models.openai.insert(
1721 "default".to_string(),
1722 zeroclaw_config::schema::OpenAIModelProviderConfig {
1723 base: zeroclaw_config::schema::ModelProviderConfig {
1724 model: Some("openai/gpt-4o-mini".to_string()),
1725 api_key: Some("sk-test".to_string()),
1726 ..Default::default()
1727 },
1728 },
1729 );
1730
1731 assert!(
1732 needs_onboarding_ws_error(&config).is_none(),
1733 "current configured model must allow WebSocket agent construction to continue"
1734 );
1735 }
1736
1737 #[derive(Debug, PartialEq, Eq)]
1742 enum DisconnectAction {
1743 Break,
1744 Continue,
1745 ProcessText,
1746 }
1747
1748 fn classify_client_msg(
1749 msg: Option<Result<axum::extract::ws::Message, &'static str>>,
1750 ) -> DisconnectAction {
1751 use axum::extract::ws::Message;
1752 match msg {
1753 Some(Ok(Message::Text(_))) => DisconnectAction::ProcessText,
1754 Some(Ok(Message::Close(_))) | Some(Err(_)) | None => DisconnectAction::Break,
1755 _ => DisconnectAction::Continue,
1756 }
1757 }
1758
1759 #[test]
1760 fn mid_turn_client_msg_breaks_on_stream_end_close_or_err() {
1761 use axum::extract::ws::Message;
1762 assert_eq!(classify_client_msg(None), DisconnectAction::Break);
1763 assert_eq!(
1764 classify_client_msg(Some(Ok(Message::Close(None)))),
1765 DisconnectAction::Break,
1766 );
1767 assert_eq!(
1768 classify_client_msg(Some(Err("io"))),
1769 DisconnectAction::Break,
1770 );
1771 assert_eq!(
1772 classify_client_msg(Some(Ok(Message::Ping(Default::default())))),
1773 DisconnectAction::Continue,
1774 );
1775 assert_eq!(
1776 classify_client_msg(Some(Ok(Message::Text("{}".into())))),
1777 DisconnectAction::ProcessText,
1778 );
1779 }
1780
1781 #[test]
1782 fn mid_turn_disconnect_cancel_unblocks_joined_turn() {
1783 let token = tokio_util::sync::CancellationToken::new();
1784 let clone_for_turn = token.clone();
1785 assert!(!clone_for_turn.is_cancelled());
1786 token.cancel();
1787 assert!(
1788 clone_for_turn.is_cancelled(),
1789 "cloned token (held by turn_fut via agent.turn_streamed) must observe cancellation"
1790 );
1791 }
1792
1793 #[test]
1794 fn session_queue_errors_map_to_explicit_websocket_codes() {
1795 use crate::session_queue::SessionQueueError;
1796
1797 assert_eq!(
1798 session_queue_ws_error_code(&SessionQueueError::QueueFull {
1799 session_id: "gw_test".into(),
1800 depth: 2,
1801 }),
1802 "SESSION_QUEUE_FULL"
1803 );
1804 assert_eq!(
1805 session_queue_ws_error_code(&SessionQueueError::Timeout {
1806 session_id: "gw_test".into(),
1807 }),
1808 "SESSION_QUEUE_TIMEOUT"
1809 );
1810 }
1811
1812 struct DeletedSessionBackend {
1829 append_calls: std::sync::Mutex<Vec<String>>,
1830 }
1831
1832 impl zeroclaw_infra::session_backend::SessionBackend for DeletedSessionBackend {
1833 fn load(&self, _session_key: &str) -> Vec<zeroclaw_providers::ChatMessage> {
1834 Vec::new()
1835 }
1836 fn append(
1837 &self,
1838 session_key: &str,
1839 message: &zeroclaw_providers::ChatMessage,
1840 ) -> std::io::Result<()> {
1841 self.append_calls.lock().unwrap().push(format!(
1842 "{}:{}:{}",
1843 session_key, message.role, message.content
1844 ));
1845 Ok(())
1846 }
1847 fn remove_last(&self, _session_key: &str) -> std::io::Result<bool> {
1848 Ok(false)
1849 }
1850 fn list_sessions(&self) -> Vec<String> {
1851 Vec::new()
1852 }
1853 fn session_exists(&self, _session_key: &str) -> bool {
1854 false
1856 }
1857 }
1858
1859 #[test]
1860 fn persist_conversation_messages_skips_deleted_session() {
1861 use zeroclaw_providers::{ChatMessage, ConversationMessage};
1862 let backend = DeletedSessionBackend {
1863 append_calls: std::sync::Mutex::new(Vec::new()),
1864 };
1865 let messages = vec![
1866 ConversationMessage::Chat(ChatMessage::user("hi")),
1867 ConversationMessage::Chat(ChatMessage::assistant("[interrupted by user]")),
1868 ];
1869
1870 persist_conversation_messages(&backend, "gw_deleted", &messages);
1871
1872 assert!(
1873 backend.append_calls.lock().unwrap().is_empty(),
1874 "persist_conversation_messages must not resurrect a session whose \
1875 session_exists() returned false (see #7126)"
1876 );
1877 }
1878}