From c0e738b1178f1c7f8e231d49c8b5f49b6d99a10a Mon Sep 17 00:00:00 2001 From: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> Date: Thu, 2 Apr 2026 11:34:27 -0700 Subject: [PATCH 1/5] paperclip infra changes (extracted from main paperclip branch) Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> --- .claude/agents/ad-onboard-reviewer.md | 9 + AGENTS.md | 8 + examples/auto_deploy/build_and_run_ad.py | 229 +++++++++++------- .../_torch/auto_deploy/config/default.yaml | 4 +- .../custom_ops/attention/trtllm_attention.py | 3 +- .../custom_ops/attention_interface.py | 57 +++-- .../custom_ops/utils/torch_gather_logits.py | 14 +- tensorrt_llm/_torch/auto_deploy/llm.py | 19 +- tensorrt_llm/_torch/auto_deploy/llm_args.py | 11 +- .../_torch/auto_deploy/shim/interface.py | 5 +- .../transform/library/_onnx_schemas.py | 16 +- .../smoke/test_ad_build_small_multi.py | 14 +- .../custom_ops/test_resource_handlers.py | 14 +- .../shim/test_cached_sequence_interface.py | 55 ++++- .../auto_deploy/singlegpu/shim/test_engine.py | 10 + .../singlegpu/shim/test_llm_config.py | 22 +- .../smoke/test_ad_guided_decoding_regex.py | 2 +- .../singlegpu/smoke/test_ad_trtllm_sampler.py | 2 +- .../test_fuse_trtllm_attention_quant_fp8.py | 2 + .../library/test_gated_delta_rule_cache.py | 1 + .../transformations/library/test_kv_cache.py | 5 + .../library/test_mrope_delta_cache.py | 1 + .../test_torch_gated_delta_rule_cache.py | 1 + 23 files changed, 329 insertions(+), 175 deletions(-) diff --git a/.claude/agents/ad-onboard-reviewer.md b/.claude/agents/ad-onboard-reviewer.md index feedc4aeef4..660379fbbb7 100644 --- a/.claude/agents/ad-onboard-reviewer.md +++ b/.claude/agents/ad-onboard-reviewer.md @@ -44,6 +44,15 @@ Read the actual source code for each check. Cite `file:line_number` for every PA Note: BB1–BB2 only apply if the HF source indicates the model is multi-modal (has image/audio inputs). Mark N/A with justification for pure language models. +### BB. Vision / Multi-Modal Support + +| # | Check | How to verify | +|---|-------|---------------| +| BB1 | If the model has a vision tower (multi-modal), the full `nn.Module` hierarchy for the vision component is present in the modeling file — it is NOT omitted, stubbed out, or replaced with a `pass` body | Grep for vision-related class names (e.g., `VisionTower`, `ViT`, `CLIPVision`, `SiglipVision`) from the HF source. If the model is multi-modal and none appear, flag as FAIL. | +| BB2 | The test file asserts that vision-related weight keys are present in the model's `state_dict` after `load_state_dict` | Grep the test file for assertions on vision weight key names (or a check that vision-prefixed keys are in the loaded state_dict). Absence of any such assertion is a FAIL for multi-modal models. | + +Note: BB1–BB2 only apply if the HF source indicates the model is multi-modal (has image/audio inputs). Mark N/A with justification for pure language models. + ### C. Ops & Compatibility (STRICT — canonical ops are the backbone of AD) | # | Check | How to verify | diff --git a/AGENTS.md b/AGENTS.md index 39ea7e4fb6f..fd8938edaab 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -123,6 +123,14 @@ HuggingFace Model → LLM API → Executor (PyTorch/AutoDeploy/TensorRT) - Target `main` unless fixing a release branch bug - See `CONTRIBUTING.md` for full PR policies +### GitHub CLI authentication (`GH_CONFIG_DIR`) + +The `gh` CLI uses `~/.config/gh` by default for authentication. Different GitHub hosts or forks may require a different config directory. **Before running any `gh` command** (e.g., `gh pr create`, `gh api`, `gh pr comment`): + +1. Check if the user has specified a custom `GH_CONFIG_DIR` (e.g., in `CLAUDE.local.md` or environment). If so, use it. +2. If not explicitly set, **ask the user** whether the default `~/.config/gh` is correct or if a different directory should be used. This is especially relevant when the PR target is a fork (e.g., `nv-auto-deploy/TensorRT-LLM`) rather than `NVIDIA/TensorRT-LLM`. +3. Prefix all `gh` commands with the resolved config dir: `GH_CONFIG_DIR= gh ...` + ## CI / Testing See [CI overview](docs/source/developer-guide/ci-overview.md) for full details. diff --git a/examples/auto_deploy/build_and_run_ad.py b/examples/auto_deploy/build_and_run_ad.py index 9a6088df386..34eaf461679 100644 --- a/examples/auto_deploy/build_and_run_ad.py +++ b/examples/auto_deploy/build_and_run_ad.py @@ -1,5 +1,8 @@ """Main entrypoint to build, test, and prompt AutoDeploy inference models.""" +import json +import sys +from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Union import torch @@ -20,16 +23,21 @@ DynamicYamlMixInForSettings, deep_merge_dicts, ) -from tensorrt_llm._torch.auto_deploy.utils.benchmark import benchmark, store_benchmark_results from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.sampling_params import SamplingParams +# Registry paths +_REGISTRY_DIR = Path(__file__).resolve().parent / "model_registry" +_REGISTRY_YAML = _REGISTRY_DIR / "models.yaml" +_REGISTRY_CONFIGS_DIR = _REGISTRY_DIR / "configs" + # Global torch config, set the torch compile cache to fix up to llama 405B torch._dynamo.config.cache_size_limit = 20 -# simple string, TRT-LLM style text-only prompt or full-scale HF message template -PromptInput = Union[str, Dict, List[Dict]] + +# A single query is either a plain string or a full HF chat message template. +PromptInput = Union[str, List[Dict]] class PromptConfig(BaseModel): @@ -37,30 +45,28 @@ class PromptConfig(BaseModel): This configuration class can be used for this example script to configure the example prompts and the sampling parameters. + + Queries can be plain strings or HF-style chat message lists + (``[{"role": "user", "content": "..."}]``). Plain-string queries are automatically wrapped in + a chat template when the model's tokenizer supports one. """ - batch_size: int = Field(default=2, description="Number of queries") + batch_size: int = Field(default=10, description="Number of queries") queries: Union[PromptInput, List[PromptInput]] = Field( default_factory=lambda: [ - # OPTION 1: simple text prompt "How big is the universe? ", - # OPTION 2: wrapped text prompt for TRT-LLM - {"prompt": "In simple words and a single sentence, explain the concept of gravity: "}, - # OPTION 3: a full-scale HF message template (this one works for text-only models!) - # Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating - # and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal - [ - { - "role": "user", - "content": "How to fix slicing in golf?", - } - ], - # More prompts... - {"prompt": "Where is the capital of Iceland? "}, + "In simple words and a single sentence, explain the concept of gravity: ", + "How to fix slicing in golf? ", + "Where is the capital of Iceland? ", + "What are the three laws of thermodynamics? ", + "Summarize the plot of Romeo and Juliet in two sentences: ", + "Write a Python function that checks if a number is prime.", + "Explain the difference between a compiler and an interpreter: ", + "What causes the northern lights? ", + "What are the health benefits of drinking green tea?", ], - description="Example queries to prompt the model with. We support both TRT-LLM text-only " - "queries via the 'prompt' key and full-scale HF message template called via " - "apply_chat_template.", + description="Plain-text queries or HF-style chat message lists. Plain strings are " + "automatically wrapped as chat messages when the model's tokenizer has a chat template.", ) sp_kwargs: Dict[str, Any] = Field( default_factory=lambda: {"max_tokens": 100, "top_k": None, "temperature": 1.0}, @@ -69,33 +75,19 @@ class PromptConfig(BaseModel): ) def model_post_init(self, __context: Any): - """Cut queries to batch_size. + """Repeat and truncate queries to match batch_size. NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field validators are only run if a value is provided. """ - queries = self.queries if isinstance(self.queries, list) else [self.queries] - batch_size = self.batch_size - queries = queries * (batch_size // len(queries) + 1) - queries = queries[:batch_size] - - # now let's standardize the queries for the LLM api to understand them - queries_processed = [] - for query in queries: - if isinstance(query, str): - queries_processed.append({"prompt": query}) - elif isinstance(query, dict): - queries_processed.append(query) - elif isinstance(query, list): - queries_processed.append( - { - "prompt": "Fake prompt. Check out messages field for the HF chat template.", - "messages": query, # contains the actual HF chat template - } - ) - else: - raise ValueError(f"Invalid query type: {type(query)}") - self.queries = queries_processed + queries = self.queries + if isinstance(queries, str): + queries = [queries] + elif isinstance(queries, list) and queries and isinstance(queries[0], dict): + # single HF message template, e.g. [{"role": "user", "content": "..."}] + queries = [queries] + queries = queries * (self.batch_size // len(queries) + 1) + self.queries = queries[: self.batch_size] @field_validator("sp_kwargs", mode="after") @classmethod @@ -106,21 +98,10 @@ def validate_sp_kwargs(cls, sp_kwargs): class BenchmarkConfig(BaseModel): - """Benchmark configuration. + """Configuration for storing results.""" - This configuration class can be used for this example script to configure the simple - benchmarking we run at the end of the script. - """ - - enabled: bool = Field(default=False, description="If true, run simple benchmark") - num: int = Field(default=10, ge=1, description="By default run 10 times and get average") - isl: int = Field(default=2048, ge=1, description="Input seq length for benchmarking") - osl: int = Field(default=128, ge=1, description="Output seq length for benchmarking") - bs: int = Field(default=1, ge=1, description="Batch size for benchmarking") results_path: Optional[str] = Field(default="./benchmark_results.json") - store_results: bool = Field( - default=False, description="If True, store benchmark res in benchmark_results_path" - ) + store_results: bool = Field(default=False, description="If True, store results to results_path") class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings): @@ -225,17 +206,30 @@ def sync_prompt_batch_size_with_args_max_batch_size(cls, prompt: PromptConfig, i args.max_batch_size = prompt.batch_size return prompt - @field_validator("benchmark", mode="after") - @classmethod - def adjust_args_for_benchmark(cls, benchmark: BenchmarkConfig, info): - if "args" not in info.data: - return benchmark - args: LlmArgs = info.data["args"] - if benchmark.enabled: - # propagate benchmark settings to args - args.max_batch_size = max(benchmark.bs, args.max_batch_size) - args.max_seq_len = max(args.max_seq_len, benchmark.isl + benchmark.osl) - return benchmark + +def get_registry_yaml_extra(model_name: str) -> List[str]: + """Look up a model in the registry and return its resolved yaml_extra config paths. + + Args: + model_name: HuggingFace model id as listed in the registry (e.g. ``meta-llama/Llama-3.1-8B-Instruct``). + + Returns: + List of absolute paths to the yaml config files for the model. + + Raises: + KeyError: If the model is not found in the registry. + """ + with open(_REGISTRY_YAML) as f: + registry = yaml.safe_load(f) + + for entry in registry.get("models", []): + if entry["name"] == model_name: + return [str(_REGISTRY_CONFIGS_DIR / cfg) for cfg in entry.get("yaml_extra", [])] + + raise KeyError( + f"Model '{model_name}' not found in the AutoDeploy model registry ({_REGISTRY_YAML}). " + "Either add it to the registry or provide --yaml-extra directly." + ) def build_llm_from_config(config: ExperimentConfig) -> LLM: @@ -249,6 +243,36 @@ def build_llm_from_config(config: ExperimentConfig) -> LLM: return llm +def prepare_queries(queries: List[PromptInput], tokenizer=None) -> List[Dict]: + """Prepare queries for the LLM API. + + Queries that are already HF-style message lists (``List[Dict]``) are passed through directly. + Plain-string queries are wrapped as HF chat messages when the tokenizer has a chat template, + or passed as plain text prompts otherwise. + """ + has_chat_template = getattr(tokenizer, "chat_template", None) is not None + + prepared = [] + for query in queries: + if isinstance(query, list): + prepared.append( + { + "prompt": query[0].get("content", "") if query else "", + "messages": query, + } + ) + elif has_chat_template: + prepared.append( + { + "prompt": query, + "messages": [{"role": "user", "content": query}], + } + ) + else: + prepared.append({"prompt": query}) + return prepared + + def print_outputs(outs: Union[RequestOutput, List[RequestOutput]]) -> List[List[str]]: prompts_and_outputs: List[List[str]] = [] if isinstance(outs, RequestOutput): @@ -260,8 +284,45 @@ def print_outputs(outs: Union[RequestOutput, List[RequestOutput]]) -> List[List[ return prompts_and_outputs +def _inject_registry_yaml_extra() -> None: + """If ``--use-registry`` is in sys.argv, replace it with the resolved ``--yaml-extra`` entries. + + This allows callers to simply run:: + + python build_and_run_ad.py --model --use-registry + + instead of manually specifying every ``--yaml-extra`` path. The flag is consumed here and the + resolved paths are injected back into ``sys.argv`` before pydantic-settings parses them. + """ + if "--use-registry" not in sys.argv: + return + + # Extract model name from argv (support both --model=X and --model X forms) + model_name: Optional[str] = None + for i, arg in enumerate(sys.argv): + if arg.startswith("--model="): + model_name = arg.split("=", 1)[1] + break + if arg == "--model" and i + 1 < len(sys.argv): + model_name = sys.argv[i + 1] + break + + if model_name is None: + raise ValueError("--use-registry requires --model to be specified.") + + yaml_extra_paths = get_registry_yaml_extra(model_name) + + # Remove --use-registry and inject --yaml-extra --yaml-extra ... + # Each path needs its own flag because pydantic-settings CLI only captures one value per flag. + argv = [a for a in sys.argv if a != "--use-registry"] + for path in yaml_extra_paths: + argv += ["--args.yaml-extra", path] + sys.argv = argv + + def main(config: Optional[ExperimentConfig] = None): if config is None: + _inject_registry_yaml_extra() config: ExperimentConfig = CliApp.run(ExperimentConfig) ad_logger.info(f"AutoDeploy Experiment Config:\n{yaml.dump(config.model_dump())}") @@ -272,8 +333,10 @@ def main(config: Optional[ExperimentConfig] = None): # prompt the model and print its output ad_logger.info("Running example prompts...") + hf_tokenizer = getattr(llm.tokenizer, "tokenizer", None) + queries = prepare_queries(config.prompt.queries, hf_tokenizer) outs = llm.generate( - config.prompt.queries, + queries, sampling_params=SamplingParams(**config.prompt.sp_kwargs), ) results = { @@ -282,31 +345,11 @@ def main(config: Optional[ExperimentConfig] = None): # Add config values so they get logged to JET extra results.update(config.model_dump(mode="json")) - # run a benchmark for the model with batch_size == config.benchmark_bs - if config.benchmark.enabled and config.args.runtime != "trtllm": - ad_logger.info("Running benchmark...") - keys_from_args = [] - fields_to_show = [f"benchmark={config.benchmark}"] - fields_to_show.extend([f"{k}={getattr(config.args, k)}" for k in keys_from_args]) - results["benchmark_results"] = benchmark( - func=lambda: llm.generate( - torch.randint(0, 100, (config.benchmark.bs, config.benchmark.isl)).tolist(), - sampling_params=SamplingParams( - max_tokens=config.benchmark.osl, - top_k=None, - ignore_eos=True, - ), - use_tqdm=False, - ), - num_runs=config.benchmark.num, - log_prefix="Benchmark with " + ", ".join(fields_to_show), - results_path=config.benchmark.results_path, - ) - elif config.benchmark.enabled: - ad_logger.info("Skipping simple benchmarking for trtllm...") - if config.benchmark.store_results: - store_benchmark_results(results, config.benchmark.results_path) + results_path = Path(config.benchmark.results_path) + results_path.parent.mkdir(parents=True, exist_ok=True) + with results_path.open("w") as f: + json.dump(results, f, indent=2) llm.shutdown() return results diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index dbc2dc8b4e8..0be815bd1ee 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -228,7 +228,7 @@ transforms: ########################################################################################### insert_cached_attention: stage: cache_init - backend: flashinfer + backend: trtllm insert_cached_mla_attention: stage: cache_init requires_shape_prop: true @@ -280,6 +280,6 @@ transforms: expect_mem_change: true run_per_gm: false cuda_graph_batch_sizes: null - backend: torch-compile + backend: torch-cudagraph piecewise_enabled: false piecewise_num_tokens: null diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py index d0c8a1dd8c0..437eb8b0de1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py @@ -25,6 +25,7 @@ - All possible "constants" inferred from tensor shapes at runtime """ +import math from typing import List, Optional, Tuple import torch @@ -462,7 +463,7 @@ def trtllm_mha_with_cache( 1, # beam_width int(AttentionMaskType.causal), # mask_type quant_mode, # quant_mode - 1.0, # q_scaling + scale * math.sqrt(head_dim) if scale is not None else 1.0, # q_scaling 0, # position_embedding_type 0, # rotary_embedding_dim 10000.0, # rotary_embedding_base diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 748ce06ba22..c6ca285f6a5 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -403,16 +403,18 @@ def __init__(self, batch_info_host: Optional[torch.Tensor] = None): BatchInfo._NUM_ELEMENTS, dtype=torch.int, pin_memory=prefer_pinned() ) self._batch_info_host = batch_info_host - self._batch_info_np = batch_info_host.numpy() # same storage! + # Use the tensor view directly so fake tensors can flow through + # torch.compile metadata tracing without requiring a real .numpy() view. + self._batch_info = batch_info_host def serialize(self) -> torch.Tensor: return self._batch_info_host def update(self, batch_info: List[int]) -> None: - self._batch_info_np[:6] = batch_info + self._batch_info[:6] = torch.as_tensor(batch_info, dtype=self._batch_info.dtype) def is_generate_only(self) -> bool: - return self._batch_info_np[:4].sum().item() == 0 + return self._batch_info[:4].sum().item() == 0 def get_total_num_sequences(self) -> int: return sum(self.get_num_sequences()) @@ -429,14 +431,14 @@ def get_absorbed_info(self) -> Tuple[int, int, int]: def get_num_sequences(self) -> Tuple[int, int, int]: """Get the number of prefill, extend, and decode sequences.""" - num_prefill, num_extend, num_decode = self._batch_info_np[:6:2].tolist() + num_prefill, num_extend, num_decode = self._batch_info[:6:2].tolist() return num_prefill, num_extend, num_decode def get_total_num_tokens(self) -> int: return sum(self.get_num_tokens()) def get_num_tokens(self) -> Tuple[int, int, int]: - prefill_tokens, extend_tokens, decode_tokens = self._batch_info_np[1:6:2].tolist() + prefill_tokens, extend_tokens, decode_tokens = self._batch_info[1:6:2].tolist() return prefill_tokens, extend_tokens, decode_tokens # --- max sequence info (slots 6-9) writers --- @@ -448,42 +450,48 @@ def update_max_seq_info( block_offset_multiplier: int, max_batch_size: int, ) -> None: - self._batch_info_np[6:10] = [ - max_context_length, - max_blocks_per_seq, - block_offset_multiplier, - max_batch_size, - ] + self._batch_info[6:10] = torch.tensor( + [ + max_context_length, + max_blocks_per_seq, + block_offset_multiplier, + max_batch_size, + ], + dtype=self._batch_info.dtype, + ) # --- max sequence info (slots 6-9) readers --- def get_max_seq_info(self) -> Tuple[int, int, int, int]: - return tuple(self._batch_info_np[6:10].tolist()) + return tuple(self._batch_info[6:10].tolist()) def get_max_context_length(self) -> int: - return int(self._batch_info_np[6]) + return int(self._batch_info[6]) def get_max_blocks_per_seq(self) -> int: - return int(self._batch_info_np[7]) + return int(self._batch_info[7]) def get_block_offset_multiplier(self) -> int: - return int(self._batch_info_np[8]) + return int(self._batch_info[8]) def get_max_batch_size(self) -> int: - return int(self._batch_info_np[9]) + return int(self._batch_info[9]) # --- tokens gather info (slots 10-11) writers --- def update_tokens_gather_info(self, num_tokens_to_gather: int, gather_required: bool) -> None: - self._batch_info_np[10:12] = [num_tokens_to_gather, int(gather_required)] + self._batch_info[10:12] = torch.tensor( + [num_tokens_to_gather, int(gather_required)], + dtype=self._batch_info.dtype, + ) # --- tokens gather info (slots 10-11) readers --- def get_num_tokens_to_gather(self) -> int: - return int(self._batch_info_np[10]) + return int(self._batch_info[10]) def is_gather_required(self) -> bool: - return bool(self._batch_info_np[11]) + return bool(self._batch_info[11]) class SequenceInfo: @@ -568,8 +576,8 @@ def __init__( self, max_seq_len: int, max_batch_size: int, + max_num_tokens: int, tokens_per_block: Optional[int] = None, - max_num_tokens: Optional[int] = None, vocab_size_padded: Optional[int] = None, ): """Initialize the SequenceInfo object. @@ -579,13 +587,13 @@ def __init__( includes the tokens in the input sequence and the tokens generated by the model. max_batch_size: corresponds to the maximum number of sequences (or requests) that the model can process. - tokens_per_block: corresponds to the tokens per block of the cache. max_num_tokens: corresponds to the maximum number of tokens that the model can process across all sequences in the batch. If a batch is composed of context-only requests of input sequence length ISL, then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens // ISL). Similarly, if a batch is composed of generate-only requests, then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens). + tokens_per_block: corresponds to the tokens per block of the cache. vocab_size_padded: corresponds to the padded vocabulary size of the model. Returns: None @@ -596,10 +604,7 @@ def __init__( self.max_batch_size = max(2, max_batch_size) self.tokens_per_block = tokens_per_block or max_seq_len self.max_blocks_per_seq = math.ceil(max_seq_len / self.tokens_per_block) - # NOTE (lucaslie): +1 is a WAR to address issue when using flashinfer attention with - # (max_batch_size, max_seq_len) input in trtllm runtime. - # see https://github.com/NVIDIA/TensorRT-LLM/issues/4504 - self.max_num_tokens = max_num_tokens or (max_seq_len + 1) * max_batch_size + self.max_num_tokens = max_num_tokens # will store num_blocks later... self._num_blocks = None @@ -1248,7 +1253,7 @@ def maybe_gather_and_squeeze(self, token_tnsr: torch.Tensor) -> torch.Tensor: self.get_arg("token_gather_indices"), self.batch_info.serialize(), ) - return self.flatten(token_tnsr) + return token_tnsr.reshape(token_tnsr.shape[0] * token_tnsr.shape[1], *token_tnsr.shape[2:]) @nvtx_range("ad_unnest_sequences") def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/utils/torch_gather_logits.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/utils/torch_gather_logits.py index 055f8e5f9cf..160d7e6fb2f 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/utils/torch_gather_logits.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/utils/torch_gather_logits.py @@ -27,16 +27,14 @@ def gather_tokens( """Gather hidden states using token_gather_indices before LM head. Args: - hidden_states: Hidden states tensor of shape [batch_size, 1, *other_dims] or - [1, total_token_length, *other_dims] + hidden_states: Hidden states tensor of shape [batch_size, 1, *other_dims], + [1, total_token_length, *other_dims], or [batch_size, seq_len, *other_dims]. token_gather_indices: indices for gathering logits. batch_info_host: BatchInfo tensor containing tokens_gather_info. Returns: Gathered and flattened hidden states [num_gathered_tokens, hidden] """ - # final shape is [total_tokens, *other_dims] bsz, sl, *other_dims = hidden_states.shape - assert bsz == 1 or sl == 1, "expected batch size or sequence length to be 1" hidden_states = hidden_states.view(bsz * sl, *other_dims) batch_info = BatchInfo(batch_info_host) @@ -49,10 +47,12 @@ def gather_tokens( else: out = hidden_states.clone(memory_format=torch.contiguous_format) num_tokens_final = bsz * sl - if bsz == 1: - return out.view(1, num_tokens_final, *other_dims) - else: + # Generate-only batches use [batch, 1, ...] and need to preserve batch-major layout for the + # downstream squeeze. Any shape with seq_len > 1 is treated as a flattened token batch. + if sl == 1 and bsz > 1: return out.view(num_tokens_final, 1, *other_dims) + else: + return out.view(1, num_tokens_final, *other_dims) @gather_tokens.register_fake diff --git a/tensorrt_llm/_torch/auto_deploy/llm.py b/tensorrt_llm/_torch/auto_deploy/llm.py index 905efe1135f..9d342c2424b 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm.py +++ b/tensorrt_llm/_torch/auto_deploy/llm.py @@ -46,17 +46,27 @@ def __call__( # multi_modal_data should not be present in the messages field assert "multi_modal_data" not in inputs, f"unexpected multi_modal_data key in {inputs=}" + # Normalize message content to list-of-dicts format only for multimodal + # processors (e.g., Llama4) that expect {"type": "text", "text": "..."} + # instead of plain strings when tokenize=True. Text-only models' Jinja2 + # chat templates expect content as a plain string. + messages = inputs["messages"] + if hasattr(self.processor, "image_processor"): + for msg in messages: + if isinstance(msg.get("content"), str): + msg["content"] = [{"type": "text", "text": msg["content"]}] + # TODO: we don't really need this but it makes for a good sanity check. Consider # removing this in the future if we need to speed things up. prompt = self.processor.apply_chat_template( - inputs["messages"], + messages, add_generation_prompt=True, tokenize=False, ) inputs["prompt"] = prompt all_args = self.processor.apply_chat_template( - inputs["messages"], + messages, add_generation_prompt=True, tokenize=True, return_dict=True, @@ -87,6 +97,11 @@ def __call__( if all_args is not None: # TODO: is there a more reliable way to avoid the attention_mask here? all_args.pop("attention_mask", None) + # token_type_ids is produced by some tokenizers (e.g. Hunyuan) but is not + # a graph input for decoder-only causal LMs; drop it here so it does not get + # forwarded as an extra_arg to the exported model, which would cause a kwarg + # keyword mismatch at inference time. + all_args.pop("token_type_ids", None) # TODO: can we avoid the extra tolist() here eventually? token_ids = all_args.pop("input_ids") diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 99e97aaf9d0..8e9ae60ad22 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -260,11 +260,11 @@ def validate_and_init_tokenizer(self): ### SHORTCUTS FOR COMMON INFERENCE OPTIMIZER CONFIGS ########################################### attn_backend: str = Field( - default="flashinfer", + default="trtllm", description=_shortcut_description("Attention backend to use.", "attn_backend"), ) compile_backend: str = Field( - default="torch-compile", + default="torch-cudagraph", description=_shortcut_description( "The backend to use for compiling the model.", "compile_backend" ), @@ -280,8 +280,8 @@ def validate_and_init_tokenizer(self): ) ### SEQUENCE INTERFACE CONFIG ################################################################## - max_seq_len: int = Field(default=512, ge=1, description="The maximum sequence length.") - max_batch_size: int = Field(default=8, ge=1, description="The maximum batch size.") + max_seq_len: int = Field(default=2048, ge=1, description="The maximum sequence length.") + max_batch_size: int = Field(default=64, ge=1, description="The maximum batch size.") def model_dump(self, *args, **kwargs): """Convert the arguments to a dictionary that can be used as kwargs for the LLM API.""" @@ -338,8 +338,7 @@ def update_cuda_graph_batch_sizes(self): # if not set, use heuristic if self.cuda_graph_batch_sizes is None: cg_bs = {1, self.max_batch_size} - # Only add batch sizes up to max_batch_size - cg_bs.update(range(1, min(128, self.max_batch_size) + 1, 16)) + cg_bs.update(range(16, min(128, self.max_batch_size) + 1, 16)) cg_bs.update(range(128, self.max_batch_size + 1, 128)) else: cg_bs = [b for b in self.cuda_graph_batch_sizes if b <= self.max_batch_size] diff --git a/tensorrt_llm/_torch/auto_deploy/shim/interface.py b/tensorrt_llm/_torch/auto_deploy/shim/interface.py index 8c60753aaab..61bd5d9456c 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/interface.py @@ -57,9 +57,9 @@ def __init__( self, max_seq_len: int, max_batch_size: int, + max_num_tokens: int, device: Optional[DeviceLikeType] = None, kv_cache_config: Optional[KvCacheConfig] = None, - max_num_tokens: Optional[int] = None, vocab_size_padded: Optional[int] = None, spec_config=None, ) -> None: @@ -68,10 +68,9 @@ def __init__( Args: max_seq_len: Maximum sequence length including input and generated tokens. max_batch_size: Maximum number of sequences (requests) that can be processed. + max_num_tokens: Maximum total tokens across all sequences. device: Target device for tensors. Defaults to "cuda". kv_cache_config: KV cache configuration. If None, uses default KvCacheConfig. - max_num_tokens: Maximum total tokens across all sequences. If None, computed from - max_seq_len and max_batch_size. vocab_size_padded: Padded vocabulary size of the model. spec_config: Speculative decoding configuration. Used to set num_extra_kv_tokens, max_draft_len, max_total_draft_tokens on KVCacheManager after creation. diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py b/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py index 0a13f780f81..2efa03b5523 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py @@ -253,6 +253,16 @@ def register_onnx_schemas(): """Register ONNX custom ops.""" - defs.register_schema(_torch_rope_with_explicit_cos_sin_schema) - defs.register_schema(_torch_attention_schema) - defs.register_schema(_attention_plugin_schema) + registered = { + (schema.name, schema.domain, schema.since_version) + for schema in defs.get_all_schemas_with_history() + } + + for schema in ( + _torch_rope_with_explicit_cos_sin_schema, + _torch_attention_schema, + _attention_plugin_schema, + ): + key = (schema.name, schema.domain, schema.since_version) + if key not in registered: + defs.register_schema(schema) diff --git a/tests/unittest/auto_deploy/multigpu/smoke/test_ad_build_small_multi.py b/tests/unittest/auto_deploy/multigpu/smoke/test_ad_build_small_multi.py index bc2769d617c..b4d589df787 100644 --- a/tests/unittest/auto_deploy/multigpu/smoke/test_ad_build_small_multi.py +++ b/tests/unittest/auto_deploy/multigpu/smoke/test_ad_build_small_multi.py @@ -20,18 +20,14 @@ }, }, ), - ( - "meta-llama/Meta-Llama-3.1-8B-Instruct", - { - "transforms": { - "transformers_replace_cached_attn": {"backend": "flashinfer"}, - }, - "mode": "transformers", - }, - ), ], ) def test_build_ad(world_size: int, model_hub_id: str, llm_extra_args: dict): + # TODO: Revisit transformers-mode multigpu smoke coverage for Llama 3.1 specifically. + # This test keeps only the maintained graph-model AutoDeploy path for + # meta-llama/Meta-Llama-3.1-8B-Instruct. In this branch, the Llama path is served by + # branch-owned AD custom modeling rather than the stock HF modeling code, so the deprecated + # mode="transformers" path is no longer the right coverage target for this test. experiment_config = get_small_model_config(model_hub_id, **llm_extra_args) experiment_config["args"]["world_size"] = world_size diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py index eff235a7ad0..eea3963e981 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py @@ -51,7 +51,9 @@ def test_paged_handler_allocate_with_blocks(kv_layout): """Verify KVPagedResourceHandler.allocate() returns correct shape.""" handler = KVPagedResourceHandler(8, 64, dtype=torch.float16, kv_layout=kv_layout) tokens_per_block = 32 - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, tokens_per_block=tokens_per_block) + seq_info = SequenceInfo( + max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=tokens_per_block + ) seq_info.to("cuda") # Set up num_blocks via update_cache_information seq_info.update_cache_information(num_blocks=10) @@ -122,7 +124,7 @@ def test_state_handler_ssm_state_shape(): def test_state_handler_allocate_creates_tensor(): """Verify StateResourceHandler.allocate() creates tensor with correct shape.""" handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16) - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4) + seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4) seq_info.to("cuda") tensor = handler.allocate(seq_info) @@ -173,7 +175,11 @@ def test_unpaged_handler_allocate_returns_correct_shape(num_kv_heads, head_dim, max_seq_len = 128 handler = UnpagedResourceHandler(num_kv_heads, head_dim, dtype=dtype) - seq_info = SequenceInfo(max_seq_len=max_seq_len, max_batch_size=max_batch_size) + seq_info = SequenceInfo( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, + ) seq_info.to("cuda") tensor = handler.allocate(seq_info) @@ -187,7 +193,7 @@ def test_unpaged_handler_allocate_returns_correct_shape(num_kv_heads, head_dim, def test_unpaged_handler_allocate_correct_device(): """Verify UnpagedResourceHandler allocated tensor is on the correct device.""" handler = UnpagedResourceHandler(8, 64, dtype=torch.float16) - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4) + seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4) seq_info.to("cuda") tensor = handler.allocate(seq_info) diff --git a/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py b/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py index 6a5f8d1500c..2041233ebee 100644 --- a/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py +++ b/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py @@ -66,6 +66,7 @@ def test_init_creates_sequence_info_with_tokens_per_block(paged_kv_cache_config) interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -80,6 +81,7 @@ def test_init_uses_default_kv_cache_config_when_not_provided(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", ) @@ -108,6 +110,7 @@ def test_init_propagates_vocab_size_padded(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, vocab_size_padded=vocab_size_padded, device="cuda", ) @@ -120,6 +123,7 @@ def test_init_stores_device(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda:0", ) @@ -131,6 +135,7 @@ def test_init_default_device_is_cuda(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, ) assert interface.device == "cuda" @@ -146,6 +151,7 @@ def test_add_resource_paged_handler(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -162,6 +168,7 @@ def test_add_resource_state_handler(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -177,6 +184,7 @@ def test_add_resource_unpaged_handler(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -193,6 +201,7 @@ def test_add_multiple_resources(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -218,6 +227,7 @@ def test_initialize_resources_paged_only_creates_kv_cache_manager(paged_kv_cache interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -238,6 +248,7 @@ def test_initialize_resources_mixed_creates_mamba_hybrid_cache_manager(paged_kv_ interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -261,6 +272,7 @@ def test_initialize_resources_creates_cache_views_with_correct_shape(paged_kv_ca interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -293,6 +305,7 @@ def test_initialize_resources_creates_state_views_with_correct_shape(paged_kv_ca interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -325,6 +338,7 @@ def test_initialize_resources_unpaged_allocated_locally(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -350,6 +364,7 @@ def test_is_paged_returns_true_for_paged_only(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -366,6 +381,7 @@ def test_is_paged_returns_false_for_hybrid(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -390,6 +406,7 @@ def test_needs_resize_returns_false_when_fraction_is_zero(paged_kv_cache_config) interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -405,6 +422,7 @@ def test_needs_resize_returns_true_when_fraction_is_positive(resizable_kv_cache_ interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=resizable_kv_cache_config, ) @@ -420,6 +438,7 @@ def test_resize_kv_cache_manager_skipped_when_not_needed(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -447,6 +466,7 @@ def test_shutdown_clears_caches(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -467,6 +487,7 @@ def test_clear_caches_clears_all(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -497,6 +518,7 @@ def test_update_kv_cache_config_valid_field(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", ) @@ -510,6 +532,7 @@ def test_update_kv_cache_config_multiple_fields(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", ) @@ -529,6 +552,7 @@ def test_update_kv_cache_config_invalid_field_raises(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", ) @@ -546,6 +570,7 @@ def test_named_args_includes_sequence_info_and_caches(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -570,6 +595,7 @@ def test_args_returns_tuple_of_tensors(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -593,6 +619,7 @@ def test_to_moves_sequence_info(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cpu", kv_cache_config=paged_kv_cache_config, ) @@ -609,13 +636,15 @@ def test_to_moves_sequence_info(paged_kv_cache_config): def test_sequence_info_tokens_per_block_from_constructor(): """Verify tokens_per_block is set correctly from constructor.""" - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, tokens_per_block=32) + seq_info = SequenceInfo( + max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=32 + ) assert seq_info.tokens_per_block == 32 def test_sequence_info_tokens_per_block_defaults_to_max_seq_len(): """Verify tokens_per_block defaults to max_seq_len when not provided.""" - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4) + seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4) assert seq_info.tokens_per_block == 128 @@ -706,6 +735,7 @@ def test_sequence_info_last_page_len_uses_tokens_per_block(): seq_info = SequenceInfo( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, tokens_per_block=16, ) @@ -729,6 +759,7 @@ def test_sequence_info_page_assignments(): seq_info = SequenceInfo( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, tokens_per_block=16, ) @@ -858,6 +889,7 @@ def test_multiple_ssm_resources_contiguous_views(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -885,6 +917,7 @@ def test_multiple_conv_resources_contiguous_views(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -912,6 +945,7 @@ def test_mixed_ssm_conv_resources_uses_min_layers(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -962,6 +996,7 @@ def test_generic_state_handler_allocated_locally(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -985,7 +1020,9 @@ def test_generic_state_handler_allocated_locally(paged_kv_cache_config): def test_active_host_prep_args_initially_empty(): """Verify _active_host_prep_args starts empty and is populated by register_host_prepare.""" - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, tokens_per_block=32) + seq_info = SequenceInfo( + max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=32 + ) # Initially empty -- only populated by register_host_prepare_for_attention_forward assert len(seq_info._active_host_prep_args) == 0 @@ -1004,7 +1041,9 @@ def dummy_host_prepare(batch_info_host: torch.Tensor, cu_num_pages_host: torch.T def test_requires_copy_args_not_in_named_args(): """Verify that _requires_copy args do NOT appear in named_args.""" - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, tokens_per_block=32) + seq_info = SequenceInfo( + max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=32 + ) named_args = seq_info.named_args for rc_arg in seq_info._active_host_prep_args: @@ -1014,7 +1053,9 @@ def test_requires_copy_args_not_in_named_args(): def test_args_stored_to_input_buffer(): """Verify that args are written to InputBuffer by nest_sequences.""" - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, tokens_per_block=32) + seq_info = SequenceInfo( + max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=32 + ) # nest_sequences computes token_gather_indices internally from gather_context_logits # Default (gather_context_logits=False): 1 prefill seq of 3 tokens → gather last token only @@ -1050,7 +1091,9 @@ def test_args_stored_to_input_buffer(): def test_register_host_prepare_populates_requires_copy(): """Verify register_host_prepare_for_attention_forward auto-populates _requires_copy.""" - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, tokens_per_block=32) + seq_info = SequenceInfo( + max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=32 + ) # Define a dummy host prepare function def dummy_host_prepare(batch_info_host: torch.Tensor, cu_num_pages_host: torch.Tensor): diff --git a/tests/unittest/auto_deploy/singlegpu/shim/test_engine.py b/tests/unittest/auto_deploy/singlegpu/shim/test_engine.py index dccddec658b..3bdab93e9db 100644 --- a/tests/unittest/auto_deploy/singlegpu/shim/test_engine.py +++ b/tests/unittest/auto_deploy/singlegpu/shim/test_engine.py @@ -61,6 +61,7 @@ def test_engine(engine_cls: Type[ADEngine], tokens_per_block: int): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, device=device, kv_cache_config=kv_cache_config, ) @@ -108,6 +109,7 @@ def test_demo_engine_sampling(tokens_per_block: int): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, device=device, kv_cache_config=kv_cache_config, ) @@ -209,6 +211,7 @@ def test_ad_engine_chunked_prefill_equivalence(tokens_per_block: int): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, device=device, kv_cache_config=kv_cache_config, ) @@ -257,6 +260,7 @@ def test_ad_engine_chunked_prefill_stages_multimodal_runtime_metadata(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, device=device, kv_cache_config=kv_cache_config, ) @@ -314,6 +318,7 @@ def test_ad_engine_skips_multimodal_runtime_metadata_when_no_multimodal_requests cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, device=device, kv_cache_config=kv_cache_config, ) @@ -354,6 +359,7 @@ def test_ad_engine_stages_mm_chunk_bounds_for_multimodal_block_reuse(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, device=device, kv_cache_config=kv_cache_config, ) @@ -399,6 +405,7 @@ def test_ad_engine_rejects_mismatched_multimodal_layout_arrays(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, device=device, kv_cache_config=kv_cache_config, ) @@ -509,6 +516,7 @@ def test_ad_engine_prepare_inputs_with_hybrid_cache_manager(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, device=device, kv_cache_config=kv_cache_config, ) @@ -569,6 +577,7 @@ def test_ad_engine_prepare_inputs_generation_with_hybrid_cache(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, device=device, kv_cache_config=kv_cache_config, ) @@ -646,6 +655,7 @@ def test_ad_engine_with_regular_kv_cache_manager(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, + max_num_tokens=(max_seq_len + 1) * max_batch_size, device=device, kv_cache_config=kv_cache_config, ) diff --git a/tests/unittest/auto_deploy/singlegpu/shim/test_llm_config.py b/tests/unittest/auto_deploy/singlegpu/shim/test_llm_config.py index 163d3e65857..665337bdd4a 100644 --- a/tests/unittest/auto_deploy/singlegpu/shim/test_llm_config.py +++ b/tests/unittest/auto_deploy/singlegpu/shim/test_llm_config.py @@ -206,8 +206,8 @@ def test_small_max_batch_size_caps_heuristic(self): assert 1 in args.cuda_graph_batch_sizes assert 4 in args.cuda_graph_batch_sizes # Should NOT include heuristic values that exceed max_batch_size - assert 17 not in args.cuda_graph_batch_sizes - assert 113 not in args.cuda_graph_batch_sizes + assert 16 not in args.cuda_graph_batch_sizes + assert 32 not in args.cuda_graph_batch_sizes def test_medium_max_batch_size_caps_heuristic(self): """Test heuristic with medium max_batch_size (e.g., 64).""" @@ -220,15 +220,15 @@ def test_medium_max_batch_size_caps_heuristic(self): assert all(bs <= 64 for bs in args.cuda_graph_batch_sizes), ( f"Expected all batch sizes <= 64, got {args.cuda_graph_batch_sizes}" ) - # Should include some heuristic values up to 64 + # Should include some heuristic values up to 64 (range(16, 128+1, 16)) assert 1 in args.cuda_graph_batch_sizes - assert 17 in args.cuda_graph_batch_sizes - assert 33 in args.cuda_graph_batch_sizes - assert 49 in args.cuda_graph_batch_sizes + assert 16 in args.cuda_graph_batch_sizes + assert 32 in args.cuda_graph_batch_sizes + assert 48 in args.cuda_graph_batch_sizes assert 64 in args.cuda_graph_batch_sizes # Should NOT include values > 64 - assert 65 not in args.cuda_graph_batch_sizes - assert 81 not in args.cuda_graph_batch_sizes + assert 80 not in args.cuda_graph_batch_sizes + assert 96 not in args.cuda_graph_batch_sizes def test_large_max_batch_size_includes_all_heuristic_values(self): """Test heuristic with large max_batch_size (e.g., 256).""" @@ -241,10 +241,10 @@ def test_large_max_batch_size_includes_all_heuristic_values(self): assert all(bs <= 256 for bs in args.cuda_graph_batch_sizes), ( f"Expected all batch sizes <= 256, got {args.cuda_graph_batch_sizes}" ) - # Should include heuristic values from range(1, 129, 16) - for bs in [1, 17, 33, 49, 65, 81, 97, 113]: + # Should include heuristic values from range(16, 129, 16) + for bs in [1, 16, 32, 48, 64, 80, 96, 112, 128]: assert bs in args.cuda_graph_batch_sizes, f"Expected {bs} in batch sizes" - # Should include 128 from range(128, max_batch_size+1, 128) + # Should include 128 and 256 from range(128, max_batch_size+1, 128) assert 128 in args.cuda_graph_batch_sizes assert 256 in args.cuda_graph_batch_sizes diff --git a/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_guided_decoding_regex.py b/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_guided_decoding_regex.py index 6d0301acd1d..8680bc849ac 100644 --- a/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_guided_decoding_regex.py +++ b/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_guided_decoding_regex.py @@ -37,7 +37,7 @@ def test_ad_guided_decoding_regex_e2e(): experiment_config["args"]["guided_decoding_backend"] = guided_decoding_backend experiment_config["prompt"]["batch_size"] = 1 - experiment_config["prompt"]["queries"] = {"prompt": test_case["prompt"]} + experiment_config["prompt"]["queries"] = test_case["prompt"] cfg = ExperimentConfig(**experiment_config) diff --git a/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_trtllm_sampler.py b/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_trtllm_sampler.py index 41e96ae1cf3..8a8dd9178c7 100644 --- a/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_trtllm_sampler.py +++ b/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_trtllm_sampler.py @@ -32,7 +32,7 @@ def test_ad_trtllm_sampler_smoke(): # Setup simple prompt experiment_config["prompt"]["batch_size"] = 1 - experiment_config["prompt"]["queries"] = {"prompt": "What is the capital of France?"} + experiment_config["prompt"]["queries"] = "What is the capital of France?" experiment_config["prompt"]["sp_kwargs"] = { "max_tokens": 10, "temperature": 1.0, diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_trtllm_attention_quant_fp8.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_trtllm_attention_quant_fp8.py index 29c2ef00b16..55e522e6cc0 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_trtllm_attention_quant_fp8.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_trtllm_attention_quant_fp8.py @@ -247,6 +247,7 @@ def test_insert_cached_attention_trtllm_materializes_out_scale_reciprocal(): cm = CachedSequenceInterface( max_seq_len=64, max_batch_size=4, + max_num_tokens=256, device="cuda", kv_cache_config=kv_cache_config, ) @@ -295,6 +296,7 @@ def test_insert_cached_attention_trtllm_fallback_without_fp8_contract(): cm = CachedSequenceInterface( max_seq_len=64, max_batch_size=4, + max_num_tokens=256, device="cuda", kv_cache_config=kv_cache_config, ) diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py index b803bffc428..fdebcdaff77 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py @@ -164,6 +164,7 @@ def test_gated_delta_rule_with_cache(num_k_heads, num_v_heads): cm = CachedSequenceInterface( max_seq_len=max_position_embeddings, max_batch_size=batch_size, + max_num_tokens=(max_position_embeddings + 1) * batch_size, device="cuda", kv_cache_config=kv_cache_config, ) diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py index 5d6a18d8bfd..8a0f34cd61e 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py @@ -166,6 +166,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): cm = CachedSequenceInterface( max_seq_len=max_position_embeddings, max_batch_size=batch_size, + max_num_tokens=(max_position_embeddings + 1) * batch_size, device="cuda", kv_cache_config=kv_cache_config, ) @@ -290,6 +291,7 @@ def dummy_cached_interface(): return CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=kv_cache_config, ) @@ -366,6 +368,7 @@ def test_resize_kv_cache_transform_runs_when_needed(): cm = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, + max_num_tokens=129 * 4, device="cuda", kv_cache_config=kv_cache_config, ) @@ -415,6 +418,7 @@ def test_insert_cached_attention_uses_add_resource(): cm = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=batch_size, + max_num_tokens=(max_seq_len + 1) * batch_size, device="cuda", kv_cache_config=kv_cache_config, ) @@ -488,6 +492,7 @@ def test_insert_cached_attention_passes_kv_cache_config(): cm = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=batch_size, + max_num_tokens=(max_seq_len + 1) * batch_size, device="cuda", kv_cache_config=kv_cache_config, ) diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_mrope_delta_cache.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_mrope_delta_cache.py index f375cd5fcc9..0dc11047950 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_mrope_delta_cache.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_mrope_delta_cache.py @@ -20,6 +20,7 @@ def test_initialize_mrope_delta_cache_registers_state_resource(): cm = CachedSequenceInterface( max_seq_len=8, max_batch_size=2, + max_num_tokens=18, device="cpu", ) transform = InitializeMropeDeltaCache.from_kwargs(stage="cache_init") diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py index 8d1851c34c6..919d0cec7a8 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py @@ -152,6 +152,7 @@ def test_torch_gated_delta_rule_cache(num_k_heads, num_v_heads): cm = CachedSequenceInterface( max_seq_len=max_position_embeddings, max_batch_size=batch_size, + max_num_tokens=(max_position_embeddings + 1) * batch_size, device="cuda", kv_cache_config=kv_cache_config, ) From 952cc13ac9338d0ee6dbd0a7269c25169f7b2d6e Mon Sep 17 00:00:00 2001 From: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> Date: Thu, 2 Apr 2026 11:42:49 -0700 Subject: [PATCH 2/5] [None][infra] Remove unused auto_deploy benchmark utility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Nothing imports this module — it is dead code cleaned up as part of the paperclip infra consolidation. Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> --- .../_torch/auto_deploy/utils/benchmark.py | 125 ------------------ 1 file changed, 125 deletions(-) delete mode 100644 tensorrt_llm/_torch/auto_deploy/utils/benchmark.py diff --git a/tensorrt_llm/_torch/auto_deploy/utils/benchmark.py b/tensorrt_llm/_torch/auto_deploy/utils/benchmark.py deleted file mode 100644 index 40e73101cac..00000000000 --- a/tensorrt_llm/_torch/auto_deploy/utils/benchmark.py +++ /dev/null @@ -1,125 +0,0 @@ -import json -import os -import pathlib -from collections import deque -from contextlib import contextmanager -from typing import Callable, Collection, Deque - -import torch - -from .logger import ad_logger - - -class GenerationProfiler: - def __init__(self, num_runs: int): - self.prefill_start, self.prefill_end = self._create_events() - self.decode_start, self.decode_end = self._create_events() - self.num_runs = num_runs - - self.prefill_times: Deque[float] = deque(maxlen=self.num_runs) - self.decode_times: Deque[float] = deque(maxlen=self.num_runs) - - def _create_events(self): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - return start, end - - def _record_event(self, event: torch.cuda.Event): - event.record() - - def _get_elapsed_time(self, start_event, end_event): - torch.cuda.synchronize() - return start_event.elapsed_time(end_event) - - def _store_time(self, time_list: Deque[float], elapsed_time): - time_list.append(elapsed_time) - - def record_prefill_start(self): - self._record_event(self.prefill_start) - - def record_prefill_end(self) -> float: - self._record_event(self.prefill_end) - elapsed_time_ms = self._get_elapsed_time(self.prefill_start, self.prefill_end) - self._store_time(self.prefill_times, elapsed_time_ms) - return elapsed_time_ms - - def record_decode_start(self): - self._record_event(self.decode_start) - - def record_decode_end(self) -> float: - self._record_event(self.decode_end) - elapsed_time_ms = self._get_elapsed_time(self.decode_start, self.decode_end) - self._store_time(self.decode_times, elapsed_time_ms) - return elapsed_time_ms - - def get_average_prefill_time(self) -> float: - return self._get_average_time(self.prefill_times) - - def get_average_decode_time(self) -> float: - return self._get_average_time(self.decode_times) - - def _get_average_time(self, time_list: Collection[float]) -> float: - if len(time_list): - return sum(time_list) / len(time_list) - return 0.0 - - def reset(self): - self.prefill_start, self.prefill_end = self._create_events() - self.decode_start, self.decode_end = self._create_events() - - @contextmanager - def record_prefill(self): - try: - self.record_prefill_start() - yield - finally: - self.record_prefill_end() - - @contextmanager - def record_decode(self): - try: - self.record_decode_start() - yield - finally: - self.record_decode_end() - - -def benchmark( - func: Callable[[], None], num_runs: int, log_prefix: str = "", results_path: str | None = None -) -> Callable[[Callable], Callable]: - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - latencies = [] - - # warmup - for _ in range(1): - func() - - use_nsys_profiling = bool(os.environ.get("NSYS_PROFILING_SESSION_ID", None)) - if use_nsys_profiling: - torch.cuda.cudart().cudaProfilerStart() - func() - torch.cuda.cudart().cudaProfilerStop() - else: - for _ in range(num_runs): - start.record() - func() - end.record() - torch.cuda.synchronize() - latencies.append(start.elapsed_time(end)) - ad_logger.info( - f"{log_prefix} Average of {len(latencies)} " - f"runs: {sum(latencies) / len(latencies): 0.2f} (millisecond)" - ) - - return { - "avg_latency_ms": sum(latencies) / len(latencies), - "avg_latency_num_runs": num_runs if not use_nsys_profiling else 0, - } - - -def store_benchmark_results(results: dict, results_path: str): - results_path = pathlib.Path(results_path) - results_path.parent.mkdir(parents=True, exist_ok=True) - with results_path.open("w") as results_file: - json.dump(results, results_file, indent=2) From 3e3856a9fe6a6e69b7d0421324887a0035fde904 Mon Sep 17 00:00:00 2001 From: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> Date: Thu, 2 Apr 2026 11:47:53 -0700 Subject: [PATCH 3/5] [None][infra] Remove stale transformers-mode TODO comment The comment was historical context about switching to graph mode. Transformers mode deprecation will be tracked separately. Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> --- .../auto_deploy/multigpu/smoke/test_ad_build_small_multi.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unittest/auto_deploy/multigpu/smoke/test_ad_build_small_multi.py b/tests/unittest/auto_deploy/multigpu/smoke/test_ad_build_small_multi.py index b4d589df787..75b4539f205 100644 --- a/tests/unittest/auto_deploy/multigpu/smoke/test_ad_build_small_multi.py +++ b/tests/unittest/auto_deploy/multigpu/smoke/test_ad_build_small_multi.py @@ -23,11 +23,6 @@ ], ) def test_build_ad(world_size: int, model_hub_id: str, llm_extra_args: dict): - # TODO: Revisit transformers-mode multigpu smoke coverage for Llama 3.1 specifically. - # This test keeps only the maintained graph-model AutoDeploy path for - # meta-llama/Meta-Llama-3.1-8B-Instruct. In this branch, the Llama path is served by - # branch-owned AD custom modeling rather than the stock HF modeling code, so the deprecated - # mode="transformers" path is no longer the right coverage target for this test. experiment_config = get_small_model_config(model_hub_id, **llm_extra_args) experiment_config["args"]["world_size"] = world_size From e1c410b35c3f2018bbd18dd6a00489b63ecb90f7 Mon Sep 17 00:00:00 2001 From: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:33:45 -0700 Subject: [PATCH 4/5] [None][infra] Add default_max_num_tokens test utility Centralizes the (max_seq_len + 1) * max_batch_size formula (a WAR for flashinfer issue #4504) into a single helper in _model_test_utils.py. Replaces hardcoded magic numbers (129 * 4) and inline formulas across 7 test files. Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> --- .../_utils_test/_model_test_utils.py | 9 ++ .../custom_ops/test_resource_handlers.py | 16 +++- .../shim/test_cached_sequence_interface.py | 96 +++++++++++-------- .../auto_deploy/singlegpu/shim/test_engine.py | 21 ++-- .../library/test_gated_delta_rule_cache.py | 3 +- .../transformations/library/test_kv_cache.py | 12 +-- .../test_torch_gated_delta_rule_cache.py | 3 +- 7 files changed, 99 insertions(+), 61 deletions(-) diff --git a/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py index e0a57da550d..613f9ae0df8 100644 --- a/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py @@ -8,6 +8,15 @@ from torch.export import Dim +def default_max_num_tokens(max_seq_len: int, max_batch_size: int) -> int: + """Compute the default max_num_tokens for AutoDeploy tests. + + The +1 is a WAR for a flashinfer attention issue with (max_batch_size, max_seq_len) input. + See https://github.com/NVIDIA/TensorRT-LLM/issues/4504 + """ + return (max_seq_len + 1) * max_batch_size + + def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: freqs_cis = freqs_cis[None, : x.shape[1], None] # --> [1, s, 1, h_d//2, 2] xshaped = x.float().unflatten(-1, (-1, 2)) # [b, s, n_h, h_d] --> [b, s, n_h, h_d//2, 2] diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py index eea3963e981..5ee2e7c06aa 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_resource_handlers.py @@ -9,6 +9,7 @@ import pytest import torch +from _model_test_utils import default_max_num_tokens from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import ( AttentionDescriptor, @@ -52,7 +53,10 @@ def test_paged_handler_allocate_with_blocks(kv_layout): handler = KVPagedResourceHandler(8, 64, dtype=torch.float16, kv_layout=kv_layout) tokens_per_block = 32 seq_info = SequenceInfo( - max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=tokens_per_block + max_seq_len=128, + max_batch_size=4, + max_num_tokens=default_max_num_tokens(128, 4), + tokens_per_block=tokens_per_block, ) seq_info.to("cuda") # Set up num_blocks via update_cache_information @@ -124,7 +128,9 @@ def test_state_handler_ssm_state_shape(): def test_state_handler_allocate_creates_tensor(): """Verify StateResourceHandler.allocate() creates tensor with correct shape.""" handler = StateResourceHandler(4, 64, 16, dtype=torch.bfloat16) - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4) + seq_info = SequenceInfo( + max_seq_len=128, max_batch_size=4, max_num_tokens=default_max_num_tokens(128, 4) + ) seq_info.to("cuda") tensor = handler.allocate(seq_info) @@ -178,7 +184,7 @@ def test_unpaged_handler_allocate_returns_correct_shape(num_kv_heads, head_dim, seq_info = SequenceInfo( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), ) seq_info.to("cuda") @@ -193,7 +199,9 @@ def test_unpaged_handler_allocate_returns_correct_shape(num_kv_heads, head_dim, def test_unpaged_handler_allocate_correct_device(): """Verify UnpagedResourceHandler allocated tensor is on the correct device.""" handler = UnpagedResourceHandler(8, 64, dtype=torch.float16) - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4) + seq_info = SequenceInfo( + max_seq_len=128, max_batch_size=4, max_num_tokens=default_max_num_tokens(128, 4) + ) seq_info.to("cuda") tensor = handler.allocate(seq_info) diff --git a/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py b/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py index 2041233ebee..b10a3c2b00b 100644 --- a/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py +++ b/tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py @@ -8,6 +8,7 @@ import pytest import torch +from _model_test_utils import default_max_num_tokens from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import ( CausalConvResourceHandler, @@ -66,7 +67,7 @@ def test_init_creates_sequence_info_with_tokens_per_block(paged_kv_cache_config) interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -81,7 +82,7 @@ def test_init_uses_default_kv_cache_config_when_not_provided(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", ) @@ -110,7 +111,7 @@ def test_init_propagates_vocab_size_padded(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), vocab_size_padded=vocab_size_padded, device="cuda", ) @@ -123,7 +124,7 @@ def test_init_stores_device(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda:0", ) @@ -135,7 +136,7 @@ def test_init_default_device_is_cuda(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), ) assert interface.device == "cuda" @@ -151,7 +152,7 @@ def test_add_resource_paged_handler(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -168,7 +169,7 @@ def test_add_resource_state_handler(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -184,7 +185,7 @@ def test_add_resource_unpaged_handler(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -201,7 +202,7 @@ def test_add_multiple_resources(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -227,7 +228,7 @@ def test_initialize_resources_paged_only_creates_kv_cache_manager(paged_kv_cache interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -248,7 +249,7 @@ def test_initialize_resources_mixed_creates_mamba_hybrid_cache_manager(paged_kv_ interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -272,7 +273,7 @@ def test_initialize_resources_creates_cache_views_with_correct_shape(paged_kv_ca interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -305,7 +306,7 @@ def test_initialize_resources_creates_state_views_with_correct_shape(paged_kv_ca interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -338,7 +339,7 @@ def test_initialize_resources_unpaged_allocated_locally(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -364,7 +365,7 @@ def test_is_paged_returns_true_for_paged_only(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -381,7 +382,7 @@ def test_is_paged_returns_false_for_hybrid(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -406,7 +407,7 @@ def test_needs_resize_returns_false_when_fraction_is_zero(paged_kv_cache_config) interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -422,7 +423,7 @@ def test_needs_resize_returns_true_when_fraction_is_positive(resizable_kv_cache_ interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=resizable_kv_cache_config, ) @@ -438,7 +439,7 @@ def test_resize_kv_cache_manager_skipped_when_not_needed(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -466,7 +467,7 @@ def test_shutdown_clears_caches(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -487,7 +488,7 @@ def test_clear_caches_clears_all(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -518,7 +519,7 @@ def test_update_kv_cache_config_valid_field(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", ) @@ -532,7 +533,7 @@ def test_update_kv_cache_config_multiple_fields(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", ) @@ -552,7 +553,7 @@ def test_update_kv_cache_config_invalid_field_raises(): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", ) @@ -570,7 +571,7 @@ def test_named_args_includes_sequence_info_and_caches(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -595,7 +596,7 @@ def test_args_returns_tuple_of_tensors(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -619,7 +620,7 @@ def test_to_moves_sequence_info(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cpu", kv_cache_config=paged_kv_cache_config, ) @@ -637,14 +638,19 @@ def test_to_moves_sequence_info(paged_kv_cache_config): def test_sequence_info_tokens_per_block_from_constructor(): """Verify tokens_per_block is set correctly from constructor.""" seq_info = SequenceInfo( - max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=32 + max_seq_len=128, + max_batch_size=4, + max_num_tokens=default_max_num_tokens(128, 4), + tokens_per_block=32, ) assert seq_info.tokens_per_block == 32 def test_sequence_info_tokens_per_block_defaults_to_max_seq_len(): """Verify tokens_per_block defaults to max_seq_len when not provided.""" - seq_info = SequenceInfo(max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4) + seq_info = SequenceInfo( + max_seq_len=128, max_batch_size=4, max_num_tokens=default_max_num_tokens(128, 4) + ) assert seq_info.tokens_per_block == 128 @@ -735,7 +741,7 @@ def test_sequence_info_last_page_len_uses_tokens_per_block(): seq_info = SequenceInfo( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), tokens_per_block=16, ) @@ -759,7 +765,7 @@ def test_sequence_info_page_assignments(): seq_info = SequenceInfo( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), tokens_per_block=16, ) @@ -889,7 +895,7 @@ def test_multiple_ssm_resources_contiguous_views(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -917,7 +923,7 @@ def test_multiple_conv_resources_contiguous_views(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -945,7 +951,7 @@ def test_mixed_ssm_conv_resources_uses_min_layers(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -996,7 +1002,7 @@ def test_generic_state_handler_allocated_locally(paged_kv_cache_config): interface = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=paged_kv_cache_config, ) @@ -1021,7 +1027,10 @@ def test_generic_state_handler_allocated_locally(paged_kv_cache_config): def test_active_host_prep_args_initially_empty(): """Verify _active_host_prep_args starts empty and is populated by register_host_prepare.""" seq_info = SequenceInfo( - max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=32 + max_seq_len=128, + max_batch_size=4, + max_num_tokens=default_max_num_tokens(128, 4), + tokens_per_block=32, ) # Initially empty -- only populated by register_host_prepare_for_attention_forward @@ -1042,7 +1051,10 @@ def dummy_host_prepare(batch_info_host: torch.Tensor, cu_num_pages_host: torch.T def test_requires_copy_args_not_in_named_args(): """Verify that _requires_copy args do NOT appear in named_args.""" seq_info = SequenceInfo( - max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=32 + max_seq_len=128, + max_batch_size=4, + max_num_tokens=default_max_num_tokens(128, 4), + tokens_per_block=32, ) named_args = seq_info.named_args @@ -1054,7 +1066,10 @@ def test_requires_copy_args_not_in_named_args(): def test_args_stored_to_input_buffer(): """Verify that args are written to InputBuffer by nest_sequences.""" seq_info = SequenceInfo( - max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=32 + max_seq_len=128, + max_batch_size=4, + max_num_tokens=default_max_num_tokens(128, 4), + tokens_per_block=32, ) # nest_sequences computes token_gather_indices internally from gather_context_logits @@ -1092,7 +1107,10 @@ def test_args_stored_to_input_buffer(): def test_register_host_prepare_populates_requires_copy(): """Verify register_host_prepare_for_attention_forward auto-populates _requires_copy.""" seq_info = SequenceInfo( - max_seq_len=128, max_batch_size=4, max_num_tokens=129 * 4, tokens_per_block=32 + max_seq_len=128, + max_batch_size=4, + max_num_tokens=default_max_num_tokens(128, 4), + tokens_per_block=32, ) # Define a dummy host prepare function diff --git a/tests/unittest/auto_deploy/singlegpu/shim/test_engine.py b/tests/unittest/auto_deploy/singlegpu/shim/test_engine.py index 3bdab93e9db..3dd63035676 100644 --- a/tests/unittest/auto_deploy/singlegpu/shim/test_engine.py +++ b/tests/unittest/auto_deploy/singlegpu/shim/test_engine.py @@ -3,6 +3,7 @@ import pytest import torch import torch.nn as nn +from _model_test_utils import default_max_num_tokens from tensorrt_llm import SamplingParams from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine @@ -61,7 +62,7 @@ def test_engine(engine_cls: Type[ADEngine], tokens_per_block: int): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), device=device, kv_cache_config=kv_cache_config, ) @@ -109,7 +110,7 @@ def test_demo_engine_sampling(tokens_per_block: int): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), device=device, kv_cache_config=kv_cache_config, ) @@ -211,7 +212,7 @@ def test_ad_engine_chunked_prefill_equivalence(tokens_per_block: int): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), device=device, kv_cache_config=kv_cache_config, ) @@ -260,7 +261,7 @@ def test_ad_engine_chunked_prefill_stages_multimodal_runtime_metadata(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), device=device, kv_cache_config=kv_cache_config, ) @@ -318,7 +319,7 @@ def test_ad_engine_skips_multimodal_runtime_metadata_when_no_multimodal_requests cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), device=device, kv_cache_config=kv_cache_config, ) @@ -359,7 +360,7 @@ def test_ad_engine_stages_mm_chunk_bounds_for_multimodal_block_reuse(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), device=device, kv_cache_config=kv_cache_config, ) @@ -405,7 +406,7 @@ def test_ad_engine_rejects_mismatched_multimodal_layout_arrays(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), device=device, kv_cache_config=kv_cache_config, ) @@ -516,7 +517,7 @@ def test_ad_engine_prepare_inputs_with_hybrid_cache_manager(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), device=device, kv_cache_config=kv_cache_config, ) @@ -577,7 +578,7 @@ def test_ad_engine_prepare_inputs_generation_with_hybrid_cache(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), device=device, kv_cache_config=kv_cache_config, ) @@ -655,7 +656,7 @@ def test_ad_engine_with_regular_kv_cache_manager(): cache_seq_interface = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - max_num_tokens=(max_seq_len + 1) * max_batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, max_batch_size), device=device, kv_cache_config=kv_cache_config, ) diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py index fdebcdaff77..0b2a8736231 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py @@ -24,6 +24,7 @@ import pytest import torch import torch.nn as nn +from _model_test_utils import default_max_num_tokens from _torch_test_utils import all_close # Register all auto_deploy custom ops @@ -164,7 +165,7 @@ def test_gated_delta_rule_with_cache(num_k_heads, num_v_heads): cm = CachedSequenceInterface( max_seq_len=max_position_embeddings, max_batch_size=batch_size, - max_num_tokens=(max_position_embeddings + 1) * batch_size, + max_num_tokens=default_max_num_tokens(max_position_embeddings, batch_size), device="cuda", kv_cache_config=kv_cache_config, ) diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py index 8a0f34cd61e..6aba4cecb49 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py @@ -4,7 +4,7 @@ import pytest import torch import torch.nn as nn -from _model_test_utils import GQA +from _model_test_utils import GQA, default_max_num_tokens from _torch_test_utils import all_close # Initialize resources first (KVPagedResourceHandler is used within tests below) @@ -166,7 +166,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): cm = CachedSequenceInterface( max_seq_len=max_position_embeddings, max_batch_size=batch_size, - max_num_tokens=(max_position_embeddings + 1) * batch_size, + max_num_tokens=default_max_num_tokens(max_position_embeddings, batch_size), device="cuda", kv_cache_config=kv_cache_config, ) @@ -291,7 +291,7 @@ def dummy_cached_interface(): return CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=kv_cache_config, ) @@ -368,7 +368,7 @@ def test_resize_kv_cache_transform_runs_when_needed(): cm = CachedSequenceInterface( max_seq_len=128, max_batch_size=4, - max_num_tokens=129 * 4, + max_num_tokens=default_max_num_tokens(128, 4), device="cuda", kv_cache_config=kv_cache_config, ) @@ -418,7 +418,7 @@ def test_insert_cached_attention_uses_add_resource(): cm = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=batch_size, - max_num_tokens=(max_seq_len + 1) * batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, batch_size), device="cuda", kv_cache_config=kv_cache_config, ) @@ -492,7 +492,7 @@ def test_insert_cached_attention_passes_kv_cache_config(): cm = CachedSequenceInterface( max_seq_len=max_seq_len, max_batch_size=batch_size, - max_num_tokens=(max_seq_len + 1) * batch_size, + max_num_tokens=default_max_num_tokens(max_seq_len, batch_size), device="cuda", kv_cache_config=kv_cache_config, ) diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py index 919d0cec7a8..bfff784f10b 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py @@ -16,6 +16,7 @@ import pytest import torch import torch.nn as nn +from _model_test_utils import default_max_num_tokens from _torch_test_utils import all_close # Register all auto_deploy custom ops @@ -152,7 +153,7 @@ def test_torch_gated_delta_rule_cache(num_k_heads, num_v_heads): cm = CachedSequenceInterface( max_seq_len=max_position_embeddings, max_batch_size=batch_size, - max_num_tokens=(max_position_embeddings + 1) * batch_size, + max_num_tokens=default_max_num_tokens(max_position_embeddings, batch_size), device="cuda", kv_cache_config=kv_cache_config, ) From c44dda78701f30c29481291074ff2a48b534fb79 Mon Sep 17 00:00:00 2001 From: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> Date: Thu, 2 Apr 2026 13:14:59 -0700 Subject: [PATCH 5/5] [None][infra] Address CodeRabbit review feedback - Remove duplicate BB Vision/Multi-Modal section in ad-onboard-reviewer.md - Remove stale --benchmark.enabled flag from .vscode/launch.json - Update copyright year to 2025-2026 in _onnx_schemas.py Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> --- .claude/agents/ad-onboard-reviewer.md | 9 --------- examples/auto_deploy/.vscode/launch.json | 1 - .../auto_deploy/transform/library/_onnx_schemas.py | 2 +- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/.claude/agents/ad-onboard-reviewer.md b/.claude/agents/ad-onboard-reviewer.md index 660379fbbb7..feedc4aeef4 100644 --- a/.claude/agents/ad-onboard-reviewer.md +++ b/.claude/agents/ad-onboard-reviewer.md @@ -44,15 +44,6 @@ Read the actual source code for each check. Cite `file:line_number` for every PA Note: BB1–BB2 only apply if the HF source indicates the model is multi-modal (has image/audio inputs). Mark N/A with justification for pure language models. -### BB. Vision / Multi-Modal Support - -| # | Check | How to verify | -|---|-------|---------------| -| BB1 | If the model has a vision tower (multi-modal), the full `nn.Module` hierarchy for the vision component is present in the modeling file — it is NOT omitted, stubbed out, or replaced with a `pass` body | Grep for vision-related class names (e.g., `VisionTower`, `ViT`, `CLIPVision`, `SiglipVision`) from the HF source. If the model is multi-modal and none appear, flag as FAIL. | -| BB2 | The test file asserts that vision-related weight keys are present in the model's `state_dict` after `load_state_dict` | Grep the test file for assertions on vision weight key names (or a check that vision-prefixed keys are in the loaded state_dict). Absence of any such assertion is a FAIL for multi-modal models. | - -Note: BB1–BB2 only apply if the HF source indicates the model is multi-modal (has image/audio inputs). Mark N/A with justification for pure language models. - ### C. Ops & Compatibility (STRICT — canonical ops are the backbone of AD) | # | Check | How to verify | diff --git a/examples/auto_deploy/.vscode/launch.json b/examples/auto_deploy/.vscode/launch.json index 5eec8ad8c2e..24d1b7be4d0 100644 --- a/examples/auto_deploy/.vscode/launch.json +++ b/examples/auto_deploy/.vscode/launch.json @@ -14,7 +14,6 @@ "--args.attn-page-size=16", "--args.transforms.insert-cached-attention.backend=flashinfer", "--args.model-factory=AutoModelForCausalLM", - "--benchmark.enabled=false", "--prompt.batch-size=2", "--args.model-kwargs.num-hidden-layers=3", "--args.model-kwargs.num-attention-heads=32", diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py b/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py index 2efa03b5523..f5770e5be82 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/_onnx_schemas.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.