Skip to main content

zeroclaw_runtime/tunnel/
custom.rs

1use super::{SharedProcess, Tunnel, TunnelProcess, kill_shared, new_shared_process};
2use anyhow::{Result, bail};
3use tokio::io::AsyncBufReadExt;
4use tokio::process::Command;
5
6/// Custom Tunnel — bring your own tunnel binary.
7///
8/// Provide a `start_command` with `{port}` and `{host}` placeholders.
9/// Optionally provide a `url_pattern` regex to extract the public URL
10/// from stdout, and a `health_url` to poll for liveness.
11///
12/// Examples:
13/// - `bore local {port} --to bore.pub`
14/// - `frp -c /etc/frp/frpc.ini`
15/// - `ssh -R 80:localhost:{port} serveo.net`
16pub struct CustomTunnel {
17    start_command: String,
18    health_url: Option<String>,
19    url_pattern: Option<String>,
20    proc: SharedProcess,
21}
22
23impl CustomTunnel {
24    pub fn new(
25        start_command: String,
26        health_url: Option<String>,
27        url_pattern: Option<String>,
28    ) -> Self {
29        Self {
30            start_command,
31            health_url,
32            url_pattern,
33            proc: new_shared_process(),
34        }
35    }
36}
37
38#[async_trait::async_trait]
39impl Tunnel for CustomTunnel {
40    fn name(&self) -> &str {
41        "custom"
42    }
43
44    async fn start(&self, local_host: &str, local_port: u16) -> Result<String> {
45        let cmd = self
46            .start_command
47            .replace("{port}", &local_port.to_string())
48            .replace("{host}", local_host);
49
50        let parts: Vec<&str> = cmd.split_whitespace().collect();
51        if parts.is_empty() {
52            bail!("Custom tunnel start_command is empty");
53        }
54
55        let mut child = Command::new(parts[0])
56            .args(&parts[1..])
57            .stdout(std::process::Stdio::piped())
58            .stderr(std::process::Stdio::piped())
59            .kill_on_drop(true)
60            .spawn()?;
61
62        let mut public_url = format!("http://{local_host}:{local_port}");
63
64        // If a URL pattern is provided, try to extract the public URL from stdout
65        if let Some(ref pattern) = self.url_pattern
66            && let Some(stdout) = child.stdout.take()
67        {
68            let mut reader = tokio::io::BufReader::new(stdout).lines();
69            let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_secs(15);
70
71            while tokio::time::Instant::now() < deadline {
72                let line =
73                    tokio::time::timeout(tokio::time::Duration::from_secs(3), reader.next_line())
74                        .await;
75
76                match line {
77                    Ok(Ok(Some(l))) => {
78                        ::zeroclaw_log::record!(
79                            DEBUG,
80                            ::zeroclaw_log::Event::new(
81                                module_path!(),
82                                ::zeroclaw_log::Action::Note
83                            )
84                            .with_attrs(::serde_json::json!({"l": l})),
85                            "custom-tunnel: "
86                        );
87                        // Simple substring match on the pattern
88                        if l.contains(pattern) || l.contains("https://") || l.contains("http://") {
89                            // Extract URL from the line
90                            if let Some(idx) = l.find("https://") {
91                                let url_part = &l[idx..];
92                                let end = url_part
93                                    .find(|c: char| c.is_whitespace())
94                                    .unwrap_or(url_part.len());
95                                public_url = url_part[..end].to_string();
96                                break;
97                            } else if let Some(idx) = l.find("http://") {
98                                let url_part = &l[idx..];
99                                let end = url_part
100                                    .find(|c: char| c.is_whitespace())
101                                    .unwrap_or(url_part.len());
102                                public_url = url_part[..end].to_string();
103                                break;
104                            }
105                        }
106                    }
107                    Ok(Ok(None) | Err(_)) => break,
108                    Err(_) => {}
109                }
110            }
111        }
112
113        let mut guard = self.proc.lock().await;
114        *guard = Some(TunnelProcess {
115            child,
116            public_url: public_url.clone(),
117        });
118
119        Ok(public_url)
120    }
121
122    async fn stop(&self) -> Result<()> {
123        kill_shared(&self.proc).await
124    }
125
126    async fn health_check(&self) -> bool {
127        // If a health URL is configured, try to reach it
128        if let Some(ref url) = self.health_url {
129            return zeroclaw_config::schema::build_runtime_proxy_client("tunnel.custom")
130                .get(url)
131                .timeout(std::time::Duration::from_secs(5))
132                .send()
133                .await
134                .is_ok();
135        }
136
137        // Otherwise check if the process is still alive
138        let guard = self.proc.lock().await;
139        guard.as_ref().is_some_and(|tp| tp.child.id().is_some())
140    }
141
142    fn public_url(&self) -> Option<String> {
143        self.proc
144            .try_lock()
145            .ok()
146            .and_then(|g| g.as_ref().map(|tp| tp.public_url.clone()))
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[tokio::test]
155    async fn start_with_empty_command_returns_error() {
156        let tunnel = CustomTunnel::new("   ".into(), None, None);
157        let result = tunnel.start("127.0.0.1", 8080).await;
158
159        assert!(result.is_err());
160        assert!(
161            result
162                .unwrap_err()
163                .to_string()
164                .contains("start_command is empty")
165        );
166    }
167
168    #[tokio::test]
169    async fn start_without_pattern_returns_local_url() {
170        let tunnel = CustomTunnel::new("sleep 1".into(), None, None);
171
172        let url = tunnel.start("127.0.0.1", 4455).await.unwrap();
173        assert_eq!(url, "http://127.0.0.1:4455");
174        assert_eq!(
175            tunnel.public_url().as_deref(),
176            Some("http://127.0.0.1:4455")
177        );
178
179        tunnel.stop().await.unwrap();
180    }
181
182    #[tokio::test]
183    async fn start_with_pattern_extracts_url() {
184        let tunnel = CustomTunnel::new(
185            "echo https://public.example".into(),
186            None,
187            Some("public.example".into()),
188        );
189
190        let url = tunnel.start("localhost", 9999).await.unwrap();
191
192        assert_eq!(url, "https://public.example");
193        assert_eq!(
194            tunnel.public_url().as_deref(),
195            Some("https://public.example")
196        );
197
198        tunnel.stop().await.unwrap();
199    }
200
201    #[tokio::test]
202    async fn start_replaces_host_and_port_placeholders() {
203        let tunnel = CustomTunnel::new(
204            "echo http://{host}:{port}".into(),
205            None,
206            Some("http://".into()),
207        );
208
209        let url = tunnel.start("10.1.2.3", 4321).await.unwrap();
210
211        assert_eq!(url, "http://10.1.2.3:4321");
212        tunnel.stop().await.unwrap();
213    }
214
215    #[tokio::test]
216    async fn health_check_with_unreachable_health_url_returns_false() {
217        let tunnel = CustomTunnel::new(
218            "sleep 1".into(),
219            Some("http://127.0.0.1:9/healthz".into()),
220            None,
221        );
222
223        assert!(!tunnel.health_check().await);
224    }
225}