Skip to main content

zeroclaw_tools/
web_fetch.rs

1use async_trait::async_trait;
2use futures_util::StreamExt;
3use serde_json::json;
4use std::sync::Arc;
5use std::time::Duration;
6use zeroclaw_api::tool::{Tool, ToolResult};
7use zeroclaw_config::policy::SecurityPolicy;
8use zeroclaw_config::schema::FirecrawlConfig;
9
10/// Minimum body length to consider a standard fetch successful.
11/// Bodies shorter than this are treated as JS-only pages that need Firecrawl.
12const FIRECRAWL_MIN_BODY_LEN: usize = 100;
13
14/// Web fetch tool: fetches a web page and converts HTML to plain text for LLM consumption.
15///
16/// Unlike `http_request` (an API client returning raw responses), this tool:
17/// - Only supports GET
18/// - Follows redirects (up to 10)
19/// - Converts HTML to clean plain text via `nanohtml2text`
20/// - Passes through text/plain, text/markdown, and application/json as-is
21/// - Sets a descriptive User-Agent
22/// - Falls back to Firecrawl API when standard fetch fails (if enabled)
23pub struct WebFetchTool {
24    security: Arc<SecurityPolicy>,
25    allowed_domains: Vec<String>,
26    blocked_domains: Vec<String>,
27    allowed_private_hosts: Vec<String>,
28    max_response_size: usize,
29    timeout_secs: u64,
30    firecrawl: FirecrawlConfig,
31}
32
33impl WebFetchTool {
34    pub fn new(
35        security: Arc<SecurityPolicy>,
36        allowed_domains: Vec<String>,
37        blocked_domains: Vec<String>,
38        max_response_size: usize,
39        timeout_secs: u64,
40        firecrawl: FirecrawlConfig,
41        allowed_private_hosts: Vec<String>,
42    ) -> anyhow::Result<Self> {
43        Ok(Self {
44            security,
45            allowed_domains: normalize_allowed_domains(
46                allowed_domains,
47                "web_fetch.allowed_domains",
48            )?,
49            blocked_domains: normalize_allowed_domains(
50                blocked_domains,
51                "web_fetch.blocked_domains",
52            )?,
53            allowed_private_hosts: normalize_allowed_domains(
54                allowed_private_hosts,
55                "web_fetch.allowed_private_hosts",
56            )?,
57            max_response_size,
58            timeout_secs,
59            firecrawl,
60        })
61    }
62
63    fn validate_url(&self, raw_url: &str) -> anyhow::Result<String> {
64        validate_target_url(
65            raw_url,
66            &self.allowed_domains,
67            &self.blocked_domains,
68            &self.allowed_private_hosts,
69            "web_fetch",
70        )
71    }
72
73    fn truncate_response(&self, text: &str) -> String {
74        // max_response_size == 0 means "unlimited" (matches the
75        // http_request tool's documented semantics + tests at
76        // crates/zeroclaw-tools/src/http_request.rs:151). Without this
77        // branch, the unsigned-arithmetic path below would truncate
78        // every response to zero bytes, then append the truncation
79        // marker — useless content + spurious Firecrawl fallback.
80        if self.max_response_size == 0 {
81            return text.to_string();
82        }
83        if text.len() > self.max_response_size {
84            let mut truncated = text
85                .chars()
86                .take(self.max_response_size)
87                .collect::<String>();
88            truncated.push_str("\n\n... [Response truncated due to size limit] ...");
89            truncated
90        } else {
91            text.to_string()
92        }
93    }
94
95    async fn read_response_text_limited(
96        &self,
97        response: reqwest::Response,
98    ) -> anyhow::Result<String> {
99        let mut bytes_stream = response.bytes_stream();
100        // max_response_size == 0 → unlimited. Without this branch, the
101        // existing saturating_add(1) made hard_cap = 1 byte, so the
102        // entire stream was truncated after one byte. Use usize::MAX as
103        // the effective hard_cap when unlimited so append_chunk_with_cap
104        // never stops early on size grounds.
105        let hard_cap = if self.max_response_size == 0 {
106            usize::MAX
107        } else {
108            self.max_response_size.saturating_add(1)
109        };
110        let mut bytes = Vec::new();
111
112        while let Some(chunk_result) = bytes_stream.next().await {
113            let chunk = chunk_result?;
114            if append_chunk_with_cap(&mut bytes, &chunk, hard_cap) {
115                break;
116            }
117        }
118
119        Ok(String::from_utf8_lossy(&bytes).into_owned())
120    }
121
122    /// Whether the standard fetch result should trigger a Firecrawl fallback.
123    fn should_fallback_to_firecrawl(&self, result: &ToolResult) -> bool {
124        if !self.firecrawl.enabled {
125            return false;
126        }
127        // Fallback on failure (HTTP error, network error, etc.)
128        if !result.success {
129            return true;
130        }
131        // Fallback on empty or very short body (JS-only pages)
132        if result.output.trim().len() < FIRECRAWL_MIN_BODY_LEN {
133            return true;
134        }
135        false
136    }
137
138    /// Fetch content via the Firecrawl API.
139    async fn fetch_via_firecrawl(&self, url: &str) -> anyhow::Result<ToolResult> {
140        let api_key = std::env::var(&self.firecrawl.api_key_env).map_err(|_| {
141            ::zeroclaw_log::record!(
142                ERROR,
143                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
144                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
145                    .with_attrs(::serde_json::json!({
146                        "env_var": &self.firecrawl.api_key_env,
147                    })),
148                "web_fetch: Firecrawl API key missing from env"
149            );
150            anyhow::Error::msg(format!(
151                "Firecrawl API key not found in environment variable '{}'",
152                self.firecrawl.api_key_env
153            ))
154        })?;
155
156        let endpoint = format!("{}/scrape", self.firecrawl.api_url.trim_end_matches('/'));
157
158        let client = reqwest::Client::builder()
159            .timeout(Duration::from_secs(60))
160            .build()
161            .map_err(|e| {
162                ::zeroclaw_log::record!(
163                    ERROR,
164                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
165                        .with_outcome(::zeroclaw_log::EventOutcome::Failure)
166                        .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
167                    "web_fetch: failed to build Firecrawl HTTP client"
168                );
169                anyhow::Error::msg(format!("Failed to build Firecrawl HTTP client: {e}"))
170            })?;
171
172        let body = json!({
173            "url": url,
174            "formats": ["markdown"]
175        });
176
177        let response = client
178            .post(&endpoint)
179            .header("Authorization", format!("Bearer {api_key}"))
180            .header("Content-Type", "application/json")
181            .json(&body)
182            .send()
183            .await
184            .map_err(|e| {
185                ::zeroclaw_log::record!(
186                    ERROR,
187                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
188                        .with_outcome(::zeroclaw_log::EventOutcome::Failure)
189                        .with_attrs(::serde_json::json!({
190                            "phase": "firecrawl_request",
191                            "error": format!("{}", e),
192                        })),
193                    "web_fetch: Firecrawl request failed"
194                );
195                anyhow::Error::msg(format!("Firecrawl request failed: {e}"))
196            })?;
197
198        let status = response.status();
199        if !status.is_success() {
200            let error_body = response.text().await.unwrap_or_default();
201            return Ok(ToolResult {
202                success: false,
203                output: String::new(),
204                error: Some(format!(
205                    "Firecrawl API error: HTTP {} - {}",
206                    status.as_u16(),
207                    error_body
208                )),
209            });
210        }
211
212        let resp_json: serde_json::Value = response.json().await.map_err(|e| {
213            ::zeroclaw_log::record!(
214                ERROR,
215                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
216                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
217                    .with_attrs(::serde_json::json!({
218                        "phase": "firecrawl_response_parse",
219                        "error": format!("{}", e),
220                    })),
221                "web_fetch: failed to parse Firecrawl response"
222            );
223            anyhow::Error::msg(format!("Failed to parse Firecrawl response: {e}"))
224        })?;
225
226        let markdown = resp_json
227            .get("data")
228            .and_then(|d| d.get("markdown"))
229            .and_then(|m| m.as_str())
230            .unwrap_or("");
231
232        if markdown.is_empty() {
233            return Ok(ToolResult {
234                success: false,
235                output: String::new(),
236                error: Some("Firecrawl returned empty markdown content".into()),
237            });
238        }
239
240        let output = self.truncate_response(markdown);
241
242        Ok(ToolResult {
243            success: true,
244            output,
245            error: None,
246        })
247    }
248
249    /// Perform the standard HTTP GET fetch and convert to text.
250    async fn standard_fetch(&self, client: &reqwest::Client, url: &str) -> ToolResult {
251        let response = match client.get(url).send().await {
252            Ok(r) => r,
253            Err(e) => {
254                return ToolResult {
255                    success: false,
256                    output: String::new(),
257                    error: Some(format!("HTTP request failed: {e}")),
258                };
259            }
260        };
261
262        let status = response.status();
263        if !status.is_success() {
264            return ToolResult {
265                success: false,
266                output: String::new(),
267                error: Some(format!(
268                    "HTTP {} {}",
269                    status.as_u16(),
270                    status.canonical_reason().unwrap_or("Unknown")
271                )),
272            };
273        }
274
275        // Determine content type for processing strategy
276        let content_type = response
277            .headers()
278            .get(reqwest::header::CONTENT_TYPE)
279            .and_then(|v| v.to_str().ok())
280            .unwrap_or("")
281            .to_lowercase();
282
283        let body_mode = if content_type.contains("text/html") || content_type.is_empty() {
284            "html"
285        } else if content_type.contains("text/plain")
286            || content_type.contains("text/markdown")
287            || content_type.contains("application/json")
288        {
289            "plain"
290        } else {
291            return ToolResult {
292                success: false,
293                output: String::new(),
294                error: Some(format!(
295                    "Unsupported content type: {content_type}. \
296                     web_fetch supports text/html, text/plain, text/markdown, and application/json."
297                )),
298            };
299        };
300
301        let body = match self.read_response_text_limited(response).await {
302            Ok(t) => t,
303            Err(e) => {
304                return ToolResult {
305                    success: false,
306                    output: String::new(),
307                    error: Some(format!("Failed to read response body: {e}")),
308                };
309            }
310        };
311
312        let text = if body_mode == "html" {
313            nanohtml2text::html2text(&body)
314        } else {
315            body
316        };
317
318        let output = self.truncate_response(&text);
319
320        ToolResult {
321            success: true,
322            output,
323            error: None,
324        }
325    }
326}
327
328#[async_trait]
329impl Tool for WebFetchTool {
330    fn name(&self) -> &str {
331        "web_fetch"
332    }
333
334    fn description(&self) -> &str {
335        "Fetch a web page and return its content as clean plain text. \
336         HTML pages are automatically converted to readable text. \
337         JSON and plain text responses are returned as-is. \
338         Only GET requests; follows redirects. \
339         Falls back to Firecrawl for JS-heavy/bot-blocked sites (if enabled). \
340         Security: allowlist-only domains, no local/private hosts."
341    }
342
343    fn parameters_schema(&self) -> serde_json::Value {
344        json!({
345            "type": "object",
346            "properties": {
347                "url": {
348                    "type": "string",
349                    "description": "The HTTP or HTTPS URL to fetch"
350                }
351            },
352            "required": ["url"]
353        })
354    }
355
356    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
357        let url = args.get("url").and_then(|v| v.as_str()).ok_or_else(|| {
358            ::zeroclaw_log::record!(
359                WARN,
360                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
361                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
362                    .with_attrs(::serde_json::json!({"param": "url"})),
363                "web_fetch: missing url parameter"
364            );
365            anyhow::Error::msg("Missing 'url' parameter")
366        })?;
367
368        if !self.security.can_act() {
369            return Ok(ToolResult {
370                success: false,
371                output: String::new(),
372                error: Some("Action blocked: autonomy is read-only".into()),
373            });
374        }
375
376        // Rate limiting is applied by the RateLimitedTool wrapper at
377        // registration time (see zeroclaw-runtime::tools::mod).
378
379        let url = match self.validate_url(url) {
380            Ok(v) => v,
381            Err(e) => {
382                return Ok(ToolResult {
383                    success: false,
384                    output: String::new(),
385                    error: Some(e.to_string()),
386                });
387            }
388        };
389
390        // Build client: follow redirects, set timeout, set User-Agent
391        let timeout_secs = if self.timeout_secs == 0 {
392            ::zeroclaw_log::record!(
393                WARN,
394                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
395                    .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
396                "web_fetch: timeout_secs is 0, using safe default of 30s"
397            );
398            30
399        } else {
400            self.timeout_secs
401        };
402
403        let allowed_domains = self.allowed_domains.clone();
404        let blocked_domains = self.blocked_domains.clone();
405        let allowed_private_hosts = self.allowed_private_hosts.clone();
406        let redirect_policy = reqwest::redirect::Policy::custom(move |attempt| {
407            if attempt.previous().len() >= 10 {
408                return attempt.error(std::io::Error::other("Too many redirects (max 10)"));
409            }
410
411            if let Err(err) = validate_target_url(
412                attempt.url().as_str(),
413                &allowed_domains,
414                &blocked_domains,
415                &allowed_private_hosts,
416                "web_fetch",
417            ) {
418                return attempt.error(std::io::Error::new(
419                    std::io::ErrorKind::PermissionDenied,
420                    format!("Blocked redirect target: {err}"),
421                ));
422            }
423
424            attempt.follow()
425        });
426
427        let builder = reqwest::Client::builder()
428            .timeout(Duration::from_secs(timeout_secs))
429            .connect_timeout(Duration::from_secs(10))
430            .redirect(redirect_policy)
431            .user_agent("ZeroClaw/0.1 (web_fetch)");
432        let builder =
433            zeroclaw_config::schema::apply_runtime_proxy_to_builder(builder, "tool.web_fetch");
434        let client = match builder.build() {
435            Ok(c) => c,
436            Err(e) => {
437                return Ok(ToolResult {
438                    success: false,
439                    output: String::new(),
440                    error: Some(format!("Failed to build HTTP client: {e}")),
441                });
442            }
443        };
444
445        let standard_result = self.standard_fetch(&client, &url).await;
446
447        // If standard fetch succeeded well enough, return it directly.
448        // Otherwise, try Firecrawl fallback if enabled.
449        if self.should_fallback_to_firecrawl(&standard_result) {
450            ::zeroclaw_log::record!(
451                INFO,
452                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
453                    .with_attrs(::serde_json::json!({"url": url})),
454                "web_fetch: standard fetch insufficient for , attempting Firecrawl fallback"
455            );
456            match Box::pin(self.fetch_via_firecrawl(&url)).await {
457                Ok(firecrawl_result) if firecrawl_result.success => {
458                    return Ok(firecrawl_result);
459                }
460                Ok(firecrawl_result) => {
461                    ::zeroclaw_log::record!(
462                        WARN,
463                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
464                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
465                        &format!(
466                            "web_fetch: Firecrawl fallback also failed: {:?}",
467                            firecrawl_result.error
468                        )
469                    );
470                    // Return original standard result if Firecrawl also failed
471                }
472                Err(e) => {
473                    ::zeroclaw_log::record!(
474                        WARN,
475                        ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
476                            .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
477                            .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
478                        "web_fetch: Firecrawl fallback error"
479                    );
480                }
481            }
482        }
483
484        Ok(standard_result)
485    }
486}
487
488// ── Helper functions (independent from http_request.rs per DRY rule-of-three) ──
489
490fn validate_target_url(
491    raw_url: &str,
492    allowed_domains: &[String],
493    blocked_domains: &[String],
494    allowed_private_hosts: &[String],
495    tool_name: &str,
496) -> anyhow::Result<String> {
497    validate_target_url_with_dns_check(
498        raw_url,
499        allowed_domains,
500        blocked_domains,
501        allowed_private_hosts,
502        tool_name,
503        validate_resolved_host_is_public,
504    )
505}
506
507fn validate_target_url_with_dns_check(
508    raw_url: &str,
509    allowed_domains: &[String],
510    blocked_domains: &[String],
511    allowed_private_hosts: &[String],
512    tool_name: &str,
513    validate_dns: impl FnOnce(&str) -> anyhow::Result<()>,
514) -> anyhow::Result<String> {
515    let url = raw_url.trim();
516
517    if url.is_empty() {
518        anyhow::bail!("URL cannot be empty");
519    }
520
521    if url.chars().any(char::is_whitespace) {
522        anyhow::bail!("URL cannot contain whitespace");
523    }
524
525    if !url.starts_with("http://") && !url.starts_with("https://") {
526        anyhow::bail!("Only http:// and https:// URLs are allowed");
527    }
528
529    if allowed_domains.is_empty() {
530        anyhow::bail!(
531            "{tool_name} tool is enabled but no allowed_domains are configured. \
532             Add [{tool_name}].allowed_domains in config.toml"
533        );
534    }
535
536    let host = extract_host(url)?;
537
538    // blocked_domains always takes precedence
539    if host_matches_allowlist(&host, blocked_domains) {
540        anyhow::bail!("Host '{host}' is in {tool_name}.blocked_domains");
541    }
542
543    let host_is_private_or_local = is_private_or_local_host(&host);
544    let private_host_allowed =
545        host_matches_private_allowlist(&host, allowed_private_hosts, host_is_private_or_local);
546
547    if host_is_private_or_local && !private_host_allowed {
548        anyhow::bail!(
549            "Blocked local/private host: {host}. \
550             To allow this host, add it to {tool_name}.allowed_private_hosts in config.toml"
551        );
552    }
553
554    if private_host_allowed {
555        ::zeroclaw_log::record!(
556            WARN,
557            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
558                .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
559                .with_attrs(::serde_json::json!({"tool_name": tool_name, "host": host})),
560            "web_fetch: allowing host via allowed_private_hosts"
561        );
562    }
563
564    if !private_host_allowed && !host_matches_allowlist(&host, allowed_domains) {
565        anyhow::bail!("Host '{host}' is not in {tool_name}.allowed_domains");
566    }
567
568    if !private_host_allowed {
569        validate_dns(&host)?;
570    }
571
572    Ok(url.to_string())
573}
574
575fn append_chunk_with_cap(buffer: &mut Vec<u8>, chunk: &[u8], hard_cap: usize) -> bool {
576    if buffer.len() >= hard_cap {
577        return true;
578    }
579
580    let remaining = hard_cap - buffer.len();
581    if chunk.len() > remaining {
582        buffer.extend_from_slice(&chunk[..remaining]);
583        return true;
584    }
585
586    buffer.extend_from_slice(chunk);
587    buffer.len() >= hard_cap
588}
589
590fn normalize_allowed_domains(domains: Vec<String>, label: &str) -> anyhow::Result<Vec<String>> {
591    let mut rejected = Vec::new();
592    let mut normalized = domains
593        .into_iter()
594        .filter_map(|d| {
595            normalize_domain(&d).or_else(|| {
596                rejected.push(d.clone());
597                None
598            })
599        })
600        .collect::<Vec<_>>();
601    if !rejected.is_empty() {
602        anyhow::bail!(
603            "Invalid {label} entry(s): [{}]. Each entry must be a valid domain, hostname, IPv4, or IPv6 address.",
604            rejected.join(", ")
605        );
606    }
607    normalized.sort_unstable();
608    normalized.dedup();
609    Ok(normalized)
610}
611
612fn normalize_domain(raw: &str) -> Option<String> {
613    let input = raw.trim();
614    if input.is_empty() || input.chars().any(char::is_whitespace) {
615        return None;
616    }
617
618    let bare_ip = match (input.starts_with('['), input.ends_with(']')) {
619        (true, true) => &input[1..input.len() - 1],
620        (false, false) => input,
621        _ => return None,
622    };
623    if let Ok(ip) = bare_ip.parse::<std::net::IpAddr>() {
624        return Some(ip.to_string().to_lowercase());
625    }
626
627    let parsed = reqwest::Url::parse(input)
628        .or_else(|_| reqwest::Url::parse(&format!("https://{input}")))
629        .ok()?;
630
631    if !parsed.username().is_empty() || parsed.password().is_some() {
632        return None;
633    }
634
635    let host = parsed.host_str()?;
636    let trimmed = host.trim();
637    let host_no_brackets = match (trimmed.starts_with('['), trimmed.ends_with(']')) {
638        (true, true) => &trimmed[1..trimmed.len() - 1],
639        (false, false) => trimmed,
640        _ => return None,
641    };
642    let normalized = host_no_brackets
643        .trim_start_matches('.')
644        .trim_end_matches('.');
645    if normalized.is_empty() {
646        return None;
647    }
648
649    Some(normalized.to_lowercase())
650}
651
652fn extract_host(url: &str) -> anyhow::Result<String> {
653    let rest = url
654        .strip_prefix("http://")
655        .or_else(|| url.strip_prefix("https://"))
656        .ok_or_else(|| {
657            ::zeroclaw_log::record!(
658                WARN,
659                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
660                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
661                    .with_attrs(::serde_json::json!({"url": url})),
662                "web_fetch: non-http(s) URL rejected"
663            );
664            anyhow::Error::msg("Only http:// and https:// URLs are allowed")
665        })?;
666
667    let authority = rest.split(['/', '?', '#']).next().ok_or_else(|| {
668        ::zeroclaw_log::record!(
669            WARN,
670            ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
671                .with_outcome(::zeroclaw_log::EventOutcome::Failure)
672                .with_attrs(::serde_json::json!({"url": url})),
673            "web_fetch: invalid URL"
674        );
675        anyhow::Error::msg("Invalid URL")
676    })?;
677
678    if authority.is_empty() {
679        anyhow::bail!("URL must include a host");
680    }
681
682    if authority.contains('@') {
683        anyhow::bail!("URL userinfo is not allowed");
684    }
685
686    if authority.starts_with('[') {
687        anyhow::bail!("IPv6 hosts are not supported in web_fetch");
688    }
689
690    let host = authority
691        .split(':')
692        .next()
693        .unwrap_or_default()
694        .trim()
695        .trim_end_matches('.')
696        .to_lowercase();
697
698    if host.is_empty() {
699        anyhow::bail!("URL must include a valid host");
700    }
701
702    Ok(host)
703}
704
705fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
706    if allowed_domains.iter().any(|domain| domain == "*") {
707        return true;
708    }
709
710    allowed_domains.iter().any(|domain| {
711        host == domain
712            || host
713                .strip_suffix(domain)
714                .is_some_and(|prefix| prefix.ends_with('.'))
715    })
716}
717
718fn host_matches_private_allowlist(
719    host: &str,
720    allowed_private_hosts: &[String],
721    host_is_private_or_local: bool,
722) -> bool {
723    allowed_private_hosts.iter().any(|domain| {
724        if domain == "*" {
725            host_is_private_or_local
726        } else {
727            host == domain
728                || host
729                    .strip_suffix(domain)
730                    .is_some_and(|prefix| prefix.ends_with('.'))
731        }
732    })
733}
734
735fn is_private_or_local_host(host: &str) -> bool {
736    let bare = host
737        .strip_prefix('[')
738        .and_then(|h| h.strip_suffix(']'))
739        .unwrap_or(host);
740
741    let has_local_tld = bare
742        .rsplit('.')
743        .next()
744        .is_some_and(|label| label == "local");
745
746    if bare == "localhost" || bare.ends_with(".localhost") || has_local_tld {
747        return true;
748    }
749
750    if let Ok(ip) = bare.parse::<std::net::IpAddr>() {
751        return match ip {
752            std::net::IpAddr::V4(v4) => is_non_global_v4(v4),
753            std::net::IpAddr::V6(v6) => is_non_global_v6(v6),
754        };
755    }
756
757    false
758}
759
760#[cfg(not(test))]
761fn validate_resolved_host_is_public(host: &str) -> anyhow::Result<()> {
762    use std::net::ToSocketAddrs;
763
764    let ips = (host, 0)
765        .to_socket_addrs()
766        .map_err(|e| {
767            ::zeroclaw_log::record!(
768                ERROR,
769                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
770                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
771                    .with_attrs(::serde_json::json!({
772                        "host": host,
773                        "error": format!("{}", e),
774                    })),
775                "web_fetch: failed to resolve host"
776            );
777            anyhow::Error::msg(format!("Failed to resolve host '{host}': {e}"))
778        })?
779        .map(|addr| addr.ip())
780        .collect::<Vec<_>>();
781
782    validate_resolved_ips_are_public(host, &ips)
783}
784
785#[cfg(test)]
786fn validate_resolved_host_is_public(_host: &str) -> anyhow::Result<()> {
787    // DNS checks are covered by validate_resolved_ips_are_public unit tests.
788    Ok(())
789}
790
791fn validate_resolved_ips_are_public(host: &str, ips: &[std::net::IpAddr]) -> anyhow::Result<()> {
792    if ips.is_empty() {
793        anyhow::bail!("Failed to resolve host '{host}'");
794    }
795
796    for ip in ips {
797        let non_global = match ip {
798            std::net::IpAddr::V4(v4) => is_non_global_v4(*v4),
799            std::net::IpAddr::V6(v6) => is_non_global_v6(*v6),
800        };
801        if non_global {
802            anyhow::bail!("Blocked host '{host}' resolved to non-global address {ip}");
803        }
804    }
805
806    Ok(())
807}
808
809fn is_non_global_v4(v4: std::net::Ipv4Addr) -> bool {
810    let [a, b, c, _] = v4.octets();
811    v4.is_loopback()
812        || v4.is_private()
813        || v4.is_link_local()
814        || v4.is_unspecified()
815        || v4.is_broadcast()
816        || v4.is_multicast()
817        || (a == 100 && (64..=127).contains(&b))
818        || a >= 240
819        || (a == 192 && b == 0 && (c == 0 || c == 2))
820        || (a == 198 && b == 51)
821        || (a == 203 && b == 0)
822        || (a == 198 && (18..=19).contains(&b))
823}
824
825fn is_non_global_v6(v6: std::net::Ipv6Addr) -> bool {
826    let segs = v6.segments();
827    v6.is_loopback()
828        || v6.is_unspecified()
829        || v6.is_multicast()
830        || (segs[0] & 0xfe00) == 0xfc00
831        || (segs[0] & 0xffc0) == 0xfe80
832        || (segs[0] == 0x2001 && segs[1] == 0x0db8)
833        || v6.to_ipv4_mapped().is_some_and(is_non_global_v4)
834}
835
836#[cfg(test)]
837mod tests {
838    use super::*;
839    use zeroclaw_config::autonomy::AutonomyLevel;
840    use zeroclaw_config::policy::SecurityPolicy;
841    use zeroclaw_config::schema::FirecrawlConfig;
842
843    fn test_tool(allowed_domains: Vec<&str>) -> WebFetchTool {
844        test_tool_with_blocklist(allowed_domains, vec![])
845    }
846
847    fn test_tool_with_blocklist(
848        allowed_domains: Vec<&str>,
849        blocked_domains: Vec<&str>,
850    ) -> WebFetchTool {
851        let security = Arc::new(SecurityPolicy {
852            autonomy: AutonomyLevel::Supervised,
853            ..SecurityPolicy::default()
854        });
855        WebFetchTool::new(
856            security,
857            allowed_domains.into_iter().map(String::from).collect(),
858            blocked_domains.into_iter().map(String::from).collect(),
859            500_000,
860            30,
861            FirecrawlConfig::default(),
862            vec![],
863        )
864        .unwrap()
865    }
866
867    fn test_tool_with_private_hosts(
868        allowed_domains: Vec<&str>,
869        blocked_domains: Vec<&str>,
870        allowed_private_hosts: Vec<&str>,
871    ) -> WebFetchTool {
872        let security = Arc::new(SecurityPolicy {
873            autonomy: AutonomyLevel::Supervised,
874            ..SecurityPolicy::default()
875        });
876        WebFetchTool::new(
877            security,
878            allowed_domains.into_iter().map(String::from).collect(),
879            blocked_domains.into_iter().map(String::from).collect(),
880            500_000,
881            30,
882            FirecrawlConfig::default(),
883            allowed_private_hosts
884                .into_iter()
885                .map(String::from)
886                .collect(),
887        )
888        .unwrap()
889    }
890
891    fn test_tool_with_firecrawl(firecrawl: FirecrawlConfig) -> WebFetchTool {
892        let security = Arc::new(SecurityPolicy {
893            autonomy: AutonomyLevel::Supervised,
894            ..SecurityPolicy::default()
895        });
896        WebFetchTool::new(
897            security,
898            vec!["*".into()],
899            vec![],
900            500_000,
901            30,
902            firecrawl,
903            vec![],
904        )
905        .unwrap()
906    }
907
908    // ── Name and schema ──────────────────────────────────────────
909
910    #[test]
911    fn name_is_web_fetch() {
912        let tool = test_tool(vec!["example.com"]);
913        assert_eq!(tool.name(), "web_fetch");
914    }
915
916    #[test]
917    fn parameters_schema_requires_url() {
918        let tool = test_tool(vec!["example.com"]);
919        let schema = tool.parameters_schema();
920        assert!(schema["properties"]["url"].is_object());
921        let required = schema["required"].as_array().unwrap();
922        assert!(required.iter().any(|v| v.as_str() == Some("url")));
923    }
924
925    // ── HTML to text conversion ──────────────────────────────────
926
927    #[test]
928    fn html_to_text_conversion() {
929        let html = "<html><body><h1>Title</h1><p>Hello <b>world</b></p></body></html>";
930        let text = nanohtml2text::html2text(html);
931        assert!(text.contains("Title"));
932        assert!(text.contains("Hello"));
933        assert!(text.contains("world"));
934        assert!(!text.contains("<h1>"));
935        assert!(!text.contains("<p>"));
936    }
937
938    // ── URL validation ───────────────────────────────────────────
939
940    #[test]
941    fn validate_accepts_exact_domain() {
942        let tool = test_tool(vec!["example.com"]);
943        let got = tool.validate_url("https://example.com/page").unwrap();
944        assert_eq!(got, "https://example.com/page");
945    }
946
947    #[test]
948    fn validate_accepts_subdomain() {
949        let tool = test_tool(vec!["example.com"]);
950        assert!(tool.validate_url("https://docs.example.com/guide").is_ok());
951    }
952
953    #[test]
954    fn validate_accepts_wildcard() {
955        let tool = test_tool(vec!["*"]);
956        assert!(tool.validate_url("https://news.ycombinator.com").is_ok());
957    }
958
959    #[test]
960    fn validate_rejects_empty_url() {
961        let tool = test_tool(vec!["example.com"]);
962        let err = tool.validate_url("").unwrap_err().to_string();
963        assert!(err.contains("empty"));
964    }
965
966    #[test]
967    fn validate_rejects_missing_url() {
968        let tool = test_tool(vec!["example.com"]);
969        let err = tool.validate_url("  ").unwrap_err().to_string();
970        assert!(err.contains("empty"));
971    }
972
973    #[test]
974    fn validate_rejects_ftp_scheme() {
975        let tool = test_tool(vec!["example.com"]);
976        let err = tool
977            .validate_url("ftp://example.com")
978            .unwrap_err()
979            .to_string();
980        assert!(err.contains("http://") || err.contains("https://"));
981    }
982
983    #[test]
984    fn validate_rejects_allowlist_miss() {
985        let tool = test_tool(vec!["example.com"]);
986        let err = tool
987            .validate_url("https://google.com")
988            .unwrap_err()
989            .to_string();
990        assert!(err.contains("allowed_domains"));
991    }
992
993    #[test]
994    fn validate_requires_allowlist() {
995        let security = Arc::new(SecurityPolicy::default());
996        let tool = WebFetchTool::new(
997            security,
998            vec![],
999            vec![],
1000            500_000,
1001            30,
1002            FirecrawlConfig::default(),
1003            vec![],
1004        )
1005        .unwrap();
1006        let err = tool
1007            .validate_url("https://example.com")
1008            .unwrap_err()
1009            .to_string();
1010        assert!(err.contains("allowed_domains"));
1011    }
1012
1013    // ── SSRF protection ──────────────────────────────────────────
1014
1015    #[test]
1016    fn ssrf_blocks_localhost() {
1017        let tool = test_tool(vec!["localhost"]);
1018        let err = tool
1019            .validate_url("https://localhost:8080")
1020            .unwrap_err()
1021            .to_string();
1022        assert!(err.contains("local/private"));
1023    }
1024
1025    #[test]
1026    fn ssrf_blocks_private_ipv4() {
1027        let tool = test_tool(vec!["192.168.1.5"]);
1028        let err = tool
1029            .validate_url("https://192.168.1.5")
1030            .unwrap_err()
1031            .to_string();
1032        assert!(err.contains("local/private"));
1033    }
1034
1035    #[test]
1036    fn ssrf_blocks_loopback() {
1037        assert!(is_private_or_local_host("127.0.0.1"));
1038        assert!(is_private_or_local_host("127.0.0.2"));
1039    }
1040
1041    #[test]
1042    fn ssrf_blocks_rfc1918() {
1043        assert!(is_private_or_local_host("10.0.0.1"));
1044        assert!(is_private_or_local_host("172.16.0.1"));
1045        assert!(is_private_or_local_host("192.168.1.1"));
1046    }
1047
1048    #[test]
1049    fn ssrf_wildcard_still_blocks_private() {
1050        let tool = test_tool(vec!["*"]);
1051        let err = tool
1052            .validate_url("https://localhost:8080")
1053            .unwrap_err()
1054            .to_string();
1055        assert!(err.contains("local/private"));
1056    }
1057
1058    #[test]
1059    fn redirect_target_validation_allows_permitted_host() {
1060        let allowed = vec!["example.com".to_string()];
1061        let blocked = vec![];
1062        assert!(
1063            validate_target_url(
1064                "https://docs.example.com/page",
1065                &allowed,
1066                &blocked,
1067                &[],
1068                "web_fetch"
1069            )
1070            .is_ok()
1071        );
1072    }
1073
1074    #[test]
1075    fn redirect_target_validation_blocks_private_host() {
1076        let allowed = vec!["example.com".to_string()];
1077        let blocked = vec![];
1078        let err = validate_target_url(
1079            "https://127.0.0.1/admin",
1080            &allowed,
1081            &blocked,
1082            &[],
1083            "web_fetch",
1084        )
1085        .unwrap_err()
1086        .to_string();
1087        assert!(err.contains("local/private"));
1088    }
1089
1090    #[test]
1091    fn redirect_target_validation_blocks_blocklisted_host() {
1092        let allowed = vec!["*".to_string()];
1093        let blocked = vec!["evil.com".to_string()];
1094        let err = validate_target_url(
1095            "https://evil.com/phish",
1096            &allowed,
1097            &blocked,
1098            &[],
1099            "web_fetch",
1100        )
1101        .unwrap_err()
1102        .to_string();
1103        assert!(err.contains("blocked_domains"));
1104    }
1105
1106    // ── Security policy ──────────────────────────────────────────
1107
1108    #[tokio::test]
1109    async fn blocks_readonly_mode() {
1110        let security = Arc::new(SecurityPolicy {
1111            autonomy: AutonomyLevel::ReadOnly,
1112            ..SecurityPolicy::default()
1113        });
1114        let tool = WebFetchTool::new(
1115            security,
1116            vec!["example.com".into()],
1117            vec![],
1118            500_000,
1119            30,
1120            FirecrawlConfig::default(),
1121            vec![],
1122        )
1123        .unwrap();
1124        let result = tool
1125            .execute(json!({"url": "https://example.com"}))
1126            .await
1127            .unwrap();
1128        assert!(!result.success);
1129        assert!(result.error.unwrap().contains("read-only"));
1130    }
1131
1132    // ── Response truncation ──────────────────────────────────────
1133
1134    #[test]
1135    fn truncate_within_limit() {
1136        let tool = test_tool(vec!["example.com"]);
1137        let text = "hello world";
1138        assert_eq!(tool.truncate_response(text), "hello world");
1139    }
1140
1141    #[test]
1142    fn truncate_response_zero_means_unlimited() {
1143        // max_response_size == 0 must be treated as unlimited — no truncation
1144        // marker, full text returned regardless of length.
1145        let tool = WebFetchTool::new(
1146            Arc::new(SecurityPolicy::default()),
1147            vec!["example.com".into()],
1148            vec![],
1149            0, // unlimited
1150            30,
1151            FirecrawlConfig::default(),
1152            vec![],
1153        )
1154        .unwrap();
1155        let long_text = "x".repeat(10_000);
1156        let result = tool.truncate_response(&long_text);
1157        assert_eq!(result.len(), 10_000, "zero limit must not truncate");
1158        assert!(
1159            !result.contains("[Response truncated"),
1160            "must not append truncation marker"
1161        );
1162    }
1163
1164    /// Drives the actual streamed-read path (standard_fetch +
1165    /// read_response_text_limited) via wiremock to lock in the
1166    /// max_response_size=0 behaviour. Audacity88 review (PR #6884)
1167    /// flagged the direct-helper test as insufficient because it
1168    /// did not exercise the saturating_add(1) cap that previously
1169    /// stopped streaming after 1 byte and triggered spurious
1170    /// Firecrawl fallback.
1171    #[tokio::test]
1172    async fn standard_fetch_with_zero_limit_returns_full_body_and_skips_firecrawl_fallback() {
1173        use wiremock::matchers::method;
1174        use wiremock::{Mock, MockServer, ResponseTemplate};
1175
1176        let server = MockServer::start().await;
1177        let addr = server.address();
1178
1179        // Body must exceed FIRECRAWL_MIN_BODY_LEN (100 bytes) so any
1180        // truncation to <100 bytes would (incorrectly) trigger fallback.
1181        let body = "a".repeat(500);
1182        Mock::given(method("GET"))
1183            .respond_with(ResponseTemplate::new(200).set_body_string(body.clone()))
1184            .mount(&server)
1185            .await;
1186
1187        let tool = WebFetchTool::new(
1188            Arc::new(SecurityPolicy {
1189                autonomy: AutonomyLevel::Supervised,
1190                ..SecurityPolicy::default()
1191            }),
1192            vec!["*".into()],
1193            vec![],
1194            0, // max_response_size = unlimited
1195            30,
1196            FirecrawlConfig {
1197                enabled: true,
1198                ..FirecrawlConfig::default()
1199            },
1200            vec![],
1201        )
1202        .unwrap();
1203
1204        // Bypass SSRF-guarded execute() — call standard_fetch directly so
1205        // wiremock on 127.0.0.1 is reachable.
1206        let url = format!("http://{}:{}/", addr.ip(), addr.port());
1207        let client = reqwest::Client::builder()
1208            .timeout(std::time::Duration::from_secs(5))
1209            .build()
1210            .expect("reqwest client");
1211        let standard_result = tool.standard_fetch(&client, &url).await;
1212
1213        // (a) standard result IS the full body — proves streamed read did
1214        // not stop after 1 byte under the zero-limit path.
1215        assert!(
1216            standard_result.success,
1217            "standard_fetch must succeed, got error={:?}",
1218            standard_result.error
1219        );
1220        assert_eq!(
1221            standard_result.output.len(),
1222            body.len(),
1223            "streamed body length under zero-limit must equal full body"
1224        );
1225        assert_eq!(
1226            standard_result.output, body,
1227            "streamed body content must equal full body"
1228        );
1229        assert!(
1230            !standard_result.output.contains("[Response truncated"),
1231            "must not append truncation marker under zero limit"
1232        );
1233
1234        // (b) result does NOT trip should_fallback_to_firecrawl — proves
1235        // the regression (1-byte short body) is locked out.
1236        assert!(
1237            !tool.should_fallback_to_firecrawl(&standard_result),
1238            "500-byte body under zero limit must not trigger Firecrawl fallback"
1239        );
1240    }
1241
1242    #[test]
1243    fn truncate_over_limit() {
1244        let tool = WebFetchTool::new(
1245            Arc::new(SecurityPolicy::default()),
1246            vec!["example.com".into()],
1247            vec![],
1248            10,
1249            30,
1250            FirecrawlConfig::default(),
1251            vec![],
1252        )
1253        .unwrap();
1254        let text = "hello world this is long";
1255        let truncated = tool.truncate_response(text);
1256        assert!(truncated.contains("[Response truncated"));
1257    }
1258
1259    // ── Domain normalization ─────────────────────────────────────
1260
1261    #[test]
1262    fn normalize_domain_strips_scheme_and_case() {
1263        let got = normalize_domain("  HTTPS://Docs.Example.com/path ").unwrap();
1264        assert_eq!(got, "docs.example.com");
1265    }
1266
1267    #[test]
1268    fn normalize_domain_rejects_userinfo() {
1269        assert!(normalize_domain("https://user@example.com").is_none());
1270        assert!(normalize_domain("user@example.com").is_none());
1271        assert!(normalize_domain("https://user:pass@example.com").is_none());
1272        assert!(normalize_domain("user:pass@example.com").is_none());
1273    }
1274
1275    #[test]
1276    fn normalize_domain_rejects_unmatched_brackets() {
1277        assert!(normalize_domain("[::1").is_none());
1278        assert!(normalize_domain("::1]").is_none());
1279        assert!(normalize_domain("[127.0.0.1").is_none());
1280        assert!(normalize_domain("127.0.0.1]").is_none());
1281    }
1282
1283    #[test]
1284    fn normalize_deduplicates() {
1285        let got = normalize_allowed_domains(
1286            vec![
1287                "example.com".into(),
1288                "EXAMPLE.COM".into(),
1289                "https://example.com/".into(),
1290            ],
1291            "test",
1292        )
1293        .unwrap();
1294        assert_eq!(got, vec!["example.com".to_string()]);
1295    }
1296
1297    // ── Blocked domains ──────────────────────────────────────────
1298
1299    #[test]
1300    fn blocklist_rejects_exact_match() {
1301        let tool = test_tool_with_blocklist(vec!["*"], vec!["evil.com"]);
1302        let err = tool
1303            .validate_url("https://evil.com/page")
1304            .unwrap_err()
1305            .to_string();
1306        assert!(err.contains("blocked_domains"));
1307    }
1308
1309    #[test]
1310    fn blocklist_rejects_subdomain() {
1311        let tool = test_tool_with_blocklist(vec!["*"], vec!["evil.com"]);
1312        let err = tool
1313            .validate_url("https://api.evil.com/v1")
1314            .unwrap_err()
1315            .to_string();
1316        assert!(err.contains("blocked_domains"));
1317    }
1318
1319    #[test]
1320    fn blocklist_wins_over_allowlist() {
1321        let tool = test_tool_with_blocklist(vec!["evil.com"], vec!["evil.com"]);
1322        let err = tool
1323            .validate_url("https://evil.com")
1324            .unwrap_err()
1325            .to_string();
1326        assert!(err.contains("blocked_domains"));
1327    }
1328
1329    #[test]
1330    fn blocklist_allows_non_blocked() {
1331        let tool = test_tool_with_blocklist(vec!["*"], vec!["evil.com"]);
1332        assert!(tool.validate_url("https://example.com").is_ok());
1333    }
1334
1335    #[test]
1336    fn append_chunk_with_cap_truncates_and_stops() {
1337        let mut buffer = Vec::new();
1338        assert!(!append_chunk_with_cap(&mut buffer, b"hello", 8));
1339        assert!(append_chunk_with_cap(&mut buffer, b"world", 8));
1340        assert_eq!(buffer, b"hellowor");
1341    }
1342
1343    #[test]
1344    fn resolved_private_ip_is_rejected() {
1345        let ips = vec!["127.0.0.1".parse().unwrap()];
1346        let err = validate_resolved_ips_are_public("example.com", &ips)
1347            .unwrap_err()
1348            .to_string();
1349        assert!(err.contains("non-global address"));
1350    }
1351
1352    #[test]
1353    fn resolved_mixed_ips_are_rejected() {
1354        let ips = vec![
1355            "93.184.216.34".parse().unwrap(),
1356            "10.0.0.1".parse().unwrap(),
1357        ];
1358        let err = validate_resolved_ips_are_public("example.com", &ips)
1359            .unwrap_err()
1360            .to_string();
1361        assert!(err.contains("non-global address"));
1362    }
1363
1364    #[test]
1365    fn resolved_public_ips_are_allowed() {
1366        let ips = vec!["93.184.216.34".parse().unwrap(), "1.1.1.1".parse().unwrap()];
1367        assert!(validate_resolved_ips_are_public("example.com", &ips).is_ok());
1368    }
1369
1370    // ── Firecrawl config parsing ────────────────────────────────────
1371
1372    #[test]
1373    fn firecrawl_config_defaults() {
1374        let cfg = FirecrawlConfig::default();
1375        assert!(!cfg.enabled);
1376        assert_eq!(cfg.api_key_env, "FIRECRAWL_API_KEY");
1377        assert_eq!(cfg.api_url, "https://api.firecrawl.dev/v1");
1378        assert_eq!(cfg.mode, zeroclaw_config::schema::FirecrawlMode::Scrape);
1379    }
1380
1381    #[test]
1382    fn firecrawl_config_deserializes_from_toml() {
1383        let toml_str = r#"
1384            enabled = true
1385            api_key_env = "MY_FC_KEY"
1386            api_url = "https://custom.firecrawl.io/v2"
1387            mode = "crawl"
1388        "#;
1389        let cfg: FirecrawlConfig = toml::from_str(toml_str).unwrap();
1390        assert!(cfg.enabled);
1391        assert_eq!(cfg.api_key_env, "MY_FC_KEY");
1392        assert_eq!(cfg.api_url, "https://custom.firecrawl.io/v2");
1393        assert_eq!(cfg.mode, zeroclaw_config::schema::FirecrawlMode::Crawl);
1394    }
1395
1396    #[test]
1397    fn firecrawl_config_deserializes_defaults_from_empty_toml() {
1398        let cfg: FirecrawlConfig = toml::from_str("").unwrap();
1399        assert!(!cfg.enabled);
1400        assert_eq!(cfg.api_key_env, "FIRECRAWL_API_KEY");
1401    }
1402
1403    #[test]
1404    fn web_fetch_config_with_firecrawl_section() {
1405        use zeroclaw_config::schema::WebFetchConfig;
1406        let toml_str = r#"
1407            enabled = true
1408            [firecrawl]
1409            enabled = true
1410            api_key_env = "FC_KEY"
1411        "#;
1412        let cfg: WebFetchConfig = toml::from_str(toml_str).unwrap();
1413        assert!(cfg.enabled);
1414        assert!(cfg.firecrawl.enabled);
1415        assert_eq!(cfg.firecrawl.api_key_env, "FC_KEY");
1416    }
1417
1418    // ── Firecrawl fallback trigger conditions ───────────────────────
1419
1420    #[test]
1421    fn fallback_disabled_when_firecrawl_not_enabled() {
1422        let tool = test_tool_with_firecrawl(FirecrawlConfig::default());
1423        let result = ToolResult {
1424            success: false,
1425            output: String::new(),
1426            error: Some("HTTP 403 Forbidden".into()),
1427        };
1428        assert!(!tool.should_fallback_to_firecrawl(&result));
1429    }
1430
1431    #[test]
1432    fn fallback_triggers_on_http_error() {
1433        let tool = test_tool_with_firecrawl(FirecrawlConfig {
1434            enabled: true,
1435            ..FirecrawlConfig::default()
1436        });
1437        let result = ToolResult {
1438            success: false,
1439            output: String::new(),
1440            error: Some("HTTP 403 Forbidden".into()),
1441        };
1442        assert!(tool.should_fallback_to_firecrawl(&result));
1443    }
1444
1445    #[test]
1446    fn fallback_triggers_on_empty_body() {
1447        let tool = test_tool_with_firecrawl(FirecrawlConfig {
1448            enabled: true,
1449            ..FirecrawlConfig::default()
1450        });
1451        let result = ToolResult {
1452            success: true,
1453            output: String::new(),
1454            error: None,
1455        };
1456        assert!(tool.should_fallback_to_firecrawl(&result));
1457    }
1458
1459    #[test]
1460    fn fallback_triggers_on_short_body() {
1461        let tool = test_tool_with_firecrawl(FirecrawlConfig {
1462            enabled: true,
1463            ..FirecrawlConfig::default()
1464        });
1465        let result = ToolResult {
1466            success: true,
1467            output: "Loading...".into(), // < 100 chars, JS-only page
1468            error: None,
1469        };
1470        assert!(tool.should_fallback_to_firecrawl(&result));
1471    }
1472
1473    #[test]
1474    fn fallback_skipped_on_good_response() {
1475        let tool = test_tool_with_firecrawl(FirecrawlConfig {
1476            enabled: true,
1477            ..FirecrawlConfig::default()
1478        });
1479        let result = ToolResult {
1480            success: true,
1481            output: "A".repeat(200), // well above 100 chars
1482            error: None,
1483        };
1484        assert!(!tool.should_fallback_to_firecrawl(&result));
1485    }
1486
1487    // ── Firecrawl response parsing ──────────────────────────────────
1488
1489    #[test]
1490    fn firecrawl_response_parses_markdown() {
1491        let response_json = json!({
1492            "success": true,
1493            "data": {
1494                "markdown": "# Hello World\n\nThis is extracted content from Firecrawl.",
1495                "metadata": {
1496                    "title": "Test Page"
1497                }
1498            }
1499        });
1500        let markdown = response_json
1501            .get("data")
1502            .and_then(|d| d.get("markdown"))
1503            .and_then(|m| m.as_str())
1504            .unwrap_or("");
1505        assert!(markdown.contains("Hello World"));
1506        assert!(markdown.contains("extracted content"));
1507    }
1508
1509    #[test]
1510    fn firecrawl_response_handles_missing_markdown() {
1511        let response_json = json!({
1512            "success": true,
1513            "data": {}
1514        });
1515        let markdown = response_json
1516            .get("data")
1517            .and_then(|d| d.get("markdown"))
1518            .and_then(|m| m.as_str())
1519            .unwrap_or("");
1520        assert!(markdown.is_empty());
1521    }
1522
1523    #[test]
1524    fn firecrawl_response_handles_missing_data() {
1525        let response_json = json!({
1526            "success": false,
1527            "error": "Rate limit exceeded"
1528        });
1529        let markdown = response_json
1530            .get("data")
1531            .and_then(|d| d.get("markdown"))
1532            .and_then(|m| m.as_str())
1533            .unwrap_or("");
1534        assert!(markdown.is_empty());
1535    }
1536
1537    // ── Boundary test: FIRECRAWL_MIN_BODY_LEN (100 chars) ────────────
1538
1539    #[test]
1540    fn fallback_triggers_at_exactly_99_chars() {
1541        let tool = test_tool_with_firecrawl(FirecrawlConfig {
1542            enabled: true,
1543            ..FirecrawlConfig::default()
1544        });
1545        let result = ToolResult {
1546            success: true,
1547            output: "A".repeat(99),
1548            error: None,
1549        };
1550        assert!(
1551            tool.should_fallback_to_firecrawl(&result),
1552            "99-char body (below threshold) should trigger fallback"
1553        );
1554    }
1555
1556    #[test]
1557    fn fallback_skipped_at_exactly_100_chars() {
1558        let tool = test_tool_with_firecrawl(FirecrawlConfig {
1559            enabled: true,
1560            ..FirecrawlConfig::default()
1561        });
1562        let result = ToolResult {
1563            success: true,
1564            output: "A".repeat(100),
1565            error: None,
1566        };
1567        assert!(
1568            !tool.should_fallback_to_firecrawl(&result),
1569            "100-char body (at threshold) should NOT trigger fallback"
1570        );
1571    }
1572
1573    // ── Item 1: missing API key env var falls back gracefully ─────────
1574
1575    #[tokio::test]
1576    async fn firecrawl_missing_api_key_returns_error() {
1577        // Ensure the env var is unset for this test
1578        // SAFETY: test-only, single-threaded test runner.
1579        unsafe { std::env::remove_var("FIRECRAWL_TEST_MISSING_KEY") };
1580
1581        let tool = test_tool_with_firecrawl(FirecrawlConfig {
1582            enabled: true,
1583            api_key_env: "FIRECRAWL_TEST_MISSING_KEY".into(),
1584            ..FirecrawlConfig::default()
1585        });
1586
1587        let result = tool.fetch_via_firecrawl("https://example.com").await;
1588        assert!(
1589            result.is_err(),
1590            "fetch_via_firecrawl should return Err when API key env var is missing"
1591        );
1592        let err_msg = result.unwrap_err().to_string();
1593        assert!(
1594            err_msg.contains("FIRECRAWL_TEST_MISSING_KEY"),
1595            "Error should mention the missing env var name, got: {err_msg}"
1596        );
1597    }
1598
1599    // ── Item 2: double-failure returns original standard result ───────
1600
1601    #[tokio::test]
1602    async fn execute_double_failure_returns_original_result() {
1603        use wiremock::matchers::method;
1604        use wiremock::{Mock, MockServer, ResponseTemplate};
1605
1606        let server = MockServer::start().await;
1607        let addr = server.address();
1608
1609        // Standard fetch returns 403 (failure)
1610        Mock::given(method("GET"))
1611            .respond_with(ResponseTemplate::new(403))
1612            .mount(&server)
1613            .await;
1614
1615        // Ensure Firecrawl API key env is missing so fallback also fails
1616        // SAFETY: test-only, single-threaded test runner.
1617        unsafe { std::env::remove_var("FIRECRAWL_DOUBLE_FAIL_KEY") };
1618
1619        let security = Arc::new(SecurityPolicy {
1620            autonomy: AutonomyLevel::Supervised,
1621            ..SecurityPolicy::default()
1622        });
1623        let tool = WebFetchTool::new(
1624            security,
1625            vec!["*".into()],
1626            vec![],
1627            500_000,
1628            30,
1629            FirecrawlConfig {
1630                enabled: true,
1631                api_key_env: "FIRECRAWL_DOUBLE_FAIL_KEY".into(),
1632                api_url: format!("http://{addr}"),
1633                ..FirecrawlConfig::default()
1634            },
1635            vec![],
1636        )
1637        .unwrap();
1638
1639        // Bypass SSRF-guarded execute() — call standard_fetch + fallback
1640        // logic directly so wiremock on 127.0.0.1 is reachable.
1641        let client = reqwest::Client::builder()
1642            .timeout(Duration::from_secs(30))
1643            .build()
1644            .unwrap();
1645
1646        let url = format!("http://{addr}/page");
1647        let standard_result = tool.standard_fetch(&client, &url).await;
1648
1649        // standard_fetch should fail with 403
1650        assert!(!standard_result.success);
1651        assert!(tool.should_fallback_to_firecrawl(&standard_result));
1652
1653        // Firecrawl fallback should also fail (missing API key)
1654        let firecrawl_result = Box::pin(tool.fetch_via_firecrawl(&url)).await;
1655        assert!(
1656            firecrawl_result.is_err() || !firecrawl_result.as_ref().unwrap().success,
1657            "Expected Firecrawl fallback to fail without API key"
1658        );
1659
1660        // The orchestration should return the original 403 error
1661        assert!(
1662            standard_result
1663                .error
1664                .as_deref()
1665                .unwrap_or("")
1666                .contains("403"),
1667            "Expected original HTTP 403 error, got: {:?}",
1668            standard_result.error
1669        );
1670    }
1671
1672    // ── Item 3: end-to-end fallback orchestration in execute() ───────
1673
1674    #[tokio::test]
1675    async fn execute_falls_back_to_firecrawl_on_short_body() {
1676        use wiremock::matchers::{method, path};
1677        use wiremock::{Mock, MockServer, ResponseTemplate};
1678
1679        // Standard-fetch server: returns a very short body (JS-only placeholder)
1680        let standard_server = MockServer::start().await;
1681        Mock::given(method("GET"))
1682            .respond_with(
1683                ResponseTemplate::new(200)
1684                    .set_body_string("<html><body>Loading...</body></html>")
1685                    .insert_header("content-type", "text/html"),
1686            )
1687            .mount(&standard_server)
1688            .await;
1689
1690        // Firecrawl server: returns rich markdown content
1691        let firecrawl_server = MockServer::start().await;
1692        Mock::given(method("POST"))
1693            .and(path("/scrape"))
1694            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1695                "success": true,
1696                "data": {
1697                    "markdown": "# Real Content\n\nThis is the full page content extracted by Firecrawl, with enough text to be clearly above the minimum body length threshold."
1698                }
1699            })))
1700            .mount(&firecrawl_server)
1701            .await;
1702
1703        // Set up API key env var for this test
1704        // SAFETY: test-only, single-threaded test runner.
1705        unsafe { std::env::set_var("FIRECRAWL_E2E_TEST_KEY", "test-key-12345") };
1706
1707        let security = Arc::new(SecurityPolicy {
1708            autonomy: AutonomyLevel::Supervised,
1709            ..SecurityPolicy::default()
1710        });
1711        let standard_addr = standard_server.address();
1712        let firecrawl_addr = firecrawl_server.address();
1713        let tool = WebFetchTool::new(
1714            security,
1715            vec!["*".into()],
1716            vec![],
1717            500_000,
1718            30,
1719            FirecrawlConfig {
1720                enabled: true,
1721                api_key_env: "FIRECRAWL_E2E_TEST_KEY".into(),
1722                api_url: format!("http://{firecrawl_addr}"),
1723                ..FirecrawlConfig::default()
1724            },
1725            vec![],
1726        )
1727        .unwrap();
1728
1729        // Bypass SSRF-guarded execute() — call standard_fetch + fallback
1730        // logic directly so wiremock on 127.0.0.1 is reachable.
1731        let client = reqwest::Client::builder()
1732            .timeout(Duration::from_secs(30))
1733            .build()
1734            .unwrap();
1735
1736        let url = format!("http://{standard_addr}/page");
1737        let standard_result = tool.standard_fetch(&client, &url).await;
1738
1739        // Standard fetch returns short body, should trigger fallback
1740        assert!(tool.should_fallback_to_firecrawl(&standard_result));
1741
1742        // Firecrawl fallback should succeed with rich content
1743        let result = Box::pin(tool.fetch_via_firecrawl(&url)).await.unwrap();
1744
1745        assert!(result.success, "Expected successful Firecrawl fallback");
1746        assert!(
1747            result.output.contains("Real Content"),
1748            "Expected Firecrawl markdown content, got: {}",
1749            result.output
1750        );
1751
1752        // Clean up env var
1753        // SAFETY: test-only, single-threaded test runner.
1754        unsafe { std::env::remove_var("FIRECRAWL_E2E_TEST_KEY") };
1755    }
1756
1757    // ── Allowed private hosts ─────────────────────────────────────
1758
1759    #[test]
1760    fn allowed_private_host_bypasses_ssrf_block() {
1761        let tool = test_tool_with_private_hosts(vec!["*"], vec![], vec!["192.168.1.5"]);
1762        assert!(tool.validate_url("https://192.168.1.5/api").is_ok());
1763    }
1764
1765    #[test]
1766    fn allowed_private_domain_skips_dns_public_check() {
1767        let allowed_domains = vec!["*".to_string()];
1768        let blocked_domains = vec![];
1769        let allowed_private_hosts = vec!["local.internal".to_string()];
1770
1771        let result = validate_target_url_with_dns_check(
1772            "https://local.internal/api",
1773            &allowed_domains,
1774            &blocked_domains,
1775            &allowed_private_hosts,
1776            "web_fetch",
1777            |_| {
1778                panic!("DNS public-host validation should be skipped");
1779            },
1780        );
1781
1782        assert!(
1783            result.is_ok(),
1784            "allowlisted private domain was rejected: {result:?}"
1785        );
1786    }
1787
1788    #[test]
1789    fn unallowed_domain_resolving_private_ip_still_blocked() {
1790        let allowed_domains = vec!["*".to_string()];
1791        let blocked_domains = vec![];
1792        let allowed_private_hosts = vec![];
1793
1794        let err = validate_target_url_with_dns_check(
1795            "https://local.internal/api",
1796            &allowed_domains,
1797            &blocked_domains,
1798            &allowed_private_hosts,
1799            "web_fetch",
1800            |host| {
1801                validate_resolved_ips_are_public(
1802                    host,
1803                    &[std::net::IpAddr::V4(std::net::Ipv4Addr::new(
1804                        192, 168, 1, 5,
1805                    ))],
1806                )
1807            },
1808        )
1809        .unwrap_err()
1810        .to_string();
1811
1812        assert!(
1813            err.contains("non-global address"),
1814            "unexpected error: {err}"
1815        );
1816    }
1817
1818    #[test]
1819    fn private_allowlist_wildcard_does_not_allow_public_domain_miss() {
1820        let allowed_domains = vec!["example.com".to_string()];
1821        let blocked_domains = vec![];
1822        let allowed_private_hosts = vec!["*".to_string()];
1823
1824        let err = validate_target_url_with_dns_check(
1825            "https://not-example.com/api",
1826            &allowed_domains,
1827            &blocked_domains,
1828            &allowed_private_hosts,
1829            "web_fetch",
1830            |_| anyhow::Ok(()),
1831        )
1832        .unwrap_err()
1833        .to_string();
1834
1835        assert!(err.contains("allowed_domains"), "unexpected error: {err}");
1836    }
1837
1838    #[test]
1839    fn blocklist_overrides_allowed_private_domain() {
1840        let allowed_domains = vec!["*".to_string()];
1841        let blocked_domains = vec!["local.internal".to_string()];
1842        let allowed_private_hosts = vec!["local.internal".to_string()];
1843
1844        let err = validate_target_url_with_dns_check(
1845            "https://local.internal/api",
1846            &allowed_domains,
1847            &blocked_domains,
1848            &allowed_private_hosts,
1849            "web_fetch",
1850            |_| anyhow::bail!("blocklist should run before DNS validation"),
1851        )
1852        .unwrap_err()
1853        .to_string();
1854
1855        assert!(err.contains("blocked_domains"), "unexpected error: {err}");
1856    }
1857
1858    #[test]
1859    fn unallowed_private_host_still_blocked() {
1860        let tool = test_tool_with_private_hosts(vec!["*"], vec![], vec!["192.168.1.5"]);
1861        let err = tool
1862            .validate_url("https://10.0.0.1/admin")
1863            .unwrap_err()
1864            .to_string();
1865        assert!(err.contains("local/private"));
1866        assert!(err.contains("allowed_private_hosts"));
1867    }
1868
1869    #[test]
1870    fn blocklist_overrides_allowed_private_host() {
1871        let tool =
1872            test_tool_with_private_hosts(vec!["*"], vec!["192.168.1.5"], vec!["192.168.1.5"]);
1873        let err = tool
1874            .validate_url("https://192.168.1.5/secret")
1875            .unwrap_err()
1876            .to_string();
1877        assert!(err.contains("blocked_domains"));
1878    }
1879
1880    #[test]
1881    fn allowed_private_host_with_port() {
1882        let tool = test_tool_with_private_hosts(vec!["*"], vec![], vec!["192.168.1.5"]);
1883        assert!(tool.validate_url("https://192.168.1.5:8080/api").is_ok());
1884    }
1885}