Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/substrait/type_inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import substrait.gen.proto.algebra_pb2 as stalg
import substrait.gen.proto.extended_expression_pb2 as stee
import substrait.gen.proto.type_pb2 as stt
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.type_pb2 as stt


def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type:
Expand Down Expand Up @@ -127,7 +127,7 @@ def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type:
raise Exception(f"Unknown literal_type {literal_type}")


def infer_nested_type(nested: stalg.Expression.Nested) -> stt.Type:
def infer_nested_type(nested: stalg.Expression.Nested, parent_schema) -> stt.Type:
nested_type = nested.WhichOneof("nested_type")

nullability = (
Expand All @@ -139,22 +139,27 @@ def infer_nested_type(nested: stalg.Expression.Nested) -> stt.Type:
if nested_type == "struct":
return stt.Type(
struct=stt.Type.Struct(
types=[infer_expression_type(f) for f in nested.struct.fields],
types=[
infer_expression_type(f, parent_schema)
for f in nested.struct.fields
],
nullability=nullability,
)
)
elif nested_type == "list":
return stt.Type(
list=stt.Type.List(
type=infer_expression_type(nested.list.values[0]),
type=infer_expression_type(nested.list.values[0], parent_schema),
nullability=nullability,
)
)
elif nested_type == "map":
return stt.Type(
map=stt.Type.Map(
key=infer_expression_type(nested.map.key_values[0].key),
value=infer_expression_type(nested.map.key_values[0].value),
key=infer_expression_type(nested.map.key_values[0].key, parent_schema),
value=infer_expression_type(
nested.map.key_values[0].value, parent_schema
),
nullability=nullability,
)
)
Expand Down Expand Up @@ -191,17 +196,19 @@ def infer_expression_type(
elif rex_type == "window_function":
return expression.window_function.output_type
elif rex_type == "if_then":
return infer_expression_type(expression.if_then.ifs[0].then)
return infer_expression_type(expression.if_then.ifs[0].then, parent_schema)
elif rex_type == "switch_expression":
return infer_expression_type(expression.switch_expression.ifs[0].then)
return infer_expression_type(
expression.switch_expression.ifs[0].then, parent_schema
)
elif rex_type == "cast":
return expression.cast.type
elif rex_type == "singular_or_list" or rex_type == "multi_or_list":
return stt.Type(
bool=stt.Type.Boolean(nullability=stt.Type.Nullability.NULLABILITY_NULLABLE)
)
elif rex_type == "nested":
return infer_nested_type(expression.nested)
return infer_nested_type(expression.nested, parent_schema)
elif rex_type == "subquery":
subquery_type = expression.subquery.WhichOneof("subquery_type")

Expand Down
150 changes: 148 additions & 2 deletions tests/test_type_inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import substrait.gen.proto.algebra_pb2 as stalg
import substrait.gen.proto.type_pb2 as stt
from substrait.type_inference import infer_rel_schema

from substrait.type_inference import (
infer_expression_type,
infer_nested_type,
infer_rel_schema,
)

struct = stt.Type.Struct(
types=[
Expand Down Expand Up @@ -312,3 +315,146 @@ def test_inference_join_left_mark():
)

assert infer_rel_schema(rel) == expected


def test_infer_expression_type_literal():
"""Test infer_expression_type with a literal expression."""
expr = stalg.Expression(literal=stalg.Expression.Literal(i64=42, nullable=False))

result = infer_expression_type(expr, struct)

expected = stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED))
assert result == expected


def test_infer_expression_type_selection():
"""Test infer_expression_type with a field selection expression."""
expr = stalg.Expression(
selection=stalg.Expression.FieldReference(
root_reference=stalg.Expression.FieldReference.RootReference(),
direct_reference=stalg.Expression.ReferenceSegment(
struct_field=stalg.Expression.ReferenceSegment.StructField(field=0),
),
)
)

result = infer_expression_type(expr, struct)

# Should return the type of field 0 from the struct (i64)
expected = stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED))
assert result == expected


def test_infer_expression_type_window_function():
"""Test infer_expression_type with a window function expression."""
expr = stalg.Expression(
window_function=stalg.Expression.WindowFunction(
function_reference=0,
output_type=stt.Type(
i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE)
),
)
)

result = infer_expression_type(expr, struct)

expected = stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE))
assert result == expected


def test_infer_nested_type_struct():
"""Test infer_nested_type with a struct nested expression."""
expr = stalg.Expression(
nested=stalg.Expression.Nested(
struct=stalg.Expression.Nested.Struct(
fields=[
stalg.Expression(
literal=stalg.Expression.Literal(i32=1, nullable=False)
),
stalg.Expression(
literal=stalg.Expression.Literal(string="test", nullable=True)
),
]
),
nullable=False,
)
)

result = infer_nested_type(expr.nested, struct)

expected = stt.Type(
struct=stt.Type.Struct(
types=[
stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_REQUIRED)),
stt.Type(
string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)
),
],
nullability=stt.Type.NULLABILITY_REQUIRED,
)
)
assert result == expected


def test_infer_nested_type_list():
"""Test infer_nested_type with a list nested expression."""
expr = stalg.Expression(
nested=stalg.Expression.Nested(
list=stalg.Expression.Nested.List(
values=[
stalg.Expression(
literal=stalg.Expression.Literal(fp32=3.14, nullable=False)
),
]
),
nullable=False,
)
)

result = infer_nested_type(expr.nested, struct)

expected = stt.Type(
list=stt.Type.List(
type=stt.Type(
fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_REQUIRED)
),
nullability=stt.Type.NULLABILITY_REQUIRED,
)
)
assert result == expected


def test_infer_nested_type_map():
"""Test infer_nested_type with a map nested expression."""
expr = stalg.Expression(
nested=stalg.Expression.Nested(
map=stalg.Expression.Nested.Map(
key_values=[
stalg.Expression.Nested.Map.KeyValue(
key=stalg.Expression(
literal=stalg.Expression.Literal(
string="key", nullable=False
)
),
value=stalg.Expression(
literal=stalg.Expression.Literal(i32=42, nullable=False)
),
),
]
),
nullable=False,
)
)

result = infer_nested_type(expr.nested, struct)

expected = stt.Type(
map=stt.Type.Map(
key=stt.Type(
string=stt.Type.String(nullability=stt.Type.NULLABILITY_REQUIRED)
),
value=stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_REQUIRED)),
nullability=stt.Type.NULLABILITY_REQUIRED,
)
)
assert result == expected