diff --git a/examples/onnx_ptq/README.md b/examples/onnx_ptq/README.md index f5cbde94fa..80faa42bd6 100644 --- a/examples/onnx_ptq/README.md +++ b/examples/onnx_ptq/README.md @@ -56,7 +56,7 @@ Most of the examples in this doc use `vit_base_patch16_224.onnx` as the input mo ```bash python download_example_onnx.py \ - --vit \ + --timm_model_name=vit_base_patch16_224 \ --onnx_save_path=vit_base_patch16_224.onnx \ --fp16 # ``` diff --git a/examples/onnx_ptq/download_example_onnx.py b/examples/onnx_ptq/download_example_onnx.py index b28eccda88..e78ff2ecbf 100644 --- a/examples/onnx_ptq/download_example_onnx.py +++ b/examples/onnx_ptq/download_example_onnx.py @@ -15,7 +15,6 @@ import argparse import os -import subprocess import timm import torch @@ -46,14 +45,10 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp parser = argparse.ArgumentParser(description="Download and export example models to ONNX.") parser.add_argument( - "--vit", - action="store_true", - help="Export timm/vit_base_patch16_224 model to ONNX.", - ) - parser.add_argument( - "--llama", - action="store_true", - help="Export meta-llama/Llama-3.1-8B-Instruct to ONNX with KV cache.", + "--timm_model_name", + type=str, + required=True, + help="Export any timm model to ONNX (e.g., vit_base_patch16_224, swin_tiny_patch4_window7_224).", ) parser.add_argument( "--onnx_save_path", type=str, required=False, help="Path to save the final ONNX model." @@ -62,7 +57,7 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp "--batch_size", type=int, default=1, - help="Batch size for the exported ViT model.", + help="Batch size for the exported model.", ) parser.add_argument( "--fp16", @@ -71,54 +66,18 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp ) args = parser.parse_args() - if args.vit: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=1000).to( - device - ) - data_config = timm.data.resolve_model_data_config(model) - input_shape = (args.batch_size,) + data_config["input_size"] - - vit_save_path = args.onnx_save_path or "vit_base_patch16_224.onnx" - weights_dtype = "fp16" if args.fp16 else "fp32" - export_to_onnx( - model, - input_shape, - vit_save_path, - device, - weights_dtype=weights_dtype, - ) - print(f"ViT model exported to {vit_save_path}") - - if args.llama: - model_name = "meta-llama/Llama-3.1-8B-Instruct" - if not args.onnx_save_path: - args.onnx_save_path = "Llama-3.1-8B-Instruct/model.onnx" - - output_dir = os.path.dirname(args.onnx_save_path) - if not output_dir: # Handle cases where only filename is given (save in current dir) - output_dir = "." - os.makedirs(output_dir, exist_ok=True) - - command = [ - "python", - "-m", - "optimum.commands.optimum_cli", - "export", - "onnx", - "--model", - model_name, - "--task", - "causal-lm-with-past", - "--device", - "cuda", - "--fp16" if args.fp16 else "", - output_dir, - ] - - try: - print(f"Running optimum-cli export to {output_dir}...") - subprocess.run(command, check=True, capture_output=True, text=True, encoding="utf-8") - print(f"Llama model exported to {output_dir}") - except subprocess.CalledProcessError as e: - raise RuntimeError(f"Failed to export model: {e.stderr}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device) + data_config = timm.data.resolve_model_data_config(model) + input_shape = (args.batch_size,) + data_config["input_size"] + + save_path = args.onnx_save_path or f"{args.timm_model_name}.onnx" + weights_dtype = "fp16" if args.fp16 else "fp32" + export_to_onnx( + model, + input_shape, + save_path, + device, + weights_dtype=weights_dtype, + ) + print(f"{args.timm_model_name} model exported to {save_path}") diff --git a/examples/torch_onnx/README.md b/examples/torch_onnx/README.md index a1c01e0cbb..d540770116 100644 --- a/examples/torch_onnx/README.md +++ b/examples/torch_onnx/README.md @@ -53,6 +53,7 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf - Loads a pretrained timm torch model (default: ViT-Base). - Quantizes the torch model to FP8, MXFP8, INT8, NVFP4, or INT4_AWQ using ModelOpt. +- For models with Conv2d layers (e.g., SwinTransformer), automatically overrides Conv2d quantization to FP8 (for MXFP8/NVFP4 modes) or INT8 (for INT4_AWQ mode) for TensorRT compatibility. - Exports the quantized model to ONNX. - Postprocesses the ONNX model to be compatible with TensorRT. - Saves the final ONNX model. @@ -63,11 +64,21 @@ The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingf ```bash python torch_quant_to_onnx.py \ - --timm_model_name=vit_base_patch16_224 \ + --timm_model_name= \ --quantize_mode= \ --onnx_save_path= ``` +### Conv2d Quantization Override + +TensorRT only supports FP8 and INT8 for convolution operations. When quantizing models with Conv2d layers (like SwinTransformer), the script automatically applies the following overrides: + +| Quantize Mode | Conv2d Override | Reason | +| :---: | :---: | :--- | +| FP8, INT8 | None (already compatible) | Native TRT support | +| MXFP8, NVFP4 | Conv2d -> FP8 | TRT Conv limitation | +| INT4_AWQ | Conv2d -> INT8 | TRT Conv limitation | + ### Evaluation If the input model is of type image classification, use the following script to evaluate it. The script automatically downloads and uses the [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset from Hugging Face. This gated repository requires authentication via Hugging Face access token. See for details. @@ -79,7 +90,7 @@ python ../onnx_ptq/evaluate.py \ --onnx_path= \ --imagenet_path= \ --engine_precision=stronglyTyped \ - --model_name=vit_base_patch16_224 + --model_name= ``` ## LLM Quantization and Export with TensorRT-Edge-LLM @@ -289,13 +300,13 @@ python torch_quant_to_onnx.py \ --onnx_save_path=vit_base_patch16_224.auto_quant.onnx ``` -### Results (ViT-Base) +## ONNX Export Supported Vision Models -| | Top-1 accuracy (torch) | Top-5 accuracy (torch) | -| :--- | :---: | :---: | -| Torch autocast (FP16) | 85.11% | 97.53% | -| NVFP4 Quantized | 84.558% | 97.36% | -| Auto Quantized (FP8 + NVFP4, 4.78 effective bits) | 84.726% | 97.434% | +| Model | FP8 | INT8 | MXFP8 | NVFP4 | INT4_AWQ | Auto | +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ## Resources diff --git a/examples/torch_onnx/torch_quant_to_onnx.py b/examples/torch_onnx/torch_quant_to_onnx.py index 05f52aa4fe..7f74e617e8 100644 --- a/examples/torch_onnx/torch_quant_to_onnx.py +++ b/examples/torch_onnx/torch_quant_to_onnx.py @@ -14,8 +14,11 @@ # limitations under the License. import argparse +import copy +import json import re import sys +import warnings from pathlib import Path # Add onnx_ptq to path for shared modules @@ -44,7 +47,7 @@ mp.set_start_method("spawn", force=True) # Needed for data loader with multiple workers -QUANT_CONFIG_DICT = { +QUANT_CONFIG_DICT: dict[str, dict] = { "fp8": mtq.FP8_DEFAULT_CFG, "int8": mtq.INT8_DEFAULT_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, @@ -52,12 +55,61 @@ "int4_awq": mtq.INT4_AWQ_CFG, } +_FP8_CONV_OVERRIDE: list = [ + { + "parent_class": "nn.Conv2d", + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "parent_class": "nn.Conv2d", + "quantizer_name": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, +] + +_INT8_CONV_OVERRIDE: list = [ + { + "parent_class": "nn.Conv2d", + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + { + "parent_class": "nn.Conv2d", + "quantizer_name": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, +] + + +def get_quant_config(quantize_mode): + """Get quantization config, overriding Conv2d for TRT compatibility. + + TensorRT only supports FP8 and INT8 for Conv layers. + - For MXFP8, NVFP4: override Conv2d to FP8 + - For INT4_AWQ: override Conv2d to INT8 + """ + config: dict = copy.deepcopy(QUANT_CONFIG_DICT[quantize_mode]) + if quantize_mode in ("mxfp8", "nvfp4"): + warnings.warn( + f"TensorRT only supports FP8/INT8 for Conv layers. " + f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode." + ) + config["quant_cfg"].extend(_FP8_CONV_OVERRIDE) + elif quantize_mode == "int4_awq": + warnings.warn( + "TensorRT only supports FP8/INT8 for Conv layers. " + "Overriding Conv2d quantization to INT8 for 'int4_awq' mode." + ) + config["quant_cfg"].extend(_INT8_CONV_OVERRIDE) + return config + def filter_func(name): """Filter function to exclude certain layers from quantization.""" pattern = re.compile( r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|" - r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed).*" + r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*" ) return pattern.match(name) is not None @@ -121,6 +173,18 @@ def loss_func(output, batch): return F.cross_entropy(output, batch["label"]) +def _disable_inplace_relu(model): + """Replace inplace ReLU with non-inplace ReLU throughout the model. + + This is needed for auto_quantize which uses backward hooks for gradient-based + sensitivity scoring. Inplace ReLU on views created by custom Functions causes + PyTorch autograd errors. + """ + for module in model.modules(): + if isinstance(module, torch.nn.ReLU) and module.inplace: + module.inplace = False + + def auto_quantize_model( model, data_loader, @@ -142,6 +206,7 @@ def auto_quantize_model( Returns: Tuple of (quantized_model, search_state_dict) """ + _disable_inplace_relu(model) constraints = {"effective_bits": effective_bits} # Convert string format names to actual config objects @@ -255,12 +320,29 @@ def main(): default=128, help="Number of scoring steps for auto quantization. Default is 128.", ) + parser.add_argument( + "--no_pretrained", + action="store_true", + help="Don't load pretrained weights (useful for testing with random weights).", + ) + parser.add_argument( + "--model_kwargs", + type=str, + default=None, + help="JSON string of extra model kwargs (e.g., '{\"depth\": 1}').", + ) args = parser.parse_args() # Create model and move to appropriate device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = timm.create_model(args.timm_model_name, pretrained=True, num_classes=1000).to(device) + model_kwargs = json.loads(args.model_kwargs) if args.model_kwargs else {} + model = timm.create_model( + args.timm_model_name, + pretrained=not args.no_pretrained, + num_classes=1000, + **model_kwargs, + ).to(device) # Get input shape from model config input_size = get_model_input_shape(model) @@ -297,7 +379,7 @@ def main(): ) else: # Standard quantization - only load calibration data if needed - config = QUANT_CONFIG_DICT[args.quantize_mode] + config = get_quant_config(args.quantize_mode) if args.quantize_mode == "mxfp8": data_loader = None else: diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index a7e4208d00..0cb1a45f68 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1162,9 +1162,12 @@ def cast_initializer_to_dtype( node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto] ): """Casts the initializer to the given dtype.""" + input_id = None for id, input_name in enumerate(node.input): if input_name in initializer_map: input_id = id + if input_id is None: + return input_name = node.input[input_id] input = numpy_helper.to_array(initializer_map[input_name]) input = input.astype(np_dtype_map[dtype]) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 0bbec135c9..8cb741dbc7 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -608,7 +608,8 @@ def get_onnx_bytes_and_metadata( op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], ) # Change FP32 cast nodes feeding into Concat/Add to FP16 - onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add", "Sqrt"]) + op_list = ["Concat", "Add", "Sqrt", "LayerNormalization", "Clip", "Mul", "Exp"] + onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, op_list) else: onnx_opt_graph = convert_to_f16( onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False diff --git a/tests/_test_utils/torch/vision_models.py b/tests/_test_utils/torch/vision_models.py index 639dc16695..5fed1d20c1 100644 --- a/tests/_test_utils/torch/vision_models.py +++ b/tests/_test_utils/torch/vision_models.py @@ -69,7 +69,9 @@ def forward(self, x): def _create_timm_fn(name): def get_model_and_input(on_gpu: bool = False): model = timm.create_model(name) - return process_model_and_inputs(model, (torch.randn(1, 3, 224, 224),), {}, on_gpu) + data_config = timm.data.resolve_model_data_config(model) + input_size = data_config["input_size"] # e.g., (3, 224, 224) + return process_model_and_inputs(model, (torch.randn(1, *input_size),), {}, on_gpu) return get_model_and_input @@ -114,6 +116,7 @@ def get_model_and_input(on_gpu: bool = False): # "vovnet39a", # "dm_nfnet_f0", "efficientnet_b0", + "swin_tiny_patch4_window7_224", ], _create_timm_fn, ), diff --git a/tests/examples/torch_onnx/test_torch_quant_to_onnx.py b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py index 2e7bf58c31..7c2692c1d9 100644 --- a/tests/examples/torch_onnx/test_torch_quant_to_onnx.py +++ b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py @@ -14,28 +14,67 @@ # limitations under the License. +import os +import subprocess + import pytest from _test_utils.examples.run_command import extend_cmd_parts, run_example_command +# TODO: Add int4_awq once the INT4 exporter supports non-MatMul/Gemm consumer patterns +# (e.g., DQ -> Reshape -> Slice in small ViT / SwinTransformer ONNX graphs). +_QUANT_MODES = ["fp8", "int8", "mxfp8", "nvfp4", "auto"] + +_MODELS = { + "vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'), + "swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'), + "swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'), +} + +# Builder optimization level: 4 for low-bit modes, 3 otherwise +_LOW_BIT_MODES = {"fp8", "int8", "nvfp4"} + + +def _verify_trt_engine_build(onnx_save_path, quantize_mode): + """Verify the exported ONNX model can be compiled into a TensorRT engine.""" + example_dir = os.path.join( + os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx" + ) + onnx_path = os.path.join(example_dir, onnx_save_path) + assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}" + + opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3" + cmd = [ + "trtexec", + f"--onnx={onnx_path}", + "--stronglyTyped", + f"--builderOptimizationLevel={opt_level}", + ] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) + assert result.returncode == 0, ( + f"TensorRT engine build failed for {onnx_save_path} " + f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}" + ) + + +@pytest.mark.parametrize("quantize_mode", _QUANT_MODES) +@pytest.mark.parametrize("model_key", list(_MODELS)) +def test_torch_onnx(model_key, quantize_mode): + timm_model_name, model_kwargs = _MODELS[model_key] + onnx_save_path = f"{model_key}.{quantize_mode}.onnx" -# TODO: Add accuracy evaluation after we upgrade TRT version to 10.12 -@pytest.mark.parametrize( - ("quantize_mode", "onnx_save_path", "calib_size", "num_score_steps"), - [ - ("fp8", "vit_base_patch16_224.fp8.onnx", "1", "1"), - ("int8", "vit_base_patch16_224.int8.onnx", "1", "1"), - ("nvfp4", "vit_base_patch16_224.nvfp4.onnx", "1", "1"), - ("mxfp8", "vit_base_patch16_224.mxfp8.onnx", "1", "1"), - ("int4_awq", "vit_base_patch16_224.int4_awq.onnx", "1", "1"), - ("auto", "vit_base_patch16_224.auto.onnx", "1", "1"), - ], -) -def test_torch_onnx(quantize_mode, onnx_save_path, calib_size, num_score_steps): + # Step 1: Quantize and export to ONNX cmd_parts = extend_cmd_parts( ["python", "torch_quant_to_onnx.py"], + timm_model_name=timm_model_name, + model_kwargs=model_kwargs, quantize_mode=quantize_mode, onnx_save_path=onnx_save_path, - calibration_data_size=calib_size, - num_score_steps=num_score_steps, + calibration_data_size="1", + num_score_steps="1", ) + cmd_parts.append("--no_pretrained") run_example_command(cmd_parts, "torch_onnx") + + # Step 2: Verify the exported ONNX model builds a TensorRT engine + _verify_trt_engine_build(onnx_save_path, quantize_mode)