Skip to content

Commit 03a1022

Browse files
committed
Add tests for CaseBuilder to ensure builder state is preserved on success
1 parent cfc9f2c commit 03a1022

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

python/tests/test_expr.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,29 @@ def test_case_builder_error_preserves_builder_state():
218218
assert "CaseBuilder has already been consumed" not in err_msg
219219

220220

221+
def test_case_builder_success_preserves_builder_state():
222+
ctx = SessionContext()
223+
df = ctx.from_pydict({"flag": [False]}, name="tbl")
224+
225+
case_builder = functions.when(col("flag"), lit("true"))
226+
227+
expr_default_one = case_builder.otherwise(lit("default-1")).alias("result")
228+
result_one = df.select(expr_default_one).collect()
229+
assert result_one[0].column(0).to_pylist() == ["default-1"]
230+
231+
expr_default_two = case_builder.otherwise(lit("default-2")).alias("result")
232+
result_two = df.select(expr_default_two).collect()
233+
assert result_two[0].column(0).to_pylist() == ["default-2"]
234+
235+
expr_end_one = case_builder.end().alias("result")
236+
end_one = df.select(expr_end_one).collect()
237+
assert end_one[0].column(0).to_pylist() == ["default-2"]
238+
239+
expr_end_two = case_builder.end().alias("result")
240+
end_two = df.select(expr_end_two).collect()
241+
assert end_two[0].column(0).to_pylist() == ["default-2"]
242+
243+
221244
def test_expr_getitem() -> None:
222245
ctx = SessionContext()
223246
data = {

src/expr/conditional_expr.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ impl PyCaseBuilder {
7777
fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult<PyExpr> {
7878
let mut builder = self.take_case_builder()?;
7979
match builder.otherwise(else_expr.expr) {
80-
Ok(expr) => Ok(expr.clone().into()),
80+
Ok(expr) => {
81+
self.store_case_builder(builder);
82+
Ok(expr.clone().into())
83+
}
8184
Err(err) => {
8285
self.store_case_builder(builder);
8386
Err(err.into())
@@ -88,7 +91,10 @@ impl PyCaseBuilder {
8891
fn end(&self) -> PyDataFusionResult<PyExpr> {
8992
let builder = self.take_case_builder()?;
9093
match builder.end() {
91-
Ok(expr) => Ok(expr.clone().into()),
94+
Ok(expr) => {
95+
self.store_case_builder(builder);
96+
Ok(expr.clone().into())
97+
}
9298
Err(err) => {
9399
self.store_case_builder(builder);
94100
Err(err.into())

0 commit comments

Comments
 (0)