Skip to main content

zeroclaw_tools/
pipeline.rs

1// Pipeline tool: collapses multi-step tool chains into a single inference call.
2//
3// The agent invokes `execute_pipeline` with a JSON payload describing steps,
4// and this tool executes them sequentially (or in parallel) with result
5// interpolation between steps.
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashSet;
11use std::sync::Arc;
12use zeroclaw_api::tool::{Tool, ToolResult};
13use zeroclaw_config::schema::PipelineConfig;
14
15/// Errors specific to pipeline execution.
16#[derive(Debug, Clone, Serialize, thiserror::Error)]
17pub enum PipelineError {
18    #[error("Unknown tool '{0}' is not on the allowed list")]
19    UnknownTool(String),
20    #[error("Pipeline exceeds maximum of {0} steps")]
21    TooManySteps(usize),
22    #[error("Invalid template reference: {0}")]
23    InvalidTemplate(String),
24    #[error("Step {index} ({tool}) failed: {message}")]
25    StepFailed {
26        index: usize,
27        tool: String,
28        message: String,
29    },
30}
31
32/// A single step in a pipeline.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct PipelineStep {
35    pub tool: String,
36    pub args: serde_json::Value,
37}
38
39/// The pipeline request payload.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct PipelineRequest {
42    pub steps: Vec<PipelineStep>,
43    #[serde(default)]
44    pub parallel: bool,
45    /// What to include in the tool output. Defaults to every step's result.
46    #[serde(default)]
47    pub result: PipelineResultMode,
48}
49
50/// Controls what `execute_pipeline` returns to the caller.
51#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
52#[serde(rename_all = "snake_case")]
53pub enum PipelineResultMode {
54    /// Return every step's result as a JSON array (default; backward compatible).
55    #[default]
56    All,
57    /// Return only the final step's raw output. Use when earlier steps produce
58    /// large intermediate blobs (e.g. base64) that must not flow back into the
59    /// model context.
60    Last,
61}
62
63/// Result of a single pipeline step.
64#[derive(Debug, Clone, Serialize)]
65pub struct StepResult {
66    pub index: usize,
67    pub tool: String,
68    pub success: bool,
69    pub output: String,
70}
71
72/// The execute_pipeline tool that runs multi-step tool chains.
73pub struct PipelineTool {
74    config: PipelineConfig,
75    tools: Vec<Arc<dyn Tool>>,
76    allowed_set: HashSet<String>,
77}
78
79impl PipelineTool {
80    pub fn new(config: PipelineConfig, tools: Vec<Arc<dyn Tool>>) -> Self {
81        let allowed_set: HashSet<String> = config.allowed_tools.iter().cloned().collect();
82        Self {
83            config,
84            tools,
85            allowed_set,
86        }
87    }
88
89    /// Find a tool by name in the registry.
90    fn find_tool(&self, name: &str) -> Option<&dyn Tool> {
91        self.tools
92            .iter()
93            .find(|t| t.name() == name)
94            .map(|t| t.as_ref())
95    }
96
97    /// Validate the pipeline request before execution.
98    fn validate(&self, request: &PipelineRequest) -> std::result::Result<(), PipelineError> {
99        if request.steps.len() > self.config.max_steps {
100            return Err(PipelineError::TooManySteps(self.config.max_steps));
101        }
102
103        // Check all tools are on the allowlist before executing any.
104        for step in &request.steps {
105            if !self.allowed_set.contains(&step.tool) {
106                return Err(PipelineError::UnknownTool(step.tool.clone()));
107            }
108        }
109
110        Ok(())
111    }
112
113    /// Execute steps sequentially, interpolating results.
114    async fn execute_sequential(
115        &self,
116        steps: &[PipelineStep],
117    ) -> std::result::Result<Vec<StepResult>, PipelineError> {
118        let mut results: Vec<StepResult> = Vec::with_capacity(steps.len());
119
120        for (i, step) in steps.iter().enumerate() {
121            let tool = self
122                .find_tool(&step.tool)
123                .ok_or_else(|| PipelineError::UnknownTool(step.tool.clone()))?;
124
125            // Interpolate previous step results into args.
126            let interpolated_args = interpolate_args(&step.args, &results);
127
128            let tool_result =
129                tool.execute(interpolated_args)
130                    .await
131                    .map_err(|e| PipelineError::StepFailed {
132                        index: i,
133                        tool: step.tool.clone(),
134                        message: e.to_string(),
135                    })?;
136
137            if !tool_result.success {
138                return Err(PipelineError::StepFailed {
139                    index: i,
140                    tool: step.tool.clone(),
141                    message: tool_result
142                        .error
143                        .unwrap_or_else(|| tool_result.output.clone()),
144                });
145            }
146
147            results.push(StepResult {
148                index: i,
149                tool: step.tool.clone(),
150                success: true,
151                output: tool_result.output,
152            });
153        }
154
155        Ok(results)
156    }
157
158    /// Execute independent steps in parallel (no interpolation between them).
159    async fn execute_parallel(
160        &self,
161        steps: &[PipelineStep],
162    ) -> std::result::Result<Vec<StepResult>, PipelineError> {
163        use tokio::task::JoinSet;
164
165        let mut join_set = JoinSet::new();
166
167        for (i, step) in steps.iter().enumerate() {
168            let tool = self
169                .find_tool(&step.tool)
170                .ok_or_else(|| PipelineError::UnknownTool(step.tool.clone()))?;
171
172            // Clone what we need for the spawned task.
173            let tool_name = step.tool.clone();
174            let args = step.args.clone();
175
176            // We need a reference that lives long enough — use Arc.
177            let tool_arc = self.tools.iter().find(|t| t.name() == tool.name()).cloned();
178
179            if let Some(tool_arc) = tool_arc {
180                join_set.spawn(async move {
181                    let result = tool_arc.execute(args).await;
182                    (i, tool_name, result)
183                });
184            }
185        }
186
187        let mut results: Vec<StepResult> = Vec::with_capacity(steps.len());
188
189        while let Some(join_result) = join_set.join_next().await {
190            let (index, tool_name, tool_result) =
191                join_result.map_err(|e| PipelineError::StepFailed {
192                    index: 0,
193                    tool: "unknown".to_string(),
194                    message: format!("Task join error: {e}"),
195                })?;
196
197            let tool_result = tool_result.map_err(|e| PipelineError::StepFailed {
198                index,
199                tool: tool_name.clone(),
200                message: e.to_string(),
201            })?;
202
203            if !tool_result.success {
204                return Err(PipelineError::StepFailed {
205                    index,
206                    tool: tool_name,
207                    message: tool_result
208                        .error
209                        .unwrap_or_else(|| tool_result.output.clone()),
210                });
211            }
212
213            results.push(StepResult {
214                index,
215                tool: tool_name,
216                success: true,
217                output: tool_result.output,
218            });
219        }
220
221        // Sort by index for deterministic output.
222        results.sort_by_key(|r| r.index);
223        Ok(results)
224    }
225}
226
227#[async_trait]
228impl Tool for PipelineTool {
229    fn name(&self) -> &str {
230        "execute_pipeline"
231    }
232
233    fn description(&self) -> &str {
234        "Execute a multi-step tool pipeline in a single call. Steps run sequentially by default \
235         with result interpolation (use {{step[N].result}} to reference prior outputs), \
236         or in parallel when 'parallel: true' is set. Set 'result: \"last\"' to return only the \
237         final step's output (recommended when an earlier step yields a large blob, e.g. base64, \
238         that should not flow back into the context); the default 'all' returns every step's result."
239    }
240
241    fn parameters_schema(&self) -> serde_json::Value {
242        serde_json::json!({
243            "type": "object",
244            "properties": {
245                "steps": {
246                    "type": "array",
247                    "description": "Ordered list of tool invocations",
248                    "items": {
249                        "type": "object",
250                        "properties": {
251                            "tool": {
252                                "type": "string",
253                                "description": "Name of the tool to invoke"
254                            },
255                            "args": {
256                                "type": "object",
257                                "description": "Arguments to pass to the tool. Use {{step[N].result}} to interpolate prior step outputs."
258                            }
259                        },
260                        "required": ["tool", "args"]
261                    }
262                },
263                "parallel": {
264                    "type": "boolean",
265                    "description": "Run steps in parallel (no interpolation). Default: false",
266                    "default": false
267                },
268                "result": {
269                    "type": "string",
270                    "enum": ["all", "last"],
271                    "description": "What to return: 'all' (default) = every step's result as JSON; 'last' = only the final step's raw output. Use 'last' to keep large intermediate blobs (e.g. base64) out of the context.",
272                    "default": "all"
273                }
274            },
275            "required": ["steps"]
276        })
277    }
278
279    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
280        let request: PipelineRequest = serde_json::from_value(args).map_err(|e| {
281            ::zeroclaw_log::record!(
282                WARN,
283                ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
284                    .with_outcome(::zeroclaw_log::EventOutcome::Failure)
285                    .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
286                "pipeline: invalid request"
287            );
288            anyhow::Error::msg(format!("Invalid pipeline request: {e}"))
289        })?;
290
291        // Validate before execution.
292        if let Err(e) = self.validate(&request) {
293            return Ok(ToolResult {
294                success: false,
295                output: String::new(),
296                error: Some(e.to_string()),
297            });
298        }
299
300        let results = if request.parallel {
301            self.execute_parallel(&request.steps).await
302        } else {
303            self.execute_sequential(&request.steps).await
304        };
305
306        match results {
307            Ok(step_results) => {
308                let output = match request.result {
309                    PipelineResultMode::Last => step_results
310                        .last()
311                        .map(|s| s.output.clone())
312                        .unwrap_or_default(),
313                    PipelineResultMode::All => serde_json::to_string_pretty(&step_results)
314                        .unwrap_or_else(|_| "Pipeline completed".to_string()),
315                };
316                Ok(ToolResult {
317                    success: true,
318                    output,
319                    error: None,
320                })
321            }
322            Err(e) => Ok(ToolResult {
323                success: false,
324                output: String::new(),
325                error: Some(e.to_string()),
326            }),
327        }
328    }
329}
330
331/// Interpolate `{{step[N].result}}` references in tool arguments.
332///
333/// Single-pass replacement: values containing `{{` after substitution are stripped
334/// to prevent injection.
335pub fn interpolate_args(
336    args: &serde_json::Value,
337    prior_results: &[StepResult],
338) -> serde_json::Value {
339    match args {
340        serde_json::Value::String(s) => {
341            let interpolated = interpolate_string(s, prior_results);
342            serde_json::Value::String(interpolated)
343        }
344        serde_json::Value::Object(map) => {
345            let new_map: serde_json::Map<String, serde_json::Value> = map
346                .iter()
347                .map(|(k, v)| (k.clone(), interpolate_args(v, prior_results)))
348                .collect();
349            serde_json::Value::Object(new_map)
350        }
351        serde_json::Value::Array(arr) => {
352            let new_arr: Vec<serde_json::Value> = arr
353                .iter()
354                .map(|v| interpolate_args(v, prior_results))
355                .collect();
356            serde_json::Value::Array(new_arr)
357        }
358        other => other.clone(),
359    }
360}
361
362/// Perform single-pass interpolation of `{{step[N].result}}` in a string.
363fn interpolate_string(s: &str, prior_results: &[StepResult]) -> String {
364    let mut result = String::with_capacity(s.len());
365    let mut chars = s.char_indices().peekable();
366
367    while let Some((i, c)) = chars.next() {
368        if c == '{'
369            && let Some(&(_, '{')) = chars.peek()
370        {
371            // Found `{{` — try to match `{{step[N].result}}`
372            let rest = &s[i..];
373            if let Some(end) = find_template_end(rest) {
374                let template = &rest[2..end]; // strip {{ and }}
375                if let Some(value) = resolve_template(template, prior_results) {
376                    // Strip any `{{` in the resolved value to prevent injection.
377                    result.push_str(&value.replace("{{", ""));
378                    // Skip past the closing `}}`
379                    let skip_to = i + end + 2;
380                    while chars.peek().is_some_and(|&(idx, _)| idx < skip_to) {
381                        chars.next();
382                    }
383                    continue;
384                }
385            }
386        }
387        result.push(c);
388    }
389
390    result
391}
392
393/// Find the position of `}}` in a string starting with `{{`.
394fn find_template_end(s: &str) -> Option<usize> {
395    s[2..].find("}}").map(|pos| pos + 2)
396}
397
398/// Resolve a template reference like `step[0].result`.
399fn resolve_template(template: &str, prior_results: &[StepResult]) -> Option<String> {
400    let template = template.trim();
401    if !template.starts_with("step[") || !template.ends_with(".result") {
402        return None;
403    }
404
405    let bracket_end = template.find(']')?;
406    let index_str = &template[5..bracket_end];
407    let index: usize = index_str.parse().ok()?;
408
409    prior_results
410        .iter()
411        .find(|r| r.index == index)
412        .map(|r| r.output.clone())
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    // ── Interpolation ──────────────────────────────────────
420
421    #[test]
422    fn interpolate_simple_reference() {
423        let results = vec![StepResult {
424            index: 0,
425            tool: "web_search".to_string(),
426            success: true,
427            output: "search results here".to_string(),
428        }];
429
430        let args = serde_json::json!({"text": "Summarize: {{step[0].result}}"});
431        let interpolated = interpolate_args(&args, &results);
432        assert_eq!(
433            interpolated["text"].as_str().unwrap(),
434            "Summarize: search results here"
435        );
436    }
437
438    #[test]
439    fn interpolate_multiple_references() {
440        let results = vec![
441            StepResult {
442                index: 0,
443                tool: "a".to_string(),
444                success: true,
445                output: "first".to_string(),
446            },
447            StepResult {
448                index: 1,
449                tool: "b".to_string(),
450                success: true,
451                output: "second".to_string(),
452            },
453        ];
454
455        let args = serde_json::json!({"text": "{{step[0].result}} and {{step[1].result}}"});
456        let interpolated = interpolate_args(&args, &results);
457        assert_eq!(interpolated["text"].as_str().unwrap(), "first and second");
458    }
459
460    #[test]
461    fn interpolate_no_match_passes_through() {
462        let args = serde_json::json!({"text": "no templates here"});
463        let interpolated = interpolate_args(&args, &[]);
464        assert_eq!(interpolated["text"].as_str().unwrap(), "no templates here");
465    }
466
467    #[test]
468    fn interpolate_invalid_index_passes_through() {
469        let args = serde_json::json!({"text": "{{step[99].result}}"});
470        let interpolated = interpolate_args(&args, &[]);
471        // Invalid reference is left as-is.
472        assert_eq!(
473            interpolated["text"].as_str().unwrap(),
474            "{{step[99].result}}"
475        );
476    }
477
478    #[test]
479    fn interpolate_strips_injection() {
480        let results = vec![StepResult {
481            index: 0,
482            tool: "a".to_string(),
483            success: true,
484            output: "value with {{step[1].result}} injection".to_string(),
485        }];
486
487        let args = serde_json::json!({"text": "{{step[0].result}}"});
488        let interpolated = interpolate_args(&args, &results);
489        // The `{{` in the resolved value should be stripped.
490        let text = interpolated["text"].as_str().unwrap();
491        assert!(!text.contains("{{"));
492        assert!(text.contains("step[1].result}} injection"));
493    }
494
495    #[test]
496    fn interpolate_nested_objects() {
497        let results = vec![StepResult {
498            index: 0,
499            tool: "a".to_string(),
500            success: true,
501            output: "data".to_string(),
502        }];
503
504        let args = serde_json::json!({
505            "outer": {
506                "inner": "prefix {{step[0].result}} suffix"
507            }
508        });
509        let interpolated = interpolate_args(&args, &results);
510        assert_eq!(
511            interpolated["outer"]["inner"].as_str().unwrap(),
512            "prefix data suffix"
513        );
514    }
515
516    #[test]
517    fn interpolate_array_values() {
518        let results = vec![StepResult {
519            index: 0,
520            tool: "a".to_string(),
521            success: true,
522            output: "item".to_string(),
523        }];
524
525        let args = serde_json::json!(["{{step[0].result}}", "static"]);
526        let interpolated = interpolate_args(&args, &results);
527        assert_eq!(interpolated[0].as_str().unwrap(), "item");
528        assert_eq!(interpolated[1].as_str().unwrap(), "static");
529    }
530
531    // ── Validation ─────────────────────────────────────────
532
533    #[test]
534    fn validate_too_many_steps() {
535        let config = PipelineConfig {
536            enabled: true,
537            max_steps: 2,
538            allowed_tools: vec!["shell".to_string()],
539        };
540        let tool = PipelineTool::new(config, vec![]);
541
542        let request = PipelineRequest {
543            steps: vec![
544                PipelineStep {
545                    tool: "shell".into(),
546                    args: serde_json::json!({}),
547                },
548                PipelineStep {
549                    tool: "shell".into(),
550                    args: serde_json::json!({}),
551                },
552                PipelineStep {
553                    tool: "shell".into(),
554                    args: serde_json::json!({}),
555                },
556            ],
557            parallel: false,
558            result: PipelineResultMode::default(),
559        };
560
561        let err = tool.validate(&request).unwrap_err();
562        assert!(matches!(err, PipelineError::TooManySteps(2)));
563    }
564
565    #[test]
566    fn validate_unknown_tool() {
567        let config = PipelineConfig {
568            enabled: true,
569            max_steps: 20,
570            allowed_tools: vec!["shell".to_string()],
571        };
572        let tool = PipelineTool::new(config, vec![]);
573
574        let request = PipelineRequest {
575            steps: vec![PipelineStep {
576                tool: "forbidden_tool".into(),
577                args: serde_json::json!({}),
578            }],
579            parallel: false,
580            result: PipelineResultMode::default(),
581        };
582
583        let err = tool.validate(&request).unwrap_err();
584        assert!(matches!(err, PipelineError::UnknownTool(_)));
585    }
586
587    #[test]
588    fn validate_valid_request() {
589        let config = PipelineConfig {
590            enabled: true,
591            max_steps: 20,
592            allowed_tools: vec!["shell".to_string(), "file_read".to_string()],
593        };
594        let tool = PipelineTool::new(config, vec![]);
595
596        let request = PipelineRequest {
597            steps: vec![
598                PipelineStep {
599                    tool: "shell".into(),
600                    args: serde_json::json!({}),
601                },
602                PipelineStep {
603                    tool: "file_read".into(),
604                    args: serde_json::json!({}),
605                },
606            ],
607            parallel: false,
608            result: PipelineResultMode::default(),
609        };
610
611        assert!(tool.validate(&request).is_ok());
612    }
613
614    #[test]
615    fn validate_empty_pipeline() {
616        let config = PipelineConfig {
617            enabled: true,
618            max_steps: 20,
619            allowed_tools: vec![],
620        };
621        let tool = PipelineTool::new(config, vec![]);
622
623        let request = PipelineRequest {
624            steps: vec![],
625            parallel: false,
626            result: PipelineResultMode::default(),
627        };
628
629        assert!(tool.validate(&request).is_ok());
630    }
631
632    // ── Template resolution ────────────────────────────────
633
634    #[test]
635    fn resolve_valid_template() {
636        let results = vec![StepResult {
637            index: 0,
638            tool: "a".to_string(),
639            success: true,
640            output: "hello".to_string(),
641        }];
642        assert_eq!(
643            resolve_template("step[0].result", &results),
644            Some("hello".to_string())
645        );
646    }
647
648    #[test]
649    fn resolve_invalid_template_format() {
650        assert_eq!(resolve_template("invalid", &[]), None);
651        assert_eq!(resolve_template("step.result", &[]), None);
652        assert_eq!(resolve_template("step[abc].result", &[]), None);
653    }
654
655    #[test]
656    fn resolve_out_of_range_index() {
657        assert_eq!(resolve_template("step[5].result", &[]), None);
658    }
659
660    // ── Result mode ────────────────────────────────────────
661
662    struct EchoTool {
663        name: String,
664        output: String,
665    }
666
667    zeroclaw_api::mock_tool_attribution!(EchoTool);
668
669    #[async_trait::async_trait]
670    impl Tool for EchoTool {
671        fn name(&self) -> &str {
672            &self.name
673        }
674        fn description(&self) -> &str {
675            "echo"
676        }
677        fn parameters_schema(&self) -> serde_json::Value {
678            serde_json::json!({"type": "object"})
679        }
680        async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult> {
681            Ok(ToolResult {
682                success: true,
683                output: self.output.clone(),
684                error: None,
685            })
686        }
687    }
688
689    fn echo_pipeline() -> PipelineTool {
690        let config = PipelineConfig {
691            enabled: true,
692            max_steps: 20,
693            allowed_tools: vec!["a".to_string(), "b".to_string()],
694        };
695        let tools: Vec<Arc<dyn Tool>> = vec![
696            Arc::new(EchoTool {
697                name: "a".into(),
698                output: "FIRST_BIG_BLOB".into(),
699            }),
700            Arc::new(EchoTool {
701                name: "b".into(),
702                output: "final answer".into(),
703            }),
704        ];
705        PipelineTool::new(config, tools)
706    }
707
708    #[tokio::test]
709    async fn result_last_returns_only_final_output() {
710        let args = serde_json::json!({
711            "steps": [
712                {"tool": "a", "args": {}},
713                {"tool": "b", "args": {}}
714            ],
715            "result": "last"
716        });
717        let res = echo_pipeline().execute(args).await.unwrap();
718        assert!(res.success);
719        assert_eq!(res.output, "final answer");
720        assert!(!res.output.contains("FIRST_BIG_BLOB"));
721    }
722
723    #[tokio::test]
724    async fn result_all_is_default_and_includes_every_step() {
725        let args = serde_json::json!({
726            "steps": [
727                {"tool": "a", "args": {}},
728                {"tool": "b", "args": {}}
729            ]
730        });
731        let res = echo_pipeline().execute(args).await.unwrap();
732        assert!(res.success);
733        assert!(res.output.contains("FIRST_BIG_BLOB"));
734        assert!(res.output.contains("final answer"));
735    }
736}