Skip to main content

zeroclaw_runtime/hooks/
runner.rs

1use std::time::Duration;
2
3use futures_util::{FutureExt, future::join_all};
4use serde_json::Value;
5use std::panic::AssertUnwindSafe;
6
7use zeroclaw_api::channel::ChannelMessage;
8use zeroclaw_api::model_provider::{ChatMessage, ChatResponse};
9use zeroclaw_api::tool::ToolResult;
10
11use super::traits::{HookHandler, HookResult};
12
13/// Dispatcher that manages registered hook handlers.
14///
15/// Void hooks are dispatched in parallel via `join_all`.
16/// Modifying hooks run sequentially by priority (higher first), piping output
17/// and short-circuiting on `Cancel`.
18pub struct HookRunner {
19    handlers: Vec<Box<dyn HookHandler>>,
20}
21
22impl Default for HookRunner {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl HookRunner {
29    /// Create an empty runner with no handlers.
30    pub fn new() -> Self {
31        Self {
32            handlers: Vec::new(),
33        }
34    }
35
36    /// Register a handler and re-sort by descending priority.
37    pub fn register(&mut self, handler: Box<dyn HookHandler>) {
38        self.handlers.push(handler);
39        self.handlers
40            .sort_by_key(|h| std::cmp::Reverse(h.priority()));
41    }
42
43    // ---------------------------------------------------------------
44    // Void dispatchers (parallel, fire-and-forget)
45    // ---------------------------------------------------------------
46
47    pub async fn fire_gateway_start(&self, host: &str, port: u16) {
48        let futs: Vec<_> = self
49            .handlers
50            .iter()
51            .map(|h| h.on_gateway_start(host, port))
52            .collect();
53        join_all(futs).await;
54    }
55
56    pub async fn fire_gateway_stop(&self) {
57        let futs: Vec<_> = self.handlers.iter().map(|h| h.on_gateway_stop()).collect();
58        join_all(futs).await;
59    }
60
61    pub async fn fire_session_start(&self, session_id: &str, channel: &str) {
62        let futs: Vec<_> = self
63            .handlers
64            .iter()
65            .map(|h| h.on_session_start(session_id, channel))
66            .collect();
67        join_all(futs).await;
68    }
69
70    pub async fn fire_session_end(&self, session_id: &str, channel: &str) {
71        let futs: Vec<_> = self
72            .handlers
73            .iter()
74            .map(|h| h.on_session_end(session_id, channel))
75            .collect();
76        join_all(futs).await;
77    }
78
79    pub async fn fire_llm_input(&self, messages: &[ChatMessage], model: &str) {
80        let futs: Vec<_> = self
81            .handlers
82            .iter()
83            .map(|h| h.on_llm_input(messages, model))
84            .collect();
85        join_all(futs).await;
86    }
87
88    pub async fn fire_llm_output(&self, response: &ChatResponse) {
89        let futs: Vec<_> = self
90            .handlers
91            .iter()
92            .map(|h| h.on_llm_output(response))
93            .collect();
94        join_all(futs).await;
95    }
96
97    pub async fn fire_after_tool_call(&self, tool: &str, result: &ToolResult, duration: Duration) {
98        let futs: Vec<_> = self
99            .handlers
100            .iter()
101            .map(|h| h.on_after_tool_call(tool, result, duration))
102            .collect();
103        join_all(futs).await;
104    }
105
106    pub async fn fire_message_sent(&self, channel: &str, recipient: &str, content: &str) {
107        let futs: Vec<_> = self
108            .handlers
109            .iter()
110            .map(|h| h.on_message_sent(channel, recipient, content))
111            .collect();
112        join_all(futs).await;
113    }
114
115    pub async fn fire_heartbeat_tick(&self) {
116        let futs: Vec<_> = self
117            .handlers
118            .iter()
119            .map(|h| h.on_heartbeat_tick())
120            .collect();
121        join_all(futs).await;
122    }
123
124    // ---------------------------------------------------------------
125    // Modifying dispatchers (sequential by priority, short-circuit on Cancel)
126    // ---------------------------------------------------------------
127
128    pub async fn run_before_model_resolve(
129        &self,
130        mut model_provider: String,
131        mut model: String,
132    ) -> HookResult<(String, String)> {
133        for h in &self.handlers {
134            let hook_name = h.name();
135            match AssertUnwindSafe(h.before_model_resolve(model_provider.clone(), model.clone()))
136                .catch_unwind()
137                .await
138            {
139                Ok(HookResult::Continue((p, m))) => {
140                    model_provider = p;
141                    model = m;
142                }
143                Ok(HookResult::Cancel(reason)) => {
144                    ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "before_model_resolve cancelled by hook");
145                    return HookResult::Cancel(reason);
146                }
147                Err(_) => {
148                    ::zeroclaw_log::record!(
149                        ERROR,
150                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
151                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
152                            .with_attrs(::serde_json::json!({"hook": hook_name})),
153                        "before_model_resolve hook panicked; continuing with previous values"
154                    );
155                }
156            }
157        }
158        HookResult::Continue((model_provider, model))
159    }
160
161    pub async fn run_before_prompt_build(&self, mut prompt: String) -> HookResult<String> {
162        for h in &self.handlers {
163            let hook_name = h.name();
164            match AssertUnwindSafe(h.before_prompt_build(prompt.clone()))
165                .catch_unwind()
166                .await
167            {
168                Ok(HookResult::Continue(p)) => prompt = p,
169                Ok(HookResult::Cancel(reason)) => {
170                    ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "before_prompt_build cancelled by hook");
171                    return HookResult::Cancel(reason);
172                }
173                Err(_) => {
174                    ::zeroclaw_log::record!(
175                        ERROR,
176                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
177                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
178                            .with_attrs(::serde_json::json!({"hook": hook_name})),
179                        "before_prompt_build hook panicked; continuing with previous value"
180                    );
181                }
182            }
183        }
184        HookResult::Continue(prompt)
185    }
186
187    pub async fn run_before_llm_call(
188        &self,
189        mut messages: Vec<ChatMessage>,
190        mut model: String,
191    ) -> HookResult<(Vec<ChatMessage>, String)> {
192        for h in &self.handlers {
193            let hook_name = h.name();
194            match AssertUnwindSafe(h.before_llm_call(messages.clone(), model.clone()))
195                .catch_unwind()
196                .await
197            {
198                Ok(HookResult::Continue((m, mdl))) => {
199                    messages = m;
200                    model = mdl;
201                }
202                Ok(HookResult::Cancel(reason)) => {
203                    ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "before_llm_call cancelled by hook");
204                    return HookResult::Cancel(reason);
205                }
206                Err(_) => {
207                    ::zeroclaw_log::record!(
208                        ERROR,
209                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
210                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
211                            .with_attrs(::serde_json::json!({"hook": hook_name})),
212                        "before_llm_call hook panicked; continuing with previous values"
213                    );
214                }
215            }
216        }
217        HookResult::Continue((messages, model))
218    }
219
220    pub async fn run_before_tool_call(
221        &self,
222        mut name: String,
223        mut args: Value,
224    ) -> HookResult<(String, Value)> {
225        for h in &self.handlers {
226            let hook_name = h.name();
227            match AssertUnwindSafe(h.before_tool_call(name.clone(), args.clone()))
228                .catch_unwind()
229                .await
230            {
231                Ok(HookResult::Continue((n, a))) => {
232                    name = n;
233                    args = a;
234                }
235                Ok(HookResult::Cancel(reason)) => {
236                    ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "before_tool_call cancelled by hook");
237                    return HookResult::Cancel(reason);
238                }
239                Err(_) => {
240                    ::zeroclaw_log::record!(
241                        ERROR,
242                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
243                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
244                            .with_attrs(::serde_json::json!({"hook": hook_name})),
245                        "before_tool_call hook panicked; continuing with previous values"
246                    );
247                }
248            }
249        }
250        HookResult::Continue((name, args))
251    }
252
253    pub async fn run_on_message_received(
254        &self,
255        mut message: ChannelMessage,
256    ) -> HookResult<ChannelMessage> {
257        for h in &self.handlers {
258            let hook_name = h.name();
259            match AssertUnwindSafe(h.on_message_received(message.clone()))
260                .catch_unwind()
261                .await
262            {
263                Ok(HookResult::Continue(m)) => message = m,
264                Ok(HookResult::Cancel(reason)) => {
265                    ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "on_message_received cancelled by hook");
266                    return HookResult::Cancel(reason);
267                }
268                Err(_) => {
269                    ::zeroclaw_log::record!(
270                        ERROR,
271                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
272                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
273                            .with_attrs(::serde_json::json!({"hook": hook_name})),
274                        "on_message_received hook panicked; continuing with previous message"
275                    );
276                }
277            }
278        }
279        HookResult::Continue(message)
280    }
281
282    pub async fn run_on_message_sending(
283        &self,
284        mut channel: String,
285        mut recipient: String,
286        mut content: String,
287    ) -> HookResult<(String, String, String)> {
288        for h in &self.handlers {
289            let hook_name = h.name();
290            match AssertUnwindSafe(h.on_message_sending(
291                channel.clone(),
292                recipient.clone(),
293                content.clone(),
294            ))
295            .catch_unwind()
296            .await
297            {
298                Ok(HookResult::Continue((c, r, ct))) => {
299                    channel = c;
300                    recipient = r;
301                    content = ct;
302                }
303                Ok(HookResult::Cancel(reason)) => {
304                    ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "on_message_sending cancelled by hook");
305                    return HookResult::Cancel(reason);
306                }
307                Err(_) => {
308                    ::zeroclaw_log::record!(
309                        ERROR,
310                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
311                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
312                            .with_attrs(::serde_json::json!({"hook": hook_name})),
313                        "on_message_sending hook panicked; continuing with previous message"
314                    );
315                }
316            }
317        }
318        HookResult::Continue((channel, recipient, content))
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use async_trait::async_trait;
326    use std::sync::Arc;
327    use std::sync::atomic::{AtomicU32, Ordering};
328
329    /// A hook that records how many times void events fire.
330    struct CountingHook {
331        name: String,
332        priority: i32,
333        fire_count: Arc<AtomicU32>,
334    }
335
336    impl CountingHook {
337        fn new(name: &str, priority: i32) -> (Self, Arc<AtomicU32>) {
338            let count = Arc::new(AtomicU32::new(0));
339            (
340                Self {
341                    name: name.to_string(),
342                    priority,
343                    fire_count: count.clone(),
344                },
345                count,
346            )
347        }
348    }
349
350    #[async_trait]
351    impl HookHandler for CountingHook {
352        fn name(&self) -> &str {
353            &self.name
354        }
355        fn priority(&self) -> i32 {
356            self.priority
357        }
358        async fn on_heartbeat_tick(&self) {
359            self.fire_count.fetch_add(1, Ordering::SeqCst);
360        }
361    }
362
363    /// A modifying hook that uppercases the prompt.
364    struct UppercasePromptHook {
365        name: String,
366        priority: i32,
367    }
368
369    #[async_trait]
370    impl HookHandler for UppercasePromptHook {
371        fn name(&self) -> &str {
372            &self.name
373        }
374        fn priority(&self) -> i32 {
375            self.priority
376        }
377        async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
378            HookResult::Continue(prompt.to_uppercase())
379        }
380    }
381
382    /// A modifying hook that cancels before_prompt_build.
383    struct CancelPromptHook {
384        name: String,
385        priority: i32,
386    }
387
388    #[async_trait]
389    impl HookHandler for CancelPromptHook {
390        fn name(&self) -> &str {
391            &self.name
392        }
393        fn priority(&self) -> i32 {
394            self.priority
395        }
396        async fn before_prompt_build(&self, _prompt: String) -> HookResult<String> {
397            HookResult::Cancel("blocked by policy".into())
398        }
399    }
400
401    /// A modifying hook that appends a suffix to the prompt.
402    struct SuffixPromptHook {
403        name: String,
404        priority: i32,
405        suffix: String,
406    }
407
408    #[async_trait]
409    impl HookHandler for SuffixPromptHook {
410        fn name(&self) -> &str {
411            &self.name
412        }
413        fn priority(&self) -> i32 {
414            self.priority
415        }
416        async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
417            HookResult::Continue(format!("{}{}", prompt, self.suffix))
418        }
419    }
420
421    #[test]
422    fn register_and_sort_by_priority() {
423        let mut runner = HookRunner::new();
424        let (low, _) = CountingHook::new("low", 1);
425        let (high, _) = CountingHook::new("high", 10);
426        let (mid, _) = CountingHook::new("mid", 5);
427
428        runner.register(Box::new(low));
429        runner.register(Box::new(high));
430        runner.register(Box::new(mid));
431
432        let names: Vec<&str> = runner.handlers.iter().map(|h| h.name()).collect();
433        assert_eq!(names, vec!["high", "mid", "low"]);
434    }
435
436    #[tokio::test]
437    async fn void_hooks_fire_all_handlers() {
438        let mut runner = HookRunner::new();
439        let (h1, c1) = CountingHook::new("hook_a", 0);
440        let (h2, c2) = CountingHook::new("hook_b", 0);
441
442        runner.register(Box::new(h1));
443        runner.register(Box::new(h2));
444
445        runner.fire_heartbeat_tick().await;
446
447        assert_eq!(c1.load(Ordering::SeqCst), 1);
448        assert_eq!(c2.load(Ordering::SeqCst), 1);
449    }
450
451    #[tokio::test]
452    async fn modifying_hook_can_cancel() {
453        let mut runner = HookRunner::new();
454        runner.register(Box::new(CancelPromptHook {
455            name: "blocker".into(),
456            priority: 10,
457        }));
458        runner.register(Box::new(UppercasePromptHook {
459            name: "upper".into(),
460            priority: 0,
461        }));
462
463        let result = runner.run_before_prompt_build("hello".into()).await;
464        assert!(result.is_cancel());
465    }
466
467    #[tokio::test]
468    async fn modifying_hook_pipelines_data() {
469        let mut runner = HookRunner::new();
470
471        // Priority 10 runs first: uppercases
472        runner.register(Box::new(UppercasePromptHook {
473            name: "upper".into(),
474            priority: 10,
475        }));
476        // Priority 0 runs second: appends suffix
477        runner.register(Box::new(SuffixPromptHook {
478            name: "suffix".into(),
479            priority: 0,
480            suffix: "_done".into(),
481        }));
482
483        match runner.run_before_prompt_build("hello".into()).await {
484            HookResult::Continue(result) => assert_eq!(result, "HELLO_done"),
485            HookResult::Cancel(_) => panic!("should not cancel"),
486        }
487    }
488}