1use anyhow::Context;
2use async_trait::async_trait;
3use serde_json::json;
4use std::path::PathBuf;
5use std::sync::Arc;
6use zeroclaw_api::tool::{Tool, ToolResult};
7use zeroclaw_config::policy::SecurityPolicy;
8use zeroclaw_config::policy::ToolOperation;
9
10pub struct ImageGenTool {
16 security: Arc<SecurityPolicy>,
17 workspace_dir: PathBuf,
18 default_model: String,
19 api_key_env: String,
20}
21
22impl ImageGenTool {
23 pub fn new(
24 security: Arc<SecurityPolicy>,
25 workspace_dir: PathBuf,
26 default_model: String,
27 api_key_env: String,
28 ) -> Self {
29 Self {
30 security,
31 workspace_dir,
32 default_model,
33 api_key_env,
34 }
35 }
36
37 fn http_client() -> reqwest::Client {
39 reqwest::Client::builder()
40 .timeout(std::time::Duration::from_secs(120))
41 .build()
42 .unwrap_or_default()
43 }
44
45 fn read_api_key(env_var: &str) -> Result<String, String> {
47 std::env::var(env_var)
48 .map(|v| v.trim().to_string())
49 .ok()
50 .filter(|v| !v.is_empty())
51 .ok_or_else(|| format!("Missing API key: set the {env_var} environment variable"))
52 }
53
54 async fn generate(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
56 let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
58 Some(p) if !p.trim().is_empty() => p.trim().to_string(),
59 _ => {
60 return Ok(ToolResult {
61 success: false,
62 output: String::new(),
63 error: Some("Missing required parameter: 'prompt'".into()),
64 });
65 }
66 };
67
68 let filename = args
69 .get("filename")
70 .and_then(|v| v.as_str())
71 .filter(|s| !s.trim().is_empty())
72 .unwrap_or("generated_image");
73
74 let safe_name = PathBuf::from(filename).file_name().map_or_else(
76 || "generated_image".to_string(),
77 |n| n.to_string_lossy().to_string(),
78 );
79
80 let size = args
81 .get("size")
82 .and_then(|v| v.as_str())
83 .unwrap_or("square_hd");
84
85 const VALID_SIZES: &[&str] = &[
87 "square_hd",
88 "landscape_4_3",
89 "portrait_4_3",
90 "landscape_16_9",
91 "portrait_16_9",
92 ];
93 if !VALID_SIZES.contains(&size) {
94 return Ok(ToolResult {
95 success: false,
96 output: String::new(),
97 error: Some(format!(
98 "Invalid size '{size}'. Valid values: {}",
99 VALID_SIZES.join(", ")
100 )),
101 });
102 }
103
104 let model = args
105 .get("model")
106 .and_then(|v| v.as_str())
107 .filter(|s| !s.trim().is_empty())
108 .unwrap_or(&self.default_model);
109
110 if model.contains("..")
114 || model.contains('?')
115 || model.contains('#')
116 || model.contains('\\')
117 || model.starts_with('/')
118 {
119 return Ok(ToolResult {
120 success: false,
121 output: String::new(),
122 error: Some(format!(
123 "Invalid model identifier '{model}'. \
124 Must be a fal.ai model path (e.g. 'fal-ai/flux/schnell')."
125 )),
126 });
127 }
128
129 let api_key = match Self::read_api_key(&self.api_key_env) {
131 Ok(k) => k,
132 Err(msg) => {
133 return Ok(ToolResult {
134 success: false,
135 output: String::new(),
136 error: Some(msg),
137 });
138 }
139 };
140
141 let client = Self::http_client();
143 let url = format!("https://fal.run/{model}");
144
145 let body = json!({
146 "prompt": prompt,
147 "image_size": size,
148 "num_images": 1
149 });
150
151 let resp = client
152 .post(&url)
153 .header("Authorization", format!("Key {api_key}"))
154 .header("Content-Type", "application/json")
155 .json(&body)
156 .send()
157 .await
158 .context("fal.ai request failed")?;
159
160 let status = resp.status();
161 if !status.is_success() {
162 let body_text = resp.text().await.unwrap_or_default();
163 return Ok(ToolResult {
164 success: false,
165 output: String::new(),
166 error: Some(format!("fal.ai API error ({status}): {body_text}")),
167 });
168 }
169
170 let resp_json: serde_json::Value = resp
171 .json()
172 .await
173 .context("Failed to parse fal.ai response as JSON")?;
174
175 let image_url = resp_json
176 .pointer("/images/0/url")
177 .and_then(|v| v.as_str())
178 .ok_or_else(|| {
179 ::zeroclaw_log::record!(
180 ERROR,
181 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
182 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
183 "image_gen: fal.ai response missing image URL"
184 );
185 anyhow::Error::msg("No image URL in fal.ai response")
186 })?;
187
188 let img_resp = client
190 .get(image_url)
191 .send()
192 .await
193 .context("Failed to download generated image")?;
194
195 if !img_resp.status().is_success() {
196 return Ok(ToolResult {
197 success: false,
198 output: String::new(),
199 error: Some(format!(
200 "Failed to download image from {image_url} ({})",
201 img_resp.status()
202 )),
203 });
204 }
205
206 let bytes = img_resp
207 .bytes()
208 .await
209 .context("Failed to read image bytes")?;
210
211 let images_dir = self.workspace_dir.join("images");
213 tokio::fs::create_dir_all(&images_dir)
214 .await
215 .context("Failed to create images directory")?;
216
217 let output_path = images_dir.join(format!("{safe_name}.png"));
218 tokio::fs::write(&output_path, &bytes)
219 .await
220 .context("Failed to write image file")?;
221
222 let size_kb = bytes.len() / 1024;
223
224 Ok(ToolResult {
225 success: true,
226 output: format!(
227 "Image generated successfully.\n\
228 File: {}\n\
229 Size: {} KB\n\
230 Model: {}\n\
231 Prompt: {}",
232 output_path.display(),
233 size_kb,
234 model,
235 prompt,
236 ),
237 error: None,
238 })
239 }
240}
241
242#[async_trait]
243impl Tool for ImageGenTool {
244 fn name(&self) -> &str {
245 "image_gen"
246 }
247
248 fn description(&self) -> &str {
249 "Generate an image from a text prompt using fal.ai (Flux models). \
250 Saves the result to the workspace images directory and returns the file path."
251 }
252
253 fn parameters_schema(&self) -> serde_json::Value {
254 json!({
255 "type": "object",
256 "required": ["prompt"],
257 "properties": {
258 "prompt": {
259 "type": "string",
260 "description": "Text prompt describing the image to generate."
261 },
262 "filename": {
263 "type": "string",
264 "description": "Output filename without extension (default: 'generated_image'). Saved as PNG in workspace/images/."
265 },
266 "size": {
267 "type": "string",
268 "enum": ["square_hd", "landscape_4_3", "portrait_4_3", "landscape_16_9", "portrait_16_9"],
269 "description": "Image aspect ratio / size preset (default: 'square_hd')."
270 },
271 "model": {
272 "type": "string",
273 "description": "fal.ai model identifier (default: 'fal-ai/flux/schnell')."
274 }
275 }
276 })
277 }
278
279 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
280 if let Err(error) = self
282 .security
283 .enforce_tool_operation(ToolOperation::Act, "image_gen")
284 {
285 return Ok(ToolResult {
286 success: false,
287 output: String::new(),
288 error: Some(error),
289 });
290 }
291
292 self.generate(args).await
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use zeroclaw_config::autonomy::AutonomyLevel;
300 use zeroclaw_config::policy::SecurityPolicy;
301
302 fn test_security() -> Arc<SecurityPolicy> {
303 Arc::new(SecurityPolicy {
304 autonomy: AutonomyLevel::Full,
305 workspace_dir: std::env::temp_dir(),
306 ..SecurityPolicy::default()
307 })
308 }
309
310 fn test_tool() -> ImageGenTool {
311 ImageGenTool::new(
312 test_security(),
313 std::env::temp_dir(),
314 "fal-ai/flux/schnell".into(),
315 "FAL_API_KEY".into(),
316 )
317 }
318
319 #[test]
320 fn tool_name() {
321 let tool = test_tool();
322 assert_eq!(tool.name(), "image_gen");
323 }
324
325 #[test]
326 fn tool_description_is_nonempty() {
327 let tool = test_tool();
328 assert!(!tool.description().is_empty());
329 assert!(tool.description().contains("image"));
330 }
331
332 #[test]
333 fn tool_schema_has_required_prompt() {
334 let tool = test_tool();
335 let schema = tool.parameters_schema();
336 assert_eq!(schema["required"], json!(["prompt"]));
337 assert!(schema["properties"]["prompt"].is_object());
338 }
339
340 #[test]
341 fn tool_schema_has_optional_params() {
342 let tool = test_tool();
343 let schema = tool.parameters_schema();
344 assert!(schema["properties"]["filename"].is_object());
345 assert!(schema["properties"]["size"].is_object());
346 assert!(schema["properties"]["model"].is_object());
347 }
348
349 #[test]
350 fn tool_spec_roundtrip() {
351 let tool = test_tool();
352 let spec = tool.spec();
353 assert_eq!(spec.name, "image_gen");
354 assert!(spec.parameters.is_object());
355 }
356
357 #[tokio::test]
358 async fn missing_prompt_returns_error() {
359 let tool = test_tool();
360 let result = tool.execute(json!({})).await.unwrap();
361 assert!(!result.success);
362 assert!(result.error.as_deref().unwrap().contains("prompt"));
363 }
364
365 #[tokio::test]
366 async fn empty_prompt_returns_error() {
367 let tool = test_tool();
368 let result = tool.execute(json!({"prompt": " "})).await.unwrap();
369 assert!(!result.success);
370 assert!(result.error.as_deref().unwrap().contains("prompt"));
371 }
372
373 #[tokio::test]
374 async fn missing_api_key_returns_error() {
375 let original = std::env::var("FAL_API_KEY_TEST_IMAGE_GEN").ok();
377 unsafe { std::env::remove_var("FAL_API_KEY_TEST_IMAGE_GEN") };
379
380 let tool = ImageGenTool::new(
381 test_security(),
382 std::env::temp_dir(),
383 "fal-ai/flux/schnell".into(),
384 "FAL_API_KEY_TEST_IMAGE_GEN".into(),
385 );
386 let result = tool
387 .execute(json!({"prompt": "a sunset over the ocean"}))
388 .await
389 .unwrap();
390 assert!(!result.success);
391 assert!(
392 result
393 .error
394 .as_deref()
395 .unwrap()
396 .contains("FAL_API_KEY_TEST_IMAGE_GEN")
397 );
398
399 if let Some(val) = original {
401 unsafe { std::env::set_var("FAL_API_KEY_TEST_IMAGE_GEN", val) };
403 }
404 }
405
406 #[tokio::test]
407 async fn invalid_size_returns_error() {
408 unsafe { std::env::set_var("FAL_API_KEY_TEST_SIZE", "dummy_key") };
411
412 let tool = ImageGenTool::new(
413 test_security(),
414 std::env::temp_dir(),
415 "fal-ai/flux/schnell".into(),
416 "FAL_API_KEY_TEST_SIZE".into(),
417 );
418 let result = tool
419 .execute(json!({"prompt": "test", "size": "invalid_size"}))
420 .await
421 .unwrap();
422 assert!(!result.success);
423 assert!(result.error.as_deref().unwrap().contains("Invalid size"));
424
425 unsafe { std::env::remove_var("FAL_API_KEY_TEST_SIZE") };
427 }
428
429 #[tokio::test]
430 async fn read_only_autonomy_blocks_execution() {
431 let security = Arc::new(SecurityPolicy {
432 autonomy: AutonomyLevel::ReadOnly,
433 workspace_dir: std::env::temp_dir(),
434 ..SecurityPolicy::default()
435 });
436 let tool = ImageGenTool::new(
437 security,
438 std::env::temp_dir(),
439 "fal-ai/flux/schnell".into(),
440 "FAL_API_KEY".into(),
441 );
442 let result = tool.execute(json!({"prompt": "test image"})).await.unwrap();
443 assert!(!result.success);
444 let err = result.error.as_deref().unwrap();
445 assert!(
446 err.contains("read-only") || err.contains("image_gen"),
447 "expected read-only or image_gen in error, got: {err}"
448 );
449 }
450
451 #[tokio::test]
452 async fn invalid_model_with_traversal_returns_error() {
453 unsafe { std::env::set_var("FAL_API_KEY_TEST_MODEL", "dummy_key") };
455
456 let tool = ImageGenTool::new(
457 test_security(),
458 std::env::temp_dir(),
459 "fal-ai/flux/schnell".into(),
460 "FAL_API_KEY_TEST_MODEL".into(),
461 );
462 let result = tool
463 .execute(json!({"prompt": "test", "model": "../../evil-endpoint"}))
464 .await
465 .unwrap();
466 assert!(!result.success);
467 assert!(
468 result
469 .error
470 .as_deref()
471 .unwrap()
472 .contains("Invalid model identifier")
473 );
474
475 unsafe { std::env::remove_var("FAL_API_KEY_TEST_MODEL") };
477 }
478
479 #[test]
480 fn read_api_key_missing() {
481 let result = ImageGenTool::read_api_key("DEFINITELY_NOT_SET_ZC_TEST_12345");
482 assert!(result.is_err());
483 assert!(
484 result
485 .unwrap_err()
486 .contains("DEFINITELY_NOT_SET_ZC_TEST_12345")
487 );
488 }
489
490 #[test]
491 fn filename_traversal_is_sanitized() {
492 let sanitized = PathBuf::from("../../etc/passwd").file_name().map_or_else(
494 || "generated_image".to_string(),
495 |n| n.to_string_lossy().to_string(),
496 );
497 assert_eq!(sanitized, "passwd");
498
499 let sanitized = PathBuf::from("..").file_name().map_or_else(
501 || "generated_image".to_string(),
502 |n| n.to_string_lossy().to_string(),
503 );
504 assert_eq!(sanitized, "generated_image");
505 }
506
507 #[test]
508 fn read_api_key_present() {
509 unsafe { std::env::set_var("ZC_IMAGE_GEN_TEST_KEY", "test_value_123") };
511 let result = ImageGenTool::read_api_key("ZC_IMAGE_GEN_TEST_KEY");
512 assert!(result.is_ok());
513 assert_eq!(result.unwrap(), "test_value_123");
514 unsafe { std::env::remove_var("ZC_IMAGE_GEN_TEST_KEY") };
516 }
517}