diff --git a/src/substrait/type_inference.py b/src/substrait/type_inference.py index d6a68e8..cb8faad 100644 --- a/src/substrait/type_inference.py +++ b/src/substrait/type_inference.py @@ -1,7 +1,7 @@ import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.extended_expression_pb2 as stee -import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.type_pb2 as stt def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type: @@ -127,7 +127,7 @@ def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type: raise Exception(f"Unknown literal_type {literal_type}") -def infer_nested_type(nested: stalg.Expression.Nested) -> stt.Type: +def infer_nested_type(nested: stalg.Expression.Nested, parent_schema) -> stt.Type: nested_type = nested.WhichOneof("nested_type") nullability = ( @@ -139,22 +139,27 @@ def infer_nested_type(nested: stalg.Expression.Nested) -> stt.Type: if nested_type == "struct": return stt.Type( struct=stt.Type.Struct( - types=[infer_expression_type(f) for f in nested.struct.fields], + types=[ + infer_expression_type(f, parent_schema) + for f in nested.struct.fields + ], nullability=nullability, ) ) elif nested_type == "list": return stt.Type( list=stt.Type.List( - type=infer_expression_type(nested.list.values[0]), + type=infer_expression_type(nested.list.values[0], parent_schema), nullability=nullability, ) ) elif nested_type == "map": return stt.Type( map=stt.Type.Map( - key=infer_expression_type(nested.map.key_values[0].key), - value=infer_expression_type(nested.map.key_values[0].value), + key=infer_expression_type(nested.map.key_values[0].key, parent_schema), + value=infer_expression_type( + nested.map.key_values[0].value, parent_schema + ), nullability=nullability, ) ) @@ -191,9 +196,11 @@ def infer_expression_type( elif rex_type == "window_function": return expression.window_function.output_type elif rex_type == "if_then": - return infer_expression_type(expression.if_then.ifs[0].then) + return infer_expression_type(expression.if_then.ifs[0].then, parent_schema) elif rex_type == "switch_expression": - return infer_expression_type(expression.switch_expression.ifs[0].then) + return infer_expression_type( + expression.switch_expression.ifs[0].then, parent_schema + ) elif rex_type == "cast": return expression.cast.type elif rex_type == "singular_or_list" or rex_type == "multi_or_list": @@ -201,7 +208,7 @@ def infer_expression_type( bool=stt.Type.Boolean(nullability=stt.Type.Nullability.NULLABILITY_NULLABLE) ) elif rex_type == "nested": - return infer_nested_type(expression.nested) + return infer_nested_type(expression.nested, parent_schema) elif rex_type == "subquery": subquery_type = expression.subquery.WhichOneof("subquery_type") diff --git a/tests/test_type_inference.py b/tests/test_type_inference.py index d761672..8d20eca 100644 --- a/tests/test_type_inference.py +++ b/tests/test_type_inference.py @@ -1,7 +1,10 @@ import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.type_pb2 as stt -from substrait.type_inference import infer_rel_schema - +from substrait.type_inference import ( + infer_expression_type, + infer_nested_type, + infer_rel_schema, +) struct = stt.Type.Struct( types=[ @@ -312,3 +315,146 @@ def test_inference_join_left_mark(): ) assert infer_rel_schema(rel) == expected + + +def test_infer_expression_type_literal(): + """Test infer_expression_type with a literal expression.""" + expr = stalg.Expression(literal=stalg.Expression.Literal(i64=42, nullable=False)) + + result = infer_expression_type(expr, struct) + + expected = stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)) + assert result == expected + + +def test_infer_expression_type_selection(): + """Test infer_expression_type with a field selection expression.""" + expr = stalg.Expression( + selection=stalg.Expression.FieldReference( + root_reference=stalg.Expression.FieldReference.RootReference(), + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField(field=0), + ), + ) + ) + + result = infer_expression_type(expr, struct) + + # Should return the type of field 0 from the struct (i64) + expected = stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)) + assert result == expected + + +def test_infer_expression_type_window_function(): + """Test infer_expression_type with a window function expression.""" + expr = stalg.Expression( + window_function=stalg.Expression.WindowFunction( + function_reference=0, + output_type=stt.Type( + i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE) + ), + ) + ) + + result = infer_expression_type(expr, struct) + + expected = stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE)) + assert result == expected + + +def test_infer_nested_type_struct(): + """Test infer_nested_type with a struct nested expression.""" + expr = stalg.Expression( + nested=stalg.Expression.Nested( + struct=stalg.Expression.Nested.Struct( + fields=[ + stalg.Expression( + literal=stalg.Expression.Literal(i32=1, nullable=False) + ), + stalg.Expression( + literal=stalg.Expression.Literal(string="test", nullable=True) + ), + ] + ), + nullable=False, + ) + ) + + result = infer_nested_type(expr.nested, struct) + + expected = stt.Type( + struct=stt.Type.Struct( + types=[ + stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type( + string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE) + ), + ], + nullability=stt.Type.NULLABILITY_REQUIRED, + ) + ) + assert result == expected + + +def test_infer_nested_type_list(): + """Test infer_nested_type with a list nested expression.""" + expr = stalg.Expression( + nested=stalg.Expression.Nested( + list=stalg.Expression.Nested.List( + values=[ + stalg.Expression( + literal=stalg.Expression.Literal(fp32=3.14, nullable=False) + ), + ] + ), + nullable=False, + ) + ) + + result = infer_nested_type(expr.nested, struct) + + expected = stt.Type( + list=stt.Type.List( + type=stt.Type( + fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + nullability=stt.Type.NULLABILITY_REQUIRED, + ) + ) + assert result == expected + + +def test_infer_nested_type_map(): + """Test infer_nested_type with a map nested expression.""" + expr = stalg.Expression( + nested=stalg.Expression.Nested( + map=stalg.Expression.Nested.Map( + key_values=[ + stalg.Expression.Nested.Map.KeyValue( + key=stalg.Expression( + literal=stalg.Expression.Literal( + string="key", nullable=False + ) + ), + value=stalg.Expression( + literal=stalg.Expression.Literal(i32=42, nullable=False) + ), + ), + ] + ), + nullable=False, + ) + ) + + result = infer_nested_type(expr.nested, struct) + + expected = stt.Type( + map=stt.Type.Map( + key=stt.Type( + string=stt.Type.String(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + value=stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_REQUIRED)), + nullability=stt.Type.NULLABILITY_REQUIRED, + ) + ) + assert result == expected