-
Notifications
You must be signed in to change notification settings - Fork 43
Add torch profiler support #245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
akaitsuki-ii
wants to merge
1
commit into
v1
Choose a base branch
from
dev/add-torch-profiler
base: v1
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| # Adapted from vLLM-Omni vllm_omni/diffusion/profiler/torch_profiler.py. | ||
|
|
||
| import os | ||
| import subprocess | ||
| from contextlib import nullcontext | ||
|
|
||
| import torch | ||
| from torch.profiler import ProfilerActivity, profile | ||
|
|
||
| from diffsynth_engine.utils import logging | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class TorchProfiler: | ||
| """ | ||
| Torch-based profiler configured for End-to-End continuous recording. | ||
| Uses 'on_trace_ready' to handle Trace export. | ||
| Compression is offloaded to a background subprocess to avoid blocking the worker loop. | ||
| """ | ||
|
|
||
| _profiler: profile | None = None | ||
| _trace_template: str = "" | ||
|
|
||
| @classmethod | ||
| def start(cls, trace_path_template: str, profile_rank0_only: bool = True) -> str: | ||
| """ | ||
| Start the profiler with the given trace path template. | ||
| """ | ||
| # 1. Cleanup any existing profiler | ||
| if cls._profiler is not None: | ||
| logger.warning("[Rank %s] Stopping existing Torch profiler", cls._get_rank()) | ||
| cls._profiler.stop() | ||
| cls._profiler = None | ||
|
|
||
| rank = cls._get_rank() | ||
|
|
||
| # 2. Make path absolute | ||
| trace_path_template = os.path.abspath(trace_path_template) | ||
| cls._trace_template = trace_path_template | ||
|
|
||
| if rank != 0 and profile_rank0_only: | ||
| return "" | ||
|
|
||
| # Expected paths | ||
| json_file = f"{trace_path_template}_rank{rank}.json" | ||
|
|
||
| os.makedirs(os.path.dirname(json_file), exist_ok=True) | ||
|
|
||
| logger.info(f"[Rank {rank}] Starting End-to-End Torch profiler") | ||
|
|
||
| # 3. Define the on_trace_ready handler | ||
| def trace_handler(p): | ||
| nonlocal json_file | ||
|
|
||
| # A. Export JSON Trace | ||
| try: | ||
| p.export_chrome_trace(json_file) | ||
| logger.info(f"[Rank {rank}] Trace exported to {json_file}") | ||
|
|
||
| try: | ||
| subprocess.Popen(["gzip", "-f", json_file]) | ||
| logger.info(f"[Rank {rank}] Triggered background compression for {json_file}") | ||
| # Update variable to point to the eventual file | ||
| json_file = f"{json_file}.gz" | ||
| except Exception as compress_err: | ||
| logger.warning(f"[Rank {rank}] Background gzip failed to start: {compress_err}") | ||
|
|
||
| except Exception as e: | ||
| logger.warning(f"[Rank {rank}] Failed to export trace: {e}") | ||
|
|
||
| # 4. Initialize profiler with long active period | ||
| cls._profiler = profile( | ||
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | ||
| schedule=torch.profiler.schedule( | ||
| wait=0, | ||
| warmup=0, | ||
| active=100000, # long capture window | ||
| ), | ||
| on_trace_ready=trace_handler, | ||
| record_shapes=True, | ||
| profile_memory=True, | ||
| with_stack=True, | ||
| with_flops=True, | ||
| ) | ||
|
|
||
| # 5. Start profiling | ||
| cls._profiler.start() | ||
|
|
||
| # Return the expected final path | ||
| return f"{trace_path_template}_rank{rank}.json.gz" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| @classmethod | ||
| def stop(cls) -> dict | None: | ||
| if cls._profiler is None: | ||
| return None | ||
|
|
||
| rank = cls._get_rank() | ||
|
|
||
| # Determine expected paths | ||
| base_path = f"{cls._trace_template}_rank{rank}" | ||
| gz_path = f"{base_path}.json.gz" | ||
|
|
||
| try: | ||
| # This triggers trace_handler synchronously | ||
| # Since we removed table generation and backgrounded compression, this returns fast. | ||
| cls._profiler.stop() | ||
| except Exception as e: | ||
| logger.warning(f"[Rank {rank}] Profiler stop failed: {e}") | ||
|
|
||
| cls._profiler = None | ||
|
|
||
| # We return the .gz path assuming background compression will succeed. | ||
| return {"trace": gz_path, "table": None} | ||
|
|
||
| @classmethod | ||
| def step(cls): | ||
| if cls._profiler is not None: | ||
| cls._profiler.step() | ||
|
|
||
| @classmethod | ||
| def is_active(cls) -> bool: | ||
| return cls._profiler is not None | ||
|
|
||
| @classmethod | ||
| def get_step_context(cls): | ||
| return nullcontext() | ||
|
|
||
| @classmethod | ||
| def _get_rank(cls) -> int: | ||
| return int(os.getenv("RANK", "0")) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| ) | ||
| from diffsynth_engine.pipelines.utils import get_pipeline_class | ||
| from diffsynth_engine.utils import logging | ||
| from diffsynth_engine.utils.torch_profiler import TorchProfiler | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
@@ -59,6 +60,18 @@ def __init__( | |
| def __call__(self, **kwargs): | ||
| return self.pipeline(**kwargs) | ||
|
|
||
| @classmethod | ||
| def start_profile(cls, **kwargs): | ||
| path = kwargs.get("path", ".") | ||
| profile_rank0_only = kwargs.get("profile_rank0_only", True) | ||
| return TorchProfiler.start(path, profile_rank0_only=profile_rank0_only) | ||
|
|
||
| @classmethod | ||
| def stop_profile(cls, **kwargs): | ||
| result = TorchProfiler.stop() | ||
| get_world_group().barrier() | ||
| return result | ||
|
|
||
|
|
||
| def run_worker_loop( | ||
| local_rank: int, | ||
|
|
@@ -87,6 +100,7 @@ def run_worker_loop( | |
| world_group = get_world_group() | ||
|
|
||
| while True: | ||
| should_reply = rank == 0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| try: | ||
| if rank == 0: | ||
| data = conn.recv() | ||
|
|
@@ -100,8 +114,11 @@ def run_worker_loop( | |
| if method == "shutdown": | ||
| break | ||
|
|
||
| output_rank = data.get("output_rank", 0) | ||
| should_reply = output_rank is None or output_rank == rank | ||
|
|
||
| output = getattr(worker, method)(**kwargs) | ||
| if rank == 0: | ||
| if should_reply: | ||
| conn.send( | ||
| { | ||
| "status": "success", | ||
|
|
@@ -111,7 +128,7 @@ def run_worker_loop( | |
| world_group.barrier() | ||
| except EOFError as e: | ||
| logger.error(f"Worker process {rank} connection closed: {e}", exc_info=True) | ||
| if rank == 0: | ||
| if should_reply: | ||
| conn.send( | ||
| { | ||
| "status": "error", | ||
|
|
@@ -121,7 +138,7 @@ def run_worker_loop( | |
| break | ||
| except Exception as e: | ||
| logger.error(f"Worker process {rank} error: {e}", exc_info=True) | ||
| if rank == 0: | ||
| if should_reply: | ||
| conn.send( | ||
| { | ||
| "status": "error", | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The background compression relies on the system
gziputility, which may not be available in all environments (e.g., Windows or minimal containers). Ifgzipis missing,subprocess.Popenwill raise aFileNotFoundError. While this is caught, the return values ofstartandstopwill still incorrectly indicate a.gzextension. Consider checking forgzipavailability usingshutil.whichand adjusting the returned file path accordingly, or use the built-ingzipmodule in a separate thread for better portability.