Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/onnx_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Most of the examples in this doc use `vit_base_patch16_224.onnx` as the input mo

```bash
python download_example_onnx.py \
--vit \
--timm_model_name=vit_base_patch16_224 \
--onnx_save_path=vit_base_patch16_224.onnx \
--fp16 # <Optional, if the desired output ONNX precision is FP16>
```
Expand Down
81 changes: 20 additions & 61 deletions examples/onnx_ptq/download_example_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import argparse
import os
import subprocess

import timm
import torch
Expand Down Expand Up @@ -46,14 +45,10 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
parser = argparse.ArgumentParser(description="Download and export example models to ONNX.")

parser.add_argument(
"--vit",
action="store_true",
help="Export timm/vit_base_patch16_224 model to ONNX.",
)
parser.add_argument(
"--llama",
action="store_true",
help="Export meta-llama/Llama-3.1-8B-Instruct to ONNX with KV cache.",
"--timm_model_name",
type=str,
required=True,
help="Export any timm model to ONNX (e.g., vit_base_patch16_224, swin_tiny_patch4_window7_224).",
)
parser.add_argument(
"--onnx_save_path", type=str, required=False, help="Path to save the final ONNX model."
Expand All @@ -62,7 +57,7 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
"--batch_size",
type=int,
default=1,
help="Batch size for the exported ViT model.",
help="Batch size for the exported model.",
)
parser.add_argument(
"--fp16",
Expand All @@ -71,54 +66,18 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
)
args = parser.parse_args()

if args.vit:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=1000).to(
device
)
data_config = timm.data.resolve_model_data_config(model)
input_shape = (args.batch_size,) + data_config["input_size"]

vit_save_path = args.onnx_save_path or "vit_base_patch16_224.onnx"
weights_dtype = "fp16" if args.fp16 else "fp32"
export_to_onnx(
model,
input_shape,
vit_save_path,
device,
weights_dtype=weights_dtype,
)
print(f"ViT model exported to {vit_save_path}")

if args.llama:
model_name = "meta-llama/Llama-3.1-8B-Instruct"
if not args.onnx_save_path:
args.onnx_save_path = "Llama-3.1-8B-Instruct/model.onnx"

output_dir = os.path.dirname(args.onnx_save_path)
if not output_dir: # Handle cases where only filename is given (save in current dir)
output_dir = "."
os.makedirs(output_dir, exist_ok=True)

command = [
"python",
"-m",
"optimum.commands.optimum_cli",
"export",
"onnx",
"--model",
model_name,
"--task",
"causal-lm-with-past",
"--device",
"cuda",
"--fp16" if args.fp16 else "",
output_dir,
]

try:
print(f"Running optimum-cli export to {output_dir}...")
subprocess.run(command, check=True, capture_output=True, text=True, encoding="utf-8")
print(f"Llama model exported to {output_dir}")
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Failed to export model: {e.stderr}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device)
data_config = timm.data.resolve_model_data_config(model)
input_shape = (args.batch_size,) + data_config["input_size"]

save_path = args.onnx_save_path or f"{args.timm_model_name}.onnx"
weights_dtype = "fp16" if args.fp16 else "fp32"
export_to_onnx(
model,
input_shape,
save_path,
device,
weights_dtype=weights_dtype,
)
print(f"{args.timm_model_name} model exported to {save_path}")
27 changes: 19 additions & 8 deletions examples/torch_onnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf

- Loads a pretrained timm torch model (default: ViT-Base).
- Quantizes the torch model to FP8, MXFP8, INT8, NVFP4, or INT4_AWQ using ModelOpt.
- For models with Conv2d layers (e.g., SwinTransformer), automatically overrides Conv2d quantization to FP8 (for MXFP8/NVFP4 modes) or INT8 (for INT4_AWQ mode) for TensorRT compatibility.
- Exports the quantized model to ONNX.
- Postprocesses the ONNX model to be compatible with TensorRT.
- Saves the final ONNX model.
Expand All @@ -63,11 +64,21 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf

```bash
python torch_quant_to_onnx.py \
--timm_model_name=vit_base_patch16_224 \
--timm_model_name=<timm model name> \
--quantize_mode=<fp8|mxfp8|int8|nvfp4|int4_awq> \
--onnx_save_path=<path to save the exported ONNX model>
```

### Conv2d Quantization Override

TensorRT only supports FP8 and INT8 for convolution operations. When quantizing models with Conv2d layers (like SwinTransformer), the script automatically applies the following overrides:

| Quantize Mode | Conv2d Override | Reason |
| :---: | :---: | :--- |
| FP8, INT8 | None (already compatible) | Native TRT support |
| MXFP8, NVFP4 | Conv2d -> FP8 | TRT Conv limitation |
| INT4_AWQ | Conv2d -> INT8 | TRT Conv limitation |

### Evaluation

If the input model is of type image classification, use the following script to evaluate it. The script automatically downloads and uses the [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset from Hugging Face. This gated repository requires authentication via Hugging Face access token. See <https://huggingface.co/docs/hub/en/security-tokens> for details.
Expand All @@ -79,7 +90,7 @@ python ../onnx_ptq/evaluate.py \
--onnx_path=<path to the exported ONNX model> \
--imagenet_path=<HF dataset card or local path to the ImageNet dataset> \
--engine_precision=stronglyTyped \
--model_name=vit_base_patch16_224
--model_name=<timm model name>
```

## LLM Quantization and Export with TensorRT-Edge-LLM
Expand Down Expand Up @@ -289,13 +300,13 @@ python torch_quant_to_onnx.py \
--onnx_save_path=vit_base_patch16_224.auto_quant.onnx
```

### Results (ViT-Base)
## ONNX Export Supported Vision Models

| | Top-1 accuracy (torch) | Top-5 accuracy (torch) |
| :--- | :---: | :---: |
| Torch autocast (FP16) | 85.11% | 97.53% |
| NVFP4 Quantized | 84.558% | 97.36% |
| Auto Quantized (FP8 + NVFP4, 4.78 effective bits) | 84.726% | 97.434% |
| Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | |
Comment on lines +305 to +309
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Support matrix overstates swinv2_tiny Auto support

Line 309 marks Auto as ✅ for swinv2_tiny_window8_256, but tests explicitly skip that combo (tests/examples/torch_onnx/test_torch_quant_to_onnx.py Line 35). Please mark it unsupported (or add a footnote with the current limitation).

📝 Suggested docs correction
-| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
| Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) |||||||
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) |||||||
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) |||||| |
| Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) |||||||
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) |||||||
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) |||||| |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/README.md` around lines 305 - 309, Update the support
matrix row for the model identifier "swinv2_tiny_window8_256" in README.md to
remove the Auto ✅ (change to ❌) or add a footnote explaining it's currently
unsupported; reference the test that skips this combo
(test_torch_quant_to_onnx.py) as the reason for the change so readers know the
limitation is intentional.


## Resources

Expand Down
90 changes: 86 additions & 4 deletions examples/torch_onnx/torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
# limitations under the License.

import argparse
import copy
import json
import re
import sys
import warnings
from pathlib import Path

# Add onnx_ptq to path for shared modules
Expand Down Expand Up @@ -44,20 +47,69 @@

mp.set_start_method("spawn", force=True) # Needed for data loader with multiple workers

QUANT_CONFIG_DICT = {
QUANT_CONFIG_DICT: dict[str, dict] = {
"fp8": mtq.FP8_DEFAULT_CFG,
"int8": mtq.INT8_DEFAULT_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
}

_FP8_CONV_OVERRIDE: list = [
{
"parent_class": "nn.Conv2d",
"quantizer_name": "*weight_quantizer",
"cfg": {"num_bits": (4, 3), "axis": None},
},
{
"parent_class": "nn.Conv2d",
"quantizer_name": "*input_quantizer",
"cfg": {"num_bits": (4, 3), "axis": None},
},
]

_INT8_CONV_OVERRIDE: list = [
{
"parent_class": "nn.Conv2d",
"quantizer_name": "*weight_quantizer",
"cfg": {"num_bits": 8, "axis": 0},
},
{
"parent_class": "nn.Conv2d",
"quantizer_name": "*input_quantizer",
"cfg": {"num_bits": 8, "axis": None},
},
]


def get_quant_config(quantize_mode):
"""Get quantization config, overriding Conv2d for TRT compatibility.

TensorRT only supports FP8 and INT8 for Conv layers.
- For MXFP8, NVFP4: override Conv2d to FP8
- For INT4_AWQ: override Conv2d to INT8
"""
config: dict = copy.deepcopy(QUANT_CONFIG_DICT[quantize_mode])
if quantize_mode in ("mxfp8", "nvfp4"):
warnings.warn(
f"TensorRT only supports FP8/INT8 for Conv layers. "
f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode."
)
config["quant_cfg"].extend(_FP8_CONV_OVERRIDE)
elif quantize_mode == "int4_awq":
warnings.warn(
"TensorRT only supports FP8/INT8 for Conv layers. "
"Overriding Conv2d quantization to INT8 for 'int4_awq' mode."
)
config["quant_cfg"].extend(_INT8_CONV_OVERRIDE)
return config


def filter_func(name):
"""Filter function to exclude certain layers from quantization."""
pattern = re.compile(
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed).*"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
)
return pattern.match(name) is not None

Expand Down Expand Up @@ -121,6 +173,18 @@ def loss_func(output, batch):
return F.cross_entropy(output, batch["label"])


def _disable_inplace_relu(model):
"""Replace inplace ReLU with non-inplace ReLU throughout the model.

This is needed for auto_quantize which uses backward hooks for gradient-based
sensitivity scoring. Inplace ReLU on views created by custom Functions causes
PyTorch autograd errors.
"""
for module in model.modules():
if isinstance(module, torch.nn.ReLU) and module.inplace:
module.inplace = False


def auto_quantize_model(
model,
data_loader,
Expand All @@ -142,6 +206,7 @@ def auto_quantize_model(
Returns:
Tuple of (quantized_model, search_state_dict)
"""
_disable_inplace_relu(model)
constraints = {"effective_bits": effective_bits}

# Convert string format names to actual config objects
Expand Down Expand Up @@ -255,12 +320,29 @@ def main():
default=128,
help="Number of scoring steps for auto quantization. Default is 128.",
)
parser.add_argument(
"--no_pretrained",
action="store_true",
help="Don't load pretrained weights (useful for testing with random weights).",
)
parser.add_argument(
"--model_kwargs",
type=str,
default=None,
help="JSON string of extra model kwargs (e.g., '{\"depth\": 1}').",
)

args = parser.parse_args()

# Create model and move to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device)
model_kwargs = json.loads(args.model_kwargs) if args.model_kwargs else {}
model = timm.create_model(
args.timm_model_name,
pretrained=not args.no_pretrained,
num_classes=1000,
**model_kwargs,
).to(device)

# Get input shape from model config
input_size = get_model_input_shape(model)
Expand Down Expand Up @@ -297,7 +379,7 @@ def main():
)
else:
# Standard quantization - only load calibration data if needed
config = QUANT_CONFIG_DICT[args.quantize_mode]
config = get_quant_config(args.quantize_mode)
if args.quantize_mode == "mxfp8":
data_loader = None
else:
Expand Down
3 changes: 3 additions & 0 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,9 +1162,12 @@ def cast_initializer_to_dtype(
node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto]
):
"""Casts the initializer to the given dtype."""
input_id = None
for id, input_name in enumerate(node.input):
if input_name in initializer_map:
input_id = id
if input_id is None:
return
input_name = node.input[input_id]
input = numpy_helper.to_array(initializer_map[input_name])
input = input.astype(np_dtype_map[dtype])
Expand Down
3 changes: 2 additions & 1 deletion modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,8 @@ def get_onnx_bytes_and_metadata(
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
# Change FP32 cast nodes feeding into Concat/Add to FP16
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add", "Sqrt"])
op_list = ["Concat", "Add", "Sqrt", "LayerNormalization", "Clip", "Mul", "Exp"]
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, op_list)
else:
onnx_opt_graph = convert_to_f16(
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False
Expand Down
5 changes: 4 additions & 1 deletion tests/_test_utils/torch/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def forward(self, x):
def _create_timm_fn(name):
def get_model_and_input(on_gpu: bool = False):
model = timm.create_model(name)
return process_model_and_inputs(model, (torch.randn(1, 3, 224, 224),), {}, on_gpu)
data_config = timm.data.resolve_model_data_config(model)
input_size = data_config["input_size"] # e.g., (3, 224, 224)
return process_model_and_inputs(model, (torch.randn(1, *input_size),), {}, on_gpu)

return get_model_and_input

Expand Down Expand Up @@ -114,6 +116,7 @@ def get_model_and_input(on_gpu: bool = False):
# "vovnet39a",
# "dm_nfnet_f0",
"efficientnet_b0",
"swin_tiny_patch4_window7_224",
],
_create_timm_fn,
),
Expand Down
Loading
Loading