Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/mldebug/backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
"""

import importlib
import sys

from dataclasses import dataclass, field
from typing import Any

from mldebug.utils import cleanup_and_exit


@dataclass
class BackendConfig:
Expand Down Expand Up @@ -55,10 +56,10 @@ def create_backend(backend_type, config):
xrt_mod = importlib.import_module("mldebug.backend.xrt_impl")
except ModuleNotFoundError:
print("Unable to import Backend. Python 3.10 is required on Win/Linux and 3.12 on Embedded Linux.")
sys.exit(1)
cleanup_and_exit(config.args, 1)
except ImportError:
print("Unable to import XRT. Please check install.")
sys.exit(1)
cleanup_and_exit(config.args, 1)
return xrt_mod.XRTImpl(config.tiles, config.ctx_id, config.pid, config.device)

if backend_type == "test":
Expand Down
5 changes: 3 additions & 2 deletions src/mldebug/batch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from concurrent.futures import ThreadPoolExecutor, as_completed

from mldebug.utils import LOGGER, timeit
from mldebug.utils import LOGGER, cleanup_and_exit, timeit

# 16 byte pm, we assume 2 clock cycle delay
COMBO_EVENT_MAX_DELAY_CYCLES = 32
Expand Down Expand Up @@ -311,7 +311,7 @@ def _process_err(self):
else:
self.status_handle.get("aie_status_error.txt")
self._write_run_summary("FAIL")
sys.exit(1)
cleanup_and_exit(self.args, 1)

def _process_end_breakpoint(self, layer, it, sid):
"""
Expand Down Expand Up @@ -511,6 +511,7 @@ def _write_run_summary(self, status):
summary = {"status": status, "run_flags": flags_dict}

try:
pathlib.Path(self.args.top_output_dir).mkdir(parents=True, exist_ok=True)
with open(rsf, "w", encoding="utf-8") as fh:
json.dump(summary, fh, indent=2, default=str)
except (IOError, OSError) as e:
Expand Down
9 changes: 7 additions & 2 deletions src/mldebug/client_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from mldebug.interactive_controller import InteractiveController
from mldebug.layer_info import LayerInfo
from mldebug.memory_dumper import MemoryDumper
from mldebug.utils import LOGGER
from mldebug.utils import LOGGER, register_debug_server


class ClientDebug:
Expand Down Expand Up @@ -51,7 +51,12 @@ def __init__(self, args, ctx_id, pid, output_dir):

# Create this first so that connection will be aborted in case of crash
if self.args.automated_debug or self.args.l3:
debug_server = DebugServer(self.args.subgraph_name, self.output_dir, self.args.backend == "test")
debug_server = DebugServer(
self.output_dir, self.args.backend == "test", subgraph_name=self.args.subgraph_name,
)
# Track the live server so cleanup_and_exit() at unplanned exit points
# can send TERMINATE_CONNECTION to flexmlrt.
register_debug_server(debug_server)

try:
self.design_info = LayerInfo(args)
Expand Down
19 changes: 18 additions & 1 deletion src/mldebug/debug_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ class DebugServer:
and communication with flexmlrt for buffer dump and termination requests.
"""

def __init__(self, subgraph_name, output_dir, is_testmode, bind_addr=("127.0.0.1", 9000)) -> None:
def __init__(
self, output_dir, is_testmode, subgraph_name="subgraph",
bind_addr=("127.0.0.1", 9000), connect_timeout=None,
) -> None:
"""
Initialize the DebugServer instance.

Expand All @@ -28,11 +31,14 @@ def __init__(self, subgraph_name, output_dir, is_testmode, bind_addr=("127.0.0.1
output_dir (str): Directory where buffer dumps will be stored.
is_testmode (bool): Enables test mode, which disables socket operations for CI/testing.
bind_addr (tuple): Address and port to bind the debug server socket.
connect_timeout (float, optional): If set, accept() gives up after this
many seconds; used by cleanup paths to avoid hanging forever.
"""
self.bind_addr = bind_addr
self.subgraph_name = subgraph_name
self.output_dir = output_dir
self.is_testmode = is_testmode
self.connect_timeout = connect_timeout
self.server_socket = None
self.client_socket = None
self.start()
Expand Down Expand Up @@ -64,9 +70,20 @@ def start(self):
self.server_socket.listen(1)
LOGGER.verbose_print(f"Listening on {self.bind_addr}...")

if self.connect_timeout is not None:
self.server_socket.settimeout(self.connect_timeout)
self.client_socket, client_address = self.server_socket.accept()
# Reset to blocking mode for subsequent send/recv.
if self.connect_timeout is not None:
self.server_socket.settimeout(None)
self.client_socket.settimeout(None)
LOGGER.log(f"[INFO] Connected to FlexmlRT on {client_address}")
return True
except socket.timeout:
LOGGER.verbose_print(
f"Timed out after {self.connect_timeout}s waiting for flexmlrt to connect."
)
return False
except socket.error as e:
LOGGER.verbose_print(f"Socket error during setup or connection: {e}")
return False
Expand Down
59 changes: 38 additions & 21 deletions src/mldebug/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
import importlib
import os
import subprocess
import sys
import re

from mldebug.arch import load_aie_arch, AIE_DEV_PHX, AIE_DEV_STX, AIE_DEV_TEL
from mldebug.backend.core_dump_impl import CoreDumpFallbackReader
from mldebug.utils import LOGGER, is_aarch64, is_windows
from mldebug.utils import LOGGER, cleanup_and_exit, input_with_timeout, is_aarch64, is_windows

# Seconds to wait at interactive prompts before giving up and exiting.
HW_CONTEXT_INPUT_TIMEOUT_S = 60

@dataclass
class RunFlags:
Expand Down Expand Up @@ -126,13 +128,15 @@ def get_flag(s, default=False):
)


def check_registry_keys(npu3=False) -> None:
def check_registry_keys(args, npu3=False) -> None:
"""
Checks if specific registry keys are correctly configured on Windows,
and sets values if necessary for MLDebug operation. Exits on failure
or after making modifications.

Args:
args: Argument namespace. Used to drive flexmlrt cleanup on exit
(only when ``args.l3`` is set).
npu3 (bool): Whether to check npu3-specific registry keys.

Returns:
Expand Down Expand Up @@ -174,16 +178,16 @@ def check_registry_keys(npu3=False) -> None:
f"Error: Unable to access or create registry key:"
f" HKEY_LOCAL_MACHINE\\{key_path}. Please run tool with admin privileges."
)
sys.exit(1)
cleanup_and_exit(args, 1)
except ValueError:
LOGGER.log(f"Error: Invalid registry key format: {key_path}")
sys.exit(1)
cleanup_and_exit(args, 1)

if modified:
LOGGER.log(
"\nRegistry settings to enable MlDebug were modified. Please restart your machine for the changes to take effect."
)
sys.exit(1)
cleanup_and_exit(args, 1)
else:
LOGGER.log("\nRegistry settings check passed. No modifications were necessary.")

Expand Down Expand Up @@ -252,18 +256,13 @@ def print_hw_context_table(current_contexts: dict[str, dict[str, str]]) -> None:
LOGGER.log(f"{context:<12} {columns_str:<30} {context_data['pid']:<12} {context_data['status']:<12}")


def check_hw_context(device: str) -> tuple[int, int]:
def check_hw_context(args) -> tuple[int, int]:
"""
Finds and returns the hardware context and process ID from the xrt-smi command output.

If xrt-smi fails or no application is running, prompts the user to input ctx and pid manually.

Args:
device (str): Device identifier.

Returns:
Tuple[int, int]: Selected context ID and PID.
Returns (ctx_id, pid) from xrt-smi, prompting the user as a fallback.
Manual prompts time out after ``HW_CONTEXT_INPUT_TIMEOUT_S`` seconds and
call ``cleanup_and_exit(args, 1)`` on failure / timeout.
"""
device = args.device
filename = "xrt-smi_output.json"
use_shell = is_windows()

Expand Down Expand Up @@ -297,17 +296,35 @@ def check_hw_context(device: str) -> tuple[int, int]:
else:
print_hw_context_table(current_contexts)
# Ask user
selected_context_id = input("Multiple Contexts Found. Please enter the Context ID you want to select: ")
selected_context_id = input_with_timeout(
"Multiple Contexts Found. Please enter the Context ID you want to select: ",
HW_CONTEXT_INPUT_TIMEOUT_S,
)
if selected_context_id in current_contexts:
ctx = int(selected_context_id)
pid = int(current_contexts[selected_context_id]["pid"])
else:
LOGGER.log("Could not find the provided context, Exiting now.")
sys.exit(1)
cleanup_and_exit(args, 1)
except (FileNotFoundError, subprocess.CalledProcessError, json.JSONDecodeError):
LOGGER.log("Error with xrt-smi. Please enter ctx, pid manually.")
pid = int(input("Enter PID > "))
ctx = int(input("Enter CTX ID > "))
LOGGER.log(
f"Error with xrt-smi. Please enter ctx, pid manually "
f"(waiting up to {HW_CONTEXT_INPUT_TIMEOUT_S}s for each value)."
)
pid_str = input_with_timeout("Enter PID > ", HW_CONTEXT_INPUT_TIMEOUT_S)
if pid_str is None:
LOGGER.log("\nTimed out waiting for PID input. Exiting.")
cleanup_and_exit(args, 1)
ctx_str = input_with_timeout("Enter CTX ID > ", HW_CONTEXT_INPUT_TIMEOUT_S)
if ctx_str is None:
LOGGER.log("\nTimed out waiting for CTX ID input. Exiting.")
cleanup_and_exit(args, 1)
try:
pid = int(pid_str)
ctx = int(ctx_str)
except ValueError:
LOGGER.log("Invalid PID/CTX ID input. Exiting.")
cleanup_and_exit(args, 1)
return ctx, pid


Expand Down
2 changes: 1 addition & 1 deletion src/mldebug/memory_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _ensure_debug_server(self):
"""
if not self.debug_server:
LOGGER.log("[INFO] Starting L3 debug server...")
self.debug_server = DebugServer(None, self.output_dir, self.args.backend == "test")
self.debug_server = DebugServer(self.output_dir, self.args.backend == "test")
if not self.debug_server.client_socket and self.args.backend != "test":
LOGGER.log(
"[ERROR] Failed to connect to FlexML runtime. Make sure FlexML is running and waiting for debugger connection."
Expand Down
4 changes: 2 additions & 2 deletions src/mldebug/mldebug_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def launch_debug(args, output_dir):
context_id = 0
pid = 0
if args.backend == "xrt":
context_id, pid = check_hw_context(args.device)
context_id, pid = check_hw_context(args)
# Top debug handle
_apply_unsupported_kernels_from_args(args)
handle = ClientDebug(args, context_id, pid, output_dir)
Expand Down Expand Up @@ -370,7 +370,7 @@ def app():
for fsp in fsp_execution_order:
create_run_flags(args, subgraph_folder_path, fsp, fsp_execution_order)
if not registry_checked and args.backend == "xrt" and is_windows():
check_registry_keys(args.device == AIE_DEV_NPU3)
check_registry_keys(args, args.device == AIE_DEV_NPU3)
registry_checked = True
debug(args, timestamp, subgraph_name, fsp, model_folder_name)
if args.dump_aie_status:
Expand Down
71 changes: 71 additions & 0 deletions src/mldebug/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import os
import platform
import sys
import threading
import time


Expand Down Expand Up @@ -307,6 +309,75 @@ def print_tile_grid(title, tiles, register_values=None, format_type="hex"):

print(f"{'=' * total_width}")

def input_with_timeout(prompt, timeout):
"""
Read a line from stdin, or return None after ``timeout`` seconds.
Uses a daemon thread so it works on Windows (no signal.alarm).
"""
result = []

def _reader():
try:
result.append(input(prompt))
except EOFError:
pass

t = threading.Thread(target=_reader, daemon=True)
t.start()
t.join(timeout)
if t.is_alive():
return None
return result[0] if result else None


# Tracks the live DebugServer so cleanup_and_exit can close it on exit.
_active_debug_server = None


def register_debug_server(server):
"""Register the live DebugServer (or None to clear)."""
global _active_debug_server # pylint: disable=global-statement
_active_debug_server = server


def terminate_flexml_connection(timeout=5):
"""
Spin up a brief DebugServer, send TERMINATE_CONNECTION, and close.
Best-effort cleanup used on unplanned exit; all errors are swallowed.
"""
# Import lazily to avoid a circular import (debug_server imports LOGGER).
from mldebug.debug_server import DebugServer # pylint: disable=import-outside-toplevel

try:
server = DebugServer(
output_dir="",
is_testmode=False,
connect_timeout=timeout,
)
server.close()
except Exception as e: # pylint: disable=broad-except
LOGGER.log(f"[WARN] flexmlrt cleanup failed: {e}")


def cleanup_and_exit(args, code=1):
"""
Exit, first tearing down the flexmlrt connection when ``args.l3`` is set.
Closes the registered DebugServer if any, else starts a brief one to send
TERMINATE_CONNECTION (covers exits that happen before ClientDebug runs).
"""
global _active_debug_server # pylint: disable=global-statement
if args is not None and getattr(args, "l3", False):
if _active_debug_server is not None:
try:
_active_debug_server.close()
except Exception as e: # pylint: disable=broad-except
LOGGER.log(f"[WARN] Failed to close active debug server: {e}")
_active_debug_server = None
else:
terminate_flexml_connection()
sys.exit(code)


def is_aarch64():
"""
ARM
Expand Down
Loading