Skip to content

Commit cfc9f2c

Browse files
committed
Add concurrency tests for SqlSchema, Config, and DataFrame
1 parent dba5c6a commit cfc9f2c

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

python/tests/test_concurrency.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from concurrent.futures import ThreadPoolExecutor
21+
22+
import pyarrow as pa
23+
24+
from datafusion import Config, SessionContext, col, lit
25+
from datafusion.common import SqlSchema
26+
from datafusion import functions as f
27+
28+
29+
def _run_in_threads(fn, count: int = 8) -> None:
30+
with ThreadPoolExecutor(max_workers=count) as executor:
31+
futures = [executor.submit(fn, i) for i in range(count)]
32+
for future in futures:
33+
# Propagate any exception raised in the worker thread.
34+
future.result()
35+
36+
37+
def test_concurrent_access_to_shared_structures() -> None:
38+
"""Exercise SqlSchema, Config, and DataFrame concurrently."""
39+
40+
schema = SqlSchema("concurrency")
41+
config = Config()
42+
ctx = SessionContext()
43+
44+
batch = pa.record_batch([pa.array([1, 2, 3], type=pa.int32())], names=["value"])
45+
df = ctx.create_dataframe([[batch]])
46+
47+
config_key = "datafusion.execution.batch_size"
48+
expected_rows = batch.num_rows
49+
50+
def worker(index: int) -> None:
51+
schema.name = f"concurrency-{index}"
52+
assert schema.name.startswith("concurrency-")
53+
# Exercise getters that use internal locks.
54+
assert isinstance(schema.tables, list)
55+
assert isinstance(schema.views, list)
56+
assert isinstance(schema.functions, list)
57+
58+
config.set(config_key, str(1024 + index))
59+
assert config.get(config_key) is not None
60+
# Access the full config map to stress lock usage.
61+
assert config_key in config.get_all()
62+
63+
batches = df.collect()
64+
assert sum(batch.num_rows for batch in batches) == expected_rows
65+
66+
_run_in_threads(worker, count=12)
67+
68+
69+
def test_case_builder_reuse_from_multiple_threads() -> None:
70+
"""Ensure the case builder can be safely reused across threads."""
71+
72+
ctx = SessionContext()
73+
values = pa.array([0, 1, 2, 3, 4], type=pa.int32())
74+
df = ctx.create_dataframe([[pa.record_batch([values], names=["value"])]])
75+
76+
base_builder = f.case(col("value"))
77+
78+
def add_case(i: int) -> None:
79+
base_builder.when(lit(i), lit(f"value-{i}"))
80+
81+
_run_in_threads(add_case, count=8)
82+
83+
with ThreadPoolExecutor(max_workers=2) as executor:
84+
otherwise_future = executor.submit(base_builder.otherwise, lit("default"))
85+
case_expr = otherwise_future.result()
86+
87+
result = df.select(case_expr.alias("label")).collect()
88+
assert sum(batch.num_rows for batch in result) == len(values)
89+
90+
predicate_builder = f.when(col("value") == lit(0), lit("zero"))
91+
92+
def add_predicate(i: int) -> None:
93+
predicate_builder.when(col("value") == lit(i + 1), lit(f"value-{i + 1}"))
94+
95+
_run_in_threads(add_predicate, count=4)
96+
97+
with ThreadPoolExecutor(max_workers=2) as executor:
98+
end_future = executor.submit(predicate_builder.end)
99+
predicate_expr = end_future.result()
100+
101+
result = df.select(predicate_expr.alias("label")).collect()
102+
assert sum(batch.num_rows for batch in result) == len(values)

0 commit comments

Comments
 (0)