Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ ignore = [
"EM101", # Exception string literals
"EM102", # Exception f-strings
"G004", # Logging f-strings
"T201", # print() used for user output
"TRY003", # Raise with inline message strings

# Backwards-compatibility suppressions for existing code
Expand Down Expand Up @@ -264,3 +263,5 @@ ignore = [

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["S101", "PLR2004"]
"release_tools/*" = ["T201"]
"scripts/*" = ["T201"]
40 changes: 21 additions & 19 deletions src/seclab_taskflow_agent/mcp_servers/codeql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

WAIT_INTERVAL = 0.1

logger = logging.getLogger(__name__)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging should be configured here (similar as to our other MCP servers, but with a different file name)
e.g.
https://github.com/GitHubSecurityLab/seclab-taskflows/blob/1eafe13b91b5fa717f96625b24badac9cb4ffe48/src/seclab_taskflows/mcp_servers/gh_file_viewer.py#L21



# for when our stdout goes into the void
def _debug_log(msg):
Expand All @@ -30,7 +32,7 @@ def _debug_log(msg):


def shell_command_to_string(cmd):
print(f"Executing: {cmd}")
logger.debug(f"Executing: {cmd}")
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8")
stdout, stderr = p.communicate()
p.wait()
Expand Down Expand Up @@ -88,7 +90,7 @@ def _server_start(self):

# set some default callbacks for common notifications
def _handle_ql_progressUpdated(params):
print(f">> Progress: {params.get('step')}/{params.get('maxStep')} status: {params.get('message')}")
logger.debug(f">> Progress: {params.get('step')}/{params.get('maxStep')} status: {params.get('message')}")

ql_progressUpdated = "ql/progressUpdated"
if ql_progressUpdated not in self.method_handlers:
Expand Down Expand Up @@ -150,7 +152,7 @@ def _callback(err: Exception, res: str | None = None):
if err:
raise err
self.active_database = database
print(f"++ {rpc_method}: {res}")
logger.debug(f"++ {rpc_method}: {res}")

return self._server_rpc_call(
rpc_method,
Expand All @@ -168,7 +170,7 @@ def _callback(err: Exception, res: str | None = None):
if err:
raise err
self.active_database = None
print(f"++ {rpc_method}: {res}")
logger.debug(f"++ {rpc_method}: {res}")

return self._server_rpc_call(
rpc_method,
Expand Down Expand Up @@ -232,25 +234,25 @@ def _check_runquery_result_for_errors(params: dict):
case 0:
return False, ""
case 1:
print(f"xx ERROR Other: {message}")
logger.error(f"xx ERROR Other: {message}")
return True, message
case 2:
print(f"xx ERROR Compilation: {message}")
logger.error(f"xx ERROR Compilation: {message}")
return True, message
case 3:
print(f"xx ERROR OOM: {message}")
logger.error(f"xx ERROR OOM: {message}")
return True, message
case 4:
print(f"xx ERROR Query Canceled: {message}")
logger.error(f"xx ERROR Query Canceled: {message}")
return True, message
case 5:
print(f"xx ERROR DB Scheme mismatch: {message}")
logger.error(f"xx ERROR DB Scheme mismatch: {message}")
return True, message
case 6:
print(f"xx ERROR DB Scheme no upgrade found: {message}")
logger.error(f"xx ERROR DB Scheme no upgrade found: {message}")
return True, message
case _:
print(f"xx ERROR: unknown result type {result_type}: {message}")
logger.error(f"xx ERROR: unknown result type {result_type}: {message}")
return True, message
else:
return False, ""
Expand All @@ -260,7 +262,7 @@ def _check_runquery_result_for_errors(params: dict):
else:
self.active_query_error = (True, f"Unknown result state: {res}")
self.active_query_id = None
print(f"++ {rpc_method}: {res}")
logger.debug(f"++ {rpc_method}: {res}")
if err:
raise err

Expand All @@ -284,7 +286,7 @@ def _search_paths_from_codeql_config(self, config="~/.config/codeql/config"):
if match and match.group(2):
return match.group(2).split(":")
except FileNotFoundError as e:
print(f"Error: {e}")
logger.error(f"Error: {e}")
return []

def _lang_server_contact(self):
Expand All @@ -310,7 +312,7 @@ def _resolve_library_paths(self, query_path):
args = ["resolve", "library-path"]
args += ["-v", "--log-to-stderr", "--format=json"]
if search_path:
print(f"Using search path: {search_path}")
logger.debug(f"Using search path: {search_path}")
args += [f'--additional-packs="{search_path}"']
args += [f"--query={query_path}"]
return json.loads(shell_command_to_string(self.codeql_cli + args))
Expand Down Expand Up @@ -364,7 +366,7 @@ def _bqrs_to_csv(self, bqrs_path, entities=""):
with open(csv_out) as f:
return f.read()
except RuntimeError as e:
print(f"Could not decode {bqrs_path} to {csv_out}: {e}")
logger.error(f"Could not decode {bqrs_path} to {csv_out}: {e}")
return ""

def _bqrs_to_json(self, bqrs_path, entities):
Expand All @@ -377,7 +379,7 @@ def _bqrs_to_json(self, bqrs_path, entities):
with open(json_out) as f:
return f.read()
except RuntimeError as e:
print(f"Could not decode {bqrs_path} to {json_out}: {e}")
logger.error(f"Could not decode {bqrs_path} to {json_out}: {e}")
return ""

def _bqrs_to_sarif(self, bqrs_path, query_info, max_paths=10):
Expand All @@ -401,7 +403,7 @@ def _bqrs_to_sarif(self, bqrs_path, query_info, max_paths=10):
):
with open(sarif_out) as f:
return f.read()
print(f"Could not decode {bqrs_path} to {sarif_out}")
logger.error(f"Could not decode {bqrs_path} to {sarif_out}")
return ""


Expand All @@ -417,12 +419,12 @@ def __enter__(self):
return _ACTIVE_CODEQL_SERVERS[self.database]
if not self.active_connection:
self._server_start()
print("Waiting for server start ...")
logger.info("Waiting for server start ...")
while not self.active_connection:
time.sleep(WAIT_INTERVAL)
if not self.active_database:
self._server_register_database(self.database)
print("Waiting for database registration ...")
logger.info("Waiting for database registration ...")
while not self.active_database:
time.sleep(WAIT_INTERVAL)
if self.keep_alive:
Expand Down
4 changes: 3 additions & 1 deletion src/seclab_taskflow_agent/render_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import asyncio
import logging
import sys

from .path_utils import log_file_name

Expand Down Expand Up @@ -44,4 +45,5 @@ async def render_model_output(data: str, log: bool = True, async_task: bool = Fa
if data:
if log:
render_logger.info(data)
print(data, end="", flush=True)
sys.stdout.write(data)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a single line linter suppression is better here than using sys.stdout

sys.stdout.flush()