1use std::time::Duration;
2
3use futures_util::{FutureExt, future::join_all};
4use serde_json::Value;
5use std::panic::AssertUnwindSafe;
6
7use zeroclaw_api::channel::ChannelMessage;
8use zeroclaw_api::model_provider::{ChatMessage, ChatResponse};
9use zeroclaw_api::tool::ToolResult;
10
11use super::traits::{HookHandler, HookResult};
12
13pub struct HookRunner {
19 handlers: Vec<Box<dyn HookHandler>>,
20}
21
22impl Default for HookRunner {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl HookRunner {
29 pub fn new() -> Self {
31 Self {
32 handlers: Vec::new(),
33 }
34 }
35
36 pub fn register(&mut self, handler: Box<dyn HookHandler>) {
38 self.handlers.push(handler);
39 self.handlers
40 .sort_by_key(|h| std::cmp::Reverse(h.priority()));
41 }
42
43 pub async fn fire_gateway_start(&self, host: &str, port: u16) {
48 let futs: Vec<_> = self
49 .handlers
50 .iter()
51 .map(|h| h.on_gateway_start(host, port))
52 .collect();
53 join_all(futs).await;
54 }
55
56 pub async fn fire_gateway_stop(&self) {
57 let futs: Vec<_> = self.handlers.iter().map(|h| h.on_gateway_stop()).collect();
58 join_all(futs).await;
59 }
60
61 pub async fn fire_session_start(&self, session_id: &str, channel: &str) {
62 let futs: Vec<_> = self
63 .handlers
64 .iter()
65 .map(|h| h.on_session_start(session_id, channel))
66 .collect();
67 join_all(futs).await;
68 }
69
70 pub async fn fire_session_end(&self, session_id: &str, channel: &str) {
71 let futs: Vec<_> = self
72 .handlers
73 .iter()
74 .map(|h| h.on_session_end(session_id, channel))
75 .collect();
76 join_all(futs).await;
77 }
78
79 pub async fn fire_llm_input(&self, messages: &[ChatMessage], model: &str) {
80 let futs: Vec<_> = self
81 .handlers
82 .iter()
83 .map(|h| h.on_llm_input(messages, model))
84 .collect();
85 join_all(futs).await;
86 }
87
88 pub async fn fire_llm_output(&self, response: &ChatResponse) {
89 let futs: Vec<_> = self
90 .handlers
91 .iter()
92 .map(|h| h.on_llm_output(response))
93 .collect();
94 join_all(futs).await;
95 }
96
97 pub async fn fire_after_tool_call(&self, tool: &str, result: &ToolResult, duration: Duration) {
98 let futs: Vec<_> = self
99 .handlers
100 .iter()
101 .map(|h| h.on_after_tool_call(tool, result, duration))
102 .collect();
103 join_all(futs).await;
104 }
105
106 pub async fn fire_message_sent(&self, channel: &str, recipient: &str, content: &str) {
107 let futs: Vec<_> = self
108 .handlers
109 .iter()
110 .map(|h| h.on_message_sent(channel, recipient, content))
111 .collect();
112 join_all(futs).await;
113 }
114
115 pub async fn fire_heartbeat_tick(&self) {
116 let futs: Vec<_> = self
117 .handlers
118 .iter()
119 .map(|h| h.on_heartbeat_tick())
120 .collect();
121 join_all(futs).await;
122 }
123
124 pub async fn run_before_model_resolve(
129 &self,
130 mut model_provider: String,
131 mut model: String,
132 ) -> HookResult<(String, String)> {
133 for h in &self.handlers {
134 let hook_name = h.name();
135 match AssertUnwindSafe(h.before_model_resolve(model_provider.clone(), model.clone()))
136 .catch_unwind()
137 .await
138 {
139 Ok(HookResult::Continue((p, m))) => {
140 model_provider = p;
141 model = m;
142 }
143 Ok(HookResult::Cancel(reason)) => {
144 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "before_model_resolve cancelled by hook");
145 return HookResult::Cancel(reason);
146 }
147 Err(_) => {
148 ::zeroclaw_log::record!(
149 ERROR,
150 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
151 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
152 .with_attrs(::serde_json::json!({"hook": hook_name})),
153 "before_model_resolve hook panicked; continuing with previous values"
154 );
155 }
156 }
157 }
158 HookResult::Continue((model_provider, model))
159 }
160
161 pub async fn run_before_prompt_build(&self, mut prompt: String) -> HookResult<String> {
162 for h in &self.handlers {
163 let hook_name = h.name();
164 match AssertUnwindSafe(h.before_prompt_build(prompt.clone()))
165 .catch_unwind()
166 .await
167 {
168 Ok(HookResult::Continue(p)) => prompt = p,
169 Ok(HookResult::Cancel(reason)) => {
170 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "before_prompt_build cancelled by hook");
171 return HookResult::Cancel(reason);
172 }
173 Err(_) => {
174 ::zeroclaw_log::record!(
175 ERROR,
176 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
177 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
178 .with_attrs(::serde_json::json!({"hook": hook_name})),
179 "before_prompt_build hook panicked; continuing with previous value"
180 );
181 }
182 }
183 }
184 HookResult::Continue(prompt)
185 }
186
187 pub async fn run_before_llm_call(
188 &self,
189 mut messages: Vec<ChatMessage>,
190 mut model: String,
191 ) -> HookResult<(Vec<ChatMessage>, String)> {
192 for h in &self.handlers {
193 let hook_name = h.name();
194 match AssertUnwindSafe(h.before_llm_call(messages.clone(), model.clone()))
195 .catch_unwind()
196 .await
197 {
198 Ok(HookResult::Continue((m, mdl))) => {
199 messages = m;
200 model = mdl;
201 }
202 Ok(HookResult::Cancel(reason)) => {
203 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "before_llm_call cancelled by hook");
204 return HookResult::Cancel(reason);
205 }
206 Err(_) => {
207 ::zeroclaw_log::record!(
208 ERROR,
209 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
210 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
211 .with_attrs(::serde_json::json!({"hook": hook_name})),
212 "before_llm_call hook panicked; continuing with previous values"
213 );
214 }
215 }
216 }
217 HookResult::Continue((messages, model))
218 }
219
220 pub async fn run_before_tool_call(
221 &self,
222 mut name: String,
223 mut args: Value,
224 ) -> HookResult<(String, Value)> {
225 for h in &self.handlers {
226 let hook_name = h.name();
227 match AssertUnwindSafe(h.before_tool_call(name.clone(), args.clone()))
228 .catch_unwind()
229 .await
230 {
231 Ok(HookResult::Continue((n, a))) => {
232 name = n;
233 args = a;
234 }
235 Ok(HookResult::Cancel(reason)) => {
236 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "before_tool_call cancelled by hook");
237 return HookResult::Cancel(reason);
238 }
239 Err(_) => {
240 ::zeroclaw_log::record!(
241 ERROR,
242 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
243 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
244 .with_attrs(::serde_json::json!({"hook": hook_name})),
245 "before_tool_call hook panicked; continuing with previous values"
246 );
247 }
248 }
249 }
250 HookResult::Continue((name, args))
251 }
252
253 pub async fn run_on_message_received(
254 &self,
255 mut message: ChannelMessage,
256 ) -> HookResult<ChannelMessage> {
257 for h in &self.handlers {
258 let hook_name = h.name();
259 match AssertUnwindSafe(h.on_message_received(message.clone()))
260 .catch_unwind()
261 .await
262 {
263 Ok(HookResult::Continue(m)) => message = m,
264 Ok(HookResult::Cancel(reason)) => {
265 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "on_message_received cancelled by hook");
266 return HookResult::Cancel(reason);
267 }
268 Err(_) => {
269 ::zeroclaw_log::record!(
270 ERROR,
271 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
272 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
273 .with_attrs(::serde_json::json!({"hook": hook_name})),
274 "on_message_received hook panicked; continuing with previous message"
275 );
276 }
277 }
278 }
279 HookResult::Continue(message)
280 }
281
282 pub async fn run_on_message_sending(
283 &self,
284 mut channel: String,
285 mut recipient: String,
286 mut content: String,
287 ) -> HookResult<(String, String, String)> {
288 for h in &self.handlers {
289 let hook_name = h.name();
290 match AssertUnwindSafe(h.on_message_sending(
291 channel.clone(),
292 recipient.clone(),
293 content.clone(),
294 ))
295 .catch_unwind()
296 .await
297 {
298 Ok(HookResult::Continue((c, r, ct))) => {
299 channel = c;
300 recipient = r;
301 content = ct;
302 }
303 Ok(HookResult::Cancel(reason)) => {
304 ::zeroclaw_log::record!(INFO, ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note).with_attrs(::serde_json::json!({"hook": hook_name, "reason": reason.to_string()})), "on_message_sending cancelled by hook");
305 return HookResult::Cancel(reason);
306 }
307 Err(_) => {
308 ::zeroclaw_log::record!(
309 ERROR,
310 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
311 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
312 .with_attrs(::serde_json::json!({"hook": hook_name})),
313 "on_message_sending hook panicked; continuing with previous message"
314 );
315 }
316 }
317 }
318 HookResult::Continue((channel, recipient, content))
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use async_trait::async_trait;
326 use std::sync::Arc;
327 use std::sync::atomic::{AtomicU32, Ordering};
328
329 struct CountingHook {
331 name: String,
332 priority: i32,
333 fire_count: Arc<AtomicU32>,
334 }
335
336 impl CountingHook {
337 fn new(name: &str, priority: i32) -> (Self, Arc<AtomicU32>) {
338 let count = Arc::new(AtomicU32::new(0));
339 (
340 Self {
341 name: name.to_string(),
342 priority,
343 fire_count: count.clone(),
344 },
345 count,
346 )
347 }
348 }
349
350 #[async_trait]
351 impl HookHandler for CountingHook {
352 fn name(&self) -> &str {
353 &self.name
354 }
355 fn priority(&self) -> i32 {
356 self.priority
357 }
358 async fn on_heartbeat_tick(&self) {
359 self.fire_count.fetch_add(1, Ordering::SeqCst);
360 }
361 }
362
363 struct UppercasePromptHook {
365 name: String,
366 priority: i32,
367 }
368
369 #[async_trait]
370 impl HookHandler for UppercasePromptHook {
371 fn name(&self) -> &str {
372 &self.name
373 }
374 fn priority(&self) -> i32 {
375 self.priority
376 }
377 async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
378 HookResult::Continue(prompt.to_uppercase())
379 }
380 }
381
382 struct CancelPromptHook {
384 name: String,
385 priority: i32,
386 }
387
388 #[async_trait]
389 impl HookHandler for CancelPromptHook {
390 fn name(&self) -> &str {
391 &self.name
392 }
393 fn priority(&self) -> i32 {
394 self.priority
395 }
396 async fn before_prompt_build(&self, _prompt: String) -> HookResult<String> {
397 HookResult::Cancel("blocked by policy".into())
398 }
399 }
400
401 struct SuffixPromptHook {
403 name: String,
404 priority: i32,
405 suffix: String,
406 }
407
408 #[async_trait]
409 impl HookHandler for SuffixPromptHook {
410 fn name(&self) -> &str {
411 &self.name
412 }
413 fn priority(&self) -> i32 {
414 self.priority
415 }
416 async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
417 HookResult::Continue(format!("{}{}", prompt, self.suffix))
418 }
419 }
420
421 #[test]
422 fn register_and_sort_by_priority() {
423 let mut runner = HookRunner::new();
424 let (low, _) = CountingHook::new("low", 1);
425 let (high, _) = CountingHook::new("high", 10);
426 let (mid, _) = CountingHook::new("mid", 5);
427
428 runner.register(Box::new(low));
429 runner.register(Box::new(high));
430 runner.register(Box::new(mid));
431
432 let names: Vec<&str> = runner.handlers.iter().map(|h| h.name()).collect();
433 assert_eq!(names, vec!["high", "mid", "low"]);
434 }
435
436 #[tokio::test]
437 async fn void_hooks_fire_all_handlers() {
438 let mut runner = HookRunner::new();
439 let (h1, c1) = CountingHook::new("hook_a", 0);
440 let (h2, c2) = CountingHook::new("hook_b", 0);
441
442 runner.register(Box::new(h1));
443 runner.register(Box::new(h2));
444
445 runner.fire_heartbeat_tick().await;
446
447 assert_eq!(c1.load(Ordering::SeqCst), 1);
448 assert_eq!(c2.load(Ordering::SeqCst), 1);
449 }
450
451 #[tokio::test]
452 async fn modifying_hook_can_cancel() {
453 let mut runner = HookRunner::new();
454 runner.register(Box::new(CancelPromptHook {
455 name: "blocker".into(),
456 priority: 10,
457 }));
458 runner.register(Box::new(UppercasePromptHook {
459 name: "upper".into(),
460 priority: 0,
461 }));
462
463 let result = runner.run_before_prompt_build("hello".into()).await;
464 assert!(result.is_cancel());
465 }
466
467 #[tokio::test]
468 async fn modifying_hook_pipelines_data() {
469 let mut runner = HookRunner::new();
470
471 runner.register(Box::new(UppercasePromptHook {
473 name: "upper".into(),
474 priority: 10,
475 }));
476 runner.register(Box::new(SuffixPromptHook {
478 name: "suffix".into(),
479 priority: 0,
480 suffix: "_done".into(),
481 }));
482
483 match runner.run_before_prompt_build("hello".into()).await {
484 HookResult::Continue(result) => assert_eq!(result, "HELLO_done"),
485 HookResult::Cancel(_) => panic!("should not cancel"),
486 }
487 }
488}