From 1f514038f19384e0202fa302a3ac9ab1e19d586a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 17 May 2026 14:03:23 -0400 Subject: [PATCH] feat: user-defined OptimizerRule and AnalyzerRule from Python Expose `SessionContext.add_optimizer_rule` and `SessionContext.add_analyzer_rule` symmetric with the existing `remove_optimizer_rule`. Each accepts a Python subclass of the new `datafusion.optimizer.OptimizerRule` / `AnalyzerRule` ABCs. Implementation: * New `crates/core/src/optimizer_rules.rs` wraps user Python instances in `PyOptimizerRuleAdapter` / `PyAnalyzerRuleAdapter`, which implement the upstream `OptimizerRule` / `AnalyzerRule` traits. * `OptimizerRule.rewrite(plan)` returns `None` for "no change" or a new `LogicalPlan`. The adapter maps that to `Transformed::no` / `Transformed::yes` so the upstream optimizer's fixed-point loop terminates correctly. * `AnalyzerRule.analyze(plan)` must always return a `LogicalPlan`; returning `None` surfaces a `DataFusionError::Execution` naming the offending rule. * The upstream `&dyn OptimizerConfig` / `&ConfigOptions` arguments are not surfaced to Python in this MVP; rules that need configuration should capture it at construction time (for example by holding a `SessionContext` reference) or be implemented in Rust. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/context.rs | 12 +++ crates/core/src/lib.rs | 1 + crates/core/src/optimizer_rules.rs | 168 +++++++++++++++++++++++++++++ python/datafusion/context.py | 47 ++++++++ python/datafusion/optimizer.py | 144 +++++++++++++++++++++++++ python/tests/test_optimizer.py | 111 +++++++++++++++++++ 6 files changed, 483 insertions(+) create mode 100644 crates/core/src/optimizer_rules.rs create mode 100644 python/datafusion/optimizer.py create mode 100644 python/tests/test_optimizer.py diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 642afeef7..67f8e001a 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -1145,6 +1145,18 @@ impl PySessionContext { self.ctx.remove_optimizer_rule(name) } + pub fn add_optimizer_rule(&self, rule: Bound<'_, PyAny>) -> PyResult<()> { + let adapter = crate::optimizer_rules::build_optimizer_rule(rule)?; + self.ctx.add_optimizer_rule(adapter); + Ok(()) + } + + pub fn add_analyzer_rule(&self, rule: Bound<'_, PyAny>) -> PyResult<()> { + let adapter = crate::optimizer_rules::build_analyzer_rule(rule)?; + self.ctx.add_analyzer_rule(adapter); + Ok(()) + } + pub fn table_provider(&self, name: &str, py: Python) -> PyResult { let provider = wait_for_future(py, self.ctx.table_provider(name)) // Outer error: runtime/async failure diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 8b622d344..1c1227ce2 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -45,6 +45,7 @@ pub mod expr; #[allow(clippy::borrow_deref_ref)] mod functions; pub mod metrics; +pub mod optimizer_rules; mod options; pub mod physical_plan; mod pyarrow_filter_expression; diff --git a/crates/core/src/optimizer_rules.rs b/crates/core/src/optimizer_rules.rs new file mode 100644 index 000000000..a281272ed --- /dev/null +++ b/crates/core/src/optimizer_rules.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Bridges between user-provided Python rule classes and the upstream +//! [`OptimizerRule`] / [`AnalyzerRule`] traits. +//! +//! The Python side defines abstract base classes ``OptimizerRule`` and +//! ``AnalyzerRule`` with ``name()`` plus, respectively, ``rewrite(plan)`` +//! and ``analyze(plan)``. Instances are wrapped in +//! [`PyOptimizerRuleAdapter`] / [`PyAnalyzerRuleAdapter`] before being +//! handed to [`SessionContext::add_optimizer_rule`] / +//! [`SessionContext::add_analyzer_rule`]. +//! +//! `rewrite` may return ``None`` to signal "no transformation" — the +//! adapter maps that to [`Transformed::no`]. Any returned +//! :class:`LogicalPlan` becomes [`Transformed::yes`]. `analyze` is +//! mandatory-rewrite (must return a plan); returning ``None`` is an +//! error. +//! +//! The upstream ``&dyn OptimizerConfig`` / ``&ConfigOptions`` arguments +//! are not surfaced to Python in this MVP. Rules that need configuration +//! access should be implemented in Rust today; Python rules read state +//! from the plan and from any captured ``SessionContext`` they were +//! constructed with. + +use std::fmt; +use std::sync::Arc; + +use datafusion::common::config::ConfigOptions; +use datafusion::common::tree_node::Transformed; +use datafusion::error::{DataFusionError, Result as DataFusionResult}; +use datafusion::logical_expr::LogicalPlan; +use datafusion::optimizer::analyzer::AnalyzerRule; +use datafusion::optimizer::optimizer::{OptimizerConfig, OptimizerRule}; +use pyo3::prelude::*; + +use crate::errors::to_datafusion_err; +use crate::sql::logical::PyLogicalPlan; + +/// Wraps a Python ``OptimizerRule`` instance so that it can be registered +/// with the upstream optimizer pipeline. +pub struct PyOptimizerRuleAdapter { + rule: Py, + name: String, +} + +impl PyOptimizerRuleAdapter { + pub fn new(rule: Bound<'_, PyAny>) -> PyResult { + let name = rule.call_method0("name")?.extract::()?; + Ok(Self { + rule: rule.unbind(), + name, + }) + } +} + +impl fmt::Debug for PyOptimizerRuleAdapter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PyOptimizerRuleAdapter") + .field("name", &self.name) + .finish() + } +} + +impl OptimizerRule for PyOptimizerRuleAdapter { + fn name(&self) -> &str { + &self.name + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> DataFusionResult> { + Python::attach(|py| { + let py_plan = PyLogicalPlan::from(plan.clone()); + let result = self + .rule + .bind(py) + .call_method1("rewrite", (py_plan,)) + .map_err(to_datafusion_err)?; + if result.is_none() { + return Ok(Transformed::no(plan)); + } + let rewritten: PyLogicalPlan = result.extract().map_err(to_datafusion_err)?; + Ok(Transformed::yes(LogicalPlan::from(rewritten))) + }) + } +} + +/// Wraps a Python ``AnalyzerRule`` instance so that it can be registered +/// with the upstream analyzer pipeline. +pub struct PyAnalyzerRuleAdapter { + rule: Py, + name: String, +} + +impl PyAnalyzerRuleAdapter { + pub fn new(rule: Bound<'_, PyAny>) -> PyResult { + let name = rule.call_method0("name")?.extract::()?; + Ok(Self { + rule: rule.unbind(), + name, + }) + } +} + +impl fmt::Debug for PyAnalyzerRuleAdapter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PyAnalyzerRuleAdapter") + .field("name", &self.name) + .finish() + } +} + +impl AnalyzerRule for PyAnalyzerRuleAdapter { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> DataFusionResult { + Python::attach(|py| { + let py_plan = PyLogicalPlan::from(plan); + let result = self + .rule + .bind(py) + .call_method1("analyze", (py_plan,)) + .map_err(to_datafusion_err)?; + if result.is_none() { + return Err(DataFusionError::Execution(format!( + "AnalyzerRule {} returned None from analyze(); analyzer rules \ + must return a LogicalPlan", + self.name + ))); + } + let rewritten: PyLogicalPlan = result.extract().map_err(to_datafusion_err)?; + Ok(LogicalPlan::from(rewritten)) + }) + } + + fn name(&self) -> &str { + &self.name + } +} + +/// Construct an adapter from a Python ``OptimizerRule`` instance. +pub(crate) fn build_optimizer_rule( + rule: Bound<'_, PyAny>, +) -> PyResult> { + Ok(Arc::new(PyOptimizerRuleAdapter::new(rule)?)) +} + +/// Construct an adapter from a Python ``AnalyzerRule`` instance. +pub(crate) fn build_analyzer_rule( + rule: Bound<'_, PyAny>, +) -> PyResult> { + Ok(Arc::new(PyAnalyzerRuleAdapter::new(rule)?)) +} diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 5c3501941..e2a4305c4 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -90,6 +90,7 @@ from datafusion.catalog import CatalogProvider, Table from datafusion.common import DFSchema from datafusion.expr import Expr, SortKey + from datafusion.optimizer import AnalyzerRule, OptimizerRule from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.user_defined import ( AggregateUDF, @@ -1260,6 +1261,52 @@ def register_udwf(self, udwf: WindowUDF) -> None: """Register a user-defined window function (UDWF) with the context.""" self.ctx.register_udwf(udwf._udwf) + def add_optimizer_rule(self, rule: OptimizerRule) -> None: + """Append a user-defined :class:`OptimizerRule` to the session. + + The rule's :py:meth:`OptimizerRule.rewrite` method is invoked + during query planning. Returning ``None`` from ``rewrite`` + signals no change; returning a new + :class:`~datafusion.plan.LogicalPlan` signals a rewrite. + + Args: + rule: An instance of a class that implements + :class:`datafusion.optimizer.OptimizerRule`. + + Examples: + >>> from datafusion.optimizer import OptimizerRule + >>> class NoopRule(OptimizerRule): + ... def name(self) -> str: return "noop" + ... def rewrite(self, plan): return None + >>> ctx = dfn.SessionContext() + >>> ctx.add_optimizer_rule(NoopRule()) + >>> ctx.remove_optimizer_rule("noop") + True + """ + self.ctx.add_optimizer_rule(rule) + + def add_analyzer_rule(self, rule: AnalyzerRule) -> None: + """Append a user-defined :class:`AnalyzerRule` to the session. + + The rule's :py:meth:`AnalyzerRule.analyze` method is invoked + during the analysis phase of query planning. Analyzer rules + must always return a :class:`~datafusion.plan.LogicalPlan` + (return the input plan unchanged when no rewrite applies). + + Args: + rule: An instance of a class that implements + :class:`datafusion.optimizer.AnalyzerRule`. + + Examples: + >>> from datafusion.optimizer import AnalyzerRule + >>> class Identity(AnalyzerRule): + ... def name(self) -> str: return "identity" + ... def analyze(self, plan): return plan + >>> ctx = dfn.SessionContext() + >>> ctx.add_analyzer_rule(Identity()) + """ + self.ctx.add_analyzer_rule(rule) + def deregister_udwf(self, name: str) -> None: """Remove a user-defined window function from the session. diff --git a/python/datafusion/optimizer.py b/python/datafusion/optimizer.py new file mode 100644 index 000000000..84beca001 --- /dev/null +++ b/python/datafusion/optimizer.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Abstract base classes for user-defined optimizer and analyzer rules. + +DataFusion's planner is built from two pipelines: + +* The :class:`Analyzer ` runs first and is responsible for + semantic rewrites — type coercion, function lookup, and rewrites that + cannot leave the plan structurally unchanged. Analyzer rules must + return a fully rewritten :class:`~datafusion.plan.LogicalPlan` every + time they run. +* The :class:`Optimizer ` runs afterwards and applies + cost-driven or semantics-preserving transformations until a fixed + point is reached. Optimizer rules may return ``None`` to signal "no + change," letting the optimizer terminate as soon as no rule mutates + the plan. + +Both ABCs are registered against a :class:`~datafusion.SessionContext` +through :py:meth:`~datafusion.SessionContext.add_optimizer_rule` / +:py:meth:`~datafusion.SessionContext.add_analyzer_rule`. + +The upstream rule traits also receive an +``OptimizerConfig`` / ``ConfigOptions`` reference. Those are not +surfaced to Python here; rules that need configuration access should +capture state at construction time (for example a +:class:`~datafusion.SessionContext` reference) or be implemented in +Rust. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from datafusion.plan import LogicalPlan + +__all__ = ["AnalyzerRule", "OptimizerRule"] + + +class OptimizerRule(ABC): + """Abstract base class for a user-defined optimizer rule. + + Subclasses must implement :py:meth:`name` and :py:meth:`rewrite`. + + Examples: + >>> import datafusion as dfn + >>> from datafusion.optimizer import OptimizerRule + >>> from datafusion.plan import LogicalPlan + >>> + >>> class TaggingRule(OptimizerRule): + ... # Mark each plan we see; never actually mutate it. + ... def __init__(self) -> None: + ... self.seen = 0 + ... + ... def name(self) -> str: + ... return "tagging_rule" + ... + ... def rewrite(self, plan: LogicalPlan) -> LogicalPlan | None: + ... self.seen += 1 + ... return None + >>> + >>> ctx = dfn.SessionContext() + >>> rule = TaggingRule() + >>> ctx.add_optimizer_rule(rule) + >>> ctx.from_pydict({"a": [1]}).count() + 1 + >>> rule.seen > 0 + True + """ + + @abstractmethod + def name(self) -> str: + """Return a unique name for this rule. + + DataFusion uses the name to deduplicate rules and to support + removal via :py:meth:`~datafusion.SessionContext.remove_optimizer_rule`. + """ + + @abstractmethod + def rewrite(self, plan: LogicalPlan) -> LogicalPlan | None: + """Attempt to rewrite ``plan``. + + Return a new :class:`~datafusion.plan.LogicalPlan` if the rule + produced one, or ``None`` to indicate no change. The optimizer + calls each rule repeatedly until no rule reports a change, so + returning ``None`` when nothing was rewritten is important for + termination. + """ + + +class AnalyzerRule(ABC): + """Abstract base class for a user-defined analyzer rule. + + Subclasses must implement :py:meth:`name` and :py:meth:`analyze`. + Unlike optimizer rules, analyzer rules must always return a + :class:`~datafusion.plan.LogicalPlan` (return the input plan + unmodified when nothing applies). + + Examples: + >>> import datafusion as dfn + >>> from datafusion.optimizer import AnalyzerRule + >>> from datafusion.plan import LogicalPlan + >>> + >>> class IdentityAnalyzer(AnalyzerRule): + ... def name(self) -> str: + ... return "identity_analyzer" + ... + ... def analyze(self, plan: LogicalPlan) -> LogicalPlan: + ... return plan + >>> + >>> ctx = dfn.SessionContext() + >>> ctx.add_analyzer_rule(IdentityAnalyzer()) + >>> ctx.from_pydict({"a": [1, 2, 3]}).count() + 3 + """ + + @abstractmethod + def name(self) -> str: + """Return a unique name for this rule.""" + + @abstractmethod + def analyze(self, plan: LogicalPlan) -> LogicalPlan: + """Rewrite ``plan`` and return the new plan. + + Analyzer rules must always return a + :class:`~datafusion.plan.LogicalPlan`. Return the input plan + unchanged when there is nothing to rewrite. + """ diff --git a/python/tests/test_optimizer.py b/python/tests/test_optimizer.py new file mode 100644 index 000000000..5c161573b --- /dev/null +++ b/python/tests/test_optimizer.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from datafusion import SessionContext +from datafusion.optimizer import AnalyzerRule, OptimizerRule +from datafusion.plan import LogicalPlan + + +def test_optimizer_rule_is_invoked_during_planning() -> None: + """A registered OptimizerRule is called as the plan is optimized.""" + seen_plans: list[str] = [] + + class TracingRule(OptimizerRule): + def name(self) -> str: + return "tracing_rule" + + def rewrite(self, plan: LogicalPlan) -> LogicalPlan | None: + seen_plans.append(plan.display()) + return None + + ctx = SessionContext() + ctx.add_optimizer_rule(TracingRule()) + df = ctx.from_pydict({"a": [1, 2, 3]}) + result = df.collect() + + # The rule sees the plan at least once during optimization, and + # since it returns None each time the optimizer terminates cleanly. + assert seen_plans, "optimizer rule was not invoked during planning" + assert result[0].column(0).to_pylist() == [1, 2, 3] + + +def test_optimizer_rule_can_be_removed_by_name() -> None: + """remove_optimizer_rule deregisters a user-supplied rule by name.""" + + class NoopRule(OptimizerRule): + def name(self) -> str: + return "noop_for_removal" + + def rewrite(self, plan: LogicalPlan) -> LogicalPlan | None: + return None + + ctx = SessionContext() + ctx.add_optimizer_rule(NoopRule()) + assert ctx.remove_optimizer_rule("noop_for_removal") is True + # Second remove returns False — already gone. + assert ctx.remove_optimizer_rule("noop_for_removal") is False + + +def test_analyzer_rule_is_invoked_during_analysis() -> None: + """A registered AnalyzerRule is called and must return a plan.""" + invocations: list[str] = [] + + class IdentityAnalyzer(AnalyzerRule): + def name(self) -> str: + return "identity_analyzer" + + def analyze(self, plan: LogicalPlan) -> LogicalPlan: + invocations.append(plan.display()) + return plan + + ctx = SessionContext() + ctx.add_analyzer_rule(IdentityAnalyzer()) + df = ctx.from_pydict({"a": [1, 2, 3]}) + result = df.collect() + + assert invocations, "analyzer rule was not invoked" + assert result[0].column(0).to_pylist() == [1, 2, 3] + + +def test_analyzer_rule_returning_none_errors() -> None: + """Analyzer rules must return a LogicalPlan; None surfaces as an error.""" + + class BadAnalyzer(AnalyzerRule): + def name(self) -> str: + return "bad_analyzer" + + def analyze(self, plan: LogicalPlan): # type: ignore[override] + return None + + ctx = SessionContext() + ctx.add_analyzer_rule(BadAnalyzer()) + df = ctx.from_pydict({"a": [1]}) + with pytest.raises(Exception, match="bad_analyzer"): + df.collect() + + +def test_optimizer_rule_abc_cannot_be_instantiated() -> None: + """OptimizerRule is abstract — direct instantiation must fail.""" + with pytest.raises(TypeError): + OptimizerRule() # type: ignore[abstract] + + +def test_analyzer_rule_abc_cannot_be_instantiated() -> None: + """AnalyzerRule is abstract — direct instantiation must fail.""" + with pytest.raises(TypeError): + AnalyzerRule() # type: ignore[abstract]