Skip to content

Commit 5cd598a

Browse files
committed
testing FFI for table provider
1 parent 3c66201 commit 5cd598a

File tree

3 files changed

+315
-5
lines changed

3 files changed

+315
-5
lines changed

src/context.rs

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use std::str::FromStr;
2121
use std::sync::Arc;
2222

2323
use arrow::array::RecordBatchReader;
24+
use arrow::ffi::FFI_ArrowSchema;
2425
use arrow::ffi_stream::ArrowArrayStreamReader;
2526
use arrow::pyarrow::FromPyArrow;
2627
use datafusion::execution::session_state::SessionStateBuilder;
@@ -36,6 +37,8 @@ use crate::dataframe::PyDataFrame;
3637
use crate::dataset::Dataset;
3738
use crate::errors::{py_datafusion_err, DataFusionError};
3839
use crate::expr::sort_expr::PySortExpr;
40+
use crate::expr::PyExpr;
41+
use crate::ffi::FFI_TableProvider;
3942
use crate::physical_plan::PyExecutionPlan;
4043
use crate::record_batch::PyRecordBatchStream;
4144
use crate::sql::logical::PyLogicalPlan;
@@ -54,11 +57,9 @@ use datafusion::datasource::file_format::parquet::ParquetFormat;
5457
use datafusion::datasource::listing::{
5558
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
5659
};
57-
use datafusion::datasource::MemTable;
5860
use datafusion::datasource::TableProvider;
59-
use datafusion::execution::context::{
60-
DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
61-
};
61+
use datafusion::datasource::{provider, MemTable};
62+
use datafusion::execution::context::{DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext};
6263
use datafusion::execution::disk_manager::DiskManagerConfig;
6364
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
6465
use datafusion::execution::options::ReadOptions;
@@ -67,7 +68,7 @@ use datafusion::physical_plan::SendableRecordBatchStream;
6768
use datafusion::prelude::{
6869
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
6970
};
70-
use pyo3::types::{PyDict, PyList, PyTuple};
71+
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
7172
use tokio::task::JoinHandle;
7273

7374
/// Configuration options for a SessionContext
@@ -566,6 +567,41 @@ impl PySessionContext {
566567
Ok(())
567568
}
568569

570+
/// Construct datafusion dataframe from Arrow Table
571+
pub fn register_table_provider(
572+
&mut self,
573+
name: &str,
574+
provider: Bound<'_, PyAny>,
575+
py: Python,
576+
) -> PyResult<()> {
577+
if provider.hasattr("__datafusion_table_provider__")? {
578+
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
579+
let capsule = capsule.downcast::<PyCapsule>()?;
580+
// validate_pycapsule(capsule, "arrow_array_stream")?;
581+
582+
let mut provider = unsafe { FFI_TableProvider::from_raw(capsule.pointer() as _) };
583+
584+
println!("Found provider version {}", provider.version);
585+
586+
if let Some(s) = provider.schema {
587+
let mut schema = FFI_ArrowSchema::empty();
588+
589+
let ret_code = unsafe { s(&mut provider, &mut schema) };
590+
591+
if ret_code == 0 {
592+
let schema = Schema::try_from(&schema)
593+
.map_err(|e| PyValueError::new_err(e.to_string()))?;
594+
println!("got schema {}", schema);
595+
} else {
596+
return Err(PyValueError::new_err(format!(
597+
"Cannot get schema from input stream. Error code: {ret_code:?}"
598+
)));
599+
}
600+
}
601+
}
602+
Ok(())
603+
}
604+
569605
pub fn register_record_batches(
570606
&mut self,
571607
name: &str,

src/ffi.rs

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
use std::{
2+
ffi::{c_char, c_int, c_void, CStr, CString},
3+
ptr::addr_of,
4+
sync::Arc,
5+
};
6+
7+
use arrow::{error::ArrowError, ffi::FFI_ArrowSchema};
8+
use datafusion::common::Result;
9+
use datafusion::{
10+
catalog::{Session, TableProvider},
11+
common::DFSchema,
12+
execution::{context::SessionState, session_state::SessionStateBuilder},
13+
physical_plan::ExecutionPlan,
14+
prelude::{Expr, SessionConfig},
15+
};
16+
use tokio::runtime::Runtime;
17+
18+
#[repr(C)]
19+
#[derive(Debug)]
20+
#[allow(non_camel_case_types)]
21+
pub enum FFI_Constraint {
22+
/// Columns with the given indices form a composite primary key (they are
23+
/// jointly unique and not nullable):
24+
PrimaryKey(Vec<usize>),
25+
/// Columns with the given indices form a composite unique key:
26+
Unique(Vec<usize>),
27+
}
28+
29+
#[repr(C)]
30+
#[derive(Debug)]
31+
#[allow(missing_docs)]
32+
#[allow(non_camel_case_types)]
33+
pub struct FFI_ExecutionPlan {
34+
pub private_data: *mut c_void,
35+
}
36+
37+
unsafe impl Send for FFI_ExecutionPlan {}
38+
39+
struct ExecutionPlanPrivateData {
40+
plan: Arc<dyn ExecutionPlan + Send>,
41+
last_error: Option<CString>,
42+
}
43+
44+
#[repr(C)]
45+
#[derive(Debug)]
46+
#[allow(missing_docs)]
47+
#[allow(non_camel_case_types)]
48+
pub struct FFI_SessionConfig {
49+
pub version: i64,
50+
51+
pub private_data: *mut c_void,
52+
}
53+
54+
unsafe impl Send for FFI_SessionConfig {}
55+
56+
struct SessionConfigPrivateData {
57+
config: SessionConfig,
58+
last_error: Option<CString>,
59+
}
60+
61+
struct ExportedSessionConfig {
62+
session: *mut FFI_SessionConfig,
63+
}
64+
65+
impl ExportedSessionConfig {
66+
fn get_private_data(&mut self) -> &mut SessionConfigPrivateData {
67+
unsafe { &mut *((*self.session).private_data as *mut SessionConfigPrivateData) }
68+
}
69+
}
70+
71+
#[repr(C)]
72+
#[derive(Debug)]
73+
#[allow(missing_docs)]
74+
#[allow(non_camel_case_types)]
75+
pub struct FFI_Expr {}
76+
77+
#[repr(C)]
78+
#[derive(Debug)]
79+
#[allow(missing_docs)]
80+
#[allow(non_camel_case_types)]
81+
pub struct FFI_TableProvider {
82+
pub version: i64,
83+
pub schema: Option<
84+
unsafe extern "C" fn(provider: *mut FFI_TableProvider, out: *mut FFI_ArrowSchema) -> c_int,
85+
>,
86+
pub scan: Option<
87+
unsafe extern "C" fn(
88+
provider: *mut FFI_TableProvider,
89+
session_config: *mut FFI_SessionConfig,
90+
n_projections: c_int,
91+
projections: *mut c_int,
92+
n_filters: c_int,
93+
filters: *mut *const c_char,
94+
limit: c_int,
95+
out: *mut FFI_ExecutionPlan,
96+
) -> c_int,
97+
>,
98+
pub private_data: *mut c_void,
99+
}
100+
101+
unsafe impl Send for FFI_TableProvider {}
102+
103+
struct ProviderPrivateData {
104+
provider: Box<dyn TableProvider + Send>,
105+
last_error: Option<CString>,
106+
}
107+
108+
struct ExportedTableProvider {
109+
provider: *mut FFI_TableProvider,
110+
}
111+
112+
// The callback used to get array schema
113+
unsafe extern "C" fn provider_schema(
114+
provider: *mut FFI_TableProvider,
115+
schema: *mut FFI_ArrowSchema,
116+
) -> c_int {
117+
ExportedTableProvider { provider }.schema(schema)
118+
}
119+
120+
unsafe extern "C" fn provider_scan(
121+
provider: *mut FFI_TableProvider,
122+
session_config: *mut FFI_SessionConfig,
123+
n_projections: c_int,
124+
projections: *mut c_int,
125+
n_filters: c_int,
126+
filters: *mut *const c_char,
127+
limit: c_int,
128+
mut out: *mut FFI_ExecutionPlan,
129+
) -> c_int {
130+
let config = unsafe { (*session_config).private_data as *const SessionConfigPrivateData };
131+
let session = SessionStateBuilder::new()
132+
.with_config((*config).config.clone())
133+
.build();
134+
135+
let num_projections: usize = n_projections.try_into().unwrap_or(0);
136+
137+
let projections: Vec<usize> = std::slice::from_raw_parts(projections, num_projections)
138+
.iter()
139+
.filter_map(|v| (*v).try_into().ok())
140+
.collect();
141+
let maybe_projections = match projections.is_empty() {
142+
true => None,
143+
false => Some(&projections),
144+
};
145+
146+
let filters_slice = std::slice::from_raw_parts(filters, n_filters as usize);
147+
let filters_vec: Vec<String> = filters_slice
148+
.iter()
149+
.map(|&s| CStr::from_ptr(s).to_string_lossy().to_string())
150+
.collect();
151+
152+
let limit = limit.try_into().ok();
153+
154+
let plan =
155+
ExportedTableProvider { provider }.scan(&session, maybe_projections, filters_vec, limit);
156+
157+
match plan {
158+
Ok(mut plan) => {
159+
out = &mut plan;
160+
0
161+
}
162+
Err(_) => 1,
163+
}
164+
}
165+
166+
impl ExportedTableProvider {
167+
fn get_private_data(&mut self) -> &mut ProviderPrivateData {
168+
unsafe { &mut *((*self.provider).private_data as *mut ProviderPrivateData) }
169+
}
170+
171+
pub fn schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 {
172+
let private_data = self.get_private_data();
173+
let provider = &private_data.provider;
174+
175+
let schema = FFI_ArrowSchema::try_from(provider.schema().as_ref());
176+
177+
match schema {
178+
Ok(schema) => {
179+
unsafe { std::ptr::copy(addr_of!(schema), out, 1) };
180+
std::mem::forget(schema);
181+
0
182+
}
183+
Err(ref err) => {
184+
private_data.last_error = Some(
185+
CString::new(err.to_string()).expect("Error string has a null byte in it."),
186+
);
187+
get_error_code(err)
188+
}
189+
}
190+
}
191+
192+
pub fn scan(
193+
&mut self,
194+
session: &SessionState,
195+
projections: Option<&Vec<usize>>,
196+
filters: Vec<String>,
197+
limit: Option<usize>,
198+
) -> Result<FFI_ExecutionPlan> {
199+
let private_data = self.get_private_data();
200+
let provider = &private_data.provider;
201+
202+
let schema = provider.schema();
203+
let df_schema: DFSchema = schema.try_into()?;
204+
205+
let filter_exprs = filters
206+
.into_iter()
207+
.map(|expr_str| session.create_logical_expr(&expr_str, &df_schema))
208+
.collect::<datafusion::common::Result<Vec<Expr>>>()?;
209+
210+
let runtime = Runtime::new().unwrap();
211+
let plan = runtime.block_on(provider.scan(session, projections, &filter_exprs, limit))?;
212+
213+
let plan_ptr = Box::new(ExecutionPlanPrivateData {
214+
plan,
215+
last_error: None,
216+
});
217+
218+
Ok(FFI_ExecutionPlan {
219+
private_data: Box::into_raw(plan_ptr) as *mut c_void,
220+
})
221+
}
222+
}
223+
224+
const ENOMEM: i32 = 12;
225+
const EIO: i32 = 5;
226+
const EINVAL: i32 = 22;
227+
const ENOSYS: i32 = 78;
228+
229+
fn get_error_code(err: &ArrowError) -> i32 {
230+
match err {
231+
ArrowError::NotYetImplemented(_) => ENOSYS,
232+
ArrowError::MemoryError(_) => ENOMEM,
233+
ArrowError::IoError(_, _) => EIO,
234+
_ => EINVAL,
235+
}
236+
}
237+
238+
impl FFI_TableProvider {
239+
/// Creates a new [`FFI_TableProvider`].
240+
pub fn new(provider: Box<dyn TableProvider + Send>) -> Self {
241+
let private_data = Box::new(ProviderPrivateData {
242+
provider,
243+
last_error: None,
244+
});
245+
246+
Self {
247+
version: 2,
248+
schema: Some(provider_schema),
249+
scan: Some(provider_scan),
250+
private_data: Box::into_raw(private_data) as *mut c_void,
251+
}
252+
}
253+
254+
/**
255+
Replace temporary pointer with updated
256+
# Safety
257+
User must validate the raw pointer is valid.
258+
*/
259+
pub unsafe fn from_raw(raw_provider: *mut FFI_TableProvider) -> Self {
260+
std::ptr::replace(raw_provider, Self::empty())
261+
}
262+
263+
/// Creates a new empty [FFI_ArrowArrayStream]. Used to import from the C Stream Interface.
264+
pub fn empty() -> Self {
265+
Self {
266+
version: 0,
267+
schema: None,
268+
scan: None,
269+
private_data: std::ptr::null_mut(),
270+
}
271+
}
272+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ mod udf;
6161
mod udwf;
6262
pub mod utils;
6363

64+
pub mod ffi;
65+
6466
#[cfg(feature = "mimalloc")]
6567
#[global_allocator]
6668
static GLOBAL: MiMalloc = MiMalloc;

0 commit comments

Comments
 (0)