From 921cf808ce5d141607caad99a2f21ad7f483d444 Mon Sep 17 00:00:00 2001 From: AGV Date: Wed, 10 Sep 2025 09:20:23 +0200 Subject: [PATCH 1/2] feat(write_table): added the write table builder and test Signed-off-by: AGV --- src/substrait/builders/plan.py | 36 +++++++++++++++++++++-- tests/builders/plan/test_write.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 3 deletions(-) create mode 100644 tests/builders/plan/test_write.py diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index a4a2180..6b2ac6a 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -7,16 +7,16 @@ from typing import Iterable, Optional, Union, Callable -import substrait.gen.proto.algebra_pb2 as stalg from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.plan_pb2 as stp import substrait.gen.proto.type_pb2 as stt -import substrait.gen.proto.extended_expression_pb2 as stee -from substrait.extension_registry import ExtensionRegistry from substrait.builders.extended_expression import ( ExtendedExpressionOrUnbound, resolve_expression, ) +from substrait.extension_registry import ExtensionRegistry from substrait.type_inference import infer_plan_schema from substrait.utils import ( merge_extension_declarations, @@ -379,3 +379,33 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: ) return resolve + + +def write_table( + table_names: Union[str, Iterable[str]], + input: PlanOrUnbound, + create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None, +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_input = input if isinstance(input, stp.Plan) else input(registry) + ns = infer_plan_schema(bound_input) + _table_names = [table_names] if isinstance(table_names, str) else table_names + _create_mode = create_mode or stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS + + write_rel = stalg.Rel( + write=stalg.WriteRel( + input=bound_input.relations[-1].root.input, + table_schema=ns, + op=stalg.WriteRel.WRITE_OP_CTAS, + create_mode=_create_mode, + named_table=stalg.NamedObjectWrite(names=_table_names), + ) + ) + return stp.Plan( + relations=[ + stp.PlanRel(root=stalg.RelRoot(input=write_rel, names=ns.names)) + ], + **_merge_extensions(bound_input), + ) + + return resolve diff --git a/tests/builders/plan/test_write.py b/tests/builders/plan/test_write.py new file mode 100644 index 0000000..dff830e --- /dev/null +++ b/tests/builders/plan/test_write.py @@ -0,0 +1,48 @@ +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.type_pb2 as stt +from substrait.builders.plan import read_named_table, write_table +from substrait.builders.type import boolean, i64 + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + + +def test_write_rel(): + actual = write_table( + "example_table_write_test", + read_named_table("example_table", named_struct), + )(None) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + write=stalg.WriteRel( + input=stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon( + direct=stalg.RelCommon.Direct() + ), + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable( + names=["example_table"] + ), + ) + ), + op=stalg.WriteRel.WRITE_OP_CTAS, + table_schema=named_struct, + create_mode=stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS, + named_table=stalg.NamedObjectWrite( + names=["example_table_write_test"] + ), + ) + ), + names=["id", "is_applicable"], + ) + ) + ] + ) + assert actual == expected From 8d008ea17c395a86757984033c2ff68b9f8c6bbc Mon Sep 17 00:00:00 2001 From: Giovanni Spadaccini Date: Mon, 8 Dec 2025 10:33:41 +0100 Subject: [PATCH 2/2] feat: renamed the write_table to write_named_table --- src/substrait/builders/plan.py | 8 ++++---- tests/builders/plan/test_write.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index 6b2ac6a..392960f 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -5,9 +5,8 @@ See `examples/builder_example.py` for usage. """ -from typing import Iterable, Optional, Union, Callable +from typing import Callable, Iterable, Optional, Union -from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.plan_pb2 as stp @@ -17,11 +16,12 @@ resolve_expression, ) from substrait.extension_registry import ExtensionRegistry +from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension from substrait.type_inference import infer_plan_schema from substrait.utils import ( merge_extension_declarations, - merge_extension_urns, merge_extension_uris, + merge_extension_urns, ) UnboundPlan = Callable[[ExtensionRegistry], stp.Plan] @@ -381,7 +381,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: return resolve -def write_table( +def write_named_table( table_names: Union[str, Iterable[str]], input: PlanOrUnbound, create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None, diff --git a/tests/builders/plan/test_write.py b/tests/builders/plan/test_write.py index dff830e..b0e1029 100644 --- a/tests/builders/plan/test_write.py +++ b/tests/builders/plan/test_write.py @@ -1,7 +1,7 @@ import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.plan_pb2 as stp import substrait.gen.proto.type_pb2 as stt -from substrait.builders.plan import read_named_table, write_table +from substrait.builders.plan import read_named_table, write_named_table from substrait.builders.type import boolean, i64 struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) @@ -10,7 +10,7 @@ def test_write_rel(): - actual = write_table( + actual = write_named_table( "example_table_write_test", read_named_table("example_table", named_struct), )(None)