1use crate::agent::agent::{Agent, StreamedTurnError, StreamedTurnSuccess, TurnEvent};
4use crate::agent::loop_::is_tool_loop_cancelled;
5use std::sync::Arc;
6use tokio::sync::{Mutex, mpsc};
7use tokio_util::sync::CancellationToken;
8use zeroclaw_api::model_provider::ConversationMessage;
9
10pub enum TurnOutcome {
11 Completed {
12 text: String,
13 messages: Vec<ConversationMessage>,
14 },
15 Cancelled {
16 partial_text: String,
17 messages: Vec<ConversationMessage>,
18 },
19}
20
21#[derive(Debug)]
22pub enum TurnError {
23 Panicked(String),
24 AgentError(String),
25}
26
27impl std::fmt::Display for TurnError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 Self::Panicked(msg) => write!(f, "Turn task panicked: {msg}"),
31 Self::AgentError(msg) => write!(f, "Agent turn failed: {msg}"),
32 }
33 }
34}
35
36impl std::error::Error for TurnError {}
37
38#[derive(Clone, Default)]
41pub struct TurnAttribution {
42 pub session_key: Option<String>,
43 pub agent_alias: String,
44 pub model_provider: String,
45 pub model: String,
46 pub channel: &'static str,
47}
48
49pub async fn execute_turn<F, Fut>(
50 agent: Arc<Mutex<Agent>>,
51 prompt: String,
52 cancel: CancellationToken,
53 attribution: TurnAttribution,
54 on_event: F,
55) -> Result<TurnOutcome, TurnError>
56where
57 F: Fn(TurnEvent) -> Fut + Send + 'static,
58 Fut: std::future::Future<Output = ()> + Send,
59{
60 let (event_tx, mut event_rx) = mpsc::channel::<TurnEvent>(64);
61 let cancel_clone = cancel.clone();
62 let session_key = attribution.session_key.clone();
63
64 let mut turn_handle = zeroclaw_spawn::spawn!(async move {
65 let mut guard = agent.lock().await;
66 let sk = attribution.session_key.clone();
67 crate::agent::loop_::scope_session_key(attribution.session_key, async move {
68 use ::zeroclaw_log::Instrument as _;
69 let span = ::zeroclaw_log::info_span!(
70 target: "zeroclaw_log_internal_scope",
71 "zeroclaw_scope",
72 session_key = %sk.as_deref().unwrap_or(""),
73 agent_alias = %attribution.agent_alias,
74 model_provider = %attribution.model_provider,
75 model = %attribution.model,
76 channel = %attribution.channel,
77 );
78 guard
79 .turn_streamed_with_steering_state(&prompt, event_tx, Some(cancel_clone), None)
80 .instrument(span)
81 .await
82 })
83 .await
84 });
85
86 let mut accumulated_text = String::new();
87
88 let drain =
94 drain_until_done_or_cancelled(&mut event_rx, &cancel, &mut accumulated_text, &on_event)
95 .await;
96 let _ = session_key; match drain {
99 DrainOutcome::Completed => {
100 let joined = turn_handle
101 .await
102 .map_err(|e| TurnError::Panicked(format!("{e}")))?;
103 outcome_from_task_result(joined, accumulated_text)
104 }
105 DrainOutcome::ExplicitCancel => {
106 match tokio::time::timeout(CANCEL_GRACE, &mut turn_handle).await {
115 Ok(joined) => outcome_from_task_result(
116 joined.map_err(|e| TurnError::Panicked(format!("{e}")))?,
117 accumulated_text,
118 ),
119 Err(_) => {
120 turn_handle.abort();
121 Ok(TurnOutcome::Cancelled {
122 partial_text: accumulated_text,
123 messages: Vec::new(),
124 })
125 }
126 }
127 }
128 }
129}
130
131const CANCEL_GRACE: std::time::Duration = std::time::Duration::from_secs(5);
135
136fn outcome_from_task_result(
140 joined: Result<StreamedTurnSuccess, StreamedTurnError>,
141 accumulated_text: String,
142) -> Result<TurnOutcome, TurnError> {
143 match joined {
144 Ok(StreamedTurnSuccess {
145 response,
146 new_messages,
147 }) => Ok(TurnOutcome::Completed {
148 text: response,
149 messages: new_messages,
150 }),
151 Err(StreamedTurnError {
152 error,
153 committed_response,
154 new_messages,
155 }) if is_tool_loop_cancelled(&error) => Ok(TurnOutcome::Cancelled {
156 partial_text: if committed_response.is_empty() {
157 accumulated_text
158 } else {
159 committed_response
160 },
161 messages: new_messages,
162 }),
163 Err(StreamedTurnError { error, .. }) => Err(TurnError::AgentError(format!("{error}"))),
164 }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172enum DrainOutcome {
173 Completed,
174 ExplicitCancel,
175}
176
177async fn drain_until_done_or_cancelled<F, Fut>(
185 event_rx: &mut mpsc::Receiver<TurnEvent>,
186 cancel: &CancellationToken,
187 accumulated: &mut String,
188 on_event: &F,
189) -> DrainOutcome
190where
191 F: Fn(TurnEvent) -> Fut,
192 Fut: std::future::Future<Output = ()>,
193{
194 loop {
195 if cancel.is_cancelled() {
196 return DrainOutcome::ExplicitCancel;
197 }
198 tokio::select! {
199 biased;
200 _ = cancel.cancelled() => return DrainOutcome::ExplicitCancel,
201 maybe_event = event_rx.recv() => {
202 match maybe_event {
203 Some(event) => {
204 if let TurnEvent::Chunk { ref delta } = event {
205 accumulated.push_str(delta);
206 }
207 on_event(event).await;
208 }
209 None => return DrainOutcome::Completed,
210 }
211 }
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use tokio::sync::mpsc;
220
221 fn noop(_e: TurnEvent) -> std::future::Ready<()> {
222 std::future::ready(())
223 }
224
225 #[tokio::test]
226 async fn drain_must_not_idle_cancel_a_live_turn_across_a_long_tool_gap() {
227 let (tx, mut rx) = mpsc::channel::<TurnEvent>(8);
228 let cancel = CancellationToken::new();
229 let mut acc = String::new();
230
231 let sender = zeroclaw_spawn::spawn!(async move {
232 let _ = tx
233 .send(TurnEvent::ToolCall {
234 id: "c1".to_string(),
235 name: "shell".to_string(),
236 args: serde_json::json!({ "command": "cargo test" }),
237 })
238 .await;
239 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
240 let _ = tx
241 .send(TurnEvent::ToolResult {
242 id: "c1".to_string(),
243 name: "shell".to_string(),
244 output: "ok".to_string(),
245 })
246 .await;
247 let _ = tx
248 .send(TurnEvent::Chunk {
249 delta: "done".to_string(),
250 })
251 .await;
252 });
253
254 let outcome = tokio::time::timeout(
255 std::time::Duration::from_secs(15),
256 drain_until_done_or_cancelled(&mut rx, &cancel, &mut acc, &noop),
257 )
258 .await
259 .expect("drain must terminate when the live turn task completes");
260
261 sender.await.unwrap();
262 assert_eq!(
263 outcome,
264 DrainOutcome::Completed,
265 "a turn whose sender is alive but quiet during a long tool \
266 execution is NOT stalled; silence during execute_tools is the \
267 normal case. Killing it is the idle_stall regression that froze \
268 the TUI mid-turn (sessions 102, 103)."
269 );
270 assert!(
271 !cancel.is_cancelled(),
272 "drain self-cancelled a healthy turn across a tool gap; the token \
273 must stay clean so downstream records no cancel."
274 );
275 assert_eq!(
276 acc, "done",
277 "drain dropped the post-tool chunk after wrongly tripping an idle \
278 bound mid-execution."
279 );
280 }
281
282 #[tokio::test]
283 async fn drain_must_still_accumulate_chunks_when_events_arrive_steadily() {
284 let (tx, mut rx) = mpsc::channel::<TurnEvent>(8);
285 let cancel = CancellationToken::new();
286 let mut acc = String::new();
287
288 let sender = zeroclaw_spawn::spawn!(async move {
289 for delta in ["he", "llo", " ", "world"] {
290 let _ = tx
291 .send(TurnEvent::Chunk {
292 delta: delta.to_string(),
293 })
294 .await;
295 tokio::time::sleep(std::time::Duration::from_millis(250)).await;
296 }
297 });
298
299 let cancelled = tokio::time::timeout(
300 std::time::Duration::from_secs(10),
301 drain_until_done_or_cancelled(&mut rx, &cancel, &mut acc, &noop),
302 )
303 .await
304 .expect("drain must terminate after the sender drops");
305
306 sender.await.unwrap();
307 assert_eq!(
308 cancelled,
309 DrainOutcome::Completed,
310 "channel closure is not a cancel; drain returned the wrong verdict"
311 );
312 assert_eq!(
313 acc, "hello world",
314 "drain dropped chunks instead of accumulating them; a fix that \
315 short-circuits with too-aggressive an idle window (e.g. <250ms) \
316 would corrupt legitimate streaming turns. The production idle \
317 window must sit comfortably between the inter-chunk gap of a \
318 healthy stream (~hundreds of ms) and the user-perceptible hang \
319 threshold (~seconds)."
320 );
321 }
322
323 #[test]
324 fn cancel_outcome_carries_committed_messages_not_just_partial_text() {
325 let msgs = vec![ConversationMessage::Chat(
331 zeroclaw_providers::ChatMessage::assistant("[interrupted by user]"),
332 )];
333 let err = StreamedTurnError {
334 error: crate::agent::loop_::ToolLoopCancelled.into(),
335 committed_response: "partial".to_string(),
336 new_messages: msgs.clone(),
337 };
338
339 let outcome = outcome_from_task_result(Err(err), "accumulated".to_string())
340 .expect("cooperative cancel maps to a Cancelled outcome, not an error");
341
342 match outcome {
343 TurnOutcome::Cancelled {
344 partial_text,
345 messages,
346 } => {
347 assert_eq!(
348 partial_text, "partial",
349 "committed_response from the task must win over the drain's \
350 accumulated text when present"
351 );
352 assert_eq!(
353 messages.len(),
354 msgs.len(),
355 "cancelled outcome dropped the messages the task committed"
356 );
357 }
358 TurnOutcome::Completed { .. } => {
359 panic!("a tool-loop cancel must not map to Completed")
360 }
361 }
362 }
363
364 #[test]
365 fn non_cancel_agent_error_stays_an_error() {
366 let err = StreamedTurnError {
367 error: anyhow::Error::msg("provider exploded"),
368 committed_response: String::new(),
369 new_messages: Vec::new(),
370 };
371 let outcome = outcome_from_task_result(Err(err), String::new());
372 assert!(
373 matches!(outcome, Err(TurnError::AgentError(_))),
374 "a genuine agent failure must surface as an error, not a silent \
375 cancel"
376 );
377 }
378}