1use anyhow::Context;
2use async_trait::async_trait;
3use serde_json::json;
4use std::path::PathBuf;
5use std::sync::Arc;
6use zeroclaw_api::tool::{Tool, ToolResult, with_ephemeral_workspace_warning};
7use zeroclaw_config::policy::SecurityPolicy;
8use zeroclaw_config::policy::ToolOperation;
9
10pub struct ImageGenTool {
16 security: Arc<SecurityPolicy>,
17 workspace_dir: PathBuf,
18 default_model: String,
19 api_key_env: String,
20 persistent_writes: bool,
27}
28
29impl ImageGenTool {
30 pub fn new(
31 security: Arc<SecurityPolicy>,
32 workspace_dir: PathBuf,
33 default_model: String,
34 api_key_env: String,
35 ) -> Self {
36 Self {
37 security,
38 workspace_dir,
39 default_model,
40 api_key_env,
41 persistent_writes: true,
42 }
43 }
44
45 pub fn new_with_persistence(
49 security: Arc<SecurityPolicy>,
50 workspace_dir: PathBuf,
51 default_model: String,
52 api_key_env: String,
53 persistent_writes: bool,
54 ) -> Self {
55 Self {
56 security,
57 workspace_dir,
58 default_model,
59 api_key_env,
60 persistent_writes,
61 }
62 }
63
64 fn http_client() -> reqwest::Client {
66 reqwest::Client::builder()
67 .timeout(std::time::Duration::from_secs(120))
68 .build()
69 .unwrap_or_default()
70 }
71
72 fn read_api_key(env_var: &str) -> Result<String, String> {
74 std::env::var(env_var)
75 .map(|v| v.trim().to_string())
76 .ok()
77 .filter(|v| !v.is_empty())
78 .ok_or_else(|| format!("Missing API key: set the {env_var} environment variable"))
79 }
80
81 async fn generate(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
83 let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
85 Some(p) if !p.trim().is_empty() => p.trim().to_string(),
86 _ => {
87 return Ok(ToolResult {
88 success: false,
89 output: String::new(),
90 error: Some("Missing required parameter: 'prompt'".into()),
91 });
92 }
93 };
94
95 let filename = args
96 .get("filename")
97 .and_then(|v| v.as_str())
98 .filter(|s| !s.trim().is_empty())
99 .unwrap_or("generated_image");
100
101 let safe_name = PathBuf::from(filename).file_name().map_or_else(
103 || "generated_image".to_string(),
104 |n| n.to_string_lossy().to_string(),
105 );
106
107 let size = args
108 .get("size")
109 .and_then(|v| v.as_str())
110 .unwrap_or("square_hd");
111
112 const VALID_SIZES: &[&str] = &[
114 "square_hd",
115 "landscape_4_3",
116 "portrait_4_3",
117 "landscape_16_9",
118 "portrait_16_9",
119 ];
120 if !VALID_SIZES.contains(&size) {
121 return Ok(ToolResult {
122 success: false,
123 output: String::new(),
124 error: Some(format!(
125 "Invalid size '{size}'. Valid values: {}",
126 VALID_SIZES.join(", ")
127 )),
128 });
129 }
130
131 let model = args
132 .get("model")
133 .and_then(|v| v.as_str())
134 .filter(|s| !s.trim().is_empty())
135 .unwrap_or(&self.default_model);
136
137 if model.contains("..")
141 || model.contains('?')
142 || model.contains('#')
143 || model.contains('\\')
144 || model.starts_with('/')
145 {
146 return Ok(ToolResult {
147 success: false,
148 output: String::new(),
149 error: Some(format!(
150 "Invalid model identifier '{model}'. \
151 Must be a fal.ai model path (e.g. 'fal-ai/flux/schnell')."
152 )),
153 });
154 }
155
156 let api_key = match Self::read_api_key(&self.api_key_env) {
158 Ok(k) => k,
159 Err(msg) => {
160 return Ok(ToolResult {
161 success: false,
162 output: String::new(),
163 error: Some(msg),
164 });
165 }
166 };
167
168 let client = Self::http_client();
170 let url = format!("https://fal.run/{model}");
171
172 let body = json!({
173 "prompt": prompt,
174 "image_size": size,
175 "num_images": 1
176 });
177
178 let resp = client
179 .post(&url)
180 .header("Authorization", format!("Key {api_key}"))
181 .header("Content-Type", "application/json")
182 .json(&body)
183 .send()
184 .await
185 .context("fal.ai request failed")?;
186
187 let status = resp.status();
188 if !status.is_success() {
189 let body_text = resp.text().await.unwrap_or_default();
190 return Ok(ToolResult {
191 success: false,
192 output: String::new(),
193 error: Some(format!("fal.ai API error ({status}): {body_text}")),
194 });
195 }
196
197 let resp_json: serde_json::Value = resp
198 .json()
199 .await
200 .context("Failed to parse fal.ai response as JSON")?;
201
202 let image_url = resp_json
203 .pointer("/images/0/url")
204 .and_then(|v| v.as_str())
205 .ok_or_else(|| {
206 ::zeroclaw_log::record!(
207 ERROR,
208 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Fail)
209 .with_outcome(::zeroclaw_log::EventOutcome::Failure),
210 "image_gen: fal.ai response missing image URL"
211 );
212 anyhow::Error::msg("No image URL in fal.ai response")
213 })?;
214
215 let img_resp = client
217 .get(image_url)
218 .send()
219 .await
220 .context("Failed to download generated image")?;
221
222 if !img_resp.status().is_success() {
223 return Ok(ToolResult {
224 success: false,
225 output: String::new(),
226 error: Some(format!(
227 "Failed to download image from {image_url} ({})",
228 img_resp.status()
229 )),
230 });
231 }
232
233 let bytes = img_resp
234 .bytes()
235 .await
236 .context("Failed to read image bytes")?;
237
238 let images_dir = self.workspace_dir.join("images");
240 tokio::fs::create_dir_all(&images_dir)
241 .await
242 .context("Failed to create images directory")?;
243
244 let output_path = images_dir.join(format!("{safe_name}.png"));
245 tokio::fs::write(&output_path, &bytes)
246 .await
247 .context("Failed to write image file")?;
248
249 let size_kb = bytes.len() / 1024;
250
251 Ok(ToolResult {
252 success: true,
253 output: format!(
254 "Image generated successfully.\n\
255 File: {}\n\
256 Size: {} KB\n\
257 Model: {}\n\
258 Prompt: {}",
259 output_path.display(),
260 size_kb,
261 model,
262 prompt,
263 ),
264 error: None,
265 })
266 }
267}
268
269#[async_trait]
270impl Tool for ImageGenTool {
271 fn name(&self) -> &str {
272 "image_gen"
273 }
274
275 fn description(&self) -> &str {
276 "Generate an image from a text prompt using fal.ai (Flux models). \
277 Saves the result to the workspace images directory and returns the file path."
278 }
279
280 fn parameters_schema(&self) -> serde_json::Value {
281 json!({
282 "type": "object",
283 "required": ["prompt"],
284 "properties": {
285 "prompt": {
286 "type": "string",
287 "description": "Text prompt describing the image to generate."
288 },
289 "filename": {
290 "type": "string",
291 "description": "Output filename without extension (default: 'generated_image'). Saved as PNG in workspace/images/."
292 },
293 "size": {
294 "type": "string",
295 "enum": ["square_hd", "landscape_4_3", "portrait_4_3", "landscape_16_9", "portrait_16_9"],
296 "description": "Image aspect ratio / size preset (default: 'square_hd')."
297 },
298 "model": {
299 "type": "string",
300 "description": "fal.ai model identifier (default: 'fal-ai/flux/schnell')."
301 }
302 }
303 })
304 }
305
306 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
307 if let Err(error) = self
309 .security
310 .enforce_tool_operation(ToolOperation::Act, "image_gen")
311 {
312 return Ok(ToolResult {
313 success: false,
314 output: String::new(),
315 error: Some(error),
316 });
317 }
318
319 let mut result = self.generate(args).await?;
320 if !self.persistent_writes && result.success {
323 result.output = with_ephemeral_workspace_warning(&result.output);
324 }
325 Ok(result)
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use zeroclaw_config::autonomy::AutonomyLevel;
333 use zeroclaw_config::policy::SecurityPolicy;
334
335 fn test_security() -> Arc<SecurityPolicy> {
336 Arc::new(SecurityPolicy {
337 autonomy: AutonomyLevel::Full,
338 workspace_dir: std::env::temp_dir(),
339 ..SecurityPolicy::default()
340 })
341 }
342
343 fn test_tool() -> ImageGenTool {
344 ImageGenTool::new(
345 test_security(),
346 std::env::temp_dir(),
347 "fal-ai/flux/schnell".into(),
348 "FAL_API_KEY".into(),
349 )
350 }
351
352 #[test]
353 fn tool_name() {
354 let tool = test_tool();
355 assert_eq!(tool.name(), "image_gen");
356 }
357
358 #[test]
359 fn tool_description_is_nonempty() {
360 let tool = test_tool();
361 assert!(!tool.description().is_empty());
362 assert!(tool.description().contains("image"));
363 }
364
365 #[test]
366 fn tool_schema_has_required_prompt() {
367 let tool = test_tool();
368 let schema = tool.parameters_schema();
369 assert_eq!(schema["required"], json!(["prompt"]));
370 assert!(schema["properties"]["prompt"].is_object());
371 }
372
373 #[test]
374 fn tool_schema_has_optional_params() {
375 let tool = test_tool();
376 let schema = tool.parameters_schema();
377 assert!(schema["properties"]["filename"].is_object());
378 assert!(schema["properties"]["size"].is_object());
379 assert!(schema["properties"]["model"].is_object());
380 }
381
382 #[test]
383 fn tool_spec_roundtrip() {
384 let tool = test_tool();
385 let spec = tool.spec();
386 assert_eq!(spec.name, "image_gen");
387 assert!(spec.parameters.is_object());
388 }
389
390 #[tokio::test]
391 async fn missing_prompt_returns_error() {
392 let tool = test_tool();
393 let result = tool.execute(json!({})).await.unwrap();
394 assert!(!result.success);
395 assert!(result.error.as_deref().unwrap().contains("prompt"));
396 }
397
398 #[tokio::test]
399 async fn empty_prompt_returns_error() {
400 let tool = test_tool();
401 let result = tool.execute(json!({"prompt": " "})).await.unwrap();
402 assert!(!result.success);
403 assert!(result.error.as_deref().unwrap().contains("prompt"));
404 }
405
406 #[tokio::test]
407 async fn missing_api_key_returns_error() {
408 let original = std::env::var("FAL_API_KEY_TEST_IMAGE_GEN").ok();
410 unsafe { std::env::remove_var("FAL_API_KEY_TEST_IMAGE_GEN") };
412
413 let tool = ImageGenTool::new(
414 test_security(),
415 std::env::temp_dir(),
416 "fal-ai/flux/schnell".into(),
417 "FAL_API_KEY_TEST_IMAGE_GEN".into(),
418 );
419 let result = tool
420 .execute(json!({"prompt": "a sunset over the ocean"}))
421 .await
422 .unwrap();
423 assert!(!result.success);
424 assert!(
425 result
426 .error
427 .as_deref()
428 .unwrap()
429 .contains("FAL_API_KEY_TEST_IMAGE_GEN")
430 );
431
432 if let Some(val) = original {
434 unsafe { std::env::set_var("FAL_API_KEY_TEST_IMAGE_GEN", val) };
436 }
437 }
438
439 #[tokio::test]
440 async fn invalid_size_returns_error() {
441 unsafe { std::env::set_var("FAL_API_KEY_TEST_SIZE", "dummy_key") };
444
445 let tool = ImageGenTool::new(
446 test_security(),
447 std::env::temp_dir(),
448 "fal-ai/flux/schnell".into(),
449 "FAL_API_KEY_TEST_SIZE".into(),
450 );
451 let result = tool
452 .execute(json!({"prompt": "test", "size": "invalid_size"}))
453 .await
454 .unwrap();
455 assert!(!result.success);
456 assert!(result.error.as_deref().unwrap().contains("Invalid size"));
457
458 unsafe { std::env::remove_var("FAL_API_KEY_TEST_SIZE") };
460 }
461
462 #[tokio::test]
463 async fn read_only_autonomy_blocks_execution() {
464 let security = Arc::new(SecurityPolicy {
465 autonomy: AutonomyLevel::ReadOnly,
466 workspace_dir: std::env::temp_dir(),
467 ..SecurityPolicy::default()
468 });
469 let tool = ImageGenTool::new(
470 security,
471 std::env::temp_dir(),
472 "fal-ai/flux/schnell".into(),
473 "FAL_API_KEY".into(),
474 );
475 let result = tool.execute(json!({"prompt": "test image"})).await.unwrap();
476 assert!(!result.success);
477 let err = result.error.as_deref().unwrap();
478 assert!(
479 err.contains("read-only") || err.contains("image_gen"),
480 "expected read-only or image_gen in error, got: {err}"
481 );
482 }
483
484 #[tokio::test]
485 async fn invalid_model_with_traversal_returns_error() {
486 unsafe { std::env::set_var("FAL_API_KEY_TEST_MODEL", "dummy_key") };
488
489 let tool = ImageGenTool::new(
490 test_security(),
491 std::env::temp_dir(),
492 "fal-ai/flux/schnell".into(),
493 "FAL_API_KEY_TEST_MODEL".into(),
494 );
495 let result = tool
496 .execute(json!({"prompt": "test", "model": "../../evil-endpoint"}))
497 .await
498 .unwrap();
499 assert!(!result.success);
500 assert!(
501 result
502 .error
503 .as_deref()
504 .unwrap()
505 .contains("Invalid model identifier")
506 );
507
508 unsafe { std::env::remove_var("FAL_API_KEY_TEST_MODEL") };
510 }
511
512 #[test]
513 fn read_api_key_missing() {
514 let result = ImageGenTool::read_api_key("DEFINITELY_NOT_SET_ZC_TEST_12345");
515 assert!(result.is_err());
516 assert!(
517 result
518 .unwrap_err()
519 .contains("DEFINITELY_NOT_SET_ZC_TEST_12345")
520 );
521 }
522
523 #[test]
524 fn filename_traversal_is_sanitized() {
525 let sanitized = PathBuf::from("../../etc/passwd").file_name().map_or_else(
527 || "generated_image".to_string(),
528 |n| n.to_string_lossy().to_string(),
529 );
530 assert_eq!(sanitized, "passwd");
531
532 let sanitized = PathBuf::from("..").file_name().map_or_else(
534 || "generated_image".to_string(),
535 |n| n.to_string_lossy().to_string(),
536 );
537 assert_eq!(sanitized, "generated_image");
538 }
539
540 #[test]
541 fn read_api_key_present() {
542 unsafe { std::env::set_var("ZC_IMAGE_GEN_TEST_KEY", "test_value_123") };
544 let result = ImageGenTool::read_api_key("ZC_IMAGE_GEN_TEST_KEY");
545 assert!(result.is_ok());
546 assert_eq!(result.unwrap(), "test_value_123");
547 unsafe { std::env::remove_var("ZC_IMAGE_GEN_TEST_KEY") };
549 }
550}