1use async_trait::async_trait;
2use parking_lot::RwLock;
3use serde_json::json;
4use std::collections::HashMap;
5use std::sync::Arc;
6use zeroclaw_api::channel::{Channel, SendMessage};
7use zeroclaw_api::tool::{Tool, ToolResult};
8use zeroclaw_config::policy::SecurityPolicy;
9use zeroclaw_config::policy::ToolOperation;
10
11pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
13
14const VOTE_EMOJIS: &[&str] = &[
16 "\u{0031}\u{FE0F}\u{20E3}", "\u{0032}\u{FE0F}\u{20E3}", "\u{0033}\u{FE0F}\u{20E3}", "\u{0034}\u{FE0F}\u{20E3}", "\u{0035}\u{FE0F}\u{20E3}", "\u{0036}\u{FE0F}\u{20E3}", "\u{0037}\u{FE0F}\u{20E3}", "\u{0038}\u{FE0F}\u{20E3}", "\u{0039}\u{FE0F}\u{20E3}", "\u{0031}\u{0030}\u{FE0F}\u{20E3}", ];
27
28const MIN_OPTIONS: usize = 2;
29const MAX_OPTIONS: usize = 10;
30const DEFAULT_DURATION_MINUTES: u64 = 60;
31
32pub struct PollTool {
33 security: Arc<SecurityPolicy>,
34 channels: ChannelMapHandle,
35}
36
37impl PollTool {
38 pub fn new(security: Arc<SecurityPolicy>, channels: ChannelMapHandle) -> Self {
39 Self { security, channels }
40 }
41}
42
43pub fn format_text_poll(
45 question: &str,
46 options: &[String],
47 duration_minutes: u64,
48 multi_select: bool,
49) -> String {
50 let mut lines = Vec::with_capacity(options.len() + 4);
51 lines.push(format!("\u{1F4CA} **Poll: {question}**"));
52 lines.push(String::new());
53 for (i, option) in options.iter().enumerate() {
54 let emoji = VOTE_EMOJIS.get(i).copied().unwrap_or(" ");
55 lines.push(format!("{emoji} {option}"));
56 }
57 lines.push(String::new());
58 let mode = if multi_select {
59 "multiple choices allowed"
60 } else {
61 "single choice"
62 };
63 lines.push(format!(
64 "_React with the corresponding number to vote ({mode}). Poll closes in {duration_minutes} min._"
65 ));
66 lines.join("\n")
67}
68
69fn validate_options(args: &serde_json::Value) -> Result<Vec<String>, String> {
71 let arr = args
72 .get("options")
73 .and_then(|v| v.as_array())
74 .ok_or("Missing or invalid 'options' parameter (expected array of strings)")?;
75
76 if arr.len() < MIN_OPTIONS {
77 return Err(format!(
78 "Poll requires at least {MIN_OPTIONS} options, got {}",
79 arr.len()
80 ));
81 }
82 if arr.len() > MAX_OPTIONS {
83 return Err(format!(
84 "Poll allows at most {MAX_OPTIONS} options, got {}",
85 arr.len()
86 ));
87 }
88
89 let mut options = Vec::with_capacity(arr.len());
90 for (i, v) in arr.iter().enumerate() {
91 let s = v
92 .as_str()
93 .map(|s| s.trim().to_string())
94 .filter(|s| !s.is_empty())
95 .ok_or(format!("Option at index {i} must be a non-empty string"))?;
96 options.push(s);
97 }
98 Ok(options)
99}
100
101fn supports_native_poll(channel_name: &str) -> bool {
103 let lower = channel_name.to_ascii_lowercase();
104 lower.contains("telegram") || lower.contains("discord")
105}
106
107#[async_trait]
108impl Tool for PollTool {
109 fn name(&self) -> &str {
110 "poll"
111 }
112
113 fn description(&self) -> &str {
114 "Create a poll in a messaging channel. For Telegram/Discord uses native polls; for other channels formats as a numbered text message with emoji reactions for voting."
115 }
116
117 fn parameters_schema(&self) -> serde_json::Value {
118 json!({
119 "type": "object",
120 "properties": {
121 "question": {
122 "type": "string",
123 "description": "The poll question"
124 },
125 "options": {
126 "type": "array",
127 "items": { "type": "string" },
128 "minItems": 2,
129 "maxItems": 10,
130 "description": "Poll answer options (2-10 items)"
131 },
132 "channel": {
133 "type": "string",
134 "description": "Target channel name. Defaults to the first available channel if omitted."
135 },
136 "recipient": {
137 "type": "string",
138 "description": "Recipient/chat identifier within the channel (e.g., chat_id for Telegram, channel_id for Slack)"
139 },
140 "duration_minutes": {
141 "type": "integer",
142 "description": "Poll duration in minutes (default: 60)"
143 },
144 "multi_select": {
145 "type": "boolean",
146 "description": "Allow multiple selections (default: false)"
147 }
148 },
149 "required": ["question", "options"]
150 })
151 }
152
153 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
154 if let Err(e) = self
156 .security
157 .enforce_tool_operation(ToolOperation::Act, "poll")
158 {
159 return Ok(ToolResult {
160 success: false,
161 output: String::new(),
162 error: Some(format!("Action blocked: {e}")),
163 });
164 }
165
166 let question = args
168 .get("question")
169 .and_then(|v| v.as_str())
170 .map(|s| s.trim())
171 .filter(|s| !s.is_empty())
172 .ok_or_else(|| {
173 ::zeroclaw_log::record!(
174 WARN,
175 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
176 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
177 .with_attrs(::serde_json::json!({"param": "question"})),
178 "poll: missing question parameter"
179 );
180 anyhow::Error::msg("Missing 'question' parameter")
181 })?
182 .to_string();
183
184 let options = match validate_options(&args) {
185 Ok(opts) => opts,
186 Err(msg) => {
187 return Ok(ToolResult {
188 success: false,
189 output: String::new(),
190 error: Some(msg),
191 });
192 }
193 };
194
195 let duration_minutes = args
196 .get("duration_minutes")
197 .and_then(|v| v.as_u64())
198 .unwrap_or(DEFAULT_DURATION_MINUTES);
199
200 let multi_select = args
201 .get("multi_select")
202 .and_then(|v| v.as_bool())
203 .unwrap_or(false);
204
205 let requested_channel = args
206 .get("channel")
207 .and_then(|v| v.as_str())
208 .map(|s| s.trim().to_string());
209
210 let recipient = args
211 .get("recipient")
212 .and_then(|v| v.as_str())
213 .map(|s| s.trim().to_string());
214
215 let (channel_name, channel): (String, Arc<dyn Channel>) = {
218 let channels = self.channels.read();
219 if let Some(ref name) = requested_channel {
220 let ch = channels.get(name.as_str()).cloned().ok_or_else(|| {
221 let available = channels.keys().cloned().collect::<Vec<_>>().join(", ");
222 ::zeroclaw_log::record!(
223 WARN,
224 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
225 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
226 .with_attrs(::serde_json::json!({
227 "channel_requested": name,
228 "available": &available,
229 })),
230 "poll: requested channel not found"
231 );
232 anyhow::Error::msg(format!(
233 "Channel '{name}' not found. Available: {available}"
234 ))
235 })?;
236 (name.clone(), ch)
237 } else {
238 let (name, ch) = channels.iter().next().ok_or_else(|| {
240 ::zeroclaw_log::record!(
241 ERROR,
242 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
243 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
244 .with_attrs(::serde_json::json!({"missing": "channels"})),
245 "poll: no channels configured"
246 );
247 anyhow::Error::msg("No channels available. Configure at least one channel.")
248 })?;
249 (name.clone(), ch.clone())
250 }
251 };
252
253 let recipient_id = recipient.unwrap_or_default();
254
255 let is_native = supports_native_poll(&channel_name);
260
261 let poll_text = format_text_poll(&question, &options, duration_minutes, multi_select);
262
263 let msg = SendMessage::new(&poll_text, &recipient_id);
264 if let Err(e) = channel.send(&msg).await {
265 return Ok(ToolResult {
266 success: false,
267 output: String::new(),
268 error: Some(format!(
269 "Failed to send poll to channel '{channel_name}': {e}"
270 )),
271 });
272 }
273
274 let native_note = if is_native {
275 " (native poll API available — text fallback used; trait extension needed for native support)"
276 } else {
277 ""
278 };
279
280 Ok(ToolResult {
281 success: true,
282 output: format!(
283 "Poll created on '{channel_name}'{native_note}:\n\
284 Question: {question}\n\
285 Options: {}\n\
286 Duration: {duration_minutes} min | Multi-select: {multi_select}",
287 options.join(", ")
288 ),
289 error: None,
290 })
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use zeroclaw_api::channel::ChannelMessage;
298
299 struct StubChannel {
300 name: String,
301 sent: Arc<RwLock<Vec<String>>>,
302 }
303
304 impl StubChannel {
305 fn new(name: &str) -> Self {
306 Self {
307 name: name.to_string(),
308 sent: Arc::new(RwLock::new(Vec::new())),
309 }
310 }
311 }
312
313 impl ::zeroclaw_api::attribution::Attributable for StubChannel {
314 fn role(&self) -> ::zeroclaw_api::attribution::Role {
315 ::zeroclaw_api::attribution::Role::Channel(
316 ::zeroclaw_api::attribution::ChannelKind::Webhook,
317 )
318 }
319 fn alias(&self) -> &str {
320 "test"
321 }
322 }
323
324 #[async_trait]
325 impl Channel for StubChannel {
326 fn name(&self) -> &str {
327 &self.name
328 }
329
330 async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
331 self.sent.write().push(message.content.clone());
332 Ok(())
333 }
334
335 async fn listen(
336 &self,
337 _tx: tokio::sync::mpsc::Sender<ChannelMessage>,
338 ) -> anyhow::Result<()> {
339 Ok(())
340 }
341 }
342
343 fn make_channel_map(channels: Vec<Arc<dyn Channel>>) -> ChannelMapHandle {
344 let mut map = HashMap::new();
345 for ch in channels {
346 map.insert(ch.name().to_string(), ch);
347 }
348 Arc::new(RwLock::new(map))
349 }
350
351 fn default_tool() -> PollTool {
352 let security = Arc::new(SecurityPolicy::default());
353 let stub: Arc<dyn Channel> = Arc::new(StubChannel::new("slack"));
354 let channels = make_channel_map(vec![stub]);
355 PollTool::new(security, channels)
356 }
357
358 #[test]
361 fn validate_options_rejects_too_few() {
362 let args = json!({ "options": ["only_one"] });
363 let err = validate_options(&args).unwrap_err();
364 assert!(err.contains("at least 2"), "got: {err}");
365 }
366
367 #[test]
368 fn validate_options_rejects_too_many() {
369 let opts: Vec<String> = (0..11).map(|i| format!("opt{i}")).collect();
370 let args = json!({ "options": opts });
371 let err = validate_options(&args).unwrap_err();
372 assert!(err.contains("at most 10"), "got: {err}");
373 }
374
375 #[test]
376 fn validate_options_rejects_empty_strings() {
377 let args = json!({ "options": ["a", " ", "b"] });
378 let err = validate_options(&args).unwrap_err();
379 assert!(err.contains("non-empty string"), "got: {err}");
380 }
381
382 #[test]
383 fn validate_options_rejects_missing_field() {
384 let args = json!({});
385 let err = validate_options(&args).unwrap_err();
386 assert!(err.contains("Missing"), "got: {err}");
387 }
388
389 #[test]
390 fn validate_options_accepts_valid_range() {
391 let args = json!({ "options": ["yes", "no"] });
392 let opts = validate_options(&args).unwrap();
393 assert_eq!(opts, vec!["yes", "no"]);
394
395 let opts10: Vec<String> = (0..10).map(|i| format!("opt{i}")).collect();
396 let args10 = json!({ "options": opts10 });
397 let result = validate_options(&args10).unwrap();
398 assert_eq!(result.len(), 10);
399 }
400
401 #[test]
404 fn format_text_poll_contains_question_and_options() {
405 let text = format_text_poll(
406 "Favorite color?",
407 &["Red".into(), "Blue".into(), "Green".into()],
408 30,
409 false,
410 );
411 assert!(text.contains("Favorite color?"));
412 assert!(text.contains("Red"));
413 assert!(text.contains("Blue"));
414 assert!(text.contains("Green"));
415 assert!(text.contains("30 min"));
416 assert!(text.contains("single choice"));
417 }
418
419 #[test]
420 fn format_text_poll_multi_select_label() {
421 let text = format_text_poll("Pick any", &["A".into(), "B".into()], 60, true);
422 assert!(text.contains("multiple choices allowed"));
423 }
424
425 #[test]
426 fn format_text_poll_includes_emoji_per_option() {
427 let options: Vec<String> = (1..=5).map(|i| format!("Option {i}")).collect();
428 let text = format_text_poll("Q?", &options, 10, false);
429 for emoji in &VOTE_EMOJIS[..5] {
431 assert!(text.contains(emoji), "missing emoji {emoji}");
432 }
433 }
434
435 #[tokio::test]
438 async fn execute_rejects_missing_question() {
439 let tool = default_tool();
440 let result = tool.execute(json!({ "options": ["a", "b"] })).await;
441 assert!(
442 result.is_err() || {
443 let r = result.unwrap();
444 !r.success || r.error.is_some()
445 }
446 );
447 }
448
449 #[tokio::test]
450 async fn execute_rejects_missing_options() {
451 let tool = default_tool();
452 let result = tool.execute(json!({ "question": "What?" })).await.unwrap();
453 assert!(!result.success);
454 assert!(result.error.as_deref().unwrap().contains("Missing"));
455 }
456
457 #[tokio::test]
458 async fn execute_rejects_invalid_option_count() {
459 let tool = default_tool();
460 let result = tool
461 .execute(json!({ "question": "Q?", "options": ["only_one"] }))
462 .await
463 .unwrap();
464 assert!(!result.success);
465 assert!(result.error.as_deref().unwrap().contains("at least 2"));
466 }
467
468 #[tokio::test]
469 async fn execute_succeeds_with_valid_args() {
470 let tool = default_tool();
471 let result = tool
472 .execute(json!({
473 "question": "Lunch?",
474 "options": ["Pizza", "Sushi"],
475 "channel": "slack",
476 "recipient": "general"
477 }))
478 .await
479 .unwrap();
480 assert!(result.success, "error: {:?}", result.error);
481 assert!(result.output.contains("Lunch?"));
482 assert!(result.output.contains("Pizza"));
483 }
484
485 #[tokio::test]
486 async fn execute_reports_unknown_channel() {
487 let tool = default_tool();
488 let result = tool
489 .execute(json!({
490 "question": "Q?",
491 "options": ["a", "b"],
492 "channel": "nonexistent"
493 }))
494 .await;
495 assert!(result.is_err());
497 }
498
499 #[test]
500 fn supports_native_poll_recognizes_telegram_and_discord() {
501 assert!(supports_native_poll("telegram"));
502 assert!(supports_native_poll("Telegram"));
503 assert!(supports_native_poll("my_telegram_bot"));
504 assert!(supports_native_poll("discord"));
505 assert!(supports_native_poll("Discord"));
506 assert!(!supports_native_poll("slack"));
507 assert!(!supports_native_poll("whatsapp"));
508 }
509}