Skip to content

Commit af2a6ed

Browse files
rustyconoverclaude
andcommitted
Bump version to 0.1.15 and add cast-compatible exchange conformance tests
Add exchange_cast_compatible method and tests verifying server-side RecordBatch.cast() for compatible schema mismatches (int32/int64/float32 to float64) and rejection of incompatible column names. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e342f5c commit af2a6ed

6 files changed

Lines changed: 162 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vgi-rpc"
3-
version = "0.1.14"
3+
version = "0.1.15"
44
description = "Vector Gateway Interface - RPC framework based on Apache Arrow"
55
readme = "README.md"
66
requires-python = ">=3.13"

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vgi_rpc/conformance/_impl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,14 @@ def exchange_error_on_nth(self, fail_on: int) -> Stream[FailOnExchangeNState]:
307307
input_schema=_SCALE_INPUT_SCHEMA,
308308
)
309309

310+
def exchange_cast_compatible(self) -> Stream[ScaleExchangeState]:
311+
"""Exchange expecting float64 input — echoes values via factor=1.0."""
312+
return Stream(
313+
output_schema=_SCALE_OUTPUT_SCHEMA,
314+
state=ScaleExchangeState(factor=1.0),
315+
input_schema=_SCALE_INPUT_SCHEMA,
316+
)
317+
310318
def exchange_error_on_init(self) -> Stream[ScaleExchangeState]:
311319
"""Raise during exchange init."""
312320
raise RuntimeError("intentional exchange init error")

vgi_rpc/conformance/_protocol.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
"""ConformanceService Protocol definition.
55
6-
Defines ~44 RPC methods covering every framework capability:
6+
Defines ~45 RPC methods covering every framework capability:
77
scalar echo, void, complex types, optionals, dataclass round-trip,
88
annotated types, multi-param, errors, logging, producer streams,
99
exchange streams, headers, and introspection.
@@ -296,6 +296,10 @@ def produce_dynamic_schema(
296296
"""
297297
...
298298

299+
def exchange_cast_compatible(self) -> Stream[StreamState]:
300+
"""Exchange expecting float64 input — tests server-side cast for compatible schemas."""
301+
...
302+
299303
def exchange_with_rich_header(self, seed: int, factor: float) -> Stream[StreamState, RichHeader]:
300304
"""Exchange stream with a rich multi-type header.
301305

vgi_rpc/conformance/_pytest_suite.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,76 @@ def test_zero_row_input(self, conformance_conn: ConnFactory) -> None:
834834
assert out.batch.num_rows == 0
835835

836836

837+
# ---------------------------------------------------------------------------
838+
# Exchange Streams: Cast-Compatible Schemas
839+
# ---------------------------------------------------------------------------
840+
841+
842+
class TestExchangeCastCompatible:
843+
"""Test that exchange streams cast compatible input schemas (e.g. int32 -> float64)."""
844+
845+
def test_cast_int32_to_float64(self, conformance_conn: ConnFactory) -> None:
846+
"""Send int32 values to a float64 exchange, expect float64 output."""
847+
with conformance_conn() as proxy, proxy.exchange_cast_compatible() as session:
848+
batch = pa.record_batch(
849+
[pa.array([1, 2, 3], type=pa.int32())],
850+
schema=pa.schema([pa.field("value", pa.int32())]),
851+
)
852+
out = session.exchange(AnnotatedBatch(batch=batch))
853+
assert out.batch.schema.field("value").type == pa.float64()
854+
assert out.batch.column("value").to_pylist() == [
855+
pytest.approx(1.0),
856+
pytest.approx(2.0),
857+
pytest.approx(3.0),
858+
]
859+
860+
def test_cast_int64_to_float64(self, conformance_conn: ConnFactory) -> None:
861+
"""Send int64 values to a float64 exchange, expect float64 output."""
862+
with conformance_conn() as proxy, proxy.exchange_cast_compatible() as session:
863+
batch = pa.record_batch(
864+
[pa.array([10, 20, 30], type=pa.int64())],
865+
schema=pa.schema([pa.field("value", pa.int64())]),
866+
)
867+
out = session.exchange(AnnotatedBatch(batch=batch))
868+
assert out.batch.schema.field("value").type == pa.float64()
869+
assert out.batch.column("value").to_pylist() == [
870+
pytest.approx(10.0),
871+
pytest.approx(20.0),
872+
pytest.approx(30.0),
873+
]
874+
875+
def test_cast_float32_to_float64(self, conformance_conn: ConnFactory) -> None:
876+
"""Send float32 values to a float64 exchange, expect float64 output."""
877+
with conformance_conn() as proxy, proxy.exchange_cast_compatible() as session:
878+
batch = pa.record_batch(
879+
[pa.array([1.5, 2.5, 3.5], type=pa.float32())],
880+
schema=pa.schema([pa.field("value", pa.float32())]),
881+
)
882+
out = session.exchange(AnnotatedBatch(batch=batch))
883+
assert out.batch.schema.field("value").type == pa.float64()
884+
assert out.batch.column("value").to_pylist() == [
885+
pytest.approx(1.5),
886+
pytest.approx(2.5),
887+
pytest.approx(3.5),
888+
]
889+
890+
def test_cast_exact_schema(self, conformance_conn: ConnFactory) -> None:
891+
"""Send matching float64 values — no cast needed."""
892+
with conformance_conn() as proxy, proxy.exchange_cast_compatible() as session:
893+
out = session.exchange(AnnotatedBatch.from_pydict({"value": [5.0, 10.0]}))
894+
assert out.batch.column("value").to_pylist() == [pytest.approx(5.0), pytest.approx(10.0)]
895+
896+
def test_cast_incompatible_column_name(self, conformance_conn: ConnFactory) -> None:
897+
"""Send wrong column name, expect RpcError."""
898+
with conformance_conn() as proxy, proxy.exchange_cast_compatible() as session:
899+
batch = pa.record_batch(
900+
[pa.array([1.0], type=pa.float64())],
901+
schema=pa.schema([pa.field("wrong", pa.float64())]),
902+
)
903+
with pytest.raises(RpcError):
904+
session.exchange(AnnotatedBatch(batch=batch))
905+
906+
837907
# ---------------------------------------------------------------------------
838908
# Exchange Streams With Headers
839909
# ---------------------------------------------------------------------------
@@ -933,7 +1003,7 @@ def test_run_describe_conformance(self, service_description: ServiceDescription)
9331003

9341004
def test_describe_via_rpc(self, service_description: ServiceDescription) -> None:
9351005
"""Smoke test: basic transport-level describe call works."""
936-
assert len(service_description.methods) == 47
1006+
assert len(service_description.methods) == 48
9371007
assert service_description.protocol_name == "ConformanceService"
9381008
echo_str = service_description.methods["echo_string"]
9391009
assert echo_str.method_type == MethodType.UNARY

vgi_rpc/conformance/_runner.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,79 @@ def _test_exchange_empty_session(proxy: ConformanceService, logs: LogCollector)
783783
assert proxy.echo_int(value=42) == 42
784784

785785

786+
# ---------------------------------------------------------------------------
787+
# Exchange cast-compatible tests
788+
# ---------------------------------------------------------------------------
789+
790+
791+
@_conformance_test(category="exchange_stream", name="cast_int32_to_float64")
792+
def _test_cast_int32(proxy: ConformanceService, logs: LogCollector) -> None:
793+
with proxy.exchange_cast_compatible() as session:
794+
batch = pa.record_batch(
795+
[pa.array([1, 2, 3], type=pa.int32())],
796+
schema=pa.schema([pa.field("value", pa.int32())]),
797+
)
798+
out = session.exchange(AnnotatedBatch(batch=batch))
799+
assert out.batch.schema.field("value").type == pa.float64()
800+
values = cast("list[float]", out.batch.column("value").to_pylist())
801+
assert abs(values[0] - 1.0) < 1e-6
802+
assert abs(values[1] - 2.0) < 1e-6
803+
assert abs(values[2] - 3.0) < 1e-6
804+
805+
806+
@_conformance_test(category="exchange_stream", name="cast_int64_to_float64")
807+
def _test_cast_int64(proxy: ConformanceService, logs: LogCollector) -> None:
808+
with proxy.exchange_cast_compatible() as session:
809+
batch = pa.record_batch(
810+
[pa.array([10, 20, 30], type=pa.int64())],
811+
schema=pa.schema([pa.field("value", pa.int64())]),
812+
)
813+
out = session.exchange(AnnotatedBatch(batch=batch))
814+
assert out.batch.schema.field("value").type == pa.float64()
815+
values = cast("list[float]", out.batch.column("value").to_pylist())
816+
assert abs(values[0] - 10.0) < 1e-6
817+
assert abs(values[1] - 20.0) < 1e-6
818+
assert abs(values[2] - 30.0) < 1e-6
819+
820+
821+
@_conformance_test(category="exchange_stream", name="cast_float32_to_float64")
822+
def _test_cast_float32(proxy: ConformanceService, logs: LogCollector) -> None:
823+
with proxy.exchange_cast_compatible() as session:
824+
batch = pa.record_batch(
825+
[pa.array([1.5, 2.5, 3.5], type=pa.float32())],
826+
schema=pa.schema([pa.field("value", pa.float32())]),
827+
)
828+
out = session.exchange(AnnotatedBatch(batch=batch))
829+
assert out.batch.schema.field("value").type == pa.float64()
830+
values = cast("list[float]", out.batch.column("value").to_pylist())
831+
assert abs(values[0] - 1.5) < 1e-6
832+
assert abs(values[1] - 2.5) < 1e-6
833+
assert abs(values[2] - 3.5) < 1e-6
834+
835+
836+
@_conformance_test(category="exchange_stream", name="cast_exact_schema")
837+
def _test_cast_exact(proxy: ConformanceService, logs: LogCollector) -> None:
838+
with proxy.exchange_cast_compatible() as session:
839+
out = session.exchange(AnnotatedBatch.from_pydict({"value": [5.0, 10.0]}))
840+
values = cast("list[float]", out.batch.column("value").to_pylist())
841+
assert abs(values[0] - 5.0) < 1e-6
842+
assert abs(values[1] - 10.0) < 1e-6
843+
844+
845+
@_conformance_test(category="exchange_stream", name="cast_incompatible_column_name")
846+
def _test_cast_incompatible(proxy: ConformanceService, logs: LogCollector) -> None:
847+
with proxy.exchange_cast_compatible() as session:
848+
batch = pa.record_batch(
849+
[pa.array([1.0], type=pa.float64())],
850+
schema=pa.schema([pa.field("wrong", pa.float64())]),
851+
)
852+
try:
853+
session.exchange(AnnotatedBatch(batch=batch))
854+
raise AssertionError("Expected RpcError")
855+
except RpcError as e:
856+
assert "TypeError" in str(e) or "type" in str(e).lower()
857+
858+
786859
# ---------------------------------------------------------------------------
787860
# Exchange header tests
788861
# ---------------------------------------------------------------------------
@@ -1074,6 +1147,7 @@ def decorator(
10741147
"echo_with_log_extras",
10751148
"echo_with_multi_logs",
10761149
"exchange_accumulate",
1150+
"exchange_cast_compatible",
10771151
"exchange_error_on_init",
10781152
"exchange_error_on_nth",
10791153
"exchange_scale",
@@ -1138,6 +1212,7 @@ def decorator(
11381212
_STREAM_METHODS = frozenset(
11391213
{
11401214
"exchange_accumulate",
1215+
"exchange_cast_compatible",
11411216
"exchange_error_on_init",
11421217
"exchange_error_on_nth",
11431218
"exchange_scale",
@@ -1203,7 +1278,7 @@ def _test_desc_describe_version(desc: ServiceDescription) -> None:
12031278

12041279
@_describe_test(category="describe_service", name="method_count")
12051280
def _test_desc_method_count(desc: ServiceDescription) -> None:
1206-
assert len(desc.methods) == 47
1281+
assert len(desc.methods) == 48
12071282

12081283

12091284
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)