@@ -126,11 +126,15 @@ impl PyDataFrame {
126126 }
127127
128128 fn _repr_html_ ( & self , py : Python ) -> PyDataFusionResult < String > {
129- let ( batch , mut has_more) =
130- wait_for_future ( py, get_first_record_batch ( self . df . as_ref ( ) . clone ( ) ) ) ?;
131- let Some ( batch ) = batch else {
129+ let ( batches , mut has_more) =
130+ wait_for_future ( py, get_first_few_record_batches ( self . df . as_ref ( ) . clone ( ) ) ) ?;
131+ let Some ( batches ) = batches else {
132132 return Ok ( "No data to display" . to_string ( ) ) ;
133133 } ;
134+ if batches. is_empty ( ) {
135+ // This should not be reached, but do it for safety since we index into the vector below
136+ return Ok ( "No data to display" . to_string ( ) ) ;
137+ }
134138
135139 let table_uuid = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
136140
@@ -162,12 +166,11 @@ impl PyDataFrame {
162166 }
163167 </style>
164168
165-
166169 <div style=\" width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\" >
167170 <table style=\" border-collapse: collapse; min-width: 100%\" >
168171 <thead>\n " . to_string ( ) ;
169172
170- let schema = batch . schema ( ) ;
173+ let schema = batches [ 0 ] . schema ( ) ;
171174
172175 let mut header = Vec :: new ( ) ;
173176 for field in schema. fields ( ) {
@@ -176,52 +179,75 @@ impl PyDataFrame {
176179 let header_str = header. join ( "" ) ;
177180 html_str. push_str ( & format ! ( "<tr>{}</tr></thead><tbody>\n " , header_str) ) ;
178181
179- let formatters = batch
180- . columns ( )
182+ let batch_formatters = batches
181183 . iter ( )
182- . map ( |c| ArrayFormatter :: try_new ( c. as_ref ( ) , & FormatOptions :: default ( ) ) )
183- . map ( |c| c. map_err ( |e| PyValueError :: new_err ( format ! ( "Error: {:?}" , e. to_string( ) ) ) ) )
184+ . map ( |batch| {
185+ batch
186+ . columns ( )
187+ . iter ( )
188+ . map ( |c| ArrayFormatter :: try_new ( c. as_ref ( ) , & FormatOptions :: default ( ) ) )
189+ . map ( |c| {
190+ c. map_err ( |e| PyValueError :: new_err ( format ! ( "Error: {:?}" , e. to_string( ) ) ) )
191+ } )
192+ . collect :: < Result < Vec < _ > , _ > > ( )
193+ } )
184194 . collect :: < Result < Vec < _ > , _ > > ( ) ?;
185195
186- let batch_size = batch. get_array_memory_size ( ) ;
187- let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY {
196+ let total_memory: usize = batches
197+ . iter ( )
198+ . map ( |batch| batch. get_array_memory_size ( ) )
199+ . sum ( ) ;
200+ let rows_per_batch = batches. iter ( ) . map ( |batch| batch. num_rows ( ) ) ;
201+ let total_rows = rows_per_batch. clone ( ) . sum ( ) ;
202+
203+ // let (total_memory, total_rows) = batches.iter().fold((0, 0), |acc, batch| {
204+ // (acc.0 + batch.get_array_memory_size(), acc.1 + batch.num_rows())
205+ // });
206+
207+ let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY {
188208 true => {
189- let num_batch_rows = batch. num_rows ( ) ;
190- let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32 ;
191- let mut reduced_row_num = ( num_batch_rows as f32 * ratio) . round ( ) as usize ;
209+ let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32 ;
210+ let mut reduced_row_num = ( total_rows as f32 * ratio) . round ( ) as usize ;
192211 if reduced_row_num < MIN_TABLE_ROWS_TO_DISPLAY {
193- reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY . min ( num_batch_rows ) ;
212+ reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY . min ( total_rows ) ;
194213 }
195214
196- has_more = has_more || reduced_row_num < num_batch_rows ;
215+ has_more = has_more || reduced_row_num < total_rows ;
197216 reduced_row_num
198217 }
199- false => batch . num_rows ( ) ,
218+ false => total_rows ,
200219 } ;
201220
202- for row in 0 ..num_rows_to_display {
203- let mut cells = Vec :: new ( ) ;
204- for ( col, formatter) in formatters. iter ( ) . enumerate ( ) {
205- let cell_data = formatter. value ( row) . to_string ( ) ;
206- // From testing, primitive data types do not typically get larger than 21 characters
207- if cell_data. len ( ) > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
208- let short_cell_data = & cell_data[ 0 ..MAX_LENGTH_CELL_WITHOUT_MINIMIZE ] ;
209- cells. push ( format ! ( "
210- <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
211- <div class=\" expandable-container\" >
212- <span class=\" expandable\" id=\" {table_uuid}-min-text-{row}-{col}\" >{short_cell_data}</span>
213- <span class=\" full-text\" id=\" {table_uuid}-full-text-{row}-{col}\" >{cell_data}</span>
214- <button class=\" expand-btn\" onclick=\" toggleDataFrameCellText('{table_uuid}',{row},{col})\" >...</button>
215- </div>
216- </td>" ) ) ;
217- } else {
218- cells. push ( format ! ( "<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>" , formatter. value( row) ) ) ;
221+ // We need to build up row by row for html
222+ let mut table_row = 0 ;
223+ for ( batch_formatter, num_rows_in_batch) in batch_formatters. iter ( ) . zip ( rows_per_batch) {
224+ for batch_row in 0 ..num_rows_in_batch {
225+ table_row += 1 ;
226+ if table_row > num_rows_to_display {
227+ break ;
228+ }
229+ let mut cells = Vec :: new ( ) ;
230+ for ( col, formatter) in batch_formatter. iter ( ) . enumerate ( ) {
231+ let cell_data = formatter. value ( batch_row) . to_string ( ) ;
232+ // From testing, primitive data types do not typically get larger than 21 characters
233+ if cell_data. len ( ) > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
234+ let short_cell_data = & cell_data[ 0 ..MAX_LENGTH_CELL_WITHOUT_MINIMIZE ] ;
235+ cells. push ( format ! ( "
236+ <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
237+ <div class=\" expandable-container\" >
238+ <span class=\" expandable\" id=\" {table_uuid}-min-text-{table_row}-{col}\" >{short_cell_data}</span>
239+ <span class=\" full-text\" id=\" {table_uuid}-full-text-{table_row}-{col}\" >{cell_data}</span>
240+ <button class=\" expand-btn\" onclick=\" toggleDataFrameCellText('{table_uuid}',{table_row},{col})\" >...</button>
241+ </div>
242+ </td>" ) ) ;
243+ } else {
244+ cells. push ( format ! ( "<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>" , formatter. value( batch_row) ) ) ;
245+ }
219246 }
247+ let row_str = cells. join ( "" ) ;
248+ html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , row_str) ) ;
220249 }
221- let row_str = cells. join ( "" ) ;
222- html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , row_str) ) ;
223250 }
224-
225251 html_str. push_str ( "</tbody></table></div>\n " ) ;
226252
227253 html_str. push_str ( "
@@ -862,24 +888,36 @@ fn record_batch_into_schema(
862888/// It additionally returns a bool, which indicates if there are more record batches available.
863889/// We do this so we can determine if we should indicate to the user that the data has been
864890/// truncated.
865- async fn get_first_record_batch (
891+ async fn get_first_few_record_batches (
866892 df : DataFrame ,
867- ) -> Result < ( Option < RecordBatch > , bool ) , DataFusionError > {
893+ ) -> Result < ( Option < Vec < RecordBatch > > , bool ) , DataFusionError > {
868894 let mut stream = df. execute_stream ( ) . await ?;
869- loop {
895+ let mut size_estimate_so_far = 0 ;
896+ let mut record_batches = Vec :: default ( ) ;
897+ while size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY {
870898 let rb = match stream. next ( ) . await {
871- None => return Ok ( ( None , false ) ) ,
899+ None => {
900+ break ;
901+ }
872902 Some ( Ok ( r) ) => r,
873903 Some ( Err ( e) ) => return Err ( e) ,
874904 } ;
875905
876906 if rb. num_rows ( ) > 0 {
877- let has_more = match stream. try_next ( ) . await {
878- Ok ( None ) => false , // reached end
879- Ok ( Some ( _) ) => true ,
880- Err ( _) => false , // Stream disconnected
881- } ;
882- return Ok ( ( Some ( rb) , has_more) ) ;
907+ size_estimate_so_far += rb. get_array_memory_size ( ) ;
908+ record_batches. push ( rb) ;
883909 }
884910 }
911+
912+ if record_batches. is_empty ( ) {
913+ return Ok ( ( None , false ) ) ;
914+ }
915+
916+ let has_more = match stream. try_next ( ) . await {
917+ Ok ( None ) => false , // reached end
918+ Ok ( Some ( _) ) => true ,
919+ Err ( _) => false , // Stream disconnected
920+ } ;
921+
922+ Ok ( ( Some ( record_batches) , has_more) )
885923}
0 commit comments