1use anyhow::{Result, bail};
2use async_trait::async_trait;
3use std::collections::HashSet;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use zeroclaw_api::channel::{Channel, ChannelMessage, SendMessage};
7
8const NOTION_API_BASE: &str = "https://api.notion.com/v1";
9const NOTION_VERSION: &str = "2022-06-28";
10const MAX_RESULT_LENGTH: usize = 2000;
11const MAX_RETRIES: u32 = 3;
12const RETRY_BASE_DELAY_MS: u64 = 2000;
13const MAX_ERROR_BODY_CHARS: usize = 500;
15
16fn floor_utf8_char_boundary(s: &str, max_bytes: usize) -> usize {
18 if max_bytes >= s.len() {
19 return s.len();
20 }
21 let mut idx = max_bytes;
22 while idx > 0 && !s.is_char_boundary(idx) {
23 idx -= 1;
24 }
25 idx
26}
27
28pub struct NotionChannel {
34 api_key: String,
35 database_id: String,
36 poll_interval_secs: u64,
37 status_property: String,
38 input_property: String,
39 result_property: String,
40 max_concurrent: usize,
41 alias: String,
45 status_type: Arc<RwLock<String>>,
46 inflight: Arc<RwLock<HashSet<String>>>,
47 http: reqwest::Client,
48 recover_stale: bool,
49}
50
51impl NotionChannel {
52 pub fn new(
54 alias: impl Into<String>,
55 api_key: String,
56 database_id: String,
57 poll_interval_secs: u64,
58 status_property: String,
59 input_property: String,
60 result_property: String,
61 max_concurrent: usize,
62 recover_stale: bool,
63 ) -> Self {
64 Self {
65 api_key,
66 database_id,
67 poll_interval_secs,
68 status_property,
69 input_property,
70 result_property,
71 max_concurrent,
72 alias: alias.into(),
73 status_type: Arc::new(RwLock::new("select".to_string())),
74 inflight: Arc::new(RwLock::new(HashSet::new())),
75 http: reqwest::Client::new(),
76 recover_stale,
77 }
78 }
79
80 fn headers(&self) -> Result<reqwest::header::HeaderMap> {
82 let mut headers = reqwest::header::HeaderMap::new();
83 headers.insert(
84 "Authorization",
85 format!("Bearer {}", self.api_key).parse().map_err(|e| {
86 ::zeroclaw_log::record!(
87 WARN,
88 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
89 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
90 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
91 "Invalid Notion API key header value"
92 );
93 anyhow::Error::msg(format!("Invalid Notion API key header value: {e}"))
94 })?,
95 );
96 headers.insert("Notion-Version", NOTION_VERSION.parse().unwrap());
97 headers.insert("Content-Type", "application/json".parse().unwrap());
98 Ok(headers)
99 }
100
101 async fn api_call(
103 &self,
104 method: reqwest::Method,
105 url: &str,
106 body: Option<serde_json::Value>,
107 ) -> Result<serde_json::Value> {
108 let mut last_err = None;
109 for attempt in 0..MAX_RETRIES {
110 let mut req = self
111 .http
112 .request(method.clone(), url)
113 .headers(self.headers()?);
114 if let Some(ref b) = body {
115 req = req.json(b);
116 }
117 match req.send().await {
118 Ok(resp) => {
119 let status = resp.status();
120 if status.is_success() {
121 return resp.json().await.map_err(|e| {
122 ::zeroclaw_log::record!(
123 ERROR,
124 ::zeroclaw_log::Event::new(
125 module_path!(),
126 ::zeroclaw_log::Action::Fail
127 )
128 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
129 .with_attrs(::serde_json::json!({
130 "phase": "response_parse",
131 "error": format!("{}", e),
132 })),
133 "notion: failed to parse response JSON"
134 );
135 anyhow::Error::msg(format!("Failed to parse response: {e}"))
136 });
137 }
138 let status_code = status.as_u16();
139 if status_code != 429 && (400..500).contains(&status_code) {
141 let body_text = resp.text().await.unwrap_or_default();
142 let truncated =
143 crate::util::truncate_with_ellipsis(&body_text, MAX_ERROR_BODY_CHARS);
144 ::zeroclaw_log::record!(
145 ERROR,
146 ::zeroclaw_log::Event::new(
147 module_path!(),
148 ::zeroclaw_log::Action::Fail
149 )
150 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
151 .with_attrs(::serde_json::json!({
152 "status": status_code,
153 "body": truncated,
154 })),
155 "notion: API client error (no retry)"
156 );
157 bail!("API error {status_code}: {truncated}");
158 }
159 ::zeroclaw_log::record!(
160 WARN,
161 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
162 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
163 .with_attrs(::serde_json::json!({
164 "status": status_code,
165 "phase": "retryable_status",
166 })),
167 "notion: API returned retryable status"
168 );
169 last_err = Some(anyhow::Error::msg(format!("API error: {status_code}")));
170 }
171 Err(e) => {
172 ::zeroclaw_log::record!(
173 WARN,
174 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
175 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
176 .with_attrs(::serde_json::json!({
177 "phase": "transport",
178 "error": format!("{}", e),
179 })),
180 "notion: HTTP request failed"
181 );
182 last_err = Some(anyhow::Error::msg(format!("HTTP request failed: {e}")));
183 }
184 }
185 let delay = RETRY_BASE_DELAY_MS * 2u64.pow(attempt);
186 ::zeroclaw_log::record!(
187 WARN,
188 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
189 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
190 &format!(
191 "API call failed (attempt {}/{}), retrying in {}ms",
192 attempt + 1,
193 MAX_RETRIES,
194 delay
195 )
196 );
197 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
198 }
199 Err(last_err.unwrap_or_else(|| {
200 ::zeroclaw_log::record!(
201 ERROR,
202 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
203 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
204 "notion: API call exhausted retries"
205 );
206 anyhow::Error::msg("API call failed after retries")
207 }))
208 }
209
210 async fn detect_status_type(&self) -> Result<String> {
212 let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
213 let resp = self.api_call(reqwest::Method::GET, &url, None).await?;
214 let status_type = resp
215 .get("properties")
216 .and_then(|p| p.get(&self.status_property))
217 .and_then(|s| s.get("type"))
218 .and_then(|t| t.as_str())
219 .unwrap_or("select")
220 .to_string();
221 Ok(status_type)
222 }
223
224 async fn query_pending(&self) -> Result<Vec<serde_json::Value>> {
226 let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
227 let status_type = self.status_type.read().await.clone();
228 let filter = build_status_filter(&self.status_property, &status_type, "pending");
229 let resp = self
230 .api_call(
231 reqwest::Method::POST,
232 &url,
233 Some(serde_json::json!({ "filter": filter })),
234 )
235 .await?;
236 Ok(resp
237 .get("results")
238 .and_then(|r| r.as_array())
239 .cloned()
240 .unwrap_or_default())
241 }
242
243 async fn claim_task(&self, page_id: &str) -> bool {
245 let mut inflight = self.inflight.write().await;
246 if inflight.contains(page_id) {
247 return false;
248 }
249 if inflight.len() >= self.max_concurrent {
250 return false;
251 }
252 inflight.insert(page_id.to_string());
253 true
254 }
255
256 async fn release_task(&self, page_id: &str) {
258 let mut inflight = self.inflight.write().await;
259 inflight.remove(page_id);
260 }
261
262 async fn set_status(&self, page_id: &str, status_value: &str) -> Result<()> {
264 let url = format!("{NOTION_API_BASE}/pages/{page_id}");
265 let status_type = self.status_type.read().await.clone();
266 let payload = serde_json::json!({
267 "properties": {
268 &self.status_property: build_status_payload(&status_type, status_value),
269 }
270 });
271 self.api_call(reqwest::Method::PATCH, &url, Some(payload))
272 .await?;
273 Ok(())
274 }
275
276 async fn recover_stale(&self) -> Result<()> {
278 let url = format!("{NOTION_API_BASE}/databases/{}/query", self.database_id);
279 let status_type = self.status_type.read().await.clone();
280 let filter = build_status_filter(&self.status_property, &status_type, "running");
281 let resp = self
282 .api_call(
283 reqwest::Method::POST,
284 &url,
285 Some(serde_json::json!({ "filter": filter })),
286 )
287 .await?;
288 let stale = resp
289 .get("results")
290 .and_then(|r| r.as_array())
291 .cloned()
292 .unwrap_or_default();
293 if stale.is_empty() {
294 return Ok(());
295 }
296 ::zeroclaw_log::record!(
297 WARN,
298 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
299 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
300 &format!(
301 "Found {} stale task(s) in 'running' state, resetting to 'pending'",
302 stale.len()
303 )
304 );
305 for task in &stale {
306 if let Some(page_id) = task.get("id").and_then(|v| v.as_str()) {
307 let page_url = format!("{NOTION_API_BASE}/pages/{page_id}");
308 let payload = serde_json::json!({
309 "properties": {
310 &self.status_property: build_status_payload(&status_type, "pending"),
311 &self.result_property: build_rich_text_payload(
312 "Reset: poller restarted while task was running"
313 ),
314 }
315 });
316 let short_id_end = floor_utf8_char_boundary(page_id, 8);
317 let short_id = &page_id[..short_id_end];
318 if let Err(e) = self
319 .api_call(reqwest::Method::PATCH, &page_url, Some(payload))
320 .await
321 {
322 ::zeroclaw_log::record!(
323 ERROR,
324 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
325 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
326 .with_attrs(
327 ::serde_json::json!({"error": format!("{}", e), "short_id": short_id})
328 ),
329 "Could not reset stale task"
330 );
331 } else {
332 ::zeroclaw_log::record!(
333 INFO,
334 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
335 .with_attrs(::serde_json::json!({"short_id": short_id})),
336 "Reset stale task to pending"
337 );
338 }
339 }
340 }
341 Ok(())
342 }
343}
344
345impl ::zeroclaw_api::attribution::Attributable for NotionChannel {
346 fn role(&self) -> ::zeroclaw_api::attribution::Role {
347 ::zeroclaw_api::attribution::Role::Channel(::zeroclaw_api::attribution::ChannelKind::Notion)
348 }
349 fn alias(&self) -> &str {
350 &self.alias
351 }
352}
353
354#[async_trait]
355impl Channel for NotionChannel {
356 fn name(&self) -> &str {
357 "notion"
358 }
359
360 async fn send(&self, message: &SendMessage) -> Result<()> {
361 let page_id = &message.recipient;
363 let status_type = self.status_type.read().await.clone();
364 let url = format!("{NOTION_API_BASE}/pages/{page_id}");
365 let payload = serde_json::json!({
366 "properties": {
367 &self.status_property: build_status_payload(&status_type, "done"),
368 &self.result_property: build_rich_text_payload(&message.content),
369 }
370 });
371 self.api_call(reqwest::Method::PATCH, &url, Some(payload))
372 .await?;
373 self.release_task(page_id).await;
374 Ok(())
375 }
376
377 async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
378 match self.detect_status_type().await {
380 Ok(st) => {
381 ::zeroclaw_log::record!(
382 INFO,
383 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
384 .with_attrs(::serde_json::json!({"st": st})),
385 "status property type"
386 );
387 *self.status_type.write().await = st;
388 }
389 Err(e) => {
390 bail!("Failed to detect Notion database schema: {e}");
391 }
392 }
393
394 if self.recover_stale
396 && let Err(e) = self.recover_stale().await
397 {
398 ::zeroclaw_log::record!(
399 ERROR,
400 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
401 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
402 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
403 "stale task recovery failed"
404 );
405 }
406
407 loop {
409 match self.query_pending().await {
410 Ok(tasks) => {
411 if !tasks.is_empty() {
412 ::zeroclaw_log::record!(
413 INFO,
414 ::zeroclaw_log::Event::new(
415 module_path!(),
416 ::zeroclaw_log::Action::Note
417 ),
418 &format!("found {} pending task(s)", tasks.len())
419 );
420 }
421 for task in tasks {
422 let page_id = match task.get("id").and_then(|v| v.as_str()) {
423 Some(id) => id.to_string(),
424 None => continue,
425 };
426
427 let input_text = extract_text_from_property(
428 task.get("properties")
429 .and_then(|p| p.get(&self.input_property)),
430 );
431
432 if input_text.trim().is_empty() {
433 let short_end = floor_utf8_char_boundary(&page_id, 8);
434 ::zeroclaw_log::record!(
435 WARN,
436 ::zeroclaw_log::Event::new(
437 module_path!(),
438 ::zeroclaw_log::Action::Note
439 )
440 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
441 &format!(
442 "empty input for task {}, skipping",
443 &page_id[..short_end]
444 )
445 );
446 continue;
447 }
448
449 if !self.claim_task(&page_id).await {
450 continue;
451 }
452
453 if let Err(e) = self.set_status(&page_id, "running").await {
455 ::zeroclaw_log::record!(
456 ERROR,
457 ::zeroclaw_log::Event::new(
458 module_path!(),
459 ::zeroclaw_log::Action::Fail
460 )
461 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
462 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
463 "failed to set running status"
464 );
465 self.release_task(&page_id).await;
466 continue;
467 }
468
469 let timestamp = std::time::SystemTime::now()
470 .duration_since(std::time::UNIX_EPOCH)
471 .unwrap_or_default()
472 .as_secs();
473
474 if tx
475 .send(ChannelMessage {
476 id: page_id.clone(),
477 sender: "notion".into(),
478 reply_target: page_id,
479 content: input_text,
480 channel: "notion".into(),
481 channel_alias: None,
482 timestamp,
483 thread_ts: None,
484 interruption_scope_id: None,
485 attachments: vec![],
486 subject: None,
487 })
488 .await
489 .is_err()
490 {
491 ::zeroclaw_log::record!(
492 INFO,
493 ::zeroclaw_log::Event::new(
494 module_path!(),
495 ::zeroclaw_log::Action::Note
496 ),
497 "channel shutting down"
498 );
499 return Ok(());
500 }
501 }
502 }
503 Err(e) => {
504 ::zeroclaw_log::record!(
505 ERROR,
506 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
507 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
508 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
509 "poll error"
510 );
511 }
512 }
513
514 tokio::time::sleep(std::time::Duration::from_secs(self.poll_interval_secs)).await;
515 }
516 }
517
518 async fn health_check(&self) -> bool {
519 let url = format!("{NOTION_API_BASE}/databases/{}", self.database_id);
520 self.api_call(reqwest::Method::GET, &url, None)
521 .await
522 .is_ok()
523 }
524}
525
526fn build_status_filter(property: &str, status_type: &str, value: &str) -> serde_json::Value {
530 if status_type == "status" {
531 serde_json::json!({
532 "property": property,
533 "status": { "equals": value }
534 })
535 } else {
536 serde_json::json!({
537 "property": property,
538 "select": { "equals": value }
539 })
540 }
541}
542
543fn build_status_payload(status_type: &str, value: &str) -> serde_json::Value {
545 if status_type == "status" {
546 serde_json::json!({ "status": { "name": value } })
547 } else {
548 serde_json::json!({ "select": { "name": value } })
549 }
550}
551
552fn build_rich_text_payload(value: &str) -> serde_json::Value {
554 let truncated = truncate_result(value);
555 serde_json::json!({
556 "rich_text": [{
557 "text": { "content": truncated }
558 }]
559 })
560}
561
562fn truncate_result(value: &str) -> String {
564 if value.len() <= MAX_RESULT_LENGTH {
565 return value.to_string();
566 }
567 let cut = MAX_RESULT_LENGTH.saturating_sub(30);
568 let end = floor_utf8_char_boundary(value, cut);
570 format!("{}\n\n... [output truncated]", &value[..end])
571}
572
573fn extract_text_from_property(prop: Option<&serde_json::Value>) -> String {
575 let Some(prop) = prop else {
576 return String::new();
577 };
578 let ptype = prop.get("type").and_then(|t| t.as_str()).unwrap_or("");
579 let array_key = match ptype {
580 "title" => "title",
581 "rich_text" => "rich_text",
582 _ => return String::new(),
583 };
584 prop.get(array_key)
585 .and_then(|arr| arr.as_array())
586 .map(|items| {
587 items
588 .iter()
589 .filter_map(|item| item.get("plain_text").and_then(|t| t.as_str()))
590 .collect::<Vec<_>>()
591 .join("")
592 })
593 .unwrap_or_default()
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599
600 #[tokio::test]
601 async fn claim_task_deduplication() {
602 let channel = NotionChannel::new(
603 "testbot",
604 "test-key".into(),
605 "test-db".into(),
606 5,
607 "Status".into(),
608 "Input".into(),
609 "Result".into(),
610 4,
611 false,
612 );
613
614 assert!(channel.claim_task("page-1").await);
615 assert!(!channel.claim_task("page-1").await);
617 assert!(channel.claim_task("page-2").await);
619
620 channel.release_task("page-1").await;
622 assert!(channel.claim_task("page-1").await);
623 }
624
625 #[test]
626 fn result_truncation_within_limit() {
627 let short = "hello world";
628 assert_eq!(truncate_result(short), short);
629 }
630
631 #[test]
632 fn result_truncation_over_limit() {
633 let long = "a".repeat(MAX_RESULT_LENGTH + 100);
634 let truncated = truncate_result(&long);
635 assert!(truncated.len() <= MAX_RESULT_LENGTH);
636 assert!(truncated.ends_with("... [output truncated]"));
637 }
638
639 #[test]
640 fn result_truncation_multibyte_safe() {
641 let mut s = String::new();
643 for _ in 0..700 {
644 s.push('\u{6E2C}'); }
646 let truncated = truncate_result(&s);
647 assert!(truncated.len() <= MAX_RESULT_LENGTH);
649 assert!(truncated.ends_with("... [output truncated]"));
650 }
651
652 #[test]
653 fn status_payload_select_type() {
654 let payload = build_status_payload("select", "pending");
655 assert_eq!(
656 payload,
657 serde_json::json!({ "select": { "name": "pending" } })
658 );
659 }
660
661 #[test]
662 fn status_payload_status_type() {
663 let payload = build_status_payload("status", "done");
664 assert_eq!(payload, serde_json::json!({ "status": { "name": "done" } }));
665 }
666
667 #[test]
668 fn rich_text_payload_construction() {
669 let payload = build_rich_text_payload("test output");
670 let text = payload["rich_text"][0]["text"]["content"].as_str().unwrap();
671 assert_eq!(text, "test output");
672 }
673
674 #[test]
675 fn status_filter_select_type() {
676 let filter = build_status_filter("Status", "select", "pending");
677 assert_eq!(
678 filter,
679 serde_json::json!({
680 "property": "Status",
681 "select": { "equals": "pending" }
682 })
683 );
684 }
685
686 #[test]
687 fn status_filter_status_type() {
688 let filter = build_status_filter("Status", "status", "running");
689 assert_eq!(
690 filter,
691 serde_json::json!({
692 "property": "Status",
693 "status": { "equals": "running" }
694 })
695 );
696 }
697
698 #[test]
699 fn extract_text_from_title_property() {
700 let prop = serde_json::json!({
701 "type": "title",
702 "title": [
703 { "plain_text": "Hello " },
704 { "plain_text": "World" }
705 ]
706 });
707 assert_eq!(extract_text_from_property(Some(&prop)), "Hello World");
708 }
709
710 #[test]
711 fn extract_text_from_rich_text_property() {
712 let prop = serde_json::json!({
713 "type": "rich_text",
714 "rich_text": [{ "plain_text": "task content" }]
715 });
716 assert_eq!(extract_text_from_property(Some(&prop)), "task content");
717 }
718
719 #[test]
720 fn extract_text_from_none() {
721 assert_eq!(extract_text_from_property(None), "");
722 }
723
724 #[test]
725 fn extract_text_from_unknown_type() {
726 let prop = serde_json::json!({ "type": "number", "number": 42 });
727 assert_eq!(extract_text_from_property(Some(&prop)), "");
728 }
729
730 #[tokio::test]
731 async fn claim_task_respects_max_concurrent() {
732 let channel = NotionChannel::new(
733 "testbot",
734 "test-key".into(),
735 "test-db".into(),
736 5,
737 "Status".into(),
738 "Input".into(),
739 "Result".into(),
740 2, false,
742 );
743
744 assert!(channel.claim_task("page-1").await);
745 assert!(channel.claim_task("page-2").await);
746 assert!(!channel.claim_task("page-3").await);
748
749 channel.release_task("page-1").await;
751 assert!(channel.claim_task("page-3").await);
752 }
753}