1use async_trait::async_trait;
8use serde_json::json;
9use std::sync::Arc;
10use zeroclaw_api::model_provider::ModelProvider;
11use zeroclaw_api::tool::{Tool, ToolResult};
12use zeroclaw_config::policy::SecurityPolicy;
13use zeroclaw_config::policy::ToolOperation;
14
15pub struct LlmTaskTool {
19 security: Arc<SecurityPolicy>,
20 default_model_provider: String,
22 default_model: String,
24 default_temperature: Option<f64>,
27 api_key: Option<String>,
29 provider_runtime_options: zeroclaw_providers::ModelProviderRuntimeOptions,
31}
32
33impl LlmTaskTool {
34 pub fn new(
35 security: Arc<SecurityPolicy>,
36 default_model_provider: String,
37 default_model: String,
38 default_temperature: Option<f64>,
39 api_key: Option<String>,
40 provider_runtime_options: zeroclaw_providers::ModelProviderRuntimeOptions,
41 ) -> Self {
42 Self {
43 security,
44 default_model_provider,
45 default_model,
46 default_temperature,
47 api_key,
48 provider_runtime_options,
49 }
50 }
51}
52
53#[async_trait]
54impl Tool for LlmTaskTool {
55 fn name(&self) -> &str {
56 "llm_task"
57 }
58
59 fn description(&self) -> &str {
60 "Run a prompt through an LLM with no tool access and return the response. \
61 Optionally validates the output against a JSON Schema. Ideal for structured \
62 data extraction, classification, summarization, and transformation tasks."
63 }
64
65 fn parameters_schema(&self) -> serde_json::Value {
66 json!({
67 "type": "object",
68 "properties": {
69 "prompt": {
70 "type": "string",
71 "description": "The prompt to send to the LLM."
72 },
73 "schema": {
74 "type": "object",
75 "description": "Optional JSON Schema to validate the LLM response against. \
76 When provided, the LLM is instructed to return valid JSON \
77 matching this schema."
78 },
79 "model": {
80 "type": "string",
81 "description": "Optional model override (e.g. 'anthropic/claude-sonnet-4-6'). \
82 Defaults to the configured default model."
83 },
84 "temperature": {
85 "type": "number",
86 "description": "Optional temperature override (0.0-2.0). \
87 Defaults to the configured default temperature."
88 }
89 },
90 "required": ["prompt"]
91 })
92 }
93
94 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
95 if let Err(error) = self
97 .security
98 .enforce_tool_operation(ToolOperation::Act, "llm_task")
99 {
100 return Ok(ToolResult {
101 success: false,
102 output: String::new(),
103 error: Some(error),
104 });
105 }
106
107 let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
109 Some(p) if !p.trim().is_empty() => p,
110 _ => {
111 return Ok(ToolResult {
112 success: false,
113 output: String::new(),
114 error: Some("Missing or empty required parameter: prompt".to_string()),
115 });
116 }
117 };
118
119 let schema = args.get("schema").and_then(|v| v.as_object());
121 let model = args
122 .get("model")
123 .and_then(|v| v.as_str())
124 .unwrap_or(&self.default_model);
125 let temperature = args
126 .get("temperature")
127 .and_then(|v| v.as_f64())
128 .or(self.default_temperature);
129
130 let effective_prompt = if let Some(schema_obj) = schema {
132 let schema_json =
133 serde_json::to_string_pretty(&serde_json::Value::Object(schema_obj.clone()))
134 .unwrap_or_else(|_| "{}".to_string());
135 format!(
136 "{prompt}\n\n\
137 IMPORTANT: You MUST respond with valid JSON that conforms to this schema:\n\
138 ```json\n{schema_json}\n```\n\
139 Respond ONLY with the JSON object, no explanation or markdown."
140 )
141 } else {
142 prompt.to_string()
143 };
144
145 let api_key_ref = self.api_key.as_deref();
147 let model_provider: Box<dyn ModelProvider> =
148 match zeroclaw_providers::create_model_provider_with_options(
149 &self.default_model_provider,
150 api_key_ref,
151 &self.provider_runtime_options,
152 ) {
153 Ok(p) => p,
154 Err(e) => {
155 return Ok(ToolResult {
156 success: false,
157 output: String::new(),
158 error: Some(format!("Failed to create model_provider: {e}")),
159 });
160 }
161 };
162
163 let response = match model_provider
167 .simple_chat(&effective_prompt, model, temperature)
168 .await
169 {
170 Ok(text) => text,
171 Err(e) => {
172 return Ok(ToolResult {
173 success: false,
174 output: String::new(),
175 error: Some(format!("LLM call failed: {e}")),
176 });
177 }
178 };
179
180 if let Some(schema_obj) = schema {
182 let schema_value = serde_json::Value::Object(schema_obj.clone());
183 match validate_json_response(&response, &schema_value) {
184 Ok(validated_json) => Ok(ToolResult {
185 success: true,
186 output: validated_json,
187 error: None,
188 }),
189 Err(validation_error) => Ok(ToolResult {
190 success: false,
191 output: response,
192 error: Some(format!("Schema validation failed: {validation_error}")),
193 }),
194 }
195 } else {
196 Ok(ToolResult {
197 success: true,
198 output: response,
199 error: None,
200 })
201 }
202 }
203}
204
205fn validate_json_response(response: &str, schema: &serde_json::Value) -> Result<String, String> {
211 let trimmed = response.trim();
213 let json_str = if trimmed.starts_with("```") {
214 trimmed
215 .trim_start_matches("```json")
216 .trim_start_matches("```")
217 .trim_end_matches("```")
218 .trim()
219 } else {
220 trimmed
221 };
222
223 let parsed: serde_json::Value =
225 serde_json::from_str(json_str).map_err(|e| format!("Invalid JSON: {e}"))?;
226
227 if let Some(required) = schema.get("required").and_then(|v| v.as_array()) {
229 for req in required {
230 if let Some(field_name) = req.as_str()
231 && parsed.get(field_name).is_none()
232 {
233 return Err(format!("Missing required field: {field_name}"));
234 }
235 }
236 }
237
238 if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
240 for (prop_name, prop_schema) in properties {
241 if let Some(value) = parsed.get(prop_name)
242 && let Some(expected_type) = prop_schema.get("type").and_then(|t| t.as_str())
243 && !type_matches(value, expected_type)
244 {
245 return Err(format!(
246 "Field '{prop_name}' has wrong type: expected {expected_type}, \
247 got {}",
248 json_type_name(value)
249 ));
250 }
251 }
252 }
253
254 serde_json::to_string(&parsed).map_err(|e| format!("JSON serialization error: {e}"))
256}
257
258fn type_matches(value: &serde_json::Value, expected: &str) -> bool {
260 match expected {
261 "string" => value.is_string(),
262 "number" => value.is_number(),
263 "integer" => value.is_i64() || value.is_u64(),
264 "boolean" => value.is_boolean(),
265 "array" => value.is_array(),
266 "object" => value.is_object(),
267 "null" => value.is_null(),
268 _ => true, }
270}
271
272fn json_type_name(value: &serde_json::Value) -> &'static str {
274 match value {
275 serde_json::Value::Null => "null",
276 serde_json::Value::Bool(_) => "boolean",
277 serde_json::Value::Number(_) => "number",
278 serde_json::Value::String(_) => "string",
279 serde_json::Value::Array(_) => "array",
280 serde_json::Value::Object(_) => "object",
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
291 fn validate_valid_json_against_schema() {
292 let schema = json!({
293 "type": "object",
294 "properties": {
295 "name": { "type": "string" },
296 "age": { "type": "integer" }
297 },
298 "required": ["name", "age"]
299 });
300
301 let response = r#"{"name": "Alice", "age": 30}"#;
302 let result = validate_json_response(response, &schema);
303 assert!(result.is_ok());
304
305 let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
306 assert_eq!(parsed["name"], "Alice");
307 assert_eq!(parsed["age"], 30);
308 }
309
310 #[test]
311 fn validate_missing_required_field() {
312 let schema = json!({
313 "type": "object",
314 "properties": {
315 "title": { "type": "string" },
316 "score": { "type": "number" }
317 },
318 "required": ["title", "score"]
319 });
320
321 let response = r#"{"title": "Test"}"#;
322 let result = validate_json_response(response, &schema);
323 assert!(result.is_err());
324 assert!(
325 result
326 .unwrap_err()
327 .contains("Missing required field: score")
328 );
329 }
330
331 #[test]
332 fn validate_wrong_type() {
333 let schema = json!({
334 "type": "object",
335 "properties": {
336 "count": { "type": "integer" }
337 },
338 "required": ["count"]
339 });
340
341 let response = r#"{"count": "not_a_number"}"#;
342 let result = validate_json_response(response, &schema);
343 assert!(result.is_err());
344 assert!(result.unwrap_err().contains("wrong type"));
345 }
346
347 #[test]
348 fn validate_strips_markdown_code_fences() {
349 let schema = json!({
350 "type": "object",
351 "properties": {
352 "result": { "type": "string" }
353 },
354 "required": ["result"]
355 });
356
357 let response = "```json\n{\"result\": \"ok\"}\n```";
358 let result = validate_json_response(response, &schema);
359 assert!(result.is_ok());
360 }
361
362 #[test]
363 fn validate_invalid_json() {
364 let schema = json!({ "type": "object" });
365 let response = "this is not json at all";
366 let result = validate_json_response(response, &schema);
367 assert!(result.is_err());
368 assert!(result.unwrap_err().contains("Invalid JSON"));
369 }
370
371 #[test]
372 fn validate_optional_fields_accepted() {
373 let schema = json!({
374 "type": "object",
375 "properties": {
376 "name": { "type": "string" },
377 "bio": { "type": "string" }
378 },
379 "required": ["name"]
380 });
381
382 let response = r#"{"name": "Bob"}"#;
384 let result = validate_json_response(response, &schema);
385 assert!(result.is_ok());
386 }
387
388 #[test]
389 fn validate_all_type_checks() {
390 assert!(type_matches(&json!("hello"), "string"));
391 assert!(!type_matches(&json!(42), "string"));
392
393 assert!(type_matches(&json!(2.72), "number"));
394 assert!(type_matches(&json!(42), "number"));
395 assert!(!type_matches(&json!("42"), "number"));
396
397 assert!(type_matches(&json!(42), "integer"));
398 assert!(!type_matches(&json!(2.72), "integer"));
399
400 assert!(type_matches(&json!(true), "boolean"));
401 assert!(!type_matches(&json!(1), "boolean"));
402
403 assert!(type_matches(&json!([1, 2]), "array"));
404 assert!(!type_matches(&json!({}), "array"));
405
406 assert!(type_matches(&json!({}), "object"));
407 assert!(!type_matches(&json!([]), "object"));
408
409 assert!(type_matches(&json!(null), "null"));
410
411 assert!(type_matches(&json!("anything"), "custom_type"));
413 }
414
415 #[test]
418 fn tool_metadata() {
419 let tool = LlmTaskTool::new(
420 Arc::new(SecurityPolicy::default()),
421 "openrouter".to_string(),
422 "test-model".to_string(),
423 Some(0.7),
424 None,
425 zeroclaw_providers::ModelProviderRuntimeOptions::default(),
426 );
427
428 assert_eq!(tool.name(), "llm_task");
429 assert!(tool.description().contains("LLM"));
430
431 let schema = tool.parameters_schema();
432 assert_eq!(schema["type"], "object");
433 assert!(schema["properties"]["prompt"].is_object());
434 assert!(schema["properties"]["schema"].is_object());
435 assert!(schema["properties"]["model"].is_object());
436 assert!(schema["properties"]["temperature"].is_object());
437
438 let required = schema["required"].as_array().unwrap();
439 assert_eq!(required.len(), 1);
440 assert_eq!(required[0], "prompt");
441 }
442
443 #[tokio::test]
444 async fn execute_missing_prompt_returns_error() {
445 let tool = LlmTaskTool::new(
446 Arc::new(SecurityPolicy::default()),
447 "openrouter".to_string(),
448 "test-model".to_string(),
449 Some(0.7),
450 None,
451 zeroclaw_providers::ModelProviderRuntimeOptions::default(),
452 );
453
454 let result = tool.execute(json!({})).await.unwrap();
455 assert!(!result.success);
456 assert!(result.error.as_deref().unwrap().contains("prompt"));
457 }
458
459 #[tokio::test]
460 async fn execute_empty_prompt_returns_error() {
461 let tool = LlmTaskTool::new(
462 Arc::new(SecurityPolicy::default()),
463 "openrouter".to_string(),
464 "test-model".to_string(),
465 Some(0.7),
466 None,
467 zeroclaw_providers::ModelProviderRuntimeOptions::default(),
468 );
469
470 let result = tool.execute(json!({"prompt": " "})).await.unwrap();
471 assert!(!result.success);
472 assert!(result.error.as_deref().unwrap().contains("prompt"));
473 }
474
475 #[tokio::test]
476 async fn execute_with_invalid_provider_returns_error() {
477 let tool = LlmTaskTool::new(
478 Arc::new(SecurityPolicy::default()),
479 "nonexistent_provider_xyz".to_string(),
480 "test-model".to_string(),
481 Some(0.7),
482 None,
483 zeroclaw_providers::ModelProviderRuntimeOptions::default(),
484 );
485
486 let result = tool
487 .execute(json!({"prompt": "Hello world"}))
488 .await
489 .unwrap();
490 assert!(!result.success);
491 assert!(result.error.as_deref().unwrap().contains("model_provider"));
492 }
493}