1use async_trait::async_trait;
9use parking_lot::RwLock;
10use serde_json::json;
11use std::collections::HashMap;
12use std::sync::Arc;
13use zeroclaw_api::channel::Channel;
14use zeroclaw_api::tool::{Tool, ToolResult};
15use zeroclaw_config::policy::SecurityPolicy;
16use zeroclaw_config::policy::ToolOperation;
17
18pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
20
21pub struct ReactionTool {
23 channels: ChannelMapHandle,
24 security: Arc<SecurityPolicy>,
25}
26
27impl ReactionTool {
28 pub fn new(security: Arc<SecurityPolicy>, channels: ChannelMapHandle) -> Self {
30 Self { channels, security }
31 }
32}
33
34#[async_trait]
35impl Tool for ReactionTool {
36 fn name(&self) -> &str {
37 "reaction"
38 }
39
40 fn description(&self) -> &str {
41 "Add or remove an emoji reaction on a message in any active channel. \
42 Provide the channel name (e.g. 'discord', 'slack'), the platform channel ID, \
43 the platform message ID, and the emoji (Unicode character or platform shortcode)."
44 }
45
46 fn parameters_schema(&self) -> serde_json::Value {
47 json!({
48 "type": "object",
49 "properties": {
50 "channel": {
51 "type": "string",
52 "description": "Name of the channel to react in (e.g. 'discord', 'slack', 'telegram')"
53 },
54 "channel_id": {
55 "type": "string",
56 "description": "Platform-specific channel/conversation identifier (e.g. Discord channel snowflake, Slack channel ID)"
57 },
58 "message_id": {
59 "type": "string",
60 "description": "Platform-scoped message identifier to react to"
61 },
62 "emoji": {
63 "type": "string",
64 "description": "Emoji to react with (Unicode character or platform shortcode)"
65 },
66 "action": {
67 "type": "string",
68 "enum": ["add", "remove"],
69 "description": "Whether to add or remove the reaction (default: 'add')"
70 }
71 },
72 "required": ["channel", "channel_id", "message_id", "emoji"]
73 })
74 }
75
76 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
77 if let Err(error) = self
79 .security
80 .enforce_tool_operation(ToolOperation::Act, "reaction")
81 {
82 return Ok(ToolResult {
83 success: false,
84 output: String::new(),
85 error: Some(error),
86 });
87 }
88
89 let channel_name = args
90 .get("channel")
91 .and_then(|v| v.as_str())
92 .ok_or_else(|| {
93 ::zeroclaw_log::record!(
94 WARN,
95 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
96 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
97 .with_attrs(::serde_json::json!({"param": "channel"})),
98 "reaction: missing channel parameter"
99 );
100 anyhow::Error::msg("Missing 'channel' parameter")
101 })?;
102
103 let channel_id = args
104 .get("channel_id")
105 .and_then(|v| v.as_str())
106 .ok_or_else(|| {
107 ::zeroclaw_log::record!(
108 WARN,
109 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
110 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
111 .with_attrs(::serde_json::json!({"param": "channel_id"})),
112 "reaction: missing channel_id parameter"
113 );
114 anyhow::Error::msg("Missing 'channel_id' parameter")
115 })?;
116
117 let message_id = args
118 .get("message_id")
119 .and_then(|v| v.as_str())
120 .ok_or_else(|| {
121 ::zeroclaw_log::record!(
122 WARN,
123 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
124 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
125 .with_attrs(::serde_json::json!({"param": "message_id"})),
126 "reaction: missing message_id parameter"
127 );
128 anyhow::Error::msg("Missing 'message_id' parameter")
129 })?;
130
131 let emoji = args.get("emoji").and_then(|v| v.as_str()).ok_or_else(|| {
132 ::zeroclaw_log::record!(
133 WARN,
134 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
135 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
136 .with_attrs(::serde_json::json!({"param": "emoji"})),
137 "reaction: missing emoji parameter"
138 );
139 anyhow::Error::msg("Missing 'emoji' parameter")
140 })?;
141
142 let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("add");
143
144 if action != "add" && action != "remove" {
145 return Ok(ToolResult {
146 success: false,
147 output: String::new(),
148 error: Some(format!(
149 "Invalid action '{action}': must be 'add' or 'remove'"
150 )),
151 });
152 }
153
154 let channel = {
156 let map = self.channels.read();
157 if map.is_empty() {
158 return Ok(ToolResult {
159 success: false,
160 output: String::new(),
161 error: Some("No channels available yet (channels not initialized)".to_string()),
162 });
163 }
164 match map.get(channel_name) {
165 Some(ch) => Arc::clone(ch),
166 None => {
167 let available: Vec<String> = map.keys().cloned().collect();
168 return Ok(ToolResult {
169 success: false,
170 output: String::new(),
171 error: Some(format!(
172 "Channel '{channel_name}' not found. Available channels: {}",
173 available.join(", ")
174 )),
175 });
176 }
177 }
178 };
179
180 let result = if action == "add" {
181 channel.add_reaction(channel_id, message_id, emoji).await
182 } else {
183 channel.remove_reaction(channel_id, message_id, emoji).await
184 };
185
186 let past_tense = if action == "remove" {
187 "removed"
188 } else {
189 "added"
190 };
191
192 match result {
193 Ok(()) => Ok(ToolResult {
194 success: true,
195 output: format!(
196 "Reaction {past_tense}: {emoji} on message {message_id} in {channel_name}"
197 ),
198 error: None,
199 }),
200 Err(e) => Ok(ToolResult {
201 success: false,
202 output: String::new(),
203 error: Some(format!("Failed to {action} reaction: {e}")),
204 }),
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use std::sync::atomic::{AtomicBool, Ordering};
213 use zeroclaw_api::channel::{ChannelMessage, SendMessage};
214
215 struct MockChannel {
216 reaction_added: AtomicBool,
217 reaction_removed: AtomicBool,
218 last_channel_id: parking_lot::Mutex<Option<String>>,
219 fail_on_add: bool,
220 }
221
222 impl MockChannel {
223 fn new() -> Self {
224 Self {
225 reaction_added: AtomicBool::new(false),
226 reaction_removed: AtomicBool::new(false),
227 last_channel_id: parking_lot::Mutex::new(None),
228 fail_on_add: false,
229 }
230 }
231
232 fn failing() -> Self {
233 Self {
234 reaction_added: AtomicBool::new(false),
235 reaction_removed: AtomicBool::new(false),
236 last_channel_id: parking_lot::Mutex::new(None),
237 fail_on_add: true,
238 }
239 }
240 }
241
242 impl ::zeroclaw_api::attribution::Attributable for MockChannel {
243 fn role(&self) -> ::zeroclaw_api::attribution::Role {
244 ::zeroclaw_api::attribution::Role::Channel(
245 ::zeroclaw_api::attribution::ChannelKind::Webhook,
246 )
247 }
248 fn alias(&self) -> &str {
249 "test"
250 }
251 }
252
253 #[async_trait]
254 impl Channel for MockChannel {
255 fn name(&self) -> &str {
256 "mock"
257 }
258
259 async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
260 Ok(())
261 }
262
263 async fn listen(
264 &self,
265 _tx: tokio::sync::mpsc::Sender<ChannelMessage>,
266 ) -> anyhow::Result<()> {
267 Ok(())
268 }
269
270 async fn add_reaction(
271 &self,
272 channel_id: &str,
273 _message_id: &str,
274 _emoji: &str,
275 ) -> anyhow::Result<()> {
276 if self.fail_on_add {
277 return Err(anyhow::Error::msg("API error: rate limited"));
278 }
279 *self.last_channel_id.lock() = Some(channel_id.to_string());
280 self.reaction_added.store(true, Ordering::SeqCst);
281 Ok(())
282 }
283
284 async fn remove_reaction(
285 &self,
286 channel_id: &str,
287 _message_id: &str,
288 _emoji: &str,
289 ) -> anyhow::Result<()> {
290 *self.last_channel_id.lock() = Some(channel_id.to_string());
291 self.reaction_removed.store(true, Ordering::SeqCst);
292 Ok(())
293 }
294 }
295
296 fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> ReactionTool {
297 let handle = Arc::new(RwLock::new(HashMap::new()));
298 {
299 let mut map = handle.write();
300 for (name, ch) in channels {
301 map.insert(name.to_string(), ch);
302 }
303 }
304 ReactionTool::new(Arc::new(SecurityPolicy::default()), handle)
305 }
306
307 #[test]
308 fn tool_metadata() {
309 let tool = ReactionTool::new(
310 Arc::new(SecurityPolicy::default()),
311 Arc::new(RwLock::new(HashMap::new())),
312 );
313 assert_eq!(tool.name(), "reaction");
314 assert!(!tool.description().is_empty());
315 let schema = tool.parameters_schema();
316 assert_eq!(schema["type"], "object");
317 assert!(schema["properties"]["channel"].is_object());
318 assert!(schema["properties"]["channel_id"].is_object());
319 assert!(schema["properties"]["message_id"].is_object());
320 assert!(schema["properties"]["emoji"].is_object());
321 assert!(schema["properties"]["action"].is_object());
322 let required = schema["required"].as_array().unwrap();
323 assert!(required.iter().any(|v| v == "channel"));
324 assert!(required.iter().any(|v| v == "channel_id"));
325 assert!(required.iter().any(|v| v == "message_id"));
326 assert!(required.iter().any(|v| v == "emoji"));
327 assert!(!required.iter().any(|v| v == "action"));
329 }
330
331 #[tokio::test]
332 async fn add_reaction_success() {
333 let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
334 let tool = make_tool_with_channels(vec![("discord", Arc::clone(&mock))]);
335
336 let result = tool
337 .execute(json!({
338 "channel": "discord",
339 "channel_id": "ch_001",
340 "message_id": "msg_123",
341 "emoji": "\u{2705}"
342 }))
343 .await
344 .unwrap();
345
346 assert!(result.success);
347 assert!(result.output.contains("added"));
348 assert!(result.error.is_none());
349 }
350
351 #[tokio::test]
352 async fn remove_reaction_success() {
353 let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
354 let tool = make_tool_with_channels(vec![("slack", Arc::clone(&mock))]);
355
356 let result = tool
357 .execute(json!({
358 "channel": "slack",
359 "channel_id": "C0123SLACK",
360 "message_id": "msg_456",
361 "emoji": "\u{1F440}",
362 "action": "remove"
363 }))
364 .await
365 .unwrap();
366
367 assert!(result.success);
368 assert!(result.output.contains("removed"));
369 }
370
371 #[tokio::test]
372 async fn unknown_channel_returns_error() {
373 let tool = make_tool_with_channels(vec![(
374 "discord",
375 Arc::new(MockChannel::new()) as Arc<dyn Channel>,
376 )]);
377
378 let result = tool
379 .execute(json!({
380 "channel": "nonexistent",
381 "channel_id": "ch_x",
382 "message_id": "msg_1",
383 "emoji": "\u{2705}"
384 }))
385 .await
386 .unwrap();
387
388 assert!(!result.success);
389 let err = result.error.as_deref().unwrap();
390 assert!(err.contains("not found"));
391 assert!(err.contains("discord"));
392 }
393
394 #[tokio::test]
395 async fn invalid_action_returns_error() {
396 let tool = make_tool_with_channels(vec![(
397 "discord",
398 Arc::new(MockChannel::new()) as Arc<dyn Channel>,
399 )]);
400
401 let result = tool
402 .execute(json!({
403 "channel": "discord",
404 "channel_id": "ch_001",
405 "message_id": "msg_1",
406 "emoji": "\u{2705}",
407 "action": "toggle"
408 }))
409 .await
410 .unwrap();
411
412 assert!(!result.success);
413 assert!(result.error.as_deref().unwrap().contains("toggle"));
414 }
415
416 #[tokio::test]
417 async fn channel_error_propagated() {
418 let mock: Arc<dyn Channel> = Arc::new(MockChannel::failing());
419 let tool = make_tool_with_channels(vec![("discord", mock)]);
420
421 let result = tool
422 .execute(json!({
423 "channel": "discord",
424 "channel_id": "ch_001",
425 "message_id": "msg_1",
426 "emoji": "\u{2705}"
427 }))
428 .await
429 .unwrap();
430
431 assert!(!result.success);
432 assert!(result.error.as_deref().unwrap().contains("rate limited"));
433 }
434
435 #[tokio::test]
436 async fn missing_required_params() {
437 let tool = make_tool_with_channels(vec![(
438 "test",
439 Arc::new(MockChannel::new()) as Arc<dyn Channel>,
440 )]);
441
442 let result = tool
444 .execute(json!({"channel_id": "c1", "message_id": "1", "emoji": "x"}))
445 .await;
446 assert!(result.is_err());
447
448 let result = tool
450 .execute(json!({"channel": "test", "message_id": "1", "emoji": "x"}))
451 .await;
452 assert!(result.is_err());
453
454 let result = tool
456 .execute(json!({"channel": "a", "channel_id": "c1", "emoji": "x"}))
457 .await;
458 assert!(result.is_err());
459
460 let result = tool
462 .execute(json!({"channel": "a", "channel_id": "c1", "message_id": "1"}))
463 .await;
464 assert!(result.is_err());
465 }
466
467 #[tokio::test]
468 async fn empty_channels_returns_not_initialized() {
469 let tool = ReactionTool::new(
470 Arc::new(SecurityPolicy::default()),
471 Arc::new(RwLock::new(HashMap::new())),
472 );
473 let result = tool
476 .execute(json!({
477 "channel": "discord",
478 "channel_id": "ch_001",
479 "message_id": "msg_1",
480 "emoji": "\u{2705}"
481 }))
482 .await
483 .unwrap();
484
485 assert!(!result.success);
486 assert!(result.error.as_deref().unwrap().contains("not initialized"));
487 }
488
489 #[tokio::test]
490 async fn default_action_is_add() {
491 let mock = Arc::new(MockChannel::new());
492 let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
493 let tool = make_tool_with_channels(vec![("test", mock_ch)]);
494
495 let result = tool
496 .execute(json!({
497 "channel": "test",
498 "channel_id": "ch_test",
499 "message_id": "msg_1",
500 "emoji": "\u{2705}"
501 }))
502 .await
503 .unwrap();
504
505 assert!(result.success);
506 assert!(mock.reaction_added.load(Ordering::SeqCst));
507 assert!(!mock.reaction_removed.load(Ordering::SeqCst));
508 }
509
510 #[tokio::test]
511 async fn channel_id_passed_to_trait_not_channel_name() {
512 let mock = Arc::new(MockChannel::new());
513 let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
514 let tool = make_tool_with_channels(vec![("discord", mock_ch)]);
515
516 let result = tool
517 .execute(json!({
518 "channel": "discord",
519 "channel_id": "123456789",
520 "message_id": "msg_1",
521 "emoji": "\u{2705}"
522 }))
523 .await
524 .unwrap();
525
526 assert!(result.success);
527 assert_eq!(
529 mock.last_channel_id.lock().as_deref(),
530 Some("123456789"),
531 "add_reaction must receive channel_id, not channel name"
532 );
533 }
534
535 #[tokio::test]
536 async fn channel_map_handle_allows_late_binding() {
537 let handle = Arc::new(RwLock::new(HashMap::new()));
538 let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()), handle.clone());
539
540 let result = tool
542 .execute(json!({
543 "channel": "slack",
544 "channel_id": "C0123",
545 "message_id": "msg_1",
546 "emoji": "\u{2705}"
547 }))
548 .await
549 .unwrap();
550 assert!(!result.success);
551
552 {
554 let mut map = handle.write();
555 map.insert(
556 "slack".to_string(),
557 Arc::new(MockChannel::new()) as Arc<dyn Channel>,
558 );
559 }
560
561 let result = tool
563 .execute(json!({
564 "channel": "slack",
565 "channel_id": "C0123",
566 "message_id": "msg_1",
567 "emoji": "\u{2705}"
568 }))
569 .await
570 .unwrap();
571 assert!(result.success);
572 }
573
574 #[test]
575 fn spec_matches_metadata() {
576 let tool = ReactionTool::new(
577 Arc::new(SecurityPolicy::default()),
578 Arc::new(RwLock::new(HashMap::new())),
579 );
580 let spec = tool.spec();
581 assert_eq!(spec.name, "reaction");
582 assert_eq!(spec.description, tool.description());
583 assert!(spec.parameters["required"].is_array());
584 }
585}