Skip to content

Commit a7de15e

Browse files
committed
Fix call to internal function. Drive by update to dquality on logical plan. Switch unit test to focus on json parsing and not byte serialization.
1 parent 63da3ae commit a7de15e

File tree

7 files changed

+89
-37
lines changed

7 files changed

+89
-37
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,4 @@ crate-type = ["cdylib", "rlib"]
9090

9191
[profile.release]
9292
lto = true
93-
codegen-units = 1
93+
codegen-units = 1

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,6 @@ dev = [
185185
"toml>=0.10.2",
186186
"pygithub==2.5.0",
187187
"codespell==2.4.1",
188-
"protobuf>=6.33.5",
189-
"substrait>=0.27.0",
190188
]
191189
docs = [
192190
"sphinx>=7.1.2",
@@ -198,4 +196,4 @@ docs = [
198196
"pickleshare>=0.7.5",
199197
"sphinx-autoapi>=3.4.0",
200198
"setuptools>=75.3.0",
201-
]
199+
]

python/datafusion/plan.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def to_proto(self) -> bytes:
9898
"""
9999
return self._raw_plan.to_proto()
100100

101+
def __eq__(self, other: LogicalPlan) -> bool:
102+
"""Test equality."""
103+
if not isinstance(other, LogicalPlan):
104+
return False
105+
return self._raw_plan.__eq__(other._raw_plan)
106+
101107

102108
class ExecutionPlan:
103109
"""Represent nodes in the DataFusion Physical Plan."""

python/datafusion/substrait.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,17 @@ def to_json(self) -> str:
7676
return self.plan_internal.to_json()
7777

7878
@staticmethod
79-
def parse_json(json: str) -> Plan:
80-
"""
81-
Parse a plan from a JSON string.
79+
def from_json(json: str) -> Plan:
80+
"""Parse a plan from a JSON string representation.
8281
8382
Args:
8483
json: JSON representation of a Substrait plan.
8584
8685
Returns:
87-
PyPlan object representing the Substrait plan.
86+
Plan object representing the Substrait plan.
8887
"""
89-
return Plan(substrait_internal.Plan.parse_json(json))
88+
return Plan(substrait_internal.Plan.from_json(json))
89+
9090

9191
@deprecated("Use `Plan` instead.")
9292
class plan(Plan): # noqa: N801
@@ -210,4 +210,4 @@ def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan:
210210

211211
@deprecated("Use `Consumer` instead.")
212212
class consumer(Consumer): # noqa: N801
213-
"""Use `Consumer` instead."""
213+
"""Use `Consumer` instead."""

python/tests/test_substrait.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import pytest
2020
from datafusion import SessionContext
2121
from datafusion import substrait as ss
22-
from substrait import plan_pb2
23-
from google.protobuf import json_format
22+
2423

2524
@pytest.fixture
2625
def ctx():
@@ -77,23 +76,74 @@ def test_substrait_file_serialization(ctx, tmp_path, path_to_str):
7776
assert str(expected_logical_plan) == str(expected_actual_plan)
7877

7978

80-
def test_plan_json_stability_and_validator_compatibility(ctx):
81-
sql = "SELECT * FROM (VALUES (1, 4), (2, 5), (3, 6)) AS t(a, b)"
82-
83-
expected_binary_plan = ss.Serde.serialize_bytes(sql, ctx)
84-
actual_plan = ss.Serde.serialize_to_plan(sql, ctx)
85-
86-
json_str = actual_plan.to_json()
87-
reconstructed_plan = ss.Plan.parse_json(json_str)
88-
89-
expected_logical_plan = ss.Consumer.from_substrait_plan(ctx, actual_plan)
90-
actual_logical_plan = ss.Consumer.from_substrait_plan(ctx, reconstructed_plan)
91-
92-
assert str(expected_logical_plan) == str(actual_logical_plan)
93-
94-
# Verify that the JSON can be parsed by the Substrait protobuf library
95-
proto_plan = plan_pb2.Plan()
96-
json_format.Parse(json_str, proto_plan)
97-
actual_binary_plan = proto_plan.SerializeToString()
79+
def test_json_processing_round_trip(ctx: SessionContext):
80+
ctx.register_record_batches("t", [[pa.record_batch({"a": [1]})]])
81+
original_logical_plan = ctx.sql("SELECT * FROM t").logical_plan()
82+
83+
substrait_plan = ss.Producer.to_substrait_plan(original_logical_plan, ctx)
84+
json_plan = substrait_plan.to_json()
85+
86+
expected = """\
87+
"relations": [
88+
{
89+
"root": {
90+
"input": {
91+
"project": {
92+
"common": {
93+
"emit": {
94+
"outputMapping": [
95+
1
96+
]
97+
}
98+
},
99+
"input": {
100+
"read": {
101+
"baseSchema": {
102+
"names": [
103+
"a"
104+
],
105+
"struct": {
106+
"types": [
107+
{
108+
"i64": {
109+
"nullability": "NULLABILITY_NULLABLE"
110+
}
111+
}
112+
],
113+
"nullability": "NULLABILITY_REQUIRED"
114+
}
115+
},
116+
"namedTable": {
117+
"names": [
118+
"t"
119+
]
120+
}
121+
}
122+
},
123+
"expressions": [
124+
{
125+
"selection": {
126+
"directReference": {
127+
"structField": {}
128+
},
129+
"rootReference": {}
130+
}
131+
}
132+
]
133+
}
134+
},
135+
"names": [
136+
"a"
137+
]
138+
}
139+
}
140+
]"""
141+
142+
assert expected in json_plan
143+
144+
round_trip_substrait_plan = ss.Plan.from_json(json_plan)
145+
round_trip_logical_plan = ss.Consumer.from_substrait_plan(
146+
ctx, round_trip_substrait_plan
147+
)
98148

99-
assert expected_binary_plan == actual_binary_plan
149+
assert round_trip_logical_plan == original_logical_plan

src/sql/logical.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ use crate::expr::unnest::PyUnnest;
6666
use crate::expr::values::PyValues;
6767
use crate::expr::window::PyWindowExpr;
6868

69-
#[pyclass(frozen, name = "LogicalPlan", module = "datafusion", subclass)]
70-
#[derive(Debug, Clone)]
69+
#[pyclass(frozen, name = "LogicalPlan", module = "datafusion", subclass, eq)]
70+
#[derive(Debug, Clone, PartialEq, Eq)]
7171
pub struct PyLogicalPlan {
7272
pub(crate) plan: Arc<LogicalPlan>,
7373
}

src/substrait.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,14 @@ impl PyPlan {
4545

4646
/// Get the JSON representation of the substrait plan
4747
fn to_json(&self) -> PyDataFusionResult<String> {
48-
let json = serde_json::to_string_pretty(&self.plan)
49-
.map_err(|e| to_datafusion_err(e))?;
48+
let json = serde_json::to_string_pretty(&self.plan).map_err(|e| to_datafusion_err(e))?;
5049
Ok(json)
5150
}
5251

5352
/// Parse a Substrait Plan from its JSON representation
5453
#[staticmethod]
5554
fn from_json(json: &str) -> PyDataFusionResult<PyPlan> {
56-
let plan: Plan = serde_json::from_str(json)
57-
.map_err(|e| to_datafusion_err(e))?;
55+
let plan: Plan = serde_json::from_str(json).map_err(|e| to_datafusion_err(e))?;
5856
Ok(PyPlan { plan })
5957
}
6058
}

0 commit comments

Comments
 (0)