1use async_trait::async_trait;
2use futures_util::StreamExt;
3use serde_json::json;
4use std::sync::Arc;
5use std::time::Duration;
6use zeroclaw_api::tool::{Tool, ToolResult};
7use zeroclaw_config::policy::SecurityPolicy;
8use zeroclaw_config::schema::FirecrawlConfig;
9
10const FIRECRAWL_MIN_BODY_LEN: usize = 100;
13
14pub struct WebFetchTool {
24 security: Arc<SecurityPolicy>,
25 allowed_domains: Vec<String>,
26 blocked_domains: Vec<String>,
27 allowed_private_hosts: Vec<String>,
28 max_response_size: usize,
29 timeout_secs: u64,
30 firecrawl: FirecrawlConfig,
31}
32
33impl WebFetchTool {
34 pub fn new(
35 security: Arc<SecurityPolicy>,
36 allowed_domains: Vec<String>,
37 blocked_domains: Vec<String>,
38 max_response_size: usize,
39 timeout_secs: u64,
40 firecrawl: FirecrawlConfig,
41 allowed_private_hosts: Vec<String>,
42 ) -> anyhow::Result<Self> {
43 Ok(Self {
44 security,
45 allowed_domains: normalize_allowed_domains(
46 allowed_domains,
47 "web_fetch.allowed_domains",
48 )?,
49 blocked_domains: normalize_allowed_domains(
50 blocked_domains,
51 "web_fetch.blocked_domains",
52 )?,
53 allowed_private_hosts: normalize_allowed_domains(
54 allowed_private_hosts,
55 "web_fetch.allowed_private_hosts",
56 )?,
57 max_response_size,
58 timeout_secs,
59 firecrawl,
60 })
61 }
62
63 fn validate_url(&self, raw_url: &str) -> anyhow::Result<String> {
64 validate_target_url(
65 raw_url,
66 &self.allowed_domains,
67 &self.blocked_domains,
68 &self.allowed_private_hosts,
69 "web_fetch",
70 )
71 }
72
73 fn truncate_response(&self, text: &str) -> String {
74 if self.max_response_size == 0 {
81 return text.to_string();
82 }
83 if text.len() > self.max_response_size {
84 let mut truncated = text
85 .chars()
86 .take(self.max_response_size)
87 .collect::<String>();
88 truncated.push_str("\n\n... [Response truncated due to size limit] ...");
89 truncated
90 } else {
91 text.to_string()
92 }
93 }
94
95 async fn read_response_text_limited(
96 &self,
97 response: reqwest::Response,
98 ) -> anyhow::Result<String> {
99 let mut bytes_stream = response.bytes_stream();
100 let hard_cap = if self.max_response_size == 0 {
106 usize::MAX
107 } else {
108 self.max_response_size.saturating_add(1)
109 };
110 let mut bytes = Vec::new();
111
112 while let Some(chunk_result) = bytes_stream.next().await {
113 let chunk = chunk_result?;
114 if append_chunk_with_cap(&mut bytes, &chunk, hard_cap) {
115 break;
116 }
117 }
118
119 Ok(String::from_utf8_lossy(&bytes).into_owned())
120 }
121
122 fn should_fallback_to_firecrawl(&self, result: &ToolResult) -> bool {
124 if !self.firecrawl.enabled {
125 return false;
126 }
127 if !result.success {
129 return true;
130 }
131 if result.output.trim().len() < FIRECRAWL_MIN_BODY_LEN {
133 return true;
134 }
135 false
136 }
137
138 async fn fetch_via_firecrawl(&self, url: &str) -> anyhow::Result<ToolResult> {
140 let api_key = std::env::var(&self.firecrawl.api_key_env).map_err(|_| {
141 ::zeroclaw_log::record!(
142 ERROR,
143 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
144 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
145 .with_attrs(::serde_json::json!({
146 "env_var": &self.firecrawl.api_key_env,
147 })),
148 "web_fetch: Firecrawl API key missing from env"
149 );
150 anyhow::Error::msg(format!(
151 "Firecrawl API key not found in environment variable '{}'",
152 self.firecrawl.api_key_env
153 ))
154 })?;
155
156 let endpoint = format!("{}/scrape", self.firecrawl.api_url.trim_end_matches('/'));
157
158 let client = reqwest::Client::builder()
159 .timeout(Duration::from_secs(60))
160 .build()
161 .map_err(|e| {
162 ::zeroclaw_log::record!(
163 ERROR,
164 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
165 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
166 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
167 "web_fetch: failed to build Firecrawl HTTP client"
168 );
169 anyhow::Error::msg(format!("Failed to build Firecrawl HTTP client: {e}"))
170 })?;
171
172 let body = json!({
173 "url": url,
174 "formats": ["markdown"]
175 });
176
177 let response = client
178 .post(&endpoint)
179 .header("Authorization", format!("Bearer {api_key}"))
180 .header("Content-Type", "application/json")
181 .json(&body)
182 .send()
183 .await
184 .map_err(|e| {
185 ::zeroclaw_log::record!(
186 ERROR,
187 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
188 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
189 .with_attrs(::serde_json::json!({
190 "phase": "firecrawl_request",
191 "error": format!("{}", e),
192 })),
193 "web_fetch: Firecrawl request failed"
194 );
195 anyhow::Error::msg(format!("Firecrawl request failed: {e}"))
196 })?;
197
198 let status = response.status();
199 if !status.is_success() {
200 let error_body = response.text().await.unwrap_or_default();
201 return Ok(ToolResult {
202 success: false,
203 output: String::new(),
204 error: Some(format!(
205 "Firecrawl API error: HTTP {} - {}",
206 status.as_u16(),
207 error_body
208 )),
209 });
210 }
211
212 let resp_json: serde_json::Value = response.json().await.map_err(|e| {
213 ::zeroclaw_log::record!(
214 ERROR,
215 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
216 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
217 .with_attrs(::serde_json::json!({
218 "phase": "firecrawl_response_parse",
219 "error": format!("{}", e),
220 })),
221 "web_fetch: failed to parse Firecrawl response"
222 );
223 anyhow::Error::msg(format!("Failed to parse Firecrawl response: {e}"))
224 })?;
225
226 let markdown = resp_json
227 .get("data")
228 .and_then(|d| d.get("markdown"))
229 .and_then(|m| m.as_str())
230 .unwrap_or("");
231
232 if markdown.is_empty() {
233 return Ok(ToolResult {
234 success: false,
235 output: String::new(),
236 error: Some("Firecrawl returned empty markdown content".into()),
237 });
238 }
239
240 let output = self.truncate_response(markdown);
241
242 Ok(ToolResult {
243 success: true,
244 output,
245 error: None,
246 })
247 }
248
249 async fn standard_fetch(&self, client: &reqwest::Client, url: &str) -> ToolResult {
251 let response = match client.get(url).send().await {
252 Ok(r) => r,
253 Err(e) => {
254 return ToolResult {
255 success: false,
256 output: String::new(),
257 error: Some(format!("HTTP request failed: {e}")),
258 };
259 }
260 };
261
262 let status = response.status();
263 if !status.is_success() {
264 return ToolResult {
265 success: false,
266 output: String::new(),
267 error: Some(format!(
268 "HTTP {} {}",
269 status.as_u16(),
270 status.canonical_reason().unwrap_or("Unknown")
271 )),
272 };
273 }
274
275 let content_type = response
277 .headers()
278 .get(reqwest::header::CONTENT_TYPE)
279 .and_then(|v| v.to_str().ok())
280 .unwrap_or("")
281 .to_lowercase();
282
283 let body_mode = if content_type.contains("text/html") || content_type.is_empty() {
284 "html"
285 } else if content_type.contains("text/plain")
286 || content_type.contains("text/markdown")
287 || content_type.contains("application/json")
288 {
289 "plain"
290 } else {
291 return ToolResult {
292 success: false,
293 output: String::new(),
294 error: Some(format!(
295 "Unsupported content type: {content_type}. \
296 web_fetch supports text/html, text/plain, text/markdown, and application/json."
297 )),
298 };
299 };
300
301 let body = match self.read_response_text_limited(response).await {
302 Ok(t) => t,
303 Err(e) => {
304 return ToolResult {
305 success: false,
306 output: String::new(),
307 error: Some(format!("Failed to read response body: {e}")),
308 };
309 }
310 };
311
312 let text = if body_mode == "html" {
313 nanohtml2text::html2text(&body)
314 } else {
315 body
316 };
317
318 let output = self.truncate_response(&text);
319
320 ToolResult {
321 success: true,
322 output,
323 error: None,
324 }
325 }
326}
327
328#[async_trait]
329impl Tool for WebFetchTool {
330 fn name(&self) -> &str {
331 "web_fetch"
332 }
333
334 fn description(&self) -> &str {
335 "Fetch a web page and return its content as clean plain text. \
336 HTML pages are automatically converted to readable text. \
337 JSON and plain text responses are returned as-is. \
338 Only GET requests; follows redirects. \
339 Falls back to Firecrawl for JS-heavy/bot-blocked sites (if enabled). \
340 Security: allowlist-only domains, no local/private hosts."
341 }
342
343 fn parameters_schema(&self) -> serde_json::Value {
344 json!({
345 "type": "object",
346 "properties": {
347 "url": {
348 "type": "string",
349 "description": "The HTTP or HTTPS URL to fetch"
350 }
351 },
352 "required": ["url"]
353 })
354 }
355
356 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
357 let url = args.get("url").and_then(|v| v.as_str()).ok_or_else(|| {
358 ::zeroclaw_log::record!(
359 WARN,
360 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
361 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
362 .with_attrs(::serde_json::json!({"param": "url"})),
363 "web_fetch: missing url parameter"
364 );
365 anyhow::Error::msg("Missing 'url' parameter")
366 })?;
367
368 if !self.security.can_act() {
369 return Ok(ToolResult {
370 success: false,
371 output: String::new(),
372 error: Some("Action blocked: autonomy is read-only".into()),
373 });
374 }
375
376 let url = match self.validate_url(url) {
380 Ok(v) => v,
381 Err(e) => {
382 return Ok(ToolResult {
383 success: false,
384 output: String::new(),
385 error: Some(e.to_string()),
386 });
387 }
388 };
389
390 let timeout_secs = if self.timeout_secs == 0 {
392 ::zeroclaw_log::record!(
393 WARN,
394 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
395 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
396 "web_fetch: timeout_secs is 0, using safe default of 30s"
397 );
398 30
399 } else {
400 self.timeout_secs
401 };
402
403 let allowed_domains = self.allowed_domains.clone();
404 let blocked_domains = self.blocked_domains.clone();
405 let allowed_private_hosts = self.allowed_private_hosts.clone();
406 let redirect_policy = reqwest::redirect::Policy::custom(move |attempt| {
407 if attempt.previous().len() >= 10 {
408 return attempt.error(std::io::Error::other("Too many redirects (max 10)"));
409 }
410
411 if let Err(err) = validate_target_url(
412 attempt.url().as_str(),
413 &allowed_domains,
414 &blocked_domains,
415 &allowed_private_hosts,
416 "web_fetch",
417 ) {
418 return attempt.error(std::io::Error::new(
419 std::io::ErrorKind::PermissionDenied,
420 format!("Blocked redirect target: {err}"),
421 ));
422 }
423
424 attempt.follow()
425 });
426
427 let builder = reqwest::Client::builder()
428 .timeout(Duration::from_secs(timeout_secs))
429 .connect_timeout(Duration::from_secs(10))
430 .redirect(redirect_policy)
431 .user_agent("ZeroClaw/0.1 (web_fetch)");
432 let builder =
433 zeroclaw_config::schema::apply_runtime_proxy_to_builder(builder, "tool.web_fetch");
434 let client = match builder.build() {
435 Ok(c) => c,
436 Err(e) => {
437 return Ok(ToolResult {
438 success: false,
439 output: String::new(),
440 error: Some(format!("Failed to build HTTP client: {e}")),
441 });
442 }
443 };
444
445 let standard_result = self.standard_fetch(&client, &url).await;
446
447 if self.should_fallback_to_firecrawl(&standard_result) {
450 ::zeroclaw_log::record!(
451 INFO,
452 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
453 .with_attrs(::serde_json::json!({"url": url})),
454 "web_fetch: standard fetch insufficient for , attempting Firecrawl fallback"
455 );
456 match Box::pin(self.fetch_via_firecrawl(&url)).await {
457 Ok(firecrawl_result) if firecrawl_result.success => {
458 return Ok(firecrawl_result);
459 }
460 Ok(firecrawl_result) => {
461 ::zeroclaw_log::record!(
462 WARN,
463 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
464 .with_outcome(::zeroclaw_log::EventOutcome::Unknown),
465 &format!(
466 "web_fetch: Firecrawl fallback also failed: {:?}",
467 firecrawl_result.error
468 )
469 );
470 }
472 Err(e) => {
473 ::zeroclaw_log::record!(
474 WARN,
475 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
476 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
477 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
478 "web_fetch: Firecrawl fallback error"
479 );
480 }
481 }
482 }
483
484 Ok(standard_result)
485 }
486}
487
488fn validate_target_url(
491 raw_url: &str,
492 allowed_domains: &[String],
493 blocked_domains: &[String],
494 allowed_private_hosts: &[String],
495 tool_name: &str,
496) -> anyhow::Result<String> {
497 let url = raw_url.trim();
498
499 if url.is_empty() {
500 anyhow::bail!("URL cannot be empty");
501 }
502
503 if url.chars().any(char::is_whitespace) {
504 anyhow::bail!("URL cannot contain whitespace");
505 }
506
507 if !url.starts_with("http://") && !url.starts_with("https://") {
508 anyhow::bail!("Only http:// and https:// URLs are allowed");
509 }
510
511 if allowed_domains.is_empty() {
512 anyhow::bail!(
513 "{tool_name} tool is enabled but no allowed_domains are configured. \
514 Add [{tool_name}].allowed_domains in config.toml"
515 );
516 }
517
518 let host = extract_host(url)?;
519
520 if host_matches_allowlist(&host, blocked_domains) {
522 anyhow::bail!("Host '{host}' is in {tool_name}.blocked_domains");
523 }
524
525 let private_host_allowed =
526 is_private_or_local_host(&host) && host_matches_allowlist(&host, allowed_private_hosts);
527
528 if is_private_or_local_host(&host) && !private_host_allowed {
529 anyhow::bail!(
530 "Blocked local/private host: {host}. \
531 To allow this host, add it to {tool_name}.allowed_private_hosts in config.toml"
532 );
533 }
534
535 if private_host_allowed {
536 ::zeroclaw_log::record!(
537 WARN,
538 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Note)
539 .with_outcome(::zeroclaw_log::EventOutcome::Unknown)
540 .with_attrs(::serde_json::json!({"tool_name": tool_name, "host": host})),
541 ": allowing private/local host '' via allowed_private_hosts"
542 );
543 }
544
545 if !private_host_allowed && !host_matches_allowlist(&host, allowed_domains) {
546 anyhow::bail!("Host '{host}' is not in {tool_name}.allowed_domains");
547 }
548
549 if !private_host_allowed {
550 validate_resolved_host_is_public(&host)?;
551 }
552
553 Ok(url.to_string())
554}
555
556fn append_chunk_with_cap(buffer: &mut Vec<u8>, chunk: &[u8], hard_cap: usize) -> bool {
557 if buffer.len() >= hard_cap {
558 return true;
559 }
560
561 let remaining = hard_cap - buffer.len();
562 if chunk.len() > remaining {
563 buffer.extend_from_slice(&chunk[..remaining]);
564 return true;
565 }
566
567 buffer.extend_from_slice(chunk);
568 buffer.len() >= hard_cap
569}
570
571fn normalize_allowed_domains(domains: Vec<String>, label: &str) -> anyhow::Result<Vec<String>> {
572 let mut rejected = Vec::new();
573 let mut normalized = domains
574 .into_iter()
575 .filter_map(|d| {
576 normalize_domain(&d).or_else(|| {
577 rejected.push(d.clone());
578 None
579 })
580 })
581 .collect::<Vec<_>>();
582 if !rejected.is_empty() {
583 anyhow::bail!(
584 "Invalid {label} entry(s): [{}]. Each entry must be a valid domain, hostname, IPv4, or IPv6 address.",
585 rejected.join(", ")
586 );
587 }
588 normalized.sort_unstable();
589 normalized.dedup();
590 Ok(normalized)
591}
592
593fn normalize_domain(raw: &str) -> Option<String> {
594 let input = raw.trim();
595 if input.is_empty() || input.chars().any(char::is_whitespace) {
596 return None;
597 }
598
599 let bare_ip = match (input.starts_with('['), input.ends_with(']')) {
600 (true, true) => &input[1..input.len() - 1],
601 (false, false) => input,
602 _ => return None,
603 };
604 if let Ok(ip) = bare_ip.parse::<std::net::IpAddr>() {
605 return Some(ip.to_string().to_lowercase());
606 }
607
608 let parsed = reqwest::Url::parse(input)
609 .or_else(|_| reqwest::Url::parse(&format!("https://{input}")))
610 .ok()?;
611
612 if !parsed.username().is_empty() || parsed.password().is_some() {
613 return None;
614 }
615
616 let host = parsed.host_str()?;
617 let trimmed = host.trim();
618 let host_no_brackets = match (trimmed.starts_with('['), trimmed.ends_with(']')) {
619 (true, true) => &trimmed[1..trimmed.len() - 1],
620 (false, false) => trimmed,
621 _ => return None,
622 };
623 let normalized = host_no_brackets
624 .trim_start_matches('.')
625 .trim_end_matches('.');
626 if normalized.is_empty() {
627 return None;
628 }
629
630 Some(normalized.to_lowercase())
631}
632
633fn extract_host(url: &str) -> anyhow::Result<String> {
634 let rest = url
635 .strip_prefix("http://")
636 .or_else(|| url.strip_prefix("https://"))
637 .ok_or_else(|| {
638 ::zeroclaw_log::record!(
639 WARN,
640 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
641 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
642 .with_attrs(::serde_json::json!({"url": url})),
643 "web_fetch: non-http(s) URL rejected"
644 );
645 anyhow::Error::msg("Only http:// and https:// URLs are allowed")
646 })?;
647
648 let authority = rest.split(['/', '?', '#']).next().ok_or_else(|| {
649 ::zeroclaw_log::record!(
650 WARN,
651 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
652 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
653 .with_attrs(::serde_json::json!({"url": url})),
654 "web_fetch: invalid URL"
655 );
656 anyhow::Error::msg("Invalid URL")
657 })?;
658
659 if authority.is_empty() {
660 anyhow::bail!("URL must include a host");
661 }
662
663 if authority.contains('@') {
664 anyhow::bail!("URL userinfo is not allowed");
665 }
666
667 if authority.starts_with('[') {
668 anyhow::bail!("IPv6 hosts are not supported in web_fetch");
669 }
670
671 let host = authority
672 .split(':')
673 .next()
674 .unwrap_or_default()
675 .trim()
676 .trim_end_matches('.')
677 .to_lowercase();
678
679 if host.is_empty() {
680 anyhow::bail!("URL must include a valid host");
681 }
682
683 Ok(host)
684}
685
686fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
687 if allowed_domains.iter().any(|domain| domain == "*") {
688 return true;
689 }
690
691 allowed_domains.iter().any(|domain| {
692 host == domain
693 || host
694 .strip_suffix(domain)
695 .is_some_and(|prefix| prefix.ends_with('.'))
696 })
697}
698
699fn is_private_or_local_host(host: &str) -> bool {
700 let bare = host
701 .strip_prefix('[')
702 .and_then(|h| h.strip_suffix(']'))
703 .unwrap_or(host);
704
705 let has_local_tld = bare
706 .rsplit('.')
707 .next()
708 .is_some_and(|label| label == "local");
709
710 if bare == "localhost" || bare.ends_with(".localhost") || has_local_tld {
711 return true;
712 }
713
714 if let Ok(ip) = bare.parse::<std::net::IpAddr>() {
715 return match ip {
716 std::net::IpAddr::V4(v4) => is_non_global_v4(v4),
717 std::net::IpAddr::V6(v6) => is_non_global_v6(v6),
718 };
719 }
720
721 false
722}
723
724#[cfg(not(test))]
725fn validate_resolved_host_is_public(host: &str) -> anyhow::Result<()> {
726 use std::net::ToSocketAddrs;
727
728 let ips = (host, 0)
729 .to_socket_addrs()
730 .map_err(|e| {
731 ::zeroclaw_log::record!(
732 ERROR,
733 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
734 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
735 .with_attrs(::serde_json::json!({
736 "host": host,
737 "error": format!("{}", e),
738 })),
739 "web_fetch: failed to resolve host"
740 );
741 anyhow::Error::msg(format!("Failed to resolve host '{host}': {e}"))
742 })?
743 .map(|addr| addr.ip())
744 .collect::<Vec<_>>();
745
746 validate_resolved_ips_are_public(host, &ips)
747}
748
749#[cfg(test)]
750fn validate_resolved_host_is_public(_host: &str) -> anyhow::Result<()> {
751 Ok(())
753}
754
755fn validate_resolved_ips_are_public(host: &str, ips: &[std::net::IpAddr]) -> anyhow::Result<()> {
756 if ips.is_empty() {
757 anyhow::bail!("Failed to resolve host '{host}'");
758 }
759
760 for ip in ips {
761 let non_global = match ip {
762 std::net::IpAddr::V4(v4) => is_non_global_v4(*v4),
763 std::net::IpAddr::V6(v6) => is_non_global_v6(*v6),
764 };
765 if non_global {
766 anyhow::bail!("Blocked host '{host}' resolved to non-global address {ip}");
767 }
768 }
769
770 Ok(())
771}
772
773fn is_non_global_v4(v4: std::net::Ipv4Addr) -> bool {
774 let [a, b, c, _] = v4.octets();
775 v4.is_loopback()
776 || v4.is_private()
777 || v4.is_link_local()
778 || v4.is_unspecified()
779 || v4.is_broadcast()
780 || v4.is_multicast()
781 || (a == 100 && (64..=127).contains(&b))
782 || a >= 240
783 || (a == 192 && b == 0 && (c == 0 || c == 2))
784 || (a == 198 && b == 51)
785 || (a == 203 && b == 0)
786 || (a == 198 && (18..=19).contains(&b))
787}
788
789fn is_non_global_v6(v6: std::net::Ipv6Addr) -> bool {
790 let segs = v6.segments();
791 v6.is_loopback()
792 || v6.is_unspecified()
793 || v6.is_multicast()
794 || (segs[0] & 0xfe00) == 0xfc00
795 || (segs[0] & 0xffc0) == 0xfe80
796 || (segs[0] == 0x2001 && segs[1] == 0x0db8)
797 || v6.to_ipv4_mapped().is_some_and(is_non_global_v4)
798}
799
800#[cfg(test)]
801mod tests {
802 use super::*;
803 use zeroclaw_config::autonomy::AutonomyLevel;
804 use zeroclaw_config::policy::SecurityPolicy;
805 use zeroclaw_config::schema::FirecrawlConfig;
806
807 fn test_tool(allowed_domains: Vec<&str>) -> WebFetchTool {
808 test_tool_with_blocklist(allowed_domains, vec![])
809 }
810
811 fn test_tool_with_blocklist(
812 allowed_domains: Vec<&str>,
813 blocked_domains: Vec<&str>,
814 ) -> WebFetchTool {
815 let security = Arc::new(SecurityPolicy {
816 autonomy: AutonomyLevel::Supervised,
817 ..SecurityPolicy::default()
818 });
819 WebFetchTool::new(
820 security,
821 allowed_domains.into_iter().map(String::from).collect(),
822 blocked_domains.into_iter().map(String::from).collect(),
823 500_000,
824 30,
825 FirecrawlConfig::default(),
826 vec![],
827 )
828 .unwrap()
829 }
830
831 fn test_tool_with_private_hosts(
832 allowed_domains: Vec<&str>,
833 blocked_domains: Vec<&str>,
834 allowed_private_hosts: Vec<&str>,
835 ) -> WebFetchTool {
836 let security = Arc::new(SecurityPolicy {
837 autonomy: AutonomyLevel::Supervised,
838 ..SecurityPolicy::default()
839 });
840 WebFetchTool::new(
841 security,
842 allowed_domains.into_iter().map(String::from).collect(),
843 blocked_domains.into_iter().map(String::from).collect(),
844 500_000,
845 30,
846 FirecrawlConfig::default(),
847 allowed_private_hosts
848 .into_iter()
849 .map(String::from)
850 .collect(),
851 )
852 .unwrap()
853 }
854
855 fn test_tool_with_firecrawl(firecrawl: FirecrawlConfig) -> WebFetchTool {
856 let security = Arc::new(SecurityPolicy {
857 autonomy: AutonomyLevel::Supervised,
858 ..SecurityPolicy::default()
859 });
860 WebFetchTool::new(
861 security,
862 vec!["*".into()],
863 vec![],
864 500_000,
865 30,
866 firecrawl,
867 vec![],
868 )
869 .unwrap()
870 }
871
872 #[test]
875 fn name_is_web_fetch() {
876 let tool = test_tool(vec!["example.com"]);
877 assert_eq!(tool.name(), "web_fetch");
878 }
879
880 #[test]
881 fn parameters_schema_requires_url() {
882 let tool = test_tool(vec!["example.com"]);
883 let schema = tool.parameters_schema();
884 assert!(schema["properties"]["url"].is_object());
885 let required = schema["required"].as_array().unwrap();
886 assert!(required.iter().any(|v| v.as_str() == Some("url")));
887 }
888
889 #[test]
892 fn html_to_text_conversion() {
893 let html = "<html><body><h1>Title</h1><p>Hello <b>world</b></p></body></html>";
894 let text = nanohtml2text::html2text(html);
895 assert!(text.contains("Title"));
896 assert!(text.contains("Hello"));
897 assert!(text.contains("world"));
898 assert!(!text.contains("<h1>"));
899 assert!(!text.contains("<p>"));
900 }
901
902 #[test]
905 fn validate_accepts_exact_domain() {
906 let tool = test_tool(vec!["example.com"]);
907 let got = tool.validate_url("https://example.com/page").unwrap();
908 assert_eq!(got, "https://example.com/page");
909 }
910
911 #[test]
912 fn validate_accepts_subdomain() {
913 let tool = test_tool(vec!["example.com"]);
914 assert!(tool.validate_url("https://docs.example.com/guide").is_ok());
915 }
916
917 #[test]
918 fn validate_accepts_wildcard() {
919 let tool = test_tool(vec!["*"]);
920 assert!(tool.validate_url("https://news.ycombinator.com").is_ok());
921 }
922
923 #[test]
924 fn validate_rejects_empty_url() {
925 let tool = test_tool(vec!["example.com"]);
926 let err = tool.validate_url("").unwrap_err().to_string();
927 assert!(err.contains("empty"));
928 }
929
930 #[test]
931 fn validate_rejects_missing_url() {
932 let tool = test_tool(vec!["example.com"]);
933 let err = tool.validate_url(" ").unwrap_err().to_string();
934 assert!(err.contains("empty"));
935 }
936
937 #[test]
938 fn validate_rejects_ftp_scheme() {
939 let tool = test_tool(vec!["example.com"]);
940 let err = tool
941 .validate_url("ftp://example.com")
942 .unwrap_err()
943 .to_string();
944 assert!(err.contains("http://") || err.contains("https://"));
945 }
946
947 #[test]
948 fn validate_rejects_allowlist_miss() {
949 let tool = test_tool(vec!["example.com"]);
950 let err = tool
951 .validate_url("https://google.com")
952 .unwrap_err()
953 .to_string();
954 assert!(err.contains("allowed_domains"));
955 }
956
957 #[test]
958 fn validate_requires_allowlist() {
959 let security = Arc::new(SecurityPolicy::default());
960 let tool = WebFetchTool::new(
961 security,
962 vec![],
963 vec![],
964 500_000,
965 30,
966 FirecrawlConfig::default(),
967 vec![],
968 )
969 .unwrap();
970 let err = tool
971 .validate_url("https://example.com")
972 .unwrap_err()
973 .to_string();
974 assert!(err.contains("allowed_domains"));
975 }
976
977 #[test]
980 fn ssrf_blocks_localhost() {
981 let tool = test_tool(vec!["localhost"]);
982 let err = tool
983 .validate_url("https://localhost:8080")
984 .unwrap_err()
985 .to_string();
986 assert!(err.contains("local/private"));
987 }
988
989 #[test]
990 fn ssrf_blocks_private_ipv4() {
991 let tool = test_tool(vec!["192.168.1.5"]);
992 let err = tool
993 .validate_url("https://192.168.1.5")
994 .unwrap_err()
995 .to_string();
996 assert!(err.contains("local/private"));
997 }
998
999 #[test]
1000 fn ssrf_blocks_loopback() {
1001 assert!(is_private_or_local_host("127.0.0.1"));
1002 assert!(is_private_or_local_host("127.0.0.2"));
1003 }
1004
1005 #[test]
1006 fn ssrf_blocks_rfc1918() {
1007 assert!(is_private_or_local_host("10.0.0.1"));
1008 assert!(is_private_or_local_host("172.16.0.1"));
1009 assert!(is_private_or_local_host("192.168.1.1"));
1010 }
1011
1012 #[test]
1013 fn ssrf_wildcard_still_blocks_private() {
1014 let tool = test_tool(vec!["*"]);
1015 let err = tool
1016 .validate_url("https://localhost:8080")
1017 .unwrap_err()
1018 .to_string();
1019 assert!(err.contains("local/private"));
1020 }
1021
1022 #[test]
1023 fn redirect_target_validation_allows_permitted_host() {
1024 let allowed = vec!["example.com".to_string()];
1025 let blocked = vec![];
1026 assert!(
1027 validate_target_url(
1028 "https://docs.example.com/page",
1029 &allowed,
1030 &blocked,
1031 &[],
1032 "web_fetch"
1033 )
1034 .is_ok()
1035 );
1036 }
1037
1038 #[test]
1039 fn redirect_target_validation_blocks_private_host() {
1040 let allowed = vec!["example.com".to_string()];
1041 let blocked = vec![];
1042 let err = validate_target_url(
1043 "https://127.0.0.1/admin",
1044 &allowed,
1045 &blocked,
1046 &[],
1047 "web_fetch",
1048 )
1049 .unwrap_err()
1050 .to_string();
1051 assert!(err.contains("local/private"));
1052 }
1053
1054 #[test]
1055 fn redirect_target_validation_blocks_blocklisted_host() {
1056 let allowed = vec!["*".to_string()];
1057 let blocked = vec!["evil.com".to_string()];
1058 let err = validate_target_url(
1059 "https://evil.com/phish",
1060 &allowed,
1061 &blocked,
1062 &[],
1063 "web_fetch",
1064 )
1065 .unwrap_err()
1066 .to_string();
1067 assert!(err.contains("blocked_domains"));
1068 }
1069
1070 #[tokio::test]
1073 async fn blocks_readonly_mode() {
1074 let security = Arc::new(SecurityPolicy {
1075 autonomy: AutonomyLevel::ReadOnly,
1076 ..SecurityPolicy::default()
1077 });
1078 let tool = WebFetchTool::new(
1079 security,
1080 vec!["example.com".into()],
1081 vec![],
1082 500_000,
1083 30,
1084 FirecrawlConfig::default(),
1085 vec![],
1086 )
1087 .unwrap();
1088 let result = tool
1089 .execute(json!({"url": "https://example.com"}))
1090 .await
1091 .unwrap();
1092 assert!(!result.success);
1093 assert!(result.error.unwrap().contains("read-only"));
1094 }
1095
1096 #[test]
1099 fn truncate_within_limit() {
1100 let tool = test_tool(vec!["example.com"]);
1101 let text = "hello world";
1102 assert_eq!(tool.truncate_response(text), "hello world");
1103 }
1104
1105 #[test]
1106 fn truncate_response_zero_means_unlimited() {
1107 let tool = WebFetchTool::new(
1110 Arc::new(SecurityPolicy::default()),
1111 vec!["example.com".into()],
1112 vec![],
1113 0, 30,
1115 FirecrawlConfig::default(),
1116 vec![],
1117 )
1118 .unwrap();
1119 let long_text = "x".repeat(10_000);
1120 let result = tool.truncate_response(&long_text);
1121 assert_eq!(result.len(), 10_000, "zero limit must not truncate");
1122 assert!(
1123 !result.contains("[Response truncated"),
1124 "must not append truncation marker"
1125 );
1126 }
1127
1128 #[tokio::test]
1136 async fn standard_fetch_with_zero_limit_returns_full_body_and_skips_firecrawl_fallback() {
1137 use wiremock::matchers::method;
1138 use wiremock::{Mock, MockServer, ResponseTemplate};
1139
1140 let server = MockServer::start().await;
1141 let addr = server.address();
1142
1143 let body = "a".repeat(500);
1146 Mock::given(method("GET"))
1147 .respond_with(ResponseTemplate::new(200).set_body_string(body.clone()))
1148 .mount(&server)
1149 .await;
1150
1151 let tool = WebFetchTool::new(
1152 Arc::new(SecurityPolicy {
1153 autonomy: AutonomyLevel::Supervised,
1154 ..SecurityPolicy::default()
1155 }),
1156 vec!["*".into()],
1157 vec![],
1158 0, 30,
1160 FirecrawlConfig {
1161 enabled: true,
1162 ..FirecrawlConfig::default()
1163 },
1164 vec![],
1165 )
1166 .unwrap();
1167
1168 let url = format!("http://{}:{}/", addr.ip(), addr.port());
1171 let client = reqwest::Client::builder()
1172 .timeout(std::time::Duration::from_secs(5))
1173 .build()
1174 .expect("reqwest client");
1175 let standard_result = tool.standard_fetch(&client, &url).await;
1176
1177 assert!(
1180 standard_result.success,
1181 "standard_fetch must succeed, got error={:?}",
1182 standard_result.error
1183 );
1184 assert_eq!(
1185 standard_result.output.len(),
1186 body.len(),
1187 "streamed body length under zero-limit must equal full body"
1188 );
1189 assert_eq!(
1190 standard_result.output, body,
1191 "streamed body content must equal full body"
1192 );
1193 assert!(
1194 !standard_result.output.contains("[Response truncated"),
1195 "must not append truncation marker under zero limit"
1196 );
1197
1198 assert!(
1201 !tool.should_fallback_to_firecrawl(&standard_result),
1202 "500-byte body under zero limit must not trigger Firecrawl fallback"
1203 );
1204 }
1205
1206 #[test]
1207 fn truncate_over_limit() {
1208 let tool = WebFetchTool::new(
1209 Arc::new(SecurityPolicy::default()),
1210 vec!["example.com".into()],
1211 vec![],
1212 10,
1213 30,
1214 FirecrawlConfig::default(),
1215 vec![],
1216 )
1217 .unwrap();
1218 let text = "hello world this is long";
1219 let truncated = tool.truncate_response(text);
1220 assert!(truncated.contains("[Response truncated"));
1221 }
1222
1223 #[test]
1226 fn normalize_domain_strips_scheme_and_case() {
1227 let got = normalize_domain(" HTTPS://Docs.Example.com/path ").unwrap();
1228 assert_eq!(got, "docs.example.com");
1229 }
1230
1231 #[test]
1232 fn normalize_domain_rejects_userinfo() {
1233 assert!(normalize_domain("https://user@example.com").is_none());
1234 assert!(normalize_domain("user@example.com").is_none());
1235 assert!(normalize_domain("https://user:pass@example.com").is_none());
1236 assert!(normalize_domain("user:pass@example.com").is_none());
1237 }
1238
1239 #[test]
1240 fn normalize_domain_rejects_unmatched_brackets() {
1241 assert!(normalize_domain("[::1").is_none());
1242 assert!(normalize_domain("::1]").is_none());
1243 assert!(normalize_domain("[127.0.0.1").is_none());
1244 assert!(normalize_domain("127.0.0.1]").is_none());
1245 }
1246
1247 #[test]
1248 fn normalize_deduplicates() {
1249 let got = normalize_allowed_domains(
1250 vec![
1251 "example.com".into(),
1252 "EXAMPLE.COM".into(),
1253 "https://example.com/".into(),
1254 ],
1255 "test",
1256 )
1257 .unwrap();
1258 assert_eq!(got, vec!["example.com".to_string()]);
1259 }
1260
1261 #[test]
1264 fn blocklist_rejects_exact_match() {
1265 let tool = test_tool_with_blocklist(vec!["*"], vec!["evil.com"]);
1266 let err = tool
1267 .validate_url("https://evil.com/page")
1268 .unwrap_err()
1269 .to_string();
1270 assert!(err.contains("blocked_domains"));
1271 }
1272
1273 #[test]
1274 fn blocklist_rejects_subdomain() {
1275 let tool = test_tool_with_blocklist(vec!["*"], vec!["evil.com"]);
1276 let err = tool
1277 .validate_url("https://api.evil.com/v1")
1278 .unwrap_err()
1279 .to_string();
1280 assert!(err.contains("blocked_domains"));
1281 }
1282
1283 #[test]
1284 fn blocklist_wins_over_allowlist() {
1285 let tool = test_tool_with_blocklist(vec!["evil.com"], vec!["evil.com"]);
1286 let err = tool
1287 .validate_url("https://evil.com")
1288 .unwrap_err()
1289 .to_string();
1290 assert!(err.contains("blocked_domains"));
1291 }
1292
1293 #[test]
1294 fn blocklist_allows_non_blocked() {
1295 let tool = test_tool_with_blocklist(vec!["*"], vec!["evil.com"]);
1296 assert!(tool.validate_url("https://example.com").is_ok());
1297 }
1298
1299 #[test]
1300 fn append_chunk_with_cap_truncates_and_stops() {
1301 let mut buffer = Vec::new();
1302 assert!(!append_chunk_with_cap(&mut buffer, b"hello", 8));
1303 assert!(append_chunk_with_cap(&mut buffer, b"world", 8));
1304 assert_eq!(buffer, b"hellowor");
1305 }
1306
1307 #[test]
1308 fn resolved_private_ip_is_rejected() {
1309 let ips = vec!["127.0.0.1".parse().unwrap()];
1310 let err = validate_resolved_ips_are_public("example.com", &ips)
1311 .unwrap_err()
1312 .to_string();
1313 assert!(err.contains("non-global address"));
1314 }
1315
1316 #[test]
1317 fn resolved_mixed_ips_are_rejected() {
1318 let ips = vec![
1319 "93.184.216.34".parse().unwrap(),
1320 "10.0.0.1".parse().unwrap(),
1321 ];
1322 let err = validate_resolved_ips_are_public("example.com", &ips)
1323 .unwrap_err()
1324 .to_string();
1325 assert!(err.contains("non-global address"));
1326 }
1327
1328 #[test]
1329 fn resolved_public_ips_are_allowed() {
1330 let ips = vec!["93.184.216.34".parse().unwrap(), "1.1.1.1".parse().unwrap()];
1331 assert!(validate_resolved_ips_are_public("example.com", &ips).is_ok());
1332 }
1333
1334 #[test]
1337 fn firecrawl_config_defaults() {
1338 let cfg = FirecrawlConfig::default();
1339 assert!(!cfg.enabled);
1340 assert_eq!(cfg.api_key_env, "FIRECRAWL_API_KEY");
1341 assert_eq!(cfg.api_url, "https://api.firecrawl.dev/v1");
1342 assert_eq!(cfg.mode, zeroclaw_config::schema::FirecrawlMode::Scrape);
1343 }
1344
1345 #[test]
1346 fn firecrawl_config_deserializes_from_toml() {
1347 let toml_str = r#"
1348 enabled = true
1349 api_key_env = "MY_FC_KEY"
1350 api_url = "https://custom.firecrawl.io/v2"
1351 mode = "crawl"
1352 "#;
1353 let cfg: FirecrawlConfig = toml::from_str(toml_str).unwrap();
1354 assert!(cfg.enabled);
1355 assert_eq!(cfg.api_key_env, "MY_FC_KEY");
1356 assert_eq!(cfg.api_url, "https://custom.firecrawl.io/v2");
1357 assert_eq!(cfg.mode, zeroclaw_config::schema::FirecrawlMode::Crawl);
1358 }
1359
1360 #[test]
1361 fn firecrawl_config_deserializes_defaults_from_empty_toml() {
1362 let cfg: FirecrawlConfig = toml::from_str("").unwrap();
1363 assert!(!cfg.enabled);
1364 assert_eq!(cfg.api_key_env, "FIRECRAWL_API_KEY");
1365 }
1366
1367 #[test]
1368 fn web_fetch_config_with_firecrawl_section() {
1369 use zeroclaw_config::schema::WebFetchConfig;
1370 let toml_str = r#"
1371 enabled = true
1372 [firecrawl]
1373 enabled = true
1374 api_key_env = "FC_KEY"
1375 "#;
1376 let cfg: WebFetchConfig = toml::from_str(toml_str).unwrap();
1377 assert!(cfg.enabled);
1378 assert!(cfg.firecrawl.enabled);
1379 assert_eq!(cfg.firecrawl.api_key_env, "FC_KEY");
1380 }
1381
1382 #[test]
1385 fn fallback_disabled_when_firecrawl_not_enabled() {
1386 let tool = test_tool_with_firecrawl(FirecrawlConfig::default());
1387 let result = ToolResult {
1388 success: false,
1389 output: String::new(),
1390 error: Some("HTTP 403 Forbidden".into()),
1391 };
1392 assert!(!tool.should_fallback_to_firecrawl(&result));
1393 }
1394
1395 #[test]
1396 fn fallback_triggers_on_http_error() {
1397 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1398 enabled: true,
1399 ..FirecrawlConfig::default()
1400 });
1401 let result = ToolResult {
1402 success: false,
1403 output: String::new(),
1404 error: Some("HTTP 403 Forbidden".into()),
1405 };
1406 assert!(tool.should_fallback_to_firecrawl(&result));
1407 }
1408
1409 #[test]
1410 fn fallback_triggers_on_empty_body() {
1411 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1412 enabled: true,
1413 ..FirecrawlConfig::default()
1414 });
1415 let result = ToolResult {
1416 success: true,
1417 output: String::new(),
1418 error: None,
1419 };
1420 assert!(tool.should_fallback_to_firecrawl(&result));
1421 }
1422
1423 #[test]
1424 fn fallback_triggers_on_short_body() {
1425 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1426 enabled: true,
1427 ..FirecrawlConfig::default()
1428 });
1429 let result = ToolResult {
1430 success: true,
1431 output: "Loading...".into(), error: None,
1433 };
1434 assert!(tool.should_fallback_to_firecrawl(&result));
1435 }
1436
1437 #[test]
1438 fn fallback_skipped_on_good_response() {
1439 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1440 enabled: true,
1441 ..FirecrawlConfig::default()
1442 });
1443 let result = ToolResult {
1444 success: true,
1445 output: "A".repeat(200), error: None,
1447 };
1448 assert!(!tool.should_fallback_to_firecrawl(&result));
1449 }
1450
1451 #[test]
1454 fn firecrawl_response_parses_markdown() {
1455 let response_json = json!({
1456 "success": true,
1457 "data": {
1458 "markdown": "# Hello World\n\nThis is extracted content from Firecrawl.",
1459 "metadata": {
1460 "title": "Test Page"
1461 }
1462 }
1463 });
1464 let markdown = response_json
1465 .get("data")
1466 .and_then(|d| d.get("markdown"))
1467 .and_then(|m| m.as_str())
1468 .unwrap_or("");
1469 assert!(markdown.contains("Hello World"));
1470 assert!(markdown.contains("extracted content"));
1471 }
1472
1473 #[test]
1474 fn firecrawl_response_handles_missing_markdown() {
1475 let response_json = json!({
1476 "success": true,
1477 "data": {}
1478 });
1479 let markdown = response_json
1480 .get("data")
1481 .and_then(|d| d.get("markdown"))
1482 .and_then(|m| m.as_str())
1483 .unwrap_or("");
1484 assert!(markdown.is_empty());
1485 }
1486
1487 #[test]
1488 fn firecrawl_response_handles_missing_data() {
1489 let response_json = json!({
1490 "success": false,
1491 "error": "Rate limit exceeded"
1492 });
1493 let markdown = response_json
1494 .get("data")
1495 .and_then(|d| d.get("markdown"))
1496 .and_then(|m| m.as_str())
1497 .unwrap_or("");
1498 assert!(markdown.is_empty());
1499 }
1500
1501 #[test]
1504 fn fallback_triggers_at_exactly_99_chars() {
1505 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1506 enabled: true,
1507 ..FirecrawlConfig::default()
1508 });
1509 let result = ToolResult {
1510 success: true,
1511 output: "A".repeat(99),
1512 error: None,
1513 };
1514 assert!(
1515 tool.should_fallback_to_firecrawl(&result),
1516 "99-char body (below threshold) should trigger fallback"
1517 );
1518 }
1519
1520 #[test]
1521 fn fallback_skipped_at_exactly_100_chars() {
1522 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1523 enabled: true,
1524 ..FirecrawlConfig::default()
1525 });
1526 let result = ToolResult {
1527 success: true,
1528 output: "A".repeat(100),
1529 error: None,
1530 };
1531 assert!(
1532 !tool.should_fallback_to_firecrawl(&result),
1533 "100-char body (at threshold) should NOT trigger fallback"
1534 );
1535 }
1536
1537 #[tokio::test]
1540 async fn firecrawl_missing_api_key_returns_error() {
1541 unsafe { std::env::remove_var("FIRECRAWL_TEST_MISSING_KEY") };
1544
1545 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1546 enabled: true,
1547 api_key_env: "FIRECRAWL_TEST_MISSING_KEY".into(),
1548 ..FirecrawlConfig::default()
1549 });
1550
1551 let result = tool.fetch_via_firecrawl("https://example.com").await;
1552 assert!(
1553 result.is_err(),
1554 "fetch_via_firecrawl should return Err when API key env var is missing"
1555 );
1556 let err_msg = result.unwrap_err().to_string();
1557 assert!(
1558 err_msg.contains("FIRECRAWL_TEST_MISSING_KEY"),
1559 "Error should mention the missing env var name, got: {err_msg}"
1560 );
1561 }
1562
1563 #[tokio::test]
1566 async fn execute_double_failure_returns_original_result() {
1567 use wiremock::matchers::method;
1568 use wiremock::{Mock, MockServer, ResponseTemplate};
1569
1570 let server = MockServer::start().await;
1571 let addr = server.address();
1572
1573 Mock::given(method("GET"))
1575 .respond_with(ResponseTemplate::new(403))
1576 .mount(&server)
1577 .await;
1578
1579 unsafe { std::env::remove_var("FIRECRAWL_DOUBLE_FAIL_KEY") };
1582
1583 let security = Arc::new(SecurityPolicy {
1584 autonomy: AutonomyLevel::Supervised,
1585 ..SecurityPolicy::default()
1586 });
1587 let tool = WebFetchTool::new(
1588 security,
1589 vec!["*".into()],
1590 vec![],
1591 500_000,
1592 30,
1593 FirecrawlConfig {
1594 enabled: true,
1595 api_key_env: "FIRECRAWL_DOUBLE_FAIL_KEY".into(),
1596 api_url: format!("http://{addr}"),
1597 ..FirecrawlConfig::default()
1598 },
1599 vec![],
1600 )
1601 .unwrap();
1602
1603 let client = reqwest::Client::builder()
1606 .timeout(Duration::from_secs(30))
1607 .build()
1608 .unwrap();
1609
1610 let url = format!("http://{addr}/page");
1611 let standard_result = tool.standard_fetch(&client, &url).await;
1612
1613 assert!(!standard_result.success);
1615 assert!(tool.should_fallback_to_firecrawl(&standard_result));
1616
1617 let firecrawl_result = Box::pin(tool.fetch_via_firecrawl(&url)).await;
1619 assert!(
1620 firecrawl_result.is_err() || !firecrawl_result.as_ref().unwrap().success,
1621 "Expected Firecrawl fallback to fail without API key"
1622 );
1623
1624 assert!(
1626 standard_result
1627 .error
1628 .as_deref()
1629 .unwrap_or("")
1630 .contains("403"),
1631 "Expected original HTTP 403 error, got: {:?}",
1632 standard_result.error
1633 );
1634 }
1635
1636 #[tokio::test]
1639 async fn execute_falls_back_to_firecrawl_on_short_body() {
1640 use wiremock::matchers::{method, path};
1641 use wiremock::{Mock, MockServer, ResponseTemplate};
1642
1643 let standard_server = MockServer::start().await;
1645 Mock::given(method("GET"))
1646 .respond_with(
1647 ResponseTemplate::new(200)
1648 .set_body_string("<html><body>Loading...</body></html>")
1649 .insert_header("content-type", "text/html"),
1650 )
1651 .mount(&standard_server)
1652 .await;
1653
1654 let firecrawl_server = MockServer::start().await;
1656 Mock::given(method("POST"))
1657 .and(path("/scrape"))
1658 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1659 "success": true,
1660 "data": {
1661 "markdown": "# Real Content\n\nThis is the full page content extracted by Firecrawl, with enough text to be clearly above the minimum body length threshold."
1662 }
1663 })))
1664 .mount(&firecrawl_server)
1665 .await;
1666
1667 unsafe { std::env::set_var("FIRECRAWL_E2E_TEST_KEY", "test-key-12345") };
1670
1671 let security = Arc::new(SecurityPolicy {
1672 autonomy: AutonomyLevel::Supervised,
1673 ..SecurityPolicy::default()
1674 });
1675 let standard_addr = standard_server.address();
1676 let firecrawl_addr = firecrawl_server.address();
1677 let tool = WebFetchTool::new(
1678 security,
1679 vec!["*".into()],
1680 vec![],
1681 500_000,
1682 30,
1683 FirecrawlConfig {
1684 enabled: true,
1685 api_key_env: "FIRECRAWL_E2E_TEST_KEY".into(),
1686 api_url: format!("http://{firecrawl_addr}"),
1687 ..FirecrawlConfig::default()
1688 },
1689 vec![],
1690 )
1691 .unwrap();
1692
1693 let client = reqwest::Client::builder()
1696 .timeout(Duration::from_secs(30))
1697 .build()
1698 .unwrap();
1699
1700 let url = format!("http://{standard_addr}/page");
1701 let standard_result = tool.standard_fetch(&client, &url).await;
1702
1703 assert!(tool.should_fallback_to_firecrawl(&standard_result));
1705
1706 let result = Box::pin(tool.fetch_via_firecrawl(&url)).await.unwrap();
1708
1709 assert!(result.success, "Expected successful Firecrawl fallback");
1710 assert!(
1711 result.output.contains("Real Content"),
1712 "Expected Firecrawl markdown content, got: {}",
1713 result.output
1714 );
1715
1716 unsafe { std::env::remove_var("FIRECRAWL_E2E_TEST_KEY") };
1719 }
1720
1721 #[test]
1724 fn allowed_private_host_bypasses_ssrf_block() {
1725 let tool = test_tool_with_private_hosts(vec!["*"], vec![], vec!["192.168.1.5"]);
1726 assert!(tool.validate_url("https://192.168.1.5/api").is_ok());
1727 }
1728
1729 #[test]
1730 fn unallowed_private_host_still_blocked() {
1731 let tool = test_tool_with_private_hosts(vec!["*"], vec![], vec!["192.168.1.5"]);
1732 let err = tool
1733 .validate_url("https://10.0.0.1/admin")
1734 .unwrap_err()
1735 .to_string();
1736 assert!(err.contains("local/private"));
1737 assert!(err.contains("allowed_private_hosts"));
1738 }
1739
1740 #[test]
1741 fn blocklist_overrides_allowed_private_host() {
1742 let tool =
1743 test_tool_with_private_hosts(vec!["*"], vec!["192.168.1.5"], vec!["192.168.1.5"]);
1744 let err = tool
1745 .validate_url("https://192.168.1.5/secret")
1746 .unwrap_err()
1747 .to_string();
1748 assert!(err.contains("blocked_domains"));
1749 }
1750
1751 #[test]
1752 fn allowed_private_host_with_port() {
1753 let tool = test_tool_with_private_hosts(vec!["*"], vec![], vec!["192.168.1.5"]);
1754 assert!(tool.validate_url("https://192.168.1.5:8080/api").is_ok());
1755 }
1756}