1use async_trait::async_trait;
2use serde_json::json;
3use std::sync::Arc;
4use zeroclaw_api::tool::{Tool, ToolResult};
5use zeroclaw_config::policy::{SecurityPolicy, ToolOperation};
6
7const NOTION_API_BASE: &str = "https://api.notion.com/v1";
8const NOTION_VERSION: &str = "2022-06-28";
9const NOTION_REQUEST_TIMEOUT_SECS: u64 = 30;
10const MAX_ERROR_BODY_CHARS: usize = 500;
12
13pub struct NotionTool {
17 api_key: String,
18 http: reqwest::Client,
19 security: Arc<SecurityPolicy>,
20}
21
22impl NotionTool {
23 pub fn new(api_key: String, security: Arc<SecurityPolicy>) -> Self {
25 Self {
26 api_key,
27 http: reqwest::Client::new(),
28 security,
29 }
30 }
31
32 fn headers(&self) -> anyhow::Result<reqwest::header::HeaderMap> {
34 let mut headers = reqwest::header::HeaderMap::new();
35 headers.insert(
36 "Authorization",
37 format!("Bearer {}", self.api_key).parse().map_err(|e| {
38 ::zeroclaw_log::record!(
39 WARN,
40 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
41 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
42 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
43 "notion_tool: invalid API key header value"
44 );
45 anyhow::Error::msg(format!("Invalid Notion API key header value: {e}"))
46 })?,
47 );
48 headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
49 headers.insert("Content-Type", "application/json".parse().unwrap());
50 Ok(headers)
51 }
52
53 async fn query_database(
55 &self,
56 database_id: &str,
57 filter: Option<&serde_json::Value>,
58 ) -> anyhow::Result<serde_json::Value> {
59 let url = format!("{NOTION_API_BASE}/databases/{database_id}/query");
60 let mut body = json!({});
61 if let Some(f) = filter {
62 body["filter"] = f.clone();
63 }
64 let resp = self
65 .http
66 .post(&url)
67 .headers(self.headers()?)
68 .json(&body)
69 .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
70 .send()
71 .await?;
72 let status = resp.status();
73 if !status.is_success() {
74 let text = resp.text().await.unwrap_or_default();
75 let truncated =
76 crate::util_helpers::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
77 anyhow::bail!("Notion query_database failed ({status}): {truncated}");
78 }
79 resp.json().await.map_err(Into::into)
80 }
81
82 async fn read_page(&self, page_id: &str) -> anyhow::Result<serde_json::Value> {
84 let url = format!("{NOTION_API_BASE}/pages/{page_id}");
85 let resp = self
86 .http
87 .get(&url)
88 .headers(self.headers()?)
89 .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
90 .send()
91 .await?;
92 let status = resp.status();
93 if !status.is_success() {
94 let text = resp.text().await.unwrap_or_default();
95 let truncated =
96 crate::util_helpers::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
97 anyhow::bail!("Notion read_page failed ({status}): {truncated}");
98 }
99 resp.json().await.map_err(Into::into)
100 }
101
102 async fn create_page(
104 &self,
105 properties: &serde_json::Value,
106 database_id: Option<&str>,
107 ) -> anyhow::Result<serde_json::Value> {
108 let url = format!("{NOTION_API_BASE}/pages");
109 let mut body = json!({ "properties": properties });
110 if let Some(db_id) = database_id {
111 body["parent"] = json!({ "database_id": db_id });
112 }
113 let resp = self
114 .http
115 .post(&url)
116 .headers(self.headers()?)
117 .json(&body)
118 .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
119 .send()
120 .await?;
121 let status = resp.status();
122 if !status.is_success() {
123 let text = resp.text().await.unwrap_or_default();
124 let truncated =
125 crate::util_helpers::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
126 anyhow::bail!("Notion create_page failed ({status}): {truncated}");
127 }
128 resp.json().await.map_err(Into::into)
129 }
130
131 async fn update_page(
133 &self,
134 page_id: &str,
135 properties: &serde_json::Value,
136 ) -> anyhow::Result<serde_json::Value> {
137 let url = format!("{NOTION_API_BASE}/pages/{page_id}");
138 let body = json!({ "properties": properties });
139 let resp = self
140 .http
141 .patch(&url)
142 .headers(self.headers()?)
143 .json(&body)
144 .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
145 .send()
146 .await?;
147 let status = resp.status();
148 if !status.is_success() {
149 let text = resp.text().await.unwrap_or_default();
150 let truncated =
151 crate::util_helpers::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
152 anyhow::bail!("Notion update_page failed ({status}): {truncated}");
153 }
154 resp.json().await.map_err(Into::into)
155 }
156
157 async fn search(&self, query: &str) -> anyhow::Result<serde_json::Value> {
159 let url = format!("{NOTION_API_BASE}/search");
160 let body = json!({ "query": query });
161 let resp = self
162 .http
163 .post(&url)
164 .headers(self.headers()?)
165 .json(&body)
166 .timeout(std::time::Duration::from_secs(NOTION_REQUEST_TIMEOUT_SECS))
167 .send()
168 .await?;
169 let status = resp.status();
170 if !status.is_success() {
171 let text = resp.text().await.unwrap_or_default();
172 let truncated =
173 crate::util_helpers::truncate_with_ellipsis(&text, MAX_ERROR_BODY_CHARS);
174 anyhow::bail!("Notion search failed ({status}): {truncated}");
175 }
176 resp.json().await.map_err(Into::into)
177 }
178}
179
180#[async_trait]
181impl Tool for NotionTool {
182 fn name(&self) -> &str {
183 "notion"
184 }
185
186 fn description(&self) -> &str {
187 "Interact with Notion: query databases, read/create/update pages, and search the workspace."
188 }
189
190 fn parameters_schema(&self) -> serde_json::Value {
191 json!({
192 "type": "object",
193 "properties": {
194 "action": {
195 "type": "string",
196 "enum": ["query_database", "read_page", "create_page", "update_page", "search"],
197 "description": "The Notion API action to perform"
198 },
199 "database_id": {
200 "type": "string",
201 "description": "Database ID (required for query_database, optional for create_page)"
202 },
203 "page_id": {
204 "type": "string",
205 "description": "Page ID (required for read_page and update_page)"
206 },
207 "filter": {
208 "type": "object",
209 "description": "Notion filter object for query_database"
210 },
211 "properties": {
212 "type": "object",
213 "description": "Properties object for create_page and update_page"
214 },
215 "query": {
216 "type": "string",
217 "description": "Search query string for the search action"
218 }
219 },
220 "required": ["action"]
221 })
222 }
223
224 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
225 let action = match args.get("action").and_then(|v| v.as_str()) {
226 Some(a) => a,
227 None => {
228 return Ok(ToolResult {
229 success: false,
230 output: String::new(),
231 error: Some("Missing required parameter: action".into()),
232 });
233 }
234 };
235
236 let operation = match action {
238 "query_database" | "read_page" | "search" => ToolOperation::Read,
239 "create_page" | "update_page" => ToolOperation::Act,
240 _ => {
241 return Ok(ToolResult {
242 success: false,
243 output: String::new(),
244 error: Some(format!(
245 "Unknown action: {action}. Valid actions: query_database, read_page, create_page, update_page, search"
246 )),
247 });
248 }
249 };
250
251 if let Err(error) = self.security.enforce_tool_operation(operation, "notion") {
252 return Ok(ToolResult {
253 success: false,
254 output: String::new(),
255 error: Some(error),
256 });
257 }
258
259 let result = match action {
260 "query_database" => {
261 let database_id = match args.get("database_id").and_then(|v| v.as_str()) {
262 Some(id) => id,
263 None => {
264 return Ok(ToolResult {
265 success: false,
266 output: String::new(),
267 error: Some("query_database requires database_id parameter".into()),
268 });
269 }
270 };
271 let filter = args.get("filter");
272 self.query_database(database_id, filter).await
273 }
274 "read_page" => {
275 let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
276 Some(id) => id,
277 None => {
278 return Ok(ToolResult {
279 success: false,
280 output: String::new(),
281 error: Some("read_page requires page_id parameter".into()),
282 });
283 }
284 };
285 self.read_page(page_id).await
286 }
287 "create_page" => {
288 let properties = match args.get("properties") {
289 Some(p) => p,
290 None => {
291 return Ok(ToolResult {
292 success: false,
293 output: String::new(),
294 error: Some("create_page requires properties parameter".into()),
295 });
296 }
297 };
298 let database_id = args.get("database_id").and_then(|v| v.as_str());
299 self.create_page(properties, database_id).await
300 }
301 "update_page" => {
302 let page_id = match args.get("page_id").and_then(|v| v.as_str()) {
303 Some(id) => id,
304 None => {
305 return Ok(ToolResult {
306 success: false,
307 output: String::new(),
308 error: Some("update_page requires page_id parameter".into()),
309 });
310 }
311 };
312 let properties = match args.get("properties") {
313 Some(p) => p,
314 None => {
315 return Ok(ToolResult {
316 success: false,
317 output: String::new(),
318 error: Some("update_page requires properties parameter".into()),
319 });
320 }
321 };
322 self.update_page(page_id, properties).await
323 }
324 "search" => {
325 let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
326 self.search(query).await
327 }
328 _ => unreachable!(), };
330
331 match result {
332 Ok(value) => Ok(ToolResult {
333 success: true,
334 output: serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()),
335 error: None,
336 }),
337 Err(e) => Ok(ToolResult {
338 success: false,
339 output: String::new(),
340 error: Some(e.to_string()),
341 }),
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use zeroclaw_config::policy::SecurityPolicy;
350
351 fn test_tool() -> NotionTool {
352 let security = Arc::new(SecurityPolicy::default());
353 NotionTool::new("test-key".into(), security)
354 }
355
356 #[test]
357 fn tool_name_is_notion() {
358 let tool = test_tool();
359 assert_eq!(tool.name(), "notion");
360 }
361
362 #[test]
363 fn parameters_schema_has_required_action() {
364 let tool = test_tool();
365 let schema = tool.parameters_schema();
366 let required = schema["required"].as_array().unwrap();
367 assert!(required.iter().any(|v| v.as_str() == Some("action")));
368 }
369
370 #[test]
371 fn parameters_schema_defines_all_actions() {
372 let tool = test_tool();
373 let schema = tool.parameters_schema();
374 let actions = schema["properties"]["action"]["enum"].as_array().unwrap();
375 let action_strs: Vec<&str> = actions.iter().filter_map(|v| v.as_str()).collect();
376 assert!(action_strs.contains(&"query_database"));
377 assert!(action_strs.contains(&"read_page"));
378 assert!(action_strs.contains(&"create_page"));
379 assert!(action_strs.contains(&"update_page"));
380 assert!(action_strs.contains(&"search"));
381 }
382
383 #[tokio::test]
384 async fn execute_missing_action_returns_error() {
385 let tool = test_tool();
386 let result = tool.execute(json!({})).await.unwrap();
387 assert!(!result.success);
388 assert!(result.error.as_deref().unwrap().contains("action"));
389 }
390
391 #[tokio::test]
392 async fn execute_unknown_action_returns_error() {
393 let tool = test_tool();
394 let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
395 assert!(!result.success);
396 assert!(result.error.as_deref().unwrap().contains("Unknown action"));
397 }
398
399 #[tokio::test]
400 async fn execute_query_database_missing_id_returns_error() {
401 let tool = test_tool();
402 let result = tool
403 .execute(json!({"action": "query_database"}))
404 .await
405 .unwrap();
406 assert!(!result.success);
407 assert!(result.error.as_deref().unwrap().contains("database_id"));
408 }
409
410 #[tokio::test]
411 async fn execute_read_page_missing_id_returns_error() {
412 let tool = test_tool();
413 let result = tool.execute(json!({"action": "read_page"})).await.unwrap();
414 assert!(!result.success);
415 assert!(result.error.as_deref().unwrap().contains("page_id"));
416 }
417
418 #[tokio::test]
419 async fn execute_create_page_missing_properties_returns_error() {
420 let tool = test_tool();
421 let result = tool
422 .execute(json!({"action": "create_page"}))
423 .await
424 .unwrap();
425 assert!(!result.success);
426 assert!(result.error.as_deref().unwrap().contains("properties"));
427 }
428
429 #[tokio::test]
430 async fn execute_update_page_missing_page_id_returns_error() {
431 let tool = test_tool();
432 let result = tool
433 .execute(json!({"action": "update_page", "properties": {}}))
434 .await
435 .unwrap();
436 assert!(!result.success);
437 assert!(result.error.as_deref().unwrap().contains("page_id"));
438 }
439
440 #[tokio::test]
441 async fn execute_update_page_missing_properties_returns_error() {
442 let tool = test_tool();
443 let result = tool
444 .execute(json!({"action": "update_page", "page_id": "test-id"}))
445 .await
446 .unwrap();
447 assert!(!result.success);
448 assert!(result.error.as_deref().unwrap().contains("properties"));
449 }
450}