Skip to main content

zeroclaw_tools/
llm_task.rs

1//! Lightweight LLM task tool for structured JSON-only sub-calls.
2//!
3//! Runs a single prompt through an LLM model_provider with no tool access and
4//! optionally validates the response against a caller-supplied JSON Schema.
5//! Ideal for structured data extraction in workflows.
6
7use async_trait::async_trait;
8use serde_json::json;
9use std::sync::Arc;
10use zeroclaw_api::model_provider::ModelProvider;
11use zeroclaw_api::tool::{Tool, ToolResult};
12use zeroclaw_config::policy::SecurityPolicy;
13use zeroclaw_config::policy::ToolOperation;
14
15/// Tool that runs a single prompt through an LLM and optionally validates
16/// the response against a JSON Schema. No tools are provided to the LLM —
17/// this is a pure text-in, text-out (or JSON-out) call.
18pub struct LlmTaskTool {
19    security: Arc<SecurityPolicy>,
20    /// Default model_provider name from root config (e.g. "openrouter").
21    default_model_provider: String,
22    /// Default model from root config.
23    default_model: String,
24    /// Default temperature from root config.
25    default_temperature: f64,
26    /// API key for model_provider authentication.
27    api_key: Option<String>,
28    /// ModelProvider runtime options inherited from root config.
29    provider_runtime_options: zeroclaw_providers::ModelProviderRuntimeOptions,
30}
31
32impl LlmTaskTool {
33    pub fn new(
34        security: Arc<SecurityPolicy>,
35        default_model_provider: String,
36        default_model: String,
37        default_temperature: f64,
38        api_key: Option<String>,
39        provider_runtime_options: zeroclaw_providers::ModelProviderRuntimeOptions,
40    ) -> Self {
41        Self {
42            security,
43            default_model_provider,
44            default_model,
45            default_temperature,
46            api_key,
47            provider_runtime_options,
48        }
49    }
50}
51
52#[async_trait]
53impl Tool for LlmTaskTool {
54    fn name(&self) -> &str {
55        "llm_task"
56    }
57
58    fn description(&self) -> &str {
59        "Run a prompt through an LLM with no tool access and return the response. \
60         Optionally validates the output against a JSON Schema. Ideal for structured \
61         data extraction, classification, summarization, and transformation tasks."
62    }
63
64    fn parameters_schema(&self) -> serde_json::Value {
65        json!({
66            "type": "object",
67            "properties": {
68                "prompt": {
69                    "type": "string",
70                    "description": "The prompt to send to the LLM."
71                },
72                "schema": {
73                    "type": "object",
74                    "description": "Optional JSON Schema to validate the LLM response against. \
75                                    When provided, the LLM is instructed to return valid JSON \
76                                    matching this schema."
77                },
78                "model": {
79                    "type": "string",
80                    "description": "Optional model override (e.g. 'anthropic/claude-sonnet-4-6'). \
81                                    Defaults to the configured default model."
82                },
83                "temperature": {
84                    "type": "number",
85                    "description": "Optional temperature override (0.0-2.0). \
86                                    Defaults to the configured default temperature."
87                }
88            },
89            "required": ["prompt"]
90        })
91    }
92
93    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
94        // Security gate
95        if let Err(error) = self
96            .security
97            .enforce_tool_operation(ToolOperation::Act, "llm_task")
98        {
99            return Ok(ToolResult {
100                success: false,
101                output: String::new(),
102                error: Some(error),
103            });
104        }
105
106        // Extract required prompt
107        let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
108            Some(p) if !p.trim().is_empty() => p,
109            _ => {
110                return Ok(ToolResult {
111                    success: false,
112                    output: String::new(),
113                    error: Some("Missing or empty required parameter: prompt".to_string()),
114                });
115            }
116        };
117
118        // Extract optional overrides
119        let schema = args.get("schema").and_then(|v| v.as_object());
120        let model = args
121            .get("model")
122            .and_then(|v| v.as_str())
123            .unwrap_or(&self.default_model);
124        let temperature = args
125            .get("temperature")
126            .and_then(|v| v.as_f64())
127            .unwrap_or(self.default_temperature);
128
129        // Build the effective prompt, adding JSON schema instructions when needed
130        let effective_prompt = if let Some(schema_obj) = schema {
131            let schema_json =
132                serde_json::to_string_pretty(&serde_json::Value::Object(schema_obj.clone()))
133                    .unwrap_or_else(|_| "{}".to_string());
134            format!(
135                "{prompt}\n\n\
136                 IMPORTANT: You MUST respond with valid JSON that conforms to this schema:\n\
137                 ```json\n{schema_json}\n```\n\
138                 Respond ONLY with the JSON object, no explanation or markdown."
139            )
140        } else {
141            prompt.to_string()
142        };
143
144        // Create model_provider
145        let api_key_ref = self.api_key.as_deref();
146        let model_provider: Box<dyn ModelProvider> =
147            match zeroclaw_providers::create_model_provider_with_options(
148                &self.default_model_provider,
149                api_key_ref,
150                &self.provider_runtime_options,
151            ) {
152                Ok(p) => p,
153                Err(e) => {
154                    return Ok(ToolResult {
155                        success: false,
156                        output: String::new(),
157                        error: Some(format!("Failed to create model_provider: {e}")),
158                    });
159                }
160            };
161
162        // Make the LLM call (no tools, no agent loop). `temperature` is
163        // already resolved to an f64 (tool arg → config default), so wrap
164        // it back into Some for the model_provider trait's Option<f64> contract.
165        let response = match model_provider
166            .simple_chat(&effective_prompt, model, Some(temperature))
167            .await
168        {
169            Ok(text) => text,
170            Err(e) => {
171                return Ok(ToolResult {
172                    success: false,
173                    output: String::new(),
174                    error: Some(format!("LLM call failed: {e}")),
175                });
176            }
177        };
178
179        // If schema was provided, validate the response
180        if let Some(schema_obj) = schema {
181            let schema_value = serde_json::Value::Object(schema_obj.clone());
182            match validate_json_response(&response, &schema_value) {
183                Ok(validated_json) => Ok(ToolResult {
184                    success: true,
185                    output: validated_json,
186                    error: None,
187                }),
188                Err(validation_error) => Ok(ToolResult {
189                    success: false,
190                    output: response,
191                    error: Some(format!("Schema validation failed: {validation_error}")),
192                }),
193            }
194        } else {
195            Ok(ToolResult {
196                success: true,
197                output: response,
198                error: None,
199            })
200        }
201    }
202}
203
204/// Validate a JSON response string against a JSON Schema value.
205///
206/// Performs lightweight validation: parses the response as JSON, checks that
207/// required fields exist, and verifies basic type constraints (string, number,
208/// integer, boolean, array, object) for each declared property.
209fn validate_json_response(response: &str, schema: &serde_json::Value) -> Result<String, String> {
210    // Strip markdown code fences if the LLM wrapped the response
211    let trimmed = response.trim();
212    let json_str = if trimmed.starts_with("```") {
213        trimmed
214            .trim_start_matches("```json")
215            .trim_start_matches("```")
216            .trim_end_matches("```")
217            .trim()
218    } else {
219        trimmed
220    };
221
222    // Parse as JSON
223    let parsed: serde_json::Value =
224        serde_json::from_str(json_str).map_err(|e| format!("Invalid JSON: {e}"))?;
225
226    // Check required fields
227    if let Some(required) = schema.get("required").and_then(|v| v.as_array()) {
228        for req in required {
229            if let Some(field_name) = req.as_str()
230                && parsed.get(field_name).is_none()
231            {
232                return Err(format!("Missing required field: {field_name}"));
233            }
234        }
235    }
236
237    // Check property types
238    if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
239        for (prop_name, prop_schema) in properties {
240            if let Some(value) = parsed.get(prop_name)
241                && let Some(expected_type) = prop_schema.get("type").and_then(|t| t.as_str())
242                && !type_matches(value, expected_type)
243            {
244                return Err(format!(
245                    "Field '{prop_name}' has wrong type: expected {expected_type}, \
246                             got {}",
247                    json_type_name(value)
248                ));
249            }
250        }
251    }
252
253    // Return the cleaned, re-serialized JSON
254    serde_json::to_string(&parsed).map_err(|e| format!("JSON serialization error: {e}"))
255}
256
257/// Check whether a JSON value matches an expected JSON Schema type string.
258fn type_matches(value: &serde_json::Value, expected: &str) -> bool {
259    match expected {
260        "string" => value.is_string(),
261        "number" => value.is_number(),
262        "integer" => value.is_i64() || value.is_u64(),
263        "boolean" => value.is_boolean(),
264        "array" => value.is_array(),
265        "object" => value.is_object(),
266        "null" => value.is_null(),
267        _ => true, // Unknown type — accept
268    }
269}
270
271/// Return a human-readable type name for a JSON value.
272fn json_type_name(value: &serde_json::Value) -> &'static str {
273    match value {
274        serde_json::Value::Null => "null",
275        serde_json::Value::Bool(_) => "boolean",
276        serde_json::Value::Number(_) => "number",
277        serde_json::Value::String(_) => "string",
278        serde_json::Value::Array(_) => "array",
279        serde_json::Value::Object(_) => "object",
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    // ── Schema validation tests ──────────────────────────────────────
288
289    #[test]
290    fn validate_valid_json_against_schema() {
291        let schema = json!({
292            "type": "object",
293            "properties": {
294                "name": { "type": "string" },
295                "age": { "type": "integer" }
296            },
297            "required": ["name", "age"]
298        });
299
300        let response = r#"{"name": "Alice", "age": 30}"#;
301        let result = validate_json_response(response, &schema);
302        assert!(result.is_ok());
303
304        let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
305        assert_eq!(parsed["name"], "Alice");
306        assert_eq!(parsed["age"], 30);
307    }
308
309    #[test]
310    fn validate_missing_required_field() {
311        let schema = json!({
312            "type": "object",
313            "properties": {
314                "title": { "type": "string" },
315                "score": { "type": "number" }
316            },
317            "required": ["title", "score"]
318        });
319
320        let response = r#"{"title": "Test"}"#;
321        let result = validate_json_response(response, &schema);
322        assert!(result.is_err());
323        assert!(
324            result
325                .unwrap_err()
326                .contains("Missing required field: score")
327        );
328    }
329
330    #[test]
331    fn validate_wrong_type() {
332        let schema = json!({
333            "type": "object",
334            "properties": {
335                "count": { "type": "integer" }
336            },
337            "required": ["count"]
338        });
339
340        let response = r#"{"count": "not_a_number"}"#;
341        let result = validate_json_response(response, &schema);
342        assert!(result.is_err());
343        assert!(result.unwrap_err().contains("wrong type"));
344    }
345
346    #[test]
347    fn validate_strips_markdown_code_fences() {
348        let schema = json!({
349            "type": "object",
350            "properties": {
351                "result": { "type": "string" }
352            },
353            "required": ["result"]
354        });
355
356        let response = "```json\n{\"result\": \"ok\"}\n```";
357        let result = validate_json_response(response, &schema);
358        assert!(result.is_ok());
359    }
360
361    #[test]
362    fn validate_invalid_json() {
363        let schema = json!({ "type": "object" });
364        let response = "this is not json at all";
365        let result = validate_json_response(response, &schema);
366        assert!(result.is_err());
367        assert!(result.unwrap_err().contains("Invalid JSON"));
368    }
369
370    #[test]
371    fn validate_optional_fields_accepted() {
372        let schema = json!({
373            "type": "object",
374            "properties": {
375                "name": { "type": "string" },
376                "bio": { "type": "string" }
377            },
378            "required": ["name"]
379        });
380
381        // bio is optional, so this should pass
382        let response = r#"{"name": "Bob"}"#;
383        let result = validate_json_response(response, &schema);
384        assert!(result.is_ok());
385    }
386
387    #[test]
388    fn validate_all_type_checks() {
389        assert!(type_matches(&json!("hello"), "string"));
390        assert!(!type_matches(&json!(42), "string"));
391
392        assert!(type_matches(&json!(2.72), "number"));
393        assert!(type_matches(&json!(42), "number"));
394        assert!(!type_matches(&json!("42"), "number"));
395
396        assert!(type_matches(&json!(42), "integer"));
397        assert!(!type_matches(&json!(2.72), "integer"));
398
399        assert!(type_matches(&json!(true), "boolean"));
400        assert!(!type_matches(&json!(1), "boolean"));
401
402        assert!(type_matches(&json!([1, 2]), "array"));
403        assert!(!type_matches(&json!({}), "array"));
404
405        assert!(type_matches(&json!({}), "object"));
406        assert!(!type_matches(&json!([]), "object"));
407
408        assert!(type_matches(&json!(null), "null"));
409
410        // Unknown types are accepted
411        assert!(type_matches(&json!("anything"), "custom_type"));
412    }
413
414    // ── Tool trait tests ─────────────────────────────────────────────
415
416    #[test]
417    fn tool_metadata() {
418        let tool = LlmTaskTool::new(
419            Arc::new(SecurityPolicy::default()),
420            "openrouter".to_string(),
421            "test-model".to_string(),
422            0.7,
423            None,
424            zeroclaw_providers::ModelProviderRuntimeOptions::default(),
425        );
426
427        assert_eq!(tool.name(), "llm_task");
428        assert!(tool.description().contains("LLM"));
429
430        let schema = tool.parameters_schema();
431        assert_eq!(schema["type"], "object");
432        assert!(schema["properties"]["prompt"].is_object());
433        assert!(schema["properties"]["schema"].is_object());
434        assert!(schema["properties"]["model"].is_object());
435        assert!(schema["properties"]["temperature"].is_object());
436
437        let required = schema["required"].as_array().unwrap();
438        assert_eq!(required.len(), 1);
439        assert_eq!(required[0], "prompt");
440    }
441
442    #[tokio::test]
443    async fn execute_missing_prompt_returns_error() {
444        let tool = LlmTaskTool::new(
445            Arc::new(SecurityPolicy::default()),
446            "openrouter".to_string(),
447            "test-model".to_string(),
448            0.7,
449            None,
450            zeroclaw_providers::ModelProviderRuntimeOptions::default(),
451        );
452
453        let result = tool.execute(json!({})).await.unwrap();
454        assert!(!result.success);
455        assert!(result.error.as_deref().unwrap().contains("prompt"));
456    }
457
458    #[tokio::test]
459    async fn execute_empty_prompt_returns_error() {
460        let tool = LlmTaskTool::new(
461            Arc::new(SecurityPolicy::default()),
462            "openrouter".to_string(),
463            "test-model".to_string(),
464            0.7,
465            None,
466            zeroclaw_providers::ModelProviderRuntimeOptions::default(),
467        );
468
469        let result = tool.execute(json!({"prompt": "  "})).await.unwrap();
470        assert!(!result.success);
471        assert!(result.error.as_deref().unwrap().contains("prompt"));
472    }
473
474    #[tokio::test]
475    async fn execute_with_invalid_provider_returns_error() {
476        let tool = LlmTaskTool::new(
477            Arc::new(SecurityPolicy::default()),
478            "nonexistent_provider_xyz".to_string(),
479            "test-model".to_string(),
480            0.7,
481            None,
482            zeroclaw_providers::ModelProviderRuntimeOptions::default(),
483        );
484
485        let result = tool
486            .execute(json!({"prompt": "Hello world"}))
487            .await
488            .unwrap();
489        assert!(!result.success);
490        assert!(result.error.as_deref().unwrap().contains("model_provider"));
491    }
492}