Skip to content

Commit 55c549e

Browse files
committed
Was able to get round trip schema from datafusion -> delta table -> datafusion
1 parent 7876933 commit 55c549e

File tree

2 files changed

+122
-47
lines changed

2 files changed

+122
-47
lines changed

src/context.rs

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -583,21 +583,22 @@ impl PySessionContext {
583583

584584
println!("Found provider version {}", provider.version);
585585

586-
if let Some(s) = provider.schema {
587-
let mut schema = FFI_ArrowSchema::empty();
588-
589-
let ret_code = unsafe { s(&mut provider, &mut schema) };
590-
591-
if ret_code == 0 {
592-
let schema = Schema::try_from(&schema)
593-
.map_err(|e| PyValueError::new_err(e.to_string()))?;
594-
println!("got schema {}", schema);
595-
} else {
596-
return Err(PyValueError::new_err(format!(
597-
"Cannot get schema from input stream. Error code: {ret_code:?}"
598-
)));
599-
}
600-
}
586+
let schema = provider.schema();
587+
println!("Got schema through TableProvider trait {}", schema);
588+
589+
// if let Some(s) = provider.schema {
590+
// let mut schema = s(provider);
591+
592+
// if ret_code == 0 {
593+
// let schema = Schema::try_from(&schema)
594+
// .map_err(|e| PyValueError::new_err(e.to_string()))?;
595+
// println!("got schema {}", schema);
596+
// } else {
597+
// return Err(PyValueError::new_err(format!(
598+
// "Cannot get schema from input stream. Error code: {ret_code:?}"
599+
// )));
600+
// }
601+
// }
601602
}
602603
Ok(())
603604
}

src/ffi.rs

Lines changed: 106 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
use std::{
2+
any::Any,
23
ffi::{c_char, c_int, c_void, CStr, CString},
3-
ptr::addr_of,
4+
ptr::{addr_of, addr_of_mut},
45
sync::Arc,
56
};
67

7-
use arrow::{error::ArrowError, ffi::FFI_ArrowSchema};
8-
use datafusion::common::Result;
8+
use arrow::{
9+
datatypes::{Schema, SchemaRef},
10+
error::ArrowError,
11+
ffi::FFI_ArrowSchema,
12+
};
13+
use async_trait::async_trait;
914
use datafusion::{
1015
catalog::{Session, TableProvider},
1116
common::DFSchema,
1217
execution::{context::SessionState, session_state::SessionStateBuilder},
1318
physical_plan::ExecutionPlan,
1419
prelude::{Expr, SessionConfig},
1520
};
21+
use datafusion::{
22+
common::Result, datasource::TableType, logical_expr::TableProviderFilterPushDown,
23+
};
1624
use tokio::runtime::Runtime;
1725

1826
#[repr(C)]
@@ -80,9 +88,7 @@ pub struct FFI_Expr {}
8088
#[allow(non_camel_case_types)]
8189
pub struct FFI_TableProvider {
8290
pub version: i64,
83-
pub schema: Option<
84-
unsafe extern "C" fn(provider: *mut FFI_TableProvider, out: *mut FFI_ArrowSchema) -> c_int,
85-
>,
91+
pub schema: Option<unsafe extern "C" fn(provider: *const FFI_TableProvider) -> FFI_ArrowSchema>,
8692
pub scan: Option<
8793
unsafe extern "C" fn(
8894
provider: *mut FFI_TableProvider,
@@ -99,6 +105,7 @@ pub struct FFI_TableProvider {
99105
}
100106

101107
unsafe impl Send for FFI_TableProvider {}
108+
unsafe impl Sync for FFI_TableProvider {}
102109

103110
struct ProviderPrivateData {
104111
provider: Box<dyn TableProvider + Send>,
@@ -108,13 +115,14 @@ struct ProviderPrivateData {
108115
struct ExportedTableProvider {
109116
provider: *mut FFI_TableProvider,
110117
}
118+
struct ConstExportedTableProvider {
119+
provider: *const FFI_TableProvider,
120+
}
111121

112122
// The callback used to get array schema
113-
unsafe extern "C" fn provider_schema(
114-
provider: *mut FFI_TableProvider,
115-
schema: *mut FFI_ArrowSchema,
116-
) -> c_int {
117-
ExportedTableProvider { provider }.schema(schema)
123+
unsafe extern "C" fn provider_schema(provider: *const FFI_TableProvider) -> FFI_ArrowSchema {
124+
println!("callback function");
125+
ConstExportedTableProvider { provider }.provider_schema()
118126
}
119127

120128
unsafe extern "C" fn provider_scan(
@@ -151,8 +159,12 @@ unsafe extern "C" fn provider_scan(
151159

152160
let limit = limit.try_into().ok();
153161

154-
let plan =
155-
ExportedTableProvider { provider }.scan(&session, maybe_projections, filters_vec, limit);
162+
let plan = ExportedTableProvider { provider }.provider_scan(
163+
&session,
164+
maybe_projections,
165+
filters_vec,
166+
limit,
167+
);
156168

157169
match plan {
158170
Ok(mut plan) => {
@@ -163,33 +175,33 @@ unsafe extern "C" fn provider_scan(
163175
}
164176
}
165177

166-
impl ExportedTableProvider {
167-
fn get_private_data(&mut self) -> &mut ProviderPrivateData {
168-
unsafe { &mut *((*self.provider).private_data as *mut ProviderPrivateData) }
178+
impl ConstExportedTableProvider {
179+
fn get_private_data(&self) -> &ProviderPrivateData {
180+
unsafe { &*((*self.provider).private_data as *const ProviderPrivateData) }
169181
}
170182

171-
pub fn schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 {
183+
pub fn provider_schema(&self) -> FFI_ArrowSchema {
184+
println!("Enter exported table provider");
172185
let private_data = self.get_private_data();
173186
let provider = &private_data.provider;
174187

175-
let schema = FFI_ArrowSchema::try_from(provider.schema().as_ref());
188+
println!("about to try from in provider.schema()");
189+
// This does silently fail because TableProvider does not return a result
190+
// so we expect it to always pass. Maybe some logging should be added.
191+
let mut schema = FFI_ArrowSchema::try_from(provider.schema().as_ref())
192+
.unwrap_or(FFI_ArrowSchema::empty());
176193

177-
match schema {
178-
Ok(schema) => {
179-
unsafe { std::ptr::copy(addr_of!(schema), out, 1) };
180-
std::mem::forget(schema);
181-
0
182-
}
183-
Err(ref err) => {
184-
private_data.last_error = Some(
185-
CString::new(err.to_string()).expect("Error string has a null byte in it."),
186-
);
187-
get_error_code(err)
188-
}
189-
}
194+
println!("Found the schema but can we return it?");
195+
schema
190196
}
197+
}
191198

192-
pub fn scan(
199+
impl ExportedTableProvider {
200+
fn get_private_data(&mut self) -> &mut ProviderPrivateData {
201+
unsafe { &mut *((*self.provider).private_data as *mut ProviderPrivateData) }
202+
}
203+
204+
pub fn provider_scan(
193205
&mut self,
194206
session: &SessionState,
195207
projections: Option<&Vec<usize>>,
@@ -270,3 +282,65 @@ impl FFI_TableProvider {
270282
}
271283
}
272284
}
285+
286+
#[async_trait]
287+
impl TableProvider for FFI_TableProvider {
288+
/// Returns the table provider as [`Any`](std::any::Any) so that it can be
289+
/// downcast to a specific implementation.
290+
fn as_any(&self) -> &dyn Any {
291+
self
292+
}
293+
294+
/// Get a reference to the schema for this table
295+
fn schema(&self) -> SchemaRef {
296+
let schema = match self.schema {
297+
Some(func) => {
298+
println!("About to call the function to get the schema");
299+
unsafe {
300+
let v = func(self);
301+
println!("Got the mutalbe ffi_arrow_schmea?");
302+
// func(self).as_ref().and_then(|s| Schema::try_from(s).ok())
303+
Schema::try_from(&func(self)).ok()
304+
}
305+
}
306+
None => None,
307+
};
308+
Arc::new(schema.unwrap_or(Schema::empty()))
309+
}
310+
311+
/// Get the type of this table for metadata/catalog purposes.
312+
fn table_type(&self) -> TableType {
313+
TableType::Base
314+
}
315+
316+
/// Create an ExecutionPlan that will scan the table.
317+
/// The table provider will be usually responsible of grouping
318+
/// the source data into partitions that can be efficiently
319+
/// parallelized or distributed.
320+
async fn scan(
321+
&self,
322+
_ctx: &dyn Session,
323+
projection: Option<&Vec<usize>>,
324+
filters: &[Expr],
325+
// limit can be used to reduce the amount scanned
326+
// from the datasource as a performance optimization.
327+
// If set, it contains the amount of rows needed by the `LogicalPlan`,
328+
// The datasource should return *at least* this number of rows if available.
329+
_limit: Option<usize>,
330+
) -> Result<Arc<dyn ExecutionPlan>> {
331+
Err(datafusion::error::DataFusionError::NotImplemented(
332+
"scan not implemented".to_string(),
333+
))
334+
}
335+
336+
/// Tests whether the table provider can make use of a filter expression
337+
/// to optimise data retrieval.
338+
fn supports_filters_pushdown(
339+
&self,
340+
filter: &[&Expr],
341+
) -> Result<Vec<TableProviderFilterPushDown>> {
342+
Err(datafusion::error::DataFusionError::NotImplemented(
343+
"support filter pushdown not implemented".to_string(),
344+
))
345+
}
346+
}

0 commit comments

Comments
 (0)