@@ -783,6 +783,79 @@ def _test_exchange_empty_session(proxy: ConformanceService, logs: LogCollector)
783783 assert proxy .echo_int (value = 42 ) == 42
784784
785785
786+ # ---------------------------------------------------------------------------
787+ # Exchange cast-compatible tests
788+ # ---------------------------------------------------------------------------
789+
790+
791+ @_conformance_test (category = "exchange_stream" , name = "cast_int32_to_float64" )
792+ def _test_cast_int32 (proxy : ConformanceService , logs : LogCollector ) -> None :
793+ with proxy .exchange_cast_compatible () as session :
794+ batch = pa .record_batch (
795+ [pa .array ([1 , 2 , 3 ], type = pa .int32 ())],
796+ schema = pa .schema ([pa .field ("value" , pa .int32 ())]),
797+ )
798+ out = session .exchange (AnnotatedBatch (batch = batch ))
799+ assert out .batch .schema .field ("value" ).type == pa .float64 ()
800+ values = cast ("list[float]" , out .batch .column ("value" ).to_pylist ())
801+ assert abs (values [0 ] - 1.0 ) < 1e-6
802+ assert abs (values [1 ] - 2.0 ) < 1e-6
803+ assert abs (values [2 ] - 3.0 ) < 1e-6
804+
805+
806+ @_conformance_test (category = "exchange_stream" , name = "cast_int64_to_float64" )
807+ def _test_cast_int64 (proxy : ConformanceService , logs : LogCollector ) -> None :
808+ with proxy .exchange_cast_compatible () as session :
809+ batch = pa .record_batch (
810+ [pa .array ([10 , 20 , 30 ], type = pa .int64 ())],
811+ schema = pa .schema ([pa .field ("value" , pa .int64 ())]),
812+ )
813+ out = session .exchange (AnnotatedBatch (batch = batch ))
814+ assert out .batch .schema .field ("value" ).type == pa .float64 ()
815+ values = cast ("list[float]" , out .batch .column ("value" ).to_pylist ())
816+ assert abs (values [0 ] - 10.0 ) < 1e-6
817+ assert abs (values [1 ] - 20.0 ) < 1e-6
818+ assert abs (values [2 ] - 30.0 ) < 1e-6
819+
820+
821+ @_conformance_test (category = "exchange_stream" , name = "cast_float32_to_float64" )
822+ def _test_cast_float32 (proxy : ConformanceService , logs : LogCollector ) -> None :
823+ with proxy .exchange_cast_compatible () as session :
824+ batch = pa .record_batch (
825+ [pa .array ([1.5 , 2.5 , 3.5 ], type = pa .float32 ())],
826+ schema = pa .schema ([pa .field ("value" , pa .float32 ())]),
827+ )
828+ out = session .exchange (AnnotatedBatch (batch = batch ))
829+ assert out .batch .schema .field ("value" ).type == pa .float64 ()
830+ values = cast ("list[float]" , out .batch .column ("value" ).to_pylist ())
831+ assert abs (values [0 ] - 1.5 ) < 1e-6
832+ assert abs (values [1 ] - 2.5 ) < 1e-6
833+ assert abs (values [2 ] - 3.5 ) < 1e-6
834+
835+
836+ @_conformance_test (category = "exchange_stream" , name = "cast_exact_schema" )
837+ def _test_cast_exact (proxy : ConformanceService , logs : LogCollector ) -> None :
838+ with proxy .exchange_cast_compatible () as session :
839+ out = session .exchange (AnnotatedBatch .from_pydict ({"value" : [5.0 , 10.0 ]}))
840+ values = cast ("list[float]" , out .batch .column ("value" ).to_pylist ())
841+ assert abs (values [0 ] - 5.0 ) < 1e-6
842+ assert abs (values [1 ] - 10.0 ) < 1e-6
843+
844+
845+ @_conformance_test (category = "exchange_stream" , name = "cast_incompatible_column_name" )
846+ def _test_cast_incompatible (proxy : ConformanceService , logs : LogCollector ) -> None :
847+ with proxy .exchange_cast_compatible () as session :
848+ batch = pa .record_batch (
849+ [pa .array ([1.0 ], type = pa .float64 ())],
850+ schema = pa .schema ([pa .field ("wrong" , pa .float64 ())]),
851+ )
852+ try :
853+ session .exchange (AnnotatedBatch (batch = batch ))
854+ raise AssertionError ("Expected RpcError" )
855+ except RpcError as e :
856+ assert "TypeError" in str (e ) or "type" in str (e ).lower ()
857+
858+
786859# ---------------------------------------------------------------------------
787860# Exchange header tests
788861# ---------------------------------------------------------------------------
@@ -1074,6 +1147,7 @@ def decorator(
10741147 "echo_with_log_extras" ,
10751148 "echo_with_multi_logs" ,
10761149 "exchange_accumulate" ,
1150+ "exchange_cast_compatible" ,
10771151 "exchange_error_on_init" ,
10781152 "exchange_error_on_nth" ,
10791153 "exchange_scale" ,
@@ -1138,6 +1212,7 @@ def decorator(
11381212_STREAM_METHODS = frozenset (
11391213 {
11401214 "exchange_accumulate" ,
1215+ "exchange_cast_compatible" ,
11411216 "exchange_error_on_init" ,
11421217 "exchange_error_on_nth" ,
11431218 "exchange_scale" ,
@@ -1203,7 +1278,7 @@ def _test_desc_describe_version(desc: ServiceDescription) -> None:
12031278
12041279@_describe_test (category = "describe_service" , name = "method_count" )
12051280def _test_desc_method_count (desc : ServiceDescription ) -> None :
1206- assert len (desc .methods ) == 47
1281+ assert len (desc .methods ) == 48
12071282
12081283
12091284# ---------------------------------------------------------------------------
0 commit comments