File tree Expand file tree Collapse file tree 4 files changed +31
-6
lines changed
Expand file tree Collapse file tree 4 files changed +31
-6
lines changed Original file line number Diff line number Diff line change @@ -166,11 +166,16 @@ class Table:
166166
167167 def __init__ (
168168 self ,
169- table : _InternalRawTable | _InternalTableProvider | Table ,
169+ table : _InternalRawTable | _InternalTableProvider | Table | Any ,
170170 ) -> None :
171- """Wrap a low level table or table provider."""
171+ """Wrap a low level table, table provider, or convertibles like DataFrame ."""
172172 if isinstance (table , Table ):
173173 table = table .table
174+ else :
175+ from datafusion .dataframe import DataFrame as DataFrameWrapper
176+
177+ if isinstance (table , DataFrameWrapper ):
178+ table = table .df .into_view ()
174179
175180 if not isinstance (table , (_InternalRawTable , _InternalTableProvider )):
176181 raise TypeError (EXPECTED_PROVIDER_MSG )
Original file line number Diff line number Diff line change @@ -196,6 +196,20 @@ def test_schema_register_table_with_dataframe_errors(ctx: SessionContext):
196196 assert str (exc_info .value ) == EXPECTED_PROVIDER_MSG
197197
198198
199+ def test_table_wraps_dataframe (ctx : SessionContext ):
200+ df = ctx .sql ("SELECT 1 AS value" )
201+
202+ table = Table (df )
203+ ctx .register_table ("df_table" , table )
204+
205+ try :
206+ batches = ctx .sql ("SELECT value FROM df_table" ).collect ()
207+ assert len (batches ) == 1
208+ assert batches [0 ].column (0 ).to_pylist () == [1 ]
209+ finally :
210+ ctx .deregister_table ("df_table" )
211+
212+
199213def test_in_end_to_end_python_providers (ctx : SessionContext ):
200214 """Test registering all python providers and running a query against them."""
201215
Original file line number Diff line number Diff line change @@ -429,15 +429,15 @@ impl PySessionContext {
429429 }
430430
431431 /// Returns a PyDataFrame whose plan corresponds to the SQL statement.
432- pub fn sql ( & mut self , query : & str , py : Python ) -> PyDataFusionResult < PyDataFrame > {
432+ pub fn sql ( & self , query : & str , py : Python ) -> PyDataFusionResult < PyDataFrame > {
433433 let result = self . ctx . sql ( query) ;
434434 let df = wait_for_future ( py, result) ??;
435435 Ok ( PyDataFrame :: new ( df) )
436436 }
437437
438438 #[ pyo3( signature = ( query, options=None ) ) ]
439439 pub fn sql_with_options (
440- & mut self ,
440+ & self ,
441441 query : & str ,
442442 options : Option < PySQLOptions > ,
443443 py : Python ,
Original file line number Diff line number Diff line change 8282 const INTERVAL_CHECK_SIGNALS : Duration = Duration :: from_millis ( 1_000 ) ;
8383
8484 py. allow_threads ( || {
85- runtime . block_on ( async {
85+ let wait_future = || async {
8686 tokio:: pin!( fut) ;
8787 loop {
8888 tokio:: select! {
9292 }
9393 }
9494 }
95- } )
95+ } ;
96+
97+ if tokio:: runtime:: Handle :: try_current ( ) . is_ok ( ) {
98+ tokio:: task:: block_in_place ( || runtime. block_on ( wait_future ( ) ) )
99+ } else {
100+ runtime. block_on ( wait_future ( ) )
101+ }
96102 } )
97103}
98104
You can’t perform that action at this time.
0 commit comments