Skip to main content

zeroclaw_tools/
notion_tool.rs

1use async_trait::async_trait;
2use serde_json::json;
3use std::sync::Arc;
4use zeroclaw_api::tool::{Tool, ToolResult};
5use zeroclaw_config::policy::{SecurityPolicy, ToolOperation};
6
7const NOTION_API_BASE: &str = "https://api.notion.com/v1";
8const NOTION_VERSION: &str = "2022-06-28";
9const NOTION_REQUEST_TIMEOUT_SECS: u64 = 30;
10/// Maximum number of characters to include from an error response body.
11const MAX_ERROR_BODY_CHARS: usize = 500;
12
13/// Tool for interacting with the Notion API — query databases, read/create/update pages,
14/// and search the workspace. Each action is gated by the appropriate security operation
15/// (Read for queries, Act for mutations).
16pub struct NotionTool {
17    api_key: String,
18    http: reqwest::Client,
19    security: Arc<SecurityPolicy>,
20}
21
22impl NotionTool {
23    /// Create a new Notion tool with the given API key and security policy.
24    pub fn new(api_key: String, security: Arc<SecurityPolicy>) -> Self {
25        Self {
26            api_key,
27            http: reqwest::Client::new(),
28            security,
29        }
30    }
31
32    /// Build the standard Notion API headers (Authorization, version, content-type).
33    fn headers(&self) -> anyhow::Result<reqwest::header::HeaderMap> {
34        let mut headers = reqwest::header::HeaderMap::new();
35        headers.insert(
36            "Authorization",
37            format!("Bearer {}", self.api_key).parse().map_err(|e| {
38                ::zeroclaw_log::record!(
39                    WARN,
40                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
41                        .with_outcome(::zeroclaw_log::EventOutcome::Failure)
42                        .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
43                    "notion_tool: invalid API key header value"
44                );
45                anyhow::Error::msg(format!("Invalid Notion API key header value: {e}"))
46            })?,
47        );
48        headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
49        headers.insert("Content-Type", "application/json".parse().unwrap());
50        Ok(headers)
51    }
52
53    /// Query a Notion database with an optional filter.
54    async fn query_database(
55        &self,
56        database_id: &str,
57        filter: Option<&serde_json::Value>,
58    ) -> anyhow::Result<serde_json::Value> {
59        let url = format!("{NOTION_API_BASE}/databases/{database_id}/query");
60        let mut body = json!({});
61        if let Some(f) = filter {
62            body["filter"] = f.clone();
63        }
64        let resp = self
65            .http
66            .post(&url)
67            .headers(self.headers()?)
68            .json(&body)
69            .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
70            .send()
71            .await?;
72        let status = resp.status();
73        if !status.is_success() {
74            let text = resp.text().await.unwrap_or_default();
75            let truncated =
76                crate::util_helpers::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
77            anyhow::bail!("Notion query_database failed ({status}): {truncated}");
78        }
79        resp.json().await.map_err(Into::into)
80    }
81
82    /// Read a single Notion page by ID.
83    async fn read_page(&self, page_id: &str) -> anyhow::Result<serde_json::Value> {
84        let url = format!("{NOTION_API_BASE}/pages/{page_id}");
85        let resp = self
86            .http
87            .get(&url)
88            .headers(self.headers()?)
89            .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
90            .send()
91            .await?;
92        let status = resp.status();
93        if !status.is_success() {
94            let text = resp.text().await.unwrap_or_default();
95            let truncated =
96                crate::util_helpers::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
97            anyhow::bail!("Notion read_page failed ({status}): {truncated}");
98        }
99        resp.json().await.map_err(Into::into)
100    }
101
102    /// Create a new Notion page, optionally within a database.
103    async fn create_page(
104        &self,
105        properties: &serde_json::Value,
106        database_id: Option<&str>,
107    ) -> anyhow::Result<serde_json::Value> {
108        let url = format!("{NOTION_API_BASE}/pages");
109        let mut body = json!({ "properties": properties });
110        if let Some(db_id) = database_id {
111            body["parent"] = json!({ "database_id": db_id });
112        }
113        let resp = self
114            .http
115            .post(&url)
116            .headers(self.headers()?)
117            .json(&body)
118            .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
119            .send()
120            .await?;
121        let status = resp.status();
122        if !status.is_success() {
123            let text = resp.text().await.unwrap_or_default();
124            let truncated =
125                crate::util_helpers::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
126            anyhow::bail!("Notion create_page failed ({status}): {truncated}");
127        }
128        resp.json().await.map_err(Into::into)
129    }
130
131    /// Update an existing Notion page's properties.
132    async fn update_page(
133        &self,
134        page_id: &str,
135        properties: &serde_json::Value,
136    ) -> anyhow::Result<serde_json::Value> {
137        let url = format!("{NOTION_API_BASE}/pages/{page_id}");
138        let body = json!({ "properties": properties });
139        let resp = self
140            .http
141            .patch(&url)
142            .headers(self.headers()?)
143            .json(&body)
144            .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
145            .send()
146            .await?;
147        let status = resp.status();
148        if !status.is_success() {
149            let text = resp.text().await.unwrap_or_default();
150            let truncated =
151                crate::util_helpers::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
152            anyhow::bail!("Notion update_page failed ({status}): {truncated}");
153        }
154        resp.json().await.map_err(Into::into)
155    }
156
157    /// Search the Notion workspace by query string.
158    async fn search(&self, query: &str) -> anyhow::Result<serde_json::Value> {
159        let url = format!("{NOTION_API_BASE}/search");
160        let body = json!({ "query": query });
161        let resp = self
162            .http
163            .post(&url)
164            .headers(self.headers()?)
165            .json(&body)
166            .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
167            .send()
168            .await?;
169        let status = resp.status();
170        if !status.is_success() {
171            let text = resp.text().await.unwrap_or_default();
172            let truncated =
173                crate::util_helpers::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
174            anyhow::bail!("Notion search failed ({status}): {truncated}");
175        }
176        resp.json().await.map_err(Into::into)
177    }
178}
179
180#[async_trait]
181impl Tool for NotionTool {
182    fn name(&self) -> &str {
183        "notion"
184    }
185
186    fn description(&self) -> &str {
187        "Interact with Notion: query databases, read/create/update pages, and search the workspace."
188    }
189
190    fn parameters_schema(&self) -> serde_json::Value {
191        json!({
192            "type": "object",
193            "properties": {
194                "action": {
195                    "type": "string",
196                    "enum": ["query_database", "read_page", "create_page", "update_page", "search"],
197                    "description": "The Notion API action to perform"
198                },
199                "database_id": {
200                    "type": "string",
201                    "description": "Database ID (required for query_database, optional for create_page)"
202                },
203                "page_id": {
204                    "type": "string",
205                    "description": "Page ID (required for read_page and update_page)"
206                },
207                "filter": {
208                    "type": "object",
209                    "description": "Notion filter object for query_database"
210                },
211                "properties": {
212                    "type": "object",
213                    "description": "Properties object for create_page and update_page"
214                },
215                "query": {
216                    "type": "string",
217                    "description": "Search query string for the search action"
218                }
219            },
220            "required": ["action"]
221        })
222    }
223
224    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
225        let action = match args.get("action").and_then(|v| v.as_str()) {
226            Some(a) => a,
227            None => {
228                return Ok(ToolResult {
229                    success: false,
230                    output: String::new(),
231                    error: Some("Missing required parameter: action".into()),
232                });
233            }
234        };
235
236        // Enforce granular security: Read for queries, Act for mutations
237        let operation = match action {
238            "query_database" | "read_page" | "search" => ToolOperation::Read,
239            "create_page" | "update_page" => ToolOperation::Act,
240            _ => {
241                return Ok(ToolResult {
242                    success: false,
243                    output: String::new(),
244                    error: Some(format!(
245                        "Unknown action: {action}. Valid actions: query_database, read_page, create_page, update_page, search"
246                    )),
247                });
248            }
249        };
250
251        if let Err(error) = self.security.enforce_tool_operation(operation, "notion") {
252            return Ok(ToolResult {
253                success: false,
254                output: String::new(),
255                error: Some(error),
256            });
257        }
258
259        let result = match action {
260            "query_database" => {
261                let database_id = match args.get("database_id").and_then(|v| v.as_str()) {
262                    Some(id) => id,
263                    None => {
264                        return Ok(ToolResult {
265                            success: false,
266                            output: String::new(),
267                            error: Some("query_database requires database_id parameter".into()),
268                        });
269                    }
270                };
271                let filter = args.get("filter");
272                self.query_database(database_id, filter).await
273            }
274            "read_page" => {
275                let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
276                    Some(id) => id,
277                    None => {
278                        return Ok(ToolResult {
279                            success: false,
280                            output: String::new(),
281                            error: Some("read_page requires page_id parameter".into()),
282                        });
283                    }
284                };
285                self.read_page(page_id).await
286            }
287            "create_page" => {
288                let properties = match args.get("properties") {
289                    Some(p) => p,
290                    None => {
291                        return Ok(ToolResult {
292                            success: false,
293                            output: String::new(),
294                            error: Some("create_page requires properties parameter".into()),
295                        });
296                    }
297                };
298                let database_id = args.get("database_id").and_then(|v| v.as_str());
299                self.create_page(properties, database_id).await
300            }
301            "update_page" => {
302                let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
303                    Some(id) => id,
304                    None => {
305                        return Ok(ToolResult {
306                            success: false,
307                            output: String::new(),
308                            error: Some("update_page requires page_id parameter".into()),
309                        });
310                    }
311                };
312                let properties = match args.get("properties") {
313                    Some(p) => p,
314                    None => {
315                        return Ok(ToolResult {
316                            success: false,
317                            output: String::new(),
318                            error: Some("update_page requires properties parameter".into()),
319                        });
320                    }
321                };
322                self.update_page(page_id, properties).await
323            }
324            "search" => {
325                let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
326                self.search(query).await
327            }
328            _ => unreachable!(), // Already handled above
329        };
330
331        match result {
332            Ok(value) => Ok(ToolResult {
333                success: true,
334                output: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()),
335                error: None,
336            }),
337            Err(e) => Ok(ToolResult {
338                success: false,
339                output: String::new(),
340                error: Some(e.to_string()),
341            }),
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use zeroclaw_config::policy::SecurityPolicy;
350
351    fn test_tool() -> NotionTool {
352        let security = Arc::new(SecurityPolicy::default());
353        NotionTool::new("test-key".into(), security)
354    }
355
356    #[test]
357    fn tool_name_is_notion() {
358        let tool = test_tool();
359        assert_eq!(tool.name(), "notion");
360    }
361
362    #[test]
363    fn parameters_schema_has_required_action() {
364        let tool = test_tool();
365        let schema = tool.parameters_schema();
366        let required = schema["required"].as_array().unwrap();
367        assert!(required.iter().any(|v| v.as_str() == Some("action")));
368    }
369
370    #[test]
371    fn parameters_schema_defines_all_actions() {
372        let tool = test_tool();
373        let schema = tool.parameters_schema();
374        let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
375        let action_strs: Vec<&str> = actions.iter().filter_map(|v| v.as_str()).collect();
376        assert!(action_strs.contains(&"query_database"));
377        assert!(action_strs.contains(&"read_page"));
378        assert!(action_strs.contains(&"create_page"));
379        assert!(action_strs.contains(&"update_page"));
380        assert!(action_strs.contains(&"search"));
381    }
382
383    #[tokio::test]
384    async fn execute_missing_action_returns_error() {
385        let tool = test_tool();
386        let result = tool.execute(json!({})).await.unwrap();
387        assert!(!result.success);
388        assert!(result.error.as_deref().unwrap().contains("action"));
389    }
390
391    #[tokio::test]
392    async fn execute_unknown_action_returns_error() {
393        let tool = test_tool();
394        let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
395        assert!(!result.success);
396        assert!(result.error.as_deref().unwrap().contains("Unknown action"));
397    }
398
399    #[tokio::test]
400    async fn execute_query_database_missing_id_returns_error() {
401        let tool = test_tool();
402        let result = tool
403            .execute(json!({"action": "query_database"}))
404            .await
405            .unwrap();
406        assert!(!result.success);
407        assert!(result.error.as_deref().unwrap().contains("database_id"));
408    }
409
410    #[tokio::test]
411    async fn execute_read_page_missing_id_returns_error() {
412        let tool = test_tool();
413        let result = tool.execute(json!({"action": "read_page"})).await.unwrap();
414        assert!(!result.success);
415        assert!(result.error.as_deref().unwrap().contains("page_id"));
416    }
417
418    #[tokio::test]
419    async fn execute_create_page_missing_properties_returns_error() {
420        let tool = test_tool();
421        let result = tool
422            .execute(json!({"action": "create_page"}))
423            .await
424            .unwrap();
425        assert!(!result.success);
426        assert!(result.error.as_deref().unwrap().contains("properties"));
427    }
428
429    #[tokio::test]
430    async fn execute_update_page_missing_page_id_returns_error() {
431        let tool = test_tool();
432        let result = tool
433            .execute(json!({"action": "update_page", "properties": {}}))
434            .await
435            .unwrap();
436        assert!(!result.success);
437        assert!(result.error.as_deref().unwrap().contains("page_id"));
438    }
439
440    #[tokio::test]
441    async fn execute_update_page_missing_properties_returns_error() {
442        let tool = test_tool();
443        let result = tool
444            .execute(json!({"action": "update_page", "page_id": "test-id"}))
445            .await
446            .unwrap();
447        assert!(!result.success);
448        assert!(result.error.as_deref().unwrap().contains("properties"));
449    }
450}