Skip to main content

zeroclaw_channels/
reddit.rs

1use anyhow::{Result, bail};
2use async_trait::async_trait;
3use parking_lot::Mutex;
4use serde::Deserialize;
5use std::time::{Duration, Instant};
6use zeroclaw_api::channel::{Channel, ChannelMessage, SendMessage};
7
8/// Reddit channel — polls for mentions, DMs, and comment replies via Reddit OAuth2 API.
9pub struct RedditChannel {
10    client_id: String,
11    client_secret: String,
12    refresh_token: String,
13    username: String,
14    /// Empty = accept items from any subreddit the bot has access to.
15    subreddits: Vec<String>,
16    /// The alias key under `[channels.reddit.<alias>]` this handle is
17    /// bound to. Used for attribution.
18    alias: String,
19    auth: Mutex<RedditAuth>,
20}
21
22struct RedditAuth {
23    access_token: String,
24    expires_at: Instant,
25}
26
27#[derive(Deserialize)]
28struct RedditTokenResponse {
29    access_token: String,
30    expires_in: u64,
31}
32
33#[derive(Deserialize)]
34struct RedditListing {
35    data: RedditListingData,
36}
37
38#[derive(Deserialize)]
39struct RedditListingData {
40    children: Vec<RedditChild>,
41}
42
43#[derive(Deserialize)]
44struct RedditChild {
45    data: RedditItemData,
46}
47
48#[allow(dead_code)]
49#[derive(Deserialize)]
50struct RedditItemData {
51    name: Option<String>,
52    author: Option<String>,
53    body: Option<String>,
54    subject: Option<String>,
55    parent_id: Option<String>,
56    link_id: Option<String>,
57    subreddit: Option<String>,
58    created_utc: Option<f64>,
59    new: Option<bool>,
60    #[serde(rename = "type")]
61    message_type: Option<String>,
62    context: Option<String>,
63}
64
65const REDDIT_API_BASE: &str = "https://oauth.reddit.com";
66const REDDIT_TOKEN_URL: &str = "https://www.reddit.com/api/v1/access_token";
67const USER_AGENT: &str = "zeroclaw:channel:v0.1.0 (by /u/zeroclaw-bot)";
68/// Reddit enforces 60 requests per minute.
69const POLL_INTERVAL: Duration = Duration::from_secs(5);
70
71impl RedditChannel {
72    pub fn new(
73        alias: impl Into<String>,
74        client_id: String,
75        client_secret: String,
76        refresh_token: String,
77        username: String,
78        subreddits: Vec<String>,
79    ) -> Self {
80        Self {
81            client_id,
82            client_secret,
83            refresh_token,
84            username,
85            subreddits,
86            alias: alias.into(),
87            auth: Mutex::new(RedditAuth {
88                access_token: String::new(),
89                expires_at: Instant::now(),
90            }),
91        }
92    }
93
94    fn http_client(&self) -> reqwest::Client {
95        zeroclaw_config::schema::build_runtime_proxy_client("channel.reddit")
96    }
97
98    /// Refresh the OAuth2 access token using the refresh token.
99    async fn refresh_access_token(&self) -> Result<()> {
100        let client = self.http_client();
101        let resp = client
102            .post(REDDIT_TOKEN_URL)
103            .basic_auth(&self.client_id, Some(&self.client_secret))
104            .header("User-Agent", USER_AGENT)
105            .form(&[
106                ("grant_type", "refresh_token"),
107                ("refresh_token", &self.refresh_token),
108            ])
109            .send()
110            .await?;
111
112        let status = resp.status();
113        if !status.is_success() {
114            let body = resp
115                .text()
116                .await
117                .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
118            bail!("token refresh failed ({status}): {body}");
119        }
120
121        let token_resp: RedditTokenResponse = resp.json().await?;
122        let mut auth = self.auth.lock();
123        auth.access_token = token_resp.access_token;
124        auth.expires_at =
125            Instant::now() + Duration::from_secs(token_resp.expires_in.saturating_sub(60));
126        Ok(())
127    }
128
129    /// Get a valid access token, refreshing if expired.
130    async fn get_access_token(&self) -> Result<String> {
131        {
132            let auth = self.auth.lock();
133            if !auth.access_token.is_empty() && Instant::now() < auth.expires_at {
134                return Ok(auth.access_token.clone());
135            }
136        }
137        self.refresh_access_token().await?;
138        let auth = self.auth.lock();
139        Ok(auth.access_token.clone())
140    }
141
142    /// Fetch unread inbox items (mentions, DMs, comment replies).
143    async fn fetch_inbox(&self) -> Result<Vec<RedditChild>> {
144        let token = self.get_access_token().await?;
145        let client = self.http_client();
146
147        let resp = client
148            .get(format!("{REDDIT_API_BASE}/message/unread"))
149            .bearer_auth(&token)
150            .header("User-Agent", USER_AGENT)
151            .query(&[("limit", "25")])
152            .send()
153            .await?;
154
155        let status = resp.status();
156        if !status.is_success() {
157            let body = resp
158                .text()
159                .await
160                .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
161            ::zeroclaw_log::record!(
162                WARN,
163                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
164                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
165                    .with_attrs(::serde_json::json!({"status": status.to_string(), "body": body})),
166                "inbox fetch failed"
167            );
168            return Ok(Vec::new());
169        }
170
171        let listing: RedditListing = resp.json().await?;
172        Ok(listing.data.children)
173    }
174
175    /// Mark inbox items as read.
176    async fn mark_read(&self, fullnames: &[String]) -> Result<()> {
177        if fullnames.is_empty() {
178            return Ok(());
179        }
180        let token = self.get_access_token().await?;
181        let client = self.http_client();
182
183        let ids = fullnames.join(",");
184        let resp = client
185            .post(format!("{REDDIT_API_BASE}/api/read_message"))
186            .bearer_auth(&token)
187            .header("User-Agent", USER_AGENT)
188            .form(&[("id", ids.as_str())])
189            .send()
190            .await?;
191
192        if !resp.status().is_success() {
193            ::zeroclaw_log::record!(
194                WARN,
195                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
196                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
197                &format!("mark_read failed: {}", resp.status())
198            );
199        }
200        Ok(())
201    }
202
203    /// Parse a Reddit inbox item into a ChannelMessage.
204    fn parse_item(&self, item: &RedditItemData) -> Option<ChannelMessage> {
205        let author = item.author.as_deref().unwrap_or("");
206        let body = item.body.as_deref().unwrap_or("");
207        let name = item.name.as_deref().unwrap_or("");
208
209        // Skip messages from ourselves
210        if author.eq_ignore_ascii_case(&self.username) || author.is_empty() || body.is_empty() {
211            return None;
212        }
213
214        // If a subreddit allowlist is set, skip items from other subreddits.
215        // Items without a subreddit (e.g. DMs) are always accepted.
216        if !self.subreddits.is_empty()
217            && let Some(ref item_sub) = item.subreddit
218            && !self
219                .subreddits
220                .iter()
221                .any(|allowed| allowed.eq_ignore_ascii_case(item_sub))
222        {
223            return None;
224        }
225
226        // Determine reply target: for comment replies use the parent thing name,
227        // for DMs reply to the author.
228        let reply_target =
229            if item.message_type.as_deref() == Some("comment_reply") || item.parent_id.is_some() {
230                // For comment replies, the recipient is the parent fullname
231                item.parent_id.clone().unwrap_or_else(|| name.to_string())
232            } else {
233                // For DMs, reply to the author
234                author.to_string()
235            };
236
237        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
238        let timestamp = item.created_utc.unwrap_or(0.0) as u64;
239
240        Some(ChannelMessage {
241            id: format!("reddit_{name}"),
242            sender: author.to_string(),
243            reply_target,
244            content: body.to_string(),
245            channel: "reddit".to_string(),
246            channel_alias: None,
247            timestamp,
248            thread_ts: item.parent_id.clone(),
249            interruption_scope_id: None,
250            attachments: vec![],
251            subject: None,
252        })
253    }
254}
255
256impl ::zeroclaw_api::attribution::Attributable for RedditChannel {
257    fn role(&self) -> ::zeroclaw_api::attribution::Role {
258        ::zeroclaw_api::attribution::Role::Channel(::zeroclaw_api::attribution::ChannelKind::Reddit)
259    }
260    fn alias(&self) -> &str {
261        &self.alias
262    }
263}
264
265#[async_trait]
266impl Channel for RedditChannel {
267    fn name(&self) -> &str {
268        "reddit"
269    }
270
271    async fn send(&self, message: &SendMessage) -> Result<()> {
272        let token = self.get_access_token().await?;
273        let client = self.http_client();
274
275        // If recipient looks like a Reddit fullname (t1_, t3_, t4_), it's a comment reply.
276        // Otherwise treat it as a DM to a username.
277        if message.recipient.starts_with("t1_")
278            || message.recipient.starts_with("t3_")
279            || message.recipient.starts_with("t4_")
280        {
281            // Comment reply
282            let resp = client
283                .post(format!("{REDDIT_API_BASE}/api/comment"))
284                .bearer_auth(&token)
285                .header("User-Agent", USER_AGENT)
286                .form(&[
287                    ("thing_id", message.recipient.as_str()),
288                    ("text", &message.content),
289                ])
290                .send()
291                .await?;
292
293            let status = resp.status();
294            if !status.is_success() {
295                let body = resp
296                    .text()
297                    .await
298                    .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
299                bail!("comment reply failed ({status}): {body}");
300            }
301        } else {
302            // Direct message
303            let subject = message
304                .subject
305                .as_deref()
306                .unwrap_or("Message from ZeroClaw");
307            let resp = client
308                .post(format!("{REDDIT_API_BASE}/api/compose"))
309                .bearer_auth(&token)
310                .header("User-Agent", USER_AGENT)
311                .form(&[
312                    ("to", message.recipient.as_str()),
313                    ("subject", subject),
314                    ("text", &message.content),
315                ])
316                .send()
317                .await?;
318
319            let status = resp.status();
320            if !status.is_success() {
321                let body = resp
322                    .text()
323                    .await
324                    .unwrap_or_else(|e| format!("<failed to read response: {e}>"));
325                bail!("DM failed ({status}): {body}");
326            }
327        }
328
329        Ok(())
330    }
331
332    async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
333        // Initial auth
334        self.refresh_access_token().await?;
335
336        let scope = if self.subreddits.is_empty() {
337            String::new()
338        } else {
339            format!(
340                "in {}",
341                self.subreddits
342                    .iter()
343                    .map(|s| format!("r/{s}"))
344                    .collect::<Vec<_>>()
345                    .join(", ")
346            )
347        };
348        ::zeroclaw_log::record!(
349            INFO,
350            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note),
351            &format!("channel listening as u/{} {}...", self.username, scope)
352        );
353
354        loop {
355            tokio::time::sleep(POLL_INTERVAL).await;
356
357            let items = match self.fetch_inbox().await {
358                Ok(items) => items,
359                Err(e) => {
360                    ::zeroclaw_log::record!(
361                        WARN,
362                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
363                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
364                            .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
365                        "poll error"
366                    );
367                    continue;
368                }
369            };
370
371            let mut read_ids = Vec::new();
372            for child in &items {
373                if let Some(ref name) = child.data.name {
374                    read_ids.push(name.clone());
375                }
376                if let Some(msg) = self.parse_item(&child.data)
377                    && tx.send(msg).await.is_err()
378                {
379                    return Ok(());
380                }
381            }
382
383            if let Err(e) = self.mark_read(&read_ids).await {
384                ::zeroclaw_log::record!(
385                    WARN,
386                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
387                        .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
388                        .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
389                    "mark_read error"
390                );
391            }
392        }
393    }
394
395    async fn health_check(&self) -> bool {
396        self.get_access_token().await.is_ok()
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    fn make_channel() -> RedditChannel {
405        RedditChannel::new(
406            "testbot",
407            "client_id".into(),
408            "client_secret".into(),
409            "refresh_token".into(),
410            "testbot".into(),
411            Vec::new(),
412        )
413    }
414
415    fn make_channel_with_sub(sub: &str) -> RedditChannel {
416        RedditChannel::new(
417            "testbot",
418            "client_id".into(),
419            "client_secret".into(),
420            "refresh_token".into(),
421            "testbot".into(),
422            vec![sub.into()],
423        )
424    }
425
426    #[test]
427    fn parse_comment_reply() {
428        let ch = make_channel();
429        let item = RedditItemData {
430            name: Some("t1_abc123".into()),
431            author: Some("user1".into()),
432            body: Some("hello bot".into()),
433            subject: None,
434            parent_id: Some("t1_parent1".into()),
435            link_id: Some("t3_post1".into()),
436            subreddit: Some("rust".into()),
437            created_utc: Some(1_700_000_000.0),
438            new: Some(true),
439            message_type: Some("comment_reply".into()),
440            context: None,
441        };
442
443        let msg = ch.parse_item(&item).unwrap();
444        assert_eq!(msg.sender, "user1");
445        assert_eq!(msg.content, "hello bot");
446        assert_eq!(msg.reply_target, "t1_parent1");
447        assert_eq!(msg.channel, "reddit");
448        assert_eq!(msg.id, "reddit_t1_abc123");
449    }
450
451    #[test]
452    fn parse_dm() {
453        let ch = make_channel();
454        let item = RedditItemData {
455            name: Some("t4_dm456".into()),
456            author: Some("user2".into()),
457            body: Some("private message".into()),
458            subject: Some("Hello".into()),
459            parent_id: None,
460            link_id: None,
461            subreddit: None,
462            created_utc: Some(1_700_000_100.0),
463            new: Some(true),
464            message_type: None,
465            context: None,
466        };
467
468        let msg = ch.parse_item(&item).unwrap();
469        assert_eq!(msg.sender, "user2");
470        assert_eq!(msg.content, "private message");
471        assert_eq!(msg.reply_target, "user2"); // DM reply goes to author
472    }
473
474    #[test]
475    fn skip_self_messages() {
476        let ch = make_channel();
477        let item = RedditItemData {
478            name: Some("t1_self".into()),
479            author: Some("testbot".into()),
480            body: Some("my own message".into()),
481            subject: None,
482            parent_id: None,
483            link_id: None,
484            subreddit: None,
485            created_utc: Some(1_700_000_000.0),
486            new: Some(true),
487            message_type: None,
488            context: None,
489        };
490
491        assert!(ch.parse_item(&item).is_none());
492    }
493
494    #[test]
495    fn skip_empty_body() {
496        let ch = make_channel();
497        let item = RedditItemData {
498            name: Some("t1_empty".into()),
499            author: Some("user1".into()),
500            body: Some(String::new()),
501            subject: None,
502            parent_id: None,
503            link_id: None,
504            subreddit: None,
505            created_utc: Some(1_700_000_000.0),
506            new: Some(true),
507            message_type: None,
508            context: None,
509        };
510
511        assert!(ch.parse_item(&item).is_none());
512    }
513
514    #[test]
515    fn subreddit_filter() {
516        let ch = make_channel_with_sub("rust");
517        let item = RedditItemData {
518            name: Some("t1_other".into()),
519            author: Some("user1".into()),
520            body: Some("hello".into()),
521            subject: None,
522            parent_id: None,
523            link_id: None,
524            subreddit: Some("python".into()),
525            created_utc: Some(1_700_000_000.0),
526            new: Some(true),
527            message_type: None,
528            context: None,
529        };
530
531        assert!(ch.parse_item(&item).is_none());
532
533        let matching_item = RedditItemData {
534            name: Some("t1_match".into()),
535            author: Some("user1".into()),
536            body: Some("hello".into()),
537            subject: None,
538            parent_id: None,
539            link_id: None,
540            subreddit: Some("rust".into()),
541            created_utc: Some(1_700_000_000.0),
542            new: Some(true),
543            message_type: None,
544            context: None,
545        };
546
547        assert!(ch.parse_item(&matching_item).is_some());
548    }
549
550    #[test]
551    fn send_message_formatting() {
552        // Verify SendMessage can be constructed for both DM and comment reply
553        let dm = SendMessage::new("hello", "user1");
554        assert_eq!(dm.recipient, "user1");
555        assert_eq!(dm.content, "hello");
556
557        let reply = SendMessage::new("response", "t1_abc123");
558        assert!(reply.recipient.starts_with("t1_"));
559    }
560}