Skip to content

Commit f1cf7a1

Browse files
committed
Add unit test to check field inputs
1 parent aa7d35c commit f1cf7a1

File tree

1 file changed

+61
-1
lines changed

1 file changed

+61
-1
lines changed

python/tests/test_udf.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pyarrow as pa
1919
import pytest
20-
from datafusion import column, udf
20+
from datafusion import SessionContext, column, udf
2121
from datafusion import functions as f
2222

2323

@@ -148,3 +148,63 @@ def uuid_version(uuid):
148148
)
149149

150150
assert results[0][0].to_pylist() == [4, 4, 4, 4, 4]
151+
152+
153+
def test_udf_with_nullability(ctx: SessionContext) -> None:
154+
import pyarrow.compute as pc
155+
156+
field_nullable_i64 = pa.field("with_nulls", type=pa.int64(), nullable=True)
157+
field_non_nullable_i64 = pa.field("no_nulls", type=pa.int64(), nullable=False)
158+
159+
@udf([field_nullable_i64], field_nullable_i64, "stable")
160+
def nullable_abs(input_col):
161+
return pc.abs(input_col)
162+
163+
@udf([field_non_nullable_i64], field_non_nullable_i64, "stable")
164+
def non_nullable_abs(input_col):
165+
return pc.abs(input_col)
166+
167+
batch = pa.record_batch(
168+
{
169+
"with_nulls": pa.array([-2, None, 0, 1, 2]),
170+
"no_nulls": pa.array([-2, -1, 0, 1, 2]),
171+
},
172+
schema=pa.schema(
173+
[
174+
field_nullable_i64,
175+
field_non_nullable_i64,
176+
]
177+
),
178+
)
179+
ctx.register_record_batches("t", [[batch]])
180+
df = ctx.table("t")
181+
182+
# Input matches expected, nullable
183+
df_result = df.select(nullable_abs(column("with_nulls")))
184+
returned_field = df_result.schema().field(0)
185+
assert returned_field.nullable
186+
results = df_result.collect()
187+
assert results[0][0].to_pylist() == [2, None, 0, 1, 2]
188+
189+
# Input coercible to expected, nullable
190+
df_result = df.select(nullable_abs(column("no_nulls")))
191+
returned_field = df_result.schema().field(0)
192+
assert returned_field.nullable
193+
results = df_result.collect()
194+
assert results[0][0].to_pylist() == [2, 1, 0, 1, 2]
195+
196+
# Input matches expected, no nulls
197+
df_result = df.select(non_nullable_abs(column("no_nulls")))
198+
returned_field = df_result.schema().field(0)
199+
assert not returned_field.nullable
200+
results = df_result.collect()
201+
assert results[0][0].to_pylist() == [2, 1, 0, 1, 2]
202+
203+
# Invalid - requires non-nullable input but that is not possible
204+
df_result = df.select(non_nullable_abs(column("with_nulls")))
205+
returned_field = df_result.schema().field(0)
206+
assert not returned_field.nullable
207+
208+
with pytest.raises(Exception) as e_info:
209+
_results = df_result.collect()
210+
assert "InvalidArgumentError" in str(e_info)

0 commit comments

Comments
 (0)