@@ -33,7 +33,7 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3333use datafusion:: execution:: SendableRecordBatchStream ;
3434use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
3535use datafusion:: prelude:: * ;
36- use futures:: StreamExt ;
36+ use futures:: { future , StreamExt } ;
3737use pyo3:: exceptions:: PyValueError ;
3838use pyo3:: prelude:: * ;
3939use pyo3:: pybacked:: PyBackedStr ;
@@ -92,15 +92,7 @@ impl PyDataFrame {
9292
9393 fn __repr__ ( & self , py : Python ) -> PyDataFusionResult < String > {
9494 let df = self . df . as_ref ( ) . clone ( ) ;
95-
96- let stream = wait_for_future ( py, df. execute_stream ( ) ) . map_err ( py_datafusion_err) ?;
97-
98- let batches: Vec < RecordBatch > = wait_for_future (
99- py,
100- stream. take ( 10 ) . collect :: < Vec < _ > > ( ) )
101- . into_iter ( )
102- . collect :: < Result < Vec < _ > , _ > > ( ) ?;
103-
95+ let batches: Vec < RecordBatch > = get_batches ( py, df, 10 ) ?;
10496 let batches_as_string = pretty:: pretty_format_batches ( & batches) ;
10597 match batches_as_string {
10698 Ok ( batch) => Ok ( format ! ( "DataFrame()\n {batch}" ) ) ,
@@ -111,8 +103,8 @@ impl PyDataFrame {
111103 fn _repr_html_ ( & self , py : Python ) -> PyDataFusionResult < String > {
112104 let mut html_str = "<table border='1'>\n " . to_string ( ) ;
113105
114- let df = self . df . as_ref ( ) . clone ( ) . limit ( 0 , Some ( 10 ) ) ? ;
115- let batches = wait_for_future ( py, df. collect ( ) ) ?;
106+ let df = self . df . as_ref ( ) . clone ( ) ;
107+ let batches: Vec < RecordBatch > = get_batches ( py, df, 10 ) ?;
116108
117109 if batches. is_empty ( ) {
118110 html_str. push_str ( "</table>\n " ) ;
@@ -742,3 +734,38 @@ fn record_batch_into_schema(
742734
743735 RecordBatch :: try_new ( schema, data_arrays)
744736}
737+
738+ fn get_batches (
739+ py : Python ,
740+ df : DataFrame ,
741+ max_rows : usize ,
742+ ) -> Result < Vec < RecordBatch > , PyDataFusionError > {
743+ let partitioned_stream = wait_for_future ( py, df. execute_stream_partitioned ( ) ) . map_err ( py_datafusion_err) ?;
744+ let stream = futures:: stream:: iter ( partitioned_stream) . flatten ( ) ;
745+ wait_for_future (
746+ py,
747+ stream
748+ . scan ( 0 , |state, x| {
749+ let total = * state;
750+ if total >= max_rows {
751+ future:: ready ( None )
752+ } else {
753+ match x {
754+ Ok ( batch) => {
755+ if total + batch. num_rows ( ) <= max_rows {
756+ * state = total + batch. num_rows ( ) ;
757+ future:: ready ( Some ( Ok ( batch) ) )
758+ } else {
759+ * state = max_rows;
760+ future:: ready ( Some ( Ok ( batch. slice ( 0 , max_rows - total) ) ) )
761+ }
762+ }
763+ Err ( err) => future:: ready ( Some ( Err ( PyDataFusionError :: from ( err) ) ) ) ,
764+ }
765+ }
766+ } )
767+ . collect :: < Vec < _ > > ( ) ,
768+ )
769+ . into_iter ( )
770+ . collect :: < Result < Vec < _ > , _ > > ( )
771+ }
0 commit comments