diff --git a/CHANGES.md b/CHANGES.md index 09e249630447..cdb56a28c00e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -73,6 +73,7 @@ ## New Features / Improvements * Support configuring Firestore database on ReadFn transforms (Java) ([#36904](https://github.com/apache/beam/issues/36904)). +* (Python) Inference args are now allowed in most model handlers, except where they are explicitly/intentionally disallowed ([#37093](https://github.com/apache/beam/issues/37093)). ## Breaking Changes diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 2e1c4963f11d..d79565ee24da 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -213,15 +213,12 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]: return {} def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - """Validates inference_args passed in the inference call. - - Because most frameworks do not need extra arguments in their predict() call, - the default behavior is to error out if inference_args are present. """ - if inference_args: - raise ValueError( - 'inference_args were provided, but should be None because this ' - 'framework does not expect extra arguments on inferences.') + Allows model handlers to provide some validation to make sure passed in + inference args are valid. Some ModelHandlers throw here to disallow + inference args altogether. + """ + pass def update_model_path(self, model_path: Optional[str] = None): """ diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 66e85ce163e7..574e71de89ce 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -293,6 +293,12 @@ def run_inference(self, batch, unused_model, inference_args=None): 'run_inference should not be called because error should already be ' 'thrown from the validate_inference_args check.') + def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): + if inference_args: + raise ValueError( + 'inference_args were provided, but should be None because this ' + 'framework does not expect extra arguments on inferences.') + class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler): def run_inference(self, batch, unused_model, inference_args=None): diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index f73eeff808ce..affbcd977f5c 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -342,9 +342,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_PyTorch' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs @@ -590,9 +587,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_PyTorch' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 1e5962ba64cb..84947bec3dfb 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -73,9 +73,10 @@ def _default_numpy_inference_fn( model: BaseEstimator, batch: Sequence[numpy.ndarray], inference_args: Optional[dict[str, Any]] = None) -> Any: + inference_args = {} if not inference_args else inference_args # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) - return model.predict(vectorized_batch) + return model.predict(vectorized_batch, **inference_args) class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py index d13ea53cf1bc..5ce293a06ac0 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py @@ -219,9 +219,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_TF_Numpy' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs @@ -360,9 +357,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_TF_Tensor' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py index 1b11bd9f39e2..b575dfa849da 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py @@ -341,3 +341,13 @@ def share_model_across_processes(self) -> bool: def model_copies(self) -> int: return self._model_copies + + def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): + """ + Currently, this model handler does not support inference args. Given that, + we will throw if any are passed in. + """ + if inference_args: + raise ValueError( + 'inference_args were provided, but should be None because this ' + 'framework does not expect extra arguments on inferences.') diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py index 471f2379cfb1..9858b59039c7 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py @@ -207,8 +207,5 @@ def request( return utils._convert_to_result( batch, prediction.predictions, prediction.deployed_model_id) - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs