Skip to main content

zeroclaw_runtime/rpc/
turn.rs

1//! Shared turn execution. Single source of truth for spawn-drain-cancel.
2
3use crate::agent::agent::{Agent, StreamedTurnError, StreamedTurnSuccess, TurnEvent};
4use crate::agent::loop_::is_tool_loop_cancelled;
5use std::sync::Arc;
6use tokio::sync::{Mutex, mpsc};
7use tokio_util::sync::CancellationToken;
8use zeroclaw_api::model_provider::ConversationMessage;
9
10pub enum TurnOutcome {
11    Completed {
12        text: String,
13        messages: Vec<ConversationMessage>,
14    },
15    Cancelled {
16        partial_text: String,
17        messages: Vec<ConversationMessage>,
18    },
19}
20
21#[derive(Debug)]
22pub enum TurnError {
23    Panicked(String),
24    AgentError(String),
25}
26
27impl std::fmt::Display for TurnError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            Self::Panicked(msg) => write!(f, "Turn task panicked: {msg}"),
31            Self::AgentError(msg) => write!(f, "Agent turn failed: {msg}"),
32        }
33    }
34}
35
36impl std::error::Error for TurnError {}
37
38/// Attribution fields attached to the tracing span for the duration of a turn.
39/// All fields appear on every `record!()` emitted inside the turn.
40#[derive(Clone, Default)]
41pub struct TurnAttribution {
42    pub session_key: Option<String>,
43    pub agent_alias: String,
44    pub model_provider: String,
45    pub model: String,
46    pub channel: &'static str,
47}
48
49pub async fn execute_turn<F, Fut>(
50    agent: Arc<Mutex<Agent>>,
51    prompt: String,
52    cancel: CancellationToken,
53    attribution: TurnAttribution,
54    on_event: F,
55) -> Result<TurnOutcome, TurnError>
56where
57    F: Fn(TurnEvent) -> Fut + Send + 'static,
58    Fut: std::future::Future<Output = ()> + Send,
59{
60    let (event_tx, mut event_rx) = mpsc::channel::<TurnEvent>(64);
61    let cancel_clone = cancel.clone();
62    let session_key = attribution.session_key.clone();
63
64    let mut turn_handle = zeroclaw_spawn::spawn!(async move {
65        let mut guard = agent.lock().await;
66        let sk = attribution.session_key.clone();
67        crate::agent::loop_::scope_session_key(attribution.session_key, async move {
68            use ::zeroclaw_log::Instrument as _;
69            let span = ::zeroclaw_log::info_span!(
70                target: "zeroclaw_log_internal_scope",
71                "zeroclaw_scope",
72                session_key = %sk.as_deref().unwrap_or(""),
73                agent_alias = %attribution.agent_alias,
74                model_provider = %attribution.model_provider,
75                model = %attribution.model,
76                channel = %attribution.channel,
77            );
78            guard
79                .turn_streamed_with_steering_state(&prompt, event_tx, Some(cancel_clone), None)
80                .instrument(span)
81                .await
82        })
83        .await
84    });
85
86    let mut accumulated_text = String::new();
87
88    // Drive the turn by draining its event channel, but never let a turn task
89    // wedged inside a non-cancellable tool call (shell, HTTP, a stalled provider
90    // stream) hold the dispatch path hostage. The drain exits on channel close,
91    // explicit cancel, OR an idle-stall bound; the latter two return Cancelled
92    // and the in-flight task is aborted on drop.
93    let drain =
94        drain_until_done_or_cancelled(&mut event_rx, &cancel, &mut accumulated_text, &on_event)
95            .await;
96    let _ = session_key; // consumed above
97
98    match drain {
99        DrainOutcome::Completed => {
100            let joined = turn_handle
101                .await
102                .map_err(|e| TurnError::Panicked(format!("{e}")))?;
103            outcome_from_task_result(joined, accumulated_text)
104        }
105        DrainOutcome::ExplicitCancel => {
106            // The turn task races the same cancel token and unwinds
107            // cooperatively: it synthesizes results for any in-flight tool
108            // call, pushes the `[interrupted]` assistant message, and commits
109            // both into the agent history before returning. Persistence reads
110            // that committed history, so aborting the task mid-commit drops
111            // the cancelled turn's tool exchange and corrupts the next turn.
112            // Give the task a bounded grace window to land its own unwind;
113            // only abort if it is genuinely wedged in a non-cooperative call.
114            match tokio::time::timeout(CANCEL_GRACE, &mut turn_handle).await {
115                Ok(joined) => outcome_from_task_result(
116                    joined.map_err(|e| TurnError::Panicked(format!("{e}")))?,
117                    accumulated_text,
118                ),
119                Err(_) => {
120                    turn_handle.abort();
121                    Ok(TurnOutcome::Cancelled {
122                        partial_text: accumulated_text,
123                        messages: Vec::new(),
124                    })
125                }
126            }
127        }
128    }
129}
130
131/// Grace window allowing a cancelled turn task to commit its cooperative
132/// unwind (synthesized tool results + `[interrupted]` message) into the agent
133/// history before the dispatch path falls back to a hard abort.
134const CANCEL_GRACE: std::time::Duration = std::time::Duration::from_secs(5);
135
136/// Map a finished turn task into a [`TurnOutcome`]. A successful turn yields
137/// `Completed`; a cooperative cancel yields `Cancelled` carrying the messages
138/// the task committed so persistence never depends on the abort/commit race.
139fn outcome_from_task_result(
140    joined: Result<StreamedTurnSuccess, StreamedTurnError>,
141    accumulated_text: String,
142) -> Result<TurnOutcome, TurnError> {
143    match joined {
144        Ok(StreamedTurnSuccess {
145            response,
146            new_messages,
147        }) => Ok(TurnOutcome::Completed {
148            text: response,
149            messages: new_messages,
150        }),
151        Err(StreamedTurnError {
152            error,
153            committed_response,
154            new_messages,
155        }) if is_tool_loop_cancelled(&error) => Ok(TurnOutcome::Cancelled {
156            partial_text: if committed_response.is_empty() {
157                accumulated_text
158            } else {
159                committed_response
160            },
161            messages: new_messages,
162        }),
163        Err(StreamedTurnError { error, .. }) => Err(TurnError::AgentError(format!("{error}"))),
164    }
165}
166
167/// Why [`drain_until_done_or_cancelled`] returned. `ExplicitCancel` is an
168/// outside fire (client RPC, reaper, session removal) that reached the drain.
169/// There is no self-firing idle exit: a live turn falls silent for the whole
170/// duration of a tool call, so silence is never treated as a stall.
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172enum DrainOutcome {
173    Completed,
174    ExplicitCancel,
175}
176
177/// Drain `event_rx` until the turn finishes or the cancel token fires. Chunk
178/// deltas accumulate in `accumulated` so partial text survives a cancel. The
179/// only terminals are the turn task dropping its sender (`recv` -> `None`,
180/// [`DrainOutcome::Completed`]) and an explicit cancel
181/// ([`DrainOutcome::ExplicitCancel`]). A wedged turn is bounded by the explicit
182/// layers — ownership-gated `session/cancel` and the reaper — never by guessing
183/// from channel quiet.
184async fn drain_until_done_or_cancelled<F, Fut>(
185    event_rx: &mut mpsc::Receiver<TurnEvent>,
186    cancel: &CancellationToken,
187    accumulated: &mut String,
188    on_event: &F,
189) -> DrainOutcome
190where
191    F: Fn(TurnEvent) -> Fut,
192    Fut: std::future::Future<Output = ()>,
193{
194    loop {
195        if cancel.is_cancelled() {
196            return DrainOutcome::ExplicitCancel;
197        }
198        tokio::select! {
199            biased;
200            _ = cancel.cancelled() => return DrainOutcome::ExplicitCancel,
201            maybe_event = event_rx.recv() => {
202                match maybe_event {
203                    Some(event) => {
204                        if let TurnEvent::Chunk { ref delta } = event {
205                            accumulated.push_str(delta);
206                        }
207                        on_event(event).await;
208                    }
209                    None => return DrainOutcome::Completed,
210                }
211            }
212        }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use tokio::sync::mpsc;
220
221    fn noop(_e: TurnEvent) -> std::future::Ready<()> {
222        std::future::ready(())
223    }
224
225    #[tokio::test]
226    async fn drain_must_not_idle_cancel_a_live_turn_across_a_long_tool_gap() {
227        let (tx, mut rx) = mpsc::channel::<TurnEvent>(8);
228        let cancel = CancellationToken::new();
229        let mut acc = String::new();
230
231        let sender = zeroclaw_spawn::spawn!(async move {
232            let _ = tx
233                .send(TurnEvent::ToolCall {
234                    id: "c1".to_string(),
235                    name: "shell".to_string(),
236                    args: serde_json::json!({ "command": "cargo test" }),
237                })
238                .await;
239            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
240            let _ = tx
241                .send(TurnEvent::ToolResult {
242                    id: "c1".to_string(),
243                    name: "shell".to_string(),
244                    output: "ok".to_string(),
245                })
246                .await;
247            let _ = tx
248                .send(TurnEvent::Chunk {
249                    delta: "done".to_string(),
250                })
251                .await;
252        });
253
254        let outcome = tokio::time::timeout(
255            std::time::Duration::from_secs(15),
256            drain_until_done_or_cancelled(&mut rx, &cancel, &mut acc, &noop),
257        )
258        .await
259        .expect("drain must terminate when the live turn task completes");
260
261        sender.await.unwrap();
262        assert_eq!(
263            outcome,
264            DrainOutcome::Completed,
265            "a turn whose sender is alive but quiet during a long tool \
266             execution is NOT stalled; silence during execute_tools is the \
267             normal case. Killing it is the idle_stall regression that froze \
268             the TUI mid-turn (sessions 102, 103)."
269        );
270        assert!(
271            !cancel.is_cancelled(),
272            "drain self-cancelled a healthy turn across a tool gap; the token \
273             must stay clean so downstream records no cancel."
274        );
275        assert_eq!(
276            acc, "done",
277            "drain dropped the post-tool chunk after wrongly tripping an idle \
278             bound mid-execution."
279        );
280    }
281
282    #[tokio::test]
283    async fn drain_must_still_accumulate_chunks_when_events_arrive_steadily() {
284        let (tx, mut rx) = mpsc::channel::<TurnEvent>(8);
285        let cancel = CancellationToken::new();
286        let mut acc = String::new();
287
288        let sender = zeroclaw_spawn::spawn!(async move {
289            for delta in ["he", "llo", " ", "world"] {
290                let _ = tx
291                    .send(TurnEvent::Chunk {
292                        delta: delta.to_string(),
293                    })
294                    .await;
295                tokio::time::sleep(std::time::Duration::from_millis(250)).await;
296            }
297        });
298
299        let cancelled = tokio::time::timeout(
300            std::time::Duration::from_secs(10),
301            drain_until_done_or_cancelled(&mut rx, &cancel, &mut acc, &noop),
302        )
303        .await
304        .expect("drain must terminate after the sender drops");
305
306        sender.await.unwrap();
307        assert_eq!(
308            cancelled,
309            DrainOutcome::Completed,
310            "channel closure is not a cancel; drain returned the wrong verdict"
311        );
312        assert_eq!(
313            acc, "hello world",
314            "drain dropped chunks instead of accumulating them; a fix that \
315             short-circuits with too-aggressive an idle window (e.g. <250ms) \
316             would corrupt legitimate streaming turns. The production idle \
317             window must sit comfortably between the inter-chunk gap of a \
318             healthy stream (~hundreds of ms) and the user-perceptible hang \
319             threshold (~seconds)."
320        );
321    }
322
323    #[test]
324    fn cancel_outcome_carries_committed_messages_not_just_partial_text() {
325        // A cooperative cancel returns StreamedTurnError whose new_messages
326        // hold the synthesized tool results + `[interrupted]` message the task
327        // already committed. The mapping must surface them, not drop them onto
328        // the floor and fall back to bare accumulated text — that drop is what
329        // truncated the cancelled turn's tool exchange from persisted history.
330        let msgs = vec![ConversationMessage::Chat(
331            zeroclaw_providers::ChatMessage::assistant("[interrupted by user]"),
332        )];
333        let err = StreamedTurnError {
334            error: crate::agent::loop_::ToolLoopCancelled.into(),
335            committed_response: "partial".to_string(),
336            new_messages: msgs.clone(),
337        };
338
339        let outcome = outcome_from_task_result(Err(err), "accumulated".to_string())
340            .expect("cooperative cancel maps to a Cancelled outcome, not an error");
341
342        match outcome {
343            TurnOutcome::Cancelled {
344                partial_text,
345                messages,
346            } => {
347                assert_eq!(
348                    partial_text, "partial",
349                    "committed_response from the task must win over the drain's \
350                     accumulated text when present"
351                );
352                assert_eq!(
353                    messages.len(),
354                    msgs.len(),
355                    "cancelled outcome dropped the messages the task committed"
356                );
357            }
358            TurnOutcome::Completed { .. } => {
359                panic!("a tool-loop cancel must not map to Completed")
360            }
361        }
362    }
363
364    #[test]
365    fn non_cancel_agent_error_stays_an_error() {
366        let err = StreamedTurnError {
367            error: anyhow::Error::msg("provider exploded"),
368            committed_response: String::new(),
369            new_messages: Vec::new(),
370        };
371        let outcome = outcome_from_task_result(Err(err), String::new());
372        assert!(
373            matches!(outcome, Err(TurnError::AgentError(_))),
374            "a genuine agent failure must surface as an error, not a silent \
375             cancel"
376        );
377    }
378}