Skip to main content

zeroclaw_tools/
reaction.rs

1//! Emoji reaction tool for cross-channel message reactions.
2//!
3//! Exposes `add_reaction` and `remove_reaction` from the [`Channel`] trait as an
4//! agent-callable tool. The tool holds a late-binding channel map handle that is
5//! populated once channels are initialized (after tool construction). This mirrors
6//! the pattern used by `DelegateTool` for its parent-tools handle.
7
8use async_trait::async_trait;
9use parking_lot::RwLock;
10use serde_json::json;
11use std::collections::HashMap;
12use std::sync::Arc;
13use zeroclaw_api::channel::Channel;
14use zeroclaw_api::tool::{Tool, ToolResult};
15use zeroclaw_config::policy::SecurityPolicy;
16use zeroclaw_config::policy::ToolOperation;
17
18/// Shared handle to the channel map. Starts empty; populated once channels boot.
19pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
20
21/// Agent-callable tool for adding or removing emoji reactions on messages.
22pub struct ReactionTool {
23    channels: ChannelMapHandle,
24    security: Arc<SecurityPolicy>,
25}
26
27impl ReactionTool {
28    /// Create a new reaction tool using the given channel map.
29    pub fn new(security: Arc<SecurityPolicy>, channels: ChannelMapHandle) -> Self {
30        Self { channels, security }
31    }
32}
33
34#[async_trait]
35impl Tool for ReactionTool {
36    fn name(&self) -> &str {
37        "reaction"
38    }
39
40    fn description(&self) -> &str {
41        "Add or remove an emoji reaction on a message in any active channel. \
42         Provide the channel name (e.g. 'discord', 'slack'), the platform channel ID, \
43         the platform message ID, and the emoji (Unicode character or platform shortcode)."
44    }
45
46    fn parameters_schema(&self) -> serde_json::Value {
47        json!({
48            "type": "object",
49            "properties": {
50                "channel": {
51                    "type": "string",
52                    "description": "Name of the channel to react in (e.g. 'discord', 'slack', 'telegram')"
53                },
54                "channel_id": {
55                    "type": "string",
56                    "description": "Platform-specific channel/conversation identifier (e.g. Discord channel snowflake, Slack channel ID)"
57                },
58                "message_id": {
59                    "type": "string",
60                    "description": "Platform-scoped message identifier to react to"
61                },
62                "emoji": {
63                    "type": "string",
64                    "description": "Emoji to react with (Unicode character or platform shortcode)"
65                },
66                "action": {
67                    "type": "string",
68                    "enum": ["add", "remove"],
69                    "description": "Whether to add or remove the reaction (default: 'add')"
70                }
71            },
72            "required": ["channel", "channel_id", "message_id", "emoji"]
73        })
74    }
75
76    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
77        // Security gate
78        if let Err(error) = self
79            .security
80            .enforce_tool_operation(ToolOperation::Act, "reaction")
81        {
82            return Ok(ToolResult {
83                success: false,
84                output: String::new(),
85                error: Some(error),
86            });
87        }
88
89        let channel_name = args
90            .get("channel")
91            .and_then(|v| v.as_str())
92            .ok_or_else(|| {
93                ::zeroclaw_log::record!(
94                    WARN,
95                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
96                        .with_outcome(::zeroclaw_log::EventOutcome::Failure)
97                        .with_attrs(::serde_json::json!({"param": "channel"})),
98                    "reaction: missing channel parameter"
99                );
100                anyhow::Error::msg("Missing 'channel' parameter")
101            })?;
102
103        let channel_id = args
104            .get("channel_id")
105            .and_then(|v| v.as_str())
106            .ok_or_else(|| {
107                ::zeroclaw_log::record!(
108                    WARN,
109                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
110                        .with_outcome(::zeroclaw_log::EventOutcome::Failure)
111                        .with_attrs(::serde_json::json!({"param": "channel_id"})),
112                    "reaction: missing channel_id parameter"
113                );
114                anyhow::Error::msg("Missing 'channel_id' parameter")
115            })?;
116
117        let message_id = args
118            .get("message_id")
119            .and_then(|v| v.as_str())
120            .ok_or_else(|| {
121                ::zeroclaw_log::record!(
122                    WARN,
123                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
124                        .with_outcome(::zeroclaw_log::EventOutcome::Failure)
125                        .with_attrs(::serde_json::json!({"param": "message_id"})),
126                    "reaction: missing message_id parameter"
127                );
128                anyhow::Error::msg("Missing 'message_id' parameter")
129            })?;
130
131        let emoji = args.get("emoji").and_then(|v| v.as_str()).ok_or_else(|| {
132            ::zeroclaw_log::record!(
133                WARN,
134                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
135                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
136                    .with_attrs(::serde_json::json!({"param": "emoji"})),
137                "reaction: missing emoji parameter"
138            );
139            anyhow::Error::msg("Missing 'emoji' parameter")
140        })?;
141
142        let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("add");
143
144        if action != "add" && action != "remove" {
145            return Ok(ToolResult {
146                success: false,
147                output: String::new(),
148                error: Some(format!(
149                    "Invalid action '{action}': must be 'add' or 'remove'"
150                )),
151            });
152        }
153
154        // Read-lock the channel map to find the target channel.
155        let channel = {
156            let map = self.channels.read();
157            if map.is_empty() {
158                return Ok(ToolResult {
159                    success: false,
160                    output: String::new(),
161                    error: Some("No channels available yet (channels not initialized)".to_string()),
162                });
163            }
164            match map.get(channel_name) {
165                Some(ch) => Arc::clone(ch),
166                None => {
167                    let available: Vec<String> = map.keys().cloned().collect();
168                    return Ok(ToolResult {
169                        success: false,
170                        output: String::new(),
171                        error: Some(format!(
172                            "Channel '{channel_name}' not found. Available channels: {}",
173                            available.join(", ")
174                        )),
175                    });
176                }
177            }
178        };
179
180        let result = if action == "add" {
181            channel.add_reaction(channel_id, message_id, emoji).await
182        } else {
183            channel.remove_reaction(channel_id, message_id, emoji).await
184        };
185
186        let past_tense = if action == "remove" {
187            "removed"
188        } else {
189            "added"
190        };
191
192        match result {
193            Ok(()) => Ok(ToolResult {
194                success: true,
195                output: format!(
196                    "Reaction {past_tense}: {emoji} on message {message_id} in {channel_name}"
197                ),
198                error: None,
199            }),
200            Err(e) => Ok(ToolResult {
201                success: false,
202                output: String::new(),
203                error: Some(format!("Failed to {action} reaction: {e}")),
204            }),
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use std::sync::atomic::{AtomicBool, Ordering};
213    use zeroclaw_api::channel::{ChannelMessage, SendMessage};
214
215    struct MockChannel {
216        reaction_added: AtomicBool,
217        reaction_removed: AtomicBool,
218        last_channel_id: parking_lot::Mutex<Option<String>>,
219        fail_on_add: bool,
220    }
221
222    impl MockChannel {
223        fn new() -> Self {
224            Self {
225                reaction_added: AtomicBool::new(false),
226                reaction_removed: AtomicBool::new(false),
227                last_channel_id: parking_lot::Mutex::new(None),
228                fail_on_add: false,
229            }
230        }
231
232        fn failing() -> Self {
233            Self {
234                reaction_added: AtomicBool::new(false),
235                reaction_removed: AtomicBool::new(false),
236                last_channel_id: parking_lot::Mutex::new(None),
237                fail_on_add: true,
238            }
239        }
240    }
241
242    impl ::zeroclaw_api::attribution::Attributable for MockChannel {
243        fn role(&self) -> ::zeroclaw_api::attribution::Role {
244            ::zeroclaw_api::attribution::Role::Channel(
245                ::zeroclaw_api::attribution::ChannelKind::Webhook,
246            )
247        }
248        fn alias(&self) -> &str {
249            "test"
250        }
251    }
252
253    #[async_trait]
254    impl Channel for MockChannel {
255        fn name(&self) -> &str {
256            "mock"
257        }
258
259        async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
260            Ok(())
261        }
262
263        async fn listen(
264            &self,
265            _tx: tokio::sync::mpsc::Sender<ChannelMessage>,
266        ) -> anyhow::Result<()> {
267            Ok(())
268        }
269
270        async fn add_reaction(
271            &self,
272            channel_id: &str,
273            _message_id: &str,
274            _emoji: &str,
275        ) -> anyhow::Result<()> {
276            if self.fail_on_add {
277                return Err(anyhow::Error::msg("API error: rate limited"));
278            }
279            *self.last_channel_id.lock() = Some(channel_id.to_string());
280            self.reaction_added.store(true, Ordering::SeqCst);
281            Ok(())
282        }
283
284        async fn remove_reaction(
285            &self,
286            channel_id: &str,
287            _message_id: &str,
288            _emoji: &str,
289        ) -> anyhow::Result<()> {
290            *self.last_channel_id.lock() = Some(channel_id.to_string());
291            self.reaction_removed.store(true, Ordering::SeqCst);
292            Ok(())
293        }
294    }
295
296    fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> ReactionTool {
297        let handle = Arc::new(RwLock::new(HashMap::new()));
298        {
299            let mut map = handle.write();
300            for (name, ch) in channels {
301                map.insert(name.to_string(), ch);
302            }
303        }
304        ReactionTool::new(Arc::new(SecurityPolicy::default()), handle)
305    }
306
307    #[test]
308    fn tool_metadata() {
309        let tool = ReactionTool::new(
310            Arc::new(SecurityPolicy::default()),
311            Arc::new(RwLock::new(HashMap::new())),
312        );
313        assert_eq!(tool.name(), "reaction");
314        assert!(!tool.description().is_empty());
315        let schema = tool.parameters_schema();
316        assert_eq!(schema["type"], "object");
317        assert!(schema["properties"]["channel"].is_object());
318        assert!(schema["properties"]["channel_id"].is_object());
319        assert!(schema["properties"]["message_id"].is_object());
320        assert!(schema["properties"]["emoji"].is_object());
321        assert!(schema["properties"]["action"].is_object());
322        let required = schema["required"].as_array().unwrap();
323        assert!(required.iter().any(|v| v == "channel"));
324        assert!(required.iter().any(|v| v == "channel_id"));
325        assert!(required.iter().any(|v| v == "message_id"));
326        assert!(required.iter().any(|v| v == "emoji"));
327        // action is optional (defaults to "add")
328        assert!(!required.iter().any(|v| v == "action"));
329    }
330
331    #[tokio::test]
332    async fn add_reaction_success() {
333        let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
334        let tool = make_tool_with_channels(vec![("discord", Arc::clone(&mock))]);
335
336        let result = tool
337            .execute(json!({
338                "channel": "discord",
339                "channel_id": "ch_001",
340                "message_id": "msg_123",
341                "emoji": "\u{2705}"
342            }))
343            .await
344            .unwrap();
345
346        assert!(result.success);
347        assert!(result.output.contains("added"));
348        assert!(result.error.is_none());
349    }
350
351    #[tokio::test]
352    async fn remove_reaction_success() {
353        let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
354        let tool = make_tool_with_channels(vec![("slack", Arc::clone(&mock))]);
355
356        let result = tool
357            .execute(json!({
358                "channel": "slack",
359                "channel_id": "C0123SLACK",
360                "message_id": "msg_456",
361                "emoji": "\u{1F440}",
362                "action": "remove"
363            }))
364            .await
365            .unwrap();
366
367        assert!(result.success);
368        assert!(result.output.contains("removed"));
369    }
370
371    #[tokio::test]
372    async fn unknown_channel_returns_error() {
373        let tool = make_tool_with_channels(vec![(
374            "discord",
375            Arc::new(MockChannel::new()) as Arc<dyn Channel>,
376        )]);
377
378        let result = tool
379            .execute(json!({
380                "channel": "nonexistent",
381                "channel_id": "ch_x",
382                "message_id": "msg_1",
383                "emoji": "\u{2705}"
384            }))
385            .await
386            .unwrap();
387
388        assert!(!result.success);
389        let err = result.error.as_deref().unwrap();
390        assert!(err.contains("not found"));
391        assert!(err.contains("discord"));
392    }
393
394    #[tokio::test]
395    async fn invalid_action_returns_error() {
396        let tool = make_tool_with_channels(vec![(
397            "discord",
398            Arc::new(MockChannel::new()) as Arc<dyn Channel>,
399        )]);
400
401        let result = tool
402            .execute(json!({
403                "channel": "discord",
404                "channel_id": "ch_001",
405                "message_id": "msg_1",
406                "emoji": "\u{2705}",
407                "action": "toggle"
408            }))
409            .await
410            .unwrap();
411
412        assert!(!result.success);
413        assert!(result.error.as_deref().unwrap().contains("toggle"));
414    }
415
416    #[tokio::test]
417    async fn channel_error_propagated() {
418        let mock: Arc<dyn Channel> = Arc::new(MockChannel::failing());
419        let tool = make_tool_with_channels(vec![("discord", mock)]);
420
421        let result = tool
422            .execute(json!({
423                "channel": "discord",
424                "channel_id": "ch_001",
425                "message_id": "msg_1",
426                "emoji": "\u{2705}"
427            }))
428            .await
429            .unwrap();
430
431        assert!(!result.success);
432        assert!(result.error.as_deref().unwrap().contains("rate limited"));
433    }
434
435    #[tokio::test]
436    async fn missing_required_params() {
437        let tool = make_tool_with_channels(vec![(
438            "test",
439            Arc::new(MockChannel::new()) as Arc<dyn Channel>,
440        )]);
441
442        // Missing channel
443        let result = tool
444            .execute(json!({"channel_id": "c1", "message_id": "1", "emoji": "x"}))
445            .await;
446        assert!(result.is_err());
447
448        // Missing channel_id
449        let result = tool
450            .execute(json!({"channel": "test", "message_id": "1", "emoji": "x"}))
451            .await;
452        assert!(result.is_err());
453
454        // Missing message_id
455        let result = tool
456            .execute(json!({"channel": "a", "channel_id": "c1", "emoji": "x"}))
457            .await;
458        assert!(result.is_err());
459
460        // Missing emoji
461        let result = tool
462            .execute(json!({"channel": "a", "channel_id": "c1", "message_id": "1"}))
463            .await;
464        assert!(result.is_err());
465    }
466
467    #[tokio::test]
468    async fn empty_channels_returns_not_initialized() {
469        let tool = ReactionTool::new(
470            Arc::new(SecurityPolicy::default()),
471            Arc::new(RwLock::new(HashMap::new())),
472        );
473        // No channels populated
474
475        let result = tool
476            .execute(json!({
477                "channel": "discord",
478                "channel_id": "ch_001",
479                "message_id": "msg_1",
480                "emoji": "\u{2705}"
481            }))
482            .await
483            .unwrap();
484
485        assert!(!result.success);
486        assert!(result.error.as_deref().unwrap().contains("not initialized"));
487    }
488
489    #[tokio::test]
490    async fn default_action_is_add() {
491        let mock = Arc::new(MockChannel::new());
492        let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
493        let tool = make_tool_with_channels(vec![("test", mock_ch)]);
494
495        let result = tool
496            .execute(json!({
497                "channel": "test",
498                "channel_id": "ch_test",
499                "message_id": "msg_1",
500                "emoji": "\u{2705}"
501            }))
502            .await
503            .unwrap();
504
505        assert!(result.success);
506        assert!(mock.reaction_added.load(Ordering::SeqCst));
507        assert!(!mock.reaction_removed.load(Ordering::SeqCst));
508    }
509
510    #[tokio::test]
511    async fn channel_id_passed_to_trait_not_channel_name() {
512        let mock = Arc::new(MockChannel::new());
513        let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
514        let tool = make_tool_with_channels(vec![("discord", mock_ch)]);
515
516        let result = tool
517            .execute(json!({
518                "channel": "discord",
519                "channel_id": "123456789",
520                "message_id": "msg_1",
521                "emoji": "\u{2705}"
522            }))
523            .await
524            .unwrap();
525
526        assert!(result.success);
527        // The trait must receive the platform channel_id, not the channel name
528        assert_eq!(
529            mock.last_channel_id.lock().as_deref(),
530            Some("123456789"),
531            "add_reaction must receive channel_id, not channel name"
532        );
533    }
534
535    #[tokio::test]
536    async fn channel_map_handle_allows_late_binding() {
537        let handle = Arc::new(RwLock::new(HashMap::new()));
538        let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()), handle.clone());
539
540        // Initially empty — tool reports not initialized
541        let result = tool
542            .execute(json!({
543                "channel": "slack",
544                "channel_id": "C0123",
545                "message_id": "msg_1",
546                "emoji": "\u{2705}"
547            }))
548            .await
549            .unwrap();
550        assert!(!result.success);
551
552        // Populate via the shared handle
553        {
554            let mut map = handle.write();
555            map.insert(
556                "slack".to_string(),
557                Arc::new(MockChannel::new()) as Arc<dyn Channel>,
558            );
559        }
560
561        // Now the tool can route to the channel
562        let result = tool
563            .execute(json!({
564                "channel": "slack",
565                "channel_id": "C0123",
566                "message_id": "msg_1",
567                "emoji": "\u{2705}"
568            }))
569            .await
570            .unwrap();
571        assert!(result.success);
572    }
573
574    #[test]
575    fn spec_matches_metadata() {
576        let tool = ReactionTool::new(
577            Arc::new(SecurityPolicy::default()),
578            Arc::new(RwLock::new(HashMap::new())),
579        );
580        let spec = tool.spec();
581        assert_eq!(spec.name, "reaction");
582        assert_eq!(spec.description, tool.description());
583        assert!(spec.parameters["required"].is_array());
584    }
585}