@@ -1582,6 +1582,37 @@ def test_empty_to_arrow_table(df):
15821582 assert set (pyarrow_table .column_names ) == {"a" , "b" , "c" }
15831583
15841584
1585+ def test_arrow_c_stream_to_table (monkeypatch ):
1586+ ctx = SessionContext ()
1587+
1588+ # Create a DataFrame with two separate record batches
1589+ batch1 = pa .record_batch ([pa .array ([1 ])], names = ["a" ])
1590+ batch2 = pa .record_batch ([pa .array ([2 ])], names = ["a" ])
1591+ df = ctx .create_dataframe ([[batch1 ], [batch2 ]])
1592+
1593+ # Fail if the DataFrame is pre-collected
1594+ def fail_collect (self ): # pragma: no cover - failure path
1595+ msg = "collect should not be called"
1596+ raise AssertionError (msg )
1597+
1598+ monkeypatch .setattr (DataFrame , "collect" , fail_collect )
1599+
1600+ table = pa .Table .from_batches (df )
1601+ expected = pa .Table .from_batches ([batch1 , batch2 ])
1602+
1603+ assert table .equals (expected )
1604+ assert table .schema == df .schema ()
1605+ assert table .column ("a" ).num_chunks == 2
1606+
1607+
1608+ def test_arrow_c_stream_reader (df ):
1609+ reader = pa .RecordBatchReader ._import_from_c_capsule (df .__arrow_c_stream__ ())
1610+ assert isinstance (reader , pa .RecordBatchReader )
1611+ table = pa .Table .from_batches (reader )
1612+ expected = pa .Table .from_batches (df .collect ())
1613+ assert table .equals (expected )
1614+
1615+
15851616def test_to_pylist (df ):
15861617 # Convert datafusion dataframe to Python list
15871618 pylist = df .to_pylist ()
@@ -2666,6 +2697,110 @@ def trigger_interrupt():
26662697 interrupt_thread .join (timeout = 1.0 )
26672698
26682699
2700+ def test_arrow_c_stream_interrupted ():
2701+ """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
2702+
2703+ Similar to ``test_collect_interrupted`` this test issues a long running
2704+ query, but consumes the results via ``__arrow_c_stream__``. It then raises
2705+ ``KeyboardInterrupt`` in the main thread and verifies that the stream
2706+ iteration stops promptly with the appropriate exception.
2707+ """
2708+
2709+ ctx = SessionContext ()
2710+
2711+ batches = []
2712+ for i in range (10 ):
2713+ batch = pa .RecordBatch .from_arrays (
2714+ [
2715+ pa .array (list (range (i * 1000 , (i + 1 ) * 1000 ))),
2716+ pa .array ([f"value_{ j } " for j in range (i * 1000 , (i + 1 ) * 1000 )]),
2717+ ],
2718+ names = ["a" , "b" ],
2719+ )
2720+ batches .append (batch )
2721+
2722+ ctx .register_record_batches ("t1" , [batches ])
2723+ ctx .register_record_batches ("t2" , [batches ])
2724+
2725+ df = ctx .sql (
2726+ """
2727+ WITH t1_expanded AS (
2728+ SELECT
2729+ a,
2730+ b,
2731+ CAST(a AS DOUBLE) / 1.5 AS c,
2732+ CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
2733+ FROM t1
2734+ CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
2735+ ),
2736+ t2_expanded AS (
2737+ SELECT
2738+ a,
2739+ b,
2740+ CAST(a AS DOUBLE) * 2.5 AS e,
2741+ CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
2742+ FROM t2
2743+ CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
2744+ )
2745+ SELECT
2746+ t1.a, t1.b, t1.c, t1.d,
2747+ t2.a AS a2, t2.b AS b2, t2.e, t2.f
2748+ FROM t1_expanded t1
2749+ JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
2750+ WHERE t1.a > 100 AND t2.a > 100
2751+ """
2752+ )
2753+
2754+ reader = pa .RecordBatchReader ._import_from_c_capsule (df .__arrow_c_stream__ ())
2755+
2756+ interrupted = False
2757+ interrupt_error = None
2758+ query_started = threading .Event ()
2759+ max_wait_time = 5.0
2760+
2761+ def trigger_interrupt ():
2762+ start_time = time .time ()
2763+ while not query_started .is_set ():
2764+ time .sleep (0.1 )
2765+ if time .time () - start_time > max_wait_time :
2766+ msg = f"Query did not start within { max_wait_time } seconds"
2767+ raise RuntimeError (msg )
2768+
2769+ thread_id = threading .main_thread ().ident
2770+ if thread_id is None :
2771+ msg = "Cannot get main thread ID"
2772+ raise RuntimeError (msg )
2773+
2774+ exception = ctypes .py_object (KeyboardInterrupt )
2775+ res = ctypes .pythonapi .PyThreadState_SetAsyncExc (
2776+ ctypes .c_long (thread_id ), exception
2777+ )
2778+ if res != 1 :
2779+ ctypes .pythonapi .PyThreadState_SetAsyncExc (
2780+ ctypes .c_long (thread_id ), ctypes .py_object (0 )
2781+ )
2782+ msg = "Failed to raise KeyboardInterrupt in main thread"
2783+ raise RuntimeError (msg )
2784+
2785+ interrupt_thread = threading .Thread (target = trigger_interrupt )
2786+ interrupt_thread .daemon = True
2787+ interrupt_thread .start ()
2788+
2789+ try :
2790+ query_started .set ()
2791+ # consume the reader which should block and be interrupted
2792+ reader .read_all ()
2793+ except KeyboardInterrupt :
2794+ interrupted = True
2795+ except Exception as e : # pragma: no cover - unexpected errors
2796+ interrupt_error = e
2797+
2798+ if not interrupted :
2799+ pytest .fail (f"Stream was not interrupted; got error: { interrupt_error } " )
2800+
2801+ interrupt_thread .join (timeout = 1.0 )
2802+
2803+
26692804def test_show_select_where_no_rows (capsys ) -> None :
26702805 ctx = SessionContext ()
26712806 df = ctx .sql ("SELECT 1 WHERE 1=0" )
0 commit comments