Skip to main content

zeroclaw_tools/
image_gen.rs

1use anyhow::Context;
2use async_trait::async_trait;
3use serde_json::json;
4use std::path::PathBuf;
5use std::sync::Arc;
6use zeroclaw_api::tool::{Tool, ToolResult};
7use zeroclaw_config::policy::SecurityPolicy;
8use zeroclaw_config::policy::ToolOperation;
9
10/// Standalone image generation tool using fal.ai (Flux / Nano Banana models).
11///
12/// Reads the API key from an environment variable (default: `FAL_API_KEY`),
13/// calls the fal.ai synchronous endpoint, downloads the resulting image,
14/// and saves it to `{workspace}/images/{filename}.png`.
15pub struct ImageGenTool {
16    security: Arc<SecurityPolicy>,
17    workspace_dir: PathBuf,
18    default_model: String,
19    api_key_env: String,
20}
21
22impl ImageGenTool {
23    pub fn new(
24        security: Arc<SecurityPolicy>,
25        workspace_dir: PathBuf,
26        default_model: String,
27        api_key_env: String,
28    ) -> Self {
29        Self {
30            security,
31            workspace_dir,
32            default_model,
33            api_key_env,
34        }
35    }
36
37    /// Build a reusable HTTP client with reasonable timeouts.
38    fn http_client() -> reqwest::Client {
39        reqwest::Client::builder()
40            .timeout(std::time::Duration::from_secs(120))
41            .build()
42            .unwrap_or_default()
43    }
44
45    /// Read an API key from the environment.
46    fn read_api_key(env_var: &str) -> Result<String, String> {
47        std::env::var(env_var)
48            .map(|v| v.trim().to_string())
49            .ok()
50            .filter(|v| !v.is_empty())
51            .ok_or_else(|| format!("Missing API key: set the {env_var} environment variable"))
52    }
53
54    /// Core generation logic: call fal.ai, download image, save to disk.
55    async fn generate(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
56        // ── Parse parameters ───────────────────────────────────────
57        let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
58            Some(p) if !p.trim().is_empty() => p.trim().to_string(),
59            _ => {
60                return Ok(ToolResult {
61                    success: false,
62                    output: String::new(),
63                    error: Some("Missing required parameter: 'prompt'".into()),
64                });
65            }
66        };
67
68        let filename = args
69            .get("filename")
70            .and_then(|v| v.as_str())
71            .filter(|s| !s.trim().is_empty())
72            .unwrap_or("generated_image");
73
74        // Sanitize filename — strip path components to prevent traversal.
75        let safe_name = PathBuf::from(filename).file_name().map_or_else(
76            || "generated_image".to_string(),
77            |n| n.to_string_lossy().to_string(),
78        );
79
80        let size = args
81            .get("size")
82            .and_then(|v| v.as_str())
83            .unwrap_or("square_hd");
84
85        // Validate size enum.
86        const VALID_SIZES: &[&str] = &[
87            "square_hd",
88            "landscape_4_3",
89            "portrait_4_3",
90            "landscape_16_9",
91            "portrait_16_9",
92        ];
93        if !VALID_SIZES.contains(&size) {
94            return Ok(ToolResult {
95                success: false,
96                output: String::new(),
97                error: Some(format!(
98                    "Invalid size '{size}'. Valid values: {}",
99                    VALID_SIZES.join(", ")
100                )),
101            });
102        }
103
104        let model = args
105            .get("model")
106            .and_then(|v| v.as_str())
107            .filter(|s| !s.trim().is_empty())
108            .unwrap_or(&self.default_model);
109
110        // Validate model identifier: must look like a fal.ai model path
111        // (e.g. "fal-ai/flux/schnell"). Reject values with "..", query
112        // strings, or fragments that could redirect the HTTP request.
113        if model.contains("..")
114            || model.contains('?')
115            || model.contains('#')
116            || model.contains('\\')
117            || model.starts_with('/')
118        {
119            return Ok(ToolResult {
120                success: false,
121                output: String::new(),
122                error: Some(format!(
123                    "Invalid model identifier '{model}'. \
124                     Must be a fal.ai model path (e.g. 'fal-ai/flux/schnell')."
125                )),
126            });
127        }
128
129        // ── Read API key ───────────────────────────────────────────
130        let api_key = match Self::read_api_key(&self.api_key_env) {
131            Ok(k) => k,
132            Err(msg) => {
133                return Ok(ToolResult {
134                    success: false,
135                    output: String::new(),
136                    error: Some(msg),
137                });
138            }
139        };
140
141        // ── Call fal.ai ────────────────────────────────────────────
142        let client = Self::http_client();
143        let url = format!("https://fal.run/{model}");
144
145        let body = json!({
146            "prompt": prompt,
147            "image_size": size,
148            "num_images": 1
149        });
150
151        let resp = client
152            .post(&url)
153            .header("Authorization", format!("Key {api_key}"))
154            .header("Content-Type", "application/json")
155            .json(&body)
156            .send()
157            .await
158            .context("fal.ai request failed")?;
159
160        let status = resp.status();
161        if !status.is_success() {
162            let body_text = resp.text().await.unwrap_or_default();
163            return Ok(ToolResult {
164                success: false,
165                output: String::new(),
166                error: Some(format!("fal.ai API error ({status}): {body_text}")),
167            });
168        }
169
170        let resp_json: serde_json::Value = resp
171            .json()
172            .await
173            .context("Failed to parse fal.ai response as JSON")?;
174
175        let image_url = resp_json
176            .pointer("/images/0/url")
177            .and_then(|v| v.as_str())
178            .ok_or_else(|| {
179                ::zeroclaw_log::record!(
180                    ERROR,
181                    ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
182                        .with_outcome(::zeroclaw_log::EventOutcome::Failure),
183                    "image_gen: fal.ai response missing image URL"
184                );
185                anyhow::Error::msg("No image URL in fal.ai response")
186            })?;
187
188        // ── Download image ─────────────────────────────────────────
189        let img_resp = client
190            .get(image_url)
191            .send()
192            .await
193            .context("Failed to download generated image")?;
194
195        if !img_resp.status().is_success() {
196            return Ok(ToolResult {
197                success: false,
198                output: String::new(),
199                error: Some(format!(
200                    "Failed to download image from {image_url} ({})",
201                    img_resp.status()
202                )),
203            });
204        }
205
206        let bytes = img_resp
207            .bytes()
208            .await
209            .context("Failed to read image bytes")?;
210
211        // ── Save to disk ───────────────────────────────────────────
212        let images_dir = self.workspace_dir.join("images");
213        tokio::fs::create_dir_all(&images_dir)
214            .await
215            .context("Failed to create images directory")?;
216
217        let output_path = images_dir.join(format!("{safe_name}.png"));
218        tokio::fs::write(&output_path, &bytes)
219            .await
220            .context("Failed to write image file")?;
221
222        let size_kb = bytes.len() / 1024;
223
224        Ok(ToolResult {
225            success: true,
226            output: format!(
227                "Image generated successfully.\n\
228                 File: {}\n\
229                 Size: {} KB\n\
230                 Model: {}\n\
231                 Prompt: {}",
232                output_path.display(),
233                size_kb,
234                model,
235                prompt,
236            ),
237            error: None,
238        })
239    }
240}
241
242#[async_trait]
243impl Tool for ImageGenTool {
244    fn name(&self) -> &str {
245        "image_gen"
246    }
247
248    fn description(&self) -> &str {
249        "Generate an image from a text prompt using fal.ai (Flux models). \
250         Saves the result to the workspace images directory and returns the file path."
251    }
252
253    fn parameters_schema(&self) -> serde_json::Value {
254        json!({
255            "type": "object",
256            "required": ["prompt"],
257            "properties": {
258                "prompt": {
259                    "type": "string",
260                    "description": "Text prompt describing the image to generate."
261                },
262                "filename": {
263                    "type": "string",
264                    "description": "Output filename without extension (default: 'generated_image'). Saved as PNG in workspace/images/."
265                },
266                "size": {
267                    "type": "string",
268                    "enum": ["square_hd", "landscape_4_3", "portrait_4_3", "landscape_16_9", "portrait_16_9"],
269                    "description": "Image aspect ratio / size preset (default: 'square_hd')."
270                },
271                "model": {
272                    "type": "string",
273                    "description": "fal.ai model identifier (default: 'fal-ai/flux/schnell')."
274                }
275            }
276        })
277    }
278
279    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
280        // Security: image generation is a side-effecting action (HTTP + file write).
281        if let Err(error) = self
282            .security
283            .enforce_tool_operation(ToolOperation::Act, "image_gen")
284        {
285            return Ok(ToolResult {
286                success: false,
287                output: String::new(),
288                error: Some(error),
289            });
290        }
291
292        self.generate(args).await
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use zeroclaw_config::autonomy::AutonomyLevel;
300    use zeroclaw_config::policy::SecurityPolicy;
301
302    fn test_security() -> Arc<SecurityPolicy> {
303        Arc::new(SecurityPolicy {
304            autonomy: AutonomyLevel::Full,
305            workspace_dir: std::env::temp_dir(),
306            ..SecurityPolicy::default()
307        })
308    }
309
310    fn test_tool() -> ImageGenTool {
311        ImageGenTool::new(
312            test_security(),
313            std::env::temp_dir(),
314            "fal-ai/flux/schnell".into(),
315            "FAL_API_KEY".into(),
316        )
317    }
318
319    #[test]
320    fn tool_name() {
321        let tool = test_tool();
322        assert_eq!(tool.name(), "image_gen");
323    }
324
325    #[test]
326    fn tool_description_is_nonempty() {
327        let tool = test_tool();
328        assert!(!tool.description().is_empty());
329        assert!(tool.description().contains("image"));
330    }
331
332    #[test]
333    fn tool_schema_has_required_prompt() {
334        let tool = test_tool();
335        let schema = tool.parameters_schema();
336        assert_eq!(schema["required"], json!(["prompt"]));
337        assert!(schema["properties"]["prompt"].is_object());
338    }
339
340    #[test]
341    fn tool_schema_has_optional_params() {
342        let tool = test_tool();
343        let schema = tool.parameters_schema();
344        assert!(schema["properties"]["filename"].is_object());
345        assert!(schema["properties"]["size"].is_object());
346        assert!(schema["properties"]["model"].is_object());
347    }
348
349    #[test]
350    fn tool_spec_roundtrip() {
351        let tool = test_tool();
352        let spec = tool.spec();
353        assert_eq!(spec.name, "image_gen");
354        assert!(spec.parameters.is_object());
355    }
356
357    #[tokio::test]
358    async fn missing_prompt_returns_error() {
359        let tool = test_tool();
360        let result = tool.execute(json!({})).await.unwrap();
361        assert!(!result.success);
362        assert!(result.error.as_deref().unwrap().contains("prompt"));
363    }
364
365    #[tokio::test]
366    async fn empty_prompt_returns_error() {
367        let tool = test_tool();
368        let result = tool.execute(json!({"prompt": "   "})).await.unwrap();
369        assert!(!result.success);
370        assert!(result.error.as_deref().unwrap().contains("prompt"));
371    }
372
373    #[tokio::test]
374    async fn missing_api_key_returns_error() {
375        // Temporarily ensure the env var is unset.
376        let original = std::env::var("FAL_API_KEY_TEST_IMAGE_GEN").ok();
377        // SAFETY: test-only, single-threaded test runner.
378        unsafe { std::env::remove_var("FAL_API_KEY_TEST_IMAGE_GEN") };
379
380        let tool = ImageGenTool::new(
381            test_security(),
382            std::env::temp_dir(),
383            "fal-ai/flux/schnell".into(),
384            "FAL_API_KEY_TEST_IMAGE_GEN".into(),
385        );
386        let result = tool
387            .execute(json!({"prompt": "a sunset over the ocean"}))
388            .await
389            .unwrap();
390        assert!(!result.success);
391        assert!(
392            result
393                .error
394                .as_deref()
395                .unwrap()
396                .contains("FAL_API_KEY_TEST_IMAGE_GEN")
397        );
398
399        // Restore if it was set.
400        if let Some(val) = original {
401            // SAFETY: test-only, single-threaded test runner.
402            unsafe { std::env::set_var("FAL_API_KEY_TEST_IMAGE_GEN", val) };
403        }
404    }
405
406    #[tokio::test]
407    async fn invalid_size_returns_error() {
408        // Set a dummy key so we get past the key check.
409        // SAFETY: test-only, single-threaded test runner.
410        unsafe { std::env::set_var("FAL_API_KEY_TEST_SIZE", "dummy_key") };
411
412        let tool = ImageGenTool::new(
413            test_security(),
414            std::env::temp_dir(),
415            "fal-ai/flux/schnell".into(),
416            "FAL_API_KEY_TEST_SIZE".into(),
417        );
418        let result = tool
419            .execute(json!({"prompt": "test", "size": "invalid_size"}))
420            .await
421            .unwrap();
422        assert!(!result.success);
423        assert!(result.error.as_deref().unwrap().contains("Invalid size"));
424
425        // SAFETY: test-only, single-threaded test runner.
426        unsafe { std::env::remove_var("FAL_API_KEY_TEST_SIZE") };
427    }
428
429    #[tokio::test]
430    async fn read_only_autonomy_blocks_execution() {
431        let security = Arc::new(SecurityPolicy {
432            autonomy: AutonomyLevel::ReadOnly,
433            workspace_dir: std::env::temp_dir(),
434            ..SecurityPolicy::default()
435        });
436        let tool = ImageGenTool::new(
437            security,
438            std::env::temp_dir(),
439            "fal-ai/flux/schnell".into(),
440            "FAL_API_KEY".into(),
441        );
442        let result = tool.execute(json!({"prompt": "test image"})).await.unwrap();
443        assert!(!result.success);
444        let err = result.error.as_deref().unwrap();
445        assert!(
446            err.contains("read-only") || err.contains("image_gen"),
447            "expected read-only or image_gen in error, got: {err}"
448        );
449    }
450
451    #[tokio::test]
452    async fn invalid_model_with_traversal_returns_error() {
453        // SAFETY: test-only, single-threaded test runner.
454        unsafe { std::env::set_var("FAL_API_KEY_TEST_MODEL", "dummy_key") };
455
456        let tool = ImageGenTool::new(
457            test_security(),
458            std::env::temp_dir(),
459            "fal-ai/flux/schnell".into(),
460            "FAL_API_KEY_TEST_MODEL".into(),
461        );
462        let result = tool
463            .execute(json!({"prompt": "test", "model": "../../evil-endpoint"}))
464            .await
465            .unwrap();
466        assert!(!result.success);
467        assert!(
468            result
469                .error
470                .as_deref()
471                .unwrap()
472                .contains("Invalid model identifier")
473        );
474
475        // SAFETY: test-only, single-threaded test runner.
476        unsafe { std::env::remove_var("FAL_API_KEY_TEST_MODEL") };
477    }
478
479    #[test]
480    fn read_api_key_missing() {
481        let result = ImageGenTool::read_api_key("DEFINITELY_NOT_SET_ZC_TEST_12345");
482        assert!(result.is_err());
483        assert!(
484            result
485                .unwrap_err()
486                .contains("DEFINITELY_NOT_SET_ZC_TEST_12345")
487        );
488    }
489
490    #[test]
491    fn filename_traversal_is_sanitized() {
492        // Verify that path traversal in filenames is stripped to just the final component.
493        let sanitized = PathBuf::from("../../etc/passwd").file_name().map_or_else(
494            || "generated_image".to_string(),
495            |n| n.to_string_lossy().to_string(),
496        );
497        assert_eq!(sanitized, "passwd");
498
499        // ".." alone has no file_name, falls back to default.
500        let sanitized = PathBuf::from("..").file_name().map_or_else(
501            || "generated_image".to_string(),
502            |n| n.to_string_lossy().to_string(),
503        );
504        assert_eq!(sanitized, "generated_image");
505    }
506
507    #[test]
508    fn read_api_key_present() {
509        // SAFETY: test-only, single-threaded test runner.
510        unsafe { std::env::set_var("ZC_IMAGE_GEN_TEST_KEY", "test_value_123") };
511        let result = ImageGenTool::read_api_key("ZC_IMAGE_GEN_TEST_KEY");
512        assert!(result.is_ok());
513        assert_eq!(result.unwrap(), "test_value_123");
514        // SAFETY: test-only, single-threaded test runner.
515        unsafe { std::env::remove_var("ZC_IMAGE_GEN_TEST_KEY") };
516    }
517}