Skip to main content

zeroclaw_runtime/rpc/
approval_channel.rs

1//! RpcApprovalChannel — bridges Channel::request_approval() to the
2//! daemon Unix socket RPC stream.
3
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use serde_json::json;
9use uuid::Uuid;
10
11use zeroclaw_api::attribution::{Attributable, ChannelKind, Role};
12use zeroclaw_api::channel::{
13    Channel, ChannelApprovalRequest, ChannelApprovalResponse, ChannelMessage, SendMessage,
14};
15use zeroclaw_api::jsonrpc::RpcOutbound;
16
17use super::context::ApprovalPendingMap;
18
19const DEFAULT_APPROVAL_TIMEOUT: Duration = Duration::from_secs(120);
20
21pub struct RpcApprovalChannel {
22    name: String,
23    session_id: String,
24    rpc: Arc<RpcOutbound>,
25    pending: Arc<ApprovalPendingMap>,
26    approval_timeout: Duration,
27}
28
29impl RpcApprovalChannel {
30    pub fn new(
31        name: impl Into<String>,
32        session_id: impl Into<String>,
33        rpc: Arc<RpcOutbound>,
34        pending: Arc<ApprovalPendingMap>,
35    ) -> Self {
36        Self {
37            name: name.into(),
38            session_id: session_id.into(),
39            rpc,
40            pending,
41            approval_timeout: DEFAULT_APPROVAL_TIMEOUT,
42        }
43    }
44}
45
46impl Attributable for RpcApprovalChannel {
47    fn role(&self) -> Role {
48        Role::Channel(ChannelKind::AcpChannel)
49    }
50
51    fn alias(&self) -> &str {
52        &self.name
53    }
54}
55
56#[async_trait]
57impl Channel for RpcApprovalChannel {
58    fn name(&self) -> &str {
59        &self.name
60    }
61
62    async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
63        Ok(())
64    }
65
66    async fn listen(&self, _tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
67        anyhow::bail!("RpcApprovalChannel.listen is not supported")
68    }
69
70    async fn request_approval(
71        &self,
72        recipient: &str,
73        request: &ChannelApprovalRequest,
74    ) -> anyhow::Result<Option<ChannelApprovalResponse>> {
75        self.request_approval_with_timeout(recipient, request, self.approval_timeout)
76            .await
77    }
78}
79
80impl RpcApprovalChannel {
81    pub async fn request_approval_with_timeout(
82        &self,
83        _recipient: &str,
84        request: &ChannelApprovalRequest,
85        timeout: Duration,
86    ) -> anyhow::Result<Option<ChannelApprovalResponse>> {
87        let request_id = Uuid::new_v4().to_string();
88        let (tx, rx) = tokio::sync::oneshot::channel::<ChannelApprovalResponse>();
89        self.pending.insert(request_id.clone(), tx);
90
91        self.rpc
92            .notify(
93                "session/update",
94                json!({
95                    "type": "approval_request",
96                    "session_id": self.session_id,
97                    "request_id": request_id,
98                    "tool_name": request.tool_name,
99                    "arguments_summary": request.arguments_summary,
100                    "timeout_secs": timeout.as_secs(),
101                }),
102            )
103            .await;
104
105        match tokio::time::timeout(timeout, rx).await {
106            Ok(Ok(response)) => Ok(Some(response)),
107            Ok(Err(_)) | Err(_) => Ok(Some(ChannelApprovalResponse::Deny)),
108        }
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use std::sync::Arc;
116    use tokio::sync::mpsc;
117    use zeroclaw_api::channel::{ChannelApprovalRequest, ChannelApprovalResponse};
118    use zeroclaw_api::jsonrpc::RpcOutbound;
119
120    fn make_rpc() -> (Arc<RpcOutbound>, mpsc::Receiver<String>) {
121        let (tx, rx) = mpsc::channel::<String>(16);
122        (Arc::new(RpcOutbound::new(tx)), rx)
123    }
124
125    fn make_pending() -> Arc<crate::rpc::context::ApprovalPendingMap> {
126        Arc::new(crate::rpc::context::ApprovalPendingMap::default())
127    }
128
129    #[tokio::test]
130    async fn sends_approval_request_notification_and_awaits_response() {
131        let (rpc, mut write_rx) = make_rpc();
132        let pending = make_pending();
133        let ch = RpcApprovalChannel::new("rpc", "sess-1", Arc::clone(&rpc), Arc::clone(&pending));
134
135        let request = ChannelApprovalRequest {
136            tool_name: "shell".to_string(),
137            arguments_summary: "ls /tmp".to_string(),
138            raw_arguments: None,
139        };
140
141        let pending_for_resolve = Arc::clone(&pending);
142        let task = zeroclaw_spawn::spawn!(async move { ch.request_approval("", &request).await });
143
144        let line = write_rx.recv().await.unwrap();
145        let v: serde_json::Value = serde_json::from_str(&line).unwrap();
146        assert_eq!(v["method"], "session/update");
147        assert_eq!(v["params"]["type"], "approval_request");
148        assert_eq!(v["params"]["session_id"], "sess-1");
149        assert_eq!(v["params"]["tool_name"], "shell");
150
151        let request_id = v["params"]["request_id"].as_str().unwrap().to_string();
152        pending_for_resolve.resolve(&request_id, ChannelApprovalResponse::Approve);
153
154        let result = task.await.unwrap().unwrap();
155        assert_eq!(result, Some(ChannelApprovalResponse::Approve));
156    }
157
158    #[tokio::test]
159    async fn times_out_and_auto_denies() {
160        let (rpc, _write_rx) = make_rpc();
161        let pending = make_pending();
162        let ch = RpcApprovalChannel::new("rpc", "sess-1", Arc::clone(&rpc), Arc::clone(&pending));
163        let request = ChannelApprovalRequest {
164            tool_name: "shell".to_string(),
165            arguments_summary: "rm -rf /".to_string(),
166            raw_arguments: None,
167        };
168        let result = ch
169            .request_approval_with_timeout("", &request, std::time::Duration::from_millis(50))
170            .await
171            .unwrap();
172        assert_eq!(result, Some(ChannelApprovalResponse::Deny));
173    }
174}