From b8bdaba3a4d831b76618f96cffe978185009ad07 Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Thu, 12 Feb 2026 16:16:58 -0800 Subject: [PATCH] Fix query optimization for semijoins --- crates/core/src/sql/execute.rs | 79 ++++++++++++++++++++++++++++++++ crates/physical-plan/src/plan.rs | 5 +- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index 64b4bf8fe89..f42b0c10a83 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -1303,6 +1303,85 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn test_multi_way_join_with_bridge_tables() -> anyhow::Result<()> { + let db = TestDB::durable()?; + + let orders = db.create_table_for_test( + "orders", + &[ + ("o_orderkey", AlgebraicType::U64), + ("o_custkey", AlgebraicType::U64), + ("o_orderstatus", AlgebraicType::U64), + ], + &[0.into(), 1.into(), 2.into()], + )?; + + let customer = db.create_table_for_test( + "customer", + &[("c_custkey", AlgebraicType::U64), ("c_nationkey", AlgebraicType::U64)], + &[0.into(), 1.into()], + )?; + + let nation = db.create_table_for_test( + "nation", + &[ + ("n_nationkey", AlgebraicType::U64), + ("n_name", AlgebraicType::String), + ("n_regionkey", AlgebraicType::U64), + ], + &[0.into(), 2.into()], + )?; + + let region = db.create_table_for_test( + "region", + &[("r_regionkey", AlgebraicType::U64), ("r_name", AlgebraicType::String)], + &[0.into()], + )?; + + insert_rows(&db, orders, [product![1u64, 10u64, 0u64], product![2u64, 20u64, 1u64]])?; + insert_rows(&db, customer, [product![10u64, 100u64], product![20u64, 200u64]])?; + insert_rows( + &db, + nation, + [ + product![100u64, "NATION_A", 1000u64], + product![200u64, "NATION_B", 2000u64], + ], + )?; + insert_rows( + &db, + region, + [product![1000u64, "REGION_A"], product![2000u64, "REGION_B"]], + )?; + + let result_three_way = run_for_testing( + &db, + " + SELECT customer.c_custkey, nation.n_name + FROM orders + JOIN customer ON customer.c_custkey = orders.o_custkey + JOIN nation ON nation.n_nationkey = customer.c_nationkey + WHERE orders.o_orderstatus = 0", + )?; + + assert_eq!(result_three_way, vec![product![10u64, "NATION_A"]]); + + let result_four_way = run_for_testing( + &db, + " + SELECT customer.c_custkey, region.r_name + FROM orders + JOIN customer ON customer.c_custkey = orders.o_custkey + JOIN nation ON nation.n_nationkey = customer.c_nationkey + JOIN region ON region.r_regionkey = nation.n_regionkey + WHERE orders.o_orderstatus = 0", + )?; + + assert_eq!(result_four_way, vec![product![10u64, "REGION_A"]]); + Ok(()) + } + #[test] fn test_insert() -> ResultTest<()> { let (db, mut input) = create_data(1)?; diff --git a/crates/physical-plan/src/plan.rs b/crates/physical-plan/src/plan.rs index b1572040f72..983d9aa882e 100644 --- a/crates/physical-plan/src/plan.rs +++ b/crates/physical-plan/src/plan.rs @@ -826,7 +826,10 @@ impl PhysicalPlan { Self::IxJoin(IxJoin { lhs, ..join }, Semi::Lhs) } Self::IxJoin(join, Semi::All) => { - let reqs = reqs.into_iter().filter(|label| label != &join.rhs_label).collect(); + let mut reqs: Vec<_> = reqs.into_iter().filter(|label| label != &join.rhs_label).collect(); + if !reqs.contains(&join.lhs_field.label) { + reqs.push(join.lhs_field.label); + } let lhs = join.lhs.introduce_semijoins(reqs); let lhs = Box::new(lhs); Self::IxJoin(IxJoin { lhs, ..join }, Semi::All)