1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use std:: any:: Any ;
19+ use std:: hash:: { Hash , Hasher } ;
20+ use std:: ptr:: addr_of;
1821use std:: sync:: Arc ;
1922
23+ use arrow:: datatypes:: { Field , FieldRef } ;
24+ use arrow:: ffi:: { FFI_ArrowArray , FFI_ArrowSchema } ;
2025use datafusion:: arrow:: array:: { make_array, Array , ArrayData , ArrayRef } ;
2126use datafusion:: arrow:: datatypes:: DataType ;
22- use datafusion:: arrow:: pyarrow:: { FromPyArrow , PyArrowType , ToPyArrow } ;
27+ use datafusion:: arrow:: pyarrow:: { FromPyArrow , PyArrowType } ;
2328use datafusion:: error:: DataFusionError ;
24- use datafusion:: logical_expr:: function:: ScalarFunctionImplementation ;
25- use datafusion:: logical_expr:: { create_udf, ColumnarValue , ScalarUDF , ScalarUDFImpl } ;
29+ use datafusion:: logical_expr:: {
30+ ColumnarValue , ReturnFieldArgs , ScalarFunctionArgs , ScalarUDF , ScalarUDFImpl , Signature ,
31+ Volatility ,
32+ } ;
2633use datafusion_ffi:: udf:: FFI_ScalarUDF ;
34+ use pyo3:: ffi:: Py_uintptr_t ;
2735use pyo3:: prelude:: * ;
2836use pyo3:: types:: { PyCapsule , PyTuple } ;
2937
3038use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionResult } ;
3139use crate :: expr:: PyExpr ;
3240use crate :: utils:: { parse_volatility, validate_pycapsule} ;
3341
34- /// Create a Rust callable function from a python function that expects pyarrow arrays
35- fn pyarrow_function_to_rust (
42+ /// This struct holds the Python written function that is a
43+ /// ScalarUDF.
44+ #[ derive( Debug ) ]
45+ struct PythonFunctionScalarUDF {
46+ name : String ,
3647 func : Py < PyAny > ,
37- ) -> impl Fn ( & [ ArrayRef ] ) -> Result < ArrayRef , DataFusionError > {
38- move |args : & [ ArrayRef ] | -> Result < ArrayRef , DataFusionError > {
48+ signature : Signature ,
49+ return_field : FieldRef ,
50+ }
51+
52+ impl PythonFunctionScalarUDF {
53+ fn new (
54+ name : String ,
55+ func : Py < PyAny > ,
56+ input_fields : Vec < Field > ,
57+ return_field : Field ,
58+ volatility : Volatility ,
59+ ) -> Self {
60+ let input_types = input_fields. iter ( ) . map ( |f| f. data_type ( ) . clone ( ) ) . collect ( ) ;
61+ let signature = Signature :: exact ( input_types, volatility) ;
62+ Self {
63+ name,
64+ func,
65+ signature,
66+ return_field : Arc :: new ( return_field) ,
67+ }
68+ }
69+ }
70+
71+ impl Eq for PythonFunctionScalarUDF { }
72+ impl PartialEq for PythonFunctionScalarUDF {
73+ fn eq ( & self , other : & Self ) -> bool {
74+ self . name == other. name
75+ && self . signature == other. signature
76+ && self . return_field == other. return_field
77+ && Python :: attach ( |py| self . func . bind ( py) . eq ( other. func . bind ( py) ) . unwrap_or ( false ) )
78+ }
79+ }
80+
81+ impl Hash for PythonFunctionScalarUDF {
82+ fn hash < H : Hasher > ( & self , state : & mut H ) {
83+ self . name . hash ( state) ;
84+ self . signature . hash ( state) ;
85+ self . return_field . hash ( state) ;
86+
87+ Python :: attach ( |py| {
88+ let py_hash = self . func . bind ( py) . hash ( ) . unwrap_or ( 0 ) ; // Handle unhashable objects
89+
90+ state. write_isize ( py_hash) ;
91+ } ) ;
92+ }
93+ }
94+
95+ fn array_to_pyarrow_with_field (
96+ py : Python ,
97+ array : ArrayRef ,
98+ field : & FieldRef ,
99+ ) -> PyResult < Py < PyAny > > {
100+ let array = FFI_ArrowArray :: new ( & array. to_data ( ) ) ;
101+ let schema = FFI_ArrowSchema :: try_from ( field) . map_err ( py_datafusion_err) ?;
102+
103+ let module = py. import ( "pyarrow" ) ?;
104+ let class = module. getattr ( "Array" ) ?;
105+ let array = class. call_method1 (
106+ "_import_from_c" ,
107+ (
108+ addr_of ! ( array) as Py_uintptr_t ,
109+ addr_of ! ( schema) as Py_uintptr_t ,
110+ ) ,
111+ ) ?;
112+ Ok ( array. unbind ( ) )
113+ }
114+
115+ impl ScalarUDFImpl for PythonFunctionScalarUDF {
116+ fn as_any ( & self ) -> & dyn Any {
117+ self
118+ }
119+
120+ fn name ( & self ) -> & str {
121+ & self . name
122+ }
123+
124+ fn signature ( & self ) -> & Signature {
125+ & self . signature
126+ }
127+
128+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> datafusion:: common:: Result < DataType > {
129+ unimplemented ! ( )
130+ }
131+
132+ fn return_field_from_args (
133+ & self ,
134+ _args : ReturnFieldArgs ,
135+ ) -> datafusion:: common:: Result < FieldRef > {
136+ Ok ( Arc :: clone ( & self . return_field ) )
137+ }
138+
139+ fn invoke_with_args (
140+ & self ,
141+ args : ScalarFunctionArgs ,
142+ ) -> datafusion:: common:: Result < ColumnarValue > {
143+ let num_rows = args. number_rows ;
39144 Python :: attach ( |py| {
40145 // 1. cast args to Pyarrow arrays
41146 let py_args = args
42- . iter ( )
43- . map ( |arg| {
44- arg. into_data ( )
45- . to_pyarrow ( py)
46- . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) )
147+ . args
148+ . into_iter ( )
149+ . zip ( args. arg_fields )
150+ . map ( |( arg, field) | {
151+ let array = arg. to_array ( num_rows) ?;
152+ array_to_pyarrow_with_field ( py, array, & field) . map_err ( to_datafusion_err)
47153 } )
48154 . collect :: < Result < Vec < _ > , _ > > ( ) ?;
49155 let py_args = PyTuple :: new ( py, py_args) . map_err ( to_datafusion_err) ?;
50156
51157 // 2. call function
52- let value = func
158+ let value = self
159+ . func
53160 . call ( py, py_args, None )
54161 . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
55162
56163 // 3. cast to arrow::array::Array
57164 let array_data = ArrayData :: from_pyarrow_bound ( value. bind ( py) )
58165 . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
59- Ok ( make_array ( array_data) )
166+ Ok ( ColumnarValue :: Array ( make_array ( array_data) ) )
60167 } )
61168 }
62169}
63170
64- /// Create a DataFusion's UDF implementation from a python function
65- /// that expects pyarrow arrays. This is more efficient as it performs
66- /// a zero-copy of the contents.
67- fn to_scalar_function_impl ( func : Py < PyAny > ) -> ScalarFunctionImplementation {
68- // Make the python function callable from rust
69- let pyarrow_func = pyarrow_function_to_rust ( func) ;
70-
71- // Convert input/output from datafusion ColumnarValue to arrow arrays
72- Arc :: new ( move |args : & [ ColumnarValue ] | {
73- let array_refs = ColumnarValue :: values_to_arrays ( args) ?;
74- let array_result = pyarrow_func ( & array_refs) ?;
75- Ok ( array_result. into ( ) )
76- } )
77- }
78-
79171/// Represents a PyScalarUDF
80172#[ pyclass( frozen, name = "ScalarUDF" , module = "datafusion" , subclass) ]
81173#[ derive( Debug , Clone ) ]
@@ -88,19 +180,21 @@ impl PyScalarUDF {
88180 #[ new]
89181 #[ pyo3( signature=( name, func, input_types, return_type, volatility) ) ]
90182 fn new (
91- name : & str ,
183+ name : String ,
92184 func : Py < PyAny > ,
93- input_types : PyArrowType < Vec < DataType > > ,
94- return_type : PyArrowType < DataType > ,
185+ input_types : PyArrowType < Vec < Field > > ,
186+ return_type : PyArrowType < Field > ,
95187 volatility : & str ,
96188 ) -> PyResult < Self > {
97- let function = create_udf (
189+ let py_function = PythonFunctionScalarUDF :: new (
98190 name,
191+ func,
99192 input_types. 0 ,
100193 return_type. 0 ,
101194 parse_volatility ( volatility) ?,
102- to_scalar_function_impl ( func) ,
103195 ) ;
196+ let function = ScalarUDF :: new_from_impl ( py_function) ;
197+
104198 Ok ( Self { function } )
105199 }
106200
0 commit comments