diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 5b13ff080..278486c4b 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -175,7 +175,7 @@ def convert( onnx.ModelProto: The converted mixed precision model. """ try: - self.model = onnx_utils.check_model(self.model) + onnx_utils.check_model(self.model) except onnx.checker.ValidationError as e: logger.error(f"Internal error: onnx.checker failed on input model {e}") raise Exception( diff --git a/modelopt/onnx/autocast/referencerunner.py b/modelopt/onnx/autocast/referencerunner.py index 8dc91ff08..2831f211d 100644 --- a/modelopt/onnx/autocast/referencerunner.py +++ b/modelopt/onnx/autocast/referencerunner.py @@ -24,11 +24,13 @@ import copy import io import sys +import tempfile from collections import OrderedDict import numpy as np import onnx +from modelopt.onnx import utils as onnx_utils from modelopt.onnx.autocast.logging_config import configure_logging, logger from modelopt.onnx.quantization.ort_utils import _prepare_ep_list @@ -118,13 +120,62 @@ def _load_inputs(self, inputs): return data_loader + def _get_ort_runner(self, model): + import onnxruntime as ort + from polygraphy.backend.onnx import BytesFromOnnx + from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx + + # Check if model has external data by checking: + # 1. If any initializer has data_location set to EXTERNAL (even if data is loaded) + # 2. If model size would exceed 2GB (indicating need for external data) + has_external_data = onnx_utils.check_model_uses_external_data(self.model) + + # Also check if model would be too large (>2GB) for SerializeToString + # This handles cases where model was loaded with external data already loaded + if not has_external_data: + try: + # Try to estimate size by serializing the model + # If it fails or exceeds 2GB, we need file-based approach + model_size = len(self.model.SerializeToString()) + if model_size > 2 * (1024**3): # 2GB threshold + has_external_data = True + logger.debug( + f"Model size ({model_size / (1024**3):.2f} GB) exceeds 2GB, using file-based approach" + ) + except (ValueError, AttributeError) as e: + # SerializeToString failed (likely >2GB limit), use file-based approach + if "exceeds maximum protobuf size" in str(e) or "2GB" in str(e): + has_external_data = True + logger.debug("Model exceeds protobuf 2GB limit, using file-based approach") + + if has_external_data: + logger.debug("Model has external data, using file-based approach") + # Get the actual ONNX ModelProto from ModifyOutputs wrapper + modified_model = model() + + # Use a persistent temp file to handle external data files properly + tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) + tmp_file.close() + tmp_file_path = tmp_file.name + onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True) + logger.debug(f"Model with all outputs saved to {tmp_file_path}") + session = ort.InferenceSession(tmp_file_path, providers=self.providers) + runners = [OnnxrtRunner(lambda: session)] + + else: + # For models without external data, use the original BytesFromOnnx approach (no tmp files) + logger.debug("Model has no external data, using BytesFromOnnx approach") + serialize_onnx = BytesFromOnnx(model) + build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) + runners = [OnnxrtRunner(build_onnxrt_session)] + + return runners + def run(self, inputs=None): """Run FP32 inference with provided or random inputs.""" import onnxruntime as ort from polygraphy import constants - from polygraphy.backend.onnx import BytesFromOnnx from polygraphy.backend.onnx import ModifyOutputs as ModifyOnnxOutputs - from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx from polygraphy.comparator import Comparator logger.info("Running ONNX Runtime to obtain reference outputs (this may take a while)...") @@ -133,9 +184,9 @@ def run(self, inputs=None): model_copy = copy.deepcopy(self.model) modify_outputs = ModifyOnnxOutputs(model_copy, outputs=constants.MARK_ALL) - serialize_onnx = BytesFromOnnx(modify_outputs) - build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) - runners = [OnnxrtRunner(build_onnxrt_session)] + + # Load the modified model and create an inference session + runners = self._get_ort_runner(modify_outputs) # Comparator is used despite the fact that we are using ONNXRuntime # because it provides the ability to generate random inputs using DataLoader diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 02306792a..f2b020b06 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -15,6 +15,7 @@ """Utility functions related to onnx.""" +import copy import io import os import tempfile @@ -552,7 +553,7 @@ def _get_unique_name(old_name): return onnx_model, is_modified -def check_model(model: onnx.ModelProto) -> onnx.ModelProto: +def check_model(model: onnx.ModelProto) -> None: """Checks if the given model is valid.""" if model.ByteSize() > (2 * (1024**3)): # 2GB limit with tempfile.TemporaryDirectory() as temp_dir: @@ -561,10 +562,8 @@ def check_model(model: onnx.ModelProto) -> onnx.ModelProto: onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx") save_onnx(model, onnx_tmp_path, save_as_external_data=True) onnx.checker.check_model(onnx_tmp_path) - return onnx.load(onnx_tmp_path) else: onnx.checker.check_model(model) - return model def find_lowest_common_ancestor(node1: Node, node2: Node) -> tuple[str | None, int, int]: @@ -658,15 +657,16 @@ def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: boo # Set ir_version to 10, remove it once ORT supports ir_version 11 model.ir_version = 10 - if save_as_external_data: external_data_path = os.path.basename(onnx_path) + "_data" if os.path.exists(external_data_path): logger.warning(f"Removing existing external data file: {external_data_path}") os.remove(external_data_path) + # Copy so the onnx.ModelProto object will not be modified + model_copy = copy.deepcopy(model) onnx.save_model( - model, + model_copy, onnx_path, save_as_external_data=True, all_tensors_to_one_file=True, @@ -696,6 +696,21 @@ def get_opset_version(model: onnx.ModelProto) -> int: return ai_onnx_domain[0].version +def check_model_uses_external_data(model: onnx.ModelProto) -> bool: + """Checks if the model uses external data. + + Args: + model: Loaded in-memory onnx ModelProto. + + Returns: + True if any initializer tensor has data_location set to EXTERNAL. + """ + return any( + init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL + for init in model.graph.initializer + ) + + def bfloat16_to_float32(bf16_array): """Converts a bfloat16 array (as raw data) to a float32 array.""" uint32_array = bf16_array.astype(np.uint32) << 16 diff --git a/modelopt/torch/_deploy/utils/onnx_utils.py b/modelopt/torch/_deploy/utils/onnx_utils.py index a377afcb6..9120eb73a 100644 --- a/modelopt/torch/_deploy/utils/onnx_utils.py +++ b/modelopt/torch/_deploy/utils/onnx_utils.py @@ -45,14 +45,3 @@ def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> list[str]: if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL ] return model_tensors_ext - - -def check_model_uses_external_data(model: onnx.ModelProto) -> bool: - """ - Checks if the model uses external data. - """ - model_tensors = _get_initializer_tensors(model) - return any( - tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL - for tensor in model_tensors - ) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 26a5781ed..304fb8ec7 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -42,6 +42,7 @@ ) from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero from modelopt.onnx.utils import ( + check_model_uses_external_data, get_input_names, get_input_shapes, get_node_names, @@ -55,7 +56,6 @@ from modelopt.torch.utils._pytree import TreeSpec from ..utils.onnx_optimizer import Optimizer -from .onnx_utils import check_model_uses_external_data ModelMetadata = dict[str, Any] ModelType = Any