Skip to main content

zeroclaw_tools/
llm_task.rs

1//! Lightweight LLM task tool for structured JSON-only sub-calls.
2//!
3//! Runs a single prompt through an LLM model_provider with no tool access and
4//! optionally validates the response against a caller-supplied JSON Schema.
5//! Ideal for structured data extraction in workflows.
6
7use async_trait::async_trait;
8use serde_json::json;
9use std::sync::Arc;
10use zeroclaw_api::model_provider::ModelProvider;
11use zeroclaw_api::tool::{Tool, ToolResult};
12use zeroclaw_config::policy::SecurityPolicy;
13use zeroclaw_config::policy::ToolOperation;
14
15/// Tool that runs a single prompt through an LLM and optionally validates
16/// the response against a JSON Schema. No tools are provided to the LLM —
17/// this is a pure text-in, text-out (or JSON-out) call.
18pub struct LlmTaskTool {
19    security: Arc<SecurityPolicy>,
20    /// Default model_provider name from root config (e.g. "openrouter").
21    default_model_provider: String,
22    /// Default model from root config.
23    default_model: String,
24    /// Default temperature from root config. `None` means no temperature
25    /// is sent on the wire; provider applies its own default.
26    default_temperature: Option<f64>,
27    /// API key for model_provider authentication.
28    api_key: Option<String>,
29    /// ModelProvider runtime options inherited from root config.
30    provider_runtime_options: zeroclaw_providers::ModelProviderRuntimeOptions,
31}
32
33impl LlmTaskTool {
34    pub fn new(
35        security: Arc<SecurityPolicy>,
36        default_model_provider: String,
37        default_model: String,
38        default_temperature: Option<f64>,
39        api_key: Option<String>,
40        provider_runtime_options: zeroclaw_providers::ModelProviderRuntimeOptions,
41    ) -> Self {
42        Self {
43            security,
44            default_model_provider,
45            default_model,
46            default_temperature,
47            api_key,
48            provider_runtime_options,
49        }
50    }
51}
52
53#[async_trait]
54impl Tool for LlmTaskTool {
55    fn name(&self) -> &str {
56        "llm_task"
57    }
58
59    fn description(&self) -> &str {
60        "Run a prompt through an LLM with no tool access and return the response. \
61         Optionally validates the output against a JSON Schema. Ideal for structured \
62         data extraction, classification, summarization, and transformation tasks."
63    }
64
65    fn parameters_schema(&self) -> serde_json::Value {
66        json!({
67            "type": "object",
68            "properties": {
69                "prompt": {
70                    "type": "string",
71                    "description": "The prompt to send to the LLM."
72                },
73                "schema": {
74                    "type": "object",
75                    "description": "Optional JSON Schema to validate the LLM response against. \
76                                    When provided, the LLM is instructed to return valid JSON \
77                                    matching this schema."
78                },
79                "model": {
80                    "type": "string",
81                    "description": "Optional model override (e.g. 'anthropic/claude-sonnet-4-6'). \
82                                    Defaults to the configured default model."
83                },
84                "temperature": {
85                    "type": "number",
86                    "description": "Optional temperature override (0.0-2.0). \
87                                    Defaults to the configured default temperature."
88                }
89            },
90            "required": ["prompt"]
91        })
92    }
93
94    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
95        // Security gate
96        if let Err(error) = self
97            .security
98            .enforce_tool_operation(ToolOperation::Act, "llm_task")
99        {
100            return Ok(ToolResult {
101                success: false,
102                output: String::new(),
103                error: Some(error),
104            });
105        }
106
107        // Extract required prompt
108        let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
109            Some(p) if !p.trim().is_empty() => p,
110            _ => {
111                return Ok(ToolResult {
112                    success: false,
113                    output: String::new(),
114                    error: Some("Missing or empty required parameter: prompt".to_string()),
115                });
116            }
117        };
118
119        // Extract optional overrides
120        let schema = args.get("schema").and_then(|v| v.as_object());
121        let model = args
122            .get("model")
123            .and_then(|v| v.as_str())
124            .unwrap_or(&self.default_model);
125        let temperature = args
126            .get("temperature")
127            .and_then(|v| v.as_f64())
128            .or(self.default_temperature);
129
130        // Build the effective prompt, adding JSON schema instructions when needed
131        let effective_prompt = if let Some(schema_obj) = schema {
132            let schema_json =
133                serde_json::to_string_pretty(&serde_json::Value::Object(schema_obj.clone()))
134                    .unwrap_or_else(|_| "{}".to_string());
135            format!(
136                "{prompt}\n\n\
137                 IMPORTANT: You MUST respond with valid JSON that conforms to this schema:\n\
138                 ```json\n{schema_json}\n```\n\
139                 Respond ONLY with the JSON object, no explanation or markdown."
140            )
141        } else {
142            prompt.to_string()
143        };
144
145        // Create model_provider
146        let api_key_ref = self.api_key.as_deref();
147        let model_provider: Box<dyn ModelProvider> =
148            match zeroclaw_providers::create_model_provider_with_options(
149                &self.default_model_provider,
150                api_key_ref,
151                &self.provider_runtime_options,
152            ) {
153                Ok(p) => p,
154                Err(e) => {
155                    return Ok(ToolResult {
156                        success: false,
157                        output: String::new(),
158                        error: Some(format!("Failed to create model_provider: {e}")),
159                    });
160                }
161            };
162
163        // Make the LLM call (no tools, no agent loop). `temperature` is
164        // already Option<f64>; pass straight through. None omits the field
165        // on the wire so the provider applies its own default.
166        let response = match model_provider
167            .simple_chat(&effective_prompt, model, temperature)
168            .await
169        {
170            Ok(text) => text,
171            Err(e) => {
172                return Ok(ToolResult {
173                    success: false,
174                    output: String::new(),
175                    error: Some(format!("LLM call failed: {e}")),
176                });
177            }
178        };
179
180        // If schema was provided, validate the response
181        if let Some(schema_obj) = schema {
182            let schema_value = serde_json::Value::Object(schema_obj.clone());
183            match validate_json_response(&response, &schema_value) {
184                Ok(validated_json) => Ok(ToolResult {
185                    success: true,
186                    output: validated_json,
187                    error: None,
188                }),
189                Err(validation_error) => Ok(ToolResult {
190                    success: false,
191                    output: response,
192                    error: Some(format!("Schema validation failed: {validation_error}")),
193                }),
194            }
195        } else {
196            Ok(ToolResult {
197                success: true,
198                output: response,
199                error: None,
200            })
201        }
202    }
203}
204
205/// Validate a JSON response string against a JSON Schema value.
206///
207/// Performs lightweight validation: parses the response as JSON, checks that
208/// required fields exist, and verifies basic type constraints (string, number,
209/// integer, boolean, array, object) for each declared property.
210fn validate_json_response(response: &str, schema: &serde_json::Value) -> Result<String, String> {
211    // Strip markdown code fences if the LLM wrapped the response
212    let trimmed = response.trim();
213    let json_str = if trimmed.starts_with("```") {
214        trimmed
215            .trim_start_matches("```json")
216            .trim_start_matches("```")
217            .trim_end_matches("```")
218            .trim()
219    } else {
220        trimmed
221    };
222
223    // Parse as JSON
224    let parsed: serde_json::Value =
225        serde_json::from_str(json_str).map_err(|e| format!("Invalid JSON: {e}"))?;
226
227    // Check required fields
228    if let Some(required) = schema.get("required").and_then(|v| v.as_array()) {
229        for req in required {
230            if let Some(field_name) = req.as_str()
231                && parsed.get(field_name).is_none()
232            {
233                return Err(format!("Missing required field: {field_name}"));
234            }
235        }
236    }
237
238    // Check property types
239    if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
240        for (prop_name, prop_schema) in properties {
241            if let Some(value) = parsed.get(prop_name)
242                && let Some(expected_type) = prop_schema.get("type").and_then(|t| t.as_str())
243                && !type_matches(value, expected_type)
244            {
245                return Err(format!(
246                    "Field '{prop_name}' has wrong type: expected {expected_type}, \
247                             got {}",
248                    json_type_name(value)
249                ));
250            }
251        }
252    }
253
254    // Return the cleaned, re-serialized JSON
255    serde_json::to_string(&parsed).map_err(|e| format!("JSON serialization error: {e}"))
256}
257
258/// Check whether a JSON value matches an expected JSON Schema type string.
259fn type_matches(value: &serde_json::Value, expected: &str) -> bool {
260    match expected {
261        "string" => value.is_string(),
262        "number" => value.is_number(),
263        "integer" => value.is_i64() || value.is_u64(),
264        "boolean" => value.is_boolean(),
265        "array" => value.is_array(),
266        "object" => value.is_object(),
267        "null" => value.is_null(),
268        _ => true, // Unknown type — accept
269    }
270}
271
272/// Return a human-readable type name for a JSON value.
273fn json_type_name(value: &serde_json::Value) -> &'static str {
274    match value {
275        serde_json::Value::Null => "null",
276        serde_json::Value::Bool(_) => "boolean",
277        serde_json::Value::Number(_) => "number",
278        serde_json::Value::String(_) => "string",
279        serde_json::Value::Array(_) => "array",
280        serde_json::Value::Object(_) => "object",
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    // ── Schema validation tests ──────────────────────────────────────
289
290    #[test]
291    fn validate_valid_json_against_schema() {
292        let schema = json!({
293            "type": "object",
294            "properties": {
295                "name": { "type": "string" },
296                "age": { "type": "integer" }
297            },
298            "required": ["name", "age"]
299        });
300
301        let response = r#"{"name": "Alice", "age": 30}"#;
302        let result = validate_json_response(response, &schema);
303        assert!(result.is_ok());
304
305        let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
306        assert_eq!(parsed["name"], "Alice");
307        assert_eq!(parsed["age"], 30);
308    }
309
310    #[test]
311    fn validate_missing_required_field() {
312        let schema = json!({
313            "type": "object",
314            "properties": {
315                "title": { "type": "string" },
316                "score": { "type": "number" }
317            },
318            "required": ["title", "score"]
319        });
320
321        let response = r#"{"title": "Test"}"#;
322        let result = validate_json_response(response, &schema);
323        assert!(result.is_err());
324        assert!(
325            result
326                .unwrap_err()
327                .contains("Missing required field: score")
328        );
329    }
330
331    #[test]
332    fn validate_wrong_type() {
333        let schema = json!({
334            "type": "object",
335            "properties": {
336                "count": { "type": "integer" }
337            },
338            "required": ["count"]
339        });
340
341        let response = r#"{"count": "not_a_number"}"#;
342        let result = validate_json_response(response, &schema);
343        assert!(result.is_err());
344        assert!(result.unwrap_err().contains("wrong type"));
345    }
346
347    #[test]
348    fn validate_strips_markdown_code_fences() {
349        let schema = json!({
350            "type": "object",
351            "properties": {
352                "result": { "type": "string" }
353            },
354            "required": ["result"]
355        });
356
357        let response = "```json\n{\"result\": \"ok\"}\n```";
358        let result = validate_json_response(response, &schema);
359        assert!(result.is_ok());
360    }
361
362    #[test]
363    fn validate_invalid_json() {
364        let schema = json!({ "type": "object" });
365        let response = "this is not json at all";
366        let result = validate_json_response(response, &schema);
367        assert!(result.is_err());
368        assert!(result.unwrap_err().contains("Invalid JSON"));
369    }
370
371    #[test]
372    fn validate_optional_fields_accepted() {
373        let schema = json!({
374            "type": "object",
375            "properties": {
376                "name": { "type": "string" },
377                "bio": { "type": "string" }
378            },
379            "required": ["name"]
380        });
381
382        // bio is optional, so this should pass
383        let response = r#"{"name": "Bob"}"#;
384        let result = validate_json_response(response, &schema);
385        assert!(result.is_ok());
386    }
387
388    #[test]
389    fn validate_all_type_checks() {
390        assert!(type_matches(&json!("hello"), "string"));
391        assert!(!type_matches(&json!(42), "string"));
392
393        assert!(type_matches(&json!(2.72), "number"));
394        assert!(type_matches(&json!(42), "number"));
395        assert!(!type_matches(&json!("42"), "number"));
396
397        assert!(type_matches(&json!(42), "integer"));
398        assert!(!type_matches(&json!(2.72), "integer"));
399
400        assert!(type_matches(&json!(true), "boolean"));
401        assert!(!type_matches(&json!(1), "boolean"));
402
403        assert!(type_matches(&json!([1, 2]), "array"));
404        assert!(!type_matches(&json!({}), "array"));
405
406        assert!(type_matches(&json!({}), "object"));
407        assert!(!type_matches(&json!([]), "object"));
408
409        assert!(type_matches(&json!(null), "null"));
410
411        // Unknown types are accepted
412        assert!(type_matches(&json!("anything"), "custom_type"));
413    }
414
415    // ── Tool trait tests ─────────────────────────────────────────────
416
417    #[test]
418    fn tool_metadata() {
419        let tool = LlmTaskTool::new(
420            Arc::new(SecurityPolicy::default()),
421            "openrouter".to_string(),
422            "test-model".to_string(),
423            Some(0.7),
424            None,
425            zeroclaw_providers::ModelProviderRuntimeOptions::default(),
426        );
427
428        assert_eq!(tool.name(), "llm_task");
429        assert!(tool.description().contains("LLM"));
430
431        let schema = tool.parameters_schema();
432        assert_eq!(schema["type"], "object");
433        assert!(schema["properties"]["prompt"].is_object());
434        assert!(schema["properties"]["schema"].is_object());
435        assert!(schema["properties"]["model"].is_object());
436        assert!(schema["properties"]["temperature"].is_object());
437
438        let required = schema["required"].as_array().unwrap();
439        assert_eq!(required.len(), 1);
440        assert_eq!(required[0], "prompt");
441    }
442
443    #[tokio::test]
444    async fn execute_missing_prompt_returns_error() {
445        let tool = LlmTaskTool::new(
446            Arc::new(SecurityPolicy::default()),
447            "openrouter".to_string(),
448            "test-model".to_string(),
449            Some(0.7),
450            None,
451            zeroclaw_providers::ModelProviderRuntimeOptions::default(),
452        );
453
454        let result = tool.execute(json!({})).await.unwrap();
455        assert!(!result.success);
456        assert!(result.error.as_deref().unwrap().contains("prompt"));
457    }
458
459    #[tokio::test]
460    async fn execute_empty_prompt_returns_error() {
461        let tool = LlmTaskTool::new(
462            Arc::new(SecurityPolicy::default()),
463            "openrouter".to_string(),
464            "test-model".to_string(),
465            Some(0.7),
466            None,
467            zeroclaw_providers::ModelProviderRuntimeOptions::default(),
468        );
469
470        let result = tool.execute(json!({"prompt": "  "})).await.unwrap();
471        assert!(!result.success);
472        assert!(result.error.as_deref().unwrap().contains("prompt"));
473    }
474
475    #[tokio::test]
476    async fn execute_with_invalid_provider_returns_error() {
477        let tool = LlmTaskTool::new(
478            Arc::new(SecurityPolicy::default()),
479            "nonexistent_provider_xyz".to_string(),
480            "test-model".to_string(),
481            Some(0.7),
482            None,
483            zeroclaw_providers::ModelProviderRuntimeOptions::default(),
484        );
485
486        let result = tool
487            .execute(json!({"prompt": "Hello world"}))
488            .await
489            .unwrap();
490        assert!(!result.success);
491        assert!(result.error.as_deref().unwrap().contains("model_provider"));
492    }
493}