1515# specific language governing permissions and limitations
1616# under the License.
1717
18+ from __future__ import annotations
19+
20+ import math
1821import time
1922
2023import pyarrow as pa
2124from datafusion import SessionContext
2225
2326
24- def run (n_batches : int = 8 , batch_size : int = 1_000_000 ) -> None :
27+ def run (
28+ n_batches : int = 8 ,
29+ batch_size : int = 1_000_000 ,
30+ n_partitions : int | None = None ,
31+ ) -> None :
2532 ctx = SessionContext ()
2633 batches = []
2734 for i in range (n_batches ):
2835 start = i * batch_size
2936 arr = pa .array (range (start , start + batch_size ))
3037 batches .append (pa .record_batch ([arr ], names = ["a" ]))
3138
32- df = ctx .create_dataframe ([batches ])
39+ if n_partitions is None :
40+ n_partitions = n_batches
41+ n_partitions = max (1 , min (n_partitions , n_batches ))
42+ partition_size = math .ceil (len (batches ) / n_partitions )
43+ partitions = [
44+ batches [i : i + partition_size ] for i in range (0 , len (batches ), partition_size )
45+ ]
46+ df = ctx .create_dataframe (partitions )
3347
3448 start = time .perf_counter ()
3549 df .collect ()
@@ -38,4 +52,14 @@ def run(n_batches: int = 8, batch_size: int = 1_000_000) -> None:
3852
3953
4054if __name__ == "__main__" :
41- run ()
55+ import argparse
56+
57+ parser = argparse .ArgumentParser ()
58+ parser .add_argument (
59+ "--partitions" ,
60+ type = int ,
61+ default = None ,
62+ help = "number of partitions to create (defaults to one per batch)" ,
63+ )
64+ args = parser .parse_args ()
65+ run (n_partitions = args .partitions )
0 commit comments