Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1258,12 +1258,7 @@
"Cannot convert UDTF output to Arrow. Data: <data>. Schema: <schema>. Arrow Schema: <arrow_schema>."
]
},
"UDTF_ARROW_TYPE_CAST_ERROR": {
"message": [
"Cannot convert the output value of the column '<col_name>' with type '<col_type>' to the specified return type of the column: '<arrow_type>'. Please check if the data types match and try again."
]
},
"UDTF_ARROW_TYPE_CONVERSION_ERROR": {
"UDTF_ARROW_TYPE_CONVERSION_ERROR": {
"message": [
"PyArrow UDTF must return an iterator of pyarrow.Table or pyarrow.RecordBatch objects."
]
Expand Down
81 changes: 47 additions & 34 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def convert(
assign_cols_by_name: bool = False,
int_to_decimal_coercion_enabled: bool = False,
ignore_unexpected_complex_type_values: bool = False,
is_udtf: bool = False,
is_legacy: bool = False,
) -> "pa.RecordBatch":
"""
Convert a pandas DataFrame or list of Series/DataFrames to an Arrow RecordBatch.
Expand All @@ -255,14 +255,13 @@ def convert(
Whether to enable int to decimal coercion (default False)
ignore_unexpected_complex_type_values : bool
Whether to ignore unexpected complex type values in converter (default False)
is_udtf : bool
Whether this conversion is for a UDTF. UDTFs use broader Arrow exception
handling to allow more type coercions (e.g., struct field casting via
ArrowTypeError), and convert errors to UDTF_ARROW_TYPE_CAST_ERROR.
# TODO(SPARK-55502): Unify UDTF and regular UDF conversion paths to
# eliminate the is_udtf flag.
Regular UDFs only catch ArrowInvalid to preserve legacy behavior where
e.g. string->decimal must raise an error. (default False)
is_legacy : bool
Whether to use the legacy pandas-to-Arrow conversion path. The legacy
path uses broader Arrow exception handling (ArrowException) to allow
more implicit type coercions (e.g., int->boolean, dict->struct via
ArrowTypeError). The non-legacy path only catches ArrowInvalid for
the cast fallback, so type mismatches like string->decimal raise
immediately. (default False)

Returns
-------
Expand All @@ -271,7 +270,7 @@ def convert(
import pyarrow as pa
import pandas as pd

from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkRuntimeError
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.sql.pandas.types import to_arrow_type, _create_converter_from_pandas

# Handle empty schema (0 columns)
Expand Down Expand Up @@ -318,7 +317,7 @@ def convert_column(
assign_cols_by_name=assign_cols_by_name,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values,
is_udtf=is_udtf,
is_legacy=is_legacy,
)
# Wrap the nested RecordBatch as a single StructArray column
return ArrowBatchTransformer.wrap_struct(nested_batch).column(0)
Expand All @@ -343,9 +342,10 @@ def convert_column(

mask = None if hasattr(series.array, "__arrow_array__") else series.isnull()

if is_udtf:
# UDTF path: broad ArrowException catch so that both ArrowInvalid
# AND ArrowTypeError (e.g. dict→struct) trigger the cast fallback.
if is_legacy:
# Legacy pandas conversion path: broad ArrowException catch so
# that both ArrowInvalid AND ArrowTypeError (e.g. dict->struct)
# trigger the cast fallback.
try:
try:
return pa.Array.from_pandas(
Expand All @@ -357,18 +357,26 @@ def convert_column(
target_type=arrow_type, safe=safecheck
)
raise
except pa.lib.ArrowException: # convert any Arrow error to user-friendly message
raise PySparkRuntimeError(
errorClass="UDTF_ARROW_TYPE_CAST_ERROR",
messageParameters={
"col_name": field_name,
"col_type": str(series.dtype),
"arrow_type": str(arrow_type),
},
) from None
except pa.lib.ArrowException as e:
error_msg = (
"Exception thrown when converting pandas.Series (%s) "
"with name '%s' to Arrow Array (%s)."
% (series.dtype, field_name, arrow_type)
)
if isinstance(e, TypeError):
raise PySparkTypeError(error_msg) from e
if safecheck:
error_msg += (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe "
"type check can be disabled by using SQL config "
"`spark.sql.execution.pandas."
"convertToArrowArraySafely`."
)
raise PySparkValueError(error_msg) from e
else:
# UDF path: only ArrowInvalid triggers the cast fallback.
# ArrowTypeError (e.g. stringdecimal) must NOT be silently cast.
# Non-legacy path: only ArrowInvalid triggers the cast fallback.
# ArrowTypeError (e.g. string->decimal) must NOT be silently cast.
try:
try:
return pa.Array.from_pandas(
Expand All @@ -380,21 +388,26 @@ def convert_column(
target_type=arrow_type, safe=safecheck
)
raise
except TypeError as e: # includes pa.lib.ArrowTypeError
except TypeError as e:
raise PySparkTypeError(
f"Exception thrown when converting pandas.Series ({series.dtype}) "
f"with name '{field_name}' to Arrow Array ({arrow_type})."
f"Cannot convert the output value of the column "
f"'{field_name}' with type '{series.dtype}' to the "
f"specified return type of the column: '{arrow_type}'."
f" Please check if the data types match and try again."
) from e
except ValueError as e: # includes pa.lib.ArrowInvalid
except ValueError as e:
error_msg = (
f"Exception thrown when converting pandas.Series ({series.dtype}) "
f"with name '{field_name}' to Arrow Array ({arrow_type})."
f"Failed to convert the value of the column "
f"'{field_name}' with type '{series.dtype}' to Arrow "
f"type '{arrow_type}'."
)
if safecheck:
error_msg += (
" It can be caused by overflows or other unsafe conversions "
"warned by Arrow. Arrow safe type check can be disabled by using "
"SQL config `spark.sql.execution.pandas.convertToArrowArraySafely`."
" It can be caused by overflows or other unsafe "
"conversions warned by Arrow. Arrow safe type "
"check can be disabled by using SQL config "
"`spark.sql.execution.pandas."
"convertToArrowArraySafely`."
)
raise PySparkValueError(error_msg) from e

Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def __init__(
int_to_decimal_coercion_enabled: bool = False,
prefers_large_types: bool = False,
ignore_unexpected_complex_type_values: bool = False,
is_udtf: bool = False,
is_legacy: bool = False,
):
super().__init__(
timezone,
Expand All @@ -528,7 +528,7 @@ def __init__(
)
self._assign_cols_by_name = assign_cols_by_name
self._ignore_unexpected_complex_type_values = ignore_unexpected_complex_type_values
self._is_udtf = is_udtf
self._is_legacy = is_legacy

def dump_stream(self, iterator, stream):
"""
Expand Down Expand Up @@ -567,7 +567,7 @@ def create_batch(
assign_cols_by_name=self._assign_cols_by_name,
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
ignore_unexpected_complex_type_values=self._ignore_unexpected_complex_type_values,
is_udtf=self._is_udtf,
is_legacy=self._is_legacy,
)

batches = self._write_stream_start(
Expand Down Expand Up @@ -767,9 +767,9 @@ def __init__(self, timezone, safecheck, input_type, int_to_decimal_coercion_enab
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
# UDTF-specific: ignore unexpected complex type values in converter
ignore_unexpected_complex_type_values=True,
# UDTF-specific: enables broader Arrow exception handling and
# converts errors to UDTF_ARROW_TYPE_CAST_ERROR
is_udtf=True,
# Legacy UDTF pandas conversion: enables broader Arrow exception
# handling to allow more implicit type coercions
is_legacy=True,
)

def __repr__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1382,9 +1382,7 @@ def close(self):
with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
):
with self.assertRaisesRegex(
Exception, "Exception thrown when converting pandas.Series"
):
with self.assertRaisesRegex(Exception, "Failed to convert the value"):
(
df.groupBy("id")
.transformWithStateInPandas(
Expand Down
19 changes: 9 additions & 10 deletions python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,14 @@ def check_apply_in_pandas_returning_incompatible_type(self):
with self.subTest(convert="string to double"):
pandas_type_name = "object" if LooseVersion(pd.__version__) < "3.0.0" else "str"
expected = (
rf"ValueError: Exception thrown when converting pandas.Series \({pandas_type_name}\) "
r"with name 'k' to Arrow Array \(double\)."
rf"ValueError: Failed to convert the value of the column 'k' "
rf"with type '{pandas_type_name}' to Arrow type 'double'\."
)
if safely:
expected = expected + (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
" It can be caused by overflows or other unsafe "
"conversions warned by Arrow. Arrow safe type "
"check can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
self._test_merge_error(
Expand All @@ -276,8 +276,9 @@ def check_apply_in_pandas_returning_incompatible_type(self):
# sometimes we see TypeErrors
with self.subTest(convert="double to string"):
expected = (
r"TypeError: Exception thrown when converting pandas.Series \(float64\) "
r"with name 'k' to Arrow Array \(string\)."
r"TypeError: Cannot convert the output value of the column 'k' "
r"with type 'float64' to the specified return type of the column: "
r"'string'\. Please check if the data types match and try again\."
)
self._test_merge_error(
fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": [2.0]}),
Expand Down Expand Up @@ -321,9 +322,7 @@ def int_to_decimal_merge(lft, rgt):
with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
):
with self.assertRaisesRegex(
PythonException, "Exception thrown when converting pandas.Series"
):
with self.assertRaisesRegex(PythonException, "Failed to convert the value"):
(
left.groupby("id")
.cogroup(right.groupby("id"))
Expand Down
19 changes: 9 additions & 10 deletions python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,14 @@ def check_apply_in_pandas_returning_incompatible_type(self):
with self.subTest(convert="string to double"):
pandas_type_name = "object" if LooseVersion(pd.__version__) < "3.0.0" else "str"
expected = (
rf"ValueError: Exception thrown when converting pandas.Series \({pandas_type_name}\) "
r"with name 'mean' to Arrow Array \(double\)."
rf"ValueError: Failed to convert the value of the column 'mean' "
rf"with type '{pandas_type_name}' to Arrow type 'double'\."
)
if safely:
expected = expected + (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
" It can be caused by overflows or other unsafe "
"conversions warned by Arrow. Arrow safe type "
"check can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
with self.assertRaisesRegex(PythonException, expected):
Expand All @@ -369,8 +369,9 @@ def check_apply_in_pandas_returning_incompatible_type(self):
with self.subTest(convert="double to string"):
with self.assertRaisesRegex(
PythonException,
r"TypeError: Exception thrown when converting pandas.Series \(float64\) "
r"with name 'mean' to Arrow Array \(string\).",
r"TypeError: Cannot convert the output value of the column 'mean' "
r"with type 'float64' to the specified return type of the column: "
r"'string'\. Please check if the data types match and try again\.",
):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(),)]),
Expand All @@ -397,9 +398,7 @@ def int_to_decimal_func(key, pdf):
with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
):
with self.assertRaisesRegex(
PythonException, "Exception thrown when converting pandas.Series"
):
with self.assertRaisesRegex(PythonException, "Failed to convert the value"):
(
self.data.groupby("id")
.applyInPandas(
Expand Down
24 changes: 12 additions & 12 deletions python/pyspark/sql/tests/pandas/test_pandas_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,14 @@ def func(iterator):
pandas_type_name = "object" if LooseVersion(pd.__version__) < "3.0.0" else "str"

expected = (
r"ValueError: Exception thrown when converting pandas.Series "
rf"\({pandas_type_name}\) with name 'id' to Arrow Array \(double\)."
rf"ValueError: Failed to convert the value of the column 'id' "
rf"with type '{pandas_type_name}' to Arrow type 'double'\."
)
if safely:
expected = expected + (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
" It can be caused by overflows or other unsafe "
"conversions warned by Arrow. Arrow safe type "
"check can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
with self.assertRaisesRegex(PythonException, expected):
Expand All @@ -333,11 +333,11 @@ def func(iterator):
)
if safely:
expected = (
r"ValueError: Exception thrown when converting pandas.Series "
r"\(float64\) with name 'id' to Arrow Array \(int32\)."
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
r"ValueError: Failed to convert the value of the column 'id' "
r"with type 'float64' to Arrow type 'int32'\."
" It can be caused by overflows or other unsafe "
"conversions warned by Arrow. Arrow safe type "
"check can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
with self.assertRaisesRegex(PythonException, expected):
Expand Down Expand Up @@ -489,8 +489,8 @@ def func(iterator):
pandas_type_name = "object" if LooseVersion(pd.__version__) < "3.0.0" else "str"
with self.assertRaisesRegex(
PythonException,
f"PySparkValueError: Exception thrown when converting pandas.Series \\({pandas_type_name}\\) "
"with name 'id' to Arrow Array \\(int32\\)\\.",
f"PySparkValueError: Failed to convert the value of the column 'id' "
f"with type '{pandas_type_name}' to Arrow type 'int32'\\.",
):
df.collect()

Expand Down
10 changes: 3 additions & 7 deletions python/pyspark/sql/tests/pandas/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,7 @@ def udf(column):

# Since 0.11.0, PyArrow supports the feature to raise an error for unsafe cast.
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": True}):
with self.assertRaisesRegex(
Exception, "Exception thrown when converting pandas.Series"
):
with self.assertRaisesRegex(Exception, "Failed to convert the value"):
df.select(["A"]).withColumn("udf", udf("A")).collect()

# Disabling Arrow safe type check.
Expand All @@ -342,9 +340,7 @@ def udf(column):

# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": True}):
with self.assertRaisesRegex(
Exception, "Exception thrown when converting pandas.Series"
):
with self.assertRaisesRegex(Exception, "Failed to convert the value"):
df.withColumn("udf", udf("id")).collect()

# Disabling safe type check, let Arrow do the cast anyway.
Expand Down Expand Up @@ -375,7 +371,7 @@ def int_to_decimal_udf(column):
):
self.assertRaisesRegex(
PythonException,
"Exception thrown when converting pandas.Series",
"Failed to convert the value",
df.withColumn("decimal_val", int_to_decimal_udf("id")).collect,
)

Expand Down
Loading