@@ -1582,6 +1582,61 @@ def test_empty_to_arrow_table(df):
15821582 assert set (pyarrow_table .column_names ) == {"a" , "b" , "c" }
15831583
15841584
1585+ def test_iter_batches_dataframe (fail_collect ):
1586+ ctx = SessionContext ()
1587+
1588+ batch1 = pa .record_batch ([pa .array ([1 ])], names = ["a" ])
1589+ batch2 = pa .record_batch ([pa .array ([2 ])], names = ["a" ])
1590+ df = ctx .create_dataframe ([[batch1 ], [batch2 ]])
1591+
1592+ expected = [batch1 , batch2 ]
1593+ for got , exp in zip (df , expected ):
1594+ assert got .equals (exp )
1595+
1596+
1597+ def test_arrow_c_stream_to_table (fail_collect ):
1598+ ctx = SessionContext ()
1599+
1600+ # Create a DataFrame with two separate record batches
1601+ batch1 = pa .record_batch ([pa .array ([1 ])], names = ["a" ])
1602+ batch2 = pa .record_batch ([pa .array ([2 ])], names = ["a" ])
1603+ df = ctx .create_dataframe ([[batch1 ], [batch2 ]])
1604+
1605+ table = pa .Table .from_batches (df )
1606+ batches = table .to_batches ()
1607+
1608+ assert len (batches ) == 2
1609+ assert batches [0 ].equals (batch1 )
1610+ assert batches [1 ].equals (batch2 )
1611+ assert table .schema == df .schema ()
1612+ assert table .column ("a" ).num_chunks == 2
1613+
1614+
1615+ def test_arrow_c_stream_order ():
1616+ ctx = SessionContext ()
1617+
1618+ batch1 = pa .record_batch ([pa .array ([1 ])], names = ["a" ])
1619+ batch2 = pa .record_batch ([pa .array ([2 ])], names = ["a" ])
1620+
1621+ df = ctx .create_dataframe ([[batch1 , batch2 ]])
1622+
1623+ table = pa .Table .from_batches (df )
1624+ expected = pa .Table .from_batches ([batch1 , batch2 ])
1625+
1626+ assert table .equals (expected )
1627+ col = table .column ("a" )
1628+ assert col .chunk (0 )[0 ].as_py () == 1
1629+ assert col .chunk (1 )[0 ].as_py () == 2
1630+
1631+
1632+ def test_arrow_c_stream_reader (df ):
1633+ reader = pa .RecordBatchReader ._import_from_c_capsule (df .__arrow_c_stream__ ())
1634+ assert isinstance (reader , pa .RecordBatchReader )
1635+ table = pa .Table .from_batches (reader )
1636+ expected = pa .Table .from_batches (df .collect ())
1637+ assert table .equals (expected )
1638+
1639+
15851640def test_to_pylist (df ):
15861641 # Convert datafusion dataframe to Python list
15871642 pylist = df .to_pylist ()
@@ -2666,6 +2721,110 @@ def trigger_interrupt():
26662721 interrupt_thread .join (timeout = 1.0 )
26672722
26682723
2724+ def test_arrow_c_stream_interrupted ():
2725+ """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
2726+
2727+ Similar to ``test_collect_interrupted`` this test issues a long running
2728+ query, but consumes the results via ``__arrow_c_stream__``. It then raises
2729+ ``KeyboardInterrupt`` in the main thread and verifies that the stream
2730+ iteration stops promptly with the appropriate exception.
2731+ """
2732+
2733+ ctx = SessionContext ()
2734+
2735+ batches = []
2736+ for i in range (10 ):
2737+ batch = pa .RecordBatch .from_arrays (
2738+ [
2739+ pa .array (list (range (i * 1000 , (i + 1 ) * 1000 ))),
2740+ pa .array ([f"value_{ j } " for j in range (i * 1000 , (i + 1 ) * 1000 )]),
2741+ ],
2742+ names = ["a" , "b" ],
2743+ )
2744+ batches .append (batch )
2745+
2746+ ctx .register_record_batches ("t1" , [batches ])
2747+ ctx .register_record_batches ("t2" , [batches ])
2748+
2749+ df = ctx .sql (
2750+ """
2751+ WITH t1_expanded AS (
2752+ SELECT
2753+ a,
2754+ b,
2755+ CAST(a AS DOUBLE) / 1.5 AS c,
2756+ CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
2757+ FROM t1
2758+ CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
2759+ ),
2760+ t2_expanded AS (
2761+ SELECT
2762+ a,
2763+ b,
2764+ CAST(a AS DOUBLE) * 2.5 AS e,
2765+ CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
2766+ FROM t2
2767+ CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
2768+ )
2769+ SELECT
2770+ t1.a, t1.b, t1.c, t1.d,
2771+ t2.a AS a2, t2.b AS b2, t2.e, t2.f
2772+ FROM t1_expanded t1
2773+ JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
2774+ WHERE t1.a > 100 AND t2.a > 100
2775+ """
2776+ )
2777+
2778+ reader = pa .RecordBatchReader ._import_from_c_capsule (df .__arrow_c_stream__ ())
2779+
2780+ interrupted = False
2781+ interrupt_error = None
2782+ query_started = threading .Event ()
2783+ max_wait_time = 5.0
2784+
2785+ def trigger_interrupt ():
2786+ start_time = time .time ()
2787+ while not query_started .is_set ():
2788+ time .sleep (0.1 )
2789+ if time .time () - start_time > max_wait_time :
2790+ msg = f"Query did not start within { max_wait_time } seconds"
2791+ raise RuntimeError (msg )
2792+
2793+ thread_id = threading .main_thread ().ident
2794+ if thread_id is None :
2795+ msg = "Cannot get main thread ID"
2796+ raise RuntimeError (msg )
2797+
2798+ exception = ctypes .py_object (KeyboardInterrupt )
2799+ res = ctypes .pythonapi .PyThreadState_SetAsyncExc (
2800+ ctypes .c_long (thread_id ), exception
2801+ )
2802+ if res != 1 :
2803+ ctypes .pythonapi .PyThreadState_SetAsyncExc (
2804+ ctypes .c_long (thread_id ), ctypes .py_object (0 )
2805+ )
2806+ msg = "Failed to raise KeyboardInterrupt in main thread"
2807+ raise RuntimeError (msg )
2808+
2809+ interrupt_thread = threading .Thread (target = trigger_interrupt )
2810+ interrupt_thread .daemon = True
2811+ interrupt_thread .start ()
2812+
2813+ try :
2814+ query_started .set ()
2815+ # consume the reader which should block and be interrupted
2816+ reader .read_all ()
2817+ except KeyboardInterrupt :
2818+ interrupted = True
2819+ except Exception as e : # pragma: no cover - unexpected errors
2820+ interrupt_error = e
2821+
2822+ if not interrupted :
2823+ pytest .fail (f"Stream was not interrupted; got error: { interrupt_error } " )
2824+
2825+ interrupt_thread .join (timeout = 1.0 )
2826+
2827+
26692828def test_show_select_where_no_rows (capsys ) -> None :
26702829 ctx = SessionContext ()
26712830 df = ctx .sql ("SELECT 1 WHERE 1=0" )
0 commit comments