1use anyhow::{Result, bail};
2use async_trait::async_trait;
3use parking_lot::Mutex;
4use serde::Deserialize;
5use std::time::{Duration, Instant};
6use zeroclaw_api::channel::{Channel, ChannelMessage, SendMessage};
7
8pub struct RedditChannel {
10 client_id: String,
11 client_secret: String,
12 refresh_token: String,
13 username: String,
14 subreddits: Vec<String>,
16 alias: String,
19 auth: Mutex<RedditAuth>,
20}
21
22struct RedditAuth {
23 access_token: String,
24 expires_at: Instant,
25}
26
27#[derive(Deserialize)]
28struct RedditTokenResponse {
29 access_token: String,
30 expires_in: u64,
31}
32
33#[derive(Deserialize)]
34struct RedditListing {
35 data: RedditListingData,
36}
37
38#[derive(Deserialize)]
39struct RedditListingData {
40 children: Vec<RedditChild>,
41}
42
43#[derive(Deserialize)]
44struct RedditChild {
45 data: RedditItemData,
46}
47
48#[allow(dead_code)]
49#[derive(Deserialize)]
50struct RedditItemData {
51 name: Option<String>,
52 author: Option<String>,
53 body: Option<String>,
54 subject: Option<String>,
55 parent_id: Option<String>,
56 link_id: Option<String>,
57 subreddit: Option<String>,
58 created_utc: Option<f64>,
59 new: Option<bool>,
60 #[serde(rename = "type")]
61 message_type: Option<String>,
62 context: Option<String>,
63}
64
65const REDDIT_API_BASE: &str = "https://oauth.reddit.com";
66const REDDIT_TOKEN_URL: &str = "https://www.reddit.com/api/v1/access_token";
67const USER_AGENT: &str = "zeroclaw:channel:v0.1.0 (by /u/zeroclaw-bot)";
68const POLL_INTERVAL: Duration = Duration::from_secs(5);
70
71impl RedditChannel {
72 pub fn new(
73 alias: impl Into<String>,
74 client_id: String,
75 client_secret: String,
76 refresh_token: String,
77 username: String,
78 subreddits: Vec<String>,
79 ) -> Self {
80 Self {
81 client_id,
82 client_secret,
83 refresh_token,
84 username,
85 subreddits,
86 alias: alias.into(),
87 auth: Mutex::new(RedditAuth {
88 access_token: String::new(),
89 expires_at: Instant::now(),
90 }),
91 }
92 }
93
94 fn http_client(&self) -> reqwest::Client {
95 zeroclaw_config::schema::build_runtime_proxy_client("channel.reddit")
96 }
97
98 async fn refresh_access_token(&self) -> Result<()> {
100 let client = self.http_client();
101 let resp = client
102 .post(REDDIT_TOKEN_URL)
103 .basic_auth(&self.client_id, Some(&self.client_secret))
104 .header("User-Agent", USER_AGENT)
105 .form(&[
106 ("grant_type", "refresh_token"),
107 ("refresh_token", &self.refresh_token),
108 ])
109 .send()
110 .await?;
111
112 let status = resp.status();
113 if !status.is_success() {
114 let body = resp
115 .text()
116 .await
117 .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
118 bail!("token refresh failed ({status}): {body}");
119 }
120
121 let token_resp: RedditTokenResponse = resp.json().await?;
122 let mut auth = self.auth.lock();
123 auth.access_token = token_resp.access_token;
124 auth.expires_at =
125 Instant::now() + Duration::from_secs(token_resp.expires_in.saturating_sub(60));
126 Ok(())
127 }
128
129 async fn get_access_token(&self) -> Result<String> {
131 {
132 let auth = self.auth.lock();
133 if !auth.access_token.is_empty() && Instant::now() < auth.expires_at {
134 return Ok(auth.access_token.clone());
135 }
136 }
137 self.refresh_access_token().await?;
138 let auth = self.auth.lock();
139 Ok(auth.access_token.clone())
140 }
141
142 async fn fetch_inbox(&self) -> Result<Vec<RedditChild>> {
144 let token = self.get_access_token().await?;
145 let client = self.http_client();
146
147 let resp = client
148 .get(format!("{REDDIT_API_BASE}/message/unread"))
149 .bearer_auth(&token)
150 .header("User-Agent", USER_AGENT)
151 .query(&[("limit", "25")])
152 .send()
153 .await?;
154
155 let status = resp.status();
156 if !status.is_success() {
157 let body = resp
158 .text()
159 .await
160 .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
161 ::zeroclaw_log::record!(
162 WARN,
163 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
164 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
165 .with_attrs(::serde_json::json!({"status": status.to_string(), "body": body})),
166 "inbox fetch failed"
167 );
168 return Ok(Vec::new());
169 }
170
171 let listing: RedditListing = resp.json().await?;
172 Ok(listing.data.children)
173 }
174
175 async fn mark_read(&self, fullnames: &[String]) -> Result<()> {
177 if fullnames.is_empty() {
178 return Ok(());
179 }
180 let token = self.get_access_token().await?;
181 let client = self.http_client();
182
183 let ids = fullnames.join(",");
184 let resp = client
185 .post(format!("{REDDIT_API_BASE}/api/read_message"))
186 .bearer_auth(&token)
187 .header("User-Agent", USER_AGENT)
188 .form(&[("id", ids.as_str())])
189 .send()
190 .await?;
191
192 if !resp.status().is_success() {
193 ::zeroclaw_log::record!(
194 WARN,
195 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
196 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
197 &format!("mark_read failed: {}", resp.status())
198 );
199 }
200 Ok(())
201 }
202
203 fn parse_item(&self, item: &RedditItemData) -> Option<ChannelMessage> {
205 let author = item.author.as_deref().unwrap_or("");
206 let body = item.body.as_deref().unwrap_or("");
207 let name = item.name.as_deref().unwrap_or("");
208
209 if author.eq_ignore_ascii_case(&self.username) || author.is_empty() || body.is_empty() {
211 return None;
212 }
213
214 if !self.subreddits.is_empty()
217 && let Some(ref item_sub) = item.subreddit
218 && !self
219 .subreddits
220 .iter()
221 .any(|allowed| allowed.eq_ignore_ascii_case(item_sub))
222 {
223 return None;
224 }
225
226 let reply_target =
229 if item.message_type.as_deref() == Some("comment_reply") || item.parent_id.is_some() {
230 item.parent_id.clone().unwrap_or_else(|| name.to_string())
232 } else {
233 author.to_string()
235 };
236
237 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
238 let timestamp = item.created_utc.unwrap_or(0.0) as u64;
239
240 Some(ChannelMessage {
241 id: format!("reddit_{name}"),
242 sender: author.to_string(),
243 reply_target,
244 content: body.to_string(),
245 channel: "reddit".to_string(),
246 channel_alias: None,
247 timestamp,
248 thread_ts: item.parent_id.clone(),
249 interruption_scope_id: None,
250 attachments: vec![],
251 subject: None,
252 })
253 }
254}
255
256impl ::zeroclaw_api::attribution::Attributable for RedditChannel {
257 fn role(&self) -> ::zeroclaw_api::attribution::Role {
258 ::zeroclaw_api::attribution::Role::Channel(::zeroclaw_api::attribution::ChannelKind::Reddit)
259 }
260 fn alias(&self) -> &str {
261 &self.alias
262 }
263}
264
265#[async_trait]
266impl Channel for RedditChannel {
267 fn name(&self) -> &str {
268 "reddit"
269 }
270
271 async fn send(&self, message: &SendMessage) -> Result<()> {
272 let token = self.get_access_token().await?;
273 let client = self.http_client();
274
275 if message.recipient.starts_with("t1_")
278 || message.recipient.starts_with("t3_")
279 || message.recipient.starts_with("t4_")
280 {
281 let resp = client
283 .post(format!("{REDDIT_API_BASE}/api/comment"))
284 .bearer_auth(&token)
285 .header("User-Agent", USER_AGENT)
286 .form(&[
287 ("thing_id", message.recipient.as_str()),
288 ("text", &message.content),
289 ])
290 .send()
291 .await?;
292
293 let status = resp.status();
294 if !status.is_success() {
295 let body = resp
296 .text()
297 .await
298 .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
299 bail!("comment reply failed ({status}): {body}");
300 }
301 } else {
302 let subject = message
304 .subject
305 .as_deref()
306 .unwrap_or("Message from ZeroClaw");
307 let resp = client
308 .post(format!("{REDDIT_API_BASE}/api/compose"))
309 .bearer_auth(&token)
310 .header("User-Agent", USER_AGENT)
311 .form(&[
312 ("to", message.recipient.as_str()),
313 ("subject", subject),
314 ("text", &message.content),
315 ])
316 .send()
317 .await?;
318
319 let status = resp.status();
320 if !status.is_success() {
321 let body = resp
322 .text()
323 .await
324 .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
325 bail!("DM failed ({status}): {body}");
326 }
327 }
328
329 Ok(())
330 }
331
332 async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
333 self.refresh_access_token().await?;
335
336 let scope = if self.subreddits.is_empty() {
337 String::new()
338 } else {
339 format!(
340 "in {}",
341 self.subreddits
342 .iter()
343 .map(|s| format!("r/{s}"))
344 .collect::<Vec<_>>()
345 .join(", ")
346 )
347 };
348 ::zeroclaw_log::record!(
349 INFO,
350 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
351 &format!("channel listening as u/{} {}...", self.username, scope)
352 );
353
354 loop {
355 tokio::time::sleep(POLL_INTERVAL).await;
356
357 let items = match self.fetch_inbox().await {
358 Ok(items) => items,
359 Err(e) => {
360 ::zeroclaw_log::record!(
361 WARN,
362 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
363 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
364 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
365 "poll error"
366 );
367 continue;
368 }
369 };
370
371 let mut read_ids = Vec::new();
372 for child in &items {
373 if let Some(ref name) = child.data.name {
374 read_ids.push(name.clone());
375 }
376 if let Some(msg) = self.parse_item(&child.data)
377 && tx.send(msg).await.is_err()
378 {
379 return Ok(());
380 }
381 }
382
383 if let Err(e) = self.mark_read(&read_ids).await {
384 ::zeroclaw_log::record!(
385 WARN,
386 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
387 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
388 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
389 "mark_read error"
390 );
391 }
392 }
393 }
394
395 async fn health_check(&self) -> bool {
396 self.get_access_token().await.is_ok()
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 fn make_channel() -> RedditChannel {
405 RedditChannel::new(
406 "testbot",
407 "client_id".into(),
408 "client_secret".into(),
409 "refresh_token".into(),
410 "testbot".into(),
411 Vec::new(),
412 )
413 }
414
415 fn make_channel_with_sub(sub: &str) -> RedditChannel {
416 RedditChannel::new(
417 "testbot",
418 "client_id".into(),
419 "client_secret".into(),
420 "refresh_token".into(),
421 "testbot".into(),
422 vec![sub.into()],
423 )
424 }
425
426 #[test]
427 fn parse_comment_reply() {
428 let ch = make_channel();
429 let item = RedditItemData {
430 name: Some("t1_abc123".into()),
431 author: Some("user1".into()),
432 body: Some("hello bot".into()),
433 subject: None,
434 parent_id: Some("t1_parent1".into()),
435 link_id: Some("t3_post1".into()),
436 subreddit: Some("rust".into()),
437 created_utc: Some(1_700_000_000.0),
438 new: Some(true),
439 message_type: Some("comment_reply".into()),
440 context: None,
441 };
442
443 let msg = ch.parse_item(&item).unwrap();
444 assert_eq!(msg.sender, "user1");
445 assert_eq!(msg.content, "hello bot");
446 assert_eq!(msg.reply_target, "t1_parent1");
447 assert_eq!(msg.channel, "reddit");
448 assert_eq!(msg.id, "reddit_t1_abc123");
449 }
450
451 #[test]
452 fn parse_dm() {
453 let ch = make_channel();
454 let item = RedditItemData {
455 name: Some("t4_dm456".into()),
456 author: Some("user2".into()),
457 body: Some("private message".into()),
458 subject: Some("Hello".into()),
459 parent_id: None,
460 link_id: None,
461 subreddit: None,
462 created_utc: Some(1_700_000_100.0),
463 new: Some(true),
464 message_type: None,
465 context: None,
466 };
467
468 let msg = ch.parse_item(&item).unwrap();
469 assert_eq!(msg.sender, "user2");
470 assert_eq!(msg.content, "private message");
471 assert_eq!(msg.reply_target, "user2"); }
473
474 #[test]
475 fn skip_self_messages() {
476 let ch = make_channel();
477 let item = RedditItemData {
478 name: Some("t1_self".into()),
479 author: Some("testbot".into()),
480 body: Some("my own message".into()),
481 subject: None,
482 parent_id: None,
483 link_id: None,
484 subreddit: None,
485 created_utc: Some(1_700_000_000.0),
486 new: Some(true),
487 message_type: None,
488 context: None,
489 };
490
491 assert!(ch.parse_item(&item).is_none());
492 }
493
494 #[test]
495 fn skip_empty_body() {
496 let ch = make_channel();
497 let item = RedditItemData {
498 name: Some("t1_empty".into()),
499 author: Some("user1".into()),
500 body: Some(String::new()),
501 subject: None,
502 parent_id: None,
503 link_id: None,
504 subreddit: None,
505 created_utc: Some(1_700_000_000.0),
506 new: Some(true),
507 message_type: None,
508 context: None,
509 };
510
511 assert!(ch.parse_item(&item).is_none());
512 }
513
514 #[test]
515 fn subreddit_filter() {
516 let ch = make_channel_with_sub("rust");
517 let item = RedditItemData {
518 name: Some("t1_other".into()),
519 author: Some("user1".into()),
520 body: Some("hello".into()),
521 subject: None,
522 parent_id: None,
523 link_id: None,
524 subreddit: Some("python".into()),
525 created_utc: Some(1_700_000_000.0),
526 new: Some(true),
527 message_type: None,
528 context: None,
529 };
530
531 assert!(ch.parse_item(&item).is_none());
532
533 let matching_item = RedditItemData {
534 name: Some("t1_match".into()),
535 author: Some("user1".into()),
536 body: Some("hello".into()),
537 subject: None,
538 parent_id: None,
539 link_id: None,
540 subreddit: Some("rust".into()),
541 created_utc: Some(1_700_000_000.0),
542 new: Some(true),
543 message_type: None,
544 context: None,
545 };
546
547 assert!(ch.parse_item(&matching_item).is_some());
548 }
549
550 #[test]
551 fn send_message_formatting() {
552 let dm = SendMessage::new("hello", "user1");
554 assert_eq!(dm.recipient, "user1");
555 assert_eq!(dm.content, "hello");
556
557 let reply = SendMessage::new("response", "t1_abc123");
558 assert!(reply.recipient.starts_with("t1_"));
559 }
560}