@@ -154,6 +154,15 @@ pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction {
154154 } )
155155}
156156
157+ fn aggregate_udf_from_capsule ( capsule : & Bound < ' _ , PyCapsule > ) -> PyDataFusionResult < AggregateUDF > {
158+ validate_pycapsule ( capsule, "datafusion_aggregate_udf" ) ?;
159+
160+ let udaf = unsafe { capsule. reference :: < FFI_AggregateUDF > ( ) } ;
161+ let udaf: ForeignAggregateUDF = udaf. try_into ( ) ?;
162+
163+ Ok ( udaf. into ( ) )
164+ }
165+
157166/// Represents an AggregateUDF
158167#[ pyclass( frozen, name = "AggregateUDF" , module = "datafusion" , subclass) ]
159168#[ derive( Debug , Clone ) ]
@@ -186,22 +195,24 @@ impl PyAggregateUDF {
186195
187196 #[ staticmethod]
188197 pub fn from_pycapsule ( func : Bound < ' _ , PyAny > ) -> PyDataFusionResult < Self > {
198+ if func. is_instance_of :: < PyCapsule > ( ) {
199+ let capsule = func. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
200+ let capsule: & Bound < ' _ , PyCapsule > = capsule. into ( ) ;
201+ let function = aggregate_udf_from_capsule ( capsule) ?;
202+ return Ok ( Self { function } ) ;
203+ }
204+
189205 if func. hasattr ( "__datafusion_aggregate_udf__" ) ? {
190206 let capsule = func. getattr ( "__datafusion_aggregate_udf__" ) ?. call0 ( ) ?;
191207 let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
192- validate_pycapsule ( capsule, "datafusion_aggregate_udf" ) ?;
193-
194- let udaf = unsafe { capsule. reference :: < FFI_AggregateUDF > ( ) } ;
195- let udaf: ForeignAggregateUDF = udaf. try_into ( ) ?;
196-
197- Ok ( Self {
198- function : udaf. into ( ) ,
199- } )
200- } else {
201- Err ( crate :: errors:: PyDataFusionError :: Common (
202- "__datafusion_aggregate_udf__ does not exist on AggregateUDF object." . to_string ( ) ,
203- ) )
208+ let capsule: & Bound < ' _ , PyCapsule > = capsule. into ( ) ;
209+ let function = aggregate_udf_from_capsule ( capsule) ?;
210+ return Ok ( Self { function } ) ;
204211 }
212+
213+ Err ( crate :: errors:: PyDataFusionError :: Common (
214+ "__datafusion_aggregate_udf__ does not exist on AggregateUDF object." . to_string ( ) ,
215+ ) )
205216 }
206217
207218 /// creates a new PyExpr with the call of the udf
0 commit comments