Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ HuggingFace Model → LLM API → Executor (PyTorch/AutoDeploy/TensorRT)
- Target `main` unless fixing a release branch bug
- See `CONTRIBUTING.md` for full PR policies

### GitHub CLI authentication (`GH_CONFIG_DIR`)

The `gh` CLI uses `~/.config/gh` by default for authentication. Different GitHub hosts or forks may require a different config directory. **Before running any `gh` command** (e.g., `gh pr create`, `gh api`, `gh pr comment`):

1. Check if the user has specified a custom `GH_CONFIG_DIR` (e.g., in `CLAUDE.local.md` or environment). If so, use it.
2. If not explicitly set, **ask the user** whether the default `~/.config/gh` is correct or if a different directory should be used. This is especially relevant when the PR target is a fork (e.g., `nv-auto-deploy/TensorRT-LLM`) rather than `NVIDIA/TensorRT-LLM`.
3. Prefix all `gh` commands with the resolved config dir: `GH_CONFIG_DIR=<path> gh ...`

## CI / Testing

See [CI overview](docs/source/developer-guide/ci-overview.md) for full details.
Expand Down
1 change: 0 additions & 1 deletion examples/auto_deploy/.vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"--args.attn-page-size=16",
"--args.transforms.insert-cached-attention.backend=flashinfer",
"--args.model-factory=AutoModelForCausalLM",
"--benchmark.enabled=false",
"--prompt.batch-size=2",
"--args.model-kwargs.num-hidden-layers=3",
"--args.model-kwargs.num-attention-heads=32",
Expand Down
229 changes: 136 additions & 93 deletions examples/auto_deploy/build_and_run_ad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Main entrypoint to build, test, and prompt AutoDeploy inference models."""

import json
import sys
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Union

import torch
Expand All @@ -20,47 +23,50 @@
DynamicYamlMixInForSettings,
deep_merge_dicts,
)
from tensorrt_llm._torch.auto_deploy.utils.benchmark import benchmark, store_benchmark_results
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.sampling_params import SamplingParams

# Registry paths
_REGISTRY_DIR = Path(__file__).resolve().parent / "model_registry"
_REGISTRY_YAML = _REGISTRY_DIR / "models.yaml"
_REGISTRY_CONFIGS_DIR = _REGISTRY_DIR / "configs"

# Global torch config, set the torch compile cache to fix up to llama 405B
torch._dynamo.config.cache_size_limit = 20

# simple string, TRT-LLM style text-only prompt or full-scale HF message template
PromptInput = Union[str, Dict, List[Dict]]

# A single query is either a plain string or a full HF chat message template.
PromptInput = Union[str, List[Dict]]


class PromptConfig(BaseModel):
"""Prompt configuration.

This configuration class can be used for this example script to configure the example prompts
and the sampling parameters.

Queries can be plain strings or HF-style chat message lists
(``[{"role": "user", "content": "..."}]``). Plain-string queries are automatically wrapped in
a chat template when the model's tokenizer supports one.
"""

batch_size: int = Field(default=2, description="Number of queries")
batch_size: int = Field(default=10, description="Number of queries")
queries: Union[PromptInput, List[PromptInput]] = Field(
default_factory=lambda: [
# OPTION 1: simple text prompt
"How big is the universe? ",
# OPTION 2: wrapped text prompt for TRT-LLM
{"prompt": "In simple words and a single sentence, explain the concept of gravity: "},
# OPTION 3: a full-scale HF message template (this one works for text-only models!)
# Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating
# and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal
[
{
"role": "user",
"content": "How to fix slicing in golf?",
}
],
# More prompts...
{"prompt": "Where is the capital of Iceland? "},
"In simple words and a single sentence, explain the concept of gravity: ",
"How to fix slicing in golf? ",
"Where is the capital of Iceland? ",
"What are the three laws of thermodynamics? ",
"Summarize the plot of Romeo and Juliet in two sentences: ",
"Write a Python function that checks if a number is prime.",
"Explain the difference between a compiler and an interpreter: ",
"What causes the northern lights? ",
"What are the health benefits of drinking green tea?",
],
description="Example queries to prompt the model with. We support both TRT-LLM text-only "
"queries via the 'prompt' key and full-scale HF message template called via "
"apply_chat_template.",
description="Plain-text queries or HF-style chat message lists. Plain strings are "
"automatically wrapped as chat messages when the model's tokenizer has a chat template.",
)
sp_kwargs: Dict[str, Any] = Field(
default_factory=lambda: {"max_tokens": 100, "top_k": None, "temperature": 1.0},
Expand All @@ -69,33 +75,19 @@ class PromptConfig(BaseModel):
)

def model_post_init(self, __context: Any):
"""Cut queries to batch_size.
"""Repeat and truncate queries to match batch_size.

NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field
validators are only run if a value is provided.
"""
queries = self.queries if isinstance(self.queries, list) else [self.queries]
batch_size = self.batch_size
queries = queries * (batch_size // len(queries) + 1)
queries = queries[:batch_size]

# now let's standardize the queries for the LLM api to understand them
queries_processed = []
for query in queries:
if isinstance(query, str):
queries_processed.append({"prompt": query})
elif isinstance(query, dict):
queries_processed.append(query)
elif isinstance(query, list):
queries_processed.append(
{
"prompt": "Fake prompt. Check out messages field for the HF chat template.",
"messages": query, # contains the actual HF chat template
}
)
else:
raise ValueError(f"Invalid query type: {type(query)}")
self.queries = queries_processed
queries = self.queries
if isinstance(queries, str):
queries = [queries]
elif isinstance(queries, list) and queries and isinstance(queries[0], dict):
# single HF message template, e.g. [{"role": "user", "content": "..."}]
queries = [queries]
queries = queries * (self.batch_size // len(queries) + 1)
self.queries = queries[: self.batch_size]

@field_validator("sp_kwargs", mode="after")
@classmethod
Expand All @@ -106,21 +98,10 @@ def validate_sp_kwargs(cls, sp_kwargs):


class BenchmarkConfig(BaseModel):
"""Benchmark configuration.
"""Configuration for storing results."""

This configuration class can be used for this example script to configure the simple
benchmarking we run at the end of the script.
"""

enabled: bool = Field(default=False, description="If true, run simple benchmark")
num: int = Field(default=10, ge=1, description="By default run 10 times and get average")
isl: int = Field(default=2048, ge=1, description="Input seq length for benchmarking")
osl: int = Field(default=128, ge=1, description="Output seq length for benchmarking")
bs: int = Field(default=1, ge=1, description="Batch size for benchmarking")
results_path: Optional[str] = Field(default="./benchmark_results.json")
store_results: bool = Field(
default=False, description="If True, store benchmark res in benchmark_results_path"
)
store_results: bool = Field(default=False, description="If True, store results to results_path")


class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings):
Expand Down Expand Up @@ -225,17 +206,30 @@ def sync_prompt_batch_size_with_args_max_batch_size(cls, prompt: PromptConfig, i
args.max_batch_size = prompt.batch_size
return prompt

@field_validator("benchmark", mode="after")
@classmethod
def adjust_args_for_benchmark(cls, benchmark: BenchmarkConfig, info):
if "args" not in info.data:
return benchmark
args: LlmArgs = info.data["args"]
if benchmark.enabled:
# propagate benchmark settings to args
args.max_batch_size = max(benchmark.bs, args.max_batch_size)
args.max_seq_len = max(args.max_seq_len, benchmark.isl + benchmark.osl)
return benchmark

def get_registry_yaml_extra(model_name: str) -> List[str]:
"""Look up a model in the registry and return its resolved yaml_extra config paths.

Args:
model_name: HuggingFace model id as listed in the registry (e.g. ``meta-llama/Llama-3.1-8B-Instruct``).

Returns:
List of absolute paths to the yaml config files for the model.

Raises:
KeyError: If the model is not found in the registry.
"""
with open(_REGISTRY_YAML) as f:
registry = yaml.safe_load(f)

for entry in registry.get("models", []):
if entry["name"] == model_name:
return [str(_REGISTRY_CONFIGS_DIR / cfg) for cfg in entry.get("yaml_extra", [])]

raise KeyError(
f"Model '{model_name}' not found in the AutoDeploy model registry ({_REGISTRY_YAML}). "
"Either add it to the registry or provide --yaml-extra directly."
)


def build_llm_from_config(config: ExperimentConfig) -> LLM:
Expand All @@ -249,6 +243,36 @@ def build_llm_from_config(config: ExperimentConfig) -> LLM:
return llm


def prepare_queries(queries: List[PromptInput], tokenizer=None) -> List[Dict]:
"""Prepare queries for the LLM API.

Queries that are already HF-style message lists (``List[Dict]``) are passed through directly.
Plain-string queries are wrapped as HF chat messages when the tokenizer has a chat template,
or passed as plain text prompts otherwise.
"""
has_chat_template = getattr(tokenizer, "chat_template", None) is not None

prepared = []
for query in queries:
if isinstance(query, list):
prepared.append(
{
"prompt": query[0].get("content", "") if query else "",
"messages": query,
}
)
elif has_chat_template:
prepared.append(
{
"prompt": query,
"messages": [{"role": "user", "content": query}],
}
)
else:
prepared.append({"prompt": query})
return prepared


def print_outputs(outs: Union[RequestOutput, List[RequestOutput]]) -> List[List[str]]:
prompts_and_outputs: List[List[str]] = []
if isinstance(outs, RequestOutput):
Expand All @@ -260,8 +284,45 @@ def print_outputs(outs: Union[RequestOutput, List[RequestOutput]]) -> List[List[
return prompts_and_outputs


def _inject_registry_yaml_extra() -> None:
"""If ``--use-registry`` is in sys.argv, replace it with the resolved ``--yaml-extra`` entries.

This allows callers to simply run::

python build_and_run_ad.py --model <hf_model_id> --use-registry

instead of manually specifying every ``--yaml-extra`` path. The flag is consumed here and the
resolved paths are injected back into ``sys.argv`` before pydantic-settings parses them.
"""
if "--use-registry" not in sys.argv:
return

# Extract model name from argv (support both --model=X and --model X forms)
model_name: Optional[str] = None
for i, arg in enumerate(sys.argv):
if arg.startswith("--model="):
model_name = arg.split("=", 1)[1]
break
if arg == "--model" and i + 1 < len(sys.argv):
model_name = sys.argv[i + 1]
break

if model_name is None:
raise ValueError("--use-registry requires --model to be specified.")

yaml_extra_paths = get_registry_yaml_extra(model_name)

# Remove --use-registry and inject --yaml-extra <path0> --yaml-extra <path1> ...
# Each path needs its own flag because pydantic-settings CLI only captures one value per flag.
argv = [a for a in sys.argv if a != "--use-registry"]
for path in yaml_extra_paths:
argv += ["--args.yaml-extra", path]
sys.argv = argv


def main(config: Optional[ExperimentConfig] = None):
if config is None:
_inject_registry_yaml_extra()
config: ExperimentConfig = CliApp.run(ExperimentConfig)
ad_logger.info(f"AutoDeploy Experiment Config:\n{yaml.dump(config.model_dump())}")

Expand All @@ -272,8 +333,10 @@ def main(config: Optional[ExperimentConfig] = None):

# prompt the model and print its output
ad_logger.info("Running example prompts...")
hf_tokenizer = getattr(llm.tokenizer, "tokenizer", None)
queries = prepare_queries(config.prompt.queries, hf_tokenizer)
outs = llm.generate(
config.prompt.queries,
queries,
sampling_params=SamplingParams(**config.prompt.sp_kwargs),
)
results = {
Expand All @@ -282,31 +345,11 @@ def main(config: Optional[ExperimentConfig] = None):
# Add config values so they get logged to JET extra
results.update(config.model_dump(mode="json"))

# run a benchmark for the model with batch_size == config.benchmark_bs
if config.benchmark.enabled and config.args.runtime != "trtllm":
ad_logger.info("Running benchmark...")
keys_from_args = []
fields_to_show = [f"benchmark={config.benchmark}"]
fields_to_show.extend([f"{k}={getattr(config.args, k)}" for k in keys_from_args])
results["benchmark_results"] = benchmark(
func=lambda: llm.generate(
torch.randint(0, 100, (config.benchmark.bs, config.benchmark.isl)).tolist(),
sampling_params=SamplingParams(
max_tokens=config.benchmark.osl,
top_k=None,
ignore_eos=True,
),
use_tqdm=False,
),
num_runs=config.benchmark.num,
log_prefix="Benchmark with " + ", ".join(fields_to_show),
results_path=config.benchmark.results_path,
)
elif config.benchmark.enabled:
ad_logger.info("Skipping simple benchmarking for trtllm...")

if config.benchmark.store_results:
store_benchmark_results(results, config.benchmark.results_path)
results_path = Path(config.benchmark.results_path)
results_path.parent.mkdir(parents=True, exist_ok=True)
with results_path.open("w") as f:
json.dump(results, f, indent=2)

llm.shutdown()
return results
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ transforms:
###########################################################################################
insert_cached_attention:
stage: cache_init
backend: flashinfer
backend: trtllm
insert_cached_mla_attention:
stage: cache_init
requires_shape_prop: true
Expand Down Expand Up @@ -280,6 +280,6 @@ transforms:
expect_mem_change: true
run_per_gm: false
cuda_graph_batch_sizes: null
backend: torch-compile
backend: torch-cudagraph
piecewise_enabled: false
piecewise_num_tokens: null
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
- All possible "constants" inferred from tensor shapes at runtime
"""

import math
from typing import List, Optional, Tuple

import torch
Expand Down Expand Up @@ -462,7 +463,7 @@ def trtllm_mha_with_cache(
1, # beam_width
int(AttentionMaskType.causal), # mask_type
quant_mode, # quant_mode
1.0, # q_scaling
scale * math.sqrt(head_dim) if scale is not None else 1.0, # q_scaling
0, # position_embedding_type
0, # rotary_embedding_dim
10000.0, # rotary_embedding_base
Expand Down
Loading
Loading