diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 11928e79ffc62..ff01d4ac835ba 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -208,6 +208,10 @@ def _hash_pandas_object( values, encoding=encoding, hash_key=hash_key, categorize=categorize ) + def _cast_pointwise_result(self, values: ArrayLike) -> ArrayLike: + values = np.asarray(values, dtype=object) + return lib.maybe_convert_objects(values, convert_non_numeric=True) + # Signature of "argmin" incompatible with supertype "ExtensionArray" def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override] # override base class by adding axis keyword diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index edf1d7ddcaa76..1a52df697fff4 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -442,7 +442,7 @@ def _cast_pointwise_result(self, values) -> ArrayLike: # e.g. test_by_column_values_with_same_starting_value with nested # values, one entry of which is an ArrowStringArray # or test_agg_lambda_complex128_dtype_conversion for complex values - return super()._cast_pointwise_result(values) + return values if pa.types.is_null(arr.type): if lib.infer_dtype(values) == "decimal": @@ -498,7 +498,7 @@ def _cast_pointwise_result(self, values) -> ArrayLike: if self.dtype.na_value is np.nan: # ArrowEA has different semantics, so we return numpy-based # result instead - return super()._cast_pointwise_result(values) + return values return ArrowExtensionArray(arr) return self._from_pyarrow_array(arr) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 225cc888d50db..0d2b4c64d6e3d 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -19,6 +19,7 @@ cast, overload, ) +import warnings import numpy as np @@ -33,6 +34,7 @@ cache_readonly, set_module, ) +from pandas.util._exceptions import find_stack_level from pandas.util._validators import ( validate_bool_kwarg, validate_insert_loc, @@ -86,6 +88,7 @@ AstypeArg, AxisInt, Dtype, + DtypeObj, FillnaOptions, InterpolateOptions, NumpySorter, @@ -353,6 +356,38 @@ def _from_sequence_of_strings( """ raise AbstractMethodError(cls) + @classmethod + def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self: + """ + Strict analogue to _from_sequence, allowing only sequences of scalars + that should be specifically inferred to the given dtype. + + Parameters + ---------- + scalars : sequence + dtype : ExtensionDtype + + Raises + ------ + TypeError or ValueError + + Notes + ----- + This is called in a try/except block when casting the result of a + pointwise operation. + """ + try: + return cls._from_sequence(scalars, dtype=dtype, copy=False) + except (ValueError, TypeError): + raise + except Exception: + warnings.warn( + "_from_scalars should only raise ValueError or TypeError. " + "Consider overriding _from_scalars where appropriate.", + stacklevel=find_stack_level(), + ) + raise + @classmethod def _from_factorized(cls, values, original): """ @@ -383,13 +418,26 @@ def _from_factorized(cls, values, original): """ raise AbstractMethodError(cls) - def _cast_pointwise_result(self, values) -> ArrayLike: + def _cast_pointwise_result(self, values: ArrayLike) -> ArrayLike: """ + Construct an ExtensionArray after a pointwise operation. + Cast the result of a pointwise operation (e.g. Series.map) to an - array, preserve dtype_backend if possible. + array. This is not required to return an ExtensionArray of the same + type as self or of the same dtype. It can also return another + ExtensionArray of the same "family" if you implement multiple + ExtensionArrays/Dtypes that are interoperable (e.g. if you have float + array with units, this method can return an int array with units). + + If converting to your own ExtensionArray is not possible, this method + can raise an error (TypeError or ValueError) or return the input + `values` as-is. Then pandas will do the further type inference. + """ - values = np.asarray(values, dtype=object) - return lib.maybe_convert_objects(values, convert_non_numeric=True) + try: + return type(self)._from_scalars(values, dtype=self.dtype) + except (ValueError, TypeError): + return values # ------------------------------------------------------------------------ # Must be a Sequence diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index c15a196dc6727..bd317b99455d9 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -622,7 +622,7 @@ def _from_factorized(cls, values, original) -> Self: return cls(values, dtype=original.dtype) def _cast_pointwise_result(self, values): - result = super()._cast_pointwise_result(values) + result = lib.maybe_convert_objects(values, convert_non_numeric=True) if result.dtype.kind == self.dtype.kind: try: # e.g. test_groupby_agg_extension diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 2dc45bd752ea5..6ff641ecd4056 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -414,6 +414,42 @@ def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT: return arr +def cast_pointwise_result( + result: ArrayLike, + original_array: ArrayLike, +) -> ArrayLike: + """ + Try casting result of a pointwise operation back to the original dtype if + appropriate. + + Parameters + ---------- + result : array-like + Result to cast. + original_array : array-like + Input array from which result was calculated. + + Returns + ------- + array-like + """ + if isinstance(original_array.dtype, ExtensionDtype): + try: + result = original_array._cast_pointwise_result(result) + except (TypeError, ValueError): + pass + + if isinstance(result.dtype, ExtensionDtype): + return result + + if not isinstance(result, np.ndarray): + result = np.asarray(result, dtype=object) + + if result.dtype != object: + return result + return lib.maybe_convert_objects(result, convert_non_numeric=True) + + @overload def ensure_dtype_can_hold_na(dtype: np.dtype) -> np.dtype: ... diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 9d6af3c7b9917..284aeb1072dff 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -84,6 +84,7 @@ from pandas.core.dtypes.cast import ( LossySetitemError, can_hold_element, + cast_pointwise_result, construct_1d_arraylike_from_scalar, construct_2d_arraylike_from_scalar, find_common_type, @@ -11200,7 +11201,7 @@ def _append_internal( if isinstance(self.index.dtype, ExtensionDtype): # GH#41626 retain e.g. CategoricalDtype if reached via # df.loc[key] = item - row_df.index = self.index.array._cast_pointwise_result(row_df.index._values) + row_df.index = cast_pointwise_result(row_df.index._values, self.index.array) # infer_objects is needed for # test_append_empty_frame_to_series_with_dateutil_tz diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index d86264cb95dc5..5907f984f0884 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -35,6 +35,7 @@ from pandas.util._decorators import cache_readonly from pandas.core.dtypes.cast import ( + cast_pointwise_result, maybe_downcast_to_dtype, ) from pandas.core.dtypes.common import ( @@ -963,7 +964,7 @@ def agg_series( np.ndarray or ExtensionArray """ result = self._aggregate_series_pure_python(obj, func) - return obj.array._cast_pointwise_result(result) + return cast_pointwise_result(result, obj.array) @final def _aggregate_series_pure_python( diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index df898ccf47bfa..e41f20e5c011c 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -89,6 +89,7 @@ from pandas.core.dtypes.cast import ( LossySetitemError, can_hold_element, + cast_pointwise_result, common_dtype_categorical_compat, find_result_type, infer_dtype_from, @@ -6531,7 +6532,7 @@ def map(self, mapper, na_action: Literal["ignore"] | None = None): # e.g. if we are floating and new_values is all ints, then we # don't want to cast back to floating. But if we are UInt64 # and new_values is all ints, we want to try. - new_values = arr._cast_pointwise_result(new_values) + new_values = cast_pointwise_result(new_values, arr) dtype = new_values.dtype return Index(new_values, dtype=dtype, copy=False, name=self.name) diff --git a/pandas/core/series.py b/pandas/core/series.py index 7c949952801ca..b713979261b13 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -69,6 +69,7 @@ from pandas.core.dtypes.astype import astype_is_view from pandas.core.dtypes.cast import ( LossySetitemError, + cast_pointwise_result, construct_1d_arraylike_from_scalar, find_common_type, infer_dtype_from, @@ -3252,7 +3253,7 @@ def combine( new_values[:] = [func(lv, other) for lv in self._values] new_name = self.name - res_values = self.array._cast_pointwise_result(new_values) + res_values = cast_pointwise_result(new_values, self.array) return self._constructor( res_values, dtype=res_values.dtype, diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 7d055e2143112..17288916b8ffc 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -112,13 +112,12 @@ def _from_factorized(cls, values, original): return cls(values) def _cast_pointwise_result(self, values): - result = super()._cast_pointwise_result(values) try: # If this were ever made a non-test EA, special-casing could # be avoided by handling Decimal in maybe_convert_objects - res = type(self)._from_sequence(result, dtype=self.dtype) + res = type(self)._from_sequence(values, dtype=self.dtype) except (ValueError, TypeError): - return result + return values return res _HANDLED_TYPES = (decimal.Decimal, numbers.Number, np.ndarray) diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 1878fac1b8111..07c4a3173c763 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -94,11 +94,10 @@ def _from_factorized(cls, values, original): return cls([UserDict(x) for x in values if x != ()]) def _cast_pointwise_result(self, values): - result = super()._cast_pointwise_result(values) try: - return type(self)._from_sequence(result, dtype=self.dtype) + return type(self)._from_sequence(values, dtype=self.dtype) except (ValueError, TypeError): - return result + return values def __getitem__(self, item): if isinstance(item, tuple):