Skip to main content

zeroclaw_gateway/
node_tool.rs

1//! Wraps a node capability as a zeroclaw [`Tool`] so it can be dispatched
2//! through the existing tool registry and agent loop.
3//!
4//! Tool names are prefixed with the node ID: `node:<node_id>:<capability_name>`.
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use tokio::time::Duration;
10
11use crate::nodes::{NodeInvocation, NodeRegistry};
12use zeroclaw_api::attribution::ToolKind;
13use zeroclaw_api::tool::{Tool, ToolResult};
14use zeroclaw_api::tool_attribution;
15use zeroclaw_tools::node_capabilities::requires_approval;
16
17tool_attribution!(NodeTool, ToolKind::Plugin);
18
19/// Default timeout for node invocations (30 seconds).
20const NODE_INVOKE_TIMEOUT_SECS: u64 = 30;
21
22/// A zeroclaw [`Tool`] backed by a node capability.
23///
24/// The `prefixed_name` (e.g. `node:phone-1:camera.snap`) is what the agent
25/// loop sees. Invocations are routed to the connected node via WebSocket.
26pub struct NodeTool {
27    /// Prefixed name: `node:<node_id>:<capability_name>`.
28    prefixed_name: String,
29    /// The node ID this tool belongs to.
30    node_id: String,
31    /// The original capability name.
32    capability_name: String,
33    /// Human-readable description.
34    description: String,
35    /// JSON schema for parameters.
36    parameters: serde_json::Value,
37    /// Node registry for routing invocations.
38    registry: Arc<NodeRegistry>,
39}
40
41impl NodeTool {
42    /// Create a new node tool wrapper.
43    pub fn new(
44        node_id: String,
45        capability_name: String,
46        description: String,
47        parameters: serde_json::Value,
48        registry: Arc<NodeRegistry>,
49    ) -> Self {
50        let prefixed_name = format!("node:{node_id}:{capability_name}");
51        Self {
52            prefixed_name,
53            node_id,
54            capability_name,
55            description,
56            parameters,
57            registry,
58        }
59    }
60
61    /// Build the prefixed tool name for a node capability.
62    pub fn tool_name(node_id: &str, capability_name: &str) -> String {
63        format!("node:{node_id}:{capability_name}")
64    }
65}
66
67#[async_trait]
68impl Tool for NodeTool {
69    fn name(&self) -> &str {
70        &self.prefixed_name
71    }
72
73    fn description(&self) -> &str {
74        &self.description
75    }
76
77    fn parameters_schema(&self) -> serde_json::Value {
78        self.parameters.clone()
79    }
80
81    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
82        // Check if this capability requires approval
83        if requires_approval(&self.capability_name) {
84            let approved = args
85                .get("approved")
86                .and_then(|v| v.as_bool())
87                .unwrap_or(false);
88            if !approved {
89                return Ok(ToolResult {
90                    success: false,
91                    output: String::new(),
92                    error: Some(format!(
93                        "Capability '{}' requires approval. Set approved=true to proceed.",
94                        self.capability_name
95                    )),
96                });
97            }
98        }
99
100        // Strip the `approved` field (same as MCP tools)
101        let args = match args {
102            serde_json::Value::Object(mut map) => {
103                map.remove("approved");
104                serde_json::Value::Object(map)
105            }
106            other => other,
107        };
108
109        let invoke_tx: tokio::sync::mpsc::Sender<NodeInvocation> =
110            match self.registry.invoke_tx(&self.node_id) {
111                Some(tx) => tx,
112                None => {
113                    return Ok(ToolResult {
114                        success: false,
115                        output: String::new(),
116                        error: Some(format!("Node '{}' is not connected", self.node_id)),
117                    });
118                }
119            };
120
121        let call_id = uuid::Uuid::new_v4().to_string();
122        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
123
124        let invocation = NodeInvocation {
125            call_id,
126            capability: self.capability_name.clone(),
127            args,
128            response_tx,
129        };
130
131        if invoke_tx.send(invocation).await.is_err() {
132            return Ok(ToolResult {
133                success: false,
134                output: String::new(),
135                error: Some(format!(
136                    "Failed to send invocation to node '{}'",
137                    self.node_id
138                )),
139            });
140        }
141
142        // Wait for response with timeout
143        match tokio::time::timeout(Duration::from_secs(NODE_INVOKE_TIMEOUT_SECS), response_rx).await
144        {
145            Ok(Ok(result)) => Ok(ToolResult {
146                success: result.success,
147                output: result.output,
148                error: result.error,
149            }),
150            Ok(Err(_)) => Ok(ToolResult {
151                success: false,
152                output: String::new(),
153                error: Some(format!(
154                    "Node '{}' dropped the invocation channel",
155                    self.node_id
156                )),
157            }),
158            Err(_) => Ok(ToolResult {
159                success: false,
160                output: String::new(),
161                error: Some(format!(
162                    "Node '{}' invocation timed out after {NODE_INVOKE_TIMEOUT_SECS}s",
163                    self.node_id
164                )),
165            }),
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::nodes::{NodeCapability, NodeInfo, NodeRegistry};
174
175    #[test]
176    fn node_tool_name_format() {
177        assert_eq!(
178            NodeTool::tool_name("phone-1", "camera.snap"),
179            "node:phone-1:camera.snap"
180        );
181    }
182
183    #[test]
184    fn node_tool_metadata() {
185        let registry = Arc::new(NodeRegistry::new(10));
186        let tool = NodeTool::new(
187            "phone-1".to_string(),
188            "camera.snap".to_string(),
189            "Take a photo".to_string(),
190            serde_json::json!({"type": "object", "properties": {"resolution": {"type": "string"}}}),
191            registry,
192        );
193
194        assert_eq!(tool.name(), "node:phone-1:camera.snap");
195        assert_eq!(tool.description(), "Take a photo");
196        assert_eq!(tool.parameters_schema()["type"], "object");
197    }
198
199    #[tokio::test]
200    async fn node_tool_execute_node_not_connected() {
201        let registry = Arc::new(NodeRegistry::new(10));
202        let tool = NodeTool::new(
203            "missing-node".to_string(),
204            "test".to_string(),
205            "Test".to_string(),
206            serde_json::json!({"type": "object", "properties": {}}),
207            registry,
208        );
209
210        let result = tool.execute(serde_json::json!({})).await.unwrap();
211        assert!(!result.success);
212        assert!(result.error.unwrap().contains("not connected"));
213    }
214
215    #[tokio::test]
216    async fn node_tool_execute_success() {
217        let registry = Arc::new(NodeRegistry::new(10));
218        let (invoke_tx, mut invoke_rx) = tokio::sync::mpsc::channel(32);
219
220        registry.register(NodeInfo {
221            node_id: "test-node".to_string(),
222            capabilities: vec![NodeCapability {
223                name: "echo".to_string(),
224                description: "Echo back".to_string(),
225                parameters: serde_json::json!({"type": "object", "properties": {}}),
226            }],
227            invoke_tx,
228        });
229
230        let tool = NodeTool::new(
231            "test-node".to_string(),
232            "echo".to_string(),
233            "Echo back".to_string(),
234            serde_json::json!({"type": "object", "properties": {}}),
235            Arc::clone(&registry),
236        );
237
238        // Spawn a task that simulates the node responding
239        tokio::spawn(async move {
240            if let Some(invocation) = invoke_rx.recv().await {
241                let _ = invocation
242                    .response_tx
243                    .send(crate::nodes::NodeInvocationResult {
244                        success: true,
245                        output: "echoed".to_string(),
246                        error: None,
247                    });
248            }
249        });
250
251        let result = tool
252            .execute(serde_json::json!({"msg": "hello"}))
253            .await
254            .unwrap();
255        assert!(result.success);
256        assert_eq!(result.output, "echoed");
257        assert!(result.error.is_none());
258    }
259
260    #[test]
261    fn node_tool_spec_generation() {
262        let registry = Arc::new(NodeRegistry::new(10));
263        let tool = NodeTool::new(
264            "sensor-1".to_string(),
265            "temp.read".to_string(),
266            "Read temperature".to_string(),
267            serde_json::json!({"type": "object", "properties": {"unit": {"type": "string"}}}),
268            registry,
269        );
270
271        let spec = tool.spec();
272        assert_eq!(spec.name, "node:sensor-1:temp.read");
273        assert_eq!(spec.description, "Read temperature");
274        assert!(spec.parameters["properties"]["unit"]["type"] == "string");
275    }
276
277    #[tokio::test]
278    async fn node_tool_rejects_unapproved_sensitive_operation() {
279        let registry = Arc::new(NodeRegistry::new(10));
280        let tool = NodeTool::new(
281            "phone-1".to_string(),
282            "camera.snap".to_string(),
283            "Take a photo".to_string(),
284            serde_json::json!({
285                "type": "object",
286                "properties": {
287                    "approved": { "type": "boolean" }
288                },
289                "required": ["approved"]
290            }),
291            registry,
292        );
293
294        // Without approved field
295        let result = tool.execute(serde_json::json!({})).await.unwrap();
296        assert!(!result.success);
297        assert!(result.error.as_ref().unwrap().contains("requires approval"));
298
299        // With approved=false
300        let result = tool
301            .execute(serde_json::json!({"approved": false}))
302            .await
303            .unwrap();
304        assert!(!result.success);
305        assert!(result.error.as_ref().unwrap().contains("requires approval"));
306    }
307}