Skip to main content

zeroclaw_tools/
poll.rs

1use async_trait::async_trait;
2use parking_lot::RwLock;
3use serde_json::json;
4use std::collections::HashMap;
5use std::sync::Arc;
6use zeroclaw_api::channel::{Channel, SendMessage};
7use zeroclaw_api::tool::{Tool, ToolResult};
8use zeroclaw_config::policy::SecurityPolicy;
9use zeroclaw_config::policy::ToolOperation;
10
11/// Shared handle giving tools late-bound access to the live channel map.
12pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
13
14/// Number emojis used for text-based poll fallback voting.
15const VOTE_EMOJIS: &[&str] = &[
16    "\u{0031}\u{FE0F}\u{20E3}",         // 1️⃣
17    "\u{0032}\u{FE0F}\u{20E3}",         // 2️⃣
18    "\u{0033}\u{FE0F}\u{20E3}",         // 3️⃣
19    "\u{0034}\u{FE0F}\u{20E3}",         // 4️⃣
20    "\u{0035}\u{FE0F}\u{20E3}",         // 5️⃣
21    "\u{0036}\u{FE0F}\u{20E3}",         // 6️⃣
22    "\u{0037}\u{FE0F}\u{20E3}",         // 7️⃣
23    "\u{0038}\u{FE0F}\u{20E3}",         // 8️⃣
24    "\u{0039}\u{FE0F}\u{20E3}",         // 9️⃣
25    "\u{0031}\u{0030}\u{FE0F}\u{20E3}", // 🔟 (keycap 10 — may render differently)
26];
27
28const MIN_OPTIONS: usize = 2;
29const MAX_OPTIONS: usize = 10;
30const DEFAULT_DURATION_MINUTES: u64 = 60;
31
32pub struct PollTool {
33    security: Arc<SecurityPolicy>,
34    channels: ChannelMapHandle,
35}
36
37impl PollTool {
38    pub fn new(security: Arc<SecurityPolicy>, channels: ChannelMapHandle) -> Self {
39        Self { security, channels }
40    }
41}
42
43/// Format a poll as a numbered text message for channels without native poll support.
44pub fn format_text_poll(
45    question: &str,
46    options: &[String],
47    duration_minutes: u64,
48    multi_select: bool,
49) -> String {
50    let mut lines = Vec::with_capacity(options.len() + 4);
51    lines.push(format!("\u{1F4CA} **Poll: {question}**"));
52    lines.push(String::new());
53    for (i, option) in options.iter().enumerate() {
54        let emoji = VOTE_EMOJIS.get(i).copied().unwrap_or("  ");
55        lines.push(format!("{emoji}  {option}"));
56    }
57    lines.push(String::new());
58    let mode = if multi_select {
59        "multiple choices allowed"
60    } else {
61        "single choice"
62    };
63    lines.push(format!(
64        "_React with the corresponding number to vote ({mode}). Poll closes in {duration_minutes} min._"
65    ));
66    lines.join("\n")
67}
68
69/// Validate the options array: 2-10 non-empty strings.
70fn validate_options(args: &serde_json::Value) -> Result<Vec<String>, String> {
71    let arr = args
72        .get("options")
73        .and_then(|v| v.as_array())
74        .ok_or("Missing or invalid 'options' parameter (expected array of strings)")?;
75
76    if arr.len() < MIN_OPTIONS {
77        return Err(format!(
78            "Poll requires at least {MIN_OPTIONS} options, got {}",
79            arr.len()
80        ));
81    }
82    if arr.len() > MAX_OPTIONS {
83        return Err(format!(
84            "Poll allows at most {MAX_OPTIONS} options, got {}",
85            arr.len()
86        ));
87    }
88
89    let mut options = Vec::with_capacity(arr.len());
90    for (i, v) in arr.iter().enumerate() {
91        let s = v
92            .as_str()
93            .map(|s| s.trim().to_string())
94            .filter(|s| !s.is_empty())
95            .ok_or(format!("Option at index {i} must be a non-empty string"))?;
96        options.push(s);
97    }
98    Ok(options)
99}
100
101/// Returns true for channel names that support native polls (Telegram, Discord).
102fn supports_native_poll(channel_name: &str) -> bool {
103    let lower = channel_name.to_ascii_lowercase();
104    lower.contains("telegram") || lower.contains("discord")
105}
106
107#[async_trait]
108impl Tool for PollTool {
109    fn name(&self) -> &str {
110        "poll"
111    }
112
113    fn description(&self) -> &str {
114        "Create a poll in a messaging channel. For Telegram/Discord uses native polls; for other channels formats as a numbered text message with emoji reactions for voting."
115    }
116
117    fn parameters_schema(&self) -> serde_json::Value {
118        json!({
119            "type": "object",
120            "properties": {
121                "question": {
122                    "type": "string",
123                    "description": "The poll question"
124                },
125                "options": {
126                    "type": "array",
127                    "items": { "type": "string" },
128                    "minItems": 2,
129                    "maxItems": 10,
130                    "description": "Poll answer options (2-10 items)"
131                },
132                "channel": {
133                    "type": "string",
134                    "description": "Target channel name. Defaults to the first available channel if omitted."
135                },
136                "recipient": {
137                    "type": "string",
138                    "description": "Recipient/chat identifier within the channel (e.g., chat_id for Telegram, channel_id for Slack)"
139                },
140                "duration_minutes": {
141                    "type": "integer",
142                    "description": "Poll duration in minutes (default: 60)"
143                },
144                "multi_select": {
145                    "type": "boolean",
146                    "description": "Allow multiple selections (default: false)"
147                }
148            },
149            "required": ["question", "options"]
150        })
151    }
152
153    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
154        // Security gate: Act operation
155        if let Err(e) = self
156            .security
157            .enforce_tool_operation(ToolOperation::Act, "poll")
158        {
159            return Ok(ToolResult {
160                success: false,
161                output: String::new(),
162                error: Some(format!("Action blocked: {e}")),
163            });
164        }
165
166        // Parse required params
167        let question = args
168            .get("question")
169            .and_then(|v| v.as_str())
170            .map(|s| s.trim())
171            .filter(|s| !s.is_empty())
172            .ok_or_else(|| {
173                ::zeroclaw_log::record!(
174                    WARN,
175                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
176                        .with_outcome(::zeroclaw_log::EventOutcome::Failure)
177                        .with_attrs(::serde_json::json!({"param": "question"})),
178                    "poll: missing question parameter"
179                );
180                anyhow::Error::msg("Missing 'question' parameter")
181            })?
182            .to_string();
183
184        let options = match validate_options(&args) {
185            Ok(opts) => opts,
186            Err(msg) => {
187                return Ok(ToolResult {
188                    success: false,
189                    output: String::new(),
190                    error: Some(msg),
191                });
192            }
193        };
194
195        let duration_minutes = args
196            .get("duration_minutes")
197            .and_then(|v| v.as_u64())
198            .unwrap_or(DEFAULT_DURATION_MINUTES);
199
200        let multi_select = args
201            .get("multi_select")
202            .and_then(|v| v.as_bool())
203            .unwrap_or(false);
204
205        let requested_channel = args
206            .get("channel")
207            .and_then(|v| v.as_str())
208            .map(|s| s.trim().to_string());
209
210        let recipient = args
211            .get("recipient")
212            .and_then(|v| v.as_str())
213            .map(|s| s.trim().to_string());
214
215        // Resolve channel from handle — block-scoped to drop the RwLock guard
216        // before any `.await` (parking_lot guards are !Send).
217        let (channel_name, channel): (String, Arc<dyn Channel>) = {
218            let channels = self.channels.read();
219            if let Some(ref name) = requested_channel {
220                let ch = channels.get(name.as_str()).cloned().ok_or_else(|| {
221                    let available = channels.keys().cloned().collect::<Vec<_>>().join(", ");
222                    ::zeroclaw_log::record!(
223                        WARN,
224                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
225                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
226                            .with_attrs(::serde_json::json!({
227                                "channel_requested": name,
228                                "available": &available,
229                            })),
230                        "poll: requested channel not found"
231                    );
232                    anyhow::Error::msg(format!(
233                        "Channel '{name}' not found. Available: {available}"
234                    ))
235                })?;
236                (name.clone(), ch)
237            } else {
238                // Fall back to first available channel
239                let (name, ch) = channels.iter().next().ok_or_else(|| {
240                    ::zeroclaw_log::record!(
241                        ERROR,
242                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
243                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
244                            .with_attrs(::serde_json::json!({"missing": "channels"})),
245                        "poll: no channels configured"
246                    );
247                    anyhow::Error::msg("No channels available. Configure at least one channel.")
248                })?;
249                (name.clone(), ch.clone())
250            }
251        };
252
253        let recipient_id = recipient.unwrap_or_default();
254
255        // For channels with native poll support, we still send a formatted message.
256        // The Channel trait does not expose a create_poll method, so all channels
257        // receive a text-formatted poll. Native Telegram/Discord poll APIs would
258        // require a trait extension; for now we note the intent in the output.
259        let is_native = supports_native_poll(&channel_name);
260
261        let poll_text = format_text_poll(&question, &options, duration_minutes, multi_select);
262
263        let msg = SendMessage::new(&poll_text, &recipient_id);
264        if let Err(e) = channel.send(&msg).await {
265            return Ok(ToolResult {
266                success: false,
267                output: String::new(),
268                error: Some(format!(
269                    "Failed to send poll to channel '{channel_name}': {e}"
270                )),
271            });
272        }
273
274        let native_note = if is_native {
275            " (native poll API available — text fallback used; trait extension needed for native support)"
276        } else {
277            ""
278        };
279
280        Ok(ToolResult {
281            success: true,
282            output: format!(
283                "Poll created on '{channel_name}'{native_note}:\n\
284                 Question: {question}\n\
285                 Options: {}\n\
286                 Duration: {duration_minutes} min | Multi-select: {multi_select}",
287                options.join(", ")
288            ),
289            error: None,
290        })
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use zeroclaw_api::channel::ChannelMessage;
298
299    struct StubChannel {
300        name: String,
301        sent: Arc<RwLock<Vec<String>>>,
302    }
303
304    impl StubChannel {
305        fn new(name: &str) -> Self {
306            Self {
307                name: name.to_string(),
308                sent: Arc::new(RwLock::new(Vec::new())),
309            }
310        }
311    }
312
313    impl ::zeroclaw_api::attribution::Attributable for StubChannel {
314        fn role(&self) -> ::zeroclaw_api::attribution::Role {
315            ::zeroclaw_api::attribution::Role::Channel(
316                ::zeroclaw_api::attribution::ChannelKind::Webhook,
317            )
318        }
319        fn alias(&self) -> &str {
320            "test"
321        }
322    }
323
324    #[async_trait]
325    impl Channel for StubChannel {
326        fn name(&self) -> &str {
327            &self.name
328        }
329
330        async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
331            self.sent.write().push(message.content.clone());
332            Ok(())
333        }
334
335        async fn listen(
336            &self,
337            _tx: tokio::sync::mpsc::Sender<ChannelMessage>,
338        ) -> anyhow::Result<()> {
339            Ok(())
340        }
341    }
342
343    fn make_channel_map(channels: Vec<Arc<dyn Channel>>) -> ChannelMapHandle {
344        let mut map = HashMap::new();
345        for ch in channels {
346            map.insert(ch.name().to_string(), ch);
347        }
348        Arc::new(RwLock::new(map))
349    }
350
351    fn default_tool() -> PollTool {
352        let security = Arc::new(SecurityPolicy::default());
353        let stub: Arc<dyn Channel> = Arc::new(StubChannel::new("slack"));
354        let channels = make_channel_map(vec![stub]);
355        PollTool::new(security, channels)
356    }
357
358    // ── Option validation tests ──
359
360    #[test]
361    fn validate_options_rejects_too_few() {
362        let args = json!({ "options": ["only_one"] });
363        let err = validate_options(&args).unwrap_err();
364        assert!(err.contains("at least 2"), "got: {err}");
365    }
366
367    #[test]
368    fn validate_options_rejects_too_many() {
369        let opts: Vec<String> = (0..11).map(|i| format!("opt{i}")).collect();
370        let args = json!({ "options": opts });
371        let err = validate_options(&args).unwrap_err();
372        assert!(err.contains("at most 10"), "got: {err}");
373    }
374
375    #[test]
376    fn validate_options_rejects_empty_strings() {
377        let args = json!({ "options": ["a", "  ", "b"] });
378        let err = validate_options(&args).unwrap_err();
379        assert!(err.contains("non-empty string"), "got: {err}");
380    }
381
382    #[test]
383    fn validate_options_rejects_missing_field() {
384        let args = json!({});
385        let err = validate_options(&args).unwrap_err();
386        assert!(err.contains("Missing"), "got: {err}");
387    }
388
389    #[test]
390    fn validate_options_accepts_valid_range() {
391        let args = json!({ "options": ["yes", "no"] });
392        let opts = validate_options(&args).unwrap();
393        assert_eq!(opts, vec!["yes", "no"]);
394
395        let opts10: Vec<String> = (0..10).map(|i| format!("opt{i}")).collect();
396        let args10 = json!({ "options": opts10 });
397        let result = validate_options(&args10).unwrap();
398        assert_eq!(result.len(), 10);
399    }
400
401    // ── Text-based poll formatting tests ──
402
403    #[test]
404    fn format_text_poll_contains_question_and_options() {
405        let text = format_text_poll(
406            "Favorite color?",
407            &["Red".into(), "Blue".into(), "Green".into()],
408            30,
409            false,
410        );
411        assert!(text.contains("Favorite color?"));
412        assert!(text.contains("Red"));
413        assert!(text.contains("Blue"));
414        assert!(text.contains("Green"));
415        assert!(text.contains("30 min"));
416        assert!(text.contains("single choice"));
417    }
418
419    #[test]
420    fn format_text_poll_multi_select_label() {
421        let text = format_text_poll("Pick any", &["A".into(), "B".into()], 60, true);
422        assert!(text.contains("multiple choices allowed"));
423    }
424
425    #[test]
426    fn format_text_poll_includes_emoji_per_option() {
427        let options: Vec<String> = (1..=5).map(|i| format!("Option {i}")).collect();
428        let text = format_text_poll("Q?", &options, 10, false);
429        // Each option line should contain its number emoji
430        for emoji in &VOTE_EMOJIS[..5] {
431            assert!(text.contains(emoji), "missing emoji {emoji}");
432        }
433    }
434
435    // ── Missing parameters tests ──
436
437    #[tokio::test]
438    async fn execute_rejects_missing_question() {
439        let tool = default_tool();
440        let result = tool.execute(json!({ "options": ["a", "b"] })).await;
441        assert!(
442            result.is_err() || {
443                let r = result.unwrap();
444                !r.success || r.error.is_some()
445            }
446        );
447    }
448
449    #[tokio::test]
450    async fn execute_rejects_missing_options() {
451        let tool = default_tool();
452        let result = tool.execute(json!({ "question": "What?" })).await.unwrap();
453        assert!(!result.success);
454        assert!(result.error.as_deref().unwrap().contains("Missing"));
455    }
456
457    #[tokio::test]
458    async fn execute_rejects_invalid_option_count() {
459        let tool = default_tool();
460        let result = tool
461            .execute(json!({ "question": "Q?", "options": ["only_one"] }))
462            .await
463            .unwrap();
464        assert!(!result.success);
465        assert!(result.error.as_deref().unwrap().contains("at least 2"));
466    }
467
468    #[tokio::test]
469    async fn execute_succeeds_with_valid_args() {
470        let tool = default_tool();
471        let result = tool
472            .execute(json!({
473                "question": "Lunch?",
474                "options": ["Pizza", "Sushi"],
475                "channel": "slack",
476                "recipient": "general"
477            }))
478            .await
479            .unwrap();
480        assert!(result.success, "error: {:?}", result.error);
481        assert!(result.output.contains("Lunch?"));
482        assert!(result.output.contains("Pizza"));
483    }
484
485    #[tokio::test]
486    async fn execute_reports_unknown_channel() {
487        let tool = default_tool();
488        let result = tool
489            .execute(json!({
490                "question": "Q?",
491                "options": ["a", "b"],
492                "channel": "nonexistent"
493            }))
494            .await;
495        // Should be an Err because channel not found
496        assert!(result.is_err());
497    }
498
499    #[test]
500    fn supports_native_poll_recognizes_telegram_and_discord() {
501        assert!(supports_native_poll("telegram"));
502        assert!(supports_native_poll("Telegram"));
503        assert!(supports_native_poll("my_telegram_bot"));
504        assert!(supports_native_poll("discord"));
505        assert!(supports_native_poll("Discord"));
506        assert!(!supports_native_poll("slack"));
507        assert!(!supports_native_poll("whatsapp"));
508    }
509}