From f170728d7ee67d08b9616a77dd594559573b1e66 Mon Sep 17 00:00:00 2001 From: "zhuguoxuan.zgx" Date: Tue, 19 May 2026 20:43:32 +0800 Subject: [PATCH] Add torch profiler support --- diffsynth_engine/engine.py | 60 ++++++++-- diffsynth_engine/utils/torch_profiler.py | 133 +++++++++++++++++++++++ diffsynth_engine/worker.py | 23 +++- 3 files changed, 206 insertions(+), 10 deletions(-) create mode 100644 diffsynth_engine/utils/torch_profiler.py diff --git a/diffsynth_engine/engine.py b/diffsynth_engine/engine.py index 503e98b..0d9744e 100644 --- a/diffsynth_engine/engine.py +++ b/diffsynth_engine/engine.py @@ -7,6 +7,7 @@ get_pipeline_class_name, ) from diffsynth_engine.utils import logging +from diffsynth_engine.utils.torch_profiler import TorchProfiler from diffsynth_engine.worker import run_worker_loop logger = logging.get_logger(__name__) @@ -86,18 +87,32 @@ def _init_pipeline(self): def generate(self, **kwargs): if self.workers is not None: - return self._generate(**kwargs) + return self._run_worker("__call__", kwargs, output_rank=0) else: return self.pipeline(**kwargs) - def _generate(self, **kwargs): + def _run_worker(self, method: str, kwargs: dict | None = None, output_rank: int | None = 0): # TODO: health check and timeout - self.conns[0].send({"method": "__call__", "kwargs": kwargs}) - - result = self.conns[0].recv() - + self.conns[0].send( + { + "method": method, + "kwargs": kwargs or {}, + "output_rank": output_rank, + } + ) + + if output_rank is None: + outputs = [] + for rank, conn in enumerate(self.conns): + result = conn.recv() + if result["status"] != "success": + raise RuntimeError(f"{method} failed on rank {rank}: {result.get('error', 'Unknown error')}") + outputs.append(result["output"]) + return outputs + + result = self.conns[output_rank].recv() if result["status"] != "success": - raise RuntimeError(f"Generation failed: {result.get('error', 'Unknown error')}") + raise RuntimeError(f"{method} failed on rank {output_rank}: {result.get('error', 'Unknown error')}") return result["output"] @@ -124,3 +139,34 @@ def shutdown(self): self.workers = None self.conns = None + + def start_profile(self, path: str = ".", profile_rank0_only: bool = True): + if self.workers is not None: + self._run_worker( + "start_profile", + { + "path": path, + "profile_rank0_only": profile_rank0_only, + }, + output_rank=0, + ) + else: + TorchProfiler.start(path, profile_rank0_only=profile_rank0_only) + + def stop_profile(self): + if self.workers is not None: + results = self._run_worker("stop_profile", {}, output_rank=None) + else: + results = [TorchProfiler.stop()] + + output_files = {"traces": []} + for result in results: + if not isinstance(result, dict): + continue + + trace = result.get("trace") + if trace: + output_files["traces"].append(trace) + + logger.info("Profile traces: %s", output_files["traces"]) + return output_files diff --git a/diffsynth_engine/utils/torch_profiler.py b/diffsynth_engine/utils/torch_profiler.py new file mode 100644 index 0000000..06dc330 --- /dev/null +++ b/diffsynth_engine/utils/torch_profiler.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from vLLM-Omni vllm_omni/diffusion/profiler/torch_profiler.py. + +import os +import subprocess +from contextlib import nullcontext + +import torch +from torch.profiler import ProfilerActivity, profile + +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + + +class TorchProfiler: + """ + Torch-based profiler configured for End-to-End continuous recording. + Uses 'on_trace_ready' to handle Trace export. + Compression is offloaded to a background subprocess to avoid blocking the worker loop. + """ + + _profiler: profile | None = None + _trace_template: str = "" + + @classmethod + def start(cls, trace_path_template: str, profile_rank0_only: bool = True) -> str: + """ + Start the profiler with the given trace path template. + """ + # 1. Cleanup any existing profiler + if cls._profiler is not None: + logger.warning("[Rank %s] Stopping existing Torch profiler", cls._get_rank()) + cls._profiler.stop() + cls._profiler = None + + rank = cls._get_rank() + + # 2. Make path absolute + trace_path_template = os.path.abspath(trace_path_template) + cls._trace_template = trace_path_template + + if rank != 0 and profile_rank0_only: + return "" + + # Expected paths + json_file = f"{trace_path_template}_rank{rank}.json" + + os.makedirs(os.path.dirname(json_file), exist_ok=True) + + logger.info(f"[Rank {rank}] Starting End-to-End Torch profiler") + + # 3. Define the on_trace_ready handler + def trace_handler(p): + nonlocal json_file + + # A. Export JSON Trace + try: + p.export_chrome_trace(json_file) + logger.info(f"[Rank {rank}] Trace exported to {json_file}") + + try: + subprocess.Popen(["gzip", "-f", json_file]) + logger.info(f"[Rank {rank}] Triggered background compression for {json_file}") + # Update variable to point to the eventual file + json_file = f"{json_file}.gz" + except Exception as compress_err: + logger.warning(f"[Rank {rank}] Background gzip failed to start: {compress_err}") + + except Exception as e: + logger.warning(f"[Rank {rank}] Failed to export trace: {e}") + + # 4. Initialize profiler with long active period + cls._profiler = profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=torch.profiler.schedule( + wait=0, + warmup=0, + active=100000, # long capture window + ), + on_trace_ready=trace_handler, + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + ) + + # 5. Start profiling + cls._profiler.start() + + # Return the expected final path + return f"{trace_path_template}_rank{rank}.json.gz" + + @classmethod + def stop(cls) -> dict | None: + if cls._profiler is None: + return None + + rank = cls._get_rank() + + # Determine expected paths + base_path = f"{cls._trace_template}_rank{rank}" + gz_path = f"{base_path}.json.gz" + + try: + # This triggers trace_handler synchronously + # Since we removed table generation and backgrounded compression, this returns fast. + cls._profiler.stop() + except Exception as e: + logger.warning(f"[Rank {rank}] Profiler stop failed: {e}") + + cls._profiler = None + + # We return the .gz path assuming background compression will succeed. + return {"trace": gz_path, "table": None} + + @classmethod + def step(cls): + if cls._profiler is not None: + cls._profiler.step() + + @classmethod + def is_active(cls) -> bool: + return cls._profiler is not None + + @classmethod + def get_step_context(cls): + return nullcontext() + + @classmethod + def _get_rank(cls) -> int: + return int(os.getenv("RANK", "0")) diff --git a/diffsynth_engine/worker.py b/diffsynth_engine/worker.py index 4e3e645..31cc5f2 100644 --- a/diffsynth_engine/worker.py +++ b/diffsynth_engine/worker.py @@ -11,6 +11,7 @@ ) from diffsynth_engine.pipelines.utils import get_pipeline_class from diffsynth_engine.utils import logging +from diffsynth_engine.utils.torch_profiler import TorchProfiler logger = logging.get_logger(__name__) @@ -59,6 +60,18 @@ def __init__( def __call__(self, **kwargs): return self.pipeline(**kwargs) + @classmethod + def start_profile(cls, **kwargs): + path = kwargs.get("path", ".") + profile_rank0_only = kwargs.get("profile_rank0_only", True) + return TorchProfiler.start(path, profile_rank0_only=profile_rank0_only) + + @classmethod + def stop_profile(cls, **kwargs): + result = TorchProfiler.stop() + get_world_group().barrier() + return result + def run_worker_loop( local_rank: int, @@ -87,6 +100,7 @@ def run_worker_loop( world_group = get_world_group() while True: + should_reply = rank == 0 try: if rank == 0: data = conn.recv() @@ -100,8 +114,11 @@ def run_worker_loop( if method == "shutdown": break + output_rank = data.get("output_rank", 0) + should_reply = output_rank is None or output_rank == rank + output = getattr(worker, method)(**kwargs) - if rank == 0: + if should_reply: conn.send( { "status": "success", @@ -111,7 +128,7 @@ def run_worker_loop( world_group.barrier() except EOFError as e: logger.error(f"Worker process {rank} connection closed: {e}", exc_info=True) - if rank == 0: + if should_reply: conn.send( { "status": "error", @@ -121,7 +138,7 @@ def run_worker_loop( break except Exception as e: logger.error(f"Worker process {rank} error: {e}", exc_info=True) - if rank == 0: + if should_reply: conn.send( { "status": "error",