|
17 | 17 |
|
18 | 18 | import pyarrow as pa |
19 | 19 | import pytest |
20 | | -from datafusion import column, udf |
| 20 | +from datafusion import SessionContext, column, udf |
21 | 21 | from datafusion import functions as f |
22 | 22 |
|
23 | 23 |
|
@@ -148,3 +148,63 @@ def uuid_version(uuid): |
148 | 148 | ) |
149 | 149 |
|
150 | 150 | assert results[0][0].to_pylist() == [4, 4, 4, 4, 4] |
| 151 | + |
| 152 | + |
| 153 | +def test_udf_with_nullability(ctx: SessionContext) -> None: |
| 154 | + import pyarrow.compute as pc |
| 155 | + |
| 156 | + field_nullable_i64 = pa.field("with_nulls", type=pa.int64(), nullable=True) |
| 157 | + field_non_nullable_i64 = pa.field("no_nulls", type=pa.int64(), nullable=False) |
| 158 | + |
| 159 | + @udf([field_nullable_i64], field_nullable_i64, "stable") |
| 160 | + def nullable_abs(input_col): |
| 161 | + return pc.abs(input_col) |
| 162 | + |
| 163 | + @udf([field_non_nullable_i64], field_non_nullable_i64, "stable") |
| 164 | + def non_nullable_abs(input_col): |
| 165 | + return pc.abs(input_col) |
| 166 | + |
| 167 | + batch = pa.record_batch( |
| 168 | + { |
| 169 | + "with_nulls": pa.array([-2, None, 0, 1, 2]), |
| 170 | + "no_nulls": pa.array([-2, -1, 0, 1, 2]), |
| 171 | + }, |
| 172 | + schema=pa.schema( |
| 173 | + [ |
| 174 | + field_nullable_i64, |
| 175 | + field_non_nullable_i64, |
| 176 | + ] |
| 177 | + ), |
| 178 | + ) |
| 179 | + ctx.register_record_batches("t", [[batch]]) |
| 180 | + df = ctx.table("t") |
| 181 | + |
| 182 | + # Input matches expected, nullable |
| 183 | + df_result = df.select(nullable_abs(column("with_nulls"))) |
| 184 | + returned_field = df_result.schema().field(0) |
| 185 | + assert returned_field.nullable |
| 186 | + results = df_result.collect() |
| 187 | + assert results[0][0].to_pylist() == [2, None, 0, 1, 2] |
| 188 | + |
| 189 | + # Input coercible to expected, nullable |
| 190 | + df_result = df.select(nullable_abs(column("no_nulls"))) |
| 191 | + returned_field = df_result.schema().field(0) |
| 192 | + assert returned_field.nullable |
| 193 | + results = df_result.collect() |
| 194 | + assert results[0][0].to_pylist() == [2, 1, 0, 1, 2] |
| 195 | + |
| 196 | + # Input matches expected, no nulls |
| 197 | + df_result = df.select(non_nullable_abs(column("no_nulls"))) |
| 198 | + returned_field = df_result.schema().field(0) |
| 199 | + assert not returned_field.nullable |
| 200 | + results = df_result.collect() |
| 201 | + assert results[0][0].to_pylist() == [2, 1, 0, 1, 2] |
| 202 | + |
| 203 | + # Invalid - requires non-nullable input but that is not possible |
| 204 | + df_result = df.select(non_nullable_abs(column("with_nulls"))) |
| 205 | + returned_field = df_result.schema().field(0) |
| 206 | + assert not returned_field.nullable |
| 207 | + |
| 208 | + with pytest.raises(Exception) as e_info: |
| 209 | + _results = df_result.collect() |
| 210 | + assert "InvalidArgumentError" in str(e_info) |
0 commit comments