zeroclaw_runtime/rpc/
approval_channel.rs1use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use serde_json::json;
9use uuid::Uuid;
10
11use zeroclaw_api::attribution::{Attributable, ChannelKind, Role};
12use zeroclaw_api::channel::{
13 Channel, ChannelApprovalRequest, ChannelApprovalResponse, ChannelMessage, SendMessage,
14};
15use zeroclaw_api::jsonrpc::RpcOutbound;
16
17use super::context::ApprovalPendingMap;
18
19const DEFAULT_APPROVAL_TIMEOUT: Duration = Duration::from_secs(120);
20
21pub struct RpcApprovalChannel {
22 name: String,
23 session_id: String,
24 rpc: Arc<RpcOutbound>,
25 pending: Arc<ApprovalPendingMap>,
26 approval_timeout: Duration,
27}
28
29impl RpcApprovalChannel {
30 pub fn new(
31 name: impl Into<String>,
32 session_id: impl Into<String>,
33 rpc: Arc<RpcOutbound>,
34 pending: Arc<ApprovalPendingMap>,
35 ) -> Self {
36 Self {
37 name: name.into(),
38 session_id: session_id.into(),
39 rpc,
40 pending,
41 approval_timeout: DEFAULT_APPROVAL_TIMEOUT,
42 }
43 }
44}
45
46impl Attributable for RpcApprovalChannel {
47 fn role(&self) -> Role {
48 Role::Channel(ChannelKind::AcpChannel)
49 }
50
51 fn alias(&self) -> &str {
52 &self.name
53 }
54}
55
56#[async_trait]
57impl Channel for RpcApprovalChannel {
58 fn name(&self) -> &str {
59 &self.name
60 }
61
62 async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
63 Ok(())
64 }
65
66 async fn listen(&self, _tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
67 anyhow::bail!("RpcApprovalChannel.listen is not supported")
68 }
69
70 async fn request_approval(
71 &self,
72 recipient: &str,
73 request: &ChannelApprovalRequest,
74 ) -> anyhow::Result<Option<ChannelApprovalResponse>> {
75 self.request_approval_with_timeout(recipient, request, self.approval_timeout)
76 .await
77 }
78}
79
80impl RpcApprovalChannel {
81 pub async fn request_approval_with_timeout(
82 &self,
83 _recipient: &str,
84 request: &ChannelApprovalRequest,
85 timeout: Duration,
86 ) -> anyhow::Result<Option<ChannelApprovalResponse>> {
87 let request_id = Uuid::new_v4().to_string();
88 let (tx, rx) = tokio::sync::oneshot::channel::<ChannelApprovalResponse>();
89 self.pending.insert(request_id.clone(), tx);
90
91 self.rpc
92 .notify(
93 "session/update",
94 json!({
95 "type": "approval_request",
96 "session_id": self.session_id,
97 "request_id": request_id,
98 "tool_name": request.tool_name,
99 "arguments_summary": request.arguments_summary,
100 "timeout_secs": timeout.as_secs(),
101 }),
102 )
103 .await;
104
105 match tokio::time::timeout(timeout, rx).await {
106 Ok(Ok(response)) => Ok(Some(response)),
107 Ok(Err(_)) | Err(_) => Ok(Some(ChannelApprovalResponse::Deny)),
108 }
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use std::sync::Arc;
116 use tokio::sync::mpsc;
117 use zeroclaw_api::channel::{ChannelApprovalRequest, ChannelApprovalResponse};
118 use zeroclaw_api::jsonrpc::RpcOutbound;
119
120 fn make_rpc() -> (Arc<RpcOutbound>, mpsc::Receiver<String>) {
121 let (tx, rx) = mpsc::channel::<String>(16);
122 (Arc::new(RpcOutbound::new(tx)), rx)
123 }
124
125 fn make_pending() -> Arc<crate::rpc::context::ApprovalPendingMap> {
126 Arc::new(crate::rpc::context::ApprovalPendingMap::default())
127 }
128
129 #[tokio::test]
130 async fn sends_approval_request_notification_and_awaits_response() {
131 let (rpc, mut write_rx) = make_rpc();
132 let pending = make_pending();
133 let ch = RpcApprovalChannel::new("rpc", "sess-1", Arc::clone(&rpc), Arc::clone(&pending));
134
135 let request = ChannelApprovalRequest {
136 tool_name: "shell".to_string(),
137 arguments_summary: "ls /tmp".to_string(),
138 raw_arguments: None,
139 };
140
141 let pending_for_resolve = Arc::clone(&pending);
142 let task = zeroclaw_spawn::spawn!(async move { ch.request_approval("", &request).await });
143
144 let line = write_rx.recv().await.unwrap();
145 let v: serde_json::Value = serde_json::from_str(&line).unwrap();
146 assert_eq!(v["method"], "session/update");
147 assert_eq!(v["params"]["type"], "approval_request");
148 assert_eq!(v["params"]["session_id"], "sess-1");
149 assert_eq!(v["params"]["tool_name"], "shell");
150
151 let request_id = v["params"]["request_id"].as_str().unwrap().to_string();
152 pending_for_resolve.resolve(&request_id, ChannelApprovalResponse::Approve);
153
154 let result = task.await.unwrap().unwrap();
155 assert_eq!(result, Some(ChannelApprovalResponse::Approve));
156 }
157
158 #[tokio::test]
159 async fn times_out_and_auto_denies() {
160 let (rpc, _write_rx) = make_rpc();
161 let pending = make_pending();
162 let ch = RpcApprovalChannel::new("rpc", "sess-1", Arc::clone(&rpc), Arc::clone(&pending));
163 let request = ChannelApprovalRequest {
164 tool_name: "shell".to_string(),
165 arguments_summary: "rm -rf /".to_string(),
166 raw_arguments: None,
167 };
168 let result = ch
169 .request_approval_with_timeout("", &request, std::time::Duration::from_millis(50))
170 .await
171 .unwrap();
172 assert_eq!(result, Some(ChannelApprovalResponse::Deny));
173 }
174}