@@ -19,13 +19,13 @@ use std::collections::HashMap;
1919use std:: ffi:: CString ;
2020use std:: sync:: Arc ;
2121
22- use arrow:: array:: { new_null_array, RecordBatch , RecordBatchIterator , RecordBatchReader } ;
22+ use arrow:: array:: { new_null_array, RecordBatch , RecordBatchReader } ;
2323use arrow:: compute:: can_cast_types;
2424use arrow:: error:: ArrowError ;
2525use arrow:: ffi:: FFI_ArrowSchema ;
2626use arrow:: ffi_stream:: FFI_ArrowArrayStream ;
2727use arrow:: pyarrow:: FromPyArrow ;
28- use datafusion:: arrow:: datatypes:: Schema ;
28+ use datafusion:: arrow:: datatypes:: { Schema , SchemaRef } ;
2929use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
3030use datafusion:: arrow:: util:: pretty;
3131use datafusion:: common:: UnnestOptions ;
@@ -42,7 +42,7 @@ use pyo3::exceptions::PyValueError;
4242use pyo3:: prelude:: * ;
4343use pyo3:: pybacked:: PyBackedStr ;
4444use pyo3:: types:: { PyCapsule , PyList , PyTuple , PyTupleMethods } ;
45- use tokio:: task:: JoinHandle ;
45+ use tokio:: { runtime :: Handle , task:: JoinHandle } ;
4646
4747use crate :: catalog:: PyTable ;
4848use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionError } ;
@@ -354,6 +354,41 @@ impl PyDataFrame {
354354 }
355355}
356356
357+ struct DataFrameStreamReader {
358+ stream : SendableRecordBatchStream ,
359+ runtime : Handle ,
360+ schema : SchemaRef ,
361+ projection : Option < SchemaRef > ,
362+ }
363+
364+ impl Iterator for DataFrameStreamReader {
365+ type Item = Result < RecordBatch , ArrowError > ;
366+
367+ fn next ( & mut self ) -> Option < Self :: Item > {
368+ match self . runtime . block_on ( self . stream . next ( ) ) {
369+ Some ( Ok ( batch) ) => {
370+ let batch = if let Some ( ref schema) = self . projection {
371+ match record_batch_into_schema ( batch, schema. as_ref ( ) ) {
372+ Ok ( b) => b,
373+ Err ( e) => return Some ( Err ( e) ) ,
374+ }
375+ } else {
376+ batch
377+ } ;
378+ Some ( Ok ( batch) )
379+ }
380+ Some ( Err ( e) ) => Some ( Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) ) ,
381+ None => None ,
382+ }
383+ }
384+ }
385+
386+ impl RecordBatchReader for DataFrameStreamReader {
387+ fn schema ( & self ) -> SchemaRef {
388+ self . schema . clone ( )
389+ }
390+ }
391+
357392#[ pymethods]
358393impl PyDataFrame {
359394 /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
@@ -879,8 +914,14 @@ impl PyDataFrame {
879914 py : Python < ' py > ,
880915 requested_schema : Option < Bound < ' py , PyCapsule > > ,
881916 ) -> PyDataFusionResult < Bound < ' py , PyCapsule > > {
882- let mut batches = wait_for_future ( py, self . df . as_ref ( ) . clone ( ) . collect ( ) ) ??;
917+ let rt = & get_tokio_runtime ( ) . 0 ;
918+ let df = self . df . as_ref ( ) . clone ( ) ;
919+ let fut: JoinHandle < datafusion:: common:: Result < SendableRecordBatchStream > > =
920+ rt. spawn ( async move { df. execute_stream ( ) . await } ) ;
921+ let stream = wait_for_future ( py, async { fut. await . map_err ( to_datafusion_err) } ) ???;
922+
883923 let mut schema: Schema = self . df . schema ( ) . to_owned ( ) . into ( ) ;
924+ let mut projection: Option < SchemaRef > = None ;
884925
885926 if let Some ( schema_capsule) = requested_schema {
886927 validate_pycapsule ( & schema_capsule, "arrow_schema" ) ?;
@@ -889,16 +930,17 @@ impl PyDataFrame {
889930 let desired_schema = Schema :: try_from ( schema_ptr) ?;
890931
891932 schema = project_schema ( schema, desired_schema) ?;
892-
893- batches = batches
894- . into_iter ( )
895- . map ( |record_batch| record_batch_into_schema ( record_batch, & schema) )
896- . collect :: < Result < Vec < RecordBatch > , ArrowError > > ( ) ?;
933+ projection = Some ( Arc :: new ( schema. clone ( ) ) ) ;
897934 }
898935
899- let batches_wrapped = batches . into_iter ( ) . map ( Ok ) ;
936+ let schema_ref = projection . clone ( ) . unwrap_or_else ( || Arc :: new ( schema ) ) ;
900937
901- let reader = RecordBatchIterator :: new ( batches_wrapped, Arc :: new ( schema) ) ;
938+ let reader = DataFrameStreamReader {
939+ stream,
940+ runtime : rt. handle ( ) . clone ( ) ,
941+ schema : schema_ref,
942+ projection,
943+ } ;
902944 let reader: Box < dyn RecordBatchReader + Send > = Box :: new ( reader) ;
903945
904946 let ffi_stream = FFI_ArrowArrayStream :: new ( reader) ;
0 commit comments