Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0c80165
feat: Add simulation mode POC Phase 0
venkywonka Mar 31, 2026
cb48558
fix: rc7/rc10 compatibility patches for sim mode dev container
venkywonka Mar 31, 2026
fb317dc
fix: Add start_worker call and restore priority kwarg for sim mode e2e
venkywonka Mar 31, 2026
bf97adb
feat: Add InferTimePredictor ABC and ConstantPredictor
venkywonka Mar 31, 2026
f95e1b7
feat: Add SimConfig Pydantic model, replace simulation_mode bool
venkywonka Mar 31, 2026
21bb4a1
fix: Move SimConfig import out of TYPE_CHECKING for Pydantic validation
venkywonka Mar 31, 2026
5e944a6
feat: Wire ConstantPredictor into SimModelEngine.forward()
venkywonka Mar 31, 2026
707437a
test: Add unit tests for InferTimePredictor and ConstantPredictor
venkywonka Mar 31, 2026
1717f66
test: Add unit tests for SimConfig and PredictorConfig
venkywonka Mar 31, 2026
2216e3d
feat: Add SimBatchRequest and requests field to SimBatch
venkywonka Mar 31, 2026
20c5e65
feat: Extend PredictorConfig with aiconfigurator fields and validator
venkywonka Mar 31, 2026
74a20bc
feat: Add AIConfiguratorPredictor with H100 database tests
venkywonka Mar 31, 2026
fea757e
feat: Wire AIConfiguratorPredictor into sim executor with predictor f…
venkywonka Mar 31, 2026
09e2115
feat: Add SimClock to replace time.sleep in simulation mode (Phase 3)
venkywonka Apr 1, 2026
2f70180
docs: Add Deep-Sim vision doc and update v1 roadmap
venkywonka Apr 1, 2026
99763a8
feat: Force single-process sim mode with SimDistributed (Phase 3.5)
venkywonka Apr 1, 2026
07f1c88
docs: Add Phase 3/3.5 what-was-built doc
venkywonka Apr 1, 2026
e108139
docs: Add Phase 4 metrics output design spec
venkywonka Apr 2, 2026
9560d20
feat: Add per-request and per-iteration metrics to sim mode (Phase 4)
venkywonka Apr 2, 2026
21d0e29
docs: Add Phase 4 what-was-built doc and update CLAUDE.local.md
venkywonka Apr 2, 2026
4e6df3b
docs: Reorder roadmap — CLI (Phase 5) before arrival modeling (Phase 6)
venkywonka Apr 2, 2026
b6cb121
docs: Add Phase 5 CLI integration design spec
venkywonka Apr 2, 2026
2a76950
feat: Add trtllm-bench --sim CLI integration (Phase 5)
venkywonka Apr 2, 2026
ba68f49
docs: Add Phase 5 what-was-built doc and update CLAUDE.local.md
venkywonka Apr 2, 2026
7638faa
docs: Update roadmap (Phase 5 done) and add fidelity/limitations doc
venkywonka Apr 2, 2026
a747880
docs: Update roadmap with post-v1 findings and refined future phases
venkywonka Apr 2, 2026
84a9c20
chore: Remove slop/ from git tracking, add to .gitignore
venkywonka Apr 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__/
.vscode
.cursor
slop/
*.engine
*.engine.config
*.cache
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def _init(log_level: object = None) -> None:
"\nFATAL: Decoding operators failed to load. This may be caused by an incompatibility "
"between PyTorch and TensorRT-LLM. Please rebuild and install TensorRT-LLM."
)
raise ImportError(str(e) + msg)
if os.getenv("TRTLLM_ALLOW_MISSING_OPS", "0") != "0":
logger.warning(str(e) + msg)
else:
raise ImportError(str(e) + msg)

MpiComm.local_init()

Expand Down
171 changes: 171 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from .model_engine import PyTorchModelEngine
from .model_loader import ModelLoader, _construct_checkpoint_loader
from .py_executor import PyExecutor
from .sim_model_engine import SimModelEngine
from .sim_sampler import SimSampler


class _ExecutorMemoryMonitor:
Expand Down Expand Up @@ -224,6 +226,171 @@ def get_guided_decoding_config(guided_decoding_backend: str,
return guided_decoding_config


def _create_sim_py_executor(
llm_args: TorchLlmArgs,
checkpoint_dir: str,
checkpoint_loader,
) -> PyExecutor:
"""Create a PyExecutor in simulation mode.

Loads only the HF model config (no weights), creates SimModelEngine
and SimSampler, but uses the real KV cache manager and scheduler.
"""
from ._util import KvCacheCreator, create_py_executor_instance
from ..attention_backend.interface import AttentionRuntimeFeatures
from .sim_distributed import SimDistributed

# Sim mode always skips KV cache estimation — we don't need precise
# sizing, and the estimation warmup triggers an executor shutdown/restart
# cycle that is unnecessary overhead for simulation.
skip_est = True

# Use the mapping from config but force rank=0. In sim mode we run
# single-process; TP/PP are config parameters, not distributed runtime.
mapping = copy.deepcopy(llm_args.parallel_config.to_mapping())
mapping.rank = 0
dist = SimDistributed(mapping)

# Load model config to get vocab_size and model-specific params
config_kwargs = {
'trust_remote_code': True,
'mm_encoder_only': llm_args.mm_encoder_only,
}
if llm_args.parallel_config:
config_kwargs['mapping'] = llm_args.parallel_config.to_mapping()
model_config = checkpoint_loader.load_config(checkpoint_dir,
**config_kwargs)
vocab_size = model_config.pretrained_config.vocab_size

(
max_beam_width,
max_num_tokens,
max_seq_len,
max_batch_size,
) = llm_args.get_runtime_sizes()

# max_seq_len may be None if not set by user — derive from model config
if max_seq_len is None:
max_seq_len = model_config.pretrained_config.max_position_embeddings

max_num_sequences = max_batch_size * mapping.pp_size

kv_cache_config = llm_args.kv_cache_config
tokens_per_block = kv_cache_config.tokens_per_block

# Sim engine and sampler — no model weights loaded
sim_config = llm_args.sim_config
pc = sim_config.predictor
if pc.name == "constant":
from .sim_predictor import ConstantPredictor
predictor = ConstantPredictor(
prefill_time_ms=pc.constant_prefill_time_ms,
decode_time_ms=pc.constant_decode_time_ms)
elif pc.name == "aiconfigurator":
from .sim_predictor_aic import AIConfiguratorPredictor
predictor = AIConfiguratorPredictor(
model_path=checkpoint_dir,
device_name=pc.device_name,
backend_version=pc.backend_version,
database_path=pc.database_path,
tp_size=mapping.tp_size,
prefill_scale_factor=pc.prefill_scale_factor,
decode_scale_factor=pc.decode_scale_factor)
else:
raise ValueError(f"Unknown predictor name: {pc.name}")

from .sim_clock import SimClock
clock = SimClock()

model_engine = SimModelEngine(llm_args, vocab_size, max_num_sequences,
time_predictor=predictor, clock=clock)
sampler = SimSampler(clock=clock)

# We need a minimal model shim so KvCacheCreator can read model_config
# to determine layer count, num_kv_heads, head_size, etc.
class _SimModelShim:

def __init__(self, config):
self.model_config = config # Already a ModelConfig

def named_modules(self):
return iter([])

model_engine.model = _SimModelShim(model_config)
model_engine.max_seq_len = max_seq_len
model_engine.max_num_tokens = max_num_tokens
model_engine.batch_size = max_batch_size
model_engine.max_beam_width = max_beam_width
model_engine.mapping = mapping
model_engine.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER
model_engine.attn_runtime_features = AttentionRuntimeFeatures(
chunked_prefill=llm_args.enable_chunked_prefill,
cache_reuse=kv_cache_config.enable_block_reuse,
)

# Real KV cache — scheduler needs it for capacity decisions
resources = {}
execution_stream = torch.cuda.Stream()

kv_cache_creator = KvCacheCreator(
model_engine=model_engine,
draft_model_engine=None,
mapping=mapping,
net_max_seq_len=max_seq_len,
kv_connector_manager=None,
max_num_tokens=max_num_tokens,
max_beam_width=max_beam_width,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
llm_args=llm_args,
speculative_config=None,
profiling_stage_data=None,
sparse_attention_config=None,
execution_stream=execution_stream,
draft_config=None,
skip_est=skip_est,
)
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
kv_cache_creator.build_managers(resources, estimating_kv_cache)
max_seq_len = kv_cache_creator._max_seq_len

scheduler_config = llm_args.scheduler_config

ctx_chunk_config = None
if llm_args.enable_chunked_prefill:
ctx_chunk_config = ContextChunkingPolicy.FIRST_COME_FIRST_SERVED

py_executor = create_py_executor_instance(
dist=dist,
resources=resources,
mapping=mapping,
llm_args=llm_args,
ctx_chunk_config=ctx_chunk_config,
model_engine=model_engine,
start_worker=False,
sampler=sampler,
drafter=None,
guided_decoder=None,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_num_tokens=max_num_tokens,
scheduler_config=scheduler_config,
execution_stream=execution_stream,
)

if estimating_kv_cache:
logger.warning("[SimMode] KV cache estimation requested but skipped "
"in sim mode")

py_executor.start_worker()
sim_config._clock = clock
logger.info("[SimMode] PyExecutor created in simulation mode (clock enabled)")
return py_executor


def create_py_executor(
llm_args: TorchLlmArgs,
checkpoint_dir: Optional[str] = None,
Expand Down Expand Up @@ -257,6 +424,10 @@ def create_py_executor(
llm_args = ModelLoader.load_config_and_apply_defaults(
checkpoint_dir, llm_args, checkpoint_loader)

if llm_args.sim_config is not None:
return _create_sim_py_executor(llm_args, checkpoint_dir,
checkpoint_loader)

garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold
lora_config = llm_args.lora_config
kv_connector_config = llm_args.kv_connector_config
Expand Down
131 changes: 131 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/sim_clock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simulated clock for simulation mode."""

from __future__ import annotations

import json
import os
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .sim_metrics import SimIterationRecord, SimRequestStats


class SimClock:
"""Accumulates predicted iteration times and per-request/per-iteration data.

Not a singleton — owned by SimModelEngine as an instance attribute.
"""

def __init__(self):
self._total_time_s: float = 0.0
self._num_iterations: int = 0
self._iterations: list[SimIterationRecord] = []
self._request_stats: dict[int, SimRequestStats] = {}

def step(self, duration_s: float) -> None:
"""Advance clock by one iteration's predicted duration."""
self._total_time_s += duration_s
self._num_iterations += 1

def record_iteration(self, predicted_duration_s: float, num_ctx_req: int,
num_ctx_tokens: int, num_gen_req: int) -> None:
"""Record per-iteration breakdown after a step."""
from .sim_metrics import SimIterationRecord
self._iterations.append(
SimIterationRecord(iteration=self._num_iterations,
sim_time_s=self._total_time_s,
predicted_duration_s=predicted_duration_s,
num_context_requests=num_ctx_req,
num_context_tokens=num_ctx_tokens,
num_generation_requests=num_gen_req))

def register_request(self, request_id: int, input_length: int,
created_time: float = 0.0) -> None:
"""Register a new request for tracking. No-op if already registered."""
from .sim_metrics import SimRequestStats
if request_id not in self._request_stats:
self._request_stats[request_id] = SimRequestStats(
request_id=request_id,
input_length=input_length,
created_time=created_time)

def record_token(self, request_id: int) -> None:
"""Record a generated token timestamp for the given request."""
stats = self._request_stats[request_id]
stats.gen_token_times.append(self._total_time_s)
stats.output_length += 1

@property
def total_time_s(self) -> float:
return self._total_time_s

@property
def num_iterations(self) -> int:
return self._num_iterations

@property
def iterations(self) -> list:
"""List of SimIterationRecord objects."""
return self._iterations

@property
def request_stats(self) -> dict:
"""Dict of request_id -> SimRequestStats."""
return self._request_stats

@property
def metrics(self) -> dict:
"""Compute HiSim-compatible metrics from recorded data."""
from .sim_metrics import calc_sim_metrics
return calc_sim_metrics(self._request_stats, self._iterations)

def write_metrics(self, output_dir: str) -> None:
"""Write metrics.json, request.jsonl, iteration.jsonl to output_dir."""
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "metrics.json"), "w") as f:
json.dump(self.metrics, f, indent=2)
with open(os.path.join(output_dir, "request.jsonl"), "w") as f:
for s in self._request_stats.values():
f.write(
json.dumps({
"request_id": s.request_id,
"input_length": s.input_length,
"output_length": s.output_length,
"created_time": s.created_time,
"gen_token_times": s.gen_token_times,
"ttft_ms": s.ttft_s * 1000,
"tpot_ms": s.tpot_s * 1000,
"e2e_ms": s.e2e_s * 1000,
}) + "\n")
with open(os.path.join(output_dir, "iteration.jsonl"), "w") as f:
for r in self._iterations:
f.write(
json.dumps({
"iteration": r.iteration,
"sim_time_s": r.sim_time_s,
"predicted_duration_s": r.predicted_duration_s,
"num_context_requests": r.num_context_requests,
"num_context_tokens": r.num_context_tokens,
"num_generation_requests":
r.num_generation_requests,
}) + "\n")

def reset(self) -> None:
self._total_time_s = 0.0
self._num_iterations = 0
self._iterations = []
self._request_stats = {}
Loading
Loading