|
46 | 46 | from datafusion.expr import Window |
47 | 47 | from pyarrow.csv import write_csv |
48 | 48 |
|
| 49 | +pa_cffi = pytest.importorskip("pyarrow.cffi") |
| 50 | + |
49 | 51 | MB = 1024 * 1024 |
50 | 52 |
|
51 | 53 |
|
@@ -1582,6 +1584,120 @@ def test_empty_to_arrow_table(df): |
1582 | 1584 | assert set(pyarrow_table.column_names) == {"a", "b", "c"} |
1583 | 1585 |
|
1584 | 1586 |
|
| 1587 | +def test_iter_batches_dataframe(fail_collect): |
| 1588 | + ctx = SessionContext() |
| 1589 | + |
| 1590 | + batch1 = pa.record_batch([pa.array([1])], names=["a"]) |
| 1591 | + batch2 = pa.record_batch([pa.array([2])], names=["a"]) |
| 1592 | + df = ctx.create_dataframe([[batch1], [batch2]]) |
| 1593 | + |
| 1594 | + expected = [batch1, batch2] |
| 1595 | + for got, exp in zip(df, expected): |
| 1596 | + assert got.equals(exp) |
| 1597 | + |
| 1598 | + |
| 1599 | +def test_arrow_c_stream_to_table_and_reader(fail_collect): |
| 1600 | + ctx = SessionContext() |
| 1601 | + |
| 1602 | + # Create a DataFrame with two separate record batches |
| 1603 | + batch1 = pa.record_batch([pa.array([1])], names=["a"]) |
| 1604 | + batch2 = pa.record_batch([pa.array([2])], names=["a"]) |
| 1605 | + df = ctx.create_dataframe([[batch1], [batch2]]) |
| 1606 | + |
| 1607 | + table = pa.Table.from_batches(df) |
| 1608 | + batches = table.to_batches() |
| 1609 | + |
| 1610 | + assert len(batches) == 2 |
| 1611 | + assert batches[0].equals(batch1) |
| 1612 | + assert batches[1].equals(batch2) |
| 1613 | + assert table.schema == df.schema() |
| 1614 | + assert table.column("a").num_chunks == 2 |
| 1615 | + |
| 1616 | + reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__()) |
| 1617 | + assert isinstance(reader, pa.RecordBatchReader) |
| 1618 | + reader_table = pa.Table.from_batches(reader) |
| 1619 | + expected = pa.Table.from_batches([batch1, batch2]) |
| 1620 | + assert reader_table.equals(expected) |
| 1621 | + |
| 1622 | + |
| 1623 | +def test_arrow_c_stream_order(): |
| 1624 | + ctx = SessionContext() |
| 1625 | + |
| 1626 | + batch1 = pa.record_batch([pa.array([1])], names=["a"]) |
| 1627 | + batch2 = pa.record_batch([pa.array([2])], names=["a"]) |
| 1628 | + |
| 1629 | + df = ctx.create_dataframe([[batch1, batch2]]) |
| 1630 | + |
| 1631 | + table = pa.Table.from_batches(df) |
| 1632 | + expected = pa.Table.from_batches([batch1, batch2]) |
| 1633 | + |
| 1634 | + assert table.equals(expected) |
| 1635 | + col = table.column("a") |
| 1636 | + assert col.chunk(0)[0].as_py() == 1 |
| 1637 | + assert col.chunk(1)[0].as_py() == 2 |
| 1638 | + |
| 1639 | + |
| 1640 | +def test_arrow_c_stream_schema_selection(fail_collect): |
| 1641 | + ctx = SessionContext() |
| 1642 | + |
| 1643 | + batch = pa.RecordBatch.from_arrays( |
| 1644 | + [ |
| 1645 | + pa.array([1, 2]), |
| 1646 | + pa.array([3, 4]), |
| 1647 | + pa.array([5, 6]), |
| 1648 | + ], |
| 1649 | + names=["a", "b", "c"], |
| 1650 | + ) |
| 1651 | + df = ctx.create_dataframe([[batch]]) |
| 1652 | + |
| 1653 | + requested_schema = pa.schema([("c", pa.int64()), ("a", pa.int64())]) |
| 1654 | + |
| 1655 | + c_schema = pa_cffi.ffi.new("struct ArrowSchema*") |
| 1656 | + address = int(pa_cffi.ffi.cast("uintptr_t", c_schema)) |
| 1657 | + requested_schema._export_to_c(address) |
| 1658 | + capsule_new = ctypes.pythonapi.PyCapsule_New |
| 1659 | + capsule_new.restype = ctypes.py_object |
| 1660 | + capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] |
| 1661 | + schema_capsule = capsule_new(ctypes.c_void_p(address), b"arrow_schema", None) |
| 1662 | + |
| 1663 | + reader = pa.RecordBatchReader._import_from_c_capsule( |
| 1664 | + df.__arrow_c_stream__(schema_capsule) |
| 1665 | + ) |
| 1666 | + |
| 1667 | + assert reader.schema == requested_schema |
| 1668 | + |
| 1669 | + batches = list(reader) |
| 1670 | + |
| 1671 | + assert len(batches) == 1 |
| 1672 | + expected_batch = pa.record_batch( |
| 1673 | + [pa.array([5, 6]), pa.array([1, 2])], names=["c", "a"] |
| 1674 | + ) |
| 1675 | + assert batches[0].equals(expected_batch) |
| 1676 | + |
| 1677 | + |
| 1678 | +def test_arrow_c_stream_schema_mismatch(fail_collect): |
| 1679 | + ctx = SessionContext() |
| 1680 | + |
| 1681 | + batch = pa.RecordBatch.from_arrays( |
| 1682 | + [pa.array([1, 2]), pa.array([3, 4])], names=["a", "b"] |
| 1683 | + ) |
| 1684 | + df = ctx.create_dataframe([[batch]]) |
| 1685 | + |
| 1686 | + bad_schema = pa.schema([("a", pa.string())]) |
| 1687 | + |
| 1688 | + c_schema = pa_cffi.ffi.new("struct ArrowSchema*") |
| 1689 | + address = int(pa_cffi.ffi.cast("uintptr_t", c_schema)) |
| 1690 | + bad_schema._export_to_c(address) |
| 1691 | + |
| 1692 | + capsule_new = ctypes.pythonapi.PyCapsule_New |
| 1693 | + capsule_new.restype = ctypes.py_object |
| 1694 | + capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] |
| 1695 | + bad_capsule = capsule_new(ctypes.c_void_p(address), b"arrow_schema", None) |
| 1696 | + |
| 1697 | + with pytest.raises(Exception, match="Fail to merge schema"): |
| 1698 | + df.__arrow_c_stream__(bad_capsule) |
| 1699 | + |
| 1700 | + |
1585 | 1701 | def test_to_pylist(df): |
1586 | 1702 | # Convert datafusion dataframe to Python list |
1587 | 1703 | pylist = df.to_pylist() |
@@ -2666,6 +2782,110 @@ def trigger_interrupt(): |
2666 | 2782 | interrupt_thread.join(timeout=1.0) |
2667 | 2783 |
|
2668 | 2784 |
|
| 2785 | +def test_arrow_c_stream_interrupted(): |
| 2786 | + """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals. |
| 2787 | +
|
| 2788 | + Similar to ``test_collect_interrupted`` this test issues a long running |
| 2789 | + query, but consumes the results via ``__arrow_c_stream__``. It then raises |
| 2790 | + ``KeyboardInterrupt`` in the main thread and verifies that the stream |
| 2791 | + iteration stops promptly with the appropriate exception. |
| 2792 | + """ |
| 2793 | + |
| 2794 | + ctx = SessionContext() |
| 2795 | + |
| 2796 | + batches = [] |
| 2797 | + for i in range(10): |
| 2798 | + batch = pa.RecordBatch.from_arrays( |
| 2799 | + [ |
| 2800 | + pa.array(list(range(i * 1000, (i + 1) * 1000))), |
| 2801 | + pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]), |
| 2802 | + ], |
| 2803 | + names=["a", "b"], |
| 2804 | + ) |
| 2805 | + batches.append(batch) |
| 2806 | + |
| 2807 | + ctx.register_record_batches("t1", [batches]) |
| 2808 | + ctx.register_record_batches("t2", [batches]) |
| 2809 | + |
| 2810 | + df = ctx.sql( |
| 2811 | + """ |
| 2812 | + WITH t1_expanded AS ( |
| 2813 | + SELECT |
| 2814 | + a, |
| 2815 | + b, |
| 2816 | + CAST(a AS DOUBLE) / 1.5 AS c, |
| 2817 | + CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d |
| 2818 | + FROM t1 |
| 2819 | + CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5) |
| 2820 | + ), |
| 2821 | + t2_expanded AS ( |
| 2822 | + SELECT |
| 2823 | + a, |
| 2824 | + b, |
| 2825 | + CAST(a AS DOUBLE) * 2.5 AS e, |
| 2826 | + CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f |
| 2827 | + FROM t2 |
| 2828 | + CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5) |
| 2829 | + ) |
| 2830 | + SELECT |
| 2831 | + t1.a, t1.b, t1.c, t1.d, |
| 2832 | + t2.a AS a2, t2.b AS b2, t2.e, t2.f |
| 2833 | + FROM t1_expanded t1 |
| 2834 | + JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100 |
| 2835 | + WHERE t1.a > 100 AND t2.a > 100 |
| 2836 | + """ |
| 2837 | + ) |
| 2838 | + |
| 2839 | + reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__()) |
| 2840 | + |
| 2841 | + interrupted = False |
| 2842 | + interrupt_error = None |
| 2843 | + query_started = threading.Event() |
| 2844 | + max_wait_time = 5.0 |
| 2845 | + |
| 2846 | + def trigger_interrupt(): |
| 2847 | + start_time = time.time() |
| 2848 | + while not query_started.is_set(): |
| 2849 | + time.sleep(0.1) |
| 2850 | + if time.time() - start_time > max_wait_time: |
| 2851 | + msg = f"Query did not start within {max_wait_time} seconds" |
| 2852 | + raise RuntimeError(msg) |
| 2853 | + |
| 2854 | + thread_id = threading.main_thread().ident |
| 2855 | + if thread_id is None: |
| 2856 | + msg = "Cannot get main thread ID" |
| 2857 | + raise RuntimeError(msg) |
| 2858 | + |
| 2859 | + exception = ctypes.py_object(KeyboardInterrupt) |
| 2860 | + res = ctypes.pythonapi.PyThreadState_SetAsyncExc( |
| 2861 | + ctypes.c_long(thread_id), exception |
| 2862 | + ) |
| 2863 | + if res != 1: |
| 2864 | + ctypes.pythonapi.PyThreadState_SetAsyncExc( |
| 2865 | + ctypes.c_long(thread_id), ctypes.py_object(0) |
| 2866 | + ) |
| 2867 | + msg = "Failed to raise KeyboardInterrupt in main thread" |
| 2868 | + raise RuntimeError(msg) |
| 2869 | + |
| 2870 | + interrupt_thread = threading.Thread(target=trigger_interrupt) |
| 2871 | + interrupt_thread.daemon = True |
| 2872 | + interrupt_thread.start() |
| 2873 | + |
| 2874 | + try: |
| 2875 | + query_started.set() |
| 2876 | + # consume the reader which should block and be interrupted |
| 2877 | + reader.read_all() |
| 2878 | + except KeyboardInterrupt: |
| 2879 | + interrupted = True |
| 2880 | + except Exception as e: # pragma: no cover - unexpected errors |
| 2881 | + interrupt_error = e |
| 2882 | + |
| 2883 | + if not interrupted: |
| 2884 | + pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}") |
| 2885 | + |
| 2886 | + interrupt_thread.join(timeout=1.0) |
| 2887 | + |
| 2888 | + |
2669 | 2889 | def test_show_select_where_no_rows(capsys) -> None: |
2670 | 2890 | ctx = SessionContext() |
2671 | 2891 | df = ctx.sql("SELECT 1 WHERE 1=0") |
|
0 commit comments