|
46 | 46 | from datafusion.expr import Window |
47 | 47 | from pyarrow.csv import write_csv |
48 | 48 |
|
49 | | -pa_cffi = pytest.importorskip("pyarrow.cffi") |
50 | | - |
51 | 49 | MB = 1024 * 1024 |
52 | 50 |
|
53 | 51 |
|
@@ -1584,120 +1582,6 @@ def test_empty_to_arrow_table(df): |
1584 | 1582 | assert set(pyarrow_table.column_names) == {"a", "b", "c"} |
1585 | 1583 |
|
1586 | 1584 |
|
1587 | | -def test_iter_batches_dataframe(fail_collect): |
1588 | | - ctx = SessionContext() |
1589 | | - |
1590 | | - batch1 = pa.record_batch([pa.array([1])], names=["a"]) |
1591 | | - batch2 = pa.record_batch([pa.array([2])], names=["a"]) |
1592 | | - df = ctx.create_dataframe([[batch1], [batch2]]) |
1593 | | - |
1594 | | - expected = [batch1, batch2] |
1595 | | - for got, exp in zip(df, expected): |
1596 | | - assert got.equals(exp) |
1597 | | - |
1598 | | - |
1599 | | -def test_arrow_c_stream_to_table_and_reader(fail_collect): |
1600 | | - ctx = SessionContext() |
1601 | | - |
1602 | | - # Create a DataFrame with two separate record batches |
1603 | | - batch1 = pa.record_batch([pa.array([1])], names=["a"]) |
1604 | | - batch2 = pa.record_batch([pa.array([2])], names=["a"]) |
1605 | | - df = ctx.create_dataframe([[batch1], [batch2]]) |
1606 | | - |
1607 | | - table = pa.Table.from_batches(df) |
1608 | | - batches = table.to_batches() |
1609 | | - |
1610 | | - assert len(batches) == 2 |
1611 | | - assert batches[0].equals(batch1) |
1612 | | - assert batches[1].equals(batch2) |
1613 | | - assert table.schema == df.schema() |
1614 | | - assert table.column("a").num_chunks == 2 |
1615 | | - |
1616 | | - reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__()) |
1617 | | - assert isinstance(reader, pa.RecordBatchReader) |
1618 | | - reader_table = pa.Table.from_batches(reader) |
1619 | | - expected = pa.Table.from_batches([batch1, batch2]) |
1620 | | - assert reader_table.equals(expected) |
1621 | | - |
1622 | | - |
1623 | | -def test_arrow_c_stream_order(): |
1624 | | - ctx = SessionContext() |
1625 | | - |
1626 | | - batch1 = pa.record_batch([pa.array([1])], names=["a"]) |
1627 | | - batch2 = pa.record_batch([pa.array([2])], names=["a"]) |
1628 | | - |
1629 | | - df = ctx.create_dataframe([[batch1, batch2]]) |
1630 | | - |
1631 | | - table = pa.Table.from_batches(df) |
1632 | | - expected = pa.Table.from_batches([batch1, batch2]) |
1633 | | - |
1634 | | - assert table.equals(expected) |
1635 | | - col = table.column("a") |
1636 | | - assert col.chunk(0)[0].as_py() == 1 |
1637 | | - assert col.chunk(1)[0].as_py() == 2 |
1638 | | - |
1639 | | - |
1640 | | -def test_arrow_c_stream_schema_selection(fail_collect): |
1641 | | - ctx = SessionContext() |
1642 | | - |
1643 | | - batch = pa.RecordBatch.from_arrays( |
1644 | | - [ |
1645 | | - pa.array([1, 2]), |
1646 | | - pa.array([3, 4]), |
1647 | | - pa.array([5, 6]), |
1648 | | - ], |
1649 | | - names=["a", "b", "c"], |
1650 | | - ) |
1651 | | - df = ctx.create_dataframe([[batch]]) |
1652 | | - |
1653 | | - requested_schema = pa.schema([("c", pa.int64()), ("a", pa.int64())]) |
1654 | | - |
1655 | | - c_schema = pa_cffi.ffi.new("struct ArrowSchema*") |
1656 | | - address = int(pa_cffi.ffi.cast("uintptr_t", c_schema)) |
1657 | | - requested_schema._export_to_c(address) |
1658 | | - capsule_new = ctypes.pythonapi.PyCapsule_New |
1659 | | - capsule_new.restype = ctypes.py_object |
1660 | | - capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] |
1661 | | - schema_capsule = capsule_new(ctypes.c_void_p(address), b"arrow_schema", None) |
1662 | | - |
1663 | | - reader = pa.RecordBatchReader._import_from_c_capsule( |
1664 | | - df.__arrow_c_stream__(schema_capsule) |
1665 | | - ) |
1666 | | - |
1667 | | - assert reader.schema == requested_schema |
1668 | | - |
1669 | | - batches = list(reader) |
1670 | | - |
1671 | | - assert len(batches) == 1 |
1672 | | - expected_batch = pa.record_batch( |
1673 | | - [pa.array([5, 6]), pa.array([1, 2])], names=["c", "a"] |
1674 | | - ) |
1675 | | - assert batches[0].equals(expected_batch) |
1676 | | - |
1677 | | - |
1678 | | -def test_arrow_c_stream_schema_mismatch(fail_collect): |
1679 | | - ctx = SessionContext() |
1680 | | - |
1681 | | - batch = pa.RecordBatch.from_arrays( |
1682 | | - [pa.array([1, 2]), pa.array([3, 4])], names=["a", "b"] |
1683 | | - ) |
1684 | | - df = ctx.create_dataframe([[batch]]) |
1685 | | - |
1686 | | - bad_schema = pa.schema([("a", pa.string())]) |
1687 | | - |
1688 | | - c_schema = pa_cffi.ffi.new("struct ArrowSchema*") |
1689 | | - address = int(pa_cffi.ffi.cast("uintptr_t", c_schema)) |
1690 | | - bad_schema._export_to_c(address) |
1691 | | - |
1692 | | - capsule_new = ctypes.pythonapi.PyCapsule_New |
1693 | | - capsule_new.restype = ctypes.py_object |
1694 | | - capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] |
1695 | | - bad_capsule = capsule_new(ctypes.c_void_p(address), b"arrow_schema", None) |
1696 | | - |
1697 | | - with pytest.raises(Exception, match="Fail to merge schema"): |
1698 | | - df.__arrow_c_stream__(bad_capsule) |
1699 | | - |
1700 | | - |
1701 | 1585 | def test_to_pylist(df): |
1702 | 1586 | # Convert datafusion dataframe to Python list |
1703 | 1587 | pylist = df.to_pylist() |
@@ -2782,110 +2666,6 @@ def trigger_interrupt(): |
2782 | 2666 | interrupt_thread.join(timeout=1.0) |
2783 | 2667 |
|
2784 | 2668 |
|
2785 | | -def test_arrow_c_stream_interrupted(): |
2786 | | - """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals. |
2787 | | -
|
2788 | | - Similar to ``test_collect_interrupted`` this test issues a long running |
2789 | | - query, but consumes the results via ``__arrow_c_stream__``. It then raises |
2790 | | - ``KeyboardInterrupt`` in the main thread and verifies that the stream |
2791 | | - iteration stops promptly with the appropriate exception. |
2792 | | - """ |
2793 | | - |
2794 | | - ctx = SessionContext() |
2795 | | - |
2796 | | - batches = [] |
2797 | | - for i in range(10): |
2798 | | - batch = pa.RecordBatch.from_arrays( |
2799 | | - [ |
2800 | | - pa.array(list(range(i * 1000, (i + 1) * 1000))), |
2801 | | - pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]), |
2802 | | - ], |
2803 | | - names=["a", "b"], |
2804 | | - ) |
2805 | | - batches.append(batch) |
2806 | | - |
2807 | | - ctx.register_record_batches("t1", [batches]) |
2808 | | - ctx.register_record_batches("t2", [batches]) |
2809 | | - |
2810 | | - df = ctx.sql( |
2811 | | - """ |
2812 | | - WITH t1_expanded AS ( |
2813 | | - SELECT |
2814 | | - a, |
2815 | | - b, |
2816 | | - CAST(a AS DOUBLE) / 1.5 AS c, |
2817 | | - CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d |
2818 | | - FROM t1 |
2819 | | - CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5) |
2820 | | - ), |
2821 | | - t2_expanded AS ( |
2822 | | - SELECT |
2823 | | - a, |
2824 | | - b, |
2825 | | - CAST(a AS DOUBLE) * 2.5 AS e, |
2826 | | - CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f |
2827 | | - FROM t2 |
2828 | | - CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5) |
2829 | | - ) |
2830 | | - SELECT |
2831 | | - t1.a, t1.b, t1.c, t1.d, |
2832 | | - t2.a AS a2, t2.b AS b2, t2.e, t2.f |
2833 | | - FROM t1_expanded t1 |
2834 | | - JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100 |
2835 | | - WHERE t1.a > 100 AND t2.a > 100 |
2836 | | - """ |
2837 | | - ) |
2838 | | - |
2839 | | - reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__()) |
2840 | | - |
2841 | | - interrupted = False |
2842 | | - interrupt_error = None |
2843 | | - query_started = threading.Event() |
2844 | | - max_wait_time = 5.0 |
2845 | | - |
2846 | | - def trigger_interrupt(): |
2847 | | - start_time = time.time() |
2848 | | - while not query_started.is_set(): |
2849 | | - time.sleep(0.1) |
2850 | | - if time.time() - start_time > max_wait_time: |
2851 | | - msg = f"Query did not start within {max_wait_time} seconds" |
2852 | | - raise RuntimeError(msg) |
2853 | | - |
2854 | | - thread_id = threading.main_thread().ident |
2855 | | - if thread_id is None: |
2856 | | - msg = "Cannot get main thread ID" |
2857 | | - raise RuntimeError(msg) |
2858 | | - |
2859 | | - exception = ctypes.py_object(KeyboardInterrupt) |
2860 | | - res = ctypes.pythonapi.PyThreadState_SetAsyncExc( |
2861 | | - ctypes.c_long(thread_id), exception |
2862 | | - ) |
2863 | | - if res != 1: |
2864 | | - ctypes.pythonapi.PyThreadState_SetAsyncExc( |
2865 | | - ctypes.c_long(thread_id), ctypes.py_object(0) |
2866 | | - ) |
2867 | | - msg = "Failed to raise KeyboardInterrupt in main thread" |
2868 | | - raise RuntimeError(msg) |
2869 | | - |
2870 | | - interrupt_thread = threading.Thread(target=trigger_interrupt) |
2871 | | - interrupt_thread.daemon = True |
2872 | | - interrupt_thread.start() |
2873 | | - |
2874 | | - try: |
2875 | | - query_started.set() |
2876 | | - # consume the reader which should block and be interrupted |
2877 | | - reader.read_all() |
2878 | | - except KeyboardInterrupt: |
2879 | | - interrupted = True |
2880 | | - except Exception as e: # pragma: no cover - unexpected errors |
2881 | | - interrupt_error = e |
2882 | | - |
2883 | | - if not interrupted: |
2884 | | - pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}") |
2885 | | - |
2886 | | - interrupt_thread.join(timeout=1.0) |
2887 | | - |
2888 | | - |
2889 | 2669 | def test_show_select_where_no_rows(capsys) -> None: |
2890 | 2670 | ctx = SessionContext() |
2891 | 2671 | df = ctx.sql("SELECT 1 WHERE 1=0") |
|
0 commit comments