Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ ignore = [
"A001", # Variable shadows built-in
"A002", # Argument shadows built-in
"FBT001", # Boolean positional arg
"FBT002", # Boolean default value
"N801", # Class name casing
"N802", # Function name casing
"N806", # Variable casing
Expand Down
4 changes: 2 additions & 2 deletions scripts/migrate_to_jinja2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class TemplateMigrator:
"""Migrates custom template syntax to Jinja2."""

def __init__(self, dry_run: bool = False):
def __init__(self, dry_run: bool):
self.dry_run = dry_run
self.transformations: List[Tuple[str, str]] = []

Expand Down Expand Up @@ -133,7 +133,7 @@ def _show_diff(self, original: str, migrated: str):
print(f" - {orig}")
print(f" + {mig}")

def migrate_directory(self, directory: Path, recursive: bool = True) -> int:
def migrate_directory(self, directory: Path, recursive: bool) -> int:
"""Migrate all YAML files in directory.

Returns:
Expand Down
4 changes: 2 additions & 2 deletions src/seclab_taskflow_agent/_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def drive_backend_stream(
watchdog_ping()
if isinstance(event, TextDelta):
await render_model_output(
event.text, async_task=async_task, task_id=task_id
event.text, log=True, async_task=async_task, task_id=task_id
)
elif isinstance(event, ToolEnd):
await bridge_copilot_tool_event(event, run_hooks)
Expand All @@ -120,7 +120,7 @@ async def drive_backend_stream(
await aclose()
except Exception: # noqa: BLE001 - best-effort cleanup
logging.exception("Failed to aclose backend stream iterator")
await render_model_output("\n\n", async_task=async_task, task_id=task_id)
await render_model_output("\n\n", log=True, async_task=async_task, task_id=task_id)
return
except BackendTimeoutError:
if not max_retry:
Expand Down
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def __init__(
name: str = "TaskAgent",
instructions: str = "",
handoffs: list[Any] | None = None,
exclude_from_context: bool = False,
*,
exclude_from_context: bool,
mcp_servers: list[Any] | None = None,
model: str = DEFAULT_MODEL,
model_settings: ModelSettings | None = None,
Expand Down
4 changes: 2 additions & 2 deletions src/seclab_taskflow_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ def main(
list_models: Annotated[
bool,
typer.Option("-l", "--list-models", help="List available tool-call models and exit."),
] = False,
] = typer.Option(False),
globals_: Annotated[
list[str] | None,
typer.Option("-g", "--global", help="Global variable as KEY=VALUE. Repeatable."),
] = None,
debug: Annotated[
bool,
typer.Option("-d", "--debug", help="Show full tracebacks on errors."),
] = False,
] = typer.Option(False),
resume: Annotated[
str | None,
typer.Option("--resume", help="Resume a previous session by its ID."),
Expand Down
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/mcp_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def build_mcp_servers(
available_tools: AvailableTools,
toolboxes: list[str],
blocked_tools: list[str] | None = None,
headless: bool = False,
*,
headless: bool,
) -> list[MCPServerEntry]:
"""Build MCP server instances for the given toolboxes.

Expand Down
22 changes: 12 additions & 10 deletions src/seclab_taskflow_agent/mcp_servers/codeql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __init__(
self,
codeql_cli=os.getenv("CODEQL_CLI", default="codeql"),
server_options=["--threads=0", "--quiet"],
log_stderr=False,
*,
log_stderr: bool,
):
self.server_options = server_options.copy()
if log_stderr:
Expand Down Expand Up @@ -406,7 +407,7 @@ def _bqrs_to_sarif(self, bqrs_path, query_info, max_paths=10):


class QueryServer(CodeQL):
def __init__(self, database: Path, keep_alive=False, log_stderr=False):
def __init__(self, database: Path, keep_alive: bool, log_stderr: bool):
super().__init__(log_stderr=log_stderr)
self.database = database
self.keep_alive = keep_alive
Expand Down Expand Up @@ -476,7 +477,7 @@ def _file_uri_to_path(uri):
return path, region


def _get_source_prefix(database_path: Path, strip_leading_slash=True) -> str:
def _get_source_prefix(database_path: Path, strip_leading_slash: bool) -> str:
# grab the source prefix from codeql-database.yml
db_yml_path = Path(database_path) / Path("codeql-database.yml")
with open(db_yml_path) as stream:
Expand All @@ -491,10 +492,10 @@ def _get_source_prefix(database_path: Path, strip_leading_slash=True) -> str:
raise


def list_src_files(database_path: str | Path, as_uri=False, strip_prefix=True):
def list_src_files(database_path: str | Path, as_uri: bool, strip_prefix: bool):
src_path = Path(database_path) / Path("src.zip")
files = shell_command_to_string(["zipinfo", "-1", src_path]).split("\n")
source_prefix = _get_source_prefix(Path(database_path))
source_prefix = _get_source_prefix(Path(database_path), strip_leading_slash=True)
# file:// uri are formatted absolute paths even if they're relative
files = [
f"{'file:///' if as_uri else ''}{path.strip().removeprefix(source_prefix if strip_prefix else '')}"
Expand All @@ -503,11 +504,11 @@ def list_src_files(database_path: str | Path, as_uri=False, strip_prefix=True):
return files


def search_in_src_archive(database_path: str, search_term: str, as_uri=False, strip_prefix=True):
def search_in_src_archive(database_path: str, search_term: str, as_uri: bool, strip_prefix: bool):
database_path = Path(database_path)
src_path = database_path / Path("src.zip")
results = {}
source_prefix = _get_source_prefix(database_path)
source_prefix = _get_source_prefix(database_path, strip_leading_slash=True)
with zipfile.ZipFile(src_path) as z:
for entry in z.infolist():
if entry.is_dir():
Expand All @@ -528,7 +529,7 @@ def _file_from_src_archive(relative_path: str | Path, database_path: str | Path,
# our shell utility is Popen based, so no expansions occur
database_path = Path(database_path)
src_path = database_path / Path("src.zip")
source_prefix = _get_source_prefix(Path(database_path))
source_prefix = _get_source_prefix(Path(database_path), strip_leading_slash=True)
# normalize relative path
relative_path = Path(str(relative_path).lstrip("/").removeprefix(source_prefix))
resolved_path = Path(source_prefix) / Path(relative_path)
Expand Down Expand Up @@ -596,8 +597,9 @@ def run_query(
progress_callback=None,
template_values=None,
# keep the query server alive if desired
keep_alive=True,
log_stderr=False,
*,
keep_alive: bool,
log_stderr: bool,
):
result = ""
query_path = Path(query_path)
Expand Down
5 changes: 3 additions & 2 deletions src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def _run_query(query_name: str, database_path: str, language: str, template_valu
database_path,
fmt="csv",
template_values=template_values,
keep_alive=True,
log_stderr=True,
)
return _csv_to_json_obj(csv)
Expand Down Expand Up @@ -150,7 +151,7 @@ def list_source_files(
):
"""List the available source files in a CodeQL database using their file:// URI"""
database_path = _resolve_db_path(database_path)
results = list_src_files(database_path, as_uri=True)
results = list_src_files(database_path, as_uri=True, strip_prefix=True)
return json.dumps([{"uri": item} for item in results if re.search(regex_filter, item)], indent=2)


Expand All @@ -163,7 +164,7 @@ def search_in_source_code(
Search for a string in the source code. Returns the line number and file.
"""
resolved_database_path = _resolve_db_path(database_path)
results = search_in_src_archive(resolved_database_path, search_term)
results = search_in_src_archive(resolved_database_path, search_term, as_uri=False, strip_prefix=True)
out = []
if isinstance(results, dict):
for k, v in results.items():
Expand Down
6 changes: 3 additions & 3 deletions src/seclab_taskflow_agent/render_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ async def flush_async_output(task_id: str) -> None:
# No buffered output (agent may have failed before producing any).
return
data = async_output.pop(task_id)
await render_model_output(f"** 🤖✏️ Output for async task: {task_id}\n\n")
await render_model_output(data)
await render_model_output(f"** 🤖✏️ Output for async task: {task_id}\n\n", log=True, async_task=False, task_id=None)
await render_model_output(data, log=True, async_task=False, task_id=None)


async def render_model_output(data: str, log: bool = True, async_task: bool = False, task_id: str | None = None) -> None:
async def render_model_output(data: str, log: bool, async_task: bool, task_id: str | None = None) -> None:
"""Print model output to the console, optionally buffering for async tasks."""
async with async_output_lock:
if async_task and task_id:
Expand Down
89 changes: 64 additions & 25 deletions src/seclab_taskflow_agent/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ async def _build_prompts_to_run(
raise

if not iterable_result:
await render_model_output("** 🤖❗MCP tool result iterable is empty!\n")
await render_model_output("** 🤖❗MCP tool result iterable is empty!\n", log=True, async_task=False, task_id=None)
else:
logging.debug("Rendering templated prompts for results: %s", iterable_result)
for value in iterable_result:
Expand Down Expand Up @@ -286,11 +286,21 @@ async def deploy_task_agents(
blocked_tools = blocked_tools or []

task_id = str(uuid.uuid4())
await render_model_output(f"** 🤖💪 Deploying Task Flow Agent(s): {list(agents.keys())}\n")
await render_model_output(f"** 🤖💪 Task ID : {task_id}\n")
await render_model_output(f"** 🤖💪 Model : {model}{', params: ' + str(model_par) if model_par else ''}\n")
await render_model_output(
f"** 🤖💪 Deploying Task Flow Agent(s): {list(agents.keys())}\n",
log=True,
async_task=False,
task_id=None,
)
await render_model_output(f"** 🤖💪 Task ID : {task_id}\n", log=True, async_task=False, task_id=None)
await render_model_output(
f"** 🤖💪 Model : {model}{', params: ' + str(model_par) if model_par else ''}\n",
log=True,
async_task=False,
task_id=None,
)
if endpoint:
await render_model_output(f"** 🤖💪 Endpoint: {endpoint}\n")
await render_model_output(f"** 🤖💪 Endpoint: {endpoint}\n", log=True, async_task=False, task_id=None)

# Resolve toolboxes from personality definitions or override
toolboxes: list[str] = []
Expand All @@ -313,7 +323,7 @@ async def deploy_task_agents(
model_params: dict[str, Any] = dict(model_par)

# Build MCP servers and collect server prompts
entries = build_mcp_servers(available_tools, toolboxes, blocked_tools, headless)
entries = build_mcp_servers(available_tools, toolboxes, blocked_tools, headless=headless)
mcp_params = mcp_client_params(available_tools, toolboxes)
server_prompts = [sp for _, (_, _, sp, _) in mcp_params.items()]

Expand Down Expand Up @@ -428,16 +438,24 @@ async def deploy_task_agents(
complete = True

except BackendMaxTurnsError as e:
await render_model_output(f"** 🤖❗ Max Turns Reached: {e}\n", async_task=async_task, task_id=task_id)
await render_model_output(
f"** 🤖❗ Max Turns Reached: {e}\n", log=True, async_task=async_task, task_id=task_id
)
logging.exception("Exceeded max_turns: %s", max_turns)
except BackendUnexpectedError as e:
await render_model_output(f"** 🤖❗ Agent Exception: {e}\n", async_task=async_task, task_id=task_id)
await render_model_output(
f"** 🤖❗ Agent Exception: {e}\n", log=True, async_task=async_task, task_id=task_id
)
logging.exception("Agent Exception")
except BackendBadRequestError as e:
await render_model_output(f"** 🤖❗ Request Error: {e}\n", async_task=async_task, task_id=task_id)
await render_model_output(
f"** 🤖❗ Request Error: {e}\n", log=True, async_task=async_task, task_id=task_id
)
logging.exception("Bad Request")
except BackendTimeoutError as e:
await render_model_output(f"** 🤖❗ Timeout Error: {e}\n", async_task=async_task, task_id=task_id)
await render_model_output(
f"** 🤖❗ Timeout Error: {e}\n", log=True, async_task=async_task, task_id=task_id
)
logging.exception("API Timeout")

if async_task:
Expand Down Expand Up @@ -519,10 +537,12 @@ async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TC

async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None:
watchdog_ping()
await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n")
await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n", log=True, async_task=False, task_id=None)

async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]) -> None:
await render_model_output(f"\n** 🤖🤝 Agent Handoff: {source.name} -> {agent.name}\n")
await render_model_output(
f"\n** 🤖🤝 Agent Handoff: {source.name} -> {agent.name}\n", log=True, async_task=False, task_id=None
)

if personality_path:
personality = available_tools.get_personality(personality_path)
Expand All @@ -539,7 +559,9 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo
if resume_session_id:
session = TaskflowSession.load(resume_session_id)
if session.finished:
await render_model_output(f"** 🤖✅ Session {resume_session_id} already completed\n")
await render_model_output(
f"** 🤖✅ Session {resume_session_id} already completed\n", log=True, async_task=False, task_id=None
)
return
taskflow_path = session.taskflow_path
cli_globals = session.cli_globals
Expand All @@ -549,11 +571,14 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo
if not cli_model_config and session.cli_model_config:
cli_model_config = session.cli_model_config
await render_model_output(
f"** 🤖🔄 Resuming session {resume_session_id} from task {session.next_task_index}\n"
f"** 🤖🔄 Resuming session {resume_session_id} from task {session.next_task_index}\n",
log=True,
async_task=False,
task_id=None,
)

taskflow_doc = available_tools.get_taskflow(taskflow_path)
await render_model_output(f"** 🤖💪 Running Task Flow: {taskflow_path}\n")
await render_model_output(f"** 🤖💪 Running Task Flow: {taskflow_path}\n", log=True, async_task=False, task_id=None)

# Resolve global variables (file defaults + CLI overrides)
global_variables = dict(taskflow_doc.globals or {})
Expand Down Expand Up @@ -584,13 +609,16 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo
cli_model_config=cli_model_config or "",
)
session.save()
await render_model_output(f"** 🤖📋 Session: {session.session_id}\n")
await render_model_output(f"** 🤖📋 Session: {session.session_id}\n", log=True, async_task=False, task_id=None)

for task_index, task_wrapper in enumerate(taskflow_doc.taskflow):
# Skip already-completed tasks on resume
if task_index < session.next_task_index:
await render_model_output(
f"** 🤖⏭️ Skipping completed task {task_index}\n"
f"** 🤖⏭️ Skipping completed task {task_index}\n",
log=True,
async_task=False,
task_id=None,
)
continue

Expand Down Expand Up @@ -642,15 +670,17 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo
available_tools, global_variables, inputs,
)

async def run_prompts(async_task: bool = False, max_concurrent_tasks: int = 5) -> bool:
async def run_prompts(async_task: bool, max_concurrent_tasks: int = 5) -> bool:
if run:
await render_model_output("** 🤖🐚 Executing Shell Task\n")
await render_model_output("** 🤖🐚 Executing Shell Task\n", log=True, async_task=False, task_id=None)
try:
result = shell_tool_call(run).content[0].model_dump_json()
last_mcp_tool_results.append(result)
return True
except RuntimeError as e:
await render_model_output(f"** 🤖❗ Shell Task Exception: {e}\n")
await render_model_output(
f"** 🤖❗ Shell Task Exception: {e}\n", log=True, async_task=False, task_id=None
)
logging.exception("Shell task error")
return False

Expand Down Expand Up @@ -746,7 +776,10 @@ async def _deploy(ra: dict, pp: str) -> bool:
backoff = TASK_RETRY_BACKOFF * (attempt + 1)
await render_model_output(
f"** 🤖🔄 Task {task_name!r} failed: {exc}\n"
f"** 🤖🔄 Retrying in {backoff}s ({remaining} attempts left)\n"
f"** 🤖🔄 Retrying in {backoff}s ({remaining} attempts left)\n",
log=True,
async_task=False,
task_id=None,
)
logging.warning("Task %r attempt %s failed: %s", task_name, attempt + 1, exc)
await asyncio.sleep(backoff)
Expand All @@ -764,17 +797,23 @@ async def _deploy(ra: dict, pp: str) -> bool:
session.mark_failed(f"Task {task_name!r}: {last_task_error}")
await render_model_output(
f"** 🤖💾 Session saved: {session.session_id}\n"
f"** 🤖💡 Resume with: --resume {session.session_id}\n"
f"** 🤖💡 Resume with: --resume {session.session_id}\n",
log=True,
async_task=False,
task_id=None,
)
raise last_task_error

if must_complete and not task_complete:
logging.critical("Required task not completed ... aborting!")
await render_model_output("🤖💥 *Required task not completed ...\n")
await render_model_output("🤖💥 *Required task not completed ...\n", log=True, async_task=False, task_id=None)
session.mark_failed(f"Required task {task_name!r} did not complete")
await render_model_output(
f"** 🤖💾 Session saved: {session.session_id}\n"
f"** 🤖💡 Resume with: --resume {session.session_id}\n"
f"** 🤖💡 Resume with: --resume {session.session_id}\n",
log=True,
async_task=False,
task_id=None,
)
break

Expand All @@ -790,4 +829,4 @@ async def _deploy(ra: dict, pp: str) -> bool:
# All tasks completed successfully
if session is not None and not session.error:
session.mark_finished()
await render_model_output(f"** 🤖✅ Session {session.session_id} completed\n")
await render_model_output(f"** 🤖✅ Session {session.session_id} completed\n", log=True, async_task=False, task_id=None)
Loading