Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions diffsynth_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
get_pipeline_class_name,
)
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.torch_profiler import TorchProfiler
from diffsynth_engine.worker import run_worker_loop

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -86,18 +87,32 @@ def _init_pipeline(self):

def generate(self, **kwargs):
if self.workers is not None:
return self._generate(**kwargs)
return self._run_worker("__call__", kwargs, output_rank=0)
else:
return self.pipeline(**kwargs)

def _generate(self, **kwargs):
def _run_worker(self, method: str, kwargs: dict | None = None, output_rank: int | None = 0):
# TODO: health check and timeout
self.conns[0].send({"method": "__call__", "kwargs": kwargs})

result = self.conns[0].recv()

self.conns[0].send(
{
"method": method,
"kwargs": kwargs or {},
"output_rank": output_rank,
}
)

if output_rank is None:
outputs = []
for rank, conn in enumerate(self.conns):
result = conn.recv()
if result["status"] != "success":
raise RuntimeError(f"{method} failed on rank {rank}: {result.get('error', 'Unknown error')}")
outputs.append(result["output"])
return outputs

result = self.conns[output_rank].recv()
if result["status"] != "success":
raise RuntimeError(f"Generation failed: {result.get('error', 'Unknown error')}")
raise RuntimeError(f"{method} failed on rank {output_rank}: {result.get('error', 'Unknown error')}")

return result["output"]

Expand All @@ -124,3 +139,34 @@ def shutdown(self):

self.workers = None
self.conns = None

def start_profile(self, path: str = ".", profile_rank0_only: bool = True):
if self.workers is not None:
self._run_worker(
"start_profile",
{
"path": path,
"profile_rank0_only": profile_rank0_only,
},
output_rank=0,
)
else:
TorchProfiler.start(path, profile_rank0_only=profile_rank0_only)

def stop_profile(self):
if self.workers is not None:
results = self._run_worker("stop_profile", {}, output_rank=None)
else:
results = [TorchProfiler.stop()]

output_files = {"traces": []}
for result in results:
if not isinstance(result, dict):
continue

trace = result.get("trace")
if trace:
output_files["traces"].append(trace)

logger.info("Profile traces: %s", output_files["traces"])
return output_files
133 changes: 133 additions & 0 deletions diffsynth_engine/utils/torch_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from vLLM-Omni vllm_omni/diffusion/profiler/torch_profiler.py.

import os
import subprocess
from contextlib import nullcontext

import torch
from torch.profiler import ProfilerActivity, profile

from diffsynth_engine.utils import logging

logger = logging.get_logger(__name__)


class TorchProfiler:
"""
Torch-based profiler configured for End-to-End continuous recording.
Uses 'on_trace_ready' to handle Trace export.
Compression is offloaded to a background subprocess to avoid blocking the worker loop.
"""

_profiler: profile | None = None
_trace_template: str = ""

@classmethod
def start(cls, trace_path_template: str, profile_rank0_only: bool = True) -> str:
"""
Start the profiler with the given trace path template.
"""
# 1. Cleanup any existing profiler
if cls._profiler is not None:
logger.warning("[Rank %s] Stopping existing Torch profiler", cls._get_rank())
cls._profiler.stop()
cls._profiler = None

rank = cls._get_rank()

# 2. Make path absolute
trace_path_template = os.path.abspath(trace_path_template)
cls._trace_template = trace_path_template

if rank != 0 and profile_rank0_only:
return ""

# Expected paths
json_file = f"{trace_path_template}_rank{rank}.json"

os.makedirs(os.path.dirname(json_file), exist_ok=True)

logger.info(f"[Rank {rank}] Starting End-to-End Torch profiler")

# 3. Define the on_trace_ready handler
def trace_handler(p):
nonlocal json_file

# A. Export JSON Trace
try:
p.export_chrome_trace(json_file)
logger.info(f"[Rank {rank}] Trace exported to {json_file}")

try:
subprocess.Popen(["gzip", "-f", json_file])
logger.info(f"[Rank {rank}] Triggered background compression for {json_file}")
# Update variable to point to the eventual file
json_file = f"{json_file}.gz"
except Exception as compress_err:
logger.warning(f"[Rank {rank}] Background gzip failed to start: {compress_err}")
Comment on lines +63 to +69
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The background compression relies on the system gzip utility, which may not be available in all environments (e.g., Windows or minimal containers). If gzip is missing, subprocess.Popen will raise a FileNotFoundError. While this is caught, the return values of start and stop will still incorrectly indicate a .gz extension. Consider checking for gzip availability using shutil.which and adjusting the returned file path accordingly, or use the built-in gzip module in a separate thread for better portability.

Suggested change
try:
subprocess.Popen(["gzip", "-f", json_file])
logger.info(f"[Rank {rank}] Triggered background compression for {json_file}")
# Update variable to point to the eventual file
json_file = f"{json_file}.gz"
except Exception as compress_err:
logger.warning(f"[Rank {rank}] Background gzip failed to start: {compress_err}")
import shutil
if shutil.which("gzip"):
try:
subprocess.Popen(["gzip", "-f", json_file])
logger.info(f"[Rank {rank}] Triggered background compression for {json_file}")
json_file = f"{json_file}.gz"
except Exception as compress_err:
logger.warning(f"[Rank {rank}] Background gzip failed to start: {compress_err}")
else:
logger.warning(f"[Rank {rank}] gzip utility not found, skipping compression")


except Exception as e:
logger.warning(f"[Rank {rank}] Failed to export trace: {e}")

# 4. Initialize profiler with long active period
cls._profiler = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=0,
warmup=0,
active=100000, # long capture window
),
on_trace_ready=trace_handler,
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
)

# 5. Start profiling
cls._profiler.start()

# Return the expected final path
return f"{trace_path_template}_rank{rank}.json.gz"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method returns a .gz path regardless of whether compression actually started or succeeded. This can lead to issues if the caller expects the file to exist at the returned path. It would be more robust to return the actual path based on the success of the compression trigger.


@classmethod
def stop(cls) -> dict | None:
if cls._profiler is None:
return None

rank = cls._get_rank()

# Determine expected paths
base_path = f"{cls._trace_template}_rank{rank}"
gz_path = f"{base_path}.json.gz"

try:
# This triggers trace_handler synchronously
# Since we removed table generation and backgrounded compression, this returns fast.
cls._profiler.stop()
except Exception as e:
logger.warning(f"[Rank {rank}] Profiler stop failed: {e}")

cls._profiler = None

# We return the .gz path assuming background compression will succeed.
return {"trace": gz_path, "table": None}

@classmethod
def step(cls):
if cls._profiler is not None:
cls._profiler.step()

@classmethod
def is_active(cls) -> bool:
return cls._profiler is not None

@classmethod
def get_step_context(cls):
return nullcontext()

@classmethod
def _get_rank(cls) -> int:
return int(os.getenv("RANK", "0"))
23 changes: 20 additions & 3 deletions diffsynth_engine/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from diffsynth_engine.pipelines.utils import get_pipeline_class
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.torch_profiler import TorchProfiler

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -59,6 +60,18 @@ def __init__(
def __call__(self, **kwargs):
return self.pipeline(**kwargs)

@classmethod
def start_profile(cls, **kwargs):
path = kwargs.get("path", ".")
profile_rank0_only = kwargs.get("profile_rank0_only", True)
return TorchProfiler.start(path, profile_rank0_only=profile_rank0_only)

@classmethod
def stop_profile(cls, **kwargs):
result = TorchProfiler.stop()
get_world_group().barrier()
return result


def run_worker_loop(
local_rank: int,
Expand Down Expand Up @@ -87,6 +100,7 @@ def run_worker_loop(
world_group = get_world_group()

while True:
should_reply = rank == 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable should_reply is initialized here but is immediately overwritten a few lines later (line 118) before being used. This initialization is redundant and can be removed to improve clarity.

Suggested change
should_reply = rank == 0
try:

try:
if rank == 0:
data = conn.recv()
Expand All @@ -100,8 +114,11 @@ def run_worker_loop(
if method == "shutdown":
break

output_rank = data.get("output_rank", 0)
should_reply = output_rank is None or output_rank == rank

output = getattr(worker, method)(**kwargs)
if rank == 0:
if should_reply:
conn.send(
{
"status": "success",
Expand All @@ -111,7 +128,7 @@ def run_worker_loop(
world_group.barrier()
except EOFError as e:
logger.error(f"Worker process {rank} connection closed: {e}", exc_info=True)
if rank == 0:
if should_reply:
conn.send(
{
"status": "error",
Expand All @@ -121,7 +138,7 @@ def run_worker_loop(
break
except Exception as e:
logger.error(f"Worker process {rank} error: {e}", exc_info=True)
if rank == 0:
if should_reply:
conn.send(
{
"status": "error",
Expand Down