From 8feee257a048f4bf0f07190ef6de7086e449d442 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Fri, 12 Dec 2025 10:23:33 -0500 Subject: [PATCH 1/4] Allow inference args to be passed in for most cases --- sdks/python/apache_beam/ml/inference/base.py | 13 +++++-------- .../apache_beam/ml/inference/pytorch_inference.py | 6 ------ .../apache_beam/ml/inference/sklearn_inference.py | 3 ++- .../ml/inference/tensorflow_inference.py | 6 ------ .../apache_beam/ml/inference/tensorrt_inference.py | 10 ++++++++++ .../apache_beam/ml/inference/vertex_ai_inference.py | 3 --- 6 files changed, 17 insertions(+), 24 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 2e1c4963f11d..d79565ee24da 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -213,15 +213,12 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]: return {} def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - """Validates inference_args passed in the inference call. - - Because most frameworks do not need extra arguments in their predict() call, - the default behavior is to error out if inference_args are present. """ - if inference_args: - raise ValueError( - 'inference_args were provided, but should be None because this ' - 'framework does not expect extra arguments on inferences.') + Allows model handlers to provide some validation to make sure passed in + inference args are valid. Some ModelHandlers throw here to disallow + inference args altogether. + """ + pass def update_model_path(self, model_path: Optional[str] = None): """ diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index f73eeff808ce..affbcd977f5c 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -342,9 +342,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_PyTorch' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs @@ -590,9 +587,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_PyTorch' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 1e5962ba64cb..84947bec3dfb 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -73,9 +73,10 @@ def _default_numpy_inference_fn( model: BaseEstimator, batch: Sequence[numpy.ndarray], inference_args: Optional[dict[str, Any]] = None) -> Any: + inference_args = {} if not inference_args else inference_args # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) - return model.predict(vectorized_batch) + return model.predict(vectorized_batch, **inference_args) class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py index d13ea53cf1bc..5ce293a06ac0 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py @@ -219,9 +219,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_TF_Numpy' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs @@ -360,9 +357,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_TF_Tensor' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py index 1b11bd9f39e2..b575dfa849da 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py @@ -341,3 +341,13 @@ def share_model_across_processes(self) -> bool: def model_copies(self) -> int: return self._model_copies + + def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): + """ + Currently, this model handler does not support inference args. Given that, + we will throw if any are passed in. + """ + if inference_args: + raise ValueError( + 'inference_args were provided, but should be None because this ' + 'framework does not expect extra arguments on inferences.') diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py index 471f2379cfb1..9858b59039c7 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py @@ -207,8 +207,5 @@ def request( return utils._convert_to_result( batch, prediction.predictions, prediction.deployed_model_id) - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs From 1a0ce98417d459e20a1feb5bc95c8b187ecca3cd Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Fri, 12 Dec 2025 10:27:37 -0500 Subject: [PATCH 2/4] CHANGES --- CHANGES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGES.md b/CHANGES.md index 09e249630447..cdb56a28c00e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -73,6 +73,7 @@ ## New Features / Improvements * Support configuring Firestore database on ReadFn transforms (Java) ([#36904](https://github.com/apache/beam/issues/36904)). +* (Python) Inference args are now allowed in most model handlers, except where they are explicitly/intentionally disallowed ([#37093](https://github.com/apache/beam/issues/37093)). ## Breaking Changes From d71ddb1e7a5dda7a5c1de68a81be31a476f31aa1 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Fri, 12 Dec 2025 11:02:16 -0500 Subject: [PATCH 3/4] tests --- sdks/python/apache_beam/ml/inference/base_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 66e85ce163e7..6f2d18b98736 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -292,6 +292,12 @@ def run_inference(self, batch, unused_model, inference_args=None): raise ValueError( 'run_inference should not be called because error should already be ' 'thrown from the validate_inference_args check.') + + def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): + if inference_args: + raise ValueError( + 'inference_args were provided, but should be None because this ' + 'framework does not expect extra arguments on inferences.') class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler): From a55bb69a41b19255ce75edcc295f6deacacfa34a Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Fri, 12 Dec 2025 12:27:28 -0500 Subject: [PATCH 4/4] yapf --- sdks/python/apache_beam/ml/inference/base_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 6f2d18b98736..574e71de89ce 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -292,7 +292,7 @@ def run_inference(self, batch, unused_model, inference_args=None): raise ValueError( 'run_inference should not be called because error should already be ' 'thrown from the validate_inference_args check.') - + def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): if inference_args: raise ValueError(