Skip to main content

zeroclaw_tools/
ask_user.rs

1//! Interactive user prompting tool for cross-channel confirmations.
2//!
3//! Exposes `ask_user` as an agent-callable tool that sends a question to a
4//! messaging channel and waits for the user's response. The tool holds a
5//! late-binding channel map handle that is populated once channels are
6//! initialized (after tool construction). This mirrors the pattern used by
7//! [`ReactionTool`](super::reaction::ReactionTool).
8
9use async_trait::async_trait;
10use parking_lot::RwLock;
11use serde_json::json;
12use std::collections::HashMap;
13use std::sync::Arc;
14use zeroclaw_api::channel::{Channel, ChannelMessage, SendMessage};
15use zeroclaw_api::tool::{Tool, ToolResult};
16use zeroclaw_config::policy::SecurityPolicy;
17use zeroclaw_config::policy::ToolOperation;
18
19/// Shared handle giving tools late-bound access to the live channel map.
20pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
21
22/// Default timeout in seconds when waiting for a user response.
23const DEFAULT_TIMEOUT_SECS: u64 = 300;
24
25/// Agent-callable tool for sending a question to a user and waiting for their response.
26pub struct AskUserTool {
27    security: Arc<SecurityPolicy>,
28    channels: ChannelMapHandle,
29}
30
31impl AskUserTool {
32    /// Create a new ask_user tool using the given channel map.
33    pub fn new(security: Arc<SecurityPolicy>, channels: ChannelMapHandle) -> Self {
34        Self { security, channels }
35    }
36}
37
38/// Format a question with optional choices for display.
39fn format_question(question: &str, choices: Option<&[String]>) -> String {
40    let mut lines = Vec::new();
41    lines.push(format!("**{question}**"));
42
43    if let Some(choices) = choices {
44        lines.push(String::new());
45        for (i, choice) in choices.iter().enumerate() {
46            lines.push(format!("{}. {choice}", i + 1));
47        }
48        lines.push(String::new());
49        lines.push("_Reply with a number or type your answer._".to_string());
50    }
51
52    lines.join("\n")
53}
54
55#[async_trait]
56impl Tool for AskUserTool {
57    fn name(&self) -> &str {
58        "ask_user"
59    }
60
61    fn description(&self) -> &str {
62        "Ask the user a question and wait for their response. \
63         Sends the question to a messaging channel and blocks until the user replies \
64         or the timeout expires. Optionally provide choices for structured responses."
65    }
66
67    fn parameters_schema(&self) -> serde_json::Value {
68        json!({
69            "type": "object",
70            "properties": {
71                "question": {
72                    "type": "string",
73                    "description": "The question to ask the user"
74                },
75                "choices": {
76                    "type": "array",
77                    "items": { "type": "string" },
78                    "description": "Optional list of choices (renders as buttons on Telegram, numbered list on CLI)"
79                },
80                "timeout_secs": {
81                    "type": "integer",
82                    "description": "Seconds to wait for a response (default: 300)"
83                },
84                "channel": {
85                    "type": "string",
86                    "description": "Target channel name. Defaults to the first available channel if omitted."
87                }
88            },
89            "required": ["question"]
90        })
91    }
92
93    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
94        // Security gate: Act operation
95        if let Err(e) = self
96            .security
97            .enforce_tool_operation(ToolOperation::Act, "ask_user")
98        {
99            return Ok(ToolResult {
100                success: false,
101                output: String::new(),
102                error: Some(format!("Action blocked: {e}")),
103            });
104        }
105
106        // Parse required params
107        let question = args
108            .get("question")
109            .and_then(|v| v.as_str())
110            .map(|s| s.trim())
111            .filter(|s| !s.is_empty())
112            .ok_or_else(|| {
113                ::zeroclaw_log::record!(
114                    WARN,
115                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
116                        .with_outcome(::zeroclaw_log::EventOutcome::Failure)
117                        .with_attrs(::serde_json::json!({"param": "question"})),
118                    "ask_user: missing question parameter"
119                );
120                anyhow::Error::msg("Missing 'question' parameter")
121            })?
122            .to_string();
123
124        let choices: Option<Vec<String>> = args.get("choices").and_then(|v| {
125            v.as_array().map(|arr| {
126                arr.iter()
127                    .filter_map(|item| item.as_str().map(|s| s.trim().to_string()))
128                    .filter(|s| !s.is_empty())
129                    .collect()
130            })
131        });
132
133        let timeout_secs = args
134            .get("timeout_secs")
135            .and_then(|v| v.as_u64())
136            .unwrap_or(DEFAULT_TIMEOUT_SECS);
137
138        let requested_channel = args
139            .get("channel")
140            .and_then(|v| v.as_str())
141            .map(|s| s.trim().to_string());
142
143        // Resolve channel from handle — block-scoped to drop the RwLock guard
144        // before any `.await` (parking_lot guards are !Send).
145        let (channel_name, channel): (String, Arc<dyn Channel>) = {
146            let channels = self.channels.read();
147            if channels.is_empty() {
148                return Ok(ToolResult {
149                    success: false,
150                    output: String::new(),
151                    error: Some("No channels available yet (channels not initialized)".to_string()),
152                });
153            }
154            if let Some(ref name) = requested_channel {
155                let ch = channels.get(name.as_str()).cloned().ok_or_else(|| {
156                    let available = channels.keys().cloned().collect::<Vec<_>>().join(", ");
157                    ::zeroclaw_log::record!(
158                        WARN,
159                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
160                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
161                            .with_attrs(::serde_json::json!({
162                                "channel_requested": name,
163                                "available": &available,
164                            })),
165                        "ask_user: requested channel not found"
166                    );
167                    anyhow::Error::msg(format!(
168                        "Channel '{name}' not found. Available: {available}"
169                    ))
170                })?;
171                (name.clone(), ch)
172            } else {
173                let (name, ch) = channels.iter().next().ok_or_else(|| {
174                    ::zeroclaw_log::record!(
175                        ERROR,
176                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
177                            .with_outcome(::zeroclaw_log::EventOutcome::Failure)
178                            .with_attrs(::serde_json::json!({"missing": "channels"})),
179                        "ask_user: no channels configured"
180                    );
181                    anyhow::Error::msg("No channels available. Configure at least one channel.")
182                })?;
183                (name.clone(), ch.clone())
184            }
185        };
186
187        let timeout = std::time::Duration::from_secs(timeout_secs);
188
189        // Prefer the channel's native structured-choice flow when choices are
190        // present (e.g. ACP `session/request_permission`, Telegram inline
191        // keyboard). Channels that don't implement it return `Ok(None)` and
192        // we fall through to the generic send + listen path.
193        if let Some(ref choices_vec) = choices
194            && !choices_vec.is_empty()
195        {
196            match channel
197                .request_choice(&question, choices_vec, timeout)
198                .await
199            {
200                Ok(Some(answer)) => {
201                    return Ok(ToolResult {
202                        success: true,
203                        output: answer,
204                        error: None,
205                    });
206                }
207                Ok(None) => { /* fall through to send+listen */ }
208                Err(e) => {
209                    return Ok(ToolResult {
210                        success: false,
211                        output: String::new(),
212                        error: Some(format!(
213                            "Failed to ask question on channel '{channel_name}': {e}"
214                        )),
215                    });
216                }
217            }
218        } else if !channel.supports_free_form_ask() {
219            // Free-form ask_user has no first-class ACP method yet. The ACP
220            // elicitation RFD is the future fix — until it lands, agents
221            // talking to ACP clients must supply `choices` so we can route
222            // through `session/request_permission`.
223            // RFD: https://github.com/zed-industries/agent-client-protocol/blob/main/docs/rfds/elicitation.mdx
224            return Ok(ToolResult {
225                success: false,
226                output: String::new(),
227                error: Some(format!(
228                    "Channel '{channel_name}' requires `choices` for ask_user \
229                     (free-form questions await ACP elicitation RFD)"
230                )),
231            });
232        }
233
234        // Format and send the question
235        let text = format_question(&question, choices.as_deref());
236        let msg = SendMessage::new(&text, "");
237        if let Err(e) = channel.send(&msg).await {
238            return Ok(ToolResult {
239                success: false,
240                output: String::new(),
241                error: Some(format!(
242                    "Failed to send question to channel '{channel_name}': {e}"
243                )),
244            });
245        }
246
247        // Listen for user response with timeout
248        let (tx, mut rx) = tokio::sync::mpsc::channel::<ChannelMessage>(1);
249
250        // Spawn a listener task on the channel
251        let listen_channel = Arc::clone(&channel);
252        let listen_handle = tokio::spawn(async move { listen_channel.listen(tx).await });
253
254        let response = tokio::time::timeout(timeout, rx.recv()).await;
255
256        // Abort the listener once we have a response or timeout
257        listen_handle.abort();
258
259        match response {
260            Ok(Some(msg)) => Ok(ToolResult {
261                success: true,
262                output: msg.content,
263                error: None,
264            }),
265            Ok(None) => Ok(ToolResult {
266                success: false,
267                output: "TIMEOUT".to_string(),
268                error: Some("Channel closed before receiving a response".to_string()),
269            }),
270            Err(_) => Ok(ToolResult {
271                success: false,
272                output: "TIMEOUT".to_string(),
273                error: Some(format!(
274                    "No response received within {timeout_secs} seconds"
275                )),
276            }),
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    /// A stub channel that records sent messages but never produces incoming messages.
286    struct SilentChannel {
287        channel_name: String,
288        sent: Arc<RwLock<Vec<String>>>,
289    }
290
291    impl SilentChannel {
292        fn new(name: &str) -> Self {
293            Self {
294                channel_name: name.to_string(),
295                sent: Arc::new(RwLock::new(Vec::new())),
296            }
297        }
298    }
299
300    impl ::zeroclaw_api::attribution::Attributable for SilentChannel {
301        fn role(&self) -> ::zeroclaw_api::attribution::Role {
302            ::zeroclaw_api::attribution::Role::Channel(
303                ::zeroclaw_api::attribution::ChannelKind::Webhook,
304            )
305        }
306        fn alias(&self) -> &str {
307            "test"
308        }
309    }
310
311    #[async_trait]
312    impl Channel for SilentChannel {
313        fn name(&self) -> &str {
314            &self.channel_name
315        }
316
317        async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
318            self.sent.write().push(message.content.clone());
319            Ok(())
320        }
321
322        async fn listen(
323            &self,
324            _tx: tokio::sync::mpsc::Sender<ChannelMessage>,
325        ) -> anyhow::Result<()> {
326            // Never sends anything — simulates no user response
327            tokio::time::sleep(std::time::Duration::from_secs(600)).await;
328            Ok(())
329        }
330    }
331
332    /// A stub channel that immediately responds with a canned message.
333    struct RespondingChannel {
334        channel_name: String,
335        response: String,
336        sent: Arc<RwLock<Vec<String>>>,
337    }
338
339    impl RespondingChannel {
340        fn new(name: &str, response: &str) -> Self {
341            Self {
342                channel_name: name.to_string(),
343                response: response.to_string(),
344                sent: Arc::new(RwLock::new(Vec::new())),
345            }
346        }
347    }
348
349    impl ::zeroclaw_api::attribution::Attributable for RespondingChannel {
350        fn role(&self) -> ::zeroclaw_api::attribution::Role {
351            ::zeroclaw_api::attribution::Role::Channel(
352                ::zeroclaw_api::attribution::ChannelKind::Webhook,
353            )
354        }
355        fn alias(&self) -> &str {
356            "test"
357        }
358    }
359
360    #[async_trait]
361    impl Channel for RespondingChannel {
362        fn name(&self) -> &str {
363            &self.channel_name
364        }
365
366        async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
367            self.sent.write().push(message.content.clone());
368            Ok(())
369        }
370
371        async fn listen(
372            &self,
373            tx: tokio::sync::mpsc::Sender<ChannelMessage>,
374        ) -> anyhow::Result<()> {
375            let msg = ChannelMessage {
376                id: "resp_1".to_string(),
377                sender: "user".to_string(),
378                reply_target: "user".to_string(),
379                content: self.response.clone(),
380                channel: self.channel_name.clone(),
381                channel_alias: None,
382                timestamp: 1000,
383                thread_ts: None,
384                interruption_scope_id: None,
385                attachments: vec![],
386                subject: None,
387            };
388            let _ = tx.send(msg).await;
389            Ok(())
390        }
391    }
392
393    fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> AskUserTool {
394        let handle = Arc::new(RwLock::new(HashMap::new()));
395        {
396            let mut map = handle.write();
397            for (name, ch) in channels {
398                map.insert(name.to_string(), ch);
399            }
400        }
401        AskUserTool::new(Arc::new(SecurityPolicy::default()), handle)
402    }
403
404    // ── Metadata tests ──
405
406    #[test]
407    fn tool_name_and_description() {
408        let tool = AskUserTool::new(
409            Arc::new(SecurityPolicy::default()),
410            Arc::new(RwLock::new(HashMap::new())),
411        );
412        assert_eq!(tool.name(), "ask_user");
413        assert!(!tool.description().is_empty());
414        assert!(tool.description().contains("question"));
415    }
416
417    #[test]
418    fn parameter_schema_validation() {
419        let tool = AskUserTool::new(
420            Arc::new(SecurityPolicy::default()),
421            Arc::new(RwLock::new(HashMap::new())),
422        );
423        let schema = tool.parameters_schema();
424        assert_eq!(schema["type"], "object");
425        assert!(schema["properties"]["question"].is_object());
426        assert!(schema["properties"]["choices"].is_object());
427        assert!(schema["properties"]["timeout_secs"].is_object());
428        assert!(schema["properties"]["channel"].is_object());
429        let required = schema["required"].as_array().unwrap();
430        assert!(required.iter().any(|v| v == "question"));
431        // choices, timeout_secs, channel are optional
432        assert!(!required.iter().any(|v| v == "choices"));
433        assert!(!required.iter().any(|v| v == "timeout_secs"));
434        assert!(!required.iter().any(|v| v == "channel"));
435    }
436
437    #[test]
438    fn spec_matches_metadata() {
439        let tool = AskUserTool::new(
440            Arc::new(SecurityPolicy::default()),
441            Arc::new(RwLock::new(HashMap::new())),
442        );
443        let spec = tool.spec();
444        assert_eq!(spec.name, "ask_user");
445        assert_eq!(spec.description, tool.description());
446        assert!(spec.parameters["required"].is_array());
447    }
448
449    // ── Format question tests ──
450
451    #[test]
452    fn format_question_without_choices() {
453        let text = format_question("Are you sure?", None);
454        assert!(text.contains("Are you sure?"));
455        assert!(!text.contains("1."));
456    }
457
458    #[test]
459    fn format_question_with_choices() {
460        let choices = vec!["Yes".to_string(), "No".to_string(), "Maybe".to_string()];
461        let text = format_question("Continue?", Some(&choices));
462        assert!(text.contains("Continue?"));
463        assert!(text.contains("1. Yes"));
464        assert!(text.contains("2. No"));
465        assert!(text.contains("3. Maybe"));
466        assert!(text.contains("Reply with a number"));
467    }
468
469    // ── Execute tests ──
470
471    #[tokio::test]
472    async fn execute_rejects_missing_question() {
473        let tool = make_tool_with_channels(vec![(
474            "test",
475            Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
476        )]);
477        let result = tool.execute(json!({})).await;
478        assert!(result.is_err());
479    }
480
481    #[tokio::test]
482    async fn execute_rejects_empty_question() {
483        let tool = make_tool_with_channels(vec![(
484            "test",
485            Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
486        )]);
487        let result = tool.execute(json!({ "question": "  " })).await;
488        assert!(result.is_err());
489    }
490
491    #[tokio::test]
492    async fn empty_channels_returns_not_initialized() {
493        let tool = AskUserTool::new(
494            Arc::new(SecurityPolicy::default()),
495            Arc::new(RwLock::new(HashMap::new())),
496        );
497        let result = tool.execute(json!({ "question": "Hello?" })).await.unwrap();
498        assert!(!result.success);
499        assert!(result.error.as_deref().unwrap().contains("not initialized"));
500    }
501
502    #[tokio::test]
503    async fn unknown_channel_returns_error() {
504        let tool = make_tool_with_channels(vec![(
505            "slack",
506            Arc::new(SilentChannel::new("slack")) as Arc<dyn Channel>,
507        )]);
508        let result = tool
509            .execute(json!({ "question": "Hello?", "channel": "nonexistent" }))
510            .await;
511        assert!(result.is_err());
512    }
513
514    #[tokio::test]
515    async fn timeout_returns_timeout_output() {
516        let tool = make_tool_with_channels(vec![(
517            "test",
518            Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
519        )]);
520        let result = tool
521            .execute(json!({
522                "question": "Confirm?",
523                "timeout_secs": 1
524            }))
525            .await
526            .unwrap();
527        assert!(!result.success);
528        assert_eq!(result.output, "TIMEOUT");
529        assert!(result.error.as_deref().unwrap().contains("1 seconds"));
530    }
531
532    #[tokio::test]
533    async fn successful_response_flow() {
534        let tool = make_tool_with_channels(vec![(
535            "test",
536            Arc::new(RespondingChannel::new("test", "Yes, proceed!")) as Arc<dyn Channel>,
537        )]);
538        let result = tool
539            .execute(json!({
540                "question": "Should we deploy?",
541                "timeout_secs": 5
542            }))
543            .await
544            .unwrap();
545        assert!(result.success, "error: {:?}", result.error);
546        assert_eq!(result.output, "Yes, proceed!");
547        assert!(result.error.is_none());
548    }
549
550    #[tokio::test]
551    async fn successful_response_with_choices() {
552        let tool = make_tool_with_channels(vec![(
553            "telegram",
554            Arc::new(RespondingChannel::new("telegram", "2")) as Arc<dyn Channel>,
555        )]);
556        let result = tool
557            .execute(json!({
558                "question": "Pick an option",
559                "choices": ["Option A", "Option B"],
560                "channel": "telegram",
561                "timeout_secs": 5
562            }))
563            .await
564            .unwrap();
565        assert!(result.success, "error: {:?}", result.error);
566        assert_eq!(result.output, "2");
567    }
568
569    #[tokio::test]
570    async fn channel_map_handle_allows_late_binding() {
571        let handle = Arc::new(RwLock::new(HashMap::new()));
572        let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()), handle.clone());
573
574        // Initially empty — tool reports not initialized
575        let result = tool.execute(json!({ "question": "Hello?" })).await.unwrap();
576        assert!(!result.success);
577
578        // Populate via the shared handle
579        {
580            let mut map = handle.write();
581            map.insert(
582                "cli".to_string(),
583                Arc::new(RespondingChannel::new("cli", "ok")) as Arc<dyn Channel>,
584            );
585        }
586
587        // Now the tool can route to the channel
588        let result = tool
589            .execute(json!({ "question": "Hello?", "timeout_secs": 5 }))
590            .await
591            .unwrap();
592        assert!(result.success);
593        assert_eq!(result.output, "ok");
594    }
595}