diff --git a/.gitignore b/.gitignore index a68417f916f..c07a834e868 100644 --- a/.gitignore +++ b/.gitignore @@ -229,3 +229,6 @@ sweep.timestamp # CUDA *.ptx + +# Perfetto +trace.json \ No newline at end of file diff --git a/bench-orchestrator/bench_orchestrator/cli.py b/bench-orchestrator/bench_orchestrator/cli.py index a2974b9e809..eba7521fe9f 100644 --- a/bench-orchestrator/bench_orchestrator/cli.py +++ b/bench-orchestrator/bench_orchestrator/cli.py @@ -78,6 +78,7 @@ def run( track_memory: Annotated[bool, typer.Option("--track-memory", help="Track memory usage")] = False, samply: Annotated[bool, typer.Option("--samply", help="Record a profile using samply")] = False, sample_rate: Annotated[int, typer.Option("--sample-rate", help="Sample rate to run samply with")] = None, + tracing: Annotated[bool, typer.Option("--tracing", help="Record a trace for use with perfetto")] = False, build: Annotated[bool, typer.Option("--build/--no-build", help="Build binaries before running")] = True, verbose: Annotated[bool, typer.Option("--verbose", "-v", help="Log underlying commands")] = False, options: Annotated[list[str] | None, typer.Option("--opt", help="Engine or benchmark specific options")] = None, @@ -162,6 +163,7 @@ def run( track_memory=track_memory, samply=samply, sample_rate=sample_rate, + tracing=tracing, on_result=ctx.write_raw_json, ) console.print(f"[green]{eng.value}: {len(results)} results[/green]") @@ -185,6 +187,42 @@ def run( # Not enough combinations to compare pass + # If tracing was enabled, start a localhost server to serve the trace file (./trace.json) and open the + # perfetto UI in the browser + if tracing: + import http.server + import socketserver + import threading + import webbrowser + + # This is the only localhost port allowed by Perfetto's CSP. + HOST = "127.0.0.1" + PORT = 9001 + + class TraceRequestHandler(http.server.SimpleHTTPRequestHandler): + def do_GET(self): + if self.path == "/trace.json": + self.path = "trace.json" + return super().do_GET() + + def do_POST(self): + self.send_error(404, "File not found") + + def end_headers(self): + self.send_header("Access-Control-Allow-Origin", "*") + super().end_headers() + + def start_server(): + socketserver.TCPServer.allow_reuse_address = True + with socketserver.TCPServer(("", PORT), TraceRequestHandler) as httpd: + console.print(f"[green]Serving trace on http://{HOST}:{PORT}/trace.json[/green]") + httpd.serve_forever() + + server_thread = threading.Thread(target=start_server, daemon=True) + server_thread.start() + webbrowser.open_new_tab(f"http://ui.perfetto.dev/#!/?url=http://{HOST}:{PORT}/trace.json") + server_thread.join() + @app.command() def compare( diff --git a/bench-orchestrator/bench_orchestrator/runner/executor.py b/bench-orchestrator/bench_orchestrator/runner/executor.py index 59fc58d55cf..40eced50b9a 100644 --- a/bench-orchestrator/bench_orchestrator/runner/executor.py +++ b/bench-orchestrator/bench_orchestrator/runner/executor.py @@ -36,6 +36,7 @@ def run( track_memory: bool = False, samply: bool = False, sample_rate: int | None = None, + tracing: bool = False, on_result: Callable[[str], None] | None = None, ) -> list[str]: """ @@ -72,6 +73,8 @@ def run( cmd.extend(["--exclude-queries", ",".join(map(str, exclude_queries))]) if track_memory: cmd.append("--track-memory") + if tracing: + cmd.append("--tracing") if options: for k, v in options.items(): cmd.extend(["--opt", f"{k}={v}"])