1use async_trait::async_trait;
8use serde_json::json;
9use std::sync::Arc;
10use zeroclaw_api::model_provider::ModelProvider;
11use zeroclaw_api::tool::{Tool, ToolResult};
12use zeroclaw_config::policy::SecurityPolicy;
13use zeroclaw_config::policy::ToolOperation;
14
15pub struct LlmTaskTool {
19 security: Arc<SecurityPolicy>,
20 default_model_provider: String,
22 default_model: String,
24 default_temperature: f64,
26 api_key: Option<String>,
28 provider_runtime_options: zeroclaw_providers::ModelProviderRuntimeOptions,
30}
31
32impl LlmTaskTool {
33 pub fn new(
34 security: Arc<SecurityPolicy>,
35 default_model_provider: String,
36 default_model: String,
37 default_temperature: f64,
38 api_key: Option<String>,
39 provider_runtime_options: zeroclaw_providers::ModelProviderRuntimeOptions,
40 ) -> Self {
41 Self {
42 security,
43 default_model_provider,
44 default_model,
45 default_temperature,
46 api_key,
47 provider_runtime_options,
48 }
49 }
50}
51
52#[async_trait]
53impl Tool for LlmTaskTool {
54 fn name(&self) -> &str {
55 "llm_task"
56 }
57
58 fn description(&self) -> &str {
59 "Run a prompt through an LLM with no tool access and return the response. \
60 Optionally validates the output against a JSON Schema. Ideal for structured \
61 data extraction, classification, summarization, and transformation tasks."
62 }
63
64 fn parameters_schema(&self) -> serde_json::Value {
65 json!({
66 "type": "object",
67 "properties": {
68 "prompt": {
69 "type": "string",
70 "description": "The prompt to send to the LLM."
71 },
72 "schema": {
73 "type": "object",
74 "description": "Optional JSON Schema to validate the LLM response against. \
75 When provided, the LLM is instructed to return valid JSON \
76 matching this schema."
77 },
78 "model": {
79 "type": "string",
80 "description": "Optional model override (e.g. 'anthropic/claude-sonnet-4-6'). \
81 Defaults to the configured default model."
82 },
83 "temperature": {
84 "type": "number",
85 "description": "Optional temperature override (0.0-2.0). \
86 Defaults to the configured default temperature."
87 }
88 },
89 "required": ["prompt"]
90 })
91 }
92
93 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
94 if let Err(error) = self
96 .security
97 .enforce_tool_operation(ToolOperation::Act, "llm_task")
98 {
99 return Ok(ToolResult {
100 success: false,
101 output: String::new(),
102 error: Some(error),
103 });
104 }
105
106 let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
108 Some(p) if !p.trim().is_empty() => p,
109 _ => {
110 return Ok(ToolResult {
111 success: false,
112 output: String::new(),
113 error: Some("Missing or empty required parameter: prompt".to_string()),
114 });
115 }
116 };
117
118 let schema = args.get("schema").and_then(|v| v.as_object());
120 let model = args
121 .get("model")
122 .and_then(|v| v.as_str())
123 .unwrap_or(&self.default_model);
124 let temperature = args
125 .get("temperature")
126 .and_then(|v| v.as_f64())
127 .unwrap_or(self.default_temperature);
128
129 let effective_prompt = if let Some(schema_obj) = schema {
131 let schema_json =
132 serde_json::to_string_pretty(&serde_json::Value::Object(schema_obj.clone()))
133 .unwrap_or_else(|_| "{}".to_string());
134 format!(
135 "{prompt}\n\n\
136 IMPORTANT: You MUST respond with valid JSON that conforms to this schema:\n\
137 ```json\n{schema_json}\n```\n\
138 Respond ONLY with the JSON object, no explanation or markdown."
139 )
140 } else {
141 prompt.to_string()
142 };
143
144 let api_key_ref = self.api_key.as_deref();
146 let model_provider: Box<dyn ModelProvider> =
147 match zeroclaw_providers::create_model_provider_with_options(
148 &self.default_model_provider,
149 api_key_ref,
150 &self.provider_runtime_options,
151 ) {
152 Ok(p) => p,
153 Err(e) => {
154 return Ok(ToolResult {
155 success: false,
156 output: String::new(),
157 error: Some(format!("Failed to create model_provider: {e}")),
158 });
159 }
160 };
161
162 let response = match model_provider
166 .simple_chat(&effective_prompt, model, Some(temperature))
167 .await
168 {
169 Ok(text) => text,
170 Err(e) => {
171 return Ok(ToolResult {
172 success: false,
173 output: String::new(),
174 error: Some(format!("LLM call failed: {e}")),
175 });
176 }
177 };
178
179 if let Some(schema_obj) = schema {
181 let schema_value = serde_json::Value::Object(schema_obj.clone());
182 match validate_json_response(&response, &schema_value) {
183 Ok(validated_json) => Ok(ToolResult {
184 success: true,
185 output: validated_json,
186 error: None,
187 }),
188 Err(validation_error) => Ok(ToolResult {
189 success: false,
190 output: response,
191 error: Some(format!("Schema validation failed: {validation_error}")),
192 }),
193 }
194 } else {
195 Ok(ToolResult {
196 success: true,
197 output: response,
198 error: None,
199 })
200 }
201 }
202}
203
204fn validate_json_response(response: &str, schema: &serde_json::Value) -> Result<String, String> {
210 let trimmed = response.trim();
212 let json_str = if trimmed.starts_with("```") {
213 trimmed
214 .trim_start_matches("```json")
215 .trim_start_matches("```")
216 .trim_end_matches("```")
217 .trim()
218 } else {
219 trimmed
220 };
221
222 let parsed: serde_json::Value =
224 serde_json::from_str(json_str).map_err(|e| format!("Invalid JSON: {e}"))?;
225
226 if let Some(required) = schema.get("required").and_then(|v| v.as_array()) {
228 for req in required {
229 if let Some(field_name) = req.as_str()
230 && parsed.get(field_name).is_none()
231 {
232 return Err(format!("Missing required field: {field_name}"));
233 }
234 }
235 }
236
237 if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
239 for (prop_name, prop_schema) in properties {
240 if let Some(value) = parsed.get(prop_name)
241 && let Some(expected_type) = prop_schema.get("type").and_then(|t| t.as_str())
242 && !type_matches(value, expected_type)
243 {
244 return Err(format!(
245 "Field '{prop_name}' has wrong type: expected {expected_type}, \
246 got {}",
247 json_type_name(value)
248 ));
249 }
250 }
251 }
252
253 serde_json::to_string(&parsed).map_err(|e| format!("JSON serialization error: {e}"))
255}
256
257fn type_matches(value: &serde_json::Value, expected: &str) -> bool {
259 match expected {
260 "string" => value.is_string(),
261 "number" => value.is_number(),
262 "integer" => value.is_i64() || value.is_u64(),
263 "boolean" => value.is_boolean(),
264 "array" => value.is_array(),
265 "object" => value.is_object(),
266 "null" => value.is_null(),
267 _ => true, }
269}
270
271fn json_type_name(value: &serde_json::Value) -> &'static str {
273 match value {
274 serde_json::Value::Null => "null",
275 serde_json::Value::Bool(_) => "boolean",
276 serde_json::Value::Number(_) => "number",
277 serde_json::Value::String(_) => "string",
278 serde_json::Value::Array(_) => "array",
279 serde_json::Value::Object(_) => "object",
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
290 fn validate_valid_json_against_schema() {
291 let schema = json!({
292 "type": "object",
293 "properties": {
294 "name": { "type": "string" },
295 "age": { "type": "integer" }
296 },
297 "required": ["name", "age"]
298 });
299
300 let response = r#"{"name": "Alice", "age": 30}"#;
301 let result = validate_json_response(response, &schema);
302 assert!(result.is_ok());
303
304 let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
305 assert_eq!(parsed["name"], "Alice");
306 assert_eq!(parsed["age"], 30);
307 }
308
309 #[test]
310 fn validate_missing_required_field() {
311 let schema = json!({
312 "type": "object",
313 "properties": {
314 "title": { "type": "string" },
315 "score": { "type": "number" }
316 },
317 "required": ["title", "score"]
318 });
319
320 let response = r#"{"title": "Test"}"#;
321 let result = validate_json_response(response, &schema);
322 assert!(result.is_err());
323 assert!(
324 result
325 .unwrap_err()
326 .contains("Missing required field: score")
327 );
328 }
329
330 #[test]
331 fn validate_wrong_type() {
332 let schema = json!({
333 "type": "object",
334 "properties": {
335 "count": { "type": "integer" }
336 },
337 "required": ["count"]
338 });
339
340 let response = r#"{"count": "not_a_number"}"#;
341 let result = validate_json_response(response, &schema);
342 assert!(result.is_err());
343 assert!(result.unwrap_err().contains("wrong type"));
344 }
345
346 #[test]
347 fn validate_strips_markdown_code_fences() {
348 let schema = json!({
349 "type": "object",
350 "properties": {
351 "result": { "type": "string" }
352 },
353 "required": ["result"]
354 });
355
356 let response = "```json\n{\"result\": \"ok\"}\n```";
357 let result = validate_json_response(response, &schema);
358 assert!(result.is_ok());
359 }
360
361 #[test]
362 fn validate_invalid_json() {
363 let schema = json!({ "type": "object" });
364 let response = "this is not json at all";
365 let result = validate_json_response(response, &schema);
366 assert!(result.is_err());
367 assert!(result.unwrap_err().contains("Invalid JSON"));
368 }
369
370 #[test]
371 fn validate_optional_fields_accepted() {
372 let schema = json!({
373 "type": "object",
374 "properties": {
375 "name": { "type": "string" },
376 "bio": { "type": "string" }
377 },
378 "required": ["name"]
379 });
380
381 let response = r#"{"name": "Bob"}"#;
383 let result = validate_json_response(response, &schema);
384 assert!(result.is_ok());
385 }
386
387 #[test]
388 fn validate_all_type_checks() {
389 assert!(type_matches(&json!("hello"), "string"));
390 assert!(!type_matches(&json!(42), "string"));
391
392 assert!(type_matches(&json!(2.72), "number"));
393 assert!(type_matches(&json!(42), "number"));
394 assert!(!type_matches(&json!("42"), "number"));
395
396 assert!(type_matches(&json!(42), "integer"));
397 assert!(!type_matches(&json!(2.72), "integer"));
398
399 assert!(type_matches(&json!(true), "boolean"));
400 assert!(!type_matches(&json!(1), "boolean"));
401
402 assert!(type_matches(&json!([1, 2]), "array"));
403 assert!(!type_matches(&json!({}), "array"));
404
405 assert!(type_matches(&json!({}), "object"));
406 assert!(!type_matches(&json!([]), "object"));
407
408 assert!(type_matches(&json!(null), "null"));
409
410 assert!(type_matches(&json!("anything"), "custom_type"));
412 }
413
414 #[test]
417 fn tool_metadata() {
418 let tool = LlmTaskTool::new(
419 Arc::new(SecurityPolicy::default()),
420 "openrouter".to_string(),
421 "test-model".to_string(),
422 0.7,
423 None,
424 zeroclaw_providers::ModelProviderRuntimeOptions::default(),
425 );
426
427 assert_eq!(tool.name(), "llm_task");
428 assert!(tool.description().contains("LLM"));
429
430 let schema = tool.parameters_schema();
431 assert_eq!(schema["type"], "object");
432 assert!(schema["properties"]["prompt"].is_object());
433 assert!(schema["properties"]["schema"].is_object());
434 assert!(schema["properties"]["model"].is_object());
435 assert!(schema["properties"]["temperature"].is_object());
436
437 let required = schema["required"].as_array().unwrap();
438 assert_eq!(required.len(), 1);
439 assert_eq!(required[0], "prompt");
440 }
441
442 #[tokio::test]
443 async fn execute_missing_prompt_returns_error() {
444 let tool = LlmTaskTool::new(
445 Arc::new(SecurityPolicy::default()),
446 "openrouter".to_string(),
447 "test-model".to_string(),
448 0.7,
449 None,
450 zeroclaw_providers::ModelProviderRuntimeOptions::default(),
451 );
452
453 let result = tool.execute(json!({})).await.unwrap();
454 assert!(!result.success);
455 assert!(result.error.as_deref().unwrap().contains("prompt"));
456 }
457
458 #[tokio::test]
459 async fn execute_empty_prompt_returns_error() {
460 let tool = LlmTaskTool::new(
461 Arc::new(SecurityPolicy::default()),
462 "openrouter".to_string(),
463 "test-model".to_string(),
464 0.7,
465 None,
466 zeroclaw_providers::ModelProviderRuntimeOptions::default(),
467 );
468
469 let result = tool.execute(json!({"prompt": " "})).await.unwrap();
470 assert!(!result.success);
471 assert!(result.error.as_deref().unwrap().contains("prompt"));
472 }
473
474 #[tokio::test]
475 async fn execute_with_invalid_provider_returns_error() {
476 let tool = LlmTaskTool::new(
477 Arc::new(SecurityPolicy::default()),
478 "nonexistent_provider_xyz".to_string(),
479 "test-model".to_string(),
480 0.7,
481 None,
482 zeroclaw_providers::ModelProviderRuntimeOptions::default(),
483 );
484
485 let result = tool
486 .execute(json!({"prompt": "Hello world"}))
487 .await
488 .unwrap();
489 assert!(!result.success);
490 assert!(result.error.as_deref().unwrap().contains("model_provider"));
491 }
492}