@@ -31,9 +31,11 @@ use datafusion::common::UnnestOptions;
3131use datafusion:: config:: { CsvOptions , TableParquetOptions } ;
3232use datafusion:: dataframe:: { DataFrame , DataFrameWriteOptions } ;
3333use datafusion:: datasource:: TableProvider ;
34+ use datafusion:: error:: DataFusionError ;
3435use datafusion:: execution:: SendableRecordBatchStream ;
3536use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
3637use datafusion:: prelude:: * ;
38+ use futures:: { StreamExt , TryStreamExt } ;
3739use pyo3:: exceptions:: PyValueError ;
3840use pyo3:: prelude:: * ;
3941use pyo3:: pybacked:: PyBackedStr ;
@@ -70,6 +72,7 @@ impl PyTableProvider {
7072 PyTable :: new ( table_provider)
7173 }
7274}
75+ const MAX_TABLE_BYTES_TO_DISPLAY : usize = 2 * 1024 * 1024 ; // 2 MB
7376
7477/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
7578/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -121,46 +124,57 @@ impl PyDataFrame {
121124 }
122125
123126 fn _repr_html_ ( & self , py : Python ) -> PyDataFusionResult < String > {
124- let mut html_str = "<table border='1'>\n " . to_string ( ) ;
125-
126- let df = self . df . as_ref ( ) . clone ( ) . limit ( 0 , Some ( 10 ) ) ?;
127- let batches = wait_for_future ( py, df. collect ( ) ) ?;
127+ let ( batch, mut has_more) =
128+ wait_for_future ( py, get_first_record_batch ( self . df . as_ref ( ) . clone ( ) ) ) ?;
129+ let Some ( batch) = batch else {
130+ return Ok ( "No data to display" . to_string ( ) ) ;
131+ } ;
128132
129- if batches . is_empty ( ) {
130- html_str . push_str ( "</table> \n " ) ;
131- return Ok ( html_str ) ;
132- }
133+ let mut html_str = "
134+ <div style= \" width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc; \" >
135+ <table style= \" border-collapse: collapse; min-width: 100% \" >
136+ <thead> \n " . to_string ( ) ;
133137
134- let schema = batches [ 0 ] . schema ( ) ;
138+ let schema = batch . schema ( ) ;
135139
136140 let mut header = Vec :: new ( ) ;
137141 for field in schema. fields ( ) {
138- header. push ( format ! ( "<th>{}</td >" , field. name( ) ) ) ;
142+ header. push ( format ! ( "<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;' >{}</th >" , field. name( ) ) ) ;
139143 }
140144 let header_str = header. join ( "" ) ;
141- html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , header_str) ) ;
145+ html_str. push_str ( & format ! ( "<tr>{}</tr></thead><tbody>\n " , header_str) ) ;
146+
147+ let formatters = batch
148+ . columns ( )
149+ . iter ( )
150+ . map ( |c| ArrayFormatter :: try_new ( c. as_ref ( ) , & FormatOptions :: default ( ) ) )
151+ . map ( |c| c. map_err ( |e| PyValueError :: new_err ( format ! ( "Error: {:?}" , e. to_string( ) ) ) ) )
152+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
153+
154+ let batch_size = batch. get_array_memory_size ( ) ;
155+ let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY {
156+ true => {
157+ has_more = true ;
158+ let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32 ;
159+ ( batch. num_rows ( ) as f32 * ratio) . round ( ) as usize
160+ }
161+ false => batch. num_rows ( ) ,
162+ } ;
142163
143- for batch in batches {
144- let formatters = batch
145- . columns ( )
146- . iter ( )
147- . map ( |c| ArrayFormatter :: try_new ( c. as_ref ( ) , & FormatOptions :: default ( ) ) )
148- . map ( |c| {
149- c. map_err ( |e| PyValueError :: new_err ( format ! ( "Error: {:?}" , e. to_string( ) ) ) )
150- } )
151- . collect :: < Result < Vec < _ > , _ > > ( ) ?;
152-
153- for row in 0 ..batch. num_rows ( ) {
154- let mut cells = Vec :: new ( ) ;
155- for formatter in & formatters {
156- cells. push ( format ! ( "<td>{}</td>" , formatter. value( row) ) ) ;
157- }
158- let row_str = cells. join ( "" ) ;
159- html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , row_str) ) ;
164+ for row in 0 ..num_rows_to_display {
165+ let mut cells = Vec :: new ( ) ;
166+ for formatter in & formatters {
167+ cells. push ( format ! ( "<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>" , formatter. value( row) ) ) ;
160168 }
169+ let row_str = cells. join ( "" ) ;
170+ html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , row_str) ) ;
161171 }
162172
163- html_str. push_str ( "</table>\n " ) ;
173+ html_str. push_str ( "</tbody></table></div>\n " ) ;
174+
175+ if has_more {
176+ html_str. push_str ( "Data truncated due to size." ) ;
177+ }
164178
165179 Ok ( html_str)
166180 }
@@ -771,3 +785,29 @@ fn record_batch_into_schema(
771785
772786 RecordBatch :: try_new ( schema, data_arrays)
773787}
788+
789+ /// This is a helper function to return the first non-empty record batch from executing a DataFrame.
790+ /// It additionally returns a bool, which indicates if there are more record batches available.
791+ /// We do this so we can determine if we should indicate to the user that the data has been
792+ /// truncated.
793+ async fn get_first_record_batch (
794+ df : DataFrame ,
795+ ) -> Result < ( Option < RecordBatch > , bool ) , DataFusionError > {
796+ let mut stream = df. execute_stream ( ) . await ?;
797+ loop {
798+ let rb = match stream. next ( ) . await {
799+ None => return Ok ( ( None , false ) ) ,
800+ Some ( Ok ( r) ) => r,
801+ Some ( Err ( e) ) => return Err ( e) ,
802+ } ;
803+
804+ if rb. num_rows ( ) > 0 {
805+ let has_more = match stream. try_next ( ) . await {
806+ Ok ( None ) => false , // reached end
807+ Ok ( Some ( _) ) => true ,
808+ Err ( _) => false , // Stream disconnected
809+ } ;
810+ return Ok ( ( Some ( rb) , has_more) ) ;
811+ }
812+ }
813+ }
0 commit comments