Skip to content

Commit dc396c7

Browse files
committed
Enhance run function to support dynamic partitioning of batches
1 parent 4f5532d commit dc396c7

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

benchmarks/collect_gil_bench.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,35 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from __future__ import annotations
19+
20+
import math
1821
import time
1922

2023
import pyarrow as pa
2124
from datafusion import SessionContext
2225

2326

24-
def run(n_batches: int = 8, batch_size: int = 1_000_000) -> None:
27+
def run(
28+
n_batches: int = 8,
29+
batch_size: int = 1_000_000,
30+
n_partitions: int | None = None,
31+
) -> None:
2532
ctx = SessionContext()
2633
batches = []
2734
for i in range(n_batches):
2835
start = i * batch_size
2936
arr = pa.array(range(start, start + batch_size))
3037
batches.append(pa.record_batch([arr], names=["a"]))
3138

32-
df = ctx.create_dataframe([batches])
39+
if n_partitions is None:
40+
n_partitions = n_batches
41+
n_partitions = max(1, min(n_partitions, n_batches))
42+
partition_size = math.ceil(len(batches) / n_partitions)
43+
partitions = [
44+
batches[i : i + partition_size] for i in range(0, len(batches), partition_size)
45+
]
46+
df = ctx.create_dataframe(partitions)
3347

3448
start = time.perf_counter()
3549
df.collect()
@@ -38,4 +52,14 @@ def run(n_batches: int = 8, batch_size: int = 1_000_000) -> None:
3852

3953

4054
if __name__ == "__main__":
41-
run()
55+
import argparse
56+
57+
parser = argparse.ArgumentParser()
58+
parser.add_argument(
59+
"--partitions",
60+
type=int,
61+
default=None,
62+
help="number of partitions to create (defaults to one per batch)",
63+
)
64+
args = parser.parse_args()
65+
run(n_partitions=args.partitions)

0 commit comments

Comments
 (0)