diff --git a/src/workflow/planner.py b/src/workflow/planner.py index f3e1cb91..c6cf2dcf 100644 --- a/src/workflow/planner.py +++ b/src/workflow/planner.py @@ -77,22 +77,37 @@ def parse_plan(raw: str) -> Plan: except json.JSONDecodeError: args[n] = {} - steps = [ - PlanStep( - step_number=n, - task=tasks[n], - server=servers.get(n, ""), - tool=tools.get(n, ""), - tool_args=args.get(n, {}), - dependencies=( - [] - if deps_raw.get(n, "None").strip().lower() == "none" - else [int(x) for x in _DEP_NUM_RE.findall(deps_raw.get(n, ""))] - ), - expected_output=outputs.get(n, ""), + steps = [] + for n in sorted(tasks): + raw_dep = deps_raw.get(n, "None").strip() + + if raw_dep.lower() == "none": + dependencies = [] + else: + dependencies = [int(x) for x in _DEP_NUM_RE.findall(raw_dep)] + + # Make sure dependency references only point to earlier valid steps. + if not dependencies: + raise ValueError(f"Invalid dependency format for step {n}: {raw_dep}") + + for dep in dependencies: + if dep < 1 or dep >= n: + raise ValueError( + f"Invalid dependency reference for step {n}: #S{dep}" + ) + + steps.append( + PlanStep( + step_number=n, + task=tasks[n], + server=servers.get(n, ""), + tool=tools.get(n, ""), + tool_args=args.get(n, {}), + dependencies=dependencies, + expected_output=outputs.get(n, ""), + ) ) - for n in sorted(tasks) - ] + return Plan(steps=steps, raw=raw) @@ -121,4 +136,4 @@ def generate_plan( ) prompt = _PLAN_PROMPT.format(servers=servers_text, question=question) raw = self._llm.generate(prompt) - return parse_plan(raw) + return parse_plan(raw) \ No newline at end of file