Skip to content

Commit ac5c16f

Browse files
committed
Add aggregate_udf_from_capsule helper for UDFs
Introduce a utility to validate PyCapsules and convert them into reusable DataFusion aggregate UDFs. Update PyAggregateUDF.from_pycapsule to handle raw PyCapsule inputs, leverage the new helper, and maintain existing provider fallback and error handling.
1 parent 9d0b191 commit ac5c16f

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

src/udaf.rs

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction {
154154
})
155155
}
156156

157+
fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
158+
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
159+
160+
let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
161+
let udaf: ForeignAggregateUDF = udaf.try_into()?;
162+
163+
Ok(udaf.into())
164+
}
165+
157166
/// Represents an AggregateUDF
158167
#[pyclass(frozen, name = "AggregateUDF", module = "datafusion", subclass)]
159168
#[derive(Debug, Clone)]
@@ -186,22 +195,24 @@ impl PyAggregateUDF {
186195

187196
#[staticmethod]
188197
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
198+
if func.is_instance_of::<PyCapsule>() {
199+
let capsule = func.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
200+
let capsule: &Bound<'_, PyCapsule> = capsule.into();
201+
let function = aggregate_udf_from_capsule(capsule)?;
202+
return Ok(Self { function });
203+
}
204+
189205
if func.hasattr("__datafusion_aggregate_udf__")? {
190206
let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?;
191207
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
192-
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
193-
194-
let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
195-
let udaf: ForeignAggregateUDF = udaf.try_into()?;
196-
197-
Ok(Self {
198-
function: udaf.into(),
199-
})
200-
} else {
201-
Err(crate::errors::PyDataFusionError::Common(
202-
"__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
203-
))
208+
let capsule: &Bound<'_, PyCapsule> = capsule.into();
209+
let function = aggregate_udf_from_capsule(capsule)?;
210+
return Ok(Self { function });
204211
}
212+
213+
Err(crate::errors::PyDataFusionError::Common(
214+
"__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
215+
))
205216
}
206217

207218
/// creates a new PyExpr with the call of the udf

0 commit comments

Comments
 (0)