diff --git a/pyproject.toml b/pyproject.toml index 6f805a7f..c2a2820f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -190,7 +190,6 @@ ignore = [ "A001", # Variable shadows built-in "A002", # Argument shadows built-in "FBT001", # Boolean positional arg - "FBT002", # Boolean default value "N801", # Class name casing "N802", # Function name casing "N806", # Variable casing diff --git a/scripts/migrate_to_jinja2.py b/scripts/migrate_to_jinja2.py index 43844f52..6a35d77f 100755 --- a/scripts/migrate_to_jinja2.py +++ b/scripts/migrate_to_jinja2.py @@ -21,7 +21,7 @@ class TemplateMigrator: """Migrates custom template syntax to Jinja2.""" - def __init__(self, dry_run: bool = False): + def __init__(self, dry_run: bool): self.dry_run = dry_run self.transformations: List[Tuple[str, str]] = [] @@ -133,7 +133,7 @@ def _show_diff(self, original: str, migrated: str): print(f" - {orig}") print(f" + {mig}") - def migrate_directory(self, directory: Path, recursive: bool = True) -> int: + def migrate_directory(self, directory: Path, recursive: bool) -> int: """Migrate all YAML files in directory. Returns: diff --git a/src/seclab_taskflow_agent/_stream.py b/src/seclab_taskflow_agent/_stream.py index b758a55f..d7f1b0ca 100644 --- a/src/seclab_taskflow_agent/_stream.py +++ b/src/seclab_taskflow_agent/_stream.py @@ -106,7 +106,7 @@ async def drive_backend_stream( watchdog_ping() if isinstance(event, TextDelta): await render_model_output( - event.text, async_task=async_task, task_id=task_id + event.text, log=True, async_task=async_task, task_id=task_id ) elif isinstance(event, ToolEnd): await bridge_copilot_tool_event(event, run_hooks) @@ -120,7 +120,7 @@ async def drive_backend_stream( await aclose() except Exception: # noqa: BLE001 - best-effort cleanup logging.exception("Failed to aclose backend stream iterator") - await render_model_output("\n\n", async_task=async_task, task_id=task_id) + await render_model_output("\n\n", log=True, async_task=async_task, task_id=task_id) return except BackendTimeoutError: if not max_retry: diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index a113a04f..a7a67481 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -151,7 +151,8 @@ def __init__( name: str = "TaskAgent", instructions: str = "", handoffs: list[Any] | None = None, - exclude_from_context: bool = False, + *, + exclude_from_context: bool, mcp_servers: list[Any] | None = None, model: str = DEFAULT_MODEL, model_settings: ModelSettings | None = None, diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py index b5bf2ebc..8181d307 100644 --- a/src/seclab_taskflow_agent/cli.py +++ b/src/seclab_taskflow_agent/cli.py @@ -91,7 +91,7 @@ def main( list_models: Annotated[ bool, typer.Option("-l", "--list-models", help="List available tool-call models and exit."), - ] = False, + ] = typer.Option(False), globals_: Annotated[ list[str] | None, typer.Option("-g", "--global", help="Global variable as KEY=VALUE. Repeatable."), @@ -99,7 +99,7 @@ def main( debug: Annotated[ bool, typer.Option("-d", "--debug", help="Show full tracebacks on errors."), - ] = False, + ] = typer.Option(False), resume: Annotated[ str | None, typer.Option("--resume", help="Resume a previous session by its ID."), diff --git a/src/seclab_taskflow_agent/mcp_lifecycle.py b/src/seclab_taskflow_agent/mcp_lifecycle.py index f24fe74a..30eeec0c 100644 --- a/src/seclab_taskflow_agent/mcp_lifecycle.py +++ b/src/seclab_taskflow_agent/mcp_lifecycle.py @@ -45,7 +45,8 @@ def build_mcp_servers( available_tools: AvailableTools, toolboxes: list[str], blocked_tools: list[str] | None = None, - headless: bool = False, + *, + headless: bool, ) -> list[MCPServerEntry]: """Build MCP server instances for the given toolboxes. diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/client.py b/src/seclab_taskflow_agent/mcp_servers/codeql/client.py index bba38a4e..ca658a39 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/client.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/client.py @@ -44,7 +44,8 @@ def __init__( self, codeql_cli=os.getenv("CODEQL_CLI", default="codeql"), server_options=["--threads=0", "--quiet"], - log_stderr=False, + *, + log_stderr: bool, ): self.server_options = server_options.copy() if log_stderr: @@ -406,7 +407,7 @@ def _bqrs_to_sarif(self, bqrs_path, query_info, max_paths=10): class QueryServer(CodeQL): - def __init__(self, database: Path, keep_alive=False, log_stderr=False): + def __init__(self, database: Path, keep_alive: bool, log_stderr: bool): super().__init__(log_stderr=log_stderr) self.database = database self.keep_alive = keep_alive @@ -476,7 +477,7 @@ def _file_uri_to_path(uri): return path, region -def _get_source_prefix(database_path: Path, strip_leading_slash=True) -> str: +def _get_source_prefix(database_path: Path, strip_leading_slash: bool) -> str: # grab the source prefix from codeql-database.yml db_yml_path = Path(database_path) / Path("codeql-database.yml") with open(db_yml_path) as stream: @@ -491,10 +492,10 @@ def _get_source_prefix(database_path: Path, strip_leading_slash=True) -> str: raise -def list_src_files(database_path: str | Path, as_uri=False, strip_prefix=True): +def list_src_files(database_path: str | Path, as_uri: bool, strip_prefix: bool): src_path = Path(database_path) / Path("src.zip") files = shell_command_to_string(["zipinfo", "-1", src_path]).split("\n") - source_prefix = _get_source_prefix(Path(database_path)) + source_prefix = _get_source_prefix(Path(database_path), strip_leading_slash=True) # file:// uri are formatted absolute paths even if they're relative files = [ f"{'file:///' if as_uri else ''}{path.strip().removeprefix(source_prefix if strip_prefix else '')}" @@ -503,11 +504,11 @@ def list_src_files(database_path: str | Path, as_uri=False, strip_prefix=True): return files -def search_in_src_archive(database_path: str, search_term: str, as_uri=False, strip_prefix=True): +def search_in_src_archive(database_path: str, search_term: str, as_uri: bool, strip_prefix: bool): database_path = Path(database_path) src_path = database_path / Path("src.zip") results = {} - source_prefix = _get_source_prefix(database_path) + source_prefix = _get_source_prefix(database_path, strip_leading_slash=True) with zipfile.ZipFile(src_path) as z: for entry in z.infolist(): if entry.is_dir(): @@ -528,7 +529,7 @@ def _file_from_src_archive(relative_path: str | Path, database_path: str | Path, # our shell utility is Popen based, so no expansions occur database_path = Path(database_path) src_path = database_path / Path("src.zip") - source_prefix = _get_source_prefix(Path(database_path)) + source_prefix = _get_source_prefix(Path(database_path), strip_leading_slash=True) # normalize relative path relative_path = Path(str(relative_path).lstrip("/").removeprefix(source_prefix)) resolved_path = Path(source_prefix) / Path(relative_path) @@ -596,8 +597,9 @@ def run_query( progress_callback=None, template_values=None, # keep the query server alive if desired - keep_alive=True, - log_stderr=False, + *, + keep_alive: bool, + log_stderr: bool, ): result = "" query_path = Path(query_path) diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py b/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py index d245666a..e6d9b19d 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py @@ -115,6 +115,7 @@ def _run_query(query_name: str, database_path: str, language: str, template_valu database_path, fmt="csv", template_values=template_values, + keep_alive=True, log_stderr=True, ) return _csv_to_json_obj(csv) @@ -150,7 +151,7 @@ def list_source_files( ): """List the available source files in a CodeQL database using their file:// URI""" database_path = _resolve_db_path(database_path) - results = list_src_files(database_path, as_uri=True) + results = list_src_files(database_path, as_uri=True, strip_prefix=True) return json.dumps([{"uri": item} for item in results if re.search(regex_filter, item)], indent=2) @@ -163,7 +164,7 @@ def search_in_source_code( Search for a string in the source code. Returns the line number and file. """ resolved_database_path = _resolve_db_path(database_path) - results = search_in_src_archive(resolved_database_path, search_term) + results = search_in_src_archive(resolved_database_path, search_term, as_uri=False, strip_prefix=True) out = [] if isinstance(results, dict): for k, v in results.items(): diff --git a/src/seclab_taskflow_agent/render_utils.py b/src/seclab_taskflow_agent/render_utils.py index 7a018506..f19f4ebf 100644 --- a/src/seclab_taskflow_agent/render_utils.py +++ b/src/seclab_taskflow_agent/render_utils.py @@ -27,11 +27,11 @@ async def flush_async_output(task_id: str) -> None: # No buffered output (agent may have failed before producing any). return data = async_output.pop(task_id) - await render_model_output(f"** 🤖✏️ Output for async task: {task_id}\n\n") - await render_model_output(data) + await render_model_output(f"** 🤖✏️ Output for async task: {task_id}\n\n", log=True, async_task=False, task_id=None) + await render_model_output(data, log=True, async_task=False, task_id=None) -async def render_model_output(data: str, log: bool = True, async_task: bool = False, task_id: str | None = None) -> None: +async def render_model_output(data: str, log: bool, async_task: bool, task_id: str | None = None) -> None: """Print model output to the console, optionally buffering for async tasks.""" async with async_output_lock: if async_task and task_id: diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 12d36bd8..97dcd05b 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -220,7 +220,7 @@ async def _build_prompts_to_run( raise if not iterable_result: - await render_model_output("** 🤖❗MCP tool result iterable is empty!\n") + await render_model_output("** 🤖❗MCP tool result iterable is empty!\n", log=True, async_task=False, task_id=None) else: logging.debug("Rendering templated prompts for results: %s", iterable_result) for value in iterable_result: @@ -286,11 +286,21 @@ async def deploy_task_agents( blocked_tools = blocked_tools or [] task_id = str(uuid.uuid4()) - await render_model_output(f"** 🤖💪 Deploying Task Flow Agent(s): {list(agents.keys())}\n") - await render_model_output(f"** 🤖💪 Task ID : {task_id}\n") - await render_model_output(f"** 🤖💪 Model : {model}{', params: ' + str(model_par) if model_par else ''}\n") + await render_model_output( + f"** 🤖💪 Deploying Task Flow Agent(s): {list(agents.keys())}\n", + log=True, + async_task=False, + task_id=None, + ) + await render_model_output(f"** 🤖💪 Task ID : {task_id}\n", log=True, async_task=False, task_id=None) + await render_model_output( + f"** 🤖💪 Model : {model}{', params: ' + str(model_par) if model_par else ''}\n", + log=True, + async_task=False, + task_id=None, + ) if endpoint: - await render_model_output(f"** 🤖💪 Endpoint: {endpoint}\n") + await render_model_output(f"** 🤖💪 Endpoint: {endpoint}\n", log=True, async_task=False, task_id=None) # Resolve toolboxes from personality definitions or override toolboxes: list[str] = [] @@ -313,7 +323,7 @@ async def deploy_task_agents( model_params: dict[str, Any] = dict(model_par) # Build MCP servers and collect server prompts - entries = build_mcp_servers(available_tools, toolboxes, blocked_tools, headless) + entries = build_mcp_servers(available_tools, toolboxes, blocked_tools, headless=headless) mcp_params = mcp_client_params(available_tools, toolboxes) server_prompts = [sp for _, (_, _, sp, _) in mcp_params.items()] @@ -428,16 +438,24 @@ async def deploy_task_agents( complete = True except BackendMaxTurnsError as e: - await render_model_output(f"** 🤖❗ Max Turns Reached: {e}\n", async_task=async_task, task_id=task_id) + await render_model_output( + f"** 🤖❗ Max Turns Reached: {e}\n", log=True, async_task=async_task, task_id=task_id + ) logging.exception("Exceeded max_turns: %s", max_turns) except BackendUnexpectedError as e: - await render_model_output(f"** 🤖❗ Agent Exception: {e}\n", async_task=async_task, task_id=task_id) + await render_model_output( + f"** 🤖❗ Agent Exception: {e}\n", log=True, async_task=async_task, task_id=task_id + ) logging.exception("Agent Exception") except BackendBadRequestError as e: - await render_model_output(f"** 🤖❗ Request Error: {e}\n", async_task=async_task, task_id=task_id) + await render_model_output( + f"** 🤖❗ Request Error: {e}\n", log=True, async_task=async_task, task_id=task_id + ) logging.exception("Bad Request") except BackendTimeoutError as e: - await render_model_output(f"** 🤖❗ Timeout Error: {e}\n", async_task=async_task, task_id=task_id) + await render_model_output( + f"** 🤖❗ Timeout Error: {e}\n", log=True, async_task=async_task, task_id=task_id + ) logging.exception("API Timeout") if async_task: @@ -519,10 +537,12 @@ async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TC async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None: watchdog_ping() - await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n") + await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n", log=True, async_task=False, task_id=None) async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]) -> None: - await render_model_output(f"\n** 🤖🤝 Agent Handoff: {source.name} -> {agent.name}\n") + await render_model_output( + f"\n** 🤖🤝 Agent Handoff: {source.name} -> {agent.name}\n", log=True, async_task=False, task_id=None + ) if personality_path: personality = available_tools.get_personality(personality_path) @@ -539,7 +559,9 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo if resume_session_id: session = TaskflowSession.load(resume_session_id) if session.finished: - await render_model_output(f"** 🤖✅ Session {resume_session_id} already completed\n") + await render_model_output( + f"** 🤖✅ Session {resume_session_id} already completed\n", log=True, async_task=False, task_id=None + ) return taskflow_path = session.taskflow_path cli_globals = session.cli_globals @@ -549,11 +571,14 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo if not cli_model_config and session.cli_model_config: cli_model_config = session.cli_model_config await render_model_output( - f"** 🤖🔄 Resuming session {resume_session_id} from task {session.next_task_index}\n" + f"** 🤖🔄 Resuming session {resume_session_id} from task {session.next_task_index}\n", + log=True, + async_task=False, + task_id=None, ) taskflow_doc = available_tools.get_taskflow(taskflow_path) - await render_model_output(f"** 🤖💪 Running Task Flow: {taskflow_path}\n") + await render_model_output(f"** 🤖💪 Running Task Flow: {taskflow_path}\n", log=True, async_task=False, task_id=None) # Resolve global variables (file defaults + CLI overrides) global_variables = dict(taskflow_doc.globals or {}) @@ -584,13 +609,16 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo cli_model_config=cli_model_config or "", ) session.save() - await render_model_output(f"** 🤖📋 Session: {session.session_id}\n") + await render_model_output(f"** 🤖📋 Session: {session.session_id}\n", log=True, async_task=False, task_id=None) for task_index, task_wrapper in enumerate(taskflow_doc.taskflow): # Skip already-completed tasks on resume if task_index < session.next_task_index: await render_model_output( - f"** 🤖⏭️ Skipping completed task {task_index}\n" + f"** 🤖⏭️ Skipping completed task {task_index}\n", + log=True, + async_task=False, + task_id=None, ) continue @@ -642,15 +670,17 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo available_tools, global_variables, inputs, ) - async def run_prompts(async_task: bool = False, max_concurrent_tasks: int = 5) -> bool: + async def run_prompts(async_task: bool, max_concurrent_tasks: int = 5) -> bool: if run: - await render_model_output("** 🤖🐚 Executing Shell Task\n") + await render_model_output("** 🤖🐚 Executing Shell Task\n", log=True, async_task=False, task_id=None) try: result = shell_tool_call(run).content[0].model_dump_json() last_mcp_tool_results.append(result) return True except RuntimeError as e: - await render_model_output(f"** 🤖❗ Shell Task Exception: {e}\n") + await render_model_output( + f"** 🤖❗ Shell Task Exception: {e}\n", log=True, async_task=False, task_id=None + ) logging.exception("Shell task error") return False @@ -746,7 +776,10 @@ async def _deploy(ra: dict, pp: str) -> bool: backoff = TASK_RETRY_BACKOFF * (attempt + 1) await render_model_output( f"** 🤖🔄 Task {task_name!r} failed: {exc}\n" - f"** 🤖🔄 Retrying in {backoff}s ({remaining} attempts left)\n" + f"** 🤖🔄 Retrying in {backoff}s ({remaining} attempts left)\n", + log=True, + async_task=False, + task_id=None, ) logging.warning("Task %r attempt %s failed: %s", task_name, attempt + 1, exc) await asyncio.sleep(backoff) @@ -764,17 +797,23 @@ async def _deploy(ra: dict, pp: str) -> bool: session.mark_failed(f"Task {task_name!r}: {last_task_error}") await render_model_output( f"** 🤖💾 Session saved: {session.session_id}\n" - f"** 🤖💡 Resume with: --resume {session.session_id}\n" + f"** 🤖💡 Resume with: --resume {session.session_id}\n", + log=True, + async_task=False, + task_id=None, ) raise last_task_error if must_complete and not task_complete: logging.critical("Required task not completed ... aborting!") - await render_model_output("🤖💥 *Required task not completed ...\n") + await render_model_output("🤖💥 *Required task not completed ...\n", log=True, async_task=False, task_id=None) session.mark_failed(f"Required task {task_name!r} did not complete") await render_model_output( f"** 🤖💾 Session saved: {session.session_id}\n" - f"** 🤖💡 Resume with: --resume {session.session_id}\n" + f"** 🤖💡 Resume with: --resume {session.session_id}\n", + log=True, + async_task=False, + task_id=None, ) break @@ -790,4 +829,4 @@ async def _deploy(ra: dict, pp: str) -> bool: # All tasks completed successfully if session is not None and not session.error: session.mark_finished() - await render_model_output(f"** 🤖✅ Session {session.session_id} completed\n") + await render_model_output(f"** 🤖✅ Session {session.session_id} completed\n", log=True, async_task=False, task_id=None) diff --git a/src/seclab_taskflow_agent/template_utils.py b/src/seclab_taskflow_agent/template_utils.py index 2f21d4a6..ccd99da3 100644 --- a/src/seclab_taskflow_agent/template_utils.py +++ b/src/seclab_taskflow_agent/template_utils.py @@ -56,7 +56,7 @@ def get_source( raise jinja2.TemplateNotFound(template) -def env_function(var_name: str, default: Optional[str] = None, required: bool = True) -> str: +def env_function(var_name: str, default: Optional[str], required: bool) -> str: """Jinja2 function to access environment variables. Args: @@ -107,7 +107,13 @@ def create_jinja_environment(available_tools: "AvailableTools") -> jinja2.Enviro ) # Register custom functions - env.globals['env'] = env_function + env.globals["env"] = ( + lambda var_name, default=None, required=None: env_function( + var_name, + default, + required=True if required is None else required, + ) + ) return env diff --git a/tests/test_template_utils.py b/tests/test_template_utils.py index f2f812a5..7b4b6d97 100644 --- a/tests/test_template_utils.py +++ b/tests/test_template_utils.py @@ -22,14 +22,14 @@ def test_env_existing_var(self): """Test accessing existing environment variable.""" os.environ['TEST_VAR_JINJA'] = 'test_value' try: - assert env_function('TEST_VAR_JINJA') == 'test_value' + assert env_function('TEST_VAR_JINJA', default=None, required=True) == 'test_value' finally: del os.environ['TEST_VAR_JINJA'] def test_env_missing_required(self): """Test error on missing required variable.""" with pytest.raises(LookupError, match="Required environment variable"): - env_function('NONEXISTENT_VAR_JINJA') + env_function('NONEXISTENT_VAR_JINJA', default=None, required=True) def test_env_with_default(self): """Test default value for missing variable.""" @@ -38,14 +38,14 @@ def test_env_with_default(self): def test_env_optional_missing(self): """Test optional variable returns empty string.""" - result = env_function('NONEXISTENT_VAR_JINJA', required=False) + result = env_function('NONEXISTENT_VAR_JINJA', default=None, required=False) assert result == '' def test_env_with_default_exists(self): """Test that existing var takes precedence over default.""" os.environ['TEST_VAR_DEFAULT'] = 'actual_value' try: - result = env_function('TEST_VAR_DEFAULT', default='default_value') + result = env_function('TEST_VAR_DEFAULT', default='default_value', required=True) assert result == 'actual_value' finally: del os.environ['TEST_VAR_DEFAULT']