1use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashSet;
11use std::sync::Arc;
12use zeroclaw_api::tool::{Tool, ToolResult};
13use zeroclaw_config::schema::PipelineConfig;
14
15#[derive(Debug, Clone, Serialize, thiserror::Error)]
17pub enum PipelineError {
18 #[error("Unknown tool '{0}' is not on the allowed list")]
19 UnknownTool(String),
20 #[error("Pipeline exceeds maximum of {0} steps")]
21 TooManySteps(usize),
22 #[error("Invalid template reference: {0}")]
23 InvalidTemplate(String),
24 #[error("Step {index} ({tool}) failed: {message}")]
25 StepFailed {
26 index: usize,
27 tool: String,
28 message: String,
29 },
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct PipelineStep {
35 pub tool: String,
36 pub args: serde_json::Value,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct PipelineRequest {
42 pub steps: Vec<PipelineStep>,
43 #[serde(default)]
44 pub parallel: bool,
45 #[serde(default)]
47 pub result: PipelineResultMode,
48}
49
50#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
52#[serde(rename_all = "snake_case")]
53pub enum PipelineResultMode {
54 #[default]
56 All,
57 Last,
61}
62
63#[derive(Debug, Clone, Serialize)]
65pub struct StepResult {
66 pub index: usize,
67 pub tool: String,
68 pub success: bool,
69 pub output: String,
70}
71
72pub struct PipelineTool {
74 config: PipelineConfig,
75 tools: Vec<Arc<dyn Tool>>,
76 allowed_set: HashSet<String>,
77}
78
79impl PipelineTool {
80 pub fn new(config: PipelineConfig, tools: Vec<Arc<dyn Tool>>) -> Self {
81 let allowed_set: HashSet<String> = config.allowed_tools.iter().cloned().collect();
82 Self {
83 config,
84 tools,
85 allowed_set,
86 }
87 }
88
89 fn find_tool(&self, name: &str) -> Option<&dyn Tool> {
91 self.tools
92 .iter()
93 .find(|t| t.name() == name)
94 .map(|t| t.as_ref())
95 }
96
97 fn validate(&self, request: &PipelineRequest) -> std::result::Result<(), PipelineError> {
99 if request.steps.len() > self.config.max_steps {
100 return Err(PipelineError::TooManySteps(self.config.max_steps));
101 }
102
103 for step in &request.steps {
105 if !self.allowed_set.contains(&step.tool) {
106 return Err(PipelineError::UnknownTool(step.tool.clone()));
107 }
108 }
109
110 Ok(())
111 }
112
113 async fn execute_sequential(
115 &self,
116 steps: &[PipelineStep],
117 ) -> std::result::Result<Vec<StepResult>, PipelineError> {
118 let mut results: Vec<StepResult> = Vec::with_capacity(steps.len());
119
120 for (i, step) in steps.iter().enumerate() {
121 let tool = self
122 .find_tool(&step.tool)
123 .ok_or_else(|| PipelineError::UnknownTool(step.tool.clone()))?;
124
125 let interpolated_args = interpolate_args(&step.args, &results);
127
128 let tool_result =
129 tool.execute(interpolated_args)
130 .await
131 .map_err(|e| PipelineError::StepFailed {
132 index: i,
133 tool: step.tool.clone(),
134 message: e.to_string(),
135 })?;
136
137 if !tool_result.success {
138 return Err(PipelineError::StepFailed {
139 index: i,
140 tool: step.tool.clone(),
141 message: tool_result
142 .error
143 .unwrap_or_else(|| tool_result.output.clone()),
144 });
145 }
146
147 results.push(StepResult {
148 index: i,
149 tool: step.tool.clone(),
150 success: true,
151 output: tool_result.output,
152 });
153 }
154
155 Ok(results)
156 }
157
158 async fn execute_parallel(
160 &self,
161 steps: &[PipelineStep],
162 ) -> std::result::Result<Vec<StepResult>, PipelineError> {
163 use tokio::task::JoinSet;
164
165 let mut join_set = JoinSet::new();
166
167 for (i, step) in steps.iter().enumerate() {
168 let tool = self
169 .find_tool(&step.tool)
170 .ok_or_else(|| PipelineError::UnknownTool(step.tool.clone()))?;
171
172 let tool_name = step.tool.clone();
174 let args = step.args.clone();
175
176 let tool_arc = self.tools.iter().find(|t| t.name() == tool.name()).cloned();
178
179 if let Some(tool_arc) = tool_arc {
180 join_set.spawn(async move {
181 let result = tool_arc.execute(args).await;
182 (i, tool_name, result)
183 });
184 }
185 }
186
187 let mut results: Vec<StepResult> = Vec::with_capacity(steps.len());
188
189 while let Some(join_result) = join_set.join_next().await {
190 let (index, tool_name, tool_result) =
191 join_result.map_err(|e| PipelineError::StepFailed {
192 index: 0,
193 tool: "unknown".to_string(),
194 message: format!("Task join error: {e}"),
195 })?;
196
197 let tool_result = tool_result.map_err(|e| PipelineError::StepFailed {
198 index,
199 tool: tool_name.clone(),
200 message: e.to_string(),
201 })?;
202
203 if !tool_result.success {
204 return Err(PipelineError::StepFailed {
205 index,
206 tool: tool_name,
207 message: tool_result
208 .error
209 .unwrap_or_else(|| tool_result.output.clone()),
210 });
211 }
212
213 results.push(StepResult {
214 index,
215 tool: tool_name,
216 success: true,
217 output: tool_result.output,
218 });
219 }
220
221 results.sort_by_key(|r| r.index);
223 Ok(results)
224 }
225}
226
227#[async_trait]
228impl Tool for PipelineTool {
229 fn name(&self) -> &str {
230 "execute_pipeline"
231 }
232
233 fn description(&self) -> &str {
234 "Execute a multi-step tool pipeline in a single call. Steps run sequentially by default \
235 with result interpolation (use {{step[N].result}} to reference prior outputs), \
236 or in parallel when 'parallel: true' is set. Set 'result: \"last\"' to return only the \
237 final step's output (recommended when an earlier step yields a large blob, e.g. base64, \
238 that should not flow back into the context); the default 'all' returns every step's result."
239 }
240
241 fn parameters_schema(&self) -> serde_json::Value {
242 serde_json::json!({
243 "type": "object",
244 "properties": {
245 "steps": {
246 "type": "array",
247 "description": "Ordered list of tool invocations",
248 "items": {
249 "type": "object",
250 "properties": {
251 "tool": {
252 "type": "string",
253 "description": "Name of the tool to invoke"
254 },
255 "args": {
256 "type": "object",
257 "description": "Arguments to pass to the tool. Use {{step[N].result}} to interpolate prior step outputs."
258 }
259 },
260 "required": ["tool", "args"]
261 }
262 },
263 "parallel": {
264 "type": "boolean",
265 "description": "Run steps in parallel (no interpolation). Default: false",
266 "default": false
267 },
268 "result": {
269 "type": "string",
270 "enum": ["all", "last"],
271 "description": "What to return: 'all' (default) = every step's result as JSON; 'last' = only the final step's raw output. Use 'last' to keep large intermediate blobs (e.g. base64) out of the context.",
272 "default": "all"
273 }
274 },
275 "required": ["steps"]
276 })
277 }
278
279 async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
280 let request: PipelineRequest = serde_json::from_value(args).map_err(|e| {
281 ::zeroclaw_log::record!(
282 WARN,
283 ::zeroclaw_log::Event::new(module_path!(), ::zeroclaw_log::Action::Reject)
284 .with_outcome(::zeroclaw_log::EventOutcome::Failure)
285 .with_attrs(::serde_json::json!({"error": format!("{}", e)})),
286 "pipeline: invalid request"
287 );
288 anyhow::Error::msg(format!("Invalid pipeline request: {e}"))
289 })?;
290
291 if let Err(e) = self.validate(&request) {
293 return Ok(ToolResult {
294 success: false,
295 output: String::new(),
296 error: Some(e.to_string()),
297 });
298 }
299
300 let results = if request.parallel {
301 self.execute_parallel(&request.steps).await
302 } else {
303 self.execute_sequential(&request.steps).await
304 };
305
306 match results {
307 Ok(step_results) => {
308 let output = match request.result {
309 PipelineResultMode::Last => step_results
310 .last()
311 .map(|s| s.output.clone())
312 .unwrap_or_default(),
313 PipelineResultMode::All => serde_json::to_string_pretty(&step_results)
314 .unwrap_or_else(|_| "Pipeline completed".to_string()),
315 };
316 Ok(ToolResult {
317 success: true,
318 output,
319 error: None,
320 })
321 }
322 Err(e) => Ok(ToolResult {
323 success: false,
324 output: String::new(),
325 error: Some(e.to_string()),
326 }),
327 }
328 }
329}
330
331pub fn interpolate_args(
336 args: &serde_json::Value,
337 prior_results: &[StepResult],
338) -> serde_json::Value {
339 match args {
340 serde_json::Value::String(s) => {
341 let interpolated = interpolate_string(s, prior_results);
342 serde_json::Value::String(interpolated)
343 }
344 serde_json::Value::Object(map) => {
345 let new_map: serde_json::Map<String, serde_json::Value> = map
346 .iter()
347 .map(|(k, v)| (k.clone(), interpolate_args(v, prior_results)))
348 .collect();
349 serde_json::Value::Object(new_map)
350 }
351 serde_json::Value::Array(arr) => {
352 let new_arr: Vec<serde_json::Value> = arr
353 .iter()
354 .map(|v| interpolate_args(v, prior_results))
355 .collect();
356 serde_json::Value::Array(new_arr)
357 }
358 other => other.clone(),
359 }
360}
361
362fn interpolate_string(s: &str, prior_results: &[StepResult]) -> String {
364 let mut result = String::with_capacity(s.len());
365 let mut chars = s.char_indices().peekable();
366
367 while let Some((i, c)) = chars.next() {
368 if c == '{'
369 && let Some(&(_, '{')) = chars.peek()
370 {
371 let rest = &s[i..];
373 if let Some(end) = find_template_end(rest) {
374 let template = &rest[2..end]; if let Some(value) = resolve_template(template, prior_results) {
376 result.push_str(&value.replace("{{", ""));
378 let skip_to = i + end + 2;
380 while chars.peek().is_some_and(|&(idx, _)| idx < skip_to) {
381 chars.next();
382 }
383 continue;
384 }
385 }
386 }
387 result.push(c);
388 }
389
390 result
391}
392
393fn find_template_end(s: &str) -> Option<usize> {
395 s[2..].find("}}").map(|pos| pos + 2)
396}
397
398fn resolve_template(template: &str, prior_results: &[StepResult]) -> Option<String> {
400 let template = template.trim();
401 if !template.starts_with("step[") || !template.ends_with(".result") {
402 return None;
403 }
404
405 let bracket_end = template.find(']')?;
406 let index_str = &template[5..bracket_end];
407 let index: usize = index_str.parse().ok()?;
408
409 prior_results
410 .iter()
411 .find(|r| r.index == index)
412 .map(|r| r.output.clone())
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
422 fn interpolate_simple_reference() {
423 let results = vec![StepResult {
424 index: 0,
425 tool: "web_search".to_string(),
426 success: true,
427 output: "search results here".to_string(),
428 }];
429
430 let args = serde_json::json!({"text": "Summarize: {{step[0].result}}"});
431 let interpolated = interpolate_args(&args, &results);
432 assert_eq!(
433 interpolated["text"].as_str().unwrap(),
434 "Summarize: search results here"
435 );
436 }
437
438 #[test]
439 fn interpolate_multiple_references() {
440 let results = vec![
441 StepResult {
442 index: 0,
443 tool: "a".to_string(),
444 success: true,
445 output: "first".to_string(),
446 },
447 StepResult {
448 index: 1,
449 tool: "b".to_string(),
450 success: true,
451 output: "second".to_string(),
452 },
453 ];
454
455 let args = serde_json::json!({"text": "{{step[0].result}} and {{step[1].result}}"});
456 let interpolated = interpolate_args(&args, &results);
457 assert_eq!(interpolated["text"].as_str().unwrap(), "first and second");
458 }
459
460 #[test]
461 fn interpolate_no_match_passes_through() {
462 let args = serde_json::json!({"text": "no templates here"});
463 let interpolated = interpolate_args(&args, &[]);
464 assert_eq!(interpolated["text"].as_str().unwrap(), "no templates here");
465 }
466
467 #[test]
468 fn interpolate_invalid_index_passes_through() {
469 let args = serde_json::json!({"text": "{{step[99].result}}"});
470 let interpolated = interpolate_args(&args, &[]);
471 assert_eq!(
473 interpolated["text"].as_str().unwrap(),
474 "{{step[99].result}}"
475 );
476 }
477
478 #[test]
479 fn interpolate_strips_injection() {
480 let results = vec![StepResult {
481 index: 0,
482 tool: "a".to_string(),
483 success: true,
484 output: "value with {{step[1].result}} injection".to_string(),
485 }];
486
487 let args = serde_json::json!({"text": "{{step[0].result}}"});
488 let interpolated = interpolate_args(&args, &results);
489 let text = interpolated["text"].as_str().unwrap();
491 assert!(!text.contains("{{"));
492 assert!(text.contains("step[1].result}} injection"));
493 }
494
495 #[test]
496 fn interpolate_nested_objects() {
497 let results = vec![StepResult {
498 index: 0,
499 tool: "a".to_string(),
500 success: true,
501 output: "data".to_string(),
502 }];
503
504 let args = serde_json::json!({
505 "outer": {
506 "inner": "prefix {{step[0].result}} suffix"
507 }
508 });
509 let interpolated = interpolate_args(&args, &results);
510 assert_eq!(
511 interpolated["outer"]["inner"].as_str().unwrap(),
512 "prefix data suffix"
513 );
514 }
515
516 #[test]
517 fn interpolate_array_values() {
518 let results = vec![StepResult {
519 index: 0,
520 tool: "a".to_string(),
521 success: true,
522 output: "item".to_string(),
523 }];
524
525 let args = serde_json::json!(["{{step[0].result}}", "static"]);
526 let interpolated = interpolate_args(&args, &results);
527 assert_eq!(interpolated[0].as_str().unwrap(), "item");
528 assert_eq!(interpolated[1].as_str().unwrap(), "static");
529 }
530
531 #[test]
534 fn validate_too_many_steps() {
535 let config = PipelineConfig {
536 enabled: true,
537 max_steps: 2,
538 allowed_tools: vec!["shell".to_string()],
539 };
540 let tool = PipelineTool::new(config, vec![]);
541
542 let request = PipelineRequest {
543 steps: vec![
544 PipelineStep {
545 tool: "shell".into(),
546 args: serde_json::json!({}),
547 },
548 PipelineStep {
549 tool: "shell".into(),
550 args: serde_json::json!({}),
551 },
552 PipelineStep {
553 tool: "shell".into(),
554 args: serde_json::json!({}),
555 },
556 ],
557 parallel: false,
558 result: PipelineResultMode::default(),
559 };
560
561 let err = tool.validate(&request).unwrap_err();
562 assert!(matches!(err, PipelineError::TooManySteps(2)));
563 }
564
565 #[test]
566 fn validate_unknown_tool() {
567 let config = PipelineConfig {
568 enabled: true,
569 max_steps: 20,
570 allowed_tools: vec!["shell".to_string()],
571 };
572 let tool = PipelineTool::new(config, vec![]);
573
574 let request = PipelineRequest {
575 steps: vec![PipelineStep {
576 tool: "forbidden_tool".into(),
577 args: serde_json::json!({}),
578 }],
579 parallel: false,
580 result: PipelineResultMode::default(),
581 };
582
583 let err = tool.validate(&request).unwrap_err();
584 assert!(matches!(err, PipelineError::UnknownTool(_)));
585 }
586
587 #[test]
588 fn validate_valid_request() {
589 let config = PipelineConfig {
590 enabled: true,
591 max_steps: 20,
592 allowed_tools: vec!["shell".to_string(), "file_read".to_string()],
593 };
594 let tool = PipelineTool::new(config, vec![]);
595
596 let request = PipelineRequest {
597 steps: vec![
598 PipelineStep {
599 tool: "shell".into(),
600 args: serde_json::json!({}),
601 },
602 PipelineStep {
603 tool: "file_read".into(),
604 args: serde_json::json!({}),
605 },
606 ],
607 parallel: false,
608 result: PipelineResultMode::default(),
609 };
610
611 assert!(tool.validate(&request).is_ok());
612 }
613
614 #[test]
615 fn validate_empty_pipeline() {
616 let config = PipelineConfig {
617 enabled: true,
618 max_steps: 20,
619 allowed_tools: vec![],
620 };
621 let tool = PipelineTool::new(config, vec![]);
622
623 let request = PipelineRequest {
624 steps: vec![],
625 parallel: false,
626 result: PipelineResultMode::default(),
627 };
628
629 assert!(tool.validate(&request).is_ok());
630 }
631
632 #[test]
635 fn resolve_valid_template() {
636 let results = vec![StepResult {
637 index: 0,
638 tool: "a".to_string(),
639 success: true,
640 output: "hello".to_string(),
641 }];
642 assert_eq!(
643 resolve_template("step[0].result", &results),
644 Some("hello".to_string())
645 );
646 }
647
648 #[test]
649 fn resolve_invalid_template_format() {
650 assert_eq!(resolve_template("invalid", &[]), None);
651 assert_eq!(resolve_template("step.result", &[]), None);
652 assert_eq!(resolve_template("step[abc].result", &[]), None);
653 }
654
655 #[test]
656 fn resolve_out_of_range_index() {
657 assert_eq!(resolve_template("step[5].result", &[]), None);
658 }
659
660 struct EchoTool {
663 name: String,
664 output: String,
665 }
666
667 zeroclaw_api::mock_tool_attribution!(EchoTool);
668
669 #[async_trait::async_trait]
670 impl Tool for EchoTool {
671 fn name(&self) -> &str {
672 &self.name
673 }
674 fn description(&self) -> &str {
675 "echo"
676 }
677 fn parameters_schema(&self) -> serde_json::Value {
678 serde_json::json!({"type": "object"})
679 }
680 async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult> {
681 Ok(ToolResult {
682 success: true,
683 output: self.output.clone(),
684 error: None,
685 })
686 }
687 }
688
689 fn echo_pipeline() -> PipelineTool {
690 let config = PipelineConfig {
691 enabled: true,
692 max_steps: 20,
693 allowed_tools: vec!["a".to_string(), "b".to_string()],
694 };
695 let tools: Vec<Arc<dyn Tool>> = vec![
696 Arc::new(EchoTool {
697 name: "a".into(),
698 output: "FIRST_BIG_BLOB".into(),
699 }),
700 Arc::new(EchoTool {
701 name: "b".into(),
702 output: "final answer".into(),
703 }),
704 ];
705 PipelineTool::new(config, tools)
706 }
707
708 #[tokio::test]
709 async fn result_last_returns_only_final_output() {
710 let args = serde_json::json!({
711 "steps": [
712 {"tool": "a", "args": {}},
713 {"tool": "b", "args": {}}
714 ],
715 "result": "last"
716 });
717 let res = echo_pipeline().execute(args).await.unwrap();
718 assert!(res.success);
719 assert_eq!(res.output, "final answer");
720 assert!(!res.output.contains("FIRST_BIG_BLOB"));
721 }
722
723 #[tokio::test]
724 async fn result_all_is_default_and_includes_every_step() {
725 let args = serde_json::json!({
726 "steps": [
727 {"tool": "a", "args": {}},
728 {"tool": "b", "args": {}}
729 ]
730 });
731 let res = echo_pipeline().execute(args).await.unwrap();
732 assert!(res.success);
733 assert!(res.output.contains("FIRST_BIG_BLOB"));
734 assert!(res.output.contains("final answer"));
735 }
736}