Skip to main content

zeroclaw_tools/
http_request.rs

1use async_trait::async_trait;
2use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
3use serde_json::json;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::Duration;
7use zeroclaw_api::tool::{Tool, ToolResult};
8use zeroclaw_config::policy::SecurityPolicy;
9
10/// HTTP request tool for API interactions.
11/// Supports GET, POST, PUT, DELETE methods with configurable security.
12pub struct HttpRequestTool {
13    security: Arc<SecurityPolicy>,
14    allowed_domains: Vec<String>,
15    max_response_size: usize,
16    timeout_secs: u64,
17    allow_private_hosts: bool,
18}
19
20impl HttpRequestTool {
21    pub fn new(
22        security: Arc<SecurityPolicy>,
23        allowed_domains: Vec<String>,
24        max_response_size: usize,
25        timeout_secs: u64,
26        allow_private_hosts: bool,
27    ) -> anyhow::Result<Self> {
28        Ok(Self {
29            security,
30            allowed_domains: normalize_allowed_domains(allowed_domains)?,
31            max_response_size,
32            timeout_secs,
33            allow_private_hosts,
34        })
35    }
36
37    fn validate_url(&self, raw_url: &str) -> anyhow::Result<String> {
38        let url = raw_url.trim();
39
40        if url.is_empty() {
41            anyhow::bail!("URL cannot be empty");
42        }
43
44        if url.chars().any(char::is_whitespace) {
45            anyhow::bail!("URL cannot contain whitespace");
46        }
47
48        if !url.starts_with("http://") && !url.starts_with("https://") {
49            anyhow::bail!("Only http:// and https:// URLs are allowed");
50        }
51
52        if self.allowed_domains.is_empty() {
53            anyhow::bail!(
54                "HTTP request tool is enabled but no allowed_domains are configured. Add [http_request].allowed_domains in config.toml"
55            );
56        }
57
58        let host = extract_host(url)?;
59
60        if !self.allow_private_hosts && is_private_or_local_host(&host) {
61            anyhow::bail!("Blocked local/private host: {host}");
62        }
63
64        if !host_matches_allowlist(&host, &self.allowed_domains) {
65            anyhow::bail!("Host '{host}' is not in http_request.allowed_domains");
66        }
67
68        Ok(url.to_string())
69    }
70
71    fn validate_method(&self, method: &str) -> anyhow::Result<reqwest::Method> {
72        match method.to_uppercase().as_str() {
73            "GET" => Ok(reqwest::Method::GET),
74            "POST" => Ok(reqwest::Method::POST),
75            "PUT" => Ok(reqwest::Method::PUT),
76            "DELETE" => Ok(reqwest::Method::DELETE),
77            "PATCH" => Ok(reqwest::Method::PATCH),
78            "HEAD" => Ok(reqwest::Method::HEAD),
79            "OPTIONS" => Ok(reqwest::Method::OPTIONS),
80            _ => anyhow::bail!(
81                "Unsupported HTTP method: {method}. Supported: GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS"
82            ),
83        }
84    }
85
86    fn parse_headers(&self, headers: &serde_json::Value) -> anyhow::Result<HeaderMap> {
87        let mut result = HeaderMap::new();
88        if let Some(obj) = headers.as_object() {
89            for (key, value) in obj {
90                let Some(str_val) = value.as_str() else {
91                    anyhow::bail!("Header '{key}' value must be a string, got: {}", value);
92                };
93                let header_name = HeaderName::from_str(key)
94                    .map_err(|e| anyhow::Error::msg(format!("Invalid header name '{key}': {e}")))?;
95                let header_value = HeaderValue::from_str(str_val).map_err(|e| {
96                    anyhow::Error::msg(format!("Invalid value for header '{key}': {e}"))
97                })?;
98                result.insert(header_name, header_value);
99            }
100        }
101        Ok(result)
102    }
103
104    #[cfg(test)]
105    fn redact_headers_for_display(headers: &[(String, String)]) -> Vec<(String, String)> {
106        headers
107            .iter()
108            .map(|(key, value)| {
109                let lower = key.to_lowercase();
110                let is_sensitive = lower.contains("authorization")
111                    || lower.contains("api-key")
112                    || lower.contains("apikey")
113                    || lower.contains("token")
114                    || lower.contains("secret");
115                if is_sensitive {
116                    (key.clone(), "***REDACTED***".into())
117                } else {
118                    (key.clone(), value.clone())
119                }
120            })
121            .collect()
122    }
123
124    async fn execute_request(
125        &self,
126        url: &str,
127        method: reqwest::Method,
128        headers: HeaderMap,
129        body: Option<&str>,
130    ) -> anyhow::Result<reqwest::Response> {
131        let timeout_secs = if self.timeout_secs == 0 {
132            ::zeroclaw_log::record!(
133                WARN,
134                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
135                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
136                "http_request: timeout_secs is 0, using safe default of 30s"
137            );
138            30
139        } else {
140            self.timeout_secs
141        };
142        let builder = reqwest::Client::builder()
143            .timeout(Duration::from_secs(timeout_secs))
144            .connect_timeout(Duration::from_secs(10))
145            .redirect(reqwest::redirect::Policy::none());
146        let builder =
147            zeroclaw_config::schema::apply_runtime_proxy_to_builder(builder, "tool.http_request");
148        let client = builder.build()?;
149
150        let mut request = client.request(method, url).headers(headers);
151
152        if let Some(body_str) = body {
153            request = request.body(body_str.to_string());
154        }
155
156        Ok(request.send().await?)
157    }
158
159    fn truncate_response(&self, text: &str) -> String {
160        // 0 means unlimited — no truncation.
161        if self.max_response_size == 0 {
162            return text.to_string();
163        }
164        if text.len() > self.max_response_size {
165            let mut truncated = text
166                .chars()
167                .take(self.max_response_size)
168                .collect::<String>();
169            truncated.push_str("\n\n... [Response truncated due to size limit] ...");
170            truncated
171        } else {
172            text.to_string()
173        }
174    }
175}
176
177#[async_trait]
178impl Tool for HttpRequestTool {
179    fn name(&self) -> &str {
180        "http_request"
181    }
182
183    fn description(&self) -> &str {
184        "Make HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS methods. \
185        Security constraints: allowlist-only domains, no local/private hosts, configurable timeout and response size limits."
186    }
187
188    fn parameters_schema(&self) -> serde_json::Value {
189        json!({
190            "type": "object",
191            "properties": {
192                "url": {
193                    "type": "string",
194                    "description": "HTTP or HTTPS URL to request"
195                },
196                "method": {
197                    "type": "string",
198                    "description": "HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)",
199                    "default": "GET"
200                },
201                "headers": {
202                    "type": "object",
203                    "description": "Optional HTTP headers as key-value pairs (e.g., {\"Authorization\": \"Bearer token\", \"Content-Type\": \"application/json\"})",
204                    "default": {}
205                },
206                "body": {
207                    "type": "string",
208                    "description": "Optional request body (for POST, PUT, PATCH requests)"
209                }
210            },
211            "required": ["url"]
212        })
213    }
214
215    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
216        let url = args.get("url").and_then(|v| v.as_str()).ok_or_else(|| {
217            ::zeroclaw_log::record!(
218                WARN,
219                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
220                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
221                    .with_attrs(::serde_json::json!({"param": "url"})),
222                "http_request: missing url parameter"
223            );
224            anyhow::Error::msg("Missing 'url' parameter")
225        })?;
226
227        let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
228        let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
229        let body = args.get("body").and_then(|v| v.as_str());
230
231        if !self.security.can_act() {
232            return Ok(ToolResult {
233                success: false,
234                output: String::new(),
235                error: Some("Action blocked: autonomy is read-only".into()),
236            });
237        }
238
239        // Rate limiting is applied by the RateLimitedTool wrapper at
240        // registration time (see zeroclaw-runtime::tools::mod).
241
242        let url = match self.validate_url(url) {
243            Ok(v) => v,
244            Err(e) => {
245                return Ok(ToolResult {
246                    success: false,
247                    output: String::new(),
248                    error: Some(e.to_string()),
249                });
250            }
251        };
252
253        let method = match self.validate_method(method_str) {
254            Ok(m) => m,
255            Err(e) => {
256                return Ok(ToolResult {
257                    success: false,
258                    output: String::new(),
259                    error: Some(e.to_string()),
260                });
261            }
262        };
263
264        let request_headers = match self.parse_headers(&headers_val) {
265            Ok(h) => h,
266            Err(e) => {
267                return Ok(ToolResult {
268                    success: false,
269                    output: String::new(),
270                    error: Some(e.to_string()),
271                });
272            }
273        };
274
275        match self
276            .execute_request(&url, method, request_headers, body)
277            .await
278        {
279            Ok(response) => {
280                let status = response.status();
281                let status_code = status.as_u16();
282
283                // Get response headers (redact sensitive ones)
284                let response_headers = response.headers().iter();
285                let headers_text = response_headers
286                    .map(|(k, _)| {
287                        let is_sensitive = k.as_str().to_lowercase().contains("set-cookie");
288                        if is_sensitive {
289                            format!("{}: ***REDACTED***", k.as_str())
290                        } else {
291                            format!("{}: {:?}", k.as_str(), k.as_str())
292                        }
293                    })
294                    .collect::<Vec<_>>()
295                    .join(", ");
296
297                // Get response body with size limit
298                let response_text = match response.text().await {
299                    Ok(text) => self.truncate_response(&text),
300                    Err(e) => format!("[Failed to read response body: {e}]"),
301                };
302
303                let output = format!(
304                    "Status: {} {}\nResponse Headers: {}\n\nResponse Body:\n{}",
305                    status_code,
306                    status.canonical_reason().unwrap_or("Unknown"),
307                    headers_text,
308                    response_text
309                );
310
311                Ok(ToolResult {
312                    success: status.is_success(),
313                    output,
314                    error: if status.is_client_error() || status.is_server_error() {
315                        Some(format!("HTTP {}", status_code))
316                    } else {
317                        None
318                    },
319                })
320            }
321            Err(e) => Ok(ToolResult {
322                success: false,
323                output: String::new(),
324                error: Some(format!("HTTP request failed: {e}")),
325            }),
326        }
327    }
328}
329
330// Helper functions similar to browser_open.rs
331
332fn normalize_allowed_domains(domains: Vec<String>) -> anyhow::Result<Vec<String>> {
333    let mut rejected = Vec::new();
334    let mut normalized = domains
335        .into_iter()
336        .filter_map(|d| {
337            normalize_domain(&d).or_else(|| {
338                rejected.push(d.clone());
339                None
340            })
341        })
342        .collect::<Vec<_>>();
343    if !rejected.is_empty() {
344        anyhow::bail!(
345            "Invalid http_request.allowed_domains entry(s): [{}]. Each entry must be a valid domain, hostname, IPv4, or IPv6 address.",
346            rejected.join(", ")
347        );
348    }
349    normalized.sort_unstable();
350    normalized.dedup();
351    Ok(normalized)
352}
353
354fn normalize_domain(raw: &str) -> Option<String> {
355    let input = raw.trim();
356    if input.is_empty() || input.chars().any(char::is_whitespace) {
357        return None;
358    }
359
360    let bare_ip = match (input.starts_with('['), input.ends_with(']')) {
361        (true, true) => &input[1..input.len() - 1],
362        (false, false) => input,
363        _ => return None,
364    };
365    if let Ok(ip) = bare_ip.parse::<std::net::IpAddr>() {
366        return Some(ip.to_string().to_lowercase());
367    }
368
369    let parsed = reqwest::Url::parse(input)
370        .or_else(|_| reqwest::Url::parse(&format!("https://{input}")))
371        .ok()?;
372
373    if !parsed.username().is_empty() || parsed.password().is_some() {
374        return None;
375    }
376
377    let host = parsed.host_str()?;
378    let trimmed = host.trim();
379    let host_no_brackets = match (trimmed.starts_with('['), trimmed.ends_with(']')) {
380        (true, true) => &trimmed[1..trimmed.len() - 1],
381        (false, false) => trimmed,
382        _ => return None,
383    };
384    let normalized = host_no_brackets
385        .trim_start_matches('.')
386        .trim_end_matches('.');
387    if normalized.is_empty() {
388        return None;
389    }
390
391    Some(normalized.to_lowercase())
392}
393
394fn extract_host(url: &str) -> anyhow::Result<String> {
395    if !url.starts_with("http://") && !url.starts_with("https://") {
396        ::zeroclaw_log::record!(
397            WARN,
398            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
399                .with_outcome(::zeroclaw_log::EventOutcome::Failure)
400                .with_attrs(::serde_json::json!({"url": url})),
401            "http_request: non-http(s) URL rejected"
402        );
403        anyhow::bail!("Only http:// and https:// URLs are allowed");
404    }
405
406    let parsed = reqwest::Url::parse(url).map_err(|e| {
407        ::zeroclaw_log::record!(
408            WARN,
409            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
410                .with_outcome(::zeroclaw_log::EventOutcome::Failure)
411                .with_attrs(::serde_json::json!({"url": url})),
412            "http_request: invalid URL"
413        );
414        anyhow::Error::msg(format!("Invalid URL format: {e}"))
415    })?;
416
417    if !parsed.username().is_empty() || parsed.password().is_some() {
418        anyhow::bail!("URL userinfo is not allowed");
419    }
420
421    let host = parsed
422        .host_str()
423        .ok_or_else(|| anyhow::Error::msg("URL must include a host"))?;
424
425    let trimmed = host.trim();
426    let host_no_brackets = match (trimmed.starts_with('['), trimmed.ends_with(']')) {
427        (true, true) => &trimmed[1..trimmed.len() - 1],
428        (false, false) => trimmed,
429        _ => {
430            anyhow::bail!("URL host has unmatched IPv6 brackets");
431        }
432    };
433    let host = host_no_brackets.trim_end_matches('.').to_lowercase();
434
435    if host.is_empty() {
436        anyhow::bail!("URL must include a valid host");
437    }
438
439    Ok(host)
440}
441
442fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
443    if allowed_domains.iter().any(|domain| domain == "*") {
444        return true;
445    }
446
447    let host_is_ip = host.parse::<std::net::IpAddr>().is_ok();
448    allowed_domains.iter().any(|domain| {
449        if host_is_ip || domain.parse::<std::net::IpAddr>().is_ok() {
450            host == domain
451        } else {
452            host == domain
453                || host
454                    .strip_suffix(domain)
455                    .is_some_and(|prefix| prefix.ends_with('.'))
456        }
457    })
458}
459
460fn is_private_or_local_host(host: &str) -> bool {
461    // Strip brackets from IPv6 addresses like [::1]
462    let bare = host
463        .strip_prefix('[')
464        .and_then(|h| h.strip_suffix(']'))
465        .unwrap_or(host);
466
467    let has_local_tld = bare
468        .rsplit('.')
469        .next()
470        .is_some_and(|label| label == "local");
471
472    if bare == "localhost" || bare.ends_with(".localhost") || has_local_tld {
473        return true;
474    }
475
476    if let Ok(ip) = bare.parse::<std::net::IpAddr>() {
477        return match ip {
478            std::net::IpAddr::V4(v4) => is_non_global_v4(v4),
479            std::net::IpAddr::V6(v6) => is_non_global_v6(v6),
480        };
481    }
482
483    false
484}
485
486/// Returns true if the IPv4 address is not globally routable.
487fn is_non_global_v4(v4: std::net::Ipv4Addr) -> bool {
488    let [a, b, c, _] = v4.octets();
489    v4.is_loopback()                       // 127.0.0.0/8
490        || v4.is_private()                 // 10/8, 172.16/12, 192.168/16
491        || v4.is_link_local()              // 169.254.0.0/16
492        || v4.is_unspecified()             // 0.0.0.0
493        || v4.is_broadcast()              // 255.255.255.255
494        || v4.is_multicast()              // 224.0.0.0/4
495        || (a == 100 && (64..=127).contains(&b)) // Shared address space (RFC 6598)
496        || a >= 240                        // Reserved (240.0.0.0/4, except broadcast)
497        || (a == 192 && b == 0 && (c == 0 || c == 2)) // IETF assignments + TEST-NET-1
498        || (a == 198 && b == 51)           // Documentation (198.51.100.0/24)
499        || (a == 203 && b == 0)            // Documentation (203.0.113.0/24)
500        || (a == 198 && (18..=19).contains(&b)) // Benchmarking (198.18.0.0/15)
501}
502
503/// Returns true if the IPv6 address is not globally routable.
504fn is_non_global_v6(v6: std::net::Ipv6Addr) -> bool {
505    let segs = v6.segments();
506    v6.is_loopback()                       // ::1
507        || v6.is_unspecified()             // ::
508        || v6.is_multicast()              // ff00::/8
509        || (segs[0] & 0xfe00) == 0xfc00   // Unique-local (fc00::/7)
510        || (segs[0] & 0xffc0) == 0xfe80   // Link-local (fe80::/10)
511        || (segs[0] == 0x2001 && segs[1] == 0x0db8) // Documentation (2001:db8::/32)
512        || v6.to_ipv4_mapped().is_some_and(is_non_global_v4)
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use zeroclaw_config::autonomy::AutonomyLevel;
519    use zeroclaw_config::policy::SecurityPolicy;
520
521    fn test_tool(allowed_domains: Vec<&str>) -> HttpRequestTool {
522        test_tool_with_private(allowed_domains, false)
523    }
524
525    fn test_tool_with_private(
526        allowed_domains: Vec<&str>,
527        allow_private_hosts: bool,
528    ) -> HttpRequestTool {
529        let security = Arc::new(SecurityPolicy {
530            autonomy: AutonomyLevel::Supervised,
531            ..SecurityPolicy::default()
532        });
533        HttpRequestTool::new(
534            security,
535            allowed_domains.into_iter().map(String::from).collect(),
536            1_000_000,
537            30,
538            allow_private_hosts,
539        )
540        .unwrap()
541    }
542
543    #[test]
544    fn normalize_domain_strips_scheme_path_and_case() {
545        let got = normalize_domain("  HTTPS://Docs.Example.com/path ").unwrap();
546        assert_eq!(got, "docs.example.com");
547    }
548
549    #[test]
550    fn normalize_domain_accepts_ipv6_literal() {
551        let got = normalize_domain("[2001:db8::1]").unwrap();
552        assert_eq!(got, "2001:db8::1");
553    }
554
555    #[test]
556    fn normalize_domain_rejects_userinfo() {
557        assert!(normalize_domain("https://user@example.com").is_none());
558        assert!(normalize_domain("user@example.com").is_none());
559        assert!(normalize_domain("https://user:pass@example.com").is_none());
560        assert!(normalize_domain("user:pass@example.com").is_none());
561    }
562
563    #[test]
564    fn normalize_domain_rejects_unmatched_brackets() {
565        assert!(normalize_domain("[::1").is_none());
566        assert!(normalize_domain("::1]").is_none());
567        assert!(normalize_domain("[127.0.0.1").is_none());
568        assert!(normalize_domain("127.0.0.1]").is_none());
569    }
570
571    #[test]
572    fn extract_host_normalizes_ipv6_without_brackets() {
573        let got = extract_host("https://[2001:db8::1]:443/path").unwrap();
574        assert_eq!(got, "2001:db8::1");
575    }
576
577    #[test]
578    fn normalize_allowed_domains_rejects_invalid_entries() {
579        let err = normalize_allowed_domains(vec![
580            "".into(),
581            "example.com".into(),
582            "   ".into(),
583            "api.example.com".into(),
584        ])
585        .unwrap_err();
586        let msg = err.to_string();
587        assert!(
588            msg.contains("Invalid http_request.allowed_domains entry"),
589            "got: {msg}"
590        );
591    }
592
593    #[test]
594    fn normalize_allowed_domains_accepts_all_valid() {
595        let got = normalize_allowed_domains(vec!["example.com".into(), "api.example.com".into()])
596            .unwrap();
597        assert_eq!(got.len(), 2);
598        assert!(got.contains(&"example.com".to_string()));
599        assert!(got.contains(&"api.example.com".to_string()));
600    }
601
602    #[test]
603    fn normalize_allowed_domains_deduplicates() {
604        let got = normalize_allowed_domains(vec![
605            "example.com".into(),
606            "EXAMPLE.COM".into(),
607            "https://example.com/".into(),
608        ])
609        .unwrap();
610        assert_eq!(got, vec!["example.com".to_string()]);
611    }
612
613    #[test]
614    fn validate_accepts_exact_domain() {
615        let tool = test_tool(vec!["example.com"]);
616        let got = tool.validate_url("https://example.com/docs").unwrap();
617        assert_eq!(got, "https://example.com/docs");
618    }
619
620    #[test]
621    fn validate_accepts_http() {
622        let tool = test_tool(vec!["example.com"]);
623        assert!(tool.validate_url("http://example.com").is_ok());
624    }
625
626    #[test]
627    fn validate_accepts_subdomain() {
628        let tool = test_tool(vec!["example.com"]);
629        assert!(tool.validate_url("https://api.example.com/v1").is_ok());
630    }
631
632    #[test]
633    fn validate_accepts_wildcard_allowlist_for_public_host() {
634        let tool = test_tool(vec!["*"]);
635        assert!(tool.validate_url("https://news.ycombinator.com").is_ok());
636    }
637
638    #[test]
639    fn validate_wildcard_allowlist_still_rejects_private_host() {
640        let tool = test_tool(vec!["*"]);
641        let err = tool
642            .validate_url("https://localhost:8080")
643            .unwrap_err()
644            .to_string();
645        assert!(err.contains("local/private"));
646    }
647
648    #[test]
649    fn validate_rejects_allowlist_miss() {
650        let tool = test_tool(vec!["example.com"]);
651        let err = tool
652            .validate_url("https://google.com")
653            .unwrap_err()
654            .to_string();
655        assert!(err.contains("allowed_domains"));
656    }
657
658    #[test]
659    fn validate_rejects_localhost() {
660        let tool = test_tool(vec!["localhost"]);
661        let err = tool
662            .validate_url("https://localhost:8080")
663            .unwrap_err()
664            .to_string();
665        assert!(err.contains("local/private"));
666    }
667
668    #[test]
669    fn validate_rejects_private_ipv4() {
670        let tool = test_tool(vec!["192.168.1.5"]);
671        let err = tool
672            .validate_url("https://192.168.1.5")
673            .unwrap_err()
674            .to_string();
675        assert!(err.contains("local/private"));
676    }
677
678    #[test]
679    fn validate_rejects_whitespace() {
680        let tool = test_tool(vec!["example.com"]);
681        let err = tool
682            .validate_url("https://example.com/hello world")
683            .unwrap_err()
684            .to_string();
685        assert!(err.contains("whitespace"));
686    }
687
688    #[test]
689    fn validate_rejects_userinfo() {
690        let tool = test_tool(vec!["example.com"]);
691        let err = tool
692            .validate_url("https://user@example.com")
693            .unwrap_err()
694            .to_string();
695        assert!(err.contains("userinfo"));
696    }
697
698    #[test]
699    fn validate_requires_allowlist() {
700        let security = Arc::new(SecurityPolicy::default());
701        let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30, false).unwrap();
702        let err = tool
703            .validate_url("https://example.com")
704            .unwrap_err()
705            .to_string();
706        assert!(err.contains("allowed_domains"));
707    }
708
709    #[test]
710    fn validate_accepts_valid_methods() {
711        let tool = test_tool(vec!["example.com"]);
712        assert!(tool.validate_method("GET").is_ok());
713        assert!(tool.validate_method("POST").is_ok());
714        assert!(tool.validate_method("PUT").is_ok());
715        assert!(tool.validate_method("DELETE").is_ok());
716        assert!(tool.validate_method("PATCH").is_ok());
717        assert!(tool.validate_method("HEAD").is_ok());
718        assert!(tool.validate_method("OPTIONS").is_ok());
719    }
720
721    #[test]
722    fn validate_rejects_invalid_method() {
723        let tool = test_tool(vec!["example.com"]);
724        let err = tool.validate_method("INVALID").unwrap_err().to_string();
725        assert!(err.contains("Unsupported HTTP method"));
726    }
727
728    #[test]
729    fn blocks_multicast_ipv4() {
730        assert!(is_private_or_local_host("224.0.0.1"));
731        assert!(is_private_or_local_host("239.255.255.255"));
732    }
733
734    #[test]
735    fn blocks_broadcast() {
736        assert!(is_private_or_local_host("255.255.255.255"));
737    }
738
739    #[test]
740    fn blocks_reserved_ipv4() {
741        assert!(is_private_or_local_host("240.0.0.1"));
742        assert!(is_private_or_local_host("250.1.2.3"));
743    }
744
745    #[test]
746    fn blocks_documentation_ranges() {
747        assert!(is_private_or_local_host("192.0.2.1")); // TEST-NET-1
748        assert!(is_private_or_local_host("198.51.100.1")); // TEST-NET-2
749        assert!(is_private_or_local_host("203.0.113.1")); // TEST-NET-3
750    }
751
752    #[test]
753    fn blocks_benchmarking_range() {
754        assert!(is_private_or_local_host("198.18.0.1"));
755        assert!(is_private_or_local_host("198.19.255.255"));
756    }
757
758    #[test]
759    fn blocks_ipv6_localhost() {
760        assert!(is_private_or_local_host("::1"));
761        assert!(is_private_or_local_host("[::1]"));
762    }
763
764    #[test]
765    fn blocks_ipv6_multicast() {
766        assert!(is_private_or_local_host("ff02::1"));
767    }
768
769    #[test]
770    fn blocks_ipv6_link_local() {
771        assert!(is_private_or_local_host("fe80::1"));
772    }
773
774    #[test]
775    fn blocks_ipv6_unique_local() {
776        assert!(is_private_or_local_host("fd00::1"));
777    }
778
779    #[test]
780    fn blocks_ipv4_mapped_ipv6() {
781        assert!(is_private_or_local_host("::ffff:127.0.0.1"));
782        assert!(is_private_or_local_host("::ffff:192.168.1.1"));
783        assert!(is_private_or_local_host("::ffff:10.0.0.1"));
784    }
785
786    #[test]
787    fn allows_public_ipv4() {
788        assert!(!is_private_or_local_host("8.8.8.8"));
789        assert!(!is_private_or_local_host("1.1.1.1"));
790        assert!(!is_private_or_local_host("93.184.216.34"));
791    }
792
793    #[test]
794    fn blocks_ipv6_documentation_range() {
795        assert!(is_private_or_local_host("2001:db8::1"));
796    }
797
798    #[test]
799    fn allows_public_ipv6() {
800        assert!(!is_private_or_local_host("2607:f8b0:4004:800::200e"));
801    }
802
803    #[test]
804    fn blocks_shared_address_space() {
805        assert!(is_private_or_local_host("100.64.0.1"));
806        assert!(is_private_or_local_host("100.127.255.255"));
807        assert!(!is_private_or_local_host("100.63.0.1")); // Just below range
808        assert!(!is_private_or_local_host("100.128.0.1")); // Just above range
809    }
810
811    #[tokio::test]
812    async fn execute_blocks_readonly_mode() {
813        let security = Arc::new(SecurityPolicy {
814            autonomy: AutonomyLevel::ReadOnly,
815            ..SecurityPolicy::default()
816        });
817        let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30, false)
818            .unwrap();
819        let result = tool
820            .execute(json!({"url": "https://example.com"}))
821            .await
822            .unwrap();
823        assert!(!result.success);
824        assert!(result.error.unwrap().contains("read-only"));
825    }
826
827    #[test]
828    fn truncate_response_within_limit() {
829        let tool = test_tool(vec!["example.com"]);
830        let text = "hello world";
831        assert_eq!(tool.truncate_response(text), "hello world");
832    }
833
834    #[test]
835    fn truncate_response_over_limit() {
836        let tool = HttpRequestTool::new(
837            Arc::new(SecurityPolicy::default()),
838            vec!["example.com".into()],
839            10,
840            30,
841            false,
842        )
843        .unwrap();
844        let text = "hello world this is long";
845        let truncated = tool.truncate_response(text);
846        assert!(truncated.len() <= 10 + 60); // limit + message
847        assert!(truncated.contains("[Response truncated"));
848    }
849
850    #[test]
851    fn truncate_response_zero_means_unlimited() {
852        let tool = HttpRequestTool::new(
853            Arc::new(SecurityPolicy::default()),
854            vec!["example.com".into()],
855            0, // max_response_size = 0 means no limit
856            30,
857            false,
858        )
859        .unwrap();
860        let text = "a".repeat(10_000_000);
861        assert_eq!(tool.truncate_response(&text), text);
862    }
863
864    #[test]
865    fn truncate_response_nonzero_still_truncates() {
866        let tool = HttpRequestTool::new(
867            Arc::new(SecurityPolicy::default()),
868            vec!["example.com".into()],
869            5,
870            30,
871            false,
872        )
873        .unwrap();
874        let text = "hello world";
875        let truncated = tool.truncate_response(text);
876        assert!(truncated.starts_with("hello"));
877        assert!(truncated.contains("[Response truncated"));
878    }
879
880    #[test]
881    fn parse_headers_rejects_non_string_values() {
882        let tool = test_tool(vec!["example.com"]);
883        let headers = json!({
884            "X-Number": 42,
885            "Content-Type": "application/json"
886        });
887        let err = tool.parse_headers(&headers).unwrap_err().to_string();
888        assert!(
889            err.contains("X-Number"),
890            "Should reject non-string header value, got: {err}"
891        );
892    }
893
894    #[test]
895    fn parse_headers_preserves_original_values() {
896        let tool = test_tool(vec!["example.com"]);
897        let headers = json!({
898            "Authorization": "Bearer secret",
899            "Content-Type": "application/json",
900            "X-API-Key": "my-key"
901        });
902        let parsed = tool.parse_headers(&headers).unwrap();
903        assert_eq!(parsed.len(), 3);
904        assert_eq!(parsed["authorization"], "Bearer secret");
905        assert_eq!(parsed["x-api-key"], "my-key");
906        assert_eq!(parsed["content-type"], "application/json");
907    }
908
909    #[test]
910    fn redact_headers_for_display_redacts_sensitive() {
911        let headers = vec![
912            ("Authorization".into(), "Bearer secret".into()),
913            ("Content-Type".into(), "application/json".into()),
914            ("X-API-Key".into(), "my-key".into()),
915            ("X-Secret-Token".into(), "tok-123".into()),
916        ];
917        let redacted = HttpRequestTool::redact_headers_for_display(&headers);
918        assert_eq!(redacted.len(), 4);
919        assert!(
920            redacted
921                .iter()
922                .any(|(k, v)| k == "Authorization" && v == "***REDACTED***")
923        );
924        assert!(
925            redacted
926                .iter()
927                .any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***")
928        );
929        assert!(
930            redacted
931                .iter()
932                .any(|(k, v)| k == "X-Secret-Token" && v == "***REDACTED***")
933        );
934        assert!(
935            redacted
936                .iter()
937                .any(|(k, v)| k == "Content-Type" && v == "application/json")
938        );
939    }
940
941    #[test]
942    fn redact_headers_does_not_alter_original() {
943        let headers = vec![("Authorization".into(), "Bearer real-token".into())];
944        let _ = HttpRequestTool::redact_headers_for_display(&headers);
945        assert_eq!(headers[0].1, "Bearer real-token");
946    }
947
948    // ── SSRF: alternate IP notation bypass defense-in-depth ─────────
949    //
950    // Rust's IpAddr::parse() rejects non-standard notations (octal, hex,
951    // decimal integer, zero-padded). These tests document that property
952    // so regressions are caught if the parsing strategy ever changes.
953
954    #[test]
955    fn ssrf_octal_loopback_not_parsed_as_ip() {
956        // 0177.0.0.1 is octal for 127.0.0.1 in some languages, but
957        // Rust's IpAddr rejects it — it falls through as a hostname.
958        assert!(!is_private_or_local_host("0177.0.0.1"));
959    }
960
961    #[test]
962    fn ssrf_hex_loopback_not_parsed_as_ip() {
963        // 0x7f000001 is hex for 127.0.0.1 in some languages.
964        assert!(!is_private_or_local_host("0x7f000001"));
965    }
966
967    #[test]
968    fn ssrf_decimal_loopback_not_parsed_as_ip() {
969        // 2130706433 is decimal for 127.0.0.1 in some languages.
970        assert!(!is_private_or_local_host("2130706433"));
971    }
972
973    #[test]
974    fn ssrf_zero_padded_loopback_not_parsed_as_ip() {
975        // 127.000.000.001 uses zero-padded octets.
976        assert!(!is_private_or_local_host("127.000.000.001"));
977    }
978
979    #[test]
980    fn ssrf_alternate_notations_rejected_by_validate_url() {
981        // Alternate notations must be blocked by validation.
982        // Depending on URL canonicalization, they may be rejected either as:
983        // - private/local hosts, or
984        // - allowlist mismatches.
985        let tool = test_tool(vec!["example.com"]);
986        for notation in [
987            "http://0177.0.0.1",
988            "http://0x7f000001",
989            "http://2130706433",
990            "http://127.000.000.001",
991        ] {
992            let err = tool.validate_url(notation).unwrap_err().to_string();
993            assert!(
994                err.contains("allowed_domains") || err.contains("local/private"),
995                "Expected secure rejection for {notation}, got: {err}"
996            );
997        }
998    }
999
1000    #[test]
1001    fn redirect_policy_is_none() {
1002        // Structural test: the tool should be buildable with redirect-safe config.
1003        // The actual Policy::none() enforcement is in execute_request's client builder.
1004        let tool = test_tool(vec!["example.com"]);
1005        assert_eq!(tool.name(), "http_request");
1006    }
1007
1008    // ── §1.4 DNS rebinding / SSRF defense-in-depth tests ─────
1009
1010    #[test]
1011    fn ssrf_blocks_loopback_127_range() {
1012        assert!(is_private_or_local_host("127.0.0.1"));
1013        assert!(is_private_or_local_host("127.0.0.2"));
1014        assert!(is_private_or_local_host("127.255.255.255"));
1015    }
1016
1017    #[test]
1018    fn ssrf_blocks_rfc1918_10_range() {
1019        assert!(is_private_or_local_host("10.0.0.1"));
1020        assert!(is_private_or_local_host("10.255.255.255"));
1021    }
1022
1023    #[test]
1024    fn ssrf_blocks_rfc1918_172_range() {
1025        assert!(is_private_or_local_host("172.16.0.1"));
1026        assert!(is_private_or_local_host("172.31.255.255"));
1027    }
1028
1029    #[test]
1030    fn ssrf_blocks_unspecified_address() {
1031        assert!(is_private_or_local_host("0.0.0.0"));
1032    }
1033
1034    #[test]
1035    fn ssrf_blocks_dot_localhost_subdomain() {
1036        assert!(is_private_or_local_host("evil.localhost"));
1037        assert!(is_private_or_local_host("a.b.localhost"));
1038    }
1039
1040    #[test]
1041    fn ssrf_blocks_dot_local_tld() {
1042        assert!(is_private_or_local_host("service.local"));
1043    }
1044
1045    #[test]
1046    fn ssrf_ipv6_unspecified() {
1047        assert!(is_private_or_local_host("::"));
1048    }
1049
1050    #[test]
1051    fn validate_rejects_ftp_scheme() {
1052        let tool = test_tool(vec!["example.com"]);
1053        let err = tool
1054            .validate_url("ftp://example.com")
1055            .unwrap_err()
1056            .to_string();
1057        assert!(err.contains("http://") || err.contains("https://"));
1058    }
1059
1060    #[test]
1061    fn validate_rejects_empty_url() {
1062        let tool = test_tool(vec!["example.com"]);
1063        let err = tool.validate_url("").unwrap_err().to_string();
1064        assert!(err.contains("empty"));
1065    }
1066
1067    #[test]
1068    fn validate_accepts_public_ipv6_host_when_allowlisted() {
1069        let tool = test_tool(vec!["2607:f8b0:4004:800::200e"]);
1070        assert!(
1071            tool.validate_url("https://[2607:f8b0:4004:800::200e]/path")
1072                .is_ok()
1073        );
1074    }
1075
1076    // ── allow_private_hosts opt-in tests ────────────────────────
1077
1078    #[test]
1079    fn default_blocks_private_hosts() {
1080        let tool = test_tool(vec!["localhost", "192.168.1.5", "*"]);
1081        assert!(
1082            tool.validate_url("https://localhost:8080")
1083                .unwrap_err()
1084                .to_string()
1085                .contains("local/private")
1086        );
1087        assert!(
1088            tool.validate_url("https://192.168.1.5")
1089                .unwrap_err()
1090                .to_string()
1091                .contains("local/private")
1092        );
1093        assert!(
1094            tool.validate_url("https://10.0.0.1")
1095                .unwrap_err()
1096                .to_string()
1097                .contains("local/private")
1098        );
1099    }
1100
1101    #[test]
1102    fn allow_private_hosts_permits_localhost() {
1103        let tool = test_tool_with_private(vec!["localhost"], true);
1104        assert!(tool.validate_url("https://localhost:8080").is_ok());
1105    }
1106
1107    #[test]
1108    fn allow_private_hosts_permits_private_ipv4() {
1109        let tool = test_tool_with_private(vec!["192.168.1.5"], true);
1110        assert!(tool.validate_url("https://192.168.1.5").is_ok());
1111    }
1112
1113    #[test]
1114    fn allow_private_hosts_permits_rfc1918_with_wildcard() {
1115        let tool = test_tool_with_private(vec!["*"], true);
1116        assert!(tool.validate_url("https://10.0.0.1").is_ok());
1117        assert!(tool.validate_url("https://172.16.0.1").is_ok());
1118        assert!(tool.validate_url("https://192.168.1.1").is_ok());
1119        assert!(tool.validate_url("http://localhost:8123").is_ok());
1120    }
1121
1122    #[test]
1123    fn allow_private_hosts_permits_ipv6_loopback_when_allowlisted() {
1124        let tool = test_tool_with_private(vec!["::1"], true);
1125        assert!(tool.validate_url("https://[::1]:8443").is_ok());
1126    }
1127
1128    #[test]
1129    fn allow_private_hosts_still_requires_allowlist() {
1130        let tool = test_tool_with_private(vec!["example.com"], true);
1131        let err = tool
1132            .validate_url("https://192.168.1.5")
1133            .unwrap_err()
1134            .to_string();
1135        assert!(
1136            err.contains("allowed_domains"),
1137            "Private host should still need allowlist match, got: {err}"
1138        );
1139    }
1140
1141    #[test]
1142    fn allow_private_hosts_false_still_blocks() {
1143        let tool = test_tool_with_private(vec!["*"], false);
1144        assert!(
1145            tool.validate_url("https://localhost:8080")
1146                .unwrap_err()
1147                .to_string()
1148                .contains("local/private")
1149        );
1150    }
1151
1152    // ── IPv6 end-to-end coverage ──────────────────────────────
1153
1154    #[test]
1155    fn ipv6_url_parse_variants_extract_correct_host() {
1156        assert_eq!(
1157            extract_host("https://[2001:db8::1]/api").unwrap(),
1158            "2001:db8::1"
1159        );
1160        assert_eq!(
1161            extract_host("https://[2001:db8::1]:8080/api?q=1").unwrap(),
1162            "2001:db8::1"
1163        );
1164        assert_eq!(
1165            extract_host("http://[2607:f8b0:4004:800::200e]:443/path#frag").unwrap(),
1166            "2607:f8b0:4004:800::200e"
1167        );
1168    }
1169
1170    #[test]
1171    fn ipv6_allowlist_handles_compressed_notation() {
1172        let tool = test_tool(vec!["::1", "fe80::1"]);
1173        assert!(tool.validate_url("https://[::1]:8443").is_err()); // blocked — local/private
1174        assert!(tool.validate_url("https://[fe80::1]").is_err()); // blocked — local/private
1175    }
1176
1177    #[test]
1178    fn ipv6_normalize_domain_handles_edge_cases() {
1179        assert_eq!(normalize_domain("::1").unwrap(), "::1");
1180        assert_eq!(normalize_domain("[::1]").unwrap(), "::1");
1181        assert_eq!(normalize_domain("2001:db8::1").unwrap(), "2001:db8::1");
1182        assert_eq!(normalize_domain("[2001:db8::1]").unwrap(), "2001:db8::1");
1183    }
1184
1185    #[test]
1186    fn ipv6_host_matches_allowlist_exact_only() {
1187        let domains = vec!["2001:db8::1".to_string()];
1188        // exact match
1189        assert!(host_matches_allowlist("2001:db8::1", &domains));
1190        // different IP — should NOT suffix-match as if it were a domain
1191        assert!(!host_matches_allowlist("2001:db8::2", &domains));
1192        // prefix should NOT match either
1193        assert!(!host_matches_allowlist("2001:db8::", &domains));
1194    }
1195
1196    #[tokio::test]
1197    async fn ipv6_end_to_end_real_request_over_loopback() {
1198        let listener = match tokio::net::TcpListener::bind("[::1]:0").await {
1199            Ok(l) => l,
1200            Err(_) => return, // IPv6 not available in this environment
1201        };
1202        let port = listener.local_addr().unwrap().port();
1203
1204        // Spawn a minimal HTTP server that responds with a known body.
1205        let server_handle = tokio::spawn(async move {
1206            if let Ok((mut stream, _)) = listener.accept().await {
1207                use tokio::io::AsyncWriteExt;
1208                let response = b"HTTP/1.1 200 OK\r\nContent-Length: 16\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\nhello from ipv6!";
1209                let _ = stream.write_all(response).await;
1210                let _ = stream.flush().await;
1211            }
1212        });
1213
1214        let url = format!("http://[::1]:{port}/");
1215
1216        let security = Arc::new(SecurityPolicy {
1217            autonomy: AutonomyLevel::Supervised,
1218            ..SecurityPolicy::default()
1219        });
1220        let tool = HttpRequestTool::new(
1221            security,
1222            vec!["::1".to_string()],
1223            1_000_000, // max_response_size
1224            5,         // timeout_secs
1225            true,      // allow_private_hosts
1226        )
1227        .unwrap();
1228
1229        let result = tokio::time::timeout(
1230            Duration::from_secs(10),
1231            tool.execute(json!({
1232                "url": url,
1233                "method": "GET"
1234            })),
1235        )
1236        .await;
1237
1238        // Abort the server task regardless of outcome.
1239        server_handle.abort();
1240
1241        match result {
1242            Ok(Ok(r)) if r.success && r.output.contains("hello from ipv6!") => {}
1243            Ok(Ok(_)) => {} // request completed but response didn't match — acceptable
1244            Ok(Err(_)) => {} // validation/network error — acceptable
1245            Err(_) => {}    // timeout — IPv6 connectivity may be unavailable
1246        }
1247    }
1248}