1use async_trait::async_trait;
10use parking_lot::RwLock;
11use serde_json::json;
12use std::collections::HashMap;
13use std::sync::Arc;
14use zeroclaw_api::channel::{Channel, ChannelMessage, SendMessage};
15use zeroclaw_api::tool::{Tool, ToolResult};
16use zeroclaw_config::policy::SecurityPolicy;
17use zeroclaw_config::policy::ToolOperation;
18
19pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
21
22const DEFAULT_TIMEOUT_SECS: u64 = 300;
24
25pub struct AskUserTool {
27 security: Arc<SecurityPolicy>,
28 channels: ChannelMapHandle,
29}
30
31impl AskUserTool {
32 pub fn new(security: Arc<SecurityPolicy>, channels: ChannelMapHandle) -> Self {
34 Self { security, channels }
35 }
36}
37
38fn format_question(question: &str, choices: Option<&[String]>) -> String {
40 let mut lines = Vec::new();
41 lines.push(format!("**{question}**"));
42
43 if let Some(choices) = choices {
44 lines.push(String::new());
45 for (i, choice) in choices.iter().enumerate() {
46 lines.push(format!("{}. {choice}", i + 1));
47 }
48 lines.push(String::new());
49 lines.push("_Reply with a number or type your answer._".to_string());
50 }
51
52 lines.join("\n")
53}
54
55#[async_trait]
56impl Tool for AskUserTool {
57 fn name(&self) -> &str {
58 "ask_user"
59 }
60
61 fn description(&self) -> &str {
62 "Ask the user a question and wait for their response. \
63 Sends the question to a messaging channel and blocks until the user replies \
64 or the timeout expires. Optionally provide choices for structured responses."
65 }
66
67 fn parameters_schema(&self) -> serde_json::Value {
68 json!({
69 "type": "object",
70 "properties": {
71 "question": {
72 "type": "string",
73 "description": "The question to ask the user"
74 },
75 "choices": {
76 "type": "array",
77 "items": { "type": "string" },
78 "description": "Optional list of choices (renders as buttons on Telegram, numbered list on CLI)"
79 },
80 "timeout_secs": {
81 "type": "integer",
82 "description": "Seconds to wait for a response (default: 300)"
83 },
84 "channel": {
85 "type": "string",
86 "description": "Target channel name. Defaults to the first available channel if omitted."
87 }
88 },
89 "required": ["question"]
90 })
91 }
92
93 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
94 if let Err(e) = self
96 .security
97 .enforce_tool_operation(ToolOperation::Act, "ask_user")
98 {
99 return Ok(ToolResult {
100 success: false,
101 output: String::new(),
102 error: Some(format!("Action blocked: {e}")),
103 });
104 }
105
106 let question = args
108 .get("question")
109 .and_then(|v| v.as_str())
110 .map(|s| s.trim())
111 .filter(|s| !s.is_empty())
112 .ok_or_else(|| {
113 ::zeroclaw_log::record!(
114 WARN,
115 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
116 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
117 .with_attrs(::serde_json::json!({"param": "question"})),
118 "ask_user: missing question parameter"
119 );
120 anyhow::Error::msg("Missing 'question' parameter")
121 })?
122 .to_string();
123
124 let choices: Option<Vec<String>> = args.get("choices").and_then(|v| {
125 v.as_array().map(|arr| {
126 arr.iter()
127 .filter_map(|item| item.as_str().map(|s| s.trim().to_string()))
128 .filter(|s| !s.is_empty())
129 .collect()
130 })
131 });
132
133 let timeout_secs = args
134 .get("timeout_secs")
135 .and_then(|v| v.as_u64())
136 .unwrap_or(DEFAULT_TIMEOUT_SECS);
137
138 let requested_channel = args
139 .get("channel")
140 .and_then(|v| v.as_str())
141 .map(|s| s.trim().to_string());
142
143 let (channel_name, channel): (String, Arc<dyn Channel>) = {
146 let channels = self.channels.read();
147 if channels.is_empty() {
148 return Ok(ToolResult {
149 success: false,
150 output: String::new(),
151 error: Some("No channels available yet (channels not initialized)".to_string()),
152 });
153 }
154 if let Some(ref name) = requested_channel {
155 let ch = channels.get(name.as_str()).cloned().ok_or_else(|| {
156 let available = channels.keys().cloned().collect::<Vec<_>>().join(", ");
157 ::zeroclaw_log::record!(
158 WARN,
159 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
160 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
161 .with_attrs(::serde_json::json!({
162 "channel_requested": name,
163 "available": &available,
164 })),
165 "ask_user: requested channel not found"
166 );
167 anyhow::Error::msg(format!(
168 "Channel '{name}' not found. Available: {available}"
169 ))
170 })?;
171 (name.clone(), ch)
172 } else {
173 let (name, ch) = channels.iter().next().ok_or_else(|| {
174 ::zeroclaw_log::record!(
175 ERROR,
176 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
177 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
178 .with_attrs(::serde_json::json!({"missing": "channels"})),
179 "ask_user: no channels configured"
180 );
181 anyhow::Error::msg("No channels available. Configure at least one channel.")
182 })?;
183 (name.clone(), ch.clone())
184 }
185 };
186
187 let timeout = std::time::Duration::from_secs(timeout_secs);
188
189 if let Some(ref choices_vec) = choices
194 && !choices_vec.is_empty()
195 {
196 match channel
197 .request_choice(&question, choices_vec, timeout)
198 .await
199 {
200 Ok(Some(answer)) => {
201 return Ok(ToolResult {
202 success: true,
203 output: answer,
204 error: None,
205 });
206 }
207 Ok(None) => { }
208 Err(e) => {
209 return Ok(ToolResult {
210 success: false,
211 output: String::new(),
212 error: Some(format!(
213 "Failed to ask question on channel '{channel_name}': {e}"
214 )),
215 });
216 }
217 }
218 } else if !channel.supports_free_form_ask() {
219 return Ok(ToolResult {
225 success: false,
226 output: String::new(),
227 error: Some(format!(
228 "Channel '{channel_name}' requires `choices` for ask_user \
229 (free-form questions await ACP elicitation RFD)"
230 )),
231 });
232 }
233
234 let text = format_question(&question, choices.as_deref());
236 let msg = SendMessage::new(&text, "");
237 if let Err(e) = channel.send(&msg).await {
238 return Ok(ToolResult {
239 success: false,
240 output: String::new(),
241 error: Some(format!(
242 "Failed to send question to channel '{channel_name}': {e}"
243 )),
244 });
245 }
246
247 let (tx, mut rx) = tokio::sync::mpsc::channel::<ChannelMessage>(1);
249
250 let listen_channel = Arc::clone(&channel);
252 let listen_handle = tokio::spawn(async move { listen_channel.listen(tx).await });
253
254 let response = tokio::time::timeout(timeout, rx.recv()).await;
255
256 listen_handle.abort();
258
259 match response {
260 Ok(Some(msg)) => Ok(ToolResult {
261 success: true,
262 output: msg.content,
263 error: None,
264 }),
265 Ok(None) => Ok(ToolResult {
266 success: false,
267 output: "TIMEOUT".to_string(),
268 error: Some("Channel closed before receiving a response".to_string()),
269 }),
270 Err(_) => Ok(ToolResult {
271 success: false,
272 output: "TIMEOUT".to_string(),
273 error: Some(format!(
274 "No response received within {timeout_secs} seconds"
275 )),
276 }),
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 struct SilentChannel {
287 channel_name: String,
288 sent: Arc<RwLock<Vec<String>>>,
289 }
290
291 impl SilentChannel {
292 fn new(name: &str) -> Self {
293 Self {
294 channel_name: name.to_string(),
295 sent: Arc::new(RwLock::new(Vec::new())),
296 }
297 }
298 }
299
300 impl ::zeroclaw_api::attribution::Attributable for SilentChannel {
301 fn role(&self) -> ::zeroclaw_api::attribution::Role {
302 ::zeroclaw_api::attribution::Role::Channel(
303 ::zeroclaw_api::attribution::ChannelKind::Webhook,
304 )
305 }
306 fn alias(&self) -> &str {
307 "test"
308 }
309 }
310
311 #[async_trait]
312 impl Channel for SilentChannel {
313 fn name(&self) -> &str {
314 &self.channel_name
315 }
316
317 async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
318 self.sent.write().push(message.content.clone());
319 Ok(())
320 }
321
322 async fn listen(
323 &self,
324 _tx: tokio::sync::mpsc::Sender<ChannelMessage>,
325 ) -> anyhow::Result<()> {
326 tokio::time::sleep(std::time::Duration::from_secs(600)).await;
328 Ok(())
329 }
330 }
331
332 struct RespondingChannel {
334 channel_name: String,
335 response: String,
336 sent: Arc<RwLock<Vec<String>>>,
337 }
338
339 impl RespondingChannel {
340 fn new(name: &str, response: &str) -> Self {
341 Self {
342 channel_name: name.to_string(),
343 response: response.to_string(),
344 sent: Arc::new(RwLock::new(Vec::new())),
345 }
346 }
347 }
348
349 impl ::zeroclaw_api::attribution::Attributable for RespondingChannel {
350 fn role(&self) -> ::zeroclaw_api::attribution::Role {
351 ::zeroclaw_api::attribution::Role::Channel(
352 ::zeroclaw_api::attribution::ChannelKind::Webhook,
353 )
354 }
355 fn alias(&self) -> &str {
356 "test"
357 }
358 }
359
360 #[async_trait]
361 impl Channel for RespondingChannel {
362 fn name(&self) -> &str {
363 &self.channel_name
364 }
365
366 async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
367 self.sent.write().push(message.content.clone());
368 Ok(())
369 }
370
371 async fn listen(
372 &self,
373 tx: tokio::sync::mpsc::Sender<ChannelMessage>,
374 ) -> anyhow::Result<()> {
375 let msg = ChannelMessage {
376 id: "resp_1".to_string(),
377 sender: "user".to_string(),
378 reply_target: "user".to_string(),
379 content: self.response.clone(),
380 channel: self.channel_name.clone(),
381 channel_alias: None,
382 timestamp: 1000,
383 thread_ts: None,
384 interruption_scope_id: None,
385 attachments: vec![],
386 subject: None,
387 };
388 let _ = tx.send(msg).await;
389 Ok(())
390 }
391 }
392
393 fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> AskUserTool {
394 let handle = Arc::new(RwLock::new(HashMap::new()));
395 {
396 let mut map = handle.write();
397 for (name, ch) in channels {
398 map.insert(name.to_string(), ch);
399 }
400 }
401 AskUserTool::new(Arc::new(SecurityPolicy::default()), handle)
402 }
403
404 #[test]
407 fn tool_name_and_description() {
408 let tool = AskUserTool::new(
409 Arc::new(SecurityPolicy::default()),
410 Arc::new(RwLock::new(HashMap::new())),
411 );
412 assert_eq!(tool.name(), "ask_user");
413 assert!(!tool.description().is_empty());
414 assert!(tool.description().contains("question"));
415 }
416
417 #[test]
418 fn parameter_schema_validation() {
419 let tool = AskUserTool::new(
420 Arc::new(SecurityPolicy::default()),
421 Arc::new(RwLock::new(HashMap::new())),
422 );
423 let schema = tool.parameters_schema();
424 assert_eq!(schema["type"], "object");
425 assert!(schema["properties"]["question"].is_object());
426 assert!(schema["properties"]["choices"].is_object());
427 assert!(schema["properties"]["timeout_secs"].is_object());
428 assert!(schema["properties"]["channel"].is_object());
429 let required = schema["required"].as_array().unwrap();
430 assert!(required.iter().any(|v| v == "question"));
431 assert!(!required.iter().any(|v| v == "choices"));
433 assert!(!required.iter().any(|v| v == "timeout_secs"));
434 assert!(!required.iter().any(|v| v == "channel"));
435 }
436
437 #[test]
438 fn spec_matches_metadata() {
439 let tool = AskUserTool::new(
440 Arc::new(SecurityPolicy::default()),
441 Arc::new(RwLock::new(HashMap::new())),
442 );
443 let spec = tool.spec();
444 assert_eq!(spec.name, "ask_user");
445 assert_eq!(spec.description, tool.description());
446 assert!(spec.parameters["required"].is_array());
447 }
448
449 #[test]
452 fn format_question_without_choices() {
453 let text = format_question("Are you sure?", None);
454 assert!(text.contains("Are you sure?"));
455 assert!(!text.contains("1."));
456 }
457
458 #[test]
459 fn format_question_with_choices() {
460 let choices = vec!["Yes".to_string(), "No".to_string(), "Maybe".to_string()];
461 let text = format_question("Continue?", Some(&choices));
462 assert!(text.contains("Continue?"));
463 assert!(text.contains("1. Yes"));
464 assert!(text.contains("2. No"));
465 assert!(text.contains("3. Maybe"));
466 assert!(text.contains("Reply with a number"));
467 }
468
469 #[tokio::test]
472 async fn execute_rejects_missing_question() {
473 let tool = make_tool_with_channels(vec![(
474 "test",
475 Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
476 )]);
477 let result = tool.execute(json!({})).await;
478 assert!(result.is_err());
479 }
480
481 #[tokio::test]
482 async fn execute_rejects_empty_question() {
483 let tool = make_tool_with_channels(vec![(
484 "test",
485 Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
486 )]);
487 let result = tool.execute(json!({ "question": " " })).await;
488 assert!(result.is_err());
489 }
490
491 #[tokio::test]
492 async fn empty_channels_returns_not_initialized() {
493 let tool = AskUserTool::new(
494 Arc::new(SecurityPolicy::default()),
495 Arc::new(RwLock::new(HashMap::new())),
496 );
497 let result = tool.execute(json!({ "question": "Hello?" })).await.unwrap();
498 assert!(!result.success);
499 assert!(result.error.as_deref().unwrap().contains("not initialized"));
500 }
501
502 #[tokio::test]
503 async fn unknown_channel_returns_error() {
504 let tool = make_tool_with_channels(vec![(
505 "slack",
506 Arc::new(SilentChannel::new("slack")) as Arc<dyn Channel>,
507 )]);
508 let result = tool
509 .execute(json!({ "question": "Hello?", "channel": "nonexistent" }))
510 .await;
511 assert!(result.is_err());
512 }
513
514 #[tokio::test]
515 async fn timeout_returns_timeout_output() {
516 let tool = make_tool_with_channels(vec![(
517 "test",
518 Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
519 )]);
520 let result = tool
521 .execute(json!({
522 "question": "Confirm?",
523 "timeout_secs": 1
524 }))
525 .await
526 .unwrap();
527 assert!(!result.success);
528 assert_eq!(result.output, "TIMEOUT");
529 assert!(result.error.as_deref().unwrap().contains("1 seconds"));
530 }
531
532 #[tokio::test]
533 async fn successful_response_flow() {
534 let tool = make_tool_with_channels(vec![(
535 "test",
536 Arc::new(RespondingChannel::new("test", "Yes, proceed!")) as Arc<dyn Channel>,
537 )]);
538 let result = tool
539 .execute(json!({
540 "question": "Should we deploy?",
541 "timeout_secs": 5
542 }))
543 .await
544 .unwrap();
545 assert!(result.success, "error: {:?}", result.error);
546 assert_eq!(result.output, "Yes, proceed!");
547 assert!(result.error.is_none());
548 }
549
550 #[tokio::test]
551 async fn successful_response_with_choices() {
552 let tool = make_tool_with_channels(vec![(
553 "telegram",
554 Arc::new(RespondingChannel::new("telegram", "2")) as Arc<dyn Channel>,
555 )]);
556 let result = tool
557 .execute(json!({
558 "question": "Pick an option",
559 "choices": ["Option A", "Option B"],
560 "channel": "telegram",
561 "timeout_secs": 5
562 }))
563 .await
564 .unwrap();
565 assert!(result.success, "error: {:?}", result.error);
566 assert_eq!(result.output, "2");
567 }
568
569 #[tokio::test]
570 async fn channel_map_handle_allows_late_binding() {
571 let handle = Arc::new(RwLock::new(HashMap::new()));
572 let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()), handle.clone());
573
574 let result = tool.execute(json!({ "question": "Hello?" })).await.unwrap();
576 assert!(!result.success);
577
578 {
580 let mut map = handle.write();
581 map.insert(
582 "cli".to_string(),
583 Arc::new(RespondingChannel::new("cli", "ok")) as Arc<dyn Channel>,
584 );
585 }
586
587 let result = tool
589 .execute(json!({ "question": "Hello?", "timeout_secs": 5 }))
590 .await
591 .unwrap();
592 assert!(result.success);
593 assert_eq!(result.output, "ok");
594 }
595}