Skip to main content

zeroclaw_tools/
image_gen.rs

1use anyhow::Context;
2use async_trait::async_trait;
3use serde_json::json;
4use std::path::PathBuf;
5use std::sync::Arc;
6use zeroclaw_api::tool::{Tool, ToolResult, with_ephemeral_workspace_warning};
7use zeroclaw_config::policy::SecurityPolicy;
8use zeroclaw_config::policy::ToolOperation;
9
10/// Standalone image generation tool using fal.ai (Flux / Nano Banana models).
11///
12/// Reads the API key from an environment variable (default: `FAL_API_KEY`),
13/// calls the fal.ai synchronous endpoint, downloads the resulting image,
14/// and saves it to `{workspace}/images/{filename}.png`.
15pub struct ImageGenTool {
16    security: Arc<SecurityPolicy>,
17    workspace_dir: PathBuf,
18    default_model: String,
19    api_key_env: String,
20    /// Whether the saved image persists on the host filesystem. `false` on an
21    /// ephemeral runtime (Docker tmpfs / no volume mount), where the PNG is
22    /// written inside the container but invisible on the host and discarded at
23    /// session end. When `false`, a successful generation carries a loud
24    /// ephemeral-workspace warning. Mirrors
25    /// [`super::file_write::FileWriteTool`]. See issue #4627.
26    persistent_writes: bool,
27}
28
29impl ImageGenTool {
30    pub fn new(
31        security: Arc<SecurityPolicy>,
32        workspace_dir: PathBuf,
33        default_model: String,
34        api_key_env: String,
35    ) -> Self {
36        Self {
37            security,
38            workspace_dir,
39            default_model,
40            api_key_env,
41            persistent_writes: true,
42        }
43    }
44
45    /// Construct with an explicit persistence flag derived from the active
46    /// runtime adapter's `has_filesystem_access()`. Mirrors
47    /// [`super::file_write::FileWriteTool::new_with_persistence`].
48    pub fn new_with_persistence(
49        security: Arc<SecurityPolicy>,
50        workspace_dir: PathBuf,
51        default_model: String,
52        api_key_env: String,
53        persistent_writes: bool,
54    ) -> Self {
55        Self {
56            security,
57            workspace_dir,
58            default_model,
59            api_key_env,
60            persistent_writes,
61        }
62    }
63
64    /// Build a reusable HTTP client with reasonable timeouts.
65    fn http_client() -> reqwest::Client {
66        reqwest::Client::builder()
67            .timeout(std::time::Duration::from_secs(120))
68            .build()
69            .unwrap_or_default()
70    }
71
72    /// Read an API key from the environment.
73    fn read_api_key(env_var: &str) -> Result<String, String> {
74        std::env::var(env_var)
75            .map(|v| v.trim().to_string())
76            .ok()
77            .filter(|v| !v.is_empty())
78            .ok_or_else(|| format!("Missing API key: set the {env_var} environment variable"))
79    }
80
81    /// Core generation logic: call fal.ai, download image, save to disk.
82    async fn generate(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
83        // ── Parse parameters ───────────────────────────────────────
84        let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
85            Some(p) if !p.trim().is_empty() => p.trim().to_string(),
86            _ => {
87                return Ok(ToolResult {
88                    success: false,
89                    output: String::new(),
90                    error: Some("Missing required parameter: 'prompt'".into()),
91                });
92            }
93        };
94
95        let filename = args
96            .get("filename")
97            .and_then(|v| v.as_str())
98            .filter(|s| !s.trim().is_empty())
99            .unwrap_or("generated_image");
100
101        // Sanitize filename — strip path components to prevent traversal.
102        let safe_name = PathBuf::from(filename).file_name().map_or_else(
103            || "generated_image".to_string(),
104            |n| n.to_string_lossy().to_string(),
105        );
106
107        let size = args
108            .get("size")
109            .and_then(|v| v.as_str())
110            .unwrap_or("square_hd");
111
112        // Validate size enum.
113        const VALID_SIZES: &[&str] = &[
114            "square_hd",
115            "landscape_4_3",
116            "portrait_4_3",
117            "landscape_16_9",
118            "portrait_16_9",
119        ];
120        if !VALID_SIZES.contains(&size) {
121            return Ok(ToolResult {
122                success: false,
123                output: String::new(),
124                error: Some(format!(
125                    "Invalid size '{size}'. Valid values: {}",
126                    VALID_SIZES.join(", ")
127                )),
128            });
129        }
130
131        let model = args
132            .get("model")
133            .and_then(|v| v.as_str())
134            .filter(|s| !s.trim().is_empty())
135            .unwrap_or(&self.default_model);
136
137        // Validate model identifier: must look like a fal.ai model path
138        // (e.g. "fal-ai/flux/schnell"). Reject values with "..", query
139        // strings, or fragments that could redirect the HTTP request.
140        if model.contains("..")
141            || model.contains('?')
142            || model.contains('#')
143            || model.contains('\\')
144            || model.starts_with('/')
145        {
146            return Ok(ToolResult {
147                success: false,
148                output: String::new(),
149                error: Some(format!(
150                    "Invalid model identifier '{model}'. \
151                     Must be a fal.ai model path (e.g. 'fal-ai/flux/schnell')."
152                )),
153            });
154        }
155
156        // ── Read API key ───────────────────────────────────────────
157        let api_key = match Self::read_api_key(&self.api_key_env) {
158            Ok(k) => k,
159            Err(msg) => {
160                return Ok(ToolResult {
161                    success: false,
162                    output: String::new(),
163                    error: Some(msg),
164                });
165            }
166        };
167
168        // ── Call fal.ai ────────────────────────────────────────────
169        let client = Self::http_client();
170        let url = format!("https://fal.run/{model}");
171
172        let body = json!({
173            "prompt": prompt,
174            "image_size": size,
175            "num_images": 1
176        });
177
178        let resp = client
179            .post(&url)
180            .header("Authorization", format!("Key {api_key}"))
181            .header("Content-Type", "application/json")
182            .json(&body)
183            .send()
184            .await
185            .context("fal.ai request failed")?;
186
187        let status = resp.status();
188        if !status.is_success() {
189            let body_text = resp.text().await.unwrap_or_default();
190            return Ok(ToolResult {
191                success: false,
192                output: String::new(),
193                error: Some(format!("fal.ai API error ({status}): {body_text}")),
194            });
195        }
196
197        let resp_json: serde_json::Value = resp
198            .json()
199            .await
200            .context("Failed to parse fal.ai response as JSON")?;
201
202        let image_url = resp_json
203            .pointer("/images/0/url")
204            .and_then(|v| v.as_str())
205            .ok_or_else(|| {
206                ::zeroclaw_log::record!(
207                    ERROR,
208                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
209                        .with_outcome(::zeroclaw_log::EventOutcome::Failure),
210                    "image_gen: fal.ai response missing image URL"
211                );
212                anyhow::Error::msg("No image URL in fal.ai response")
213            })?;
214
215        // ── Download image ─────────────────────────────────────────
216        let img_resp = client
217            .get(image_url)
218            .send()
219            .await
220            .context("Failed to download generated image")?;
221
222        if !img_resp.status().is_success() {
223            return Ok(ToolResult {
224                success: false,
225                output: String::new(),
226                error: Some(format!(
227                    "Failed to download image from {image_url} ({})",
228                    img_resp.status()
229                )),
230            });
231        }
232
233        let bytes = img_resp
234            .bytes()
235            .await
236            .context("Failed to read image bytes")?;
237
238        // ── Save to disk ───────────────────────────────────────────
239        let images_dir = self.workspace_dir.join("images");
240        tokio::fs::create_dir_all(&images_dir)
241            .await
242            .context("Failed to create images directory")?;
243
244        let output_path = images_dir.join(format!("{safe_name}.png"));
245        tokio::fs::write(&output_path, &bytes)
246            .await
247            .context("Failed to write image file")?;
248
249        let size_kb = bytes.len() / 1024;
250
251        Ok(ToolResult {
252            success: true,
253            output: format!(
254                "Image generated successfully.\n\
255                 File: {}\n\
256                 Size: {} KB\n\
257                 Model: {}\n\
258                 Prompt: {}",
259                output_path.display(),
260                size_kb,
261                model,
262                prompt,
263            ),
264            error: None,
265        })
266    }
267}
268
269#[async_trait]
270impl Tool for ImageGenTool {
271    fn name(&self) -> &str {
272        "image_gen"
273    }
274
275    fn description(&self) -> &str {
276        "Generate an image from a text prompt using fal.ai (Flux models). \
277         Saves the result to the workspace images directory and returns the file path."
278    }
279
280    fn parameters_schema(&self) -> serde_json::Value {
281        json!({
282            "type": "object",
283            "required": ["prompt"],
284            "properties": {
285                "prompt": {
286                    "type": "string",
287                    "description": "Text prompt describing the image to generate."
288                },
289                "filename": {
290                    "type": "string",
291                    "description": "Output filename without extension (default: 'generated_image'). Saved as PNG in workspace/images/."
292                },
293                "size": {
294                    "type": "string",
295                    "enum": ["square_hd", "landscape_4_3", "portrait_4_3", "landscape_16_9", "portrait_16_9"],
296                    "description": "Image aspect ratio / size preset (default: 'square_hd')."
297                },
298                "model": {
299                    "type": "string",
300                    "description": "fal.ai model identifier (default: 'fal-ai/flux/schnell')."
301                }
302            }
303        })
304    }
305
306    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
307        // Security: image generation is a side-effecting action (HTTP + file write).
308        if let Err(error) = self
309            .security
310            .enforce_tool_operation(ToolOperation::Act, "image_gen")
311        {
312            return Ok(ToolResult {
313                success: false,
314                output: String::new(),
315                error: Some(error),
316            });
317        }
318
319        let mut result = self.generate(args).await?;
320        // A generated image saved to an ephemeral workspace never reaches the
321        // host and is lost at session end; warn loudly on success (issue #4627).
322        if !self.persistent_writes && result.success {
323            result.output = with_ephemeral_workspace_warning(&result.output);
324        }
325        Ok(result)
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use zeroclaw_config::autonomy::AutonomyLevel;
333    use zeroclaw_config::policy::SecurityPolicy;
334
335    fn test_security() -> Arc<SecurityPolicy> {
336        Arc::new(SecurityPolicy {
337            autonomy: AutonomyLevel::Full,
338            workspace_dir: std::env::temp_dir(),
339            ..SecurityPolicy::default()
340        })
341    }
342
343    fn test_tool() -> ImageGenTool {
344        ImageGenTool::new(
345            test_security(),
346            std::env::temp_dir(),
347            "fal-ai/flux/schnell".into(),
348            "FAL_API_KEY".into(),
349        )
350    }
351
352    #[test]
353    fn tool_name() {
354        let tool = test_tool();
355        assert_eq!(tool.name(), "image_gen");
356    }
357
358    #[test]
359    fn tool_description_is_nonempty() {
360        let tool = test_tool();
361        assert!(!tool.description().is_empty());
362        assert!(tool.description().contains("image"));
363    }
364
365    #[test]
366    fn tool_schema_has_required_prompt() {
367        let tool = test_tool();
368        let schema = tool.parameters_schema();
369        assert_eq!(schema["required"], json!(["prompt"]));
370        assert!(schema["properties"]["prompt"].is_object());
371    }
372
373    #[test]
374    fn tool_schema_has_optional_params() {
375        let tool = test_tool();
376        let schema = tool.parameters_schema();
377        assert!(schema["properties"]["filename"].is_object());
378        assert!(schema["properties"]["size"].is_object());
379        assert!(schema["properties"]["model"].is_object());
380    }
381
382    #[test]
383    fn tool_spec_roundtrip() {
384        let tool = test_tool();
385        let spec = tool.spec();
386        assert_eq!(spec.name, "image_gen");
387        assert!(spec.parameters.is_object());
388    }
389
390    #[tokio::test]
391    async fn missing_prompt_returns_error() {
392        let tool = test_tool();
393        let result = tool.execute(json!({})).await.unwrap();
394        assert!(!result.success);
395        assert!(result.error.as_deref().unwrap().contains("prompt"));
396    }
397
398    #[tokio::test]
399    async fn empty_prompt_returns_error() {
400        let tool = test_tool();
401        let result = tool.execute(json!({"prompt": "   "})).await.unwrap();
402        assert!(!result.success);
403        assert!(result.error.as_deref().unwrap().contains("prompt"));
404    }
405
406    #[tokio::test]
407    async fn missing_api_key_returns_error() {
408        // Temporarily ensure the env var is unset.
409        let original = std::env::var("FAL_API_KEY_TEST_IMAGE_GEN").ok();
410        // SAFETY: test-only, single-threaded test runner.
411        unsafe { std::env::remove_var("FAL_API_KEY_TEST_IMAGE_GEN") };
412
413        let tool = ImageGenTool::new(
414            test_security(),
415            std::env::temp_dir(),
416            "fal-ai/flux/schnell".into(),
417            "FAL_API_KEY_TEST_IMAGE_GEN".into(),
418        );
419        let result = tool
420            .execute(json!({"prompt": "a sunset over the ocean"}))
421            .await
422            .unwrap();
423        assert!(!result.success);
424        assert!(
425            result
426                .error
427                .as_deref()
428                .unwrap()
429                .contains("FAL_API_KEY_TEST_IMAGE_GEN")
430        );
431
432        // Restore if it was set.
433        if let Some(val) = original {
434            // SAFETY: test-only, single-threaded test runner.
435            unsafe { std::env::set_var("FAL_API_KEY_TEST_IMAGE_GEN", val) };
436        }
437    }
438
439    #[tokio::test]
440    async fn invalid_size_returns_error() {
441        // Set a dummy key so we get past the key check.
442        // SAFETY: test-only, single-threaded test runner.
443        unsafe { std::env::set_var("FAL_API_KEY_TEST_SIZE", "dummy_key") };
444
445        let tool = ImageGenTool::new(
446            test_security(),
447            std::env::temp_dir(),
448            "fal-ai/flux/schnell".into(),
449            "FAL_API_KEY_TEST_SIZE".into(),
450        );
451        let result = tool
452            .execute(json!({"prompt": "test", "size": "invalid_size"}))
453            .await
454            .unwrap();
455        assert!(!result.success);
456        assert!(result.error.as_deref().unwrap().contains("Invalid size"));
457
458        // SAFETY: test-only, single-threaded test runner.
459        unsafe { std::env::remove_var("FAL_API_KEY_TEST_SIZE") };
460    }
461
462    #[tokio::test]
463    async fn read_only_autonomy_blocks_execution() {
464        let security = Arc::new(SecurityPolicy {
465            autonomy: AutonomyLevel::ReadOnly,
466            workspace_dir: std::env::temp_dir(),
467            ..SecurityPolicy::default()
468        });
469        let tool = ImageGenTool::new(
470            security,
471            std::env::temp_dir(),
472            "fal-ai/flux/schnell".into(),
473            "FAL_API_KEY".into(),
474        );
475        let result = tool.execute(json!({"prompt": "test image"})).await.unwrap();
476        assert!(!result.success);
477        let err = result.error.as_deref().unwrap();
478        assert!(
479            err.contains("read-only") || err.contains("image_gen"),
480            "expected read-only or image_gen in error, got: {err}"
481        );
482    }
483
484    #[tokio::test]
485    async fn invalid_model_with_traversal_returns_error() {
486        // SAFETY: test-only, single-threaded test runner.
487        unsafe { std::env::set_var("FAL_API_KEY_TEST_MODEL", "dummy_key") };
488
489        let tool = ImageGenTool::new(
490            test_security(),
491            std::env::temp_dir(),
492            "fal-ai/flux/schnell".into(),
493            "FAL_API_KEY_TEST_MODEL".into(),
494        );
495        let result = tool
496            .execute(json!({"prompt": "test", "model": "../../evil-endpoint"}))
497            .await
498            .unwrap();
499        assert!(!result.success);
500        assert!(
501            result
502                .error
503                .as_deref()
504                .unwrap()
505                .contains("Invalid model identifier")
506        );
507
508        // SAFETY: test-only, single-threaded test runner.
509        unsafe { std::env::remove_var("FAL_API_KEY_TEST_MODEL") };
510    }
511
512    #[test]
513    fn read_api_key_missing() {
514        let result = ImageGenTool::read_api_key("DEFINITELY_NOT_SET_ZC_TEST_12345");
515        assert!(result.is_err());
516        assert!(
517            result
518                .unwrap_err()
519                .contains("DEFINITELY_NOT_SET_ZC_TEST_12345")
520        );
521    }
522
523    #[test]
524    fn filename_traversal_is_sanitized() {
525        // Verify that path traversal in filenames is stripped to just the final component.
526        let sanitized = PathBuf::from("../../etc/passwd").file_name().map_or_else(
527            || "generated_image".to_string(),
528            |n| n.to_string_lossy().to_string(),
529        );
530        assert_eq!(sanitized, "passwd");
531
532        // ".." alone has no file_name, falls back to default.
533        let sanitized = PathBuf::from("..").file_name().map_or_else(
534            || "generated_image".to_string(),
535            |n| n.to_string_lossy().to_string(),
536        );
537        assert_eq!(sanitized, "generated_image");
538    }
539
540    #[test]
541    fn read_api_key_present() {
542        // SAFETY: test-only, single-threaded test runner.
543        unsafe { std::env::set_var("ZC_IMAGE_GEN_TEST_KEY", "test_value_123") };
544        let result = ImageGenTool::read_api_key("ZC_IMAGE_GEN_TEST_KEY");
545        assert!(result.is_ok());
546        assert_eq!(result.unwrap(), "test_value_123");
547        // SAFETY: test-only, single-threaded test runner.
548        unsafe { std::env::remove_var("ZC_IMAGE_GEN_TEST_KEY") };
549    }
550}