diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index b5c76edcc8..8a28fe0d03 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -16,6 +16,7 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio - Put private methods last in class definitions, immediately before private fields. - Do not redesign alias-based or low-level public type shapes to add convenience methods unless the user explicitly asks for that API change. - For cross-language feature ports, match protocol behavior but use idiomatic C++ ownership and layering instead of mirroring Java structure literally. +- Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. ## Key Paths diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index 4ba1e6a361..891339a8fb 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -9,6 +9,7 @@ Load this file when changing `csharp/` or C# xlang behavior. - C# code must build without compiler or analyzer warnings. Treat warnings as blockers in project, test, and generated code. - Fory C# requires .NET SDK `8.0+` and C# `12+`. - Use `dotnet format` to keep C# code style consistent. +- Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. - When extending C# tests from Java references, prioritize xlang spec behavior and the public C# contract before adding complex Java-specific parity cases. ## Commands diff --git a/.agents/languages/dart.md b/.agents/languages/dart.md index 5214a43adc..616a4cc0d0 100644 --- a/.agents/languages/dart.md +++ b/.agents/languages/dart.md @@ -17,6 +17,9 @@ Load this file when changing `dart/`. - Codegen must support private fields through same-library `part` generation. If generated file naming changes from `*.fory.dart`, update builder config, source `part` directives, analysis exclusions, docs, CI snippets, and stale artifacts together. - Keep generated Dart outputs (`*.fory.dart`) and Dart `pubspec.lock` files untracked in this repo. - For generated numeric or xlang changes, test root values and generated required/nullable fields across schema-consistent and compatible serializers, metadata type IDs, rejection paths, and every affected encoding mode. +- Compatible scalar conversion is immediate-field-only. Recursive compatible schema comparison for list elements, typed-array elements, map keys, and map values must reject scalar mismatches instead of applying top-level scalar conversion. +- Generated compatible struct reads must consume per-remote-field read descriptors built before field dispatch. Exact doubled cases read directly from local field metadata and must not receive remote compatible metadata; compatible scalar cases use preclassified scalar read descriptors instead of layout-wide scalar source arrays or hot schema/type-pair eligibility helpers. +- Generated struct serializers should use serializer-owned field descriptors for runtime resolver decisions and emit direct field-specific write/read code for static schemas. Do not route generated hot writes through generic field-info value helpers such as `writeGeneratedStructFieldInfoValue`. - Dart xlang or runtime ownership changes need local Dart package tests plus the Java-driven `DartXlangTest`; package-only smoke tests are not enough. - When claiming non-VM Dart support, prove a relevant non-VM compile path such as `dart compile js` against active runtime or example code. diff --git a/.agents/languages/java.md b/.agents/languages/java.md index a68fd7bd97..163c19da8c 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -128,6 +128,9 @@ Load this file when changing anything under `java/` or when Java drives a cross- `ObjectInstantiators.getObjectInstantiator(TypeResolver, Class)` or bypass the runtime-scoped owner; format builders without a Fory runtime context may use the base `ObjectInstantiators.getObjectInstantiator(Class)` construction default. +- Compatible scalar, list-array, and binary/uint8 adapters are immediate-field-only. For matched + nested field metadata, classify schema pairs before generated dispatch and require exact + nullable/ref/generic/type shape, except for user-defined type-family normalization. - Root codegen and builder classes that still need Unsafe on JDK8-24 must route symbolic Unsafe access through a helper with a Java 25 replacement. Do not leave `_JDKAccess.unsafe()` or `sun.misc.Unsafe` references in JDK25-visible classes outside matching `java25` replacements. @@ -230,9 +233,13 @@ Load this file when changing anything under `java/` or when Java drives a cross- - Source-generated constructor serializers must own their constructor metadata at generation time and call constructors directly. They must not depend on runtime `ObjectInstantiator` constructor-field metadata or varargs constructor calls. -- Java annotation-processor static serializers do not own ordinary-class constructor metadata. - Reject ordinary non-record final fields instead of generating descriptor-based final-field - mutation; records and Kotlin KSP primary-constructor serializers are the constructor-owned paths. +- Java annotation-processor static serializers must support the same Java class surface as normal + Fory object serialization when compatible-read static generation needs it. Ordinary classes that + cannot be assigned directly should allocate through generated-subclass-owned + `ObjectInstantiator` state and write private, inaccessible, or final fields through + generated-subclass-owned cached `FieldAccessor`s. Do not add constructor-binding APIs, per-read + reflective lookup, descriptor-based varargs constructor calls, or shared-parent argument buffers. + Records and Kotlin KSP primary-constructor serializers remain constructor-owned paths. - Generated JVM copy code may direct-copy immutable scalar values, but Java `Collection`/`Map` subclasses must be copied through `CopyContext.copyObject(...)` so collection/map serializers own concrete type, comparator, wrapper, and reference behavior. diff --git a/.agents/languages/javascript.md b/.agents/languages/javascript.md index b69393ad36..b302673f20 100644 --- a/.agents/languages/javascript.md +++ b/.agents/languages/javascript.md @@ -12,6 +12,9 @@ Load this file when changing `javascript/`. - Runtime value carriers such as decimal or reduced-precision numeric types belong under the core `types/` ownership boundary, with imports, exports, and codegen externals updated together. - Keep `TypeInfo` as schema metadata. Compatibility-sensitive decisions belong on `TypeResolver` or explicit operations, not as retained resolver state on metadata objects. - Normalize optional boolean config values at config construction; do not carry `null` through runtime paths when it means `false`. +- Regenerated compatible read serializers are remote-schema-specific. After classification marks a field as direct, compatible scalar, or skip, generated JavaScript should emit straight-line remote-field-order code. Do not add an outer matched-id switch unless the current regenerated shape cannot preserve those semantics. +- Compatible scalar codegen must decide the exact remote/local scalar pair before emitting source. Generate the concrete `reader.readXxx()` call plus inline trivial conversions such as boolean-to-string or numeric widening, and keep helpers only for semantic validation such as range checks, exactness checks, decimal parsing/formatting, and string-to-bool. Do not call a generic hot-path converter that redispatches on `remoteTypeId`, `localTypeId`, field descriptors, or field names. +- Compatible scalar conversion is immediate-field-only. Recursive schema comparison for collection elements, array elements, map keys, and map values must reject scalar mismatches instead of applying the top-level scalar conversion matrix. ## Commands diff --git a/.agents/languages/python.md b/.agents/languages/python.md index 08c6beb978..7365632c37 100644 --- a/.agents/languages/python.md +++ b/.agents/languages/python.md @@ -18,6 +18,7 @@ Load this file when changing `python/`, Cython serialization, or Python xlang be - For wheel or extension pipeline changes, derive extension-module paths from current build targets, packaging config, or wheel payload discovery rather than historical module names. - Keep new Python test names compact and behavior-focused; avoid sentence-length names that restate setup details already obvious from the test body. - `ENABLE_FORY_DEBUG_OUTPUT=1` enables detailed struct serialization and deserialization logs. +- Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. ## Key Paths diff --git a/.agents/languages/rust.md b/.agents/languages/rust.md index 4a9e45e72e..6b1f8db318 100644 --- a/.agents/languages/rust.md +++ b/.agents/languages/rust.md @@ -15,6 +15,7 @@ Load this file when changing `rust/` or Rust xlang behavior. - Runtime carriers belong in `types/`, and schema or type-hash helpers belong with metadata hashing rather than generic wire/type-id modules. - If breakage is explicitly acceptable during a Rust module refactor, rewire macros, tests, and sibling crates directly to the new boundaries instead of adding compatibility re-exports. - For panic-safety in hot paths, preserve TLS context reuse. Add scoped guards or owned fallbacks rather than per-call context allocation, and reset reused contexts at entry and successful exit. +- Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Keep recursive matched-field shape classification owned by `fory-core/src/meta/type_meta.rs`; collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. ## Key Paths @@ -77,5 +78,5 @@ cargo bench cd java mvn -T16 install -DskipTests cd fory-core -RUST_BACKTRACE=1 FORY_PANIC_ON_ERROR=1 FORY_RUST_JAVA_CI=1 ENABLE_FORY_DEBUG_OUTPUT=1 mvn test -Dtest=org.apache.fory.xlang.RustXlangTest +RUST_BACKTRACE=1 FORY_RUST_JAVA_CI=1 ENABLE_FORY_DEBUG_OUTPUT=1 mvn test -Dtest=org.apache.fory.xlang.RustXlangTest ``` diff --git a/.agents/languages/swift.md b/.agents/languages/swift.md index 1e8dacbd3d..cfd5198524 100644 --- a/.agents/languages/swift.md +++ b/.agents/languages/swift.md @@ -12,6 +12,7 @@ Load this file when changing `swift/` or Swift xlang behavior. - Prefer the user-requested or existing Foundation public value type when it is the intended Swift surface; do not invent Fory-prefixed wrappers only to avoid import ambiguity. - Preserve distinct temporal semantics. Timestamp values and day-only local dates should have protocol-accurate helper names and no stale aliases after a refactor. - When temporal or public-type refactors touch generated Swift code, sweep message fields, union payloads, macros, xlang harnesses, and integration fixtures together. +- Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. ## Commands diff --git a/.agents/testing/integration-tests.md b/.agents/testing/integration-tests.md index 0f2086eed9..120734594f 100644 --- a/.agents/testing/integration-tests.md +++ b/.agents/testing/integration-tests.md @@ -7,6 +7,9 @@ Load this file when changing `integration_tests/`, xlang behavior, compiler-gene - Run all commands from within `integration_tests/`. - For Java-related integration tests, install the Java libraries first with `cd ../java && mvn -T16 install -DskipTests` if Java changed. If unsure, run it. - On macOS, GraalVM is usually installed under `/Library/Java/JavaVirtualMachines/graalvm-xxx`. +- Do not set `FORY_PANIC_ON_ERROR` for normal tests, CI reproduction, or xlang validation. It is + only for focused debugging. Verification commands should omit it, while test harnesses must not + filter it when the user command provides it. - For `integration_tests/idl_tests`: - Always run `cd ../java && mvn -T16 install -DskipTests` before the test if Java changed since the last install. If unsure, run it. - Always run `cd ../python && pip install -v -e .` before the test if Python or Cython code changed. Rebuild Cython if needed. diff --git a/AGENTS.md b/AGENTS.md index aca7ebaec1..5c62f3f76d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -90,6 +90,9 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Keep class registration enabled unless explicitly requested otherwise. - Prefer schema-consistent mode unless compatibility work requires something else. - When debugging test errors, always set `ENABLE_FORY_DEBUG_OUTPUT=1` to see debug output. +- Do not set `FORY_PANIC_ON_ERROR` for normal tests, CI reproduction, or xlang validation. + It is a focused debug knob only; omit it from verification commands, but do not filter it + from test harnesses when the user command provides it. - Never work around failures. Find and fix the root cause. Do not hack, weaken, or bypass tests to make them pass. ## Source of Truth @@ -131,6 +134,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th ## Shared Validation Expectations - Run the relevant tests for every touched language or subsystem before finishing. +- A formatter-only pass after successful tests does not invalidate those test results. Do not rerun tests solely because formatting ran after the tests already passed. - When multiple independent language test suites are required, run them concurrently when the environment has enough resources instead of running them one by one; keep each language's logs and results separate, and rerun any failed suite with focused diagnostics. - Run applicable test commands in a subagent with a thinking budget one level lower than the main task budget, using medium when the current budget is unclear, unless the change is docs-only or the user explicitly asks to run them locally. - Reuse the same test subagent for repeated runs within one task and subsystem so it keeps failure context; create a fresh subagent when switching unrelated subsystems or when prior context may be stale or misleading. diff --git a/benchmarks/cpp/README.md b/benchmarks/cpp/README.md index f9d1d62508..0bca3b0fa8 100644 --- a/benchmarks/cpp/README.md +++ b/benchmarks/cpp/README.md @@ -73,6 +73,15 @@ Examples: ./run.sh --data struct --serializer fory --duration 5 ``` +## Schema Mismatch Mode + +Set `FORY_BENCH_SCHEMA_MISMATCH=1` to run the Fory-only compatible-read +schema-mismatch mode. This mode is off by default. When enabled, run with +`--serializer fory`; protobuf and MessagePack benchmark modes fail with a +configuration error. Fory serialization uses the normal v1 benchmark structs, +and Fory deserialization uses v2 structs registered with the same Fory type IDs +where one int32 field is widened to int64. + ## Building ```bash diff --git a/benchmarks/cpp/benchmark.cc b/benchmarks/cpp/benchmark.cc index 953075d21a..dfa69806a3 100644 --- a/benchmarks/cpp/benchmark.cc +++ b/benchmarks/cpp/benchmark.cc @@ -19,7 +19,9 @@ #include #include +#include #include +#include #include #include @@ -216,6 +218,113 @@ struct MediaContentList { }; FORY_STRUCT(MediaContentList, (media_content_list, fory::F(1))); +struct NumericStructV2 { + int64_t f1; + int32_t f2; + int32_t f3; + int32_t f4; + int32_t f5; + int32_t f6; + int32_t f7; + int32_t f8; + int32_t f9; + int32_t f10; + int32_t f11; + int32_t f12; +}; +FORY_STRUCT(NumericStructV2, (f1, fory::F(1)), (f2, fory::F(2)), + (f3, fory::F(3)), (f4, fory::F(4)), (f5, fory::F(5)), + (f6, fory::F(6)), (f7, fory::F(7)), (f8, fory::F(8)), + (f9, fory::F(9)), (f10, fory::F(10)), (f11, fory::F(11)), + (f12, fory::F(12))); + +struct SampleV2 { + int64_t int_value; + int64_t long_value; + float float_value; + double double_value; + int32_t short_value; + int32_t char_value; + bool boolean_value; + int32_t int_value_boxed; + int64_t long_value_boxed; + float float_value_boxed; + double double_value_boxed; + int32_t short_value_boxed; + int32_t char_value_boxed; + bool boolean_value_boxed; + std::vector int_array; + std::vector long_array; + std::vector float_array; + std::vector double_array; + std::vector short_array; + std::vector char_array; + std::vector boolean_array; + std::string string; +}; +FORY_STRUCT(SampleV2, (int_value, fory::F(1)), (long_value, fory::F(2)), + (float_value, fory::F(3)), (double_value, fory::F(4)), + (short_value, fory::F(5)), (char_value, fory::F(6)), + (boolean_value, fory::F(7)), (int_value_boxed, fory::F(8)), + (long_value_boxed, fory::F(9)), (float_value_boxed, fory::F(10)), + (double_value_boxed, fory::F(11)), (short_value_boxed, fory::F(12)), + (char_value_boxed, fory::F(13)), (boolean_value_boxed, fory::F(14)), + (int_array, fory::F(15)), (long_array, fory::F(16)), + (float_array, fory::F(17)), (double_array, fory::F(18)), + (short_array, fory::F(19)), (char_array, fory::F(20)), + (boolean_array, fory::F(21)), (string, fory::F(22))); + +struct MediaV2 { + std::string uri; + std::string title; + int64_t width; + int32_t height; + std::string format; + int64_t duration; + int64_t size; + int32_t bitrate; + bool has_bitrate; + std::vector persons; + Player player; + std::string copyright; +}; +FORY_STRUCT(MediaV2, (uri, fory::F(1)), (title, fory::F(2)), + (width, fory::F(3)), (height, fory::F(4)), (format, fory::F(5)), + (duration, fory::F(6)), (size, fory::F(7)), (bitrate, fory::F(8)), + (has_bitrate, fory::F(9)), (persons, fory::F(10)), + (player, fory::F(11)), (copyright, fory::F(12))); + +struct ImageV2 { + std::string uri; + std::string title; + int64_t width; + int32_t height; + Size size; +}; +FORY_STRUCT(ImageV2, (uri, fory::F(1)), (title, fory::F(2)), + (width, fory::F(3)), (height, fory::F(4)), (size, fory::F(5))); + +struct MediaContentV2 { + MediaV2 media; + std::vector images; +}; +FORY_STRUCT(MediaContentV2, (media, fory::F(1)), (images, fory::F(2))); + +struct NumericStructListV2 { + std::vector struct_list; +}; +FORY_STRUCT(NumericStructListV2, (struct_list, fory::F(1))); + +struct SampleListV2 { + std::vector sample_list; +}; +FORY_STRUCT(SampleListV2, (sample_list, fory::F(1))); + +struct MediaContentListV2 { + std::vector media_content_list; +}; +FORY_STRUCT(MediaContentListV2, (media_content_list, fory::F(1))); + // ============================================================================ // Test data creation // ============================================================================ @@ -626,8 +735,153 @@ void register_fory_types(fory::serialization::Fory &fory) { fory.register_struct(8); } +void register_fory_types_v2(fory::serialization::Fory &fory) { + fory.register_struct(1); + fory.register_struct(2); + fory.register_struct(3); + fory.register_struct(4); + fory.register_struct(5); + fory.register_struct(6); + fory.register_struct(7); + fory.register_struct(8); +} + +bool schema_mismatch_enabled() { + const char *value = std::getenv("FORY_BENCH_SCHEMA_MISMATCH"); + return value != nullptr && std::string(value) == "1"; +} + +bool reject_non_fory_schema_mismatch(benchmark::State &state) { + if (!schema_mismatch_enabled()) { + return false; + } + state.SkipWithError( + "FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun " + "with --serializer fory"); + return true; +} + +fory::serialization::Fory new_benchmark_fory() { + return fory::serialization::Fory::builder() + .xlang(true) + .compatible(true) + .track_ref(false) + .build(); +} + +template +void verify_schema_mismatch_read(const ReadT &, const WriteT &) {} + +template <> +void verify_schema_mismatch_read(const NumericStructV2 &decoded, + const NumericStruct &expected) { + if (decoded.f1 != expected.f1) { + throw std::runtime_error("NumericStructV2 schema mismatch read failed"); + } +} + +template <> +void verify_schema_mismatch_read(const SampleV2 &decoded, + const Sample &expected) { + if (decoded.int_value != expected.int_value) { + throw std::runtime_error("SampleV2 schema mismatch read failed"); + } +} + +template <> +void verify_schema_mismatch_read(const MediaContentV2 &decoded, + const MediaContent &expected) { + if (decoded.media.width != expected.media.width || decoded.images.empty() || + decoded.images[0].width != expected.images[0].width) { + throw std::runtime_error("MediaContentV2 schema mismatch read failed"); + } +} + +template <> +void verify_schema_mismatch_read(const NumericStructListV2 &decoded, + const NumericStructList &expected) { + if (decoded.struct_list.empty() || + decoded.struct_list[0].f1 != expected.struct_list[0].f1) { + throw std::runtime_error("NumericStructListV2 schema mismatch read failed"); + } +} + +template <> +void verify_schema_mismatch_read(const SampleListV2 &decoded, + const SampleList &expected) { + if (decoded.sample_list.empty() || + decoded.sample_list[0].int_value != expected.sample_list[0].int_value) { + throw std::runtime_error("SampleListV2 schema mismatch read failed"); + } +} + +template <> +void verify_schema_mismatch_read(const MediaContentListV2 &decoded, + const MediaContentList &expected) { + if (decoded.media_content_list.empty() || + decoded.media_content_list[0].media.width != + expected.media_content_list[0].media.width || + decoded.media_content_list[0].images.empty() || + decoded.media_content_list[0].images[0].width != + expected.media_content_list[0].images[0].width) { + throw std::runtime_error("MediaContentListV2 schema mismatch read failed"); + } +} + +template +void run_fory_deserialize_benchmark(benchmark::State &state, Factory factory) { + auto writer = new_benchmark_fory(); + register_fory_types(writer); + auto reader = new_benchmark_fory(); + const bool mismatch = schema_mismatch_enabled(); + if (mismatch) { + register_fory_types_v2(reader); + } else { + register_fory_types(reader); + } + WriteT obj = factory(); + auto serialized = writer.serialize(obj); + if (!serialized.ok()) { + state.SkipWithError("Serialization failed"); + return; + } + auto &bytes = serialized.value(); + + if (mismatch) { + auto test_result = reader.deserialize(bytes.data(), bytes.size()); + if (!test_result.ok()) { + state.SkipWithError("Schema-mismatch deserialization test failed"); + return; + } + try { + verify_schema_mismatch_read(test_result.value(), obj); + } catch (const std::exception &e) { + state.SkipWithError(e.what()); + return; + } + for (auto _ : state) { + auto result = reader.deserialize(bytes.data(), bytes.size()); + benchmark::DoNotOptimize(result); + } + return; + } + + auto test_result = reader.deserialize(bytes.data(), bytes.size()); + if (!test_result.ok()) { + state.SkipWithError("Deserialization test failed"); + return; + } + for (auto _ : state) { + auto result = reader.deserialize(bytes.data(), bytes.size()); + benchmark::DoNotOptimize(result); + } +} + template void run_msgpack_serialize_benchmark(benchmark::State &state, Factory factory) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } T obj = factory(); msgpack::sbuffer output; @@ -642,6 +896,9 @@ void run_msgpack_serialize_benchmark(benchmark::State &state, Factory factory) { template void run_msgpack_deserialize_benchmark(benchmark::State &state, Factory factory) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } T obj = factory(); msgpack::sbuffer output; msgpack::pack(output, obj); @@ -704,6 +961,9 @@ BENCHMARK(BM_Fory_NumericStruct_Serialize); // Fair comparison: convert plain C++ struct to protobuf, then serialize // (Same pattern as Java benchmark's buildPBStruct().toByteArray()) static void BM_Protobuf_NumericStruct_Serialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } NumericStruct obj = create_numeric_struct(); protobuf::NumericStruct pb = to_pb_struct(obj); std::vector output; @@ -718,38 +978,17 @@ static void BM_Protobuf_NumericStruct_Serialize(benchmark::State &state) { BENCHMARK(BM_Protobuf_NumericStruct_Serialize); static void BM_Fory_NumericStruct_Deserialize(benchmark::State &state) { - auto fory = fory::serialization::Fory::builder() - .xlang(true) - .compatible(true) - .track_ref(false) - .build(); - register_fory_types(fory); - NumericStruct obj = create_numeric_struct(); - auto serialized = fory.serialize(obj); - if (!serialized.ok()) { - state.SkipWithError("Serialization failed"); - return; - } - auto &bytes = serialized.value(); - - // Verify deserialization works first - auto test_result = - fory.deserialize(bytes.data(), bytes.size()); - if (!test_result.ok()) { - state.SkipWithError("Deserialization test failed"); - return; - } - - for (auto _ : state) { - auto result = fory.deserialize(bytes.data(), bytes.size()); - benchmark::DoNotOptimize(result); - } + run_fory_deserialize_benchmark( + state, create_numeric_struct); } BENCHMARK(BM_Fory_NumericStruct_Deserialize); // Fair comparison: deserialize and convert protobuf to plain C++ struct // (Same pattern as Java benchmark's fromPBObject()) static void BM_Protobuf_NumericStruct_Deserialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } protobuf::NumericStruct obj = create_proto_struct(); std::string serialized; obj.SerializeToString(&serialized); @@ -790,6 +1029,9 @@ static void BM_Fory_Sample_Serialize(benchmark::State &state) { BENCHMARK(BM_Fory_Sample_Serialize); static void BM_Protobuf_Sample_Serialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } protobuf::Sample obj = create_proto_sample(); std::vector output; output.resize(obj.ByteSizeLong()); @@ -802,35 +1044,14 @@ static void BM_Protobuf_Sample_Serialize(benchmark::State &state) { BENCHMARK(BM_Protobuf_Sample_Serialize); static void BM_Fory_Sample_Deserialize(benchmark::State &state) { - auto fory = fory::serialization::Fory::builder() - .xlang(true) - .compatible(true) - .track_ref(false) - .build(); - register_fory_types(fory); - Sample obj = create_sample(); - auto serialized = fory.serialize(obj); - if (!serialized.ok()) { - state.SkipWithError("Serialization failed"); - return; - } - auto &bytes = serialized.value(); - - // Verify deserialization works first - auto test_result = fory.deserialize(bytes.data(), bytes.size()); - if (!test_result.ok()) { - state.SkipWithError("Deserialization test failed"); - return; - } - - for (auto _ : state) { - auto result = fory.deserialize(bytes.data(), bytes.size()); - benchmark::DoNotOptimize(result); - } + run_fory_deserialize_benchmark(state, create_sample); } BENCHMARK(BM_Fory_Sample_Deserialize); static void BM_Protobuf_Sample_Deserialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } protobuf::Sample obj = create_proto_sample(); std::string serialized; obj.SerializeToString(&serialized); @@ -870,6 +1091,9 @@ static void BM_Fory_MediaContent_Serialize(benchmark::State &state) { BENCHMARK(BM_Fory_MediaContent_Serialize); static void BM_Protobuf_MediaContent_Serialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } MediaContent obj = create_media_content(); protobuf::MediaContent pb = to_pb_mediaContent(obj); std::vector output; @@ -884,35 +1108,15 @@ static void BM_Protobuf_MediaContent_Serialize(benchmark::State &state) { BENCHMARK(BM_Protobuf_MediaContent_Serialize); static void BM_Fory_MediaContent_Deserialize(benchmark::State &state) { - auto fory = fory::serialization::Fory::builder() - .xlang(true) - .compatible(true) - .track_ref(false) - .build(); - register_fory_types(fory); - MediaContent obj = create_media_content(); - auto serialized = fory.serialize(obj); - if (!serialized.ok()) { - state.SkipWithError("Serialization failed"); - return; - } - auto &bytes = serialized.value(); - - // Verify deserialization works first - auto test_result = fory.deserialize(bytes.data(), bytes.size()); - if (!test_result.ok()) { - state.SkipWithError("Deserialization test failed"); - return; - } - - for (auto _ : state) { - auto result = fory.deserialize(bytes.data(), bytes.size()); - benchmark::DoNotOptimize(result); - } + run_fory_deserialize_benchmark( + state, create_media_content); } BENCHMARK(BM_Fory_MediaContent_Deserialize); static void BM_Protobuf_MediaContent_Deserialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } protobuf::MediaContent obj = create_proto_media_content(); std::string serialized; obj.SerializeToString(&serialized); @@ -952,6 +1156,9 @@ static void BM_Fory_NumericStructList_Serialize(benchmark::State &state) { BENCHMARK(BM_Fory_NumericStructList_Serialize); static void BM_Protobuf_NumericStructList_Serialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } NumericStructList obj = create_numeric_struct_list(); protobuf::NumericStructList pb = to_pb_numeric_struct_list(obj); std::vector output; @@ -966,36 +1173,15 @@ static void BM_Protobuf_NumericStructList_Serialize(benchmark::State &state) { BENCHMARK(BM_Protobuf_NumericStructList_Serialize); static void BM_Fory_NumericStructList_Deserialize(benchmark::State &state) { - auto fory = fory::serialization::Fory::builder() - .xlang(true) - .compatible(true) - .track_ref(false) - .build(); - register_fory_types(fory); - NumericStructList obj = create_numeric_struct_list(); - auto serialized = fory.serialize(obj); - if (!serialized.ok()) { - state.SkipWithError("Serialization failed"); - return; - } - auto &bytes = serialized.value(); - - auto test_result = - fory.deserialize(bytes.data(), bytes.size()); - if (!test_result.ok()) { - state.SkipWithError("Deserialization test failed"); - return; - } - - for (auto _ : state) { - auto result = - fory.deserialize(bytes.data(), bytes.size()); - benchmark::DoNotOptimize(result); - } + run_fory_deserialize_benchmark( + state, create_numeric_struct_list); } BENCHMARK(BM_Fory_NumericStructList_Deserialize); static void BM_Protobuf_NumericStructList_Deserialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } protobuf::NumericStructList obj = create_proto_numeric_struct_list(); std::string serialized; obj.SerializeToString(&serialized); @@ -1031,6 +1217,9 @@ static void BM_Fory_SampleList_Serialize(benchmark::State &state) { BENCHMARK(BM_Fory_SampleList_Serialize); static void BM_Protobuf_SampleList_Serialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } SampleList obj = create_sample_list(); protobuf::SampleList pb = to_pb_sample_list(obj); std::vector output; @@ -1045,34 +1234,15 @@ static void BM_Protobuf_SampleList_Serialize(benchmark::State &state) { BENCHMARK(BM_Protobuf_SampleList_Serialize); static void BM_Fory_SampleList_Deserialize(benchmark::State &state) { - auto fory = fory::serialization::Fory::builder() - .xlang(true) - .compatible(true) - .track_ref(false) - .build(); - register_fory_types(fory); - SampleList obj = create_sample_list(); - auto serialized = fory.serialize(obj); - if (!serialized.ok()) { - state.SkipWithError("Serialization failed"); - return; - } - auto &bytes = serialized.value(); - - auto test_result = fory.deserialize(bytes.data(), bytes.size()); - if (!test_result.ok()) { - state.SkipWithError("Deserialization test failed"); - return; - } - - for (auto _ : state) { - auto result = fory.deserialize(bytes.data(), bytes.size()); - benchmark::DoNotOptimize(result); - } + run_fory_deserialize_benchmark(state, + create_sample_list); } BENCHMARK(BM_Fory_SampleList_Deserialize); static void BM_Protobuf_SampleList_Deserialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } protobuf::SampleList obj = create_proto_sample_list(); std::string serialized; obj.SerializeToString(&serialized); @@ -1108,6 +1278,9 @@ static void BM_Fory_MediaContentList_Serialize(benchmark::State &state) { BENCHMARK(BM_Fory_MediaContentList_Serialize); static void BM_Protobuf_MediaContentList_Serialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } MediaContentList obj = create_media_content_list(); protobuf::MediaContentList pb = to_pb_media_content_list(obj); std::vector output; @@ -1122,36 +1295,15 @@ static void BM_Protobuf_MediaContentList_Serialize(benchmark::State &state) { BENCHMARK(BM_Protobuf_MediaContentList_Serialize); static void BM_Fory_MediaContentList_Deserialize(benchmark::State &state) { - auto fory = fory::serialization::Fory::builder() - .xlang(true) - .compatible(true) - .track_ref(false) - .build(); - register_fory_types(fory); - MediaContentList obj = create_media_content_list(); - auto serialized = fory.serialize(obj); - if (!serialized.ok()) { - state.SkipWithError("Serialization failed"); - return; - } - auto &bytes = serialized.value(); - - auto test_result = - fory.deserialize(bytes.data(), bytes.size()); - if (!test_result.ok()) { - state.SkipWithError("Deserialization test failed"); - return; - } - - for (auto _ : state) { - auto result = - fory.deserialize(bytes.data(), bytes.size()); - benchmark::DoNotOptimize(result); - } + run_fory_deserialize_benchmark( + state, create_media_content_list); } BENCHMARK(BM_Fory_MediaContentList_Deserialize); static void BM_Protobuf_MediaContentList_Deserialize(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } protobuf::MediaContentList obj = create_proto_media_content_list(); std::string serialized; obj.SerializeToString(&serialized); @@ -1170,6 +1322,9 @@ BENCHMARK(BM_Protobuf_MediaContentList_Deserialize); // ============================================================================ static void BM_PrintSerializedSizes(benchmark::State &state) { + if (reject_non_fory_schema_mismatch(state)) { + return; + } // Fory auto fory = fory::serialization::Fory::builder() .xlang(true) diff --git a/benchmarks/cpp/run.sh b/benchmarks/cpp/run.sh index ac2f2af8b4..4237a6b3e6 100755 --- a/benchmarks/cpp/run.sh +++ b/benchmarks/cpp/run.sh @@ -97,6 +97,11 @@ while [[ $# -gt 0 ]]; do esac done +if [[ "${FORY_BENCH_SCHEMA_MISMATCH:-0}" == "1" && "$SERIALIZER" != "fory" ]]; then + echo -e "${RED}FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory.${NC}" + exit 1 +fi + # Build benchmark filter FILTER="" DATA_PATTERN="" diff --git a/benchmarks/csharp/BenchmarkModels.cs b/benchmarks/csharp/BenchmarkModels.cs index a7058c994c..95d555530e 100644 --- a/benchmarks/csharp/BenchmarkModels.cs +++ b/benchmarks/csharp/BenchmarkModels.cs @@ -189,7 +189,6 @@ public sealed class SampleList public List Values { get; set; } = []; } -[ForyStruct] [ProtoContract] public enum Player { @@ -199,7 +198,6 @@ public enum Player Flash, } -[ForyStruct] [ProtoContract] public enum MediaSize { @@ -313,6 +311,206 @@ public sealed class MediaContentList public List Values { get; set; } = []; } +[ForyStruct] +public sealed class NumericStructV2 +{ + [ForyField(Id = 1)] + public long F1 { get; set; } + + [ForyField(Id = 2)] + public int F2 { get; set; } + + [ForyField(Id = 3)] + public int F3 { get; set; } + + [ForyField(Id = 4)] + public int F4 { get; set; } + + [ForyField(Id = 5)] + public int F5 { get; set; } + + [ForyField(Id = 6)] + public int F6 { get; set; } + + [ForyField(Id = 7)] + public int F7 { get; set; } + + [ForyField(Id = 8)] + public int F8 { get; set; } + + [ForyField(Id = 9)] + public int F9 { get; set; } + + [ForyField(Id = 10)] + public int F10 { get; set; } + + [ForyField(Id = 11)] + public int F11 { get; set; } + + [ForyField(Id = 12)] + public int F12 { get; set; } +} + +[ForyStruct] +public sealed class NumericStructListV2 +{ + [ForyField(Id = 1)] + public List Values { get; set; } = []; +} + +[ForyStruct] +public sealed class SampleV2 +{ + [ForyField(Id = 1)] + public long IntValue { get; set; } + + [ForyField(Id = 2)] + public long LongValue { get; set; } + + [ForyField(Id = 3)] + public float FloatValue { get; set; } + + [ForyField(Id = 4)] + public double DoubleValue { get; set; } + + [ForyField(Id = 5)] + public int ShortValue { get; set; } + + [ForyField(Id = 6)] + public int CharValue { get; set; } + + [ForyField(Id = 7)] + public bool BooleanValue { get; set; } + + [ForyField(Id = 8)] + public int IntValueBoxed { get; set; } + + [ForyField(Id = 9)] + public long LongValueBoxed { get; set; } + + [ForyField(Id = 10)] + public float FloatValueBoxed { get; set; } + + [ForyField(Id = 11)] + public double DoubleValueBoxed { get; set; } + + [ForyField(Id = 12)] + public int ShortValueBoxed { get; set; } + + [ForyField(Id = 13)] + public int CharValueBoxed { get; set; } + + [ForyField(Id = 14)] + public bool BooleanValueBoxed { get; set; } + + [ForyField(Id = 15)] + public int[] IntArray { get; set; } = []; + + [ForyField(Id = 16)] + public long[] LongArray { get; set; } = []; + + [ForyField(Id = 17)] + public float[] FloatArray { get; set; } = []; + + [ForyField(Id = 18)] + public double[] DoubleArray { get; set; } = []; + + [ForyField(Id = 19)] + public int[] ShortArray { get; set; } = []; + + [ForyField(Id = 20)] + public int[] CharArray { get; set; } = []; + + [ForyField(Id = 21)] + public bool[] BooleanArray { get; set; } = []; + + [ForyField(Id = 22)] + public string String { get; set; } = string.Empty; +} + +[ForyStruct] +public sealed class SampleListV2 +{ + [ForyField(Id = 1)] + public List Values { get; set; } = []; +} + +[ForyStruct] +public sealed class MediaV2 +{ + [ForyField(Id = 1)] + public string Uri { get; set; } = string.Empty; + + [ForyField(Id = 2)] + public string Title { get; set; } = string.Empty; + + [ForyField(Id = 3)] + public long Width { get; set; } + + [ForyField(Id = 4)] + public int Height { get; set; } + + [ForyField(Id = 5)] + public string Format { get; set; } = string.Empty; + + [ForyField(Id = 6)] + public long Duration { get; set; } + + [ForyField(Id = 7)] + public long Size { get; set; } + + [ForyField(Id = 8)] + public int Bitrate { get; set; } + + [ForyField(Id = 9)] + public bool HasBitrate { get; set; } + + [ForyField(Id = 10)] + public List Persons { get; set; } = []; + + [ForyField(Id = 11)] + public Player Player { get; set; } + + [ForyField(Id = 12)] + public string Copyright { get; set; } = string.Empty; +} + +[ForyStruct] +public sealed class ImageV2 +{ + [ForyField(Id = 1)] + public string Uri { get; set; } = string.Empty; + + [ForyField(Id = 2)] + public string Title { get; set; } = string.Empty; + + [ForyField(Id = 3)] + public long Width { get; set; } + + [ForyField(Id = 4)] + public int Height { get; set; } + + [ForyField(Id = 5)] + public MediaSize Size { get; set; } +} + +[ForyStruct] +public sealed class MediaContentV2 +{ + [ForyField(Id = 1)] + public MediaV2 Media { get; set; } = new(); + + [ForyField(Id = 2)] + public List Images { get; set; } = []; +} + +[ForyStruct] +public sealed class MediaContentListV2 +{ + [ForyField(Id = 1)] + public List Values { get; set; } = []; +} + public static class BenchmarkDataFactory { private const int ListSize = 5; diff --git a/benchmarks/csharp/BenchmarkSerializers.cs b/benchmarks/csharp/BenchmarkSerializers.cs index f660e7d4f6..d868ebf780 100644 --- a/benchmarks/csharp/BenchmarkSerializers.cs +++ b/benchmarks/csharp/BenchmarkSerializers.cs @@ -30,28 +30,37 @@ internal interface IBenchmarkSerializer byte[] Serialize(T value); - T Deserialize(byte[] payload); + object? Deserialize(byte[] payload); } -internal sealed class ForySerializer : IBenchmarkSerializer +internal sealed class ForySerializer : IBenchmarkSerializer { - private readonly ForyRuntime _fory = ForyRuntime.Builder().Compatible(true).Build(); + private readonly ForyRuntime _writer = ForyRuntime.Builder().Compatible(true).Build(); + private readonly ForyRuntime _reader = ForyRuntime.Builder().Compatible(true).Build(); - public ForySerializer() + public ForySerializer(bool schemaMismatch) { - BenchmarkTypeRegistry.RegisterAll(_fory); + BenchmarkTypeRegistry.RegisterAll(_writer); + if (schemaMismatch) + { + BenchmarkTypeRegistry.RegisterAllV2(_reader); + } + else + { + BenchmarkTypeRegistry.RegisterAll(_reader); + } } public string Name => "fory"; - public byte[] Serialize(T value) + public byte[] Serialize(TWrite value) { - return _fory.Serialize(value); + return _writer.Serialize(value); } - public T Deserialize(byte[] payload) + public object? Deserialize(byte[] payload) { - return _fory.Deserialize(payload); + return _reader.Deserialize(payload); } } @@ -71,6 +80,21 @@ public static void RegisterAll(ForyRuntime fory) fory.Register(9); fory.Register(10); } + + public static void RegisterAllV2(ForyRuntime fory) + { + // Keep user type IDs identical to v1 so compatible reads exercise schema evolution. + fory.Register(1); + fory.Register(2); + fory.Register(3); + fory.Register(4); + fory.Register(5); + fory.Register(6); + fory.Register(7); + fory.Register(8); + fory.Register(9); + fory.Register(10); + } } internal sealed class ProtobufSerializer : IBenchmarkSerializer @@ -87,7 +111,7 @@ public byte[] Serialize(T value) return _writeStream.ToArray(); } - public T Deserialize(byte[] payload) + public object? Deserialize(byte[] payload) { using MemoryStream stream = new(payload, writable: false); return ProtobufNetSerializer.Deserialize(stream); @@ -105,7 +129,7 @@ public byte[] Serialize(T value) return MessagePackSerializer.Serialize(value, _options); } - public T Deserialize(byte[] payload) + public object? Deserialize(byte[] payload) { return MessagePackSerializer.Deserialize(payload, _options); } diff --git a/benchmarks/csharp/Program.cs b/benchmarks/csharp/Program.cs index 7780ff505e..812810eb05 100644 --- a/benchmarks/csharp/Program.cs +++ b/benchmarks/csharp/Program.cs @@ -52,6 +52,8 @@ internal sealed record BenchmarkOutput( internal sealed class BenchmarkOptions { + private const string SchemaMismatchEnv = "FORY_BENCH_SCHEMA_MISMATCH"; + public HashSet DataFilter { get; init; } = []; public HashSet SerializerFilter { get; init; } = []; @@ -64,6 +66,8 @@ internal sealed class BenchmarkOptions public bool ShowHelp { get; init; } + public bool SchemaMismatch { get; init; } + public static BenchmarkOptions Parse(string[] args) { HashSet dataFilter = new(StringComparer.OrdinalIgnoreCase); @@ -72,6 +76,7 @@ public static BenchmarkOptions Parse(string[] args) double durationSeconds = 3.0; string outputPath = "benchmark_results.json"; bool showHelp = false; + bool schemaMismatch = Environment.GetEnvironmentVariable(SchemaMismatchEnv) == "1"; for (int i = 0; i < args.Length; i++) { @@ -114,6 +119,7 @@ public static BenchmarkOptions Parse(string[] args) DurationSeconds = durationSeconds, OutputPath = outputPath, ShowHelp = showHelp, + SchemaMismatch = schemaMismatch, }; } @@ -127,6 +133,20 @@ public bool IsSerializerEnabled(string serializer) return SerializerFilter.Count == 0 || SerializerFilter.Contains(serializer); } + public void ValidateSchemaMismatch() + { + if (!SchemaMismatch) + { + return; + } + + if (SerializerFilter.Count != 1 || !SerializerFilter.Contains("fory")) + { + throw new ArgumentException( + $"{SchemaMismatchEnv}=1 supports only Fory benchmarks; rerun with --serializer fory"); + } + } + private static void RequireValue(string[] args, int index) { if (index + 1 >= args.Length) @@ -178,6 +198,7 @@ private static int Main(string[] args) List cases; try { + options.ValidateSchemaMismatch(); cases = BuildBenchmarkCases(options); } catch (Exception ex) @@ -196,6 +217,7 @@ private static int Main(string[] args) Console.WriteLine($"Cases: {cases.Count}"); Console.WriteLine($"Warmup: {options.WarmupSeconds.ToString("F2", CultureInfo.InvariantCulture)}s"); Console.WriteLine($"Duration: {options.DurationSeconds.ToString("F2", CultureInfo.InvariantCulture)}s"); + Console.WriteLine($"Schema mismatch: {options.SchemaMismatch}"); Console.WriteLine(); List results = new(cases.Count); @@ -294,43 +316,103 @@ private static List BuildBenchmarkCases(BenchmarkOptions options) { List cases = []; - AddCases("struct", BenchmarkDataFactory.CreateNumericStruct(), options, cases); - AddCases("sample", BenchmarkDataFactory.CreateSample(), options, cases); - AddCases("mediacontent", BenchmarkDataFactory.CreateMediaContent(), options, cases); - AddCases("structlist", BenchmarkDataFactory.CreateNumericStructList(), options, cases); - AddCases("samplelist", BenchmarkDataFactory.CreateSampleList(), options, cases); - AddCases("mediacontentlist", BenchmarkDataFactory.CreateMediaContentList(), options, cases); + AddCases( + "struct", + BenchmarkDataFactory.CreateNumericStruct(), + options, + cases, + (decoded, expected) => decoded.F1 == expected.F1); + AddCases( + "sample", + BenchmarkDataFactory.CreateSample(), + options, + cases, + (decoded, expected) => decoded.IntValue == expected.IntValue); + AddCases( + "mediacontent", + BenchmarkDataFactory.CreateMediaContent(), + options, + cases, + (decoded, expected) => + decoded.Media.Width == expected.Media.Width + && decoded.Images.Count > 0 + && decoded.Images[0].Width == expected.Images[0].Width); + AddCases( + "structlist", + BenchmarkDataFactory.CreateNumericStructList(), + options, + cases, + (decoded, expected) => + decoded.Values.Count > 0 && decoded.Values[0].F1 == expected.Values[0].F1); + AddCases( + "samplelist", + BenchmarkDataFactory.CreateSampleList(), + options, + cases, + (decoded, expected) => + decoded.Values.Count > 0 + && decoded.Values[0].IntValue == expected.Values[0].IntValue); + AddCases( + "mediacontentlist", + BenchmarkDataFactory.CreateMediaContentList(), + options, + cases, + (decoded, expected) => + decoded.Values.Count > 0 + && decoded.Values[0].Media.Width == expected.Values[0].Media.Width + && decoded.Values[0].Images.Count > 0 + && decoded.Values[0].Images[0].Width == expected.Values[0].Images[0].Width); return cases; } - private static void AddCases(string dataType, T value, BenchmarkOptions options, List cases) + private static void AddCases( + string dataType, + TWrite value, + BenchmarkOptions options, + List cases, + Func validateMismatch) { if (!options.IsDataEnabled(dataType)) { return; } - List> serializers = []; + List> serializers = []; if (options.IsSerializerEnabled("fory")) { - serializers.Add(new ForySerializer()); + if (options.SchemaMismatch) + { + serializers.Add(new ForySerializer(schemaMismatch: true)); + } + else + { + serializers.Add(new ForySerializer(schemaMismatch: false)); + } } - if (options.IsSerializerEnabled("protobuf")) + if (!options.SchemaMismatch && options.IsSerializerEnabled("protobuf")) { - serializers.Add(new ProtobufSerializer()); + serializers.Add(new ProtobufSerializer()); } - if (options.IsSerializerEnabled("msgpack")) + if (!options.SchemaMismatch && options.IsSerializerEnabled("msgpack")) { - serializers.Add(new MessagePackRuntimeSerializer()); + serializers.Add(new MessagePackRuntimeSerializer()); } - foreach (IBenchmarkSerializer serializer in serializers) + foreach (IBenchmarkSerializer serializer in serializers) { byte[] payload = serializer.Serialize(value); _sink = serializer.Deserialize(payload); + if (options.SchemaMismatch && serializer.Name == "fory") + { + if (_sink is not TRead decoded || !validateMismatch(decoded, value)) + { + throw new InvalidOperationException( + $"Fory schema-mismatch validation failed for {dataType}"); + } + } cases.Add(new BenchmarkCase( serializer.Name, diff --git a/benchmarks/csharp/README.md b/benchmarks/csharp/README.md index 33df7def87..bd080aa789 100644 --- a/benchmarks/csharp/README.md +++ b/benchmarks/csharp/README.md @@ -51,6 +51,15 @@ Examples: ./run.sh --duration 10 --warmup 2 ``` +## Schema Mismatch Mode + +Set `FORY_BENCH_SCHEMA_MISMATCH=1` to run the Fory-only compatible-read +schema-mismatch mode. This mode is off by default. When enabled, run with +`--serializer fory`; protobuf-net and MessagePack benchmark modes fail with a +configuration error. Fory serialization uses the normal v1 benchmark models, +and Fory deserialization uses v2 models registered with the same Fory type IDs +where one int32 field is widened to int64. + ## Benchmark Cases - `struct`: 8-field integer object diff --git a/benchmarks/csharp/run.sh b/benchmarks/csharp/run.sh index 2c078ca711..2f8065d720 100755 --- a/benchmarks/csharp/run.sh +++ b/benchmarks/csharp/run.sh @@ -91,6 +91,11 @@ while [[ $# -gt 0 ]]; do esac done +if [[ "${FORY_BENCH_SCHEMA_MISMATCH:-0}" == "1" && "$SERIALIZER" != "fory" ]]; then + echo -e "${RED}FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory.${NC}" + exit 1 +fi + if [[ -n "$OUTPUT_DIR" ]]; then BUILD_DIR="$OUTPUT_DIR/build" REPORT_DIR="$OUTPUT_DIR/report" diff --git a/benchmarks/dart/README.md b/benchmarks/dart/README.md index 2c9c85641c..e290b7a458 100644 --- a/benchmarks/dart/README.md +++ b/benchmarks/dart/README.md @@ -15,3 +15,12 @@ The harness uses: This package is generated and validated from the repository checkout. Run `./run.sh` to generate code, compile the benchmark runner, and execute the measurements. + +## Schema Mismatch Mode + +Set `FORY_BENCH_SCHEMA_MISMATCH=1` to run the Fory-only compatible-read +schema-mismatch mode. This mode is off by default. When enabled, run with +`--serializer fory`; protobuf and JSON benchmark modes fail with a configuration +error. Fory serialization uses the normal v1 benchmark models, and Fory +deserialization uses v2 models registered with the same Fory type IDs where one +int32 field is widened to int64. diff --git a/benchmarks/dart/benchmark_report.py b/benchmarks/dart/benchmark_report.py index fb4bdd3e90..d1d0977d30 100644 --- a/benchmarks/dart/benchmark_report.py +++ b/benchmarks/dart/benchmark_report.py @@ -96,14 +96,18 @@ def format_tick(value: float, _position) -> str: return format_throughput_tick(value, _position) -def format_int(value: float) -> str: - return f"{int(round(value)):,}" +def format_int(value) -> str: + return f"{int(round(value)):,}" if value is not None and value > 0 else "N/A" -def fastest_entry(values: dict[str, float]) -> str: - positive = {serializer: value for serializer, value in values.items() if value > 0} +def fastest_entry(values: dict) -> str: + positive = { + serializer: value + for serializer, value in values.items() + if value is not None and value > 0 + } if not positive: - return "n/a" + return "N/A" ranked = sorted(positive.items(), key=lambda entry: entry[1], reverse=True) best_serializer, best_value = ranked[0] if len(ranked) == 1: @@ -236,7 +240,7 @@ def write_report(payload: dict, results: dict, output_dir: str): continue for operation in ["serialize", "deserialize"]: values = { - serializer: operations.get(operation, {}).get(serializer, 0.0) + serializer: operations.get(operation, {}).get(serializer) for serializer in SERIALIZERS } handle.write( @@ -254,8 +258,11 @@ def write_report(payload: dict, results: dict, output_dir: str): sizes = payload["sizes"].get(data_type) if sizes is None: continue + fory_size = sizes.get("fory", "N/A") + protobuf_size = sizes.get("protobuf", "N/A") + json_size = sizes.get("json", "N/A") handle.write( - f"| {DISPLAY_NAMES[data_type]} | {sizes['fory']} | {sizes['protobuf']} | {sizes['json']} |\n" + f"| {DISPLAY_NAMES[data_type]} | {fory_size} | {protobuf_size} | {json_size} |\n" ) format_markdown_with_prettier(report_path) diff --git a/benchmarks/dart/bin/benchmark_runner.dart b/benchmarks/dart/bin/benchmark_runner.dart index cc722eebee..0796fb8ce6 100644 --- a/benchmarks/dart/bin/benchmark_runner.dart +++ b/benchmarks/dart/bin/benchmark_runner.dart @@ -25,6 +25,7 @@ import 'package:args/args.dart'; import 'package:fory_dart_benchmark/src/workloads.dart'; const int _batchSize = 64; +const String _schemaMismatchEnv = 'FORY_BENCH_SCHEMA_MISMATCH'; final class BenchmarkRecord { final String serializer; @@ -85,6 +86,17 @@ void main(List arguments) { final selectedData = _parseFilter(args['data'] as String?); final selectedSerializers = _parseFilter(args['serializer'] as String?); final selectedOperations = _parseFilter(args['operation'] as String?); + final schemaMismatch = Platform.environment[_schemaMismatchEnv] == '1'; + if (schemaMismatch && + (selectedSerializers == null || + selectedSerializers.length != 1 || + !selectedSerializers.contains('fory'))) { + stderr.writeln( + '$_schemaMismatchEnv=1 supports only Fory benchmarks; rerun with --serializer fory.', + ); + exitCode = 1; + return; + } final samples = int.parse(args['samples'] as String); final duration = Duration( microseconds: (double.parse(args['duration'] as String) * 1000000).round(), @@ -94,6 +106,8 @@ void main(List arguments) { ); final definitions = buildBenchmarkDefinitions(); + final writerFory = newBenchmarkWriterFory(); + final readerFory = newBenchmarkReaderFory(schemaMismatch: schemaMismatch); final records = []; final sizes = >{}; @@ -101,12 +115,18 @@ void main(List arguments) { if (!_matches(selectedData, definition.dataType)) { continue; } - final benchmark = definition.instantiate(newBenchmarkFory()); - sizes[benchmark.dataType] = { - 'fory': benchmark.forySize, - 'protobuf': benchmark.protobufSize, - 'json': benchmark.jsonSize, - }; + final benchmark = definition.instantiate( + writerFory: writerFory, + readerFory: readerFory, + schemaMismatch: schemaMismatch, + ); + sizes[benchmark.dataType] = schemaMismatch + ? {'fory': benchmark.forySize} + : { + 'fory': benchmark.forySize, + 'protobuf': benchmark.protobufSize!, + 'json': benchmark.jsonSize!, + }; if (_matches(selectedSerializers, 'fory') && _matches(selectedOperations, 'serialize')) { @@ -320,11 +340,13 @@ void _printSizes(Map> sizes) { stdout.writeln('Serialized sizes (bytes)'); stdout.writeln('------------------------'); for (final entry in sizes.entries) { + final protobuf = entry.value['protobuf']; + final json = entry.value['json']; stdout.writeln( '${entry.key.padRight(16)} ' 'fory=${entry.value['fory']} ' - 'protobuf=${entry.value['protobuf']} ' - 'json=${entry.value['json']}', + 'protobuf=${protobuf ?? 'n/a'} ' + 'json=${json ?? 'n/a'}', ); } } @@ -334,11 +356,13 @@ String _sizeTableText(Map> sizes) { ..writeln('Serialized sizes (bytes)') ..writeln('========================'); for (final entry in sizes.entries) { + final protobuf = entry.value['protobuf']; + final json = entry.value['json']; buffer.writeln( '${entry.key}: ' 'fory=${entry.value['fory']} ' - 'protobuf=${entry.value['protobuf']} ' - 'json=${entry.value['json']}', + 'protobuf=${protobuf ?? 'n/a'} ' + 'json=${json ?? 'n/a'}', ); } return buffer.toString(); diff --git a/benchmarks/dart/lib/src/models.dart b/benchmarks/dart/lib/src/models.dart index e99254e456..b79712abdb 100644 --- a/benchmarks/dart/lib/src/models.dart +++ b/benchmarks/dart/lib/src/models.dart @@ -350,17 +350,254 @@ class MediaContentList { Map toJson() => _$MediaContentListToJson(this); } +@ForyStruct() +class NumericStructV2 { + NumericStructV2({ + required this.f1, + required this.f2, + required this.f3, + required this.f4, + required this.f5, + required this.f6, + required this.f7, + required this.f8, + required this.f9, + required this.f10, + required this.f11, + required this.f12, + }); + + @ForyField(id: 1, type: Int64Type()) + final int f1; + @ForyField(id: 2, type: Int32Type()) + final int f2; + @ForyField(id: 3, type: Int32Type()) + final int f3; + @ForyField(id: 4, type: Int32Type()) + final int f4; + @ForyField(id: 5, type: Int32Type()) + final int f5; + @ForyField(id: 6, type: Int32Type()) + final int f6; + @ForyField(id: 7, type: Int32Type()) + final int f7; + @ForyField(id: 8, type: Int32Type()) + final int f8; + @ForyField(id: 9, type: Int32Type()) + final int f9; + @ForyField(id: 10, type: Int32Type()) + final int f10; + @ForyField(id: 11, type: Int32Type()) + final int f11; + @ForyField(id: 12, type: Int32Type()) + final int f12; +} + +@ForyStruct() +class SampleV2 { + SampleV2({ + required this.intValue, + required this.longValue, + required this.floatValue, + required this.doubleValue, + required this.shortValue, + required this.charValue, + required this.booleanValue, + required this.intValueBoxed, + required this.longValueBoxed, + required this.floatValueBoxed, + required this.doubleValueBoxed, + required this.shortValueBoxed, + required this.charValueBoxed, + required this.booleanValueBoxed, + required this.intArray, + required this.longArray, + required this.floatArray, + required this.doubleArray, + required this.shortArray, + required this.charArray, + required this.booleanArray, + required this.string, + }); + + @ForyField(id: 1, type: Int64Type()) + final int intValue; + @ForyField(id: 2, type: Int64Type()) + final int longValue; + @ForyField(id: 3) + final Float32 floatValue; + @ForyField(id: 4) + final double doubleValue; + @ForyField(id: 5, type: Int32Type()) + final int shortValue; + @ForyField(id: 6, type: Int32Type()) + final int charValue; + @ForyField(id: 7) + final bool booleanValue; + @ForyField(id: 8, type: Int32Type()) + final int intValueBoxed; + @ForyField(id: 9, type: Int64Type()) + final int longValueBoxed; + @ForyField(id: 10) + final Float32 floatValueBoxed; + @ForyField(id: 11) + final double doubleValueBoxed; + @ForyField(id: 12, type: Int32Type()) + final int shortValueBoxed; + @ForyField(id: 13, type: Int32Type()) + final int charValueBoxed; + @ForyField(id: 14) + final bool booleanValueBoxed; + @ForyField(id: 15) + final Int32List intArray; + @ForyField(id: 16) + final Int64List longArray; + @ForyField(id: 17) + final Float32List floatArray; + @ForyField(id: 18) + final Float64List doubleArray; + @ForyField(id: 19) + final Int32List shortArray; + @ForyField(id: 20) + final Int32List charArray; + @ArrayField(id: 21, element: BoolType()) + final BoolList booleanArray; + @ForyField(id: 22) + final String string; +} + +@ForyStruct() +class MediaV2 { + MediaV2({ + required this.uri, + required this.title, + required this.width, + required this.height, + required this.format, + required this.duration, + required this.size, + required this.bitrate, + required this.hasBitrate, + required this.persons, + required this.player, + required this.copyright, + }); + + @ForyField(id: 1) + final String uri; + @ForyField(id: 2) + final String title; + @ForyField(id: 3, type: Int64Type()) + final int width; + @ForyField(id: 4, type: Int32Type()) + final int height; + @ForyField(id: 5) + final String format; + @ForyField(id: 6, type: Int64Type()) + final int duration; + @ForyField(id: 7, type: Int64Type()) + final int size; + @ForyField(id: 8, type: Int32Type()) + final int bitrate; + @ForyField(id: 9) + final bool hasBitrate; + @ForyField(id: 10) + final List persons; + @ForyField(id: 11) + final Player player; + @ForyField(id: 12) + final String copyright; +} + +@ForyStruct() +class ImageV2 { + ImageV2({ + required this.uri, + required this.title, + required this.width, + required this.height, + required this.size, + }); + + @ForyField(id: 1) + final String uri; + @ForyField(id: 2) + final String title; + @ForyField(id: 3, type: Int64Type()) + final int width; + @ForyField(id: 4, type: Int32Type()) + final int height; + @ForyField(id: 5) + final MediaSize size; +} + +@ForyStruct() +class MediaContentV2 { + MediaContentV2({ + required this.media, + required this.images, + }); + + @ForyField(id: 1) + final MediaV2 media; + @ForyField(id: 2) + final List images; +} + +@ForyStruct() +class NumericStructListV2 { + NumericStructListV2({ + required this.structList, + }); + + @ForyField(id: 1) + final List structList; +} + +@ForyStruct() +class SampleListV2 { + SampleListV2({ + required this.sampleList, + }); + + @ForyField(id: 1) + final List sampleList; +} + +@ForyStruct() +class MediaContentListV2 { + MediaContentListV2({ + required this.mediaContentList, + }); + + @ForyField(id: 1) + final List mediaContentList; +} + void registerBenchmarkTypes(Fory fory) { - ModelsFory.register(fory, NumericStruct, id: 1); - ModelsFory.register(fory, Sample, id: 2); - ModelsFory.register(fory, Media, id: 3); - ModelsFory.register(fory, Image, id: 4); - ModelsFory.register(fory, MediaContent, id: 5); - ModelsFory.register(fory, NumericStructList, id: 6); - ModelsFory.register(fory, SampleList, id: 7); - ModelsFory.register(fory, MediaContentList, id: 8); - ModelsFory.register(fory, Player, id: 9); - ModelsFory.register(fory, MediaSize, id: 10); + ModelsForyModule.register(fory, NumericStruct, id: 1); + ModelsForyModule.register(fory, Sample, id: 2); + ModelsForyModule.register(fory, Media, id: 3); + ModelsForyModule.register(fory, Image, id: 4); + ModelsForyModule.register(fory, MediaContent, id: 5); + ModelsForyModule.register(fory, NumericStructList, id: 6); + ModelsForyModule.register(fory, SampleList, id: 7); + ModelsForyModule.register(fory, MediaContentList, id: 8); + ModelsForyModule.register(fory, Player, id: 9); + ModelsForyModule.register(fory, MediaSize, id: 10); +} + +void registerBenchmarkTypesV2(Fory fory) { + ModelsForyModule.register(fory, NumericStructV2, id: 1); + ModelsForyModule.register(fory, SampleV2, id: 2); + ModelsForyModule.register(fory, MediaV2, id: 3); + ModelsForyModule.register(fory, ImageV2, id: 4); + ModelsForyModule.register(fory, MediaContentV2, id: 5); + ModelsForyModule.register(fory, NumericStructListV2, id: 6); + ModelsForyModule.register(fory, SampleListV2, id: 7); + ModelsForyModule.register(fory, MediaContentListV2, id: 8); + ModelsForyModule.register(fory, Player, id: 9); + ModelsForyModule.register(fory, MediaSize, id: 10); } NumericStruct createNumericStruct() { diff --git a/benchmarks/dart/lib/src/workloads.dart b/benchmarks/dart/lib/src/workloads.dart index 9a9e7a62d6..e2c8966f2e 100644 --- a/benchmarks/dart/lib/src/workloads.dart +++ b/benchmarks/dart/lib/src/workloads.dart @@ -37,7 +37,9 @@ final class BenchmarkDefinition> serializeProtobuf: (model, _) => toPbStruct(model).writeToBuffer(), serializeJson: (model) => jsonEncode(model.toJson()), parseFory: (fory, bytes) => fory.deserialize(bytes), + parseForyMismatch: (fory, bytes) => + fory.deserialize(bytes), + verifyForyMismatch: (decoded, expected) { + if (decoded is! NumericStructV2 || decoded.f1 != expected.f1) { + throw StateError('NumericStructV2 schema mismatch read failed.'); + } + }, parseProtobuf: (bytes) => fromPbStruct(pb.NumericStruct.fromBuffer(bytes)), parseJson: (text) => NumericStruct.fromJson(jsonDecode(text)), @@ -140,6 +183,12 @@ List> protobufMessage.writeToBuffer(), serializeJson: (model) => jsonEncode(model.toJson()), parseFory: (fory, bytes) => fory.deserialize(bytes), + parseForyMismatch: (fory, bytes) => fory.deserialize(bytes), + verifyForyMismatch: (decoded, expected) { + if (decoded is! SampleV2 || decoded.intValue != expected.intValue) { + throw StateError('SampleV2 schema mismatch read failed.'); + } + }, parseProtobuf: pb.Sample.fromBuffer, parseJson: (text) => Sample.fromJson(jsonDecode(text)), ), @@ -150,6 +199,16 @@ List> serializeProtobuf: (model, _) => toPbMediaContent(model).writeToBuffer(), serializeJson: (model) => jsonEncode(model.toJson()), parseFory: (fory, bytes) => fory.deserialize(bytes), + parseForyMismatch: (fory, bytes) => + fory.deserialize(bytes), + verifyForyMismatch: (decoded, expected) { + if (decoded is! MediaContentV2 || + decoded.media.width != expected.media.width || + decoded.images.isEmpty || + decoded.images.first.width != expected.images.first.width) { + throw StateError('MediaContentV2 schema mismatch read failed.'); + } + }, parseProtobuf: (bytes) => fromPbMediaContent(pb.MediaContent.fromBuffer(bytes)), parseJson: (text) => MediaContent.fromJson(jsonDecode(text)), @@ -162,6 +221,15 @@ List> toPbNumericStructList(model).writeToBuffer(), serializeJson: (model) => jsonEncode(model.toJson()), parseFory: (fory, bytes) => fory.deserialize(bytes), + parseForyMismatch: (fory, bytes) => + fory.deserialize(bytes), + verifyForyMismatch: (decoded, expected) { + if (decoded is! NumericStructListV2 || + decoded.structList.isEmpty || + decoded.structList.first.f1 != expected.structList.first.f1) { + throw StateError('NumericStructListV2 schema mismatch read failed.'); + } + }, parseProtobuf: (bytes) => fromPbNumericStructList(pb.NumericStructList.fromBuffer(bytes)), parseJson: (text) => NumericStructList.fromJson(jsonDecode(text)), @@ -173,6 +241,15 @@ List> serializeProtobuf: (model, _) => toPbSampleList(model).writeToBuffer(), serializeJson: (model) => jsonEncode(model.toJson()), parseFory: (fory, bytes) => fory.deserialize(bytes), + parseForyMismatch: (fory, bytes) => fory.deserialize(bytes), + verifyForyMismatch: (decoded, expected) { + if (decoded is! SampleListV2 || + decoded.sampleList.isEmpty || + decoded.sampleList.first.intValue != + expected.sampleList.first.intValue) { + throw StateError('SampleListV2 schema mismatch read failed.'); + } + }, parseProtobuf: (bytes) => fromPbSampleList(pb.SampleList.fromBuffer(bytes)), parseJson: (text) => SampleList.fromJson(jsonDecode(text)), @@ -185,6 +262,19 @@ List> toPbMediaContentList(model).writeToBuffer(), serializeJson: (model) => jsonEncode(model.toJson()), parseFory: (fory, bytes) => fory.deserialize(bytes), + parseForyMismatch: (fory, bytes) => + fory.deserialize(bytes), + verifyForyMismatch: (decoded, expected) { + if (decoded is! MediaContentListV2 || + decoded.mediaContentList.isEmpty || + decoded.mediaContentList.first.media.width != + expected.mediaContentList.first.media.width || + decoded.mediaContentList.first.images.isEmpty || + decoded.mediaContentList.first.images.first.width != + expected.mediaContentList.first.images.first.width) { + throw StateError('MediaContentListV2 schema mismatch read failed.'); + } + }, parseProtobuf: (bytes) => fromPbMediaContentList(pb.MediaContentList.fromBuffer(bytes)), parseJson: (text) => MediaContentList.fromJson(jsonDecode(text)), @@ -193,7 +283,21 @@ List> } Fory newBenchmarkFory() { + return newBenchmarkWriterFory(); +} + +Fory newBenchmarkWriterFory() { final fory = Fory(compatible: true); registerBenchmarkTypes(fory); return fory; } + +Fory newBenchmarkReaderFory({required bool schemaMismatch}) { + final fory = Fory(compatible: true); + if (schemaMismatch) { + registerBenchmarkTypesV2(fory); + } else { + registerBenchmarkTypes(fory); + } + return fory; +} diff --git a/benchmarks/dart/run.sh b/benchmarks/dart/run.sh index 11f5f176a8..5813adac20 100644 --- a/benchmarks/dart/run.sh +++ b/benchmarks/dart/run.sh @@ -95,6 +95,11 @@ while [[ $# -gt 0 ]]; do esac done +if [[ "${FORY_BENCH_SCHEMA_MISMATCH:-0}" == "1" && "$SERIALIZER" != "fory" ]]; then + echo "FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory." + exit 1 +fi + mkdir -p "$OUTPUT_DIR" "$BUILD_DIR" echo "============================================" diff --git a/benchmarks/dart/test/benchmark_models_test.dart b/benchmarks/dart/test/benchmark_models_test.dart index 70cb3e0d81..c279c3a1d1 100644 --- a/benchmarks/dart/test/benchmark_models_test.dart +++ b/benchmarks/dart/test/benchmark_models_test.dart @@ -66,7 +66,11 @@ void main() { }; for (final definition in definitions) { - final benchmark = definition.instantiate(fory); + final benchmark = definition.instantiate( + writerFory: fory, + readerFory: fory, + schemaMismatch: false, + ); expect( benchmark.forySize, expectedSizes[benchmark.dataType], diff --git a/benchmarks/go/README.md b/benchmarks/go/README.md index 699c1263de..2fbfc13c5d 100644 --- a/benchmarks/go/README.md +++ b/benchmarks/go/README.md @@ -69,6 +69,15 @@ apt-get install protobuf-compiler ./run.sh --no-report ``` +### Schema Mismatch Mode + +Set `FORY_BENCH_SCHEMA_MISMATCH=1` to run the Fory-only compatible-read +schema-mismatch mode. This mode is off by default. When enabled, run with +`--serializer fory`; protobuf and MessagePack benchmark modes fail with a +configuration error. Fory serialization uses the normal v1 benchmark structs, +and Fory deserialization uses v2 structs registered with the same Fory type IDs +where one int32 field is widened to int64. + ### Manual Run ```bash diff --git a/benchmarks/go/benchmark_report.py b/benchmarks/go/benchmark_report.py index cf62cb7bb7..3fca4f1aee 100755 --- a/benchmarks/go/benchmark_report.py +++ b/benchmarks/go/benchmark_report.py @@ -86,6 +86,8 @@ def format_ops_per_sec(value): + if value is None: + return "N/A" if HAS_MATPLOTLIB: return format_throughput_label(value) if value >= 1e6: @@ -93,6 +95,14 @@ def format_ops_per_sec(value): return f"{value / 1e3:.0f}K" +def ns_to_ops_per_sec(value): + return 1e9 / value if value is not None and value > 0 else None + + +def format_ns(value): + return f"{value:.1f}" if value is not None and value > 0 else "N/A" + + def format_ops_tick(value, _position): if HAS_MATPLOTLIB: return format_throughput_tick(value, _position) @@ -321,22 +331,24 @@ def generate_markdown_report(results, output_dir): continue data = results[datatype][op] - fory_ops = 1e9 / data.get("fory", float("inf")) if "fory" in data else 0 - pb_ops = ( - 1e9 / data.get("protobuf", float("inf")) if "protobuf" in data else 0 + fory_ops = ns_to_ops_per_sec(data.get("fory")) + pb_ops = ns_to_ops_per_sec(data.get("protobuf")) + mp_ops = ns_to_ops_per_sec(data.get("msgpack")) + + fory_str = format_ops_per_sec(fory_ops) + pb_str = format_ops_per_sec(pb_ops) + mp_str = format_ops_per_sec(mp_ops) + + fory_vs_pb = ( + f"{fory_ops / pb_ops:.2f}x" + if fory_ops is not None and pb_ops + else "N/A" ) - mp_ops = 1e9 / data.get("msgpack", float("inf")) if "msgpack" in data else 0 - - fory_str = ( - f"{fory_ops / 1e6:.2f}M" - if fory_ops >= 1e6 - else f"{fory_ops / 1e3:.0f}K" + fory_vs_mp = ( + f"{fory_ops / mp_ops:.2f}x" + if fory_ops is not None and mp_ops + else "N/A" ) - pb_str = f"{pb_ops / 1e6:.2f}M" if pb_ops >= 1e6 else f"{pb_ops / 1e3:.0f}K" - mp_str = f"{mp_ops / 1e6:.2f}M" if mp_ops >= 1e6 else f"{mp_ops / 1e3:.0f}K" - - fory_vs_pb = f"{fory_ops / pb_ops:.2f}x" if pb_ops > 0 else "N/A" - fory_vs_mp = f"{fory_ops / mp_ops:.2f}x" if mp_ops > 0 else "N/A" report.append( f"| {display_name(datatype)} | {op.title()} | {fory_str} | {pb_str} | {mp_str} | {fory_vs_pb} | {fory_vs_mp} |" @@ -357,9 +369,9 @@ def generate_markdown_report(results, output_dir): continue data = results[datatype][op] - fory_ns = f"{data.get('fory', 0):.1f}" - pb_ns = f"{data.get('protobuf', 0):.1f}" - mp_ns = f"{data.get('msgpack', 0):.1f}" + fory_ns = format_ns(data.get("fory")) + pb_ns = format_ns(data.get("protobuf")) + mp_ns = format_ns(data.get("msgpack")) report.append( f"| {display_name(datatype)} | {op.title()} | {fory_ns} | {pb_ns} | {mp_ns} |" diff --git a/benchmarks/go/benchmark_test.go b/benchmarks/go/benchmark_test.go index 0993ac9bd6..eabf4a039a 100644 --- a/benchmarks/go/benchmark_test.go +++ b/benchmarks/go/benchmark_test.go @@ -19,6 +19,7 @@ package benchmark import ( "fmt" + "os" "testing" pb "github.com/apache/fory/benchmarks/go/proto" @@ -37,7 +38,26 @@ func newFory() *fory.Fory { fory.WithTrackRef(false), fory.WithCompatible(true), ) - // Register types with IDs matching C++ benchmark + registerForyTypes(f) + return f +} + +func newForyReader() *fory.Fory { + f := fory.New( + fory.WithXlang(true), + fory.WithTrackRef(false), + fory.WithCompatible(true), + ) + if schemaMismatchEnabled() { + registerForyTypesV2(f) + } else { + registerForyTypes(f) + } + return f +} + +func registerForyTypes(f *fory.Fory) { + // Register types with IDs matching the existing Go benchmark schema. if err := f.RegisterStruct(NumericStruct{}, 1); err != nil { panic(err) } @@ -68,7 +88,49 @@ func newFory() *fory.Fory { if err := f.RegisterEnum(Size(0), 7); err != nil { panic(err) } - return f +} + +func registerForyTypesV2(f *fory.Fory) { + if err := f.RegisterStruct(NumericStructV2{}, 1); err != nil { + panic(err) + } + if err := f.RegisterStruct(SampleV2{}, 2); err != nil { + panic(err) + } + if err := f.RegisterStruct(MediaV2{}, 3); err != nil { + panic(err) + } + if err := f.RegisterStruct(ImageV2{}, 4); err != nil { + panic(err) + } + if err := f.RegisterStruct(MediaContentV2{}, 5); err != nil { + panic(err) + } + if err := f.RegisterStruct(NumericStructListV2{}, 8); err != nil { + panic(err) + } + if err := f.RegisterStruct(SampleListV2{}, 9); err != nil { + panic(err) + } + if err := f.RegisterStruct(MediaContentListV2{}, 10); err != nil { + panic(err) + } + if err := f.RegisterEnum(Player(0), 6); err != nil { + panic(err) + } + if err := f.RegisterEnum(Size(0), 7); err != nil { + panic(err) + } +} + +func schemaMismatchEnabled() bool { + return os.Getenv("FORY_BENCH_SCHEMA_MISMATCH") == "1" +} + +func rejectNonForySchemaMismatch(b testing.TB) { + if schemaMismatchEnabled() { + b.Fatalf("FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory") + } } // ============================================================================ @@ -91,6 +153,7 @@ func BenchmarkFory_NumericStruct_Serialize(b *testing.B) { } func BenchmarkProtobuf_NumericStruct_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateNumericStruct() b.ResetTimer() @@ -105,6 +168,7 @@ func BenchmarkProtobuf_NumericStruct_Serialize(b *testing.B) { } func BenchmarkMsgpack_NumericStruct_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateNumericStruct() b.ResetTimer() @@ -117,17 +181,36 @@ func BenchmarkMsgpack_NumericStruct_Serialize(b *testing.B) { } func BenchmarkFory_NumericStruct_Deserialize(b *testing.B) { - f := newFory() + writer := newFory() + reader := newForyReader() obj := CreateNumericStruct() - data, err := f.Serialize(&obj) + data, err := writer.Serialize(&obj) if err != nil { b.Fatal(err) } + if schemaMismatchEnabled() { + var result NumericStructV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + if result.F1 != int64(obj.F1) { + b.Fatal("NumericStructV2 schema mismatch read failed") + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + var result NumericStructV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + } + return + } + b.ResetTimer() for i := 0; i < b.N; i++ { var result NumericStruct - err := f.Deserialize(data, &result) + err := reader.Deserialize(data, &result) if err != nil { b.Fatal(err) } @@ -135,6 +218,7 @@ func BenchmarkFory_NumericStruct_Deserialize(b *testing.B) { } func BenchmarkProtobuf_NumericStruct_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateNumericStruct() pbObj := ToPbStruct(obj) data, err := proto.Marshal(pbObj) @@ -156,6 +240,7 @@ func BenchmarkProtobuf_NumericStruct_Deserialize(b *testing.B) { } func BenchmarkMsgpack_NumericStruct_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateNumericStruct() data, err := msgpack.Marshal(obj) if err != nil { @@ -192,6 +277,7 @@ func BenchmarkFory_NumericStructList_Serialize(b *testing.B) { } func BenchmarkProtobuf_NumericStructList_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateNumericStructList() b.ResetTimer() @@ -205,6 +291,7 @@ func BenchmarkProtobuf_NumericStructList_Serialize(b *testing.B) { } func BenchmarkMsgpack_NumericStructList_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateNumericStructList() b.ResetTimer() @@ -217,17 +304,36 @@ func BenchmarkMsgpack_NumericStructList_Serialize(b *testing.B) { } func BenchmarkFory_NumericStructList_Deserialize(b *testing.B) { - f := newFory() + writer := newFory() + reader := newForyReader() obj := CreateNumericStructList() - data, err := f.Serialize(&obj) + data, err := writer.Serialize(&obj) if err != nil { b.Fatal(err) } + if schemaMismatchEnabled() { + var result NumericStructListV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + if len(result.NumericStructs) == 0 || result.NumericStructs[0].F1 != int64(obj.NumericStructs[0].F1) { + b.Fatal("NumericStructListV2 schema mismatch read failed") + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + var result NumericStructListV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + } + return + } + b.ResetTimer() for i := 0; i < b.N; i++ { var result NumericStructList - err := f.Deserialize(data, &result) + err := reader.Deserialize(data, &result) if err != nil { b.Fatal(err) } @@ -235,6 +341,7 @@ func BenchmarkFory_NumericStructList_Deserialize(b *testing.B) { } func BenchmarkProtobuf_NumericStructList_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateNumericStructList() pbObj := ToPbNumericStructList(obj) data, err := proto.Marshal(pbObj) @@ -254,6 +361,7 @@ func BenchmarkProtobuf_NumericStructList_Deserialize(b *testing.B) { } func BenchmarkMsgpack_NumericStructList_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateNumericStructList() data, err := msgpack.Marshal(obj) if err != nil { @@ -290,6 +398,7 @@ func BenchmarkFory_Sample_Serialize(b *testing.B) { } func BenchmarkProtobuf_Sample_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateSample() b.ResetTimer() @@ -303,6 +412,7 @@ func BenchmarkProtobuf_Sample_Serialize(b *testing.B) { } func BenchmarkMsgpack_Sample_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateSample() b.ResetTimer() @@ -315,17 +425,36 @@ func BenchmarkMsgpack_Sample_Serialize(b *testing.B) { } func BenchmarkFory_Sample_Deserialize(b *testing.B) { - f := newFory() + writer := newFory() + reader := newForyReader() obj := CreateSample() - data, err := f.Serialize(&obj) + data, err := writer.Serialize(&obj) if err != nil { b.Fatal(err) } + if schemaMismatchEnabled() { + var result SampleV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + if result.IntValue != int64(obj.IntValue) { + b.Fatal("SampleV2 schema mismatch read failed") + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + var result SampleV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + } + return + } + b.ResetTimer() for i := 0; i < b.N; i++ { var result Sample - err := f.Deserialize(data, &result) + err := reader.Deserialize(data, &result) if err != nil { b.Fatal(err) } @@ -333,6 +462,7 @@ func BenchmarkFory_Sample_Deserialize(b *testing.B) { } func BenchmarkProtobuf_Sample_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateSample() pbObj := ToPbSample(obj) data, err := proto.Marshal(pbObj) @@ -352,6 +482,7 @@ func BenchmarkProtobuf_Sample_Deserialize(b *testing.B) { } func BenchmarkMsgpack_Sample_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateSample() data, err := msgpack.Marshal(obj) if err != nil { @@ -388,6 +519,7 @@ func BenchmarkFory_SampleList_Serialize(b *testing.B) { } func BenchmarkProtobuf_SampleList_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateSampleList() b.ResetTimer() @@ -401,6 +533,7 @@ func BenchmarkProtobuf_SampleList_Serialize(b *testing.B) { } func BenchmarkMsgpack_SampleList_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateSampleList() b.ResetTimer() @@ -413,17 +546,36 @@ func BenchmarkMsgpack_SampleList_Serialize(b *testing.B) { } func BenchmarkFory_SampleList_Deserialize(b *testing.B) { - f := newFory() + writer := newFory() + reader := newForyReader() obj := CreateSampleList() - data, err := f.Serialize(&obj) + data, err := writer.Serialize(&obj) if err != nil { b.Fatal(err) } + if schemaMismatchEnabled() { + var result SampleListV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + if len(result.SampleList) == 0 || result.SampleList[0].IntValue != int64(obj.SampleList[0].IntValue) { + b.Fatal("SampleListV2 schema mismatch read failed") + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + var result SampleListV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + } + return + } + b.ResetTimer() for i := 0; i < b.N; i++ { var result SampleList - err := f.Deserialize(data, &result) + err := reader.Deserialize(data, &result) if err != nil { b.Fatal(err) } @@ -431,6 +583,7 @@ func BenchmarkFory_SampleList_Deserialize(b *testing.B) { } func BenchmarkProtobuf_SampleList_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateSampleList() pbObj := ToPbSampleList(obj) data, err := proto.Marshal(pbObj) @@ -450,6 +603,7 @@ func BenchmarkProtobuf_SampleList_Deserialize(b *testing.B) { } func BenchmarkMsgpack_SampleList_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateSampleList() data, err := msgpack.Marshal(obj) if err != nil { @@ -486,6 +640,7 @@ func BenchmarkFory_MediaContent_Serialize(b *testing.B) { } func BenchmarkProtobuf_MediaContent_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateMediaContent() b.ResetTimer() @@ -499,6 +654,7 @@ func BenchmarkProtobuf_MediaContent_Serialize(b *testing.B) { } func BenchmarkMsgpack_MediaContent_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateMediaContent() b.ResetTimer() @@ -511,17 +667,37 @@ func BenchmarkMsgpack_MediaContent_Serialize(b *testing.B) { } func BenchmarkFory_MediaContent_Deserialize(b *testing.B) { - f := newFory() + writer := newFory() + reader := newForyReader() obj := CreateMediaContent() - data, err := f.Serialize(&obj) + data, err := writer.Serialize(&obj) if err != nil { b.Fatal(err) } + if schemaMismatchEnabled() { + var result MediaContentV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + if result.Media.Width != int64(obj.Media.Width) || + len(result.Images) == 0 || result.Images[0].Width != int64(obj.Images[0].Width) { + b.Fatal("MediaContentV2 schema mismatch read failed") + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + var result MediaContentV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + } + return + } + b.ResetTimer() for i := 0; i < b.N; i++ { var result MediaContent - err := f.Deserialize(data, &result) + err := reader.Deserialize(data, &result) if err != nil { b.Fatal(err) } @@ -529,6 +705,7 @@ func BenchmarkFory_MediaContent_Deserialize(b *testing.B) { } func BenchmarkProtobuf_MediaContent_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateMediaContent() pbObj := ToPbMediaContent(obj) data, err := proto.Marshal(pbObj) @@ -548,6 +725,7 @@ func BenchmarkProtobuf_MediaContent_Deserialize(b *testing.B) { } func BenchmarkMsgpack_MediaContent_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateMediaContent() data, err := msgpack.Marshal(obj) if err != nil { @@ -584,6 +762,7 @@ func BenchmarkFory_MediaContentList_Serialize(b *testing.B) { } func BenchmarkProtobuf_MediaContentList_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateMediaContentList() b.ResetTimer() @@ -597,6 +776,7 @@ func BenchmarkProtobuf_MediaContentList_Serialize(b *testing.B) { } func BenchmarkMsgpack_MediaContentList_Serialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateMediaContentList() b.ResetTimer() @@ -609,17 +789,39 @@ func BenchmarkMsgpack_MediaContentList_Serialize(b *testing.B) { } func BenchmarkFory_MediaContentList_Deserialize(b *testing.B) { - f := newFory() + writer := newFory() + reader := newForyReader() obj := CreateMediaContentList() - data, err := f.Serialize(&obj) + data, err := writer.Serialize(&obj) if err != nil { b.Fatal(err) } + if schemaMismatchEnabled() { + var result MediaContentListV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + if len(result.MediaContentList) == 0 || + result.MediaContentList[0].Media.Width != int64(obj.MediaContentList[0].Media.Width) || + len(result.MediaContentList[0].Images) == 0 || + result.MediaContentList[0].Images[0].Width != int64(obj.MediaContentList[0].Images[0].Width) { + b.Fatal("MediaContentListV2 schema mismatch read failed") + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + var result MediaContentListV2 + if err := reader.Deserialize(data, &result); err != nil { + b.Fatal(err) + } + } + return + } + b.ResetTimer() for i := 0; i < b.N; i++ { var result MediaContentList - err := f.Deserialize(data, &result) + err := reader.Deserialize(data, &result) if err != nil { b.Fatal(err) } @@ -627,6 +829,7 @@ func BenchmarkFory_MediaContentList_Deserialize(b *testing.B) { } func BenchmarkProtobuf_MediaContentList_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateMediaContentList() pbObj := ToPbMediaContentList(obj) data, err := proto.Marshal(pbObj) @@ -646,6 +849,7 @@ func BenchmarkProtobuf_MediaContentList_Deserialize(b *testing.B) { } func BenchmarkMsgpack_MediaContentList_Deserialize(b *testing.B) { + rejectNonForySchemaMismatch(b) obj := CreateMediaContentList() data, err := msgpack.Marshal(obj) if err != nil { @@ -668,6 +872,38 @@ func BenchmarkMsgpack_MediaContentList_Deserialize(b *testing.B) { func TestPrintSerializedSizes(t *testing.T) { f := newFory() + if schemaMismatchEnabled() { + numericStruct := CreateNumericStruct() + foryStructData, _ := f.Serialize(&numericStruct) + sample := CreateSample() + forySampleData, _ := f.Serialize(&sample) + mediaContent := CreateMediaContent() + foryMediaData, _ := f.Serialize(&mediaContent) + structList := CreateNumericStructList() + foryStructListData, _ := f.Serialize(&structList) + sampleList := CreateSampleList() + forySampleListData, _ := f.Serialize(&sampleList) + mediaContentList := CreateMediaContentList() + foryMediaContentListData, _ := f.Serialize(&mediaContentList) + + fmt.Println("============================================") + fmt.Println("Serialized Sizes (bytes):") + fmt.Println("============================================") + fmt.Printf("NumericStruct:\n") + fmt.Printf(" Fory: %d bytes\n", len(foryStructData)) + fmt.Printf("Sample:\n") + fmt.Printf(" Fory: %d bytes\n", len(forySampleData)) + fmt.Printf("MediaContent:\n") + fmt.Printf(" Fory: %d bytes\n", len(foryMediaData)) + fmt.Printf("NumericStructList:\n") + fmt.Printf(" Fory: %d bytes\n", len(foryStructListData)) + fmt.Printf("SampleList:\n") + fmt.Printf(" Fory: %d bytes\n", len(forySampleListData)) + fmt.Printf("MediaContentList:\n") + fmt.Printf(" Fory: %d bytes\n", len(foryMediaContentListData)) + fmt.Println("============================================") + return + } // NumericStruct sizes numericStruct := CreateNumericStruct() diff --git a/benchmarks/go/models.go b/benchmarks/go/models.go index b3c10acc9e..cb4e5c0c5b 100644 --- a/benchmarks/go/models.go +++ b/benchmarks/go/models.go @@ -34,6 +34,21 @@ type NumericStruct struct { F12 int32 `msgpack:"12" fory:"id=12"` } +type NumericStructV2 struct { + F1 int64 `fory:"id=1"` + F2 int32 `fory:"id=2"` + F3 int32 `fory:"id=3"` + F4 int32 `fory:"id=4"` + F5 int32 `fory:"id=5"` + F6 int32 `fory:"id=6"` + F7 int32 `fory:"id=7"` + F8 int32 `fory:"id=8"` + F9 int32 `fory:"id=9"` + F10 int32 `fory:"id=10"` + F11 int32 `fory:"id=11"` + F12 int32 `fory:"id=12"` +} + // Sample is a complex struct with various types and arrays // Matches the C++ Sample and protobuf Sample message type Sample struct { @@ -61,6 +76,31 @@ type Sample struct { String string `msgpack:"22" fory:"id=22"` } +type SampleV2 struct { + IntValue int64 `fory:"id=1"` + LongValue int64 `fory:"id=2"` + FloatValue float32 `fory:"id=3"` + DoubleValue float64 `fory:"id=4"` + ShortValue int32 `fory:"id=5"` + CharValue int32 `fory:"id=6"` + BooleanValue bool `fory:"id=7"` + IntValueBoxed int32 `fory:"id=8"` + LongValueBoxed int64 `fory:"id=9"` + FloatValueBoxed float32 `fory:"id=10"` + DoubleValueBoxed float64 `fory:"id=11"` + ShortValueBoxed int32 `fory:"id=12"` + CharValueBoxed int32 `fory:"id=13"` + BooleanValueBoxed bool `fory:"id=14"` + IntArray []int32 `fory:"id=15,type=array(element=int32)"` + LongArray []int64 `fory:"id=16,type=array(element=int64)"` + FloatArray []float32 `fory:"id=17,type=array(element=float32)"` + DoubleArray []float64 `fory:"id=18,type=array(element=float64)"` + ShortArray []int32 `fory:"id=19,type=array(element=int32)"` + CharArray []int32 `fory:"id=20,type=array(element=int32)"` + BooleanArray []bool `fory:"id=21,type=array(element=bool)"` + String string `fory:"id=22"` +} + // Player enum type type Player int32 @@ -93,6 +133,21 @@ type Media struct { Copyright string `msgpack:"12" fory:"id=12"` } +type MediaV2 struct { + URI string `fory:"id=1"` + Title string `fory:"id=2"` + Width int64 `fory:"id=3"` + Height int32 `fory:"id=4"` + Format string `fory:"id=5"` + Duration int64 `fory:"id=6"` + Size int64 `fory:"id=7"` + Bitrate int32 `fory:"id=8"` + HasBitrate bool `fory:"id=9"` + Persons []string `fory:"id=10"` + Player Player `fory:"id=11"` + Copyright string `fory:"id=12"` +} + // Image represents image metadata type Image struct { URI string `msgpack:"1" fory:"id=1"` @@ -102,24 +157,49 @@ type Image struct { Size Size `msgpack:"5" fory:"id=5"` } +type ImageV2 struct { + URI string `fory:"id=1"` + Title string `fory:"id=2"` + Width int64 `fory:"id=3"` + Height int32 `fory:"id=4"` + Size Size `fory:"id=5"` +} + // MediaContent contains media and images type MediaContent struct { Media Media `msgpack:"1" fory:"id=1"` Images []Image `msgpack:"2" fory:"id=2"` } +type MediaContentV2 struct { + Media MediaV2 `fory:"id=1"` + Images []ImageV2 `fory:"id=2"` +} + type NumericStructList struct { NumericStructs []NumericStruct `msgpack:"1" fory:"id=1"` } +type NumericStructListV2 struct { + NumericStructs []NumericStructV2 `fory:"id=1"` +} + type SampleList struct { SampleList []Sample `msgpack:"1" fory:"id=1"` } +type SampleListV2 struct { + SampleList []SampleV2 `fory:"id=1"` +} + type MediaContentList struct { MediaContentList []MediaContent `msgpack:"1" fory:"id=1"` } +type MediaContentListV2 struct { + MediaContentList []MediaContentV2 `fory:"id=1"` +} + // CreateNumericStruct creates test data matching C++ benchmark func CreateNumericStruct() NumericStruct { return NumericStruct{ diff --git a/benchmarks/go/run.sh b/benchmarks/go/run.sh index 98b1089ca7..fea1f8fc30 100755 --- a/benchmarks/go/run.sh +++ b/benchmarks/go/run.sh @@ -87,6 +87,11 @@ while [[ $# -gt 0 ]]; do esac done +if [[ "${FORY_BENCH_SCHEMA_MISMATCH:-0}" == "1" && "$SERIALIZER" != "fory" ]]; then + echo "FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory." + exit 1 +fi + # Build benchmark filter FILTER="" if [[ -n "$DATA_TYPE" ]]; then diff --git a/benchmarks/java/README.md b/benchmarks/java/README.md index 38f40281ea..d74c237838 100644 --- a/benchmarks/java/README.md +++ b/benchmarks/java/README.md @@ -65,6 +65,15 @@ See `org.openjdk.jmh.runner.options.CommandLineOptions` for more information abo -rf Result format type ``` +## Schema Mismatch Mode + +Set `FORY_BENCH_SCHEMA_MISMATCH=1` to run the Fory-only compatible-read +schema-mismatch mode for `XlangBenchmark`. This mode is off by default. When +enabled, run with `--serializer fory`; protobuf and FlatBuffers benchmark modes +fail with a configuration error. Fory serialization uses the normal v1 benchmark +classes, and Fory deserialization uses v2 classes registered with the same Fory +type IDs where one int32 field is widened to int64. + Save benchmark data to specified dir, then run `tool.py` to plot graphs. ## Plotting diff --git a/benchmarks/java/benchmark_report.py b/benchmarks/java/benchmark_report.py index 7bf472b00c..eb0327a6bf 100755 --- a/benchmarks/java/benchmark_report.py +++ b/benchmarks/java/benchmark_report.py @@ -50,9 +50,10 @@ apply_benchmark_style(plt) -SERIALIZER_ORDER = ["fory", "protobuf", "flatbuffer"] +SERIALIZER_ORDER = ["fory-codegen=true", "fory-codegen=false", "protobuf", "flatbuffer"] COLORS = { - "fory": "#FF6f01", + "fory-codegen=true": "#FF6f01", + "fory-codegen=false": "#C94700", "protobuf": "#55BCC2", "flatbuffer": (0.55, 0.40, 0.45), } @@ -146,6 +147,9 @@ def collect_results(payload: Any) -> dict: score = float(metric.get("score", bench.get("opsPerSec", 0.0))) unit = metric.get("scoreUnit", "ops/s") serializer = match.group("serializer").lower() + if serializer == "fory": + codegen = str(bench.get("params", {}).get("codegen", "unknown")).lower() + serializer = f"fory-codegen={codegen}" datatype = datatype_key(match.group("datatype")) operation = match.group("operation").lower() results[datatype][operation][serializer] = score_to_ops_per_sec(score, unit) @@ -230,12 +234,18 @@ def format_table_value(value: float) -> str: return f"{value:,.0f}" if value > 0 else "N/A" +def serializer_title(serializer: str) -> str: + if serializer.startswith("fory-codegen="): + return "Fory " + serializer[len("fory-") :] + return serializer.capitalize() + + def winner_cell(values: dict) -> str: positive = {name: value for name, value in values.items() if value > 0} if not positive: return "N/A" winner = max(positive, key=positive.get) - return winner.capitalize() + return serializer_title(winner) def build_xlang_section(results: dict, image_name: str) -> str: @@ -251,8 +261,14 @@ def build_xlang_section(results: dict, image_name: str) -> str: "JMH parameters: `-f 1 -wi 3 -i 3 -t 1 -w 3s -r 3s -bm thrpt -tu s`. " "Higher throughput is better.\n\n", f"![Java Xlang Serialization Throughput]({image_name})\n\n", - "| Data type | Operation | Fory ops/sec | Protobuf ops/sec | Flatbuffer ops/sec | Fastest |\n", - "|-----------|-----------|--------------|------------------|--------------------|---------|\n", + "| Data type | Operation | " + + " | ".join( + f"{serializer_title(serializer)} ops/sec" for serializer in SERIALIZER_ORDER + ) + + " | Fastest |\n", + "|-----------|-----------|" + + "|".join("---:" for _ in SERIALIZER_ORDER) + + "|---------|\n", ] for datatype in DATATYPE_ORDER: diff --git a/benchmarks/java/run.sh b/benchmarks/java/run.sh index ea22143ba4..58e68250c7 100755 --- a/benchmarks/java/run.sh +++ b/benchmarks/java/run.sh @@ -170,6 +170,11 @@ while [[ $# -gt 0 ]]; do esac done +if [[ "${FORY_BENCH_SCHEMA_MISMATCH:-0}" == "1" && "$(lower "${SERIALIZER_FILTER}")" != "fory" ]]; then + echo "FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory." >&2 + exit 1 +fi + SERIALIZERS="$(serializer_regex "${SERIALIZER_FILTER}")" DATA_TYPES="$(data_regex "${DATA_FILTER}")" JMH_DURATION="$(jmh_time "${DURATION_SECONDS}")" @@ -184,18 +189,24 @@ if [[ "${SKIP_BUILD}" != "true" ]]; then mvn -Pjmh -DskipTests package fi -ENABLE_FORY_DEBUG_OUTPUT=0 \ - java -jar target/benchmarks.jar "${BENCHMARK_REGEX}" \ - -f 1 \ - -wi 3 \ - -i 3 \ - -t 1 \ - -w "${JMH_DURATION}" \ - -r "${JMH_DURATION}" \ - -bm thrpt \ - -tu s \ - -rf json \ +JMH_CMD=( + java --add-opens=java.base/java.lang.invoke=ALL-UNNAMED + -jar target/benchmarks.jar "${BENCHMARK_REGEX}" + -f 1 + -wi 3 + -i 3 + -t 1 +) +JMH_CMD+=( + -w "${JMH_DURATION}" + -r "${JMH_DURATION}" + -bm thrpt + -tu s + -rf json -rff "${RESULT_JSON}" +) + +ENABLE_FORY_DEBUG_OUTPUT=0 "${JMH_CMD[@]}" if [[ "${GENERATE_REPORT}" == "true" ]]; then REPORT_ARGS=(--json-file "${RESULT_JSON}" --output-dir "${REPORT_DIR}") diff --git a/benchmarks/java/src/main/java/org/apache/fory/benchmark/XlangBenchmark.java b/benchmarks/java/src/main/java/org/apache/fory/benchmark/XlangBenchmark.java index 08051e40f8..0f8ff25dad 100644 --- a/benchmarks/java/src/main/java/org/apache/fory/benchmark/XlangBenchmark.java +++ b/benchmarks/java/src/main/java/org/apache/fory/benchmark/XlangBenchmark.java @@ -64,13 +64,16 @@ @CompilerControl(value = CompilerControl.Mode.INLINE) public class XlangBenchmark { private static final int LIST_SIZE = 5; + private static final String SCHEMA_MISMATCH_ENV = "FORY_BENCH_SCHEMA_MISMATCH"; @State(Scope.Thread) public static class XlangState { @Param({"true", "false"}) public boolean codegen; - public Fory fory; + public Fory foryWriter; + public Fory foryReader; + public boolean schemaMismatch; public NumericStruct numericStruct; public Sample sample; @@ -109,16 +112,15 @@ public static class XlangState { @Setup(Level.Trial) public void setup() { - fory = - Fory.builder() - .withXlang(true) - .withCompatible(true) - .withCodegen(codegen) - .withRefTracking(false) - .withClassVersionCheck(false) - .requireClassRegistration(true) - .build(); - registerForyTypes(fory); + schemaMismatch = schemaMismatchEnabled(); + foryWriter = newBenchmarkFory(codegen); + registerForyTypes(foryWriter); + if (schemaMismatch) { + foryReader = newBenchmarkFory(codegen); + registerForyTypesV2(foryReader); + } else { + foryReader = foryWriter; + } numericStruct = createNumericStruct(); sample = createSample(); @@ -127,47 +129,56 @@ public void setup() { sampleList = createSampleList(); mediaContentList = createMediaContentList(); - foryNumericStructBytes = fory.serialize(numericStruct); - forySampleBytes = fory.serialize(sample); - foryMediaContentBytes = fory.serialize(mediaContent); - foryNumericStructListBytes = fory.serialize(numericStructList); - forySampleListBytes = fory.serialize(sampleList); - foryMediaContentListBytes = fory.serialize(mediaContentList); - - protobufNumericStructBytes = toFixedProto(numericStruct).toByteArray(); - protobufSampleBytes = toProto(sample).toByteArray(); - protobufMediaContentBytes = toProto(mediaContent).toByteArray(); - protobufNumericStructListBytes = toProto(numericStructList).toByteArray(); - protobufSampleListBytes = toProto(sampleList).toByteArray(); - protobufMediaContentListBytes = toProto(mediaContentList).toByteArray(); - - flatbufferNumericStructBytes = toFlatBuffer(numericStruct); - flatbufferSampleBytes = toFlatBuffer(sample); - flatbufferMediaContentBytes = toFlatBuffer(mediaContent); - flatbufferNumericStructListBytes = toFlatBuffer(numericStructList); - flatbufferSampleListBytes = toFlatBuffer(sampleList); - flatbufferMediaContentListBytes = toFlatBuffer(mediaContentList); - - flatbufferNumericStructBuffer = ByteBuffer.wrap(flatbufferNumericStructBytes); - flatbufferSampleBuffer = ByteBuffer.wrap(flatbufferSampleBytes); - flatbufferMediaContentBuffer = ByteBuffer.wrap(flatbufferMediaContentBytes); - flatbufferNumericStructListBuffer = ByteBuffer.wrap(flatbufferNumericStructListBytes); - flatbufferSampleListBuffer = ByteBuffer.wrap(flatbufferSampleListBytes); - flatbufferMediaContentListBuffer = ByteBuffer.wrap(flatbufferMediaContentListBytes); + foryNumericStructBytes = foryWriter.serialize(numericStruct); + forySampleBytes = foryWriter.serialize(sample); + foryMediaContentBytes = foryWriter.serialize(mediaContent); + foryNumericStructListBytes = foryWriter.serialize(numericStructList); + forySampleListBytes = foryWriter.serialize(sampleList); + foryMediaContentListBytes = foryWriter.serialize(mediaContentList); + + if (!schemaMismatch) { + protobufNumericStructBytes = toFixedProto(numericStruct).toByteArray(); + protobufSampleBytes = toProto(sample).toByteArray(); + protobufMediaContentBytes = toProto(mediaContent).toByteArray(); + protobufNumericStructListBytes = toProto(numericStructList).toByteArray(); + protobufSampleListBytes = toProto(sampleList).toByteArray(); + protobufMediaContentListBytes = toProto(mediaContentList).toByteArray(); + + flatbufferNumericStructBytes = toFlatBuffer(numericStruct); + flatbufferSampleBytes = toFlatBuffer(sample); + flatbufferMediaContentBytes = toFlatBuffer(mediaContent); + flatbufferNumericStructListBytes = toFlatBuffer(numericStructList); + flatbufferSampleListBytes = toFlatBuffer(sampleList); + flatbufferMediaContentListBytes = toFlatBuffer(mediaContentList); + + flatbufferNumericStructBuffer = ByteBuffer.wrap(flatbufferNumericStructBytes); + flatbufferSampleBuffer = ByteBuffer.wrap(flatbufferSampleBytes); + flatbufferMediaContentBuffer = ByteBuffer.wrap(flatbufferMediaContentBytes); + flatbufferNumericStructListBuffer = ByteBuffer.wrap(flatbufferNumericStructListBytes); + flatbufferSampleListBuffer = ByteBuffer.wrap(flatbufferSampleListBytes); + flatbufferMediaContentListBuffer = ByteBuffer.wrap(flatbufferMediaContentListBytes); + } verifySetup(); } private void verifySetup() { - verifyForySerializerMode(NumericStruct.class); - verifyForySerializerMode(Sample.class); - verifyForySerializerMode(MediaContent.class); - fory.deserialize(foryNumericStructBytes); - fromProtoStruct(protobufNumericStructBytes); - fromFlatBufferNumericStruct(flatbufferNumericStructBuffer); + verifyForySerializerMode(foryWriter, NumericStruct.class); + verifyForySerializerMode(foryWriter, Sample.class); + verifyForySerializerMode(foryWriter, MediaContent.class); + if (schemaMismatch) { + verifyForySerializerMode(foryReader, NumericStructV2.class); + verifyForySerializerMode(foryReader, SampleV2.class); + verifyForySerializerMode(foryReader, MediaContentV2.class); + verifySchemaMismatch(); + } else { + foryReader.deserialize(foryNumericStructBytes); + fromProtoStruct(protobufNumericStructBytes); + fromFlatBufferNumericStruct(flatbufferNumericStructBuffer); + } } - private void verifyForySerializerMode(Class type) { + private void verifyForySerializerMode(Fory fory, Class type) { Serializer serializer = fory.getTypeResolver().getSerializer(type); boolean staticSerializer = serializer instanceof StaticGeneratedStructSerializer; if (staticSerializer == codegen) { @@ -180,6 +191,67 @@ private void verifyForySerializerMode(Class type) { + serializer.getClass().getName()); } } + + private void verifySchemaMismatch() { + NumericStructV2 numeric = (NumericStructV2) foryReader.deserialize(foryNumericStructBytes); + if (numeric.f1 != numericStruct.f1) { + throw new IllegalStateException("NumericStructV2 schema mismatch read failed"); + } + SampleV2 sampleV2 = (SampleV2) foryReader.deserialize(forySampleBytes); + if (sampleV2.int_value != sample.int_value) { + throw new IllegalStateException("SampleV2 schema mismatch read failed"); + } + MediaContentV2 mediaContentV2 = + (MediaContentV2) foryReader.deserialize(foryMediaContentBytes); + if (mediaContentV2.media.width != mediaContent.media.width + || mediaContentV2.images.isEmpty() + || mediaContentV2.images.get(0).width != mediaContent.images.get(0).width) { + throw new IllegalStateException("MediaContentV2 schema mismatch read failed"); + } + NumericStructListV2 structList = + (NumericStructListV2) foryReader.deserialize(foryNumericStructListBytes); + if (structList.struct_list.isEmpty() + || structList.struct_list.get(0).f1 != numericStructList.struct_list.get(0).f1) { + throw new IllegalStateException("NumericStructListV2 schema mismatch read failed"); + } + SampleListV2 sampleListV2 = (SampleListV2) foryReader.deserialize(forySampleListBytes); + if (sampleListV2.sample_list.isEmpty() + || sampleListV2.sample_list.get(0).int_value != sampleList.sample_list.get(0).int_value) { + throw new IllegalStateException("SampleListV2 schema mismatch read failed"); + } + MediaContentListV2 mediaList = + (MediaContentListV2) foryReader.deserialize(foryMediaContentListBytes); + if (mediaList.media_content_list.isEmpty() + || mediaList.media_content_list.get(0).media.width + != mediaContentList.media_content_list.get(0).media.width + || mediaList.media_content_list.get(0).images.isEmpty() + || mediaList.media_content_list.get(0).images.get(0).width + != mediaContentList.media_content_list.get(0).images.get(0).width) { + throw new IllegalStateException("MediaContentListV2 schema mismatch read failed"); + } + } + } + + private static Fory newBenchmarkFory(boolean codegen) { + return Fory.builder() + .withXlang(true) + .withCompatible(true) + .withCodegen(codegen) + .withRefTracking(false) + .withClassVersionCheck(false) + .requireClassRegistration(true) + .build(); + } + + private static boolean schemaMismatchEnabled() { + return "1".equals(System.getenv(SCHEMA_MISMATCH_ENV)); + } + + private static void rejectNonForySchemaMismatch() { + if (schemaMismatchEnabled()) { + throw new IllegalStateException( + SCHEMA_MISMATCH_ENV + "=1 supports only Fory benchmarks; rerun with --serializer fory"); + } } private static void registerForyTypes(Fory fory) { @@ -195,183 +267,220 @@ private static void registerForyTypes(Fory fory) { fory.register(Size.class, 102); } + private static void registerForyTypesV2(Fory fory) { + fory.register(NumericStructV2.class, 1); + fory.register(SampleV2.class, 2); + fory.register(MediaV2.class, 3); + fory.register(ImageV2.class, 4); + fory.register(MediaContentV2.class, 5); + fory.register(NumericStructListV2.class, 6); + fory.register(SampleListV2.class, 7); + fory.register(MediaContentListV2.class, 8); + fory.register(Player.class, 101); + fory.register(Size.class, 102); + } + @Benchmark public Object BM_Fory_NumericStruct_Serialize(XlangState state) { - return state.fory.serialize(state.numericStruct); + return state.foryWriter.serialize(state.numericStruct); } @Benchmark public Object BM_Fory_NumericStruct_Deserialize(XlangState state) { - return state.fory.deserialize(state.foryNumericStructBytes); + return state.foryReader.deserialize(state.foryNumericStructBytes); } @Benchmark public Object BM_Protobuf_NumericStruct_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toFixedProto(state.numericStruct).toByteArray(); } @Benchmark public Object BM_Protobuf_NumericStruct_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromProtoStruct(state.protobufNumericStructBytes); } @Benchmark public Object BM_Flatbuffer_NumericStruct_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toFlatBuffer(state.numericStruct); } @Benchmark public Object BM_Flatbuffer_NumericStruct_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromFlatBufferNumericStruct(state.flatbufferNumericStructBuffer); } @Benchmark public Object BM_Fory_Sample_Serialize(XlangState state) { - return state.fory.serialize(state.sample); + return state.foryWriter.serialize(state.sample); } @Benchmark public Object BM_Fory_Sample_Deserialize(XlangState state) { - return state.fory.deserialize(state.forySampleBytes); + return state.foryReader.deserialize(state.forySampleBytes); } @Benchmark public Object BM_Protobuf_Sample_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toProto(state.sample).toByteArray(); } @Benchmark public Object BM_Protobuf_Sample_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromProtoSample(state.protobufSampleBytes); } @Benchmark public Object BM_Flatbuffer_Sample_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toFlatBuffer(state.sample); } @Benchmark public Object BM_Flatbuffer_Sample_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromFlatBufferSample(state.flatbufferSampleBuffer); } @Benchmark public Object BM_Fory_MediaContent_Serialize(XlangState state) { - return state.fory.serialize(state.mediaContent); + return state.foryWriter.serialize(state.mediaContent); } @Benchmark public Object BM_Fory_MediaContent_Deserialize(XlangState state) { - return state.fory.deserialize(state.foryMediaContentBytes); + return state.foryReader.deserialize(state.foryMediaContentBytes); } @Benchmark public Object BM_Protobuf_MediaContent_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toProto(state.mediaContent).toByteArray(); } @Benchmark public Object BM_Protobuf_MediaContent_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromProtoMediaContent(state.protobufMediaContentBytes); } @Benchmark public Object BM_Flatbuffer_MediaContent_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toFlatBuffer(state.mediaContent); } @Benchmark public Object BM_Flatbuffer_MediaContent_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromFlatBufferMediaContent(state.flatbufferMediaContentBuffer); } @Benchmark public Object BM_Fory_NumericStructList_Serialize(XlangState state) { - return state.fory.serialize(state.numericStructList); + return state.foryWriter.serialize(state.numericStructList); } @Benchmark public Object BM_Fory_NumericStructList_Deserialize(XlangState state) { - return state.fory.deserialize(state.foryNumericStructListBytes); + return state.foryReader.deserialize(state.foryNumericStructListBytes); } @Benchmark public Object BM_Protobuf_NumericStructList_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toProto(state.numericStructList).toByteArray(); } @Benchmark public Object BM_Protobuf_NumericStructList_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromProtoNumericStructList(state.protobufNumericStructListBytes); } @Benchmark public Object BM_Flatbuffer_NumericStructList_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toFlatBuffer(state.numericStructList); } @Benchmark public Object BM_Flatbuffer_NumericStructList_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromFlatBufferNumericStructList(state.flatbufferNumericStructListBuffer); } @Benchmark public Object BM_Fory_SampleList_Serialize(XlangState state) { - return state.fory.serialize(state.sampleList); + return state.foryWriter.serialize(state.sampleList); } @Benchmark public Object BM_Fory_SampleList_Deserialize(XlangState state) { - return state.fory.deserialize(state.forySampleListBytes); + return state.foryReader.deserialize(state.forySampleListBytes); } @Benchmark public Object BM_Protobuf_SampleList_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toProto(state.sampleList).toByteArray(); } @Benchmark public Object BM_Protobuf_SampleList_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromProtoSampleList(state.protobufSampleListBytes); } @Benchmark public Object BM_Flatbuffer_SampleList_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toFlatBuffer(state.sampleList); } @Benchmark public Object BM_Flatbuffer_SampleList_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromFlatBufferSampleList(state.flatbufferSampleListBuffer); } @Benchmark public Object BM_Fory_MediaContentList_Serialize(XlangState state) { - return state.fory.serialize(state.mediaContentList); + return state.foryWriter.serialize(state.mediaContentList); } @Benchmark public Object BM_Fory_MediaContentList_Deserialize(XlangState state) { - return state.fory.deserialize(state.foryMediaContentListBytes); + return state.foryReader.deserialize(state.foryMediaContentListBytes); } @Benchmark public Object BM_Protobuf_MediaContentList_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toProto(state.mediaContentList).toByteArray(); } @Benchmark public Object BM_Protobuf_MediaContentList_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromProtoMediaContentList(state.protobufMediaContentListBytes); } @Benchmark public Object BM_Flatbuffer_MediaContentList_Serialize(XlangState state) { + rejectNonForySchemaMismatch(); return toFlatBuffer(state.mediaContentList); } @Benchmark public Object BM_Flatbuffer_MediaContentList_Deserialize(XlangState state) { + rejectNonForySchemaMismatch(); return fromFlatBufferMediaContentList(state.flatbufferMediaContentListBuffer); } @@ -1317,4 +1426,196 @@ public static class MediaContentList { @ForyField(id = 1) public List media_content_list; } + + @ForyStruct + public static class NumericStructV2 { + @ForyField(id = 1) + public long f1; + + @ForyField(id = 2) + public @Int32Type(encoding = Int32Encoding.FIXED) int f2; + + @ForyField(id = 3) + public @Int32Type(encoding = Int32Encoding.FIXED) int f3; + + @ForyField(id = 4) + public @Int32Type(encoding = Int32Encoding.FIXED) int f4; + + @ForyField(id = 5) + public @Int32Type(encoding = Int32Encoding.FIXED) int f5; + + @ForyField(id = 6) + public @Int32Type(encoding = Int32Encoding.FIXED) int f6; + + @ForyField(id = 7) + public @Int32Type(encoding = Int32Encoding.FIXED) int f7; + + @ForyField(id = 8) + public @Int32Type(encoding = Int32Encoding.FIXED) int f8; + + @ForyField(id = 9) + public @Int32Type(encoding = Int32Encoding.FIXED) int f9; + + @ForyField(id = 10) + public @Int32Type(encoding = Int32Encoding.FIXED) int f10; + + @ForyField(id = 11) + public @Int32Type(encoding = Int32Encoding.FIXED) int f11; + + @ForyField(id = 12) + public @Int32Type(encoding = Int32Encoding.FIXED) int f12; + } + + @ForyStruct + public static class SampleV2 { + @ForyField(id = 1) + public long int_value; + + @ForyField(id = 2) + public long long_value; + + @ForyField(id = 3) + public float float_value; + + @ForyField(id = 4) + public double double_value; + + @ForyField(id = 5) + public int short_value; + + @ForyField(id = 6) + public int char_value; + + @ForyField(id = 7) + public boolean boolean_value; + + @ForyField(id = 8) + public int int_value_boxed; + + @ForyField(id = 9) + public long long_value_boxed; + + @ForyField(id = 10) + public float float_value_boxed; + + @ForyField(id = 11) + public double double_value_boxed; + + @ForyField(id = 12) + public int short_value_boxed; + + @ForyField(id = 13) + public int char_value_boxed; + + @ForyField(id = 14) + public boolean boolean_value_boxed; + + @ForyField(id = 15) + public int[] int_array; + + @ForyField(id = 16) + public long[] long_array; + + @ForyField(id = 17) + public float[] float_array; + + @ForyField(id = 18) + public double[] double_array; + + @ForyField(id = 19) + public int[] short_array; + + @ForyField(id = 20) + public int[] char_array; + + @ForyField(id = 21) + public boolean[] boolean_array; + + @ForyField(id = 22) + public String string; + } + + @ForyStruct + public static class MediaV2 { + @ForyField(id = 1) + public String uri; + + @ForyField(id = 2) + public String title; + + @ForyField(id = 3) + public long width; + + @ForyField(id = 4) + public int height; + + @ForyField(id = 5) + public String format; + + @ForyField(id = 6) + public long duration; + + @ForyField(id = 7) + public long size; + + @ForyField(id = 8) + public int bitrate; + + @ForyField(id = 9) + public boolean has_bitrate; + + @ForyField(id = 10) + public List persons; + + @ForyField(id = 11) + public Player player; + + @ForyField(id = 12) + public String copyright; + } + + @ForyStruct + public static class ImageV2 { + @ForyField(id = 1) + public String uri; + + @ForyField(id = 2) + public String title; + + @ForyField(id = 3) + public long width; + + @ForyField(id = 4) + public int height; + + @ForyField(id = 5) + public Size size; + } + + @ForyStruct + public static class MediaContentV2 { + @ForyField(id = 1) + public MediaV2 media; + + @ForyField(id = 2) + public List images; + } + + @ForyStruct + public static class NumericStructListV2 { + @ForyField(id = 1) + public List struct_list; + } + + @ForyStruct + public static class SampleListV2 { + @ForyField(id = 1) + public List sample_list; + } + + @ForyStruct + public static class MediaContentListV2 { + @ForyField(id = 1) + public List media_content_list; + } } diff --git a/benchmarks/javascript/README.md b/benchmarks/javascript/README.md index ac64bd36ea..3f65ec383c 100644 --- a/benchmarks/javascript/README.md +++ b/benchmarks/javascript/README.md @@ -43,6 +43,15 @@ Examples: ./run.sh --data sample --serializer protobuf --duration 10 ``` +## Schema Mismatch Mode + +Set `FORY_BENCH_SCHEMA_MISMATCH=1` to run the Fory-only compatible-read +schema-mismatch mode. This mode is off by default. When enabled, run with +`--serializer fory`; protobuf and JSON benchmark modes fail with a configuration +error. Fory serialization uses the normal v1 schemas, and Fory deserialization +uses v2 schemas registered with the same Fory type IDs where one int32 field is +widened to int64. + ## Generated Artifacts Running the pipeline writes: diff --git a/benchmarks/javascript/benchmark.js b/benchmarks/javascript/benchmark.js index 1365d80ba9..b2d3edd432 100644 --- a/benchmarks/javascript/benchmark.js +++ b/benchmarks/javascript/benchmark.js @@ -32,6 +32,7 @@ const Fory = core.default; const { BoolArray, Type } = core; const DEFAULT_DURATION_SECONDS = 3; +const SCHEMA_MISMATCH_ENV = "FORY_BENCH_SCHEMA_MISMATCH"; const SERIALIZER_ORDER = ["fory", "protobuf", "json"]; const DATA_ORDER = [ "struct", @@ -239,6 +240,83 @@ function createSchemas() { }; } +function createSchemasV2() { + return { + NumericStruct: Type.struct(1, { + f1: int64Field(1), + f2: int32Field(2), + f3: int32Field(3), + f4: int32Field(4), + f5: int32Field(5), + f6: int32Field(6), + f7: int32Field(7), + f8: int32Field(8), + f9: int32Field(9), + f10: int32Field(10), + f11: int32Field(11), + f12: int32Field(12), + }), + Sample: Type.struct(2, { + int_value: int64Field(1), + long_value: int64Field(2), + float_value: float32Field(3), + double_value: float64Field(4), + short_value: int32Field(5), + char_value: int32Field(6), + boolean_value: boolField(7), + int_value_boxed: int32Field(8), + long_value_boxed: int64Field(9), + float_value_boxed: float32Field(10), + double_value_boxed: float64Field(11), + short_value_boxed: int32Field(12), + char_value_boxed: int32Field(13), + boolean_value_boxed: boolField(14), + int_array: int32ArrayField(15), + long_array: int64ArrayField(16), + float_array: float32ArrayField(17), + double_array: float64ArrayField(18), + short_array: int32ArrayField(19), + char_array: int32ArrayField(20), + boolean_array: boolArrayField(21), + string: stringField(22), + }), + Media: Type.struct(3, { + uri: stringField(1), + title: stringField(2), + width: int64Field(3), + height: int32Field(4), + format: stringField(5), + duration: int64Field(6), + size: int64Field(7), + bitrate: int32Field(8), + has_bitrate: boolField(9), + persons: listField(10, Type.string()), + player: enumField(11, 101, PLAYER_ENUM), + copyright: stringField(12), + }), + Image: Type.struct(4, { + uri: stringField(1), + title: stringField(2), + width: int64Field(3), + height: int32Field(4), + size: enumField(5, 102, SIZE_ENUM), + }), + MediaContent: Type.struct(5, { + media: structField(1, 3), + images: listField(2, Type.struct(4)), + }), + NumericStructList: Type.struct(6, { + struct_list: listField(1, Type.struct(1)), + }), + SampleList: Type.struct(7, { + sample_list: listField(1, Type.struct(2)), + }), + MediaContentList: Type.struct(8, { + media_content_list: listField(1, Type.struct(5)), + }), + }; +} + function createNumericStruct() { return { f1: -12345, @@ -526,13 +604,8 @@ function fromProtoMediaContentList(value) { }; } -function createForyBenchmarks() { - const fory = new Fory({ - compatible: true, - ref: false, - }); - const schemas = createSchemas(); - const serializers = { +function registerForySchemas(fory, schemas) { + return { struct: fory.register(schemas.NumericStruct), sample: fory.register(schemas.Sample), media: fory.register(schemas.Media), @@ -542,10 +615,28 @@ function createForyBenchmarks() { samplelist: fory.register(schemas.SampleList), mediacontentlist: fory.register(schemas.MediaContentList), }; - return { fory, serializers }; } -function createDatasets(root) { +function createForyBenchmarks(schemaMismatch) { + const fory = new Fory({ + compatible: true, + ref: false, + }); + const writerSerializers = registerForySchemas(fory, createSchemas()); + if (!schemaMismatch) { + return { writerSerializers, readerSerializers: writerSerializers }; + } + const reader = new Fory({ + compatible: true, + ref: false, + }); + return { + writerSerializers, + readerSerializers: registerForySchemas(reader, createSchemasV2()), + }; +} + +function createDatasets(root, schemaMismatch) { const StructType = root.lookupType("protobuf.NumericStruct"); const SampleType = root.lookupType("protobuf.Sample"); const MediaContentType = root.lookupType("protobuf.MediaContent"); @@ -553,7 +644,8 @@ function createDatasets(root) { const SampleListType = root.lookupType("protobuf.SampleList"); const MediaContentListType = root.lookupType("protobuf.MediaContentList"); - const { serializers } = createForyBenchmarks(); + const { writerSerializers, readerSerializers } = + createForyBenchmarks(schemaMismatch); return [ { @@ -563,7 +655,8 @@ function createDatasets(root) { toProto: toProtoStruct, fromProto: fromProtoStruct, protoType: StructType, - forySerializer: serializers.struct, + foryWriter: writerSerializers.struct, + foryReader: readerSerializers.struct, sizeKey: "struct", }, { @@ -573,7 +666,8 @@ function createDatasets(root) { toProto: toProtoSample, fromProto: fromProtoSample, protoType: SampleType, - forySerializer: serializers.sample, + foryWriter: writerSerializers.sample, + foryReader: readerSerializers.sample, sizeKey: "sample", }, { @@ -583,7 +677,8 @@ function createDatasets(root) { toProto: toProtoMediaContent, fromProto: fromProtoMediaContent, protoType: MediaContentType, - forySerializer: serializers.mediacontent, + foryWriter: writerSerializers.mediacontent, + foryReader: readerSerializers.mediacontent, sizeKey: "media", }, { @@ -593,7 +688,8 @@ function createDatasets(root) { toProto: toProtoNumericStructList, fromProto: fromProtoNumericStructList, protoType: StructListType, - forySerializer: serializers.structlist, + foryWriter: writerSerializers.structlist, + foryReader: readerSerializers.structlist, sizeKey: "struct_list", }, { @@ -603,7 +699,8 @@ function createDatasets(root) { toProto: toProtoSampleList, fromProto: fromProtoSampleList, protoType: SampleListType, - forySerializer: serializers.samplelist, + foryWriter: writerSerializers.samplelist, + foryReader: readerSerializers.samplelist, sizeKey: "sample_list", }, { @@ -613,12 +710,28 @@ function createDatasets(root) { toProto: toProtoMediaContentList, fromProto: fromProtoMediaContentList, protoType: MediaContentListType, - forySerializer: serializers.mediacontentlist, + foryWriter: writerSerializers.mediacontentlist, + foryReader: readerSerializers.mediacontentlist, sizeKey: "media_list", }, ]; } +function schemaMismatchEnabled() { + return process.env[SCHEMA_MISMATCH_ENV] === "1"; +} + +function validateSchemaMismatchSelection(options, schemaMismatch) { + if (!schemaMismatch) { + return; + } + if (options.serializer !== "fory") { + throw new Error( + `${SCHEMA_MISMATCH_ENV}=1 supports only Fory benchmarks; rerun with --serializer fory` + ); + } +} + function decodeProtoObject(protoType, bytes) { const message = protoType.decode(bytes); return protoType.toObject(message, { @@ -717,8 +830,8 @@ function normalizeProtobufValue(datasetKey, value) { function ensureSerializationWorks(dataset) { const value = dataset.createValue(); const foryValue = normalizeForyValue(dataset.key, value); - const foryBytes = dataset.forySerializer.serialize(foryValue); - const foryRoundTrip = dataset.forySerializer.deserialize(foryBytes); + const foryBytes = dataset.foryWriter.serialize(foryValue); + const foryRoundTrip = dataset.foryReader.deserialize(foryBytes); assert.deepStrictEqual( normalizeForyRoundTripValue(dataset.key, foryRoundTrip), foryValue @@ -734,10 +847,49 @@ function ensureSerializationWorks(dataset) { assert.deepStrictEqual(jsonRoundTrip, value); } +function verifySchemaMismatch(dataset) { + const value = dataset.createValue(); + const foryValue = normalizeForyValue(dataset.key, value); + const decoded = dataset.foryReader.deserialize(dataset.foryWriter.serialize(foryValue)); + switch (dataset.key) { + case "struct": + assert.equal(decoded.f1, BigInt(value.f1)); + break; + case "sample": + assert.equal(decoded.int_value, BigInt(value.int_value)); + break; + case "mediacontent": + assert.equal(decoded.media.width, BigInt(value.media.width)); + assert.equal(decoded.images[0].width, BigInt(value.images[0].width)); + break; + case "structlist": + assert.equal(decoded.struct_list[0].f1, BigInt(value.struct_list[0].f1)); + break; + case "samplelist": + assert.equal( + decoded.sample_list[0].int_value, + BigInt(value.sample_list[0].int_value) + ); + break; + case "mediacontentlist": + assert.equal( + decoded.media_content_list[0].media.width, + BigInt(value.media_content_list[0].media.width) + ); + assert.equal( + decoded.media_content_list[0].images[0].width, + BigInt(value.media_content_list[0].images[0].width) + ); + break; + default: + throw new Error(`Unknown dataset ${dataset.key}`); + } +} + function serializeBytes(serializerName, dataset, value) { switch (serializerName) { case "fory": - return dataset.forySerializer.serialize(normalizeForyValue(dataset.key, value)); + return dataset.foryWriter.serialize(normalizeForyValue(dataset.key, value)); case "protobuf": return dataset.protoType.encode(dataset.toProto(value)).finish(); case "json": @@ -754,13 +906,13 @@ function createBenchmarkCase(serializerName, dataset, operation) { const foryValue = normalizeForyValue(dataset.key, value); if (operation === "Serialize") { return () => { - const bytes = dataset.forySerializer.serialize(foryValue); + const bytes = dataset.foryWriter.serialize(foryValue); blackhole ^= bytes.length; }; } - const bytes = dataset.forySerializer.serialize(foryValue); + const bytes = dataset.foryWriter.serialize(foryValue); return () => { - const decoded = dataset.forySerializer.deserialize(bytes); + const decoded = dataset.foryReader.deserialize(bytes); blackhole ^= Array.isArray(decoded) ? decoded.length : 1; }; } @@ -858,9 +1010,10 @@ function buildResults(datasets, options) { const sizeCounters = { name: "BM_PrintSerializedSizes", }; + const sizeSerializers = options.schemaMismatch ? ["fory"] : SERIALIZER_ORDER; for (const dataset of datasets) { const value = dataset.createValue(); - for (const serializerName of SERIALIZER_ORDER) { + for (const serializerName of sizeSerializers) { const bytes = serializeBytes(serializerName, dataset, value); sizeCounters[`${serializerName}_${dataset.sizeKey}_size`] = bytes.length; } @@ -871,9 +1024,16 @@ function buildResults(datasets, options) { function main() { const options = parseArgs(process.argv.slice(2)); + options.schemaMismatch = schemaMismatchEnabled(); + validateSchemaMismatchSelection(options, options.schemaMismatch); const root = protobuf.loadSync(path.join(REPO_ROOT, "benchmarks", "proto", "bench.proto")); - const datasets = createDatasets(root); - datasets.forEach(ensureSerializationWorks); + const datasets = createDatasets(root, options.schemaMismatch); + const verificationDatasets = options.data + ? datasets.filter((dataset) => dataset.key === options.data) + : datasets; + verificationDatasets.forEach( + options.schemaMismatch ? verifySchemaMismatch : ensureSerializationWorks + ); const structSize = serializeBytes("fory", datasets.find((item) => item.key === "struct"), createNumericStruct()).length; console.log(`Fory NumericStruct serialized size: ${structSize} bytes`); @@ -887,6 +1047,7 @@ function main() { node_version: process.version, v8_version: process.versions.v8, duration_seconds: options.durationSeconds, + schema_mismatch: options.schemaMismatch, }, benchmarks: buildResults(datasets, options), }; diff --git a/benchmarks/javascript/run.sh b/benchmarks/javascript/run.sh index 1bf4cb2502..87ba74a858 100755 --- a/benchmarks/javascript/run.sh +++ b/benchmarks/javascript/run.sh @@ -73,6 +73,11 @@ while [[ $# -gt 0 ]]; do esac done +if [[ "${FORY_BENCH_SCHEMA_MISMATCH:-0}" == "1" && "${SERIALIZER}" != "fory" ]]; then + echo -e "${RED}FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory.${NC}" + exit 1 +fi + echo -e "${GREEN}=== Fory JavaScript Benchmark ===${NC}" echo "" diff --git a/benchmarks/python/README.md b/benchmarks/python/README.md index 290955ec69..a17150d252 100644 --- a/benchmarks/python/README.md +++ b/benchmarks/python/README.md @@ -45,6 +45,15 @@ Supported values: - `--serializer`: `fory,protobuf,pickle` - `--operation`: `all|serialize|deserialize` +## Schema Mismatch Mode + +Set `FORY_BENCH_SCHEMA_MISMATCH=1` to run the Fory-only compatible-read +schema-mismatch mode. This mode is off by default. When enabled, run with +`--serializer fory`; protobuf and pickle benchmark modes fail with a +configuration error. Fory serialization uses the normal v1 benchmark dataclasses, +and Fory deserialization uses v2 dataclasses registered with the same Fory type +IDs where one int32 field is widened to int64. + ## CPython Microbenchmark `fory_benchmark.py` can be used directly: diff --git a/benchmarks/python/benchmark.py b/benchmarks/python/benchmark.py index 2ada05e2d9..31345a094e 100755 --- a/benchmarks/python/benchmark.py +++ b/benchmarks/python/benchmark.py @@ -46,6 +46,7 @@ LIST_SIZE = 5 +SCHEMA_MISMATCH_ENV = "FORY_BENCH_SCHEMA_MISMATCH" DATA_TYPE_ORDER = [ "struct", "sample", @@ -118,6 +119,22 @@ class NumericStruct: f12: pyfory.Int32 = pyfory.field(id=12) +@dataclass +class NumericStructV2: + f1: pyfory.Int64 = pyfory.field(id=1) + f2: pyfory.Int32 = pyfory.field(id=2) + f3: pyfory.Int32 = pyfory.field(id=3) + f4: pyfory.Int32 = pyfory.field(id=4) + f5: pyfory.Int32 = pyfory.field(id=5) + f6: pyfory.Int32 = pyfory.field(id=6) + f7: pyfory.Int32 = pyfory.field(id=7) + f8: pyfory.Int32 = pyfory.field(id=8) + f9: pyfory.Int32 = pyfory.field(id=9) + f10: pyfory.Int32 = pyfory.field(id=10) + f11: pyfory.Int32 = pyfory.field(id=11) + f12: pyfory.Int32 = pyfory.field(id=12) + + @dataclass class Sample: int_value: pyfory.Int32 = pyfory.field(id=1) @@ -144,6 +161,32 @@ class Sample: string: str = pyfory.field(id=22) +@dataclass +class SampleV2: + int_value: pyfory.Int64 = pyfory.field(id=1) + long_value: pyfory.Int64 = pyfory.field(id=2) + float_value: pyfory.Float32 = pyfory.field(id=3) + double_value: pyfory.Float64 = pyfory.field(id=4) + short_value: pyfory.Int32 = pyfory.field(id=5) + char_value: pyfory.Int32 = pyfory.field(id=6) + boolean_value: bool = pyfory.field(id=7) + int_value_boxed: pyfory.Int32 = pyfory.field(id=8) + long_value_boxed: pyfory.Int64 = pyfory.field(id=9) + float_value_boxed: pyfory.Float32 = pyfory.field(id=10) + double_value_boxed: pyfory.Float64 = pyfory.field(id=11) + short_value_boxed: pyfory.Int32 = pyfory.field(id=12) + char_value_boxed: pyfory.Int32 = pyfory.field(id=13) + boolean_value_boxed: bool = pyfory.field(id=14) + int_array: pyfory.NDArray[pyfory.Int32] = pyfory.field(id=15) + long_array: pyfory.NDArray[pyfory.Int64] = pyfory.field(id=16) + float_array: pyfory.NDArray[pyfory.Float32] = pyfory.field(id=17) + double_array: pyfory.NDArray[pyfory.Float64] = pyfory.field(id=18) + short_array: pyfory.NDArray[pyfory.Int32] = pyfory.field(id=19) + char_array: pyfory.NDArray[pyfory.Int32] = pyfory.field(id=20) + boolean_array: pyfory.NDArray[bool] = pyfory.field(id=21) + string: str = pyfory.field(id=22) + + @dataclass class Media: uri: str = pyfory.field(id=1) @@ -160,6 +203,22 @@ class Media: copyright: str = pyfory.field(id=12) +@dataclass +class MediaV2: + uri: str = pyfory.field(id=1) + title: str = pyfory.field(id=2) + width: pyfory.Int64 = pyfory.field(id=3) + height: pyfory.Int32 = pyfory.field(id=4) + format: str = pyfory.field(id=5) + duration: pyfory.Int64 = pyfory.field(id=6) + size: pyfory.Int64 = pyfory.field(id=7) + bitrate: pyfory.Int32 = pyfory.field(id=8) + has_bitrate: bool = pyfory.field(id=9) + persons: List[str] = pyfory.field(id=10) + player: Player = pyfory.field(id=11) + copyright: str = pyfory.field(id=12) + + @dataclass class Image: uri: str = pyfory.field(id=1) @@ -169,27 +228,57 @@ class Image: size: Size = pyfory.field(id=5) +@dataclass +class ImageV2: + uri: str = pyfory.field(id=1) + title: str = pyfory.field(id=2) + width: pyfory.Int64 = pyfory.field(id=3) + height: pyfory.Int32 = pyfory.field(id=4) + size: Size = pyfory.field(id=5) + + @dataclass class MediaContent: media: Media = pyfory.field(id=1) images: List[Image] = pyfory.field(id=2) +@dataclass +class MediaContentV2: + media: MediaV2 = pyfory.field(id=1) + images: List[ImageV2] = pyfory.field(id=2) + + @dataclass class NumericStructList: struct_list: List[NumericStruct] = pyfory.field(id=1) +@dataclass +class NumericStructListV2: + struct_list: List[NumericStructV2] = pyfory.field(id=1) + + @dataclass class SampleList: sample_list: List[Sample] = pyfory.field(id=1) +@dataclass +class SampleListV2: + sample_list: List[SampleV2] = pyfory.field(id=1) + + @dataclass class MediaContentList: media_content_list: List[MediaContent] = pyfory.field(id=1) +@dataclass +class MediaContentListV2: + media_content_list: List[MediaContentV2] = pyfory.field(id=1) + + def create_numeric_struct() -> NumericStruct: return NumericStruct( f1=-12345, @@ -541,6 +630,10 @@ def from_pb_mediacontentlist(pb_obj) -> MediaContentList: def build_fory() -> pyfory.Fory: + return build_fory_v1() + + +def build_fory_v1() -> pyfory.Fory: fory = pyfory.Fory(xlang=True, compatible=True, ref=False) fory.register_type(Player, type_id=101) fory.register_type(Size, type_id=102) @@ -555,6 +648,102 @@ def build_fory() -> pyfory.Fory: return fory +def build_fory_v2() -> pyfory.Fory: + fory = pyfory.Fory(xlang=True, compatible=True, ref=False) + fory.register_type(Player, type_id=101) + fory.register_type(Size, type_id=102) + fory.register_type(NumericStructV2, type_id=1) + fory.register_type(SampleV2, type_id=2) + fory.register_type(MediaV2, type_id=3) + fory.register_type(ImageV2, type_id=4) + fory.register_type(MediaContentV2, type_id=5) + fory.register_type(NumericStructListV2, type_id=6) + fory.register_type(SampleListV2, type_id=7) + fory.register_type(MediaContentListV2, type_id=8) + return fory + + +def schema_mismatch_enabled() -> bool: + return os.getenv(SCHEMA_MISMATCH_ENV) == "1" + + +def validate_schema_mismatch_selection(selected_serializers: List[str]) -> None: + if schema_mismatch_enabled() and selected_serializers != ["fory"]: + raise ValueError( + f"{SCHEMA_MISMATCH_ENV}=1 supports only Fory benchmarks; " + "rerun with --serializer fory" + ) + + +def verify_schema_mismatch(datatype: str, decoded: Any, expected: Any) -> None: + if datatype == "struct": + if not isinstance(decoded, NumericStructV2) or decoded.f1 != expected.f1: + raise AssertionError("NumericStructV2 schema mismatch read failed") + return + if datatype == "sample": + if not isinstance(decoded, SampleV2) or decoded.int_value != expected.int_value: + raise AssertionError("SampleV2 schema mismatch read failed") + return + if datatype == "mediacontent": + if ( + not isinstance(decoded, MediaContentV2) + or not isinstance(decoded.media, MediaV2) + or decoded.media.width != expected.media.width + or not decoded.images + or not isinstance(decoded.images[0], ImageV2) + or decoded.images[0].width != expected.images[0].width + ): + raise AssertionError("MediaContentV2 schema mismatch read failed") + return + if datatype == "structlist": + if ( + not isinstance(decoded, NumericStructListV2) + or not decoded.struct_list + or not isinstance(decoded.struct_list[0], NumericStructV2) + or decoded.struct_list[0].f1 != expected.struct_list[0].f1 + ): + raise AssertionError("NumericStructListV2 schema mismatch read failed") + return + if datatype == "samplelist": + if ( + not isinstance(decoded, SampleListV2) + or not decoded.sample_list + or not isinstance(decoded.sample_list[0], SampleV2) + or decoded.sample_list[0].int_value != expected.sample_list[0].int_value + ): + raise AssertionError("SampleListV2 schema mismatch read failed") + return + if datatype == "mediacontentlist": + if ( + not isinstance(decoded, MediaContentListV2) + or not decoded.media_content_list + or not isinstance(decoded.media_content_list[0], MediaContentV2) + or not isinstance(decoded.media_content_list[0].media, MediaV2) + or decoded.media_content_list[0].media.width + != expected.media_content_list[0].media.width + or not decoded.media_content_list[0].images + or not isinstance(decoded.media_content_list[0].images[0], ImageV2) + or decoded.media_content_list[0].images[0].width + != expected.media_content_list[0].images[0].width + ): + raise AssertionError("MediaContentListV2 schema mismatch read failed") + return + raise AssertionError(f"Unknown datatype for schema mismatch: {datatype}") + + +def verify_fory_schema_mismatch( + benchmark_data: Dict[str, Any], + selected_datatypes: Iterable[str], + *, + writer: pyfory.Fory, + reader: pyfory.Fory, +) -> None: + for datatype in selected_datatypes: + value = benchmark_data[datatype] + decoded = reader.deserialize(writer.serialize(value)) + verify_schema_mismatch(datatype, decoded, value) + + def run_benchmark( func: Callable[..., Any], args: Tuple[Any, ...], @@ -629,13 +818,14 @@ def build_case( datatype: str, obj: Any, *, - fory: pyfory.Fory, + fory_writer: pyfory.Fory, + fory_reader: pyfory.Fory, bench_pb2, ) -> Tuple[Callable[..., Any], Tuple[Any, ...]]: if serializer == "fory": if operation == "serialize": - return fory_serialize, (fory, obj) - return fory_deserialize, (fory, fory.serialize(obj)) + return fory_serialize, (fory_writer, obj) + return fory_deserialize, (fory_reader, fory_writer.serialize(obj)) if serializer == "pickle": if operation == "serialize": @@ -660,19 +850,27 @@ def calculate_serialized_sizes( *, fory: pyfory.Fory, bench_pb2, + selected_serializers: Iterable[str], + schema_mismatch: bool, ) -> Dict[str, Dict[str, int]]: sizes: Dict[str, Dict[str, int]] = {} + serializer_names = ( + list(selected_serializers) if schema_mismatch else SERIALIZER_ORDER + ) for datatype in selected_datatypes: obj = benchmark_data[datatype] datatype_sizes: Dict[str, int] = {} - datatype_sizes["fory"] = len(fory.serialize(obj)) - datatype_sizes["pickle"] = len( - pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) - ) + if "fory" in serializer_names: + datatype_sizes["fory"] = len(fory.serialize(obj)) + if "pickle" in serializer_names: + datatype_sizes["pickle"] = len( + pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + ) - to_pb, _, _ = PROTO_CONVERTERS[datatype] - datatype_sizes["protobuf"] = len(to_pb(bench_pb2, obj).SerializeToString()) + if "protobuf" in serializer_names: + to_pb, _, _ = PROTO_CONVERTERS[datatype] + datatype_sizes["protobuf"] = len(to_pb(bench_pb2, obj).SerializeToString()) sizes[datatype] = datatype_sizes return sizes @@ -770,12 +968,22 @@ def main() -> int: selected_serializers = parse_csv_list( args.serializer, SERIALIZER_ORDER, SERIALIZER_ORDER ) + validate_schema_mismatch_selection(selected_serializers) selected_operations = ( OPERATION_ORDER if args.operation == "all" else [args.operation] ) benchmark_data = create_benchmark_data() - fory = build_fory() + mismatch = schema_mismatch_enabled() + fory_writer = build_fory_v1() + fory_reader = build_fory_v2() if mismatch else build_fory_v1() + if mismatch: + verify_fory_schema_mismatch( + benchmark_data, + selected_datatypes, + writer=fory_writer, + reader=fory_reader, + ) print( f"Benchmarking {len(selected_datatypes)} data type(s), {len(selected_serializers)} serializer(s), {len(selected_operations)} operation(s)" @@ -800,7 +1008,8 @@ def main() -> int: operation, datatype, obj, - fory=fory, + fory_writer=fory_writer, + fory_reader=fory_reader, bench_pb2=bench_pb2, ) mean, stdev = run_benchmark( @@ -830,8 +1039,10 @@ def main() -> int: sizes = calculate_serialized_sizes( benchmark_data, selected_datatypes, - fory=fory, + fory=fory_writer, bench_pb2=bench_pb2, + selected_serializers=selected_serializers, + schema_mismatch=mismatch, ) output_path = Path(args.output_json) @@ -853,6 +1064,7 @@ def main() -> int: "datatypes": selected_datatypes, "serializers": selected_serializers, "list_size": LIST_SIZE, + "schema_mismatch": mismatch, }, "benchmarks": results, "sizes": sizes, diff --git a/benchmarks/python/run.sh b/benchmarks/python/run.sh index ef0bb1b482..cd1d938709 100755 --- a/benchmarks/python/run.sh +++ b/benchmarks/python/run.sh @@ -107,6 +107,11 @@ while [[ $# -gt 0 ]]; do esac done +if [[ "${FORY_BENCH_SCHEMA_MISMATCH:-0}" == "1" && "$SERIALIZER" != "fory" ]]; then + echo "FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory." + exit 1 +fi + if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then echo "Error: $PYTHON_BIN is not available" exit 1 diff --git a/benchmarks/rust/README.md b/benchmarks/rust/README.md index ca2ef95f90..354dd231cc 100644 --- a/benchmarks/rust/README.md +++ b/benchmarks/rust/README.md @@ -38,6 +38,15 @@ Examples: ./run.sh --data sample,mediacontent --serializer protobuf ``` +## Schema Mismatch Mode + +Set `FORY_BENCH_SCHEMA_MISMATCH=1` to run the Fory-only compatible-read +schema-mismatch mode. This mode is off by default. When enabled, run with +`--serializer fory`; protobuf and MessagePack benchmark modes fail with a +configuration error. Fory serialization uses the normal v1 benchmark structs, +and Fory deserialization uses v2 structs registered with the same Fory type IDs +where one int32 field is widened to int64. + ## Benchmark Cases | Benchmark case | Description | diff --git a/benchmarks/rust/benchmark_report.py b/benchmarks/rust/benchmark_report.py index 1400f30f41..e3c9c00d5f 100644 --- a/benchmarks/rust/benchmark_report.py +++ b/benchmarks/rust/benchmark_report.py @@ -176,7 +176,10 @@ def load_serialized_sizes(size_file): if not os.path.exists(size_file): return {} - pattern = re.compile(r"^\|\s*([^|]+?)\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|$") + pattern = re.compile( + r"^\|\s*([^|]+?)\s*\|\s*(\d+|n/a)\s*\|\s*(\d+|n/a)\s*\|\s*(\d+|n/a)\s*\|$", + re.IGNORECASE, + ) sizes = {} with open(size_file, "r", encoding="utf-8") as file: for line in file: @@ -187,13 +190,21 @@ def load_serialized_sizes(size_file): if datatype == "Datatype": continue sizes[datatype] = { - "fory": int(fory_size), - "protobuf": int(protobuf_size), - "msgpack": int(msgpack_size), + "fory": parse_size_value(fory_size), + "protobuf": parse_size_value(protobuf_size), + "msgpack": parse_size_value(msgpack_size), } return sizes +def parse_size_value(value): + return None if value.lower() == "n/a" else int(value) + + +def format_size_value(value): + return str(value) if isinstance(value, int) and value >= 0 else "N/A" + + def format_tps_tick(tps, _position): return format_throughput_tick(tps, _position) @@ -358,7 +369,9 @@ def write_report(system_info, results, sizes, output_dir, plot_prefix): continue entry = sizes[title] report.append( - f"| {title} | {entry['fory']} | {entry['protobuf']} | {entry['msgpack']} |\n" + f"| {title} | {format_size_value(entry['fory'])} | " + + f"{format_size_value(entry['protobuf'])} | " + + f"{format_size_value(entry['msgpack'])} |\n" ) report_path = os.path.join(output_dir, "README.md") diff --git a/benchmarks/rust/run.sh b/benchmarks/rust/run.sh index 109d6d10eb..90bf08344b 100755 --- a/benchmarks/rust/run.sh +++ b/benchmarks/rust/run.sh @@ -87,6 +87,11 @@ while [[ $# -gt 0 ]]; do esac done +if [[ "${FORY_BENCH_SCHEMA_MISMATCH:-0}" == "1" && "$SERIALIZER_FILTER" != "fory" ]]; then + echo "FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory." + exit 1 +fi + normalize_data_filter() { local input="$1" if [[ -z "$input" || "$input" == "all" ]]; then diff --git a/benchmarks/rust/src/data.rs b/benchmarks/rust/src/data.rs index fa5292dbdd..7b75613359 100644 --- a/benchmarks/rust/src/data.rs +++ b/benchmarks/rust/src/data.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use fory::{Error, Fory}; +use fory::{Error, Fory, ForyDefault, Serializer as ForyValueSerializer}; use fory_derive::{ForyEnum, ForyStruct}; use serde::{Deserialize, Serialize}; @@ -70,6 +70,12 @@ pub trait BenchmarkCase: Clone + PartialEq + Sized + 'static { fn create() -> Self; } +pub trait SchemaMismatchCase: BenchmarkCase { + type Read: ForyValueSerializer + ForyDefault + 'static; + + fn verify_mismatch(decoded: &Self::Read, expected: &Self); +} + #[derive(Debug, Clone, PartialEq, Eq, ForyStruct, Serialize, Deserialize)] pub struct NumericStruct { #[fory(id = 1)] @@ -98,6 +104,34 @@ pub struct NumericStruct { pub f12: i32, } +#[derive(Debug, Clone, PartialEq, Eq, ForyStruct)] +pub struct NumericStructV2 { + #[fory(id = 1)] + pub f1: i64, + #[fory(id = 2)] + pub f2: i32, + #[fory(id = 3)] + pub f3: i32, + #[fory(id = 4)] + pub f4: i32, + #[fory(id = 5)] + pub f5: i32, + #[fory(id = 6)] + pub f6: i32, + #[fory(id = 7)] + pub f7: i32, + #[fory(id = 8)] + pub f8: i32, + #[fory(id = 9)] + pub f9: i32, + #[fory(id = 10)] + pub f10: i32, + #[fory(id = 11)] + pub f11: i32, + #[fory(id = 12)] + pub f12: i32, +} + #[derive(Debug, Clone, PartialEq, ForyStruct, Serialize, Deserialize)] pub struct Sample { #[fory(id = 1)] @@ -146,6 +180,54 @@ pub struct Sample { pub string: String, } +#[derive(Debug, Clone, PartialEq, ForyStruct)] +pub struct SampleV2 { + #[fory(id = 1)] + pub int_value: i64, + #[fory(id = 2)] + pub long_value: i64, + #[fory(id = 3)] + pub float_value: f32, + #[fory(id = 4)] + pub double_value: f64, + #[fory(id = 5)] + pub short_value: i32, + #[fory(id = 6)] + pub char_value: i32, + #[fory(id = 7)] + pub boolean_value: bool, + #[fory(id = 8)] + pub int_value_boxed: i32, + #[fory(id = 9)] + pub long_value_boxed: i64, + #[fory(id = 10)] + pub float_value_boxed: f32, + #[fory(id = 11)] + pub double_value_boxed: f64, + #[fory(id = 12)] + pub short_value_boxed: i32, + #[fory(id = 13)] + pub char_value_boxed: i32, + #[fory(id = 14)] + pub boolean_value_boxed: bool, + #[fory(id = 15, array)] + pub int_array: Vec, + #[fory(id = 16, array)] + pub long_array: Vec, + #[fory(id = 17, array)] + pub float_array: Vec, + #[fory(id = 18, array)] + pub double_array: Vec, + #[fory(id = 19, array)] + pub short_array: Vec, + #[fory(id = 20, array)] + pub char_array: Vec, + #[fory(id = 21, array)] + pub boolean_array: Vec, + #[fory(id = 22)] + pub string: String, +} + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ForyEnum, Serialize, Deserialize)] #[repr(i32)] pub enum Player { @@ -190,6 +272,34 @@ pub struct Media { pub copyright: String, } +#[derive(Debug, Clone, PartialEq, Eq, ForyStruct)] +pub struct MediaV2 { + #[fory(id = 1)] + pub uri: String, + #[fory(id = 2)] + pub title: String, + #[fory(id = 3)] + pub width: i64, + #[fory(id = 4)] + pub height: i32, + #[fory(id = 5)] + pub format: String, + #[fory(id = 6)] + pub duration: i64, + #[fory(id = 7)] + pub size: i64, + #[fory(id = 8)] + pub bitrate: i32, + #[fory(id = 9)] + pub has_bitrate: bool, + #[fory(id = 10)] + pub persons: Vec, + #[fory(id = 11)] + pub player: Player, + #[fory(id = 12)] + pub copyright: String, +} + #[derive(Debug, Clone, PartialEq, Eq, ForyStruct, Serialize, Deserialize)] pub struct Image { #[fory(id = 1)] @@ -204,6 +314,20 @@ pub struct Image { pub size: Size, } +#[derive(Debug, Clone, PartialEq, Eq, ForyStruct)] +pub struct ImageV2 { + #[fory(id = 1)] + pub uri: String, + #[fory(id = 2)] + pub title: String, + #[fory(id = 3)] + pub width: i64, + #[fory(id = 4)] + pub height: i32, + #[fory(id = 5)] + pub size: Size, +} + #[derive(Debug, Clone, PartialEq, Eq, ForyStruct, Serialize, Deserialize)] pub struct MediaContent { #[fory(id = 1)] @@ -212,24 +336,50 @@ pub struct MediaContent { pub images: Vec, } +#[derive(Debug, Clone, PartialEq, Eq, ForyStruct)] +pub struct MediaContentV2 { + #[fory(id = 1)] + pub media: MediaV2, + #[fory(id = 2)] + pub images: Vec, +} + #[derive(Debug, Clone, PartialEq, Eq, ForyStruct, Serialize, Deserialize)] pub struct NumericStructList { #[fory(id = 1)] pub struct_list: Vec, } +#[derive(Debug, Clone, PartialEq, Eq, ForyStruct)] +pub struct NumericStructListV2 { + #[fory(id = 1)] + pub struct_list: Vec, +} + #[derive(Debug, Clone, PartialEq, ForyStruct, Serialize, Deserialize)] pub struct SampleList { #[fory(id = 1)] pub sample_list: Vec, } +#[derive(Debug, Clone, PartialEq, ForyStruct)] +pub struct SampleListV2 { + #[fory(id = 1)] + pub sample_list: Vec, +} + #[derive(Debug, Clone, PartialEq, Eq, ForyStruct, Serialize, Deserialize)] pub struct MediaContentList { #[fory(id = 1)] pub media_content_list: Vec, } +#[derive(Debug, Clone, PartialEq, Eq, ForyStruct)] +pub struct MediaContentListV2 { + #[fory(id = 1)] + pub media_content_list: Vec, +} + impl BenchmarkCase for NumericStruct { const KIND: DataKind = DataKind::Struct; @@ -238,6 +388,14 @@ impl BenchmarkCase for NumericStruct { } } +impl SchemaMismatchCase for NumericStruct { + type Read = NumericStructV2; + + fn verify_mismatch(decoded: &Self::Read, expected: &Self) { + assert_eq!(decoded.f1, i64::from(expected.f1)); + } +} + impl BenchmarkCase for Sample { const KIND: DataKind = DataKind::Sample; @@ -246,6 +404,14 @@ impl BenchmarkCase for Sample { } } +impl SchemaMismatchCase for Sample { + type Read = SampleV2; + + fn verify_mismatch(decoded: &Self::Read, expected: &Self) { + assert_eq!(decoded.int_value, i64::from(expected.int_value)); + } +} + impl BenchmarkCase for MediaContent { const KIND: DataKind = DataKind::MediaContent; @@ -254,6 +420,15 @@ impl BenchmarkCase for MediaContent { } } +impl SchemaMismatchCase for MediaContent { + type Read = MediaContentV2; + + fn verify_mismatch(decoded: &Self::Read, expected: &Self) { + assert_eq!(decoded.media.width, i64::from(expected.media.width)); + assert_eq!(decoded.images[0].width, i64::from(expected.images[0].width)); + } +} + impl BenchmarkCase for NumericStructList { const KIND: DataKind = DataKind::NumericStructList; @@ -264,6 +439,17 @@ impl BenchmarkCase for NumericStructList { } } +impl SchemaMismatchCase for NumericStructList { + type Read = NumericStructListV2; + + fn verify_mismatch(decoded: &Self::Read, expected: &Self) { + assert_eq!( + decoded.struct_list[0].f1, + i64::from(expected.struct_list[0].f1) + ); + } +} + impl BenchmarkCase for SampleList { const KIND: DataKind = DataKind::SampleList; @@ -274,6 +460,17 @@ impl BenchmarkCase for SampleList { } } +impl SchemaMismatchCase for SampleList { + type Read = SampleListV2; + + fn verify_mismatch(decoded: &Self::Read, expected: &Self) { + assert_eq!( + decoded.sample_list[0].int_value, + i64::from(expected.sample_list[0].int_value) + ); + } +} + impl BenchmarkCase for MediaContentList { const KIND: DataKind = DataKind::MediaContentList; @@ -284,6 +481,21 @@ impl BenchmarkCase for MediaContentList { } } +impl SchemaMismatchCase for MediaContentList { + type Read = MediaContentListV2; + + fn verify_mismatch(decoded: &Self::Read, expected: &Self) { + assert_eq!( + decoded.media_content_list[0].media.width, + i64::from(expected.media_content_list[0].media.width) + ); + assert_eq!( + decoded.media_content_list[0].images[0].width, + i64::from(expected.media_content_list[0].images[0].width) + ); + } +} + pub fn register_fory_types(fory: &mut Fory) -> Result<(), Error> { fory.register::(1)?; fory.register::(2)?; @@ -298,6 +510,20 @@ pub fn register_fory_types(fory: &mut Fory) -> Result<(), Error> { Ok(()) } +pub fn register_fory_types_v2(fory: &mut Fory) -> Result<(), Error> { + fory.register::(1)?; + fory.register::(2)?; + fory.register::(3)?; + fory.register::(4)?; + fory.register::(5)?; + fory.register::(6)?; + fory.register::(7)?; + fory.register::(8)?; + fory.register::(9)?; + fory.register::(10)?; + Ok(()) +} + fn create_numeric_struct() -> NumericStruct { NumericStruct { f1: -12_345, diff --git a/benchmarks/rust/src/lib.rs b/benchmarks/rust/src/lib.rs index 6f7d666bf3..aa3d7de68b 100644 --- a/benchmarks/rust/src/lib.rs +++ b/benchmarks/rust/src/lib.rs @@ -24,11 +24,13 @@ pub mod generated { use criterion::{black_box, Criterion}; use data::{ - BenchmarkCase, MediaContent, MediaContentList, NumericStruct, NumericStructList, Sample, - SampleList, + MediaContent, MediaContentList, NumericStruct, NumericStructList, Sample, SampleList, + SchemaMismatchCase, }; use serializers::{ - fory::ForySerializer, msgpack::MsgpackSerializer, protobuf::ProtobufSerializer, + fory::{schema_mismatch_enabled, ForySerializer}, + msgpack::MsgpackSerializer, + protobuf::ProtobufSerializer, BenchmarkSerializer, }; @@ -81,12 +83,13 @@ fn run_benchmark_case( protobuf_serializer: &ProtobufSerializer, msgpack_serializer: &MsgpackSerializer, ) where - T: BenchmarkCase, + T: SchemaMismatchCase, ForySerializer: BenchmarkSerializer, ProtobufSerializer: BenchmarkSerializer, MsgpackSerializer: BenchmarkSerializer, { let data = T::create(); + let mismatch = schema_mismatch_enabled(); let mut group = c.benchmark_group(T::KIND.group_name()); group.bench_function("fory_serialize", |b| { @@ -96,48 +99,64 @@ fn run_benchmark_case( }); let fory_bytes = fory_serializer.serialize(&data).unwrap(); + if mismatch { + let value: T::Read = fory_serializer.deserialize_as(&fory_bytes).unwrap(); + T::verify_mismatch(&value, &data); + } group.bench_function("fory_deserialize", |b| { b.iter(|| { - let value: T = black_box(fory_serializer.deserialize(black_box(&fory_bytes)).unwrap()); - black_box(value); - }) - }); - - group.bench_function("protobuf_serialize", |b| { - b.iter(|| { - let _ = black_box(protobuf_serializer.serialize(black_box(&data)).unwrap()); - }) - }); - - let protobuf_bytes = protobuf_serializer.serialize(&data).unwrap(); - group.bench_function("protobuf_deserialize", |b| { - b.iter(|| { - let value: T = black_box( - protobuf_serializer - .deserialize(black_box(&protobuf_bytes)) - .unwrap(), - ); - black_box(value); + if mismatch { + let value: T::Read = black_box( + fory_serializer + .deserialize_as(black_box(&fory_bytes)) + .unwrap(), + ); + black_box(value); + } else { + let value: T = + black_box(fory_serializer.deserialize(black_box(&fory_bytes)).unwrap()); + black_box(value); + } }) }); - group.bench_function("msgpack_serialize", |b| { - b.iter(|| { - let _ = black_box(msgpack_serializer.serialize(black_box(&data)).unwrap()); - }) - }); - - let msgpack_bytes = msgpack_serializer.serialize(&data).unwrap(); - group.bench_function("msgpack_deserialize", |b| { - b.iter(|| { - let value: T = black_box( - msgpack_serializer - .deserialize(black_box(&msgpack_bytes)) - .unwrap(), - ); - black_box(value); - }) - }); + if !mismatch { + group.bench_function("protobuf_serialize", |b| { + b.iter(|| { + let _ = black_box(protobuf_serializer.serialize(black_box(&data)).unwrap()); + }) + }); + + let protobuf_bytes = protobuf_serializer.serialize(&data).unwrap(); + group.bench_function("protobuf_deserialize", |b| { + b.iter(|| { + let value: T = black_box( + protobuf_serializer + .deserialize(black_box(&protobuf_bytes)) + .unwrap(), + ); + black_box(value); + }) + }); + + group.bench_function("msgpack_serialize", |b| { + b.iter(|| { + let _ = black_box(msgpack_serializer.serialize(black_box(&data)).unwrap()); + }) + }); + + let msgpack_bytes = msgpack_serializer.serialize(&data).unwrap(); + group.bench_function("msgpack_deserialize", |b| { + b.iter(|| { + let value: T = black_box( + msgpack_serializer + .deserialize(black_box(&msgpack_bytes)) + .unwrap(), + ); + black_box(value); + }) + }); + } group.finish(); } diff --git a/benchmarks/rust/src/main.rs b/benchmarks/rust/src/main.rs index 7915ac094c..b3d5fb17f0 100644 --- a/benchmarks/rust/src/main.rs +++ b/benchmarks/rust/src/main.rs @@ -18,10 +18,12 @@ use clap::{Parser, ValueEnum}; use fory_benchmarks::data::{ BenchmarkCase, DataKind, MediaContent, MediaContentList, NumericStruct, NumericStructList, - Sample, SampleList, + Sample, SampleList, SchemaMismatchCase, }; use fory_benchmarks::serializers::{ - fory::ForySerializer, msgpack::MsgpackSerializer, protobuf::ProtobufSerializer, + fory::{schema_mismatch_enabled, ForySerializer}, + msgpack::MsgpackSerializer, + protobuf::ProtobufSerializer, BenchmarkSerializer, }; use std::hint::black_box; @@ -99,14 +101,23 @@ where fn profile_case(iterations: usize, serializer: SerializerKind, operation: Operation) where - T: BenchmarkCase, + T: SchemaMismatchCase, ForySerializer: BenchmarkSerializer, ProtobufSerializer: BenchmarkSerializer, MsgpackSerializer: BenchmarkSerializer, { let value = T::create(); + let mismatch = schema_mismatch_enabled(); + if mismatch && serializer != SerializerKind::Fory { + panic!( + "FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory" + ); + } match serializer { + SerializerKind::Fory if mismatch => { + profile_fory_mismatch::(iterations, &value, operation) + } SerializerKind::Fory => profile(iterations, &value, &ForySerializer::new(), operation), SerializerKind::Protobuf => { profile(iterations, &value, &ProtobufSerializer::new(), operation) @@ -117,6 +128,42 @@ where } } +fn profile_fory_mismatch(iterations: usize, value: &T, operation: Operation) +where + T: SchemaMismatchCase, + ForySerializer: BenchmarkSerializer, +{ + let serializer = ForySerializer::new(); + match operation { + Operation::Serialize => { + for _ in 0..1000 { + let _ = black_box(serializer.serialize(black_box(value)).unwrap()); + } + + for _ in 0..iterations { + let _ = black_box(serializer.serialize(black_box(value)).unwrap()); + } + } + Operation::Deserialize => { + let bytes = serializer.serialize(value).unwrap(); + let decoded: T::Read = serializer.deserialize_as(&bytes).unwrap(); + T::verify_mismatch(&decoded, value); + + for _ in 0..1000 { + let value: T::Read = + black_box(serializer.deserialize_as(black_box(&bytes)).unwrap()); + black_box(value); + } + + for _ in 0..iterations { + let value: T::Read = + black_box(serializer.deserialize_as(black_box(&bytes)).unwrap()); + black_box(value); + } + } + } +} + fn print_size_row(label: &str) where T: BenchmarkCase, @@ -126,6 +173,11 @@ where { let value = T::create(); let fory = ForySerializer::new().serialize(&value).unwrap().len(); + if schema_mismatch_enabled() { + println!("| {label} | {fory} | n/a | n/a |"); + return; + } + let protobuf = ProtobufSerializer::new().serialize(&value).unwrap().len(); let msgpack = MsgpackSerializer::new().serialize(&value).unwrap().len(); @@ -145,6 +197,11 @@ fn print_all_serialized_sizes() { fn main() { let cli = Cli::parse(); + if schema_mismatch_enabled() && cli.serializer != SerializerKind::Fory { + panic!( + "FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory" + ); + } if cli.print_all_serialized_sizes { print_all_serialized_sizes(); diff --git a/benchmarks/rust/src/serializers/fory.rs b/benchmarks/rust/src/serializers/fory.rs index 0a85a5e173..edf37e2d33 100644 --- a/benchmarks/rust/src/serializers/fory.rs +++ b/benchmarks/rust/src/serializers/fory.rs @@ -15,22 +15,40 @@ // specific language governing permissions and limitations // under the License. -use crate::data::register_fory_types; +use crate::data::{register_fory_types, register_fory_types_v2}; use crate::serializers::{BenchmarkSerializer, BoxError}; use fory::{Fory, ForyDefault, Serializer as ForyValueSerializer}; #[derive(Default)] pub struct ForySerializer { - fory: Fory, + writer: Fory, + reader: Fory, } impl ForySerializer { pub fn new() -> Self { - let mut fory = Fory::builder().xlang(true).compatible(true).build(); - register_fory_types(&mut fory).expect("register benchmark types"); + let mut writer = Fory::builder().xlang(true).compatible(true).build(); + register_fory_types(&mut writer).expect("register benchmark writer types"); + let mut reader = Fory::builder().xlang(true).compatible(true).build(); + if schema_mismatch_enabled() { + register_fory_types_v2(&mut reader).expect("register benchmark v2 reader types"); + } else { + register_fory_types(&mut reader).expect("register benchmark reader types"); + } - Self { fory } + Self { writer, reader } } + + pub fn deserialize_as(&self, data: &[u8]) -> Result + where + T: ForyValueSerializer + ForyDefault, + { + Ok(self.reader.deserialize(data)?) + } +} + +pub fn schema_mismatch_enabled() -> bool { + std::env::var("FORY_BENCH_SCHEMA_MISMATCH").as_deref() == Ok("1") } impl BenchmarkSerializer for ForySerializer @@ -38,10 +56,10 @@ where T: ForyValueSerializer + ForyDefault, { fn serialize(&self, data: &T) -> Result, BoxError> { - Ok(self.fory.serialize(data)?) + Ok(self.writer.serialize(data)?) } fn deserialize(&self, data: &[u8]) -> Result { - Ok(self.fory.deserialize(data)?) + self.deserialize_as(data) } } diff --git a/benchmarks/swift/README.md b/benchmarks/swift/README.md index c9c6e9b126..1026b0a6ff 100644 --- a/benchmarks/swift/README.md +++ b/benchmarks/swift/README.md @@ -58,6 +58,15 @@ Examples: ./run.sh --no-report ``` +## Schema Mismatch Mode + +Set `FORY_BENCH_SCHEMA_MISMATCH=1` to run the Fory-only compatible-read +schema-mismatch mode. This mode is off by default. When enabled, run with +`--serializer fory`; protobuf and JSON benchmark modes fail with a configuration +error. Fory serialization uses the normal v1 benchmark models, and Fory +deserialization uses v2 models registered with the same Fory type IDs where one +int32 field is widened to int64. + ## Manual Commands ```bash diff --git a/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkModels.swift b/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkModels.swift index 9887003eeb..e5f0e12b73 100644 --- a/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkModels.swift +++ b/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkModels.swift @@ -219,6 +219,150 @@ struct MediaContentList: Codable, Equatable { var mediaContentList: [MediaContent] = [] } +@ForyStruct +struct NumericStructV2: Equatable { + @ForyField(id: 1) + var f1: Int64 = 0 + @ForyField(id: 2) + var f2: Int32 = 0 + @ForyField(id: 3) + var f3: Int32 = 0 + @ForyField(id: 4) + var f4: Int32 = 0 + @ForyField(id: 5) + var f5: Int32 = 0 + @ForyField(id: 6) + var f6: Int32 = 0 + @ForyField(id: 7) + var f7: Int32 = 0 + @ForyField(id: 8) + var f8: Int32 = 0 + @ForyField(id: 9) + var f9: Int32 = 0 + @ForyField(id: 10) + var f10: Int32 = 0 + @ForyField(id: 11) + var f11: Int32 = 0 + @ForyField(id: 12) + var f12: Int32 = 0 +} + +@ForyStruct +struct SampleV2: Equatable { + @ForyField(id: 1) + var intValue: Int64 = 0 + @ForyField(id: 2) + var longValue: Int64 = 0 + @ForyField(id: 3) + var floatValue: Float = 0 + @ForyField(id: 4) + var doubleValue: Double = 0 + @ForyField(id: 5) + var shortValue: Int32 = 0 + @ForyField(id: 6) + var charValue: Int32 = 0 + @ForyField(id: 7) + var booleanValue: Bool = false + @ForyField(id: 8) + var intValueBoxed: Int32 = 0 + @ForyField(id: 9) + var longValueBoxed: Int64 = 0 + @ForyField(id: 10) + var floatValueBoxed: Float = 0 + @ForyField(id: 11) + var doubleValueBoxed: Double = 0 + @ForyField(id: 12) + var shortValueBoxed: Int32 = 0 + @ForyField(id: 13) + var charValueBoxed: Int32 = 0 + @ForyField(id: 14) + var booleanValueBoxed: Bool = false + @ForyField(id: 15, type: .array(element: .int32())) + var intArray: [Int32] = [] + @ForyField(id: 16, type: .array(element: .int64())) + var longArray: [Int64] = [] + @ForyField(id: 17, type: .array(element: .float32)) + var floatArray: [Float] = [] + @ForyField(id: 18, type: .array(element: .float64)) + var doubleArray: [Double] = [] + @ForyField(id: 19, type: .array(element: .int32())) + var shortArray: [Int32] = [] + @ForyField(id: 20, type: .array(element: .int32())) + var charArray: [Int32] = [] + @ForyField(id: 21, type: .array(element: .bool)) + var booleanArray: [Bool] = [] + @ForyField(id: 22) + var string: String = "" +} + +@ForyStruct +struct MediaV2: Equatable { + @ForyField(id: 1) + var uri: String = "" + @ForyField(id: 2) + var title: String = "" + @ForyField(id: 3) + var width: Int64 = 0 + @ForyField(id: 4) + var height: Int32 = 0 + @ForyField(id: 5) + var format: String = "" + @ForyField(id: 6) + var duration: Int64 = 0 + @ForyField(id: 7) + var size: Int64 = 0 + @ForyField(id: 8) + var bitrate: Int32 = 0 + @ForyField(id: 9) + var hasBitrate: Bool = false + @ForyField(id: 10) + var persons: [String] = [] + @ForyField(id: 11) + var player: Player = .java + @ForyField(id: 12) + var copyright: String = "" +} + +@ForyStruct +struct ImageV2: Equatable { + @ForyField(id: 1) + var uri: String = "" + @ForyField(id: 2) + var title: String = "" + @ForyField(id: 3) + var width: Int64 = 0 + @ForyField(id: 4) + var height: Int32 = 0 + @ForyField(id: 5) + var size: Size = .small +} + +@ForyStruct +struct MediaContentV2: Equatable { + @ForyField(id: 1) + var media: MediaV2 = .init() + @ForyField(id: 2) + var images: [ImageV2] = [] +} + +@ForyStruct +struct NumericStructListV2: Equatable { + @ForyField(id: 1) + var structList: [NumericStructV2] = [] +} + +@ForyStruct +struct SampleListV2: Equatable { + @ForyField(id: 1) + var sampleList: [SampleV2] = [] +} + +@ForyStruct +struct MediaContentListV2: Equatable { + @ForyField(id: 1) + var mediaContentList: [MediaContentV2] = [] +} + enum BenchmarkDataFactory { static let listSize = 5 diff --git a/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkRunner.swift b/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkRunner.swift index f488a109c2..5ce2b649d8 100644 --- a/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkRunner.swift +++ b/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkRunner.swift @@ -23,6 +23,7 @@ struct BenchmarkConfig { var durationSeconds: Double = 3.0 var dataFilter: DataKind? var serializerFilter: SerializerKind? + var schemaMismatch: Bool = false } struct BenchmarkEntry: Codable { @@ -40,8 +41,8 @@ struct BenchmarkEntry: Codable { struct SizeEntry: Codable { let dataType: String let fory: Int - let protobuf: Int - let json: Int + let protobuf: Int? + let json: Int? } struct BenchmarkContext: Codable { @@ -61,14 +62,21 @@ struct BenchmarkOutput: Codable { final class BenchmarkSuite { private let config: BenchmarkConfig - private let fory: Fory + private let foryWriter: Fory + private let foryReader: Fory private let jsonEncoder = JSONEncoder() private let jsonDecoder = JSONDecoder() init(config: BenchmarkConfig) { self.config = config - self.fory = Fory(xlang: false, ref: false, compatible: true) - registerTypes() + self.foryWriter = Fory(ref: false, compatible: true) + self.foryReader = Fory(ref: false, compatible: true) + registerV1Types(foryWriter) + if config.schemaMismatch { + registerV2Types(foryReader) + } else { + registerV1Types(foryReader) + } } func run() throws -> BenchmarkOutput { @@ -78,36 +86,58 @@ final class BenchmarkSuite { try runBenchmarks( dataKind: .numericStruct, value: BenchmarkDataFactory.makeNumericStruct(), + validateMismatch: { (decoded: NumericStructV2, expected: NumericStruct) in + decoded.f1 == Int64(expected.f1) + }, entries: &entries, sizeEntries: &sizeEntries ) try runBenchmarks( dataKind: .sample, value: BenchmarkDataFactory.makeSample(), + validateMismatch: { (decoded: SampleV2, expected: Sample) in + decoded.intValue == Int64(expected.intValue) + }, entries: &entries, sizeEntries: &sizeEntries ) try runBenchmarks( dataKind: .mediaContent, value: BenchmarkDataFactory.makeMediaContent(), + validateMismatch: { (decoded: MediaContentV2, expected: MediaContent) in + decoded.media.width == Int64(expected.media.width) + && decoded.images.first?.width == Int64(expected.images[0].width) + }, entries: &entries, sizeEntries: &sizeEntries ) try runBenchmarks( dataKind: .numericStructList, value: BenchmarkDataFactory.makeNumericStructList(), + validateMismatch: { (decoded: NumericStructListV2, expected: NumericStructList) in + decoded.structList.first?.f1 == Int64(expected.structList[0].f1) + }, entries: &entries, sizeEntries: &sizeEntries ) try runBenchmarks( dataKind: .sampleList, value: BenchmarkDataFactory.makeSampleList(), + validateMismatch: { (decoded: SampleListV2, expected: SampleList) in + decoded.sampleList.first?.intValue == Int64(expected.sampleList[0].intValue) + }, entries: &entries, sizeEntries: &sizeEntries ) try runBenchmarks( dataKind: .mediaContentList, value: BenchmarkDataFactory.makeMediaContentList(), + validateMismatch: { (decoded: MediaContentListV2, expected: MediaContentList) in + decoded.mediaContentList.first?.media.width + == Int64(expected.mediaContentList[0].media.width) + && decoded.mediaContentList.first?.images.first?.width + == Int64(expected.mediaContentList[0].images[0].width) + }, entries: &entries, sizeEntries: &sizeEntries ) @@ -122,7 +152,7 @@ final class BenchmarkSuite { ) } - private func registerTypes() { + private func registerV1Types(_ fory: Fory) { fory.register(NumericStruct.self, id: 1) fory.register(Sample.self, id: 2) fory.register(Media.self, id: 3) @@ -133,6 +163,17 @@ final class BenchmarkSuite { fory.register(MediaContentList.self, id: 8) } + private func registerV2Types(_ fory: Fory) { + fory.register(NumericStructV2.self, id: 1) + fory.register(SampleV2.self, id: 2) + fory.register(MediaV2.self, id: 3) + fory.register(ImageV2.self, id: 4) + fory.register(MediaContentV2.self, id: 5) + fory.register(NumericStructListV2.self, id: 6) + fory.register(SampleListV2.self, id: 7) + fory.register(MediaContentListV2.self, id: 8) + } + private func shouldRun(_ dataKind: DataKind, _ serializer: SerializerKind) -> Bool { if let filter = config.dataFilter, filter != dataKind { return false @@ -143,9 +184,10 @@ final class BenchmarkSuite { return true } - private func runBenchmarks( + private func runBenchmarks( dataKind: DataKind, value: T, + validateMismatch: (TRead, T) -> Bool, entries: inout [BenchmarkEntry], sizeEntries: inout [SizeEntry] ) throws { @@ -153,18 +195,24 @@ final class BenchmarkSuite { return } - let foryBytes = try fory.serialize(value) - let protobufBytes = try value.toProtobuf().serializedData() - let jsonBytes = try jsonEncoder.encode(value) + let foryBytes = try foryWriter.serialize(value) + let protobufBytes = config.schemaMismatch ? nil : try value.toProtobuf().serializedData() + let jsonBytes = config.schemaMismatch ? nil : try jsonEncoder.encode(value) sizeEntries.append( SizeEntry( dataType: dataKind.title, fory: foryBytes.count, - protobuf: protobufBytes.count, - json: jsonBytes.count + protobuf: protobufBytes?.count, + json: jsonBytes?.count ) ) + if config.schemaMismatch { + let decoded: TRead = try foryReader.deserialize(foryBytes, as: TRead.self) + guard validateMismatch(decoded, value) else { + throw BenchmarkError.schemaMismatchValidation(dataKind.rawValue) + } + } if shouldRun(dataKind, .fory) { entries.append( @@ -174,7 +222,7 @@ final class BenchmarkSuite { operation: .serialize, bytes: foryBytes.count ) { - try self.fory.serialize(value).count + try self.foryWriter.serialize(value).count } ) entries.append( @@ -184,14 +232,22 @@ final class BenchmarkSuite { operation: .deserialize, bytes: foryBytes.count ) { - let decoded: T = try self.fory.deserialize(foryBytes, as: T.self) - withExtendedLifetime(decoded) {} + if self.config.schemaMismatch { + let decoded: TRead = try self.foryReader.deserialize(foryBytes, as: TRead.self) + withExtendedLifetime(decoded) {} + } else { + let decoded: T = try self.foryReader.deserialize(foryBytes, as: T.self) + withExtendedLifetime(decoded) {} + } return foryBytes.count } ) } - if shouldRun(dataKind, .protobuf) { + if !config.schemaMismatch && shouldRun(dataKind, .protobuf) { + guard let protobufBytes = protobufBytes else { + throw BenchmarkError.schemaMismatchValidation(dataKind.rawValue) + } entries.append( try runSingleCase( serializer: .protobuf, @@ -217,7 +273,10 @@ final class BenchmarkSuite { ) } - if shouldRun(dataKind, .json) { + if !config.schemaMismatch && shouldRun(dataKind, .json) { + guard let jsonBytes = jsonBytes else { + throw BenchmarkError.schemaMismatchValidation(dataKind.rawValue) + } entries.append( try runSingleCase( serializer: .json, @@ -327,3 +386,14 @@ final class BenchmarkSuite { ) } } + +enum BenchmarkError: Error, CustomStringConvertible { + case schemaMismatchValidation(String) + + var description: String { + switch self { + case let .schemaMismatchValidation(dataType): + return "Fory schema-mismatch validation failed for \(dataType)" + } + } +} diff --git a/benchmarks/swift/Sources/SwiftBenchmark/main.swift b/benchmarks/swift/Sources/SwiftBenchmark/main.swift index d1c946dc88..68df2ce0c9 100644 --- a/benchmarks/swift/Sources/SwiftBenchmark/main.swift +++ b/benchmarks/swift/Sources/SwiftBenchmark/main.swift @@ -37,6 +37,8 @@ private struct CLI { static func parse(arguments: [String]) throws -> (BenchmarkConfig, String) { var config = BenchmarkConfig() + config.schemaMismatch = + ProcessInfo.processInfo.environment["FORY_BENCH_SCHEMA_MISMATCH"] == "1" var outputPath = "results/benchmark_results.json" var i = 1 @@ -84,6 +86,12 @@ private struct CLI { } } + if config.schemaMismatch && config.serializerFilter != .fory { + throw CLIError.invalidArgument( + "FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory" + ) + } + return (config, outputPath) } } @@ -110,6 +118,7 @@ private func runMain() throws { if let filter = config.serializerFilter { print("Serializer filter: \(filter.rawValue)") } + print("Schema mismatch: \(config.schemaMismatch)") let suite = BenchmarkSuite(config: config) let output = try suite.run() @@ -119,8 +128,17 @@ private func runMain() throws { print("\nSerialized sizes (bytes):") for entry in output.serializedSizes { - print(" \(entry.dataType): fory=\(entry.fory), protobuf=\(entry.protobuf), json=\(entry.json)") + let protobufSize = formatSize(entry.protobuf) + let jsonSize = formatSize(entry.json) + print(" \(entry.dataType): fory=\(entry.fory), protobuf=\(protobufSize), json=\(jsonSize)") + } +} + +private func formatSize(_ value: Int?) -> String { + guard let value = value else { + return "N/A" } + return String(value) } private func writeOutput(_ output: BenchmarkOutput, to path: String) throws { diff --git a/benchmarks/swift/benchmark_report.py b/benchmarks/swift/benchmark_report.py index ca733fba0d..6fe360993a 100755 --- a/benchmarks/swift/benchmark_report.py +++ b/benchmarks/swift/benchmark_report.py @@ -131,8 +131,12 @@ def datatype_plot_label(datatype: str) -> str: return datatype.capitalize() -def format_tps(value: float) -> str: - return f"{value:,.0f}" +def format_tps(value) -> str: + return f"{value:,.0f}" if value is not None and value > 0 else "N/A" + + +def format_size(value) -> str: + return str(value) if isinstance(value, int) and value >= 0 else "N/A" def format_tps_label(value: float) -> str: @@ -288,9 +292,9 @@ def write_report( continue for operation in OPERATIONS: throughputs = results.get(datatype, {}).get(operation, {}) - fory = throughputs.get("fory", 0.0) - protobuf = throughputs.get("protobuf", 0.0) - json_tps = throughputs.get("json", 0.0) + fory = throughputs.get("fory") + protobuf = throughputs.get("protobuf") + json_tps = throughputs.get("json") lines.append( "| " + f"{datatype_title(datatype)} | {operation.capitalize()} | " @@ -314,9 +318,9 @@ def write_report( lines.append( "| " + f"{datatype_label} | " - + f"{entry.get('fory', '-')} | " - + f"{entry.get('protobuf', '-')} | " - + f"{entry.get('json', '-')} |" + + f"{format_size(entry.get('fory'))} | " + + f"{format_size(entry.get('protobuf'))} | " + + f"{format_size(entry.get('json'))} |" ) report_path = os.path.join(output_dir, "README.md") diff --git a/benchmarks/swift/run.sh b/benchmarks/swift/run.sh index 7e72253422..c9c2792a1d 100755 --- a/benchmarks/swift/run.sh +++ b/benchmarks/swift/run.sh @@ -92,6 +92,11 @@ while [[ $# -gt 0 ]]; do esac done +if [[ "${FORY_BENCH_SCHEMA_MISMATCH:-0}" == "1" && "$SERIALIZER" != "fory" ]]; then + echo -e "${RED}FORY_BENCH_SCHEMA_MISMATCH=1 supports only Fory benchmarks; rerun with --serializer fory.${NC}" + exit 1 +fi + mkdir -p results echo -e "${GREEN}=== Fory Swift Benchmark ===${NC}" diff --git a/compiler/fory_compiler/generators/javascript.py b/compiler/fory_compiler/generators/javascript.py index 2a33961bb7..ac17e6ea02 100644 --- a/compiler/fory_compiler/generators/javascript.py +++ b/compiler/fory_compiler/generators/javascript.py @@ -838,17 +838,22 @@ def _field_type_expr( self, field_type: FieldType, parent_stack: Optional[List[Message]] = None, + *, + nullable: bool = False, + ref: bool = False, + element_nullable_override: bool = False, + element_ref_override: bool = False, ) -> str: """Return the Fory JS runtime ``Type.xxx()`` expression for a field type.""" parent_stack = parent_stack or [] if isinstance(field_type, PrimitiveType): integer_expr = self._integer_type_expr(field_type) if integer_expr is not None: - return integer_expr - expr = self.PRIMITIVE_RUNTIME_MAP.get(field_type.kind) - if expr is None: - return "Type.any()" - return expr + expr = integer_expr + else: + expr = self.PRIMITIVE_RUNTIME_MAP.get(field_type.kind) + if expr is None: + expr = "Type.any()" elif isinstance(field_type, NamedType): # Check for primitive-like shorthand names (e.g. "float", "double") lower = field_type.name.lower() @@ -857,83 +862,125 @@ def _field_type_expr( "double": PrimitiveKind.FLOAT64, } if lower in shorthand_map: - return self.PRIMITIVE_RUNTIME_MAP[shorthand_map[lower]] - for pk in PrimitiveKind: - if pk.value == lower: - expr = self.PRIMITIVE_RUNTIME_MAP.get(pk) - if expr is None: - raise ValueError( - f"Primitive type '{pk.value}' has no JavaScript " - f"runtime mapping." - ) - return expr - - # Named type — could be a Message, Enum, or Union - resolved = self._resolve_named_type(field_type.name, parent_stack) - if isinstance(resolved, Enum): - if self.should_register_by_id(resolved): - name_info = str(resolved.type_id) + expr = self.PRIMITIVE_RUNTIME_MAP[shorthand_map[lower]] + else: + primitive_expr: Optional[str] = None + for pk in PrimitiveKind: + if pk.value == lower: + primitive_expr = self.PRIMITIVE_RUNTIME_MAP.get(pk) + if primitive_expr is None: + raise ValueError( + f"Primitive type '{pk.value}' has no JavaScript " + f"runtime mapping." + ) + break + if primitive_expr is not None: + expr = primitive_expr else: - ns = self._get_type_package(resolved) - qname = self._qualified_type_names.get(id(resolved), resolved.name) - name_info = f'"{ns}.{qname}"' - props = ", ".join( - f"{self.strip_enum_prefix(resolved.name, v.name)}: {v.value}" - for v in resolved.values - ) - return f"Type.enum({name_info}, {{ {props} }})" - if isinstance(resolved, Union): - case_parts = [] - for case_field in resolved.fields: - case_num = ( - case_field.tag_id - if case_field.tag_id is not None - else case_field.number - ) - if case_num is not None: - case_type_expr = self._field_type_expr( - case_field.field_type, parent_stack + # Named type — could be a Message, Enum, or Union + resolved = self._resolve_named_type(field_type.name, parent_stack) + if isinstance(resolved, Enum): + if self.should_register_by_id(resolved): + name_info = str(resolved.type_id) + else: + ns = self._get_type_package(resolved) + qname = self._qualified_type_names.get( + id(resolved), resolved.name + ) + name_info = f'"{ns}.{qname}"' + props = ", ".join( + f"{self.strip_enum_prefix(resolved.name, v.name)}: {v.value}" + for v in resolved.values ) - case_parts.append(f"{case_num}: {case_type_expr}") - cases_arg = f", {{ {', '.join(case_parts)} }}" if case_parts else "" - if self.should_register_by_id(resolved): - name_info = str(resolved.type_id) - else: - ns = self._get_type_package(resolved) - qname = self._qualified_type_names.get(id(resolved), resolved.name) - name_info = f'{{ namespace: "{ns}", typeName: "{qname}" }}' - return f"Type.union({name_info}{cases_arg})" - if isinstance(resolved, Message): - evolving = self.get_effective_evolving(resolved) - if self.should_register_by_id(resolved): - if evolving: - return f"Type.struct({resolved.type_id})" + expr = f"Type.enum({name_info}, {{ {props} }})" + elif isinstance(resolved, Union): + case_parts = [] + for case_field in resolved.fields: + case_num = ( + case_field.tag_id + if case_field.tag_id is not None + else case_field.number + ) + if case_num is not None: + case_type_expr = self._field_type_expr( + case_field.field_type, parent_stack + ) + case_parts.append(f"{case_num}: {case_type_expr}") + cases_arg = ( + f", {{ {', '.join(case_parts)} }}" if case_parts else "" + ) + if self.should_register_by_id(resolved): + name_info = str(resolved.type_id) + else: + ns = self._get_type_package(resolved) + qname = self._qualified_type_names.get( + id(resolved), resolved.name + ) + name_info = f'{{ namespace: "{ns}", typeName: "{qname}" }}' + expr = f"Type.union({name_info}{cases_arg})" + elif isinstance(resolved, Message): + evolving = self.get_effective_evolving(resolved) + if self.should_register_by_id(resolved): + if evolving: + expr = f"Type.struct({resolved.type_id})" + else: + expr = ( + f"Type.struct({{ typeId: {resolved.type_id}, " + "evolving: false })" + ) + else: + ns = self._get_type_package(resolved) + qname = self._qualified_type_names.get( + id(resolved), resolved.name + ) + if evolving: + expr = ( + f'Type.struct({{ namespace: "{ns}", ' + f'typeName: "{qname}" }})' + ) + else: + expr = ( + f'Type.struct({{ namespace: "{ns}", ' + f'typeName: "{qname}", evolving: false }})' + ) else: - return f"Type.struct({{ typeId: {resolved.type_id}, evolving: false }})" - ns = self._get_type_package(resolved) - qname = self._qualified_type_names.get(id(resolved), resolved.name) - if evolving: - return f'Type.struct({{ namespace: "{ns}", typeName: "{qname}" }})' - else: - return f'Type.struct({{ namespace: "{ns}", typeName: "{qname}", evolving: false }})' - # Unresolved — fall back to any - return "Type.any()" + # Unresolved — fall back to any + expr = "Type.any()" elif isinstance(field_type, ListType): - inner = self._field_type_expr(field_type.element_type, parent_stack) - return f"Type.list({inner})" + inner = self._field_type_expr( + field_type.element_type, + parent_stack, + nullable=field_type.element_optional or element_nullable_override, + ref=field_type.element_ref or element_ref_override, + ) + expr = f"Type.list({inner})" elif isinstance(field_type, ArrayType): if isinstance(field_type.element_type, PrimitiveType): type_expr = self.PRIMITIVE_ARRAY_RUNTIME_MAP.get( field_type.element_type.kind ) if type_expr: - return type_expr - return "Type.any()" + expr = type_expr + else: + expr = "Type.any()" + else: + expr = "Type.any()" elif isinstance(field_type, MapType): key = self._field_type_expr(field_type.key_type, parent_stack) - value = self._field_type_expr(field_type.value_type, parent_stack) - return f"Type.map({key}, {value})" - return "Type.any()" + value = self._field_type_expr( + field_type.value_type, + parent_stack, + nullable=field_type.value_optional, + ref=field_type.value_ref, + ) + expr = f"Type.map({key}, {value})" + else: + expr = "Type.any()" + if nullable or ref: + expr += ".setNullable(true)" + if ref: + expr += ".setTrackingRef(true)" + return expr def _integer_type_expr(self, field_type: PrimitiveType) -> Optional[str]: method = { @@ -966,13 +1013,16 @@ def _register_type_line( used_field_names: Set[str] = set() for field in type_def.fields: member = self._field_member_name(field, type_def, used_field_names) - expr = self._field_type_expr(field.field_type, field_parent_stack) + expr = self._field_type_expr( + field.field_type, + field_parent_stack, + nullable=field.optional or field.ref, + ref=field.ref, + element_nullable_override=field.element_optional, + element_ref_override=field.element_ref, + ) if field.tag_id is not None: expr += f".setId({field.tag_id})" - if field.optional: - expr += ".setNullable(true)" - if field.ref: - expr += ".setTrackingRef(true)" props_parts.append(f"{member}: {expr}") props_str = ", ".join(props_parts) diff --git a/compiler/fory_compiler/tests/test_javascript_codegen.py b/compiler/fory_compiler/tests/test_javascript_codegen.py index fa428d602d..9a6a6467ff 100644 --- a/compiler/fory_compiler/tests/test_javascript_codegen.py +++ b/compiler/fory_compiler/tests/test_javascript_codegen.py @@ -219,6 +219,29 @@ def test_javascript_collection_types(): assert "config: Map;" in output +def test_javascript_collection_ref_registration(): + source = dedent( + """ + package example; + + message Node [id=100] { + list children = 1; + map refs = 2; + } + """ + ) + output = generate_javascript(source) + + assert ( + "children: Type.list(Type.struct(100).setNullable(true).setTrackingRef(true)).setId(1)" + in output + ) + assert ( + "refs: Type.map(Type.string(), Type.struct(100).setNullable(true).setTrackingRef(true)).setId(2)" + in output + ) + + def test_javascript_decimal_generation_uses_runtime_decimal_type(): source = dedent( """ diff --git a/cpp/fory/serialization/compatible_scalar.cc b/cpp/fory/serialization/compatible_scalar.cc index 838340ab15..1f31f1c49f 100644 --- a/cpp/fory/serialization/compatible_scalar.cc +++ b/cpp/fory/serialization/compatible_scalar.cc @@ -1229,6 +1229,194 @@ bool scalar_to_float64(const ScalarValue &value, uint32_t remote_type_id, decimal_to_float64(decimal, negative_zero, out); } +bool read_int_target(ReadContext &ctx, uint32_t remote_type_id, + uint32_t local_type_id, std::string_view field, + int64_t min_value, int64_t max_value, int64_t &out) { + int64_t signed_value = 0; + uint64_t unsigned_value = 0; + bool is_unsigned = false; + switch (static_cast(remote_type_id)) { + case TypeId::BOOL: { + uint8_t raw = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + if (FORY_PREDICT_FALSE(raw != 0 && raw != 1)) { + conversion_error(ctx, remote_type_id, local_type_id, field, + "invalid bool payload"); + return false; + } + signed_value = raw != 0 ? 1 : 0; + break; + } + case TypeId::INT8: + signed_value = ctx.read_int8(ctx.error()); + break; + case TypeId::INT16: + signed_value = ctx.read_int16(ctx.error()); + break; + case TypeId::INT32: + signed_value = ctx.read_int32(ctx.error()); + break; + case TypeId::VARINT32: + signed_value = ctx.read_var_int32(ctx.error()); + break; + case TypeId::INT64: + signed_value = ctx.read_int64(ctx.error()); + break; + case TypeId::VARINT64: + signed_value = ctx.read_var_int64(ctx.error()); + break; + case TypeId::TAGGED_INT64: + signed_value = ctx.read_tagged_int64(ctx.error()); + break; + case TypeId::UINT8: + is_unsigned = true; + unsigned_value = ctx.read_uint8(ctx.error()); + break; + case TypeId::UINT16: + is_unsigned = true; + unsigned_value = static_cast(ctx.read_int16(ctx.error())); + break; + case TypeId::UINT32: + is_unsigned = true; + unsigned_value = static_cast(ctx.read_int32(ctx.error())); + break; + case TypeId::VAR_UINT32: + is_unsigned = true; + unsigned_value = ctx.read_var_uint32(ctx.error()); + break; + case TypeId::UINT64: + is_unsigned = true; + unsigned_value = static_cast(ctx.read_int64(ctx.error())); + break; + case TypeId::VAR_UINT64: + is_unsigned = true; + unsigned_value = ctx.read_var_uint64(ctx.error()); + break; + case TypeId::TAGGED_UINT64: + is_unsigned = true; + unsigned_value = ctx.read_tagged_uint64(ctx.error()); + break; + default: + return false; + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + if (is_unsigned) { + if (FORY_PREDICT_FALSE(unsigned_value > static_cast(max_value))) { + conversion_error(ctx, remote_type_id, local_type_id, field, + "value is not lossless"); + return false; + } + out = static_cast(unsigned_value); + return true; + } + if (FORY_PREDICT_FALSE(signed_value < min_value || + signed_value > max_value)) { + conversion_error(ctx, remote_type_id, local_type_id, field, + "value is not lossless"); + return false; + } + out = signed_value; + return true; +} + +bool read_uint_target(ReadContext &ctx, uint32_t remote_type_id, + uint32_t local_type_id, std::string_view field, + uint64_t max_value, uint64_t &out) { + int64_t signed_value = 0; + uint64_t unsigned_value = 0; + bool is_unsigned = false; + switch (static_cast(remote_type_id)) { + case TypeId::BOOL: { + uint8_t raw = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + if (FORY_PREDICT_FALSE(raw != 0 && raw != 1)) { + conversion_error(ctx, remote_type_id, local_type_id, field, + "invalid bool payload"); + return false; + } + signed_value = raw != 0 ? 1 : 0; + break; + } + case TypeId::INT8: + signed_value = ctx.read_int8(ctx.error()); + break; + case TypeId::INT16: + signed_value = ctx.read_int16(ctx.error()); + break; + case TypeId::INT32: + signed_value = ctx.read_int32(ctx.error()); + break; + case TypeId::VARINT32: + signed_value = ctx.read_var_int32(ctx.error()); + break; + case TypeId::INT64: + signed_value = ctx.read_int64(ctx.error()); + break; + case TypeId::VARINT64: + signed_value = ctx.read_var_int64(ctx.error()); + break; + case TypeId::TAGGED_INT64: + signed_value = ctx.read_tagged_int64(ctx.error()); + break; + case TypeId::UINT8: + is_unsigned = true; + unsigned_value = ctx.read_uint8(ctx.error()); + break; + case TypeId::UINT16: + is_unsigned = true; + unsigned_value = static_cast(ctx.read_int16(ctx.error())); + break; + case TypeId::UINT32: + is_unsigned = true; + unsigned_value = static_cast(ctx.read_int32(ctx.error())); + break; + case TypeId::VAR_UINT32: + is_unsigned = true; + unsigned_value = ctx.read_var_uint32(ctx.error()); + break; + case TypeId::UINT64: + is_unsigned = true; + unsigned_value = static_cast(ctx.read_int64(ctx.error())); + break; + case TypeId::VAR_UINT64: + is_unsigned = true; + unsigned_value = ctx.read_var_uint64(ctx.error()); + break; + case TypeId::TAGGED_UINT64: + is_unsigned = true; + unsigned_value = ctx.read_tagged_uint64(ctx.error()); + break; + default: + return false; + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + if (is_unsigned) { + if (FORY_PREDICT_FALSE(unsigned_value > max_value)) { + conversion_error(ctx, remote_type_id, local_type_id, field, + "value is not lossless"); + return false; + } + out = unsigned_value; + return true; + } + if (FORY_PREDICT_FALSE(signed_value < 0 || + static_cast(signed_value) > max_value)) { + conversion_error(ctx, remote_type_id, local_type_id, field, + "value is not lossless"); + return false; + } + out = static_cast(signed_value); + return true; +} + template ResultType read_converted(ReadContext &ctx, uint32_t remote_type_id, uint32_t local_type_id, std::string_view field, @@ -1307,6 +1495,15 @@ FORY_NOINLINE bool read_compatible_bool(ReadContext &ctx, FORY_NOINLINE int8_t read_compatible_int8(ReadContext &ctx, uint32_t remote_type_id, std::string_view field) { + int64_t value = 0; + if (read_int_target(ctx, remote_type_id, static_cast(TypeId::INT8), + field, std::numeric_limits::min(), + std::numeric_limits::max(), value)) { + return static_cast(value); + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return 0; + } return read_converted( ctx, remote_type_id, static_cast(TypeId::INT8), field, [](const ScalarValue &value, int8_t &out) { @@ -1323,6 +1520,15 @@ FORY_NOINLINE int8_t read_compatible_int8(ReadContext &ctx, FORY_NOINLINE uint8_t read_compatible_uint8(ReadContext &ctx, uint32_t remote_type_id, std::string_view field) { + uint64_t value = 0; + if (read_uint_target(ctx, remote_type_id, + static_cast(TypeId::UINT8), field, + std::numeric_limits::max(), value)) { + return static_cast(value); + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return 0; + } return read_converted( ctx, remote_type_id, static_cast(TypeId::UINT8), field, [](const ScalarValue &value, uint8_t &out) { @@ -1339,6 +1545,15 @@ FORY_NOINLINE uint8_t read_compatible_uint8(ReadContext &ctx, FORY_NOINLINE int16_t read_compatible_int16(ReadContext &ctx, uint32_t remote_type_id, std::string_view field) { + int64_t value = 0; + if (read_int_target(ctx, remote_type_id, static_cast(TypeId::INT16), + field, std::numeric_limits::min(), + std::numeric_limits::max(), value)) { + return static_cast(value); + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return 0; + } return read_converted( ctx, remote_type_id, static_cast(TypeId::INT16), field, [](const ScalarValue &value, int16_t &out) { @@ -1355,6 +1570,15 @@ FORY_NOINLINE int16_t read_compatible_int16(ReadContext &ctx, FORY_NOINLINE uint16_t read_compatible_uint16(ReadContext &ctx, uint32_t remote_type_id, std::string_view field) { + uint64_t value = 0; + if (read_uint_target(ctx, remote_type_id, + static_cast(TypeId::UINT16), field, + std::numeric_limits::max(), value)) { + return static_cast(value); + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return 0; + } return read_converted( ctx, remote_type_id, static_cast(TypeId::UINT16), field, [](const ScalarValue &value, uint16_t &out) { @@ -1371,6 +1595,16 @@ FORY_NOINLINE uint16_t read_compatible_uint16(ReadContext &ctx, FORY_NOINLINE int32_t read_compatible_int32(ReadContext &ctx, uint32_t remote_type_id, std::string_view field) { + int64_t value = 0; + if (read_int_target(ctx, remote_type_id, + static_cast(TypeId::VARINT32), field, + std::numeric_limits::min(), + std::numeric_limits::max(), value)) { + return static_cast(value); + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return 0; + } return read_converted( ctx, remote_type_id, static_cast(TypeId::VARINT32), field, [](const ScalarValue &value, int32_t &out) { @@ -1387,6 +1621,15 @@ FORY_NOINLINE int32_t read_compatible_int32(ReadContext &ctx, FORY_NOINLINE uint32_t read_compatible_uint32(ReadContext &ctx, uint32_t remote_type_id, std::string_view field) { + uint64_t value = 0; + if (read_uint_target(ctx, remote_type_id, + static_cast(TypeId::UINT32), field, + std::numeric_limits::max(), value)) { + return static_cast(value); + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return 0; + } return read_converted( ctx, remote_type_id, static_cast(TypeId::UINT32), field, [](const ScalarValue &value, uint32_t &out) { @@ -1403,6 +1646,16 @@ FORY_NOINLINE uint32_t read_compatible_uint32(ReadContext &ctx, FORY_NOINLINE int64_t read_compatible_int64(ReadContext &ctx, uint32_t remote_type_id, std::string_view field) { + int64_t value = 0; + if (read_int_target(ctx, remote_type_id, + static_cast(TypeId::VARINT64), field, + std::numeric_limits::min(), + std::numeric_limits::max(), value)) { + return value; + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return 0; + } return read_converted( ctx, remote_type_id, static_cast(TypeId::VARINT64), field, [](const ScalarValue &value, int64_t &out) { @@ -1414,6 +1667,15 @@ FORY_NOINLINE int64_t read_compatible_int64(ReadContext &ctx, FORY_NOINLINE uint64_t read_compatible_uint64(ReadContext &ctx, uint32_t remote_type_id, std::string_view field) { + uint64_t value = 0; + if (read_uint_target(ctx, remote_type_id, + static_cast(TypeId::UINT64), field, + std::numeric_limits::max(), value)) { + return value; + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return 0; + } return read_converted( ctx, remote_type_id, static_cast(TypeId::UINT64), field, [](const ScalarValue &value, uint64_t &out) { diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 6eec0267d9..95b2453205 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -563,8 +563,8 @@ Result ReadContext::read_type_meta() { // Have local type - assign field_ids by comparing schemas // Note: Extension types don't have type_meta (only structs do) if (local_type_info->type_meta) { - TypeMeta::assign_field_ids(local_type_info->type_meta.get(), - parsed_meta->field_infos); + FORY_RETURN_NOT_OK(TypeMeta::assign_field_ids( + local_type_info->type_meta.get(), parsed_meta->field_infos)); } type_info->type_id = local_type_info->type_id; type_info->user_type_id = local_type_info->user_type_id; diff --git a/cpp/fory/serialization/struct_compatible_test.cc b/cpp/fory/serialization/struct_compatible_test.cc index a41e8ac522..da3450eaef 100644 --- a/cpp/fory/serialization/struct_compatible_test.cc +++ b/cpp/fory/serialization/struct_compatible_test.cc @@ -261,6 +261,22 @@ struct CompatibleNestedArrayField { fory::T::array(fory::T::int32())))); }; +struct CompatibleUnsignedExactV1 { + uint32_t id = 0; + uint64_t count = 0; + std::string extra; + + FORY_STRUCT(CompatibleUnsignedExactV1, (id, fory::F(1)), (count, fory::F(2)), + (extra, fory::F(3))); +}; + +struct CompatibleUnsignedExactV2 { + uint32_t id = 0; + uint64_t count = 0; + + FORY_STRUCT(CompatibleUnsignedExactV2, (id, fory::F(1)), (count, fory::F(2))); +}; + struct ScalarBoolField { bool value = false; FORY_STRUCT(ScalarBoolField, (value, fory::F(1))); @@ -604,7 +620,7 @@ TEST(SchemaEvolutionTest, ImmediateArrayFieldCanReadIntoListCarrier) { EXPECT_EQ(decoded.value().values, (std::vector{4, 5, 6})); } -TEST(SchemaEvolutionTest, NullableListElementsCannotReadIntoArrayCarrier) { +TEST(SchemaEvolutionTest, NullableListElementsReadIntoArrayCarrier) { auto writer = Fory::builder().compatible(true).xlang(true).build(); auto reader = Fory::builder().compatible(true).xlang(true).build(); @@ -622,13 +638,14 @@ TEST(SchemaEvolutionTest, NullableListElementsCannotReadIntoArrayCarrier) { ASSERT_TRUE(decoded.ok()) << decoded.error().to_string(); EXPECT_EQ(decoded.value().values, (std::vector{1, 2})); - bytes = writer.serialize(CompatibleNullableListField{{1, std::nullopt}}); - ASSERT_TRUE(bytes.ok()) << bytes.error().to_string(); - payload = std::move(bytes).value(); - decoded = - reader.deserialize(payload.data(), payload.size()); - - ASSERT_FALSE(decoded.ok()); + auto null_bytes = writer.serialize( + CompatibleNullableListField{{std::optional{1}, std::nullopt}}); + ASSERT_TRUE(null_bytes.ok()) << null_bytes.error().to_string(); + auto null_payload = std::move(null_bytes).value(); + auto null_decoded = reader.deserialize( + null_payload.data(), null_payload.size()); + ASSERT_FALSE(null_decoded.ok()); + EXPECT_EQ(null_decoded.error().code(), ErrorCode::InvalidData); } TEST(SchemaEvolutionTest, NestedListArraySchemaPairsAreNotMatched) { @@ -645,8 +662,27 @@ TEST(SchemaEvolutionTest, NestedListArraySchemaPairsAreNotMatched) { auto decoded = reader.deserialize( bytes.value().data(), bytes.value().size()); + ASSERT_FALSE(decoded.ok()); + EXPECT_EQ(decoded.error().code(), ErrorCode::TypeError); +} + +TEST(SchemaEvolutionTest, ChangedSchemaReadsExactUnsignedFields) { + auto writer = Fory::builder().compatible(true).xlang(true).build(); + auto reader = Fory::builder().compatible(true).xlang(true).build(); + + constexpr uint32_t TYPE_ID = 1009; + ASSERT_TRUE(writer.register_struct(TYPE_ID).ok()); + ASSERT_TRUE(reader.register_struct(TYPE_ID).ok()); + + CompatibleUnsignedExactV1 value{300u, 5000000000ull, "skip"}; + auto bytes = writer.serialize(value); + ASSERT_TRUE(bytes.ok()) << bytes.error().to_string(); + auto decoded = reader.deserialize( + bytes.value().data(), bytes.value().size()); + ASSERT_TRUE(decoded.ok()) << decoded.error().to_string(); - EXPECT_TRUE(decoded.value().values.empty()); + EXPECT_EQ(decoded.value().id, value.id); + EXPECT_EQ(decoded.value().count, value.count); } TEST(SchemaEvolutionTest, ScalarBoolString) { @@ -735,6 +771,10 @@ TEST(SchemaEvolutionTest, ScalarNumberNumber) { ASSERT_TRUE(narrowed.ok()) << narrowed.error().to_string(); EXPECT_EQ(narrowed.value().value, 127); + auto widened = convert_field({123}, 1050); + ASSERT_TRUE(widened.ok()) << widened.error().to_string(); + EXPECT_EQ(widened.value().value, 123); + auto range_error = convert_field({128}, 1017); ASSERT_FALSE(range_error.ok()); diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index d2dbf83f28..967d5775d7 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -1227,30 +1227,12 @@ ValueType read_configured_value(ReadContext &ctx, RefMode ref_mode, } } -template -void dispatch_field_index_impl(size_t target_index, Func &&func, - std::index_sequence, bool &handled) { - handled = ((target_index == Indices - ? (func(std::integral_constant{}), true) - : false) || - ...); -} - -template -void dispatch_field_index(size_t target_index, Func &&func, bool &handled) { - constexpr size_t field_count = - decltype(fory_field_info(std::declval()))::Size; - dispatch_field_index_impl(target_index, std::forward(func), - std::make_index_sequence{}, - handled); -} - // ------------------------------------------------------------------ // Compile-time helpers to compute sorted field indices / names and -// create small jump-table wrappers to unroll read/write per-field calls. +// create small switch-dispatch wrappers to unroll read/write per-field calls. // The goal is to mimic the Rust-derived serializer behaviour where the // sorted field order is known at compile-time and the read path for -// compatible mode uses a fast switch/jump table. +// compatible mode uses generated source-level switch cases. // ------------------------------------------------------------------ template struct CompileTimeFieldHelpers { @@ -2886,6 +2868,9 @@ FORY_ALWAYS_INLINE TargetType read_compatible_scalar_by_type_id( read_compatible_uint32(ctx, remote_type_id, field)); } else if constexpr (std::is_same_v || std::is_same_v) { + if (remote_type_id == static_cast(TypeId::VARINT32)) { + return static_cast(ctx.read_var_int32(ctx.error())); + } return static_cast( read_compatible_int64(ctx, remote_type_id, field)); } else if constexpr (std::is_same_v || @@ -3033,17 +3018,17 @@ void read_single_field_by_index(T &obj, ReadContext &ctx) { constexpr bool is_raw_prim = is_raw_primitive_v; if constexpr (is_raw_prim && is_primitive_field && !field_type_is_nullable && !is_nullable) { - auto read_value = [&ctx]() -> FieldType { - if constexpr (is_configurable_int_v) { - return read_configurable_int(ctx); - } - return read_primitive_field_direct(ctx, ctx.error()); - }; + FieldType value; + if constexpr (is_configurable_int_v) { + value = read_configurable_int(ctx); + } else { + value = read_primitive_field_direct(ctx, ctx.error()); + } // Assign to field. if constexpr (is_fory_field_v) { - (obj.*field_ptr).value = read_value(); + (obj.*field_ptr).value = value; } else { - obj.*field_ptr = read_value(); + obj.*field_ptr = value; } } else { // Special handling for std::optional with encoding @@ -3416,30 +3401,350 @@ void read_single_field_by_index_compatible(T &obj, ReadContext &ctx, } } -/// Helper to dispatch field reading by field_id in compatible mode. -/// Uses fold expression with short-circuit to avoid lambda overhead. -/// Sets handled=true if field was matched. -/// @param remote_field_type The field type tree from the remote schema. -template +template +FORY_ALWAYS_INLINE void read_compatible_exact_case(T &obj, ReadContext &ctx) { + using Helpers = CompileTimeFieldHelpers; + constexpr size_t local_sorted_id = MatchedId / 2; + constexpr size_t original_index = Helpers::sorted_indices[local_sorted_id]; + read_single_field_by_index(obj, ctx); +} + +template +FORY_ALWAYS_INLINE T read_primitive_at_checked(Buffer &buffer, uint32_t &offset, + Error &error); + +template +constexpr bool can_read_exact_primitive_with_offset() { + using Helpers = CompileTimeFieldHelpers; + constexpr size_t original_index = Helpers::sorted_indices[Index]; + using FieldPtr = + typename std::tuple_element::type; + using RawFieldType = typename meta::RemoveMemberPointerCVRefT; + using FieldType = unwrap_field_t; + constexpr TypeId field_type_id = Serializer::type_id; + return is_raw_primitive_v && is_primitive_type_id(field_type_id) && + !is_nullable_v && + !Helpers::template field_nullable() && + !configured_node_has_override(); +} + +template +FORY_ALWAYS_INLINE void read_single_exact_primitive_at(T &obj, ReadContext &ctx, + uint32_t &offset) { + using Helpers = CompileTimeFieldHelpers; + constexpr size_t original_index = Helpers::sorted_indices[Index]; + const auto field_info = fory_field_info(obj); + const auto field_ptr = + std::get(decltype(field_info)::ptrs_ref()); + using RawFieldType = + typename meta::RemoveMemberPointerCVRefT; + using FieldType = unwrap_field_t; + FieldType value = + read_primitive_at_checked(ctx.buffer(), offset, ctx.error()); + if constexpr (is_fory_field_v) { + (obj.*field_ptr).value = value; + } else { + obj.*field_ptr = value; + } +} + +template +FORY_ALWAYS_INLINE void read_compatible_exact_case_at(T &obj, ReadContext &ctx, + uint32_t &offset) { + constexpr size_t local_sorted_id = MatchedId / 2; + if constexpr (can_read_exact_primitive_with_offset()) { + read_single_exact_primitive_at(obj, ctx, offset); + } else { + ctx.buffer().reader_index(offset); + read_compatible_exact_case(obj, ctx); + offset = ctx.buffer().reader_index(); + } +} + +template FORY_ALWAYS_INLINE void -dispatch_compatible_field_read_impl(T &obj, ReadContext &ctx, int16_t field_id, - const FieldType &remote_field_type, - bool &handled, - std::index_sequence) { +read_exact_primitive_run(T &obj, ReadContext &ctx, + const std::vector &remote_fields, + size_t &remote_idx, uint32_t &offset) { + read_single_exact_primitive_at(obj, ctx, offset); + constexpr size_t next_sorted_idx = SortedIdx + 1; + if constexpr (next_sorted_idx < CompileTimeFieldHelpers::FieldCount) { + if constexpr (can_read_exact_primitive_with_offset()) { + constexpr int16_t next_matched_id = + static_cast(next_sorted_idx * 2); + if (remote_idx + 1 < remote_fields.size() && + remote_fields[remote_idx + 1].field_id == next_matched_id) { + ++remote_idx; + read_exact_primitive_run(obj, ctx, remote_fields, + remote_idx, offset); + } + } + } +} + +template +FORY_NOINLINE void +read_compatible_conversion_case(T &obj, ReadContext &ctx, + const FieldType &remote_field_type) { using Helpers = CompileTimeFieldHelpers; + constexpr size_t local_sorted_id = MatchedId / 2; + constexpr size_t original_index = Helpers::sorted_indices[local_sorted_id]; + read_single_field_by_index_compatible(obj, ctx, + remote_field_type); +} - // Short-circuit fold: stops at first match - // Each element evaluates to bool; || short-circuits on first true - (void)((static_cast(Indices) == field_id - ? (handled = true, - read_single_field_by_index_compatible< - Helpers::sorted_indices[Indices]>(obj, ctx, - remote_field_type), - true) - : false) || - ...); +#define FORY_COMPAT_EXACT_READ_SWITCH_CASE(N) \ + case Base + (N): { \ + constexpr size_t matched_case = static_cast(Base + (N)); \ + if constexpr (matched_case < total_cases && (matched_case & 1U) == 0) { \ + read_compatible_exact_case(obj, ctx); \ + } else { \ + ctx.set_error(Error::type_error("Invalid compatible matched id")); \ + } \ + return; \ + } + +#define FORY_COMPAT_EXACT_READ_SWITCH_CASES_16(O) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 0) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 1) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 2) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 3) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 4) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 5) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 6) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 7) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 8) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 9) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 10) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 11) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 12) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 13) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 14) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASE((O) + 15) + +#define FORY_COMPAT_EXACT_READ_SWITCH_CASES_128() \ + FORY_COMPAT_EXACT_READ_SWITCH_CASES_16(0) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASES_16(16) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASES_16(32) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASES_16(48) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASES_16(64) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASES_16(80) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASES_16(96) \ + FORY_COMPAT_EXACT_READ_SWITCH_CASES_16(112) + +template +FORY_ALWAYS_INLINE void +dispatch_compat_exact_read_impl(T &obj, ReadContext &ctx, int16_t matched_id) { + constexpr size_t total_cases = + CompileTimeFieldHelpers::FieldCount * static_cast(2); + switch (matched_id) { + FORY_COMPAT_EXACT_READ_SWITCH_CASES_128() + default: + if constexpr (static_cast(Base) + 128U < total_cases) { + dispatch_compat_exact_read_impl(obj, ctx, matched_id); + } else { + ctx.set_error(Error::type_error("Invalid compatible matched id")); + } + return; + } } +#undef FORY_COMPAT_EXACT_READ_SWITCH_CASES_128 +#undef FORY_COMPAT_EXACT_READ_SWITCH_CASES_16 +#undef FORY_COMPAT_EXACT_READ_SWITCH_CASE + +#define FORY_COMPAT_CONV_READ_SWITCH_CASE(N) \ + case Base + (N): { \ + constexpr size_t matched_case = static_cast(Base + (N)); \ + if constexpr (matched_case < total_cases && (matched_case & 1U) != 0) { \ + read_compatible_conversion_case(obj, ctx, \ + remote_field_type); \ + } else { \ + ctx.set_error(Error::type_error("Invalid compatible matched id")); \ + } \ + return; \ + } + +#define FORY_COMPAT_CONV_READ_SWITCH_CASES_16(O) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 0) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 1) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 2) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 3) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 4) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 5) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 6) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 7) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 8) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 9) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 10) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 11) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 12) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 13) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 14) \ + FORY_COMPAT_CONV_READ_SWITCH_CASE((O) + 15) + +#define FORY_COMPAT_CONV_READ_SWITCH_CASES_128() \ + FORY_COMPAT_CONV_READ_SWITCH_CASES_16(0) \ + FORY_COMPAT_CONV_READ_SWITCH_CASES_16(16) \ + FORY_COMPAT_CONV_READ_SWITCH_CASES_16(32) \ + FORY_COMPAT_CONV_READ_SWITCH_CASES_16(48) \ + FORY_COMPAT_CONV_READ_SWITCH_CASES_16(64) \ + FORY_COMPAT_CONV_READ_SWITCH_CASES_16(80) \ + FORY_COMPAT_CONV_READ_SWITCH_CASES_16(96) \ + FORY_COMPAT_CONV_READ_SWITCH_CASES_16(112) + +template +FORY_NOINLINE void +dispatch_compat_conversion_read_impl(T &obj, ReadContext &ctx, + int16_t matched_id, + const FieldType &remote_field_type) { + constexpr size_t total_cases = + CompileTimeFieldHelpers::FieldCount * static_cast(2); + switch (matched_id) { + FORY_COMPAT_CONV_READ_SWITCH_CASES_128() + default: + if constexpr (static_cast(Base) + 128U < total_cases) { + dispatch_compat_conversion_read_impl(obj, ctx, matched_id, + remote_field_type); + } else { + ctx.set_error(Error::type_error("Invalid compatible matched id")); + } + return; + } +} + +#undef FORY_COMPAT_CONV_READ_SWITCH_CASES_128 +#undef FORY_COMPAT_CONV_READ_SWITCH_CASES_16 +#undef FORY_COMPAT_CONV_READ_SWITCH_CASE + +#define FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE(N) \ + case Base + (N): { \ + constexpr size_t matched_case = static_cast(Base + (N)); \ + if constexpr (matched_case < total_cases && (matched_case & 1U) == 0) { \ + read_compatible_exact_case_at(obj, ctx, offset); \ + } else { \ + ctx.set_error(Error::type_error("Invalid compatible matched id")); \ + } \ + return; \ + } + +#define FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_16(O) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 0) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 1) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 2) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 3) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 4) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 5) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 6) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 7) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 8) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 9) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 10) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 11) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 12) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 13) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 14) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE((O) + 15) + +#define FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_128() \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_16(0) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_16(16) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_16(32) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_16(48) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_16(64) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_16(80) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_16(96) \ + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_16(112) + +template +FORY_ALWAYS_INLINE void +dispatch_compat_exact_read_at_impl(T &obj, ReadContext &ctx, int16_t matched_id, + uint32_t &offset) { + constexpr size_t total_cases = + CompileTimeFieldHelpers::FieldCount * static_cast(2); + switch (matched_id) { + FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_128() + default: + if constexpr (static_cast(Base) + 128U < total_cases) { + dispatch_compat_exact_read_at_impl(obj, ctx, matched_id, + offset); + } else { + ctx.set_error(Error::type_error("Invalid compatible matched id")); + } + return; + } +} + +#undef FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_128 +#undef FORY_COMPAT_EXACT_READ_AT_SWITCH_CASES_16 +#undef FORY_COMPAT_EXACT_READ_AT_SWITCH_CASE + +#define FORY_COMPAT_CONV_READ_AT_SWITCH_CASE(N) \ + case Base + (N): { \ + constexpr size_t matched_case = static_cast(Base + (N)); \ + if constexpr (matched_case < total_cases && (matched_case & 1U) != 0) { \ + ctx.buffer().reader_index(offset); \ + read_compatible_conversion_case(obj, ctx, \ + remote_field_type); \ + offset = ctx.buffer().reader_index(); \ + } else { \ + ctx.set_error(Error::type_error("Invalid compatible matched id")); \ + } \ + return; \ + } + +#define FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_16(O) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 0) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 1) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 2) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 3) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 4) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 5) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 6) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 7) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 8) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 9) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 10) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 11) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 12) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 13) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 14) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASE((O) + 15) + +#define FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_128() \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_16(0) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_16(16) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_16(32) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_16(48) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_16(64) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_16(80) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_16(96) \ + FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_16(112) + +template +FORY_NOINLINE void dispatch_compat_conversion_read_at_impl( + T &obj, ReadContext &ctx, int16_t matched_id, + const FieldType &remote_field_type, uint32_t &offset) { + constexpr size_t total_cases = + CompileTimeFieldHelpers::FieldCount * static_cast(2); + switch (matched_id) { + FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_128() + default: + if constexpr (static_cast(Base) + 128U < total_cases) { + dispatch_compat_conversion_read_at_impl( + obj, ctx, matched_id, remote_field_type, offset); + } else { + ctx.set_error(Error::type_error("Invalid compatible matched id")); + } + return; + } +} + +#undef FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_128 +#undef FORY_COMPAT_CONV_READ_AT_SWITCH_CASES_16 +#undef FORY_COMPAT_CONV_READ_AT_SWITCH_CASE + /// Helper to read a single field at compile-time sorted position template void read_field_at_sorted_position(T &obj, ReadContext &ctx) { @@ -3611,6 +3916,163 @@ FORY_ALWAYS_INLINE T read_varint_at(Buffer &buffer, uint32_t &offset) { } } +template +FORY_ALWAYS_INLINE T read_varint_at_checked(Buffer &buffer, uint32_t &offset, + Error &error) { + uint32_t bytes_read = 0; + if constexpr (std::is_same_v || std::is_same_v) { + uint32_t raw = buffer.get_var_uint32(offset, &bytes_read); + if (FORY_PREDICT_FALSE(bytes_read == 0)) { + error.set_buffer_out_of_bound(offset, 1, buffer.size()); + return T{}; + } + offset += bytes_read; + return static_cast((raw >> 1) ^ (~(raw & 1) + 1)); + } else if constexpr (std::is_same_v || + std::is_same_v) { + uint64_t raw = buffer.get_var_uint64(offset, &bytes_read); + if (FORY_PREDICT_FALSE(bytes_read == 0)) { + error.set_buffer_out_of_bound(offset, 1, buffer.size()); + return T{}; + } + offset += bytes_read; + return static_cast((raw >> 1) ^ (~(raw & 1) + 1)); + } else if constexpr (std::is_same_v || + std::is_same_v) { + uint32_t raw = buffer.get_var_uint32(offset, &bytes_read); + if (FORY_PREDICT_FALSE(bytes_read == 0)) { + error.set_buffer_out_of_bound(offset, 1, buffer.size()); + return T{}; + } + offset += bytes_read; + return static_cast(raw); + } else if constexpr (std::is_same_v || + std::is_same_v) { + uint64_t raw = buffer.get_var_uint64(offset, &bytes_read); + if (FORY_PREDICT_FALSE(bytes_read == 0)) { + error.set_buffer_out_of_bound(offset, 1, buffer.size()); + return T{}; + } + offset += bytes_read; + return static_cast(raw); + } else { + static_assert(sizeof(T) == 0, "Unsupported checked varint type"); + return T{}; + } +} + +template +FORY_ALWAYS_INLINE bool ensure_offset_readable(Buffer &buffer, uint32_t offset, + Error &error) { + static_assert(Bytes > 0, "Bytes must be positive"); + if (FORY_PREDICT_FALSE(offset > buffer.size() || + buffer.size() - offset < Bytes)) { + error.set_buffer_out_of_bound(offset, Bytes, buffer.size()); + return false; + } + return true; +} + +template +FORY_ALWAYS_INLINE T read_fixed_primitive_at_checked(Buffer &buffer, + uint32_t &offset, + Error &error) { + if constexpr (std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<1>(buffer, offset, error))) { + return false; + } + bool value = buffer.unsafe_get(offset) != 0; + offset += 1; + return value; + } else if constexpr (std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<1>(buffer, offset, error))) { + return T{}; + } + T value = static_cast(buffer.unsafe_get(offset)); + offset += 1; + return value; + } else if constexpr (std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<1>(buffer, offset, error))) { + return T{}; + } + T value = buffer.unsafe_get(offset); + offset += 1; + return value; + } else if constexpr (std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<2>(buffer, offset, error))) { + return T{}; + } + T value = buffer.unsafe_get(offset); + offset += 2; + return value; + } else if constexpr (std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<2>(buffer, offset, error))) { + return T{}; + } + T value = buffer.unsafe_get(offset); + offset += 2; + return value; + } else if constexpr (std::is_same_v || + std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<4>(buffer, offset, error))) { + return T{}; + } + T value = static_cast(buffer.unsafe_get(offset)); + offset += 4; + return value; + } else if constexpr (std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<4>(buffer, offset, error))) { + return T{}; + } + T value = buffer.unsafe_get(offset); + offset += 4; + return value; + } else if constexpr (std::is_same_v || + std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<8>(buffer, offset, error))) { + return T{}; + } + T value = static_cast(buffer.unsafe_get(offset)); + offset += 8; + return value; + } else if constexpr (std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<8>(buffer, offset, error))) { + return T{}; + } + T value = buffer.unsafe_get(offset); + offset += 8; + return value; + } else if constexpr (std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<2>(buffer, offset, error))) { + return T{}; + } + T value = float16_t::from_bits(buffer.unsafe_get(offset)); + offset += 2; + return value; + } else if constexpr (std::is_same_v) { + if (FORY_PREDICT_FALSE(!ensure_offset_readable<2>(buffer, offset, error))) { + return T{}; + } + T value = bfloat16_t::from_bits(buffer.unsafe_get(offset)); + offset += 2; + return value; + } else { + static_assert(sizeof(T) == 0, "Unsupported fixed primitive type"); + return T{}; + } +} + +template +FORY_ALWAYS_INLINE T read_primitive_at_checked(Buffer &buffer, uint32_t &offset, + Error &error) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return read_varint_at_checked(buffer, offset, error); + } else { + return read_fixed_primitive_at_checked(buffer, offset, error); + } +} + /// Helper to read a single varint primitive field. /// No lambda overhead - direct function call that will be inlined. /// Handles both standard varint and tagged encoding based on field config. @@ -3783,49 +4245,130 @@ read_struct_fields_impl_fast(T &obj, ReadContext &ctx, /// Read struct fields with schema evolution (compatible mode) /// Reads fields in remote schema order, dispatching by field_id to local fields template -void read_struct_fields_compatible(T &obj, ReadContext &ctx, - const TypeMeta *remote_type_meta, - std::index_sequence) { +FORY_NOINLINE void +read_struct_fields_compatible(T &obj, ReadContext &ctx, + const TypeMeta *remote_type_meta, + std::index_sequence) { const auto &remote_fields = remote_type_meta->get_field_infos(); + Buffer &buffer = ctx.buffer(); + const bool use_exact_offset_reads = !buffer.has_input_stream(); + uint32_t offset = buffer.reader_index(); + constexpr size_t total_cases = + CompileTimeFieldHelpers::FieldCount * static_cast(2); +#define FORY_COMPAT_LOOP_SWITCH_CASE(N) \ + case (N): { \ + constexpr size_t matched_case = static_cast(N); \ + if constexpr (matched_case < total_cases) { \ + if constexpr ((matched_case & 1U) == 0) { \ + if (use_exact_offset_reads) { \ + constexpr size_t local_sorted_id = matched_case / 2; \ + if constexpr (can_read_exact_primitive_with_offset< \ + T, local_sorted_id>()) { \ + read_exact_primitive_run( \ + obj, ctx, remote_fields, remote_idx, offset); \ + } else { \ + read_compatible_exact_case_at(obj, ctx, offset); \ + } \ + } else { \ + read_compatible_exact_case(obj, ctx); \ + } \ + } else { \ + if (use_exact_offset_reads) { \ + buffer.reader_index(offset); \ + } \ + read_compatible_conversion_case( \ + obj, ctx, remote_field.field_type); \ + if (use_exact_offset_reads) { \ + offset = buffer.reader_index(); \ + } \ + } \ + } else { \ + ctx.set_error(Error::type_error("Invalid compatible matched id")); \ + } \ + break; \ + } +#define FORY_COMPAT_LOOP_SWITCH_CASES_16(O) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 0) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 1) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 2) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 3) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 4) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 5) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 6) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 7) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 8) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 9) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 10) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 11) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 12) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 13) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 14) \ + FORY_COMPAT_LOOP_SWITCH_CASE((O) + 15) +#define FORY_COMPAT_LOOP_SWITCH_CASES_128() \ + FORY_COMPAT_LOOP_SWITCH_CASES_16(0) \ + FORY_COMPAT_LOOP_SWITCH_CASES_16(16) \ + FORY_COMPAT_LOOP_SWITCH_CASES_16(32) \ + FORY_COMPAT_LOOP_SWITCH_CASES_16(48) \ + FORY_COMPAT_LOOP_SWITCH_CASES_16(64) \ + FORY_COMPAT_LOOP_SWITCH_CASES_16(80) \ + FORY_COMPAT_LOOP_SWITCH_CASES_16(96) \ + FORY_COMPAT_LOOP_SWITCH_CASES_16(112) // Iterate through remote fields in their serialization order for (size_t remote_idx = 0; remote_idx < remote_fields.size(); ++remote_idx) { const auto &remote_field = remote_fields[remote_idx]; int16_t field_id = remote_field.field_id; - // Use the precomputed ref_mode from remote field metadata. - // This is computed from nullable and track_ref flags in the remote - // field's header during FieldInfo::from_bytes. - RefMode remote_ref_mode = remote_field.field_type.ref_mode; if (field_id == -1) { + if (use_exact_offset_reads) { + buffer.reader_index(offset); + } // Field unknown locally — skip its value + RefMode remote_ref_mode = remote_field.field_type.ref_mode; skip_field_value(ctx, remote_field.field_type, remote_ref_mode); if (FORY_PREDICT_FALSE(ctx.has_error())) { return; } - continue; - } - - // Dispatch to the correct local field by field_id - // Uses fold expression with short-circuit - no lambda overhead - // Pass remote field type for correct encoding and ref metadata. - bool handled = false; - dispatch_compatible_field_read_impl(obj, ctx, field_id, - remote_field.field_type, handled, - std::index_sequence{}); - - if (!handled) { - // Shouldn't happen if TypeMeta::assign_field_ids worked correctly - skip_field_value(ctx, remote_field.field_type, remote_ref_mode); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return; + if (use_exact_offset_reads) { + offset = buffer.reader_index(); } continue; } - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return; + switch (field_id) { + FORY_COMPAT_LOOP_SWITCH_CASES_128() + default: + if constexpr (128U < total_cases) { + if ((field_id & 1) == 0) { + if (use_exact_offset_reads) { + dispatch_compat_exact_read_at_impl(obj, ctx, field_id, + offset); + } else { + dispatch_compat_exact_read_impl(obj, ctx, field_id); + } + } else { + if (use_exact_offset_reads) { + dispatch_compat_conversion_read_at_impl( + obj, ctx, field_id, remote_field.field_type, offset); + } else { + dispatch_compat_conversion_read_impl( + obj, ctx, field_id, remote_field.field_type); + } + } + } else { + ctx.set_error(Error::type_error("Invalid compatible matched id")); + } + break; } } +#undef FORY_COMPAT_LOOP_SWITCH_CASES_128 +#undef FORY_COMPAT_LOOP_SWITCH_CASES_16 +#undef FORY_COMPAT_LOOP_SWITCH_CASE + if (use_exact_offset_reads) { + buffer.reader_index(offset); + } + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return; + } } } // namespace detail diff --git a/cpp/fory/serialization/struct_test.cc b/cpp/fory/serialization/struct_test.cc index e32ab4e7cd..970cf7154b 100644 --- a/cpp/fory/serialization/struct_test.cc +++ b/cpp/fory/serialization/struct_test.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -133,6 +134,18 @@ struct MixedFieldIdentityStruct { alpha_value, (count, fory::F(2).varint())); }; +struct SignedToUnsignedWriter { + int32_t value; + + FORY_STRUCT(SignedToUnsignedWriter, (value, fory::F(1))); +}; + +struct UnsignedTargetReader { + uint32_t value = 0; + + FORY_STRUCT(UnsignedTargetReader, (value, fory::F(1))); +}; + class PrivateFieldsStruct { public: PrivateFieldsStruct() = default; @@ -969,26 +982,46 @@ TEST(StructComprehensiveTest, NonPrimitiveFieldsSortByFieldIdentifier) { EXPECT_EQ(fields[3].field_name, "custom_value"); } -TEST(StructComprehensiveTest, - FieldTypeCompatibleFingerprintNormalizesEncoding) { +TEST(StructComprehensiveTest, FieldTypeCompatibilitySeparatesAdapters) { FieldType fixed_i32 = make_test_field_type(TypeId::INT32); FieldType var_i32 = make_test_field_type(TypeId::VARINT32); - EXPECT_TRUE(field_types_compatible(fixed_i32, var_i32)); + EXPECT_FALSE(field_types_compatible(fixed_i32, var_i32)); + EXPECT_TRUE(field_types_compatible_top_level(fixed_i32, var_i32)); EXPECT_EQ(fixed_i32.compatible_fingerprint, var_i32.compatible_fingerprint); FieldType fixed_list = make_test_field_type(TypeId::LIST, {make_test_field_type(TypeId::INT32)}); FieldType var_list = make_test_field_type( TypeId::LIST, {make_test_field_type(TypeId::VARINT32)}); - EXPECT_TRUE(field_types_compatible(fixed_list, var_list)); + EXPECT_FALSE(field_types_compatible(fixed_list, var_list)); FieldType int64_list = make_test_field_type( TypeId::LIST, {make_test_field_type(TypeId::VARINT64)}); EXPECT_FALSE(field_types_compatible(fixed_list, int64_list)); - EXPECT_TRUE( + FieldType nullable_i32(static_cast(TypeId::INT32), true); + FieldType nullable_list = make_test_field_type(TypeId::LIST, {nullable_i32}); + EXPECT_FALSE(field_types_compatible(fixed_list, nullable_list)); + + EXPECT_FALSE( field_types_compatible(make_test_field_type(TypeId::BINARY), make_test_field_type(TypeId::UINT8_ARRAY))); + EXPECT_TRUE(field_types_compatible_top_level( + make_test_field_type(TypeId::BINARY), + make_test_field_type(TypeId::UINT8_ARRAY))); + + FieldType int32_array = make_test_field_type(TypeId::INT32_ARRAY); + EXPECT_TRUE(field_types_compatible_top_level(fixed_list, int32_array)); + EXPECT_TRUE(field_types_compatible_top_level(nullable_list, int32_array)); + EXPECT_TRUE(field_types_compatible_top_level(int32_array, nullable_list)); + FieldType tracked_i32(static_cast(TypeId::INT32), false, true); + FieldType tracked_list = make_test_field_type(TypeId::LIST, {tracked_i32}); + EXPECT_TRUE(field_types_compatible_top_level(tracked_list, int32_array)); + EXPECT_FALSE(field_types_compatible_top_level(int32_array, tracked_list)); + FieldType nullable_int32_array(static_cast(TypeId::INT32_ARRAY), + true); + EXPECT_FALSE( + field_types_compatible_top_level(fixed_list, nullable_int32_array)); } TEST(StructComprehensiveTest, @@ -1004,15 +1037,26 @@ TEST(StructComprehensiveTest, make_test_field_type(TypeId::MAP, {make_test_field_type(TypeId::VAR_UINT32), make_test_field_type(TypeId::VAR_UINT32)}))}; - TypeMeta::assign_field_ids(&local_type, incompatible_remote); - EXPECT_EQ(incompatible_remote[0].field_id, -1); + auto incompatible_result = + TypeMeta::assign_field_ids(&local_type, incompatible_remote); + EXPECT_FALSE(incompatible_result.ok()); - std::vector compatible_remote = {make_test_field_info( + std::vector nested_scalar_remote = {make_test_field_info( "items", 7, make_test_field_type(TypeId::LIST, {make_test_field_type(TypeId::UINT32)}))}; - TypeMeta::assign_field_ids(&local_type, compatible_remote); - EXPECT_EQ(compatible_remote[0].field_id, 0); + auto nested_scalar_result = + TypeMeta::assign_field_ids(&local_type, nested_scalar_remote); + EXPECT_FALSE(nested_scalar_result.ok()); + + TypeMeta scalar_local; + scalar_local.field_infos = {make_test_field_info( + "count", 8, make_test_field_type(TypeId::VAR_UINT32))}; + std::vector scalar_remote = { + make_test_field_info("count", 8, make_test_field_type(TypeId::UINT32))}; + auto scalar_result = TypeMeta::assign_field_ids(&scalar_local, scalar_remote); + ASSERT_TRUE(scalar_result.ok()); + EXPECT_EQ(scalar_remote[0].field_id, 1); TypeMeta name_mode_local; name_mode_local.field_infos = {make_test_field_info( @@ -1023,14 +1067,15 @@ TEST(StructComprehensiveTest, "items", 7, make_test_field_type(TypeId::LIST, {make_test_field_type(TypeId::UINT32)}))}; - TypeMeta::assign_field_ids(&name_mode_local, mixed_mode_remote); + ASSERT_TRUE( + TypeMeta::assign_field_ids(&name_mode_local, mixed_mode_remote).ok()); EXPECT_EQ(mixed_mode_remote[0].field_id, -1); std::vector name_remote = {make_test_field_info( "items", -1, make_test_field_type(TypeId::LIST, {make_test_field_type(TypeId::UINT32)}))}; - TypeMeta::assign_field_ids(&local_type, name_remote); + ASSERT_TRUE(TypeMeta::assign_field_ids(&local_type, name_remote).ok()); EXPECT_EQ(name_remote[0].field_id, -1); TypeMeta mixed_local; @@ -1042,17 +1087,79 @@ TEST(StructComprehensiveTest, make_test_field_info("alpha", -1, make_test_field_type(TypeId::BINARY)), make_test_field_info("tagged", 3, make_test_field_type(TypeId::STRING)), make_test_field_info("beta", -1, make_test_field_type(TypeId::VARINT32))}; - TypeMeta::assign_field_ids(&mixed_local, mixed_remote); - EXPECT_EQ(mixed_remote[0].field_id, 1); + ASSERT_TRUE(TypeMeta::assign_field_ids(&mixed_local, mixed_remote).ok()); + EXPECT_EQ(mixed_remote[0].field_id, 2); EXPECT_EQ(mixed_remote[1].field_id, 0); - EXPECT_EQ(mixed_remote[2].field_id, 2); + EXPECT_EQ(mixed_remote[2].field_id, 4); std::vector untagged_remote_for_tagged_local = { make_test_field_info("tagged", -1, make_test_field_type(TypeId::STRING))}; - TypeMeta::assign_field_ids(&mixed_local, untagged_remote_for_tagged_local); + ASSERT_TRUE( + TypeMeta::assign_field_ids(&mixed_local, untagged_remote_for_tagged_local) + .ok()); EXPECT_EQ(untagged_remote_for_tagged_local[0].field_id, -1); } +TEST(StructComprehensiveTest, CompatibleSignedToUnsignedStructRead) { + auto writer = + Fory::builder().xlang(true).compatible(true).track_ref(false).build(); + auto reader = + Fory::builder().xlang(true).compatible(true).track_ref(false).build(); + ASSERT_TRUE(writer.register_struct(608).ok()); + ASSERT_TRUE(reader.register_struct(608).ok()); + + auto bytes_result = writer.serialize(SignedToUnsignedWriter{123}); + ASSERT_TRUE(bytes_result.ok()) + << "Serialization failed: " << bytes_result.error().to_string(); + + std::vector bytes = std::move(bytes_result).value(); + auto result = + reader.deserialize(bytes.data(), bytes.size()); + ASSERT_TRUE(result.ok()) << "Deserialization failed: " + << result.error().to_string(); + EXPECT_EQ(result.value().value, 123u); +} + +TEST(StructComprehensiveTest, CompatibleNegativeSignedToUnsignedFails) { + auto writer = + Fory::builder().xlang(true).compatible(true).track_ref(false).build(); + auto reader = + Fory::builder().xlang(true).compatible(true).track_ref(false).build(); + ASSERT_TRUE(writer.register_struct(609).ok()); + ASSERT_TRUE(reader.register_struct(609).ok()); + + auto bytes_result = writer.serialize(SignedToUnsignedWriter{-1}); + ASSERT_TRUE(bytes_result.ok()) + << "Serialization failed: " << bytes_result.error().to_string(); + + std::vector bytes = std::move(bytes_result).value(); + auto result = + reader.deserialize(bytes.data(), bytes.size()); + EXPECT_FALSE(result.ok()); +} + +TEST(StructComprehensiveTest, AssignFieldIdsRejectsMatchedIdOverflow) { + constexpr size_t max_compatible_matched_field_index = + (static_cast(std::numeric_limits::max()) - 1) / 2; + const FieldType field_type = make_test_field_type(TypeId::INT32); + + TypeMeta local_type; + local_type.field_infos.reserve(max_compatible_matched_field_index + 2); + for (size_t index = 0; index <= max_compatible_matched_field_index + 1; + ++index) { + local_type.field_infos.push_back( + make_test_field_info("field_" + std::to_string(index), -1, field_type)); + } + + std::vector remote_fields = {make_test_field_info( + "field_" + std::to_string(max_compatible_matched_field_index + 1), -1, + field_type)}; + auto result = TypeMeta::assign_field_ids(&local_type, remote_fields); + + ASSERT_FALSE(result.ok()); + EXPECT_NE(result.error().message().find("exceeds max"), std::string::npos); +} + TEST(StructComprehensiveTest, OptionalFieldsAllEmpty) { test_roundtrip( OptionalFieldsStruct{"John", std::nullopt, std::nullopt, std::nullopt}); diff --git a/cpp/fory/serialization/type_resolver.cc b/cpp/fory/serialization/type_resolver.cc index 2cabba4e13..92971affd8 100644 --- a/cpp/fory/serialization/type_resolver.cc +++ b/cpp/fory/serialization/type_resolver.cc @@ -697,7 +697,7 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { // Assign field IDs by comparing with local type if (local_type_info != nullptr) { - assign_field_ids(local_type_info, field_infos); + FORY_RETURN_IF_ERROR(assign_field_ids(local_type_info, field_infos)); } size_t current_pos = buffer.reader_index(); @@ -987,20 +987,71 @@ bool name_sorter(const FieldInfo &a, const FieldInfo &b) { return compare_field_sort_key(a, b) < 0; } -bool allows_empty_generic_fallback(uint32_t type_id) { - return type_id == static_cast(TypeId::LIST) || - type_id == static_cast(TypeId::SET) || - type_id == static_cast(TypeId::MAP); +uint32_t exact_schema_type_id(uint32_t type_id) { + switch (static_cast(type_id)) { + case TypeId::STRUCT: + case TypeId::COMPATIBLE_STRUCT: + case TypeId::NAMED_STRUCT: + case TypeId::NAMED_COMPATIBLE_STRUCT: + case TypeId::UNKNOWN: + return static_cast(TypeId::STRUCT); + case TypeId::ENUM: + case TypeId::NAMED_ENUM: + return static_cast(TypeId::ENUM); + case TypeId::EXT: + case TypeId::NAMED_EXT: + return static_cast(TypeId::EXT); + case TypeId::UNION: + case TypeId::TYPED_UNION: + case TypeId::NAMED_UNION: + return static_cast(TypeId::UNION); + default: + return type_id; + } } -bool union_type_ids_compatible(uint32_t local_type_id, - uint32_t remote_type_id) { - auto is_union = [](uint32_t type_id) { - TypeId tid = static_cast(type_id); - return tid == TypeId::UNION || tid == TypeId::TYPED_UNION || - tid == TypeId::NAMED_UNION; - }; - return is_union(local_type_id) && is_union(remote_type_id); +bool user_type_ids_compatible(const FieldType &local, const FieldType &remote) { + return local.user_type_id == remote.user_type_id || + local.type_id == static_cast(TypeId::UNKNOWN) || + remote.type_id == static_cast(TypeId::UNKNOWN); +} + +bool byte_sequence_field_types_compatible(const FieldType &local, + const FieldType &remote) { + if (local.track_ref || remote.track_ref || + local.nullable != remote.nullable) { + return false; + } + return (local.type_id == static_cast(TypeId::BINARY) && + remote.type_id == static_cast(TypeId::UINT8_ARRAY)) || + (local.type_id == static_cast(TypeId::UINT8_ARRAY) && + remote.type_id == static_cast(TypeId::BINARY)); +} + +bool scalar_field_type_id(uint32_t type_id) { + return compatible_scalar_field_types(type_id, type_id); +} + +bool compatible_payload_field_types(const FieldType &local, + const FieldType &remote) { + if (scalar_field_type_id(local.type_id) || + scalar_field_type_id(remote.type_id)) { + return !local.track_ref && !remote.track_ref && + local.type_id == remote.type_id; + } + if (exact_schema_type_id(local.type_id) != + exact_schema_type_id(remote.type_id) || + !user_type_ids_compatible(local, remote) || + local.generics.size() != remote.generics.size()) { + return false; + } + for (size_t i = 0; i < local.generics.size(); ++i) { + if (!compatible_payload_field_types(local.generics[i], + remote.generics[i])) { + return false; + } + } + return true; } bool direct_field_types_compatible(const FieldType &local, @@ -1028,14 +1079,20 @@ bool direct_field_types_compatible(const FieldType &local, if (local.type_id == static_cast(TypeId::LIST) && remote.generics.size() == 0 && primitive_array_element_type_id(remote.type_id, array_element_type_id) && - local.generics.size() == 1) { + local.generics.size() == 1 && !local.nullable && !local.track_ref && + !remote.nullable && !remote.track_ref) { return compatible_fingerprint_type_id(local.generics[0].type_id) == compatible_fingerprint_type_id(array_element_type_id); } if (remote.type_id == static_cast(TypeId::LIST) && local.generics.size() == 0 && primitive_array_element_type_id(local.type_id, array_element_type_id) && - remote.generics.size() == 1) { + remote.generics.size() == 1 && !local.nullable && !local.track_ref && + !remote.nullable && !remote.track_ref && !remote.generics[0].track_ref) { + // Nullable element schema is compatible with dense arrays when the payload + // has no nulls; actual null elements are rejected by the array reader. + // Ref-tracked element framing stays rejected because this path is + // primitive-only. return compatible_fingerprint_type_id(remote.generics[0].type_id) == compatible_fingerprint_type_id(array_element_type_id); } @@ -1082,15 +1139,22 @@ uint32_t compatible_fingerprint_type_id(uint32_t type_id) { } bool field_types_compatible(const FieldType &local, const FieldType &remote) { - if (local.compatible_fingerprint == remote.compatible_fingerprint) { - return true; + if (exact_schema_type_id(local.type_id) != + exact_schema_type_id(remote.type_id) || + !user_type_ids_compatible(local, remote) || + local.nullable != remote.nullable || + local.track_ref != remote.track_ref) { + return false; } - if (union_type_ids_compatible(local.type_id, remote.type_id)) { - return true; + if (local.generics.size() != remote.generics.size()) { + return false; } - return local.type_id == remote.type_id && - allows_empty_generic_fallback(local.type_id) && - (local.generics.empty() || remote.generics.empty()); + for (size_t i = 0; i < local.generics.size(); ++i) { + if (!field_types_compatible(local.generics[i], remote.generics[i])) { + return false; + } + } + return true; } bool primitive_array_element_type_id(uint32_t array_type_id, @@ -1143,8 +1207,10 @@ bool primitive_array_element_type_id(uint32_t array_type_id, bool field_types_compatible_top_level(const FieldType &local, const FieldType &remote) { return direct_field_types_compatible(local, remote) || + byte_sequence_field_types_compatible(local, remote) || (!local.track_ref && !remote.track_ref && - compatible_scalar_field_types(local.type_id, remote.type_id)); + compatible_scalar_field_types(local.type_id, remote.type_id)) || + compatible_payload_field_types(local, remote); } std::vector @@ -1196,9 +1262,12 @@ TypeMeta::sort_field_infos(std::vector fields) { // Field ID Assignment (KEY FUNCTION for schema evolution!) // ============================================================================ -void TypeMeta::assign_field_ids(const TypeMeta *local_type, - std::vector &remote_fields) { +Result +TypeMeta::assign_field_ids(const TypeMeta *local_type, + std::vector &remote_fields) { const auto &local_fields = local_type->field_infos; + constexpr size_t max_compatible_matched_field_index = + (static_cast(std::numeric_limits::max()) - 1) / 2; // Primary mapping: field name -> sorted index in local schema std::unordered_map local_field_index_map; @@ -1219,17 +1288,53 @@ void TypeMeta::assign_field_ids(const TypeMeta *local_type, // back to type-based matching. std::vector used(local_fields.size(), false); - // For each remote field, assign field ID (sorted index in local schema) + auto assign_matched_field = [&](FieldInfo &remote_field, + size_t local_index) -> Result { + if (used[local_index]) { + return false; + } + const FieldInfo &local_field = local_fields[local_index]; + if (field_types_compatible(local_field.field_type, + remote_field.field_type)) { + if (local_index > max_compatible_matched_field_index) { + return Unexpected(Error::type_error( + "Cannot assign compatible matched id for local field " + + local_field.field_name + ": local field index " + + std::to_string(local_index) + " exceeds max " + + std::to_string(max_compatible_matched_field_index))); + } + remote_field.field_id = static_cast(local_index * 2); + used[local_index] = true; + return true; + } + if (field_types_compatible_top_level(local_field.field_type, + remote_field.field_type)) { + if (local_index > max_compatible_matched_field_index) { + return Unexpected(Error::type_error( + "Cannot assign compatible matched id for local field " + + local_field.field_name + ": local field index " + + std::to_string(local_index) + " exceeds max " + + std::to_string(max_compatible_matched_field_index))); + } + remote_field.field_id = static_cast(local_index * 2 + 1); + used[local_index] = true; + return true; + } + return Unexpected(Error::type_error( + "Cannot read remote field " + remote_field.field_name + + " as local field " + local_field.field_name + + ": remote and local field schemas are not compatible")); + }; + + // For each remote field, assign doubled dispatch id in local schema. for (auto &remote_field : remote_fields) { - size_t local_index = static_cast(-1); + bool matched = false; if (remote_field.field_id >= 0) { auto id_it = local_field_id_map.find(remote_field.field_id); - if (id_it != local_field_id_map.end() && !used[id_it->second] && - field_types_compatible_top_level( - local_fields[id_it->second].field_type, - remote_field.field_type)) { - local_index = id_it->second; + if (id_it != local_field_id_map.end()) { + FORY_TRY(is_matched, assign_matched_field(remote_field, id_it->second)); + matched = is_matched; } } else { // 1) Try exact name + type match first (fast path for same-language @@ -1240,43 +1345,40 @@ void TypeMeta::assign_field_ids(const TypeMeta *local_type, if (it != local_field_index_map.end()) { size_t idx = it->second; const FieldInfo &local_field = local_fields[idx]; - if (local_field.field_id < 0 && - field_types_compatible_top_level(local_field.field_type, - remote_field.field_type)) { - local_index = idx; + if (local_field.field_id < 0) { + FORY_TRY(is_matched, assign_matched_field(remote_field, idx)); + matched = is_matched; } } - // 2) Fallback: match by type signature and position when field names - // are not available or differ across languages. Keep this fallback - // within name-based fields so a mixed schema does not switch into a - // global tag-ID mode or bind an untagged field to a tagged local - // field. - if (local_index == static_cast(-1)) { + // 2) Fallback by type signature only when no canonical remote name was + // carried. Named remote fields that miss the local name map are + // remote-only fields; matching them by type can bind an added string + // field such as `email` to an unrelated local string field. + if (!matched && remote_field.field_name.empty()) { for (size_t i = 0; i < local_fields.size(); ++i) { if (used[i] || local_fields[i].field_id >= 0) { continue; } - // Scalar conversion requires a tag or canonical-name match; the - // type-only fallback is only for schemas whose scalar identity - // already matches, otherwise unrelated fields can bind by value type. - if (direct_field_types_compatible(local_fields[i].field_type, - remote_field.field_type)) { - local_index = i; + // Compatible adapters require tag or canonical-name identity. The + // type-only fallback is only for anonymous legacy fields with exact + // schema shape, otherwise unrelated fields can bind by value type. + if (field_types_compatible(local_fields[i].field_type, + remote_field.field_type)) { + FORY_TRY(is_matched, assign_matched_field(remote_field, i)); + matched = is_matched; break; } } } } - if (local_index == static_cast(-1)) { + if (!matched) { // No suitable local field found -> mark as skipped. remote_field.field_id = -1; - } else { - remote_field.field_id = static_cast(local_index); - used[local_index] = true; } } + return Result(); } int64_t TypeMeta::compute_hash(const std::vector &meta_bytes) { diff --git a/cpp/fory/serialization/type_resolver.h b/cpp/fory/serialization/type_resolver.h index 1f0c4a11e7..e51d535a05 100644 --- a/cpp/fory/serialization/type_resolver.h +++ b/cpp/fory/serialization/type_resolver.h @@ -274,8 +274,9 @@ class TypeMeta { /// Assign field IDs by comparing with local type /// This is the key function for schema evolution! - static void assign_field_ids(const TypeMeta *local_type, - std::vector &remote_fields); + static Result + assign_field_ids(const TypeMeta *local_type, + std::vector &remote_fields); const std::vector &get_field_infos() const { return field_infos; } int64_t get_hash() const { return hash; } diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index 6319993b40..2e6ad7d4b7 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -27,6 +27,8 @@ namespace Apache.Fory.Generator; [Generator(LanguageNames.CSharp)] public sealed class ForyModelGenerator : IIncrementalGenerator { + private const uint UInt8ArrayTypeId = 48; + private static readonly SymbolDisplayFormat FullNameFormat = SymbolDisplayFormat.FullyQualifiedFormat.WithMiscellaneousOptions( SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier); @@ -240,58 +242,6 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) EmitCompatibleFieldCodecMethods(sb, model); - sb.AppendLine(" private static object __ForyReadCompatiblePrimitivePayload(global::Apache.Fory.TypeId typeId, global::Apache.Fory.ReadContext context)"); - sb.AppendLine(" {"); - sb.AppendLine(" return typeId switch"); - sb.AppendLine(" {"); - sb.AppendLine(" global::Apache.Fory.TypeId.Bool => context.Reader.ReadUInt8() != 0,"); - sb.AppendLine(" global::Apache.Fory.TypeId.Int8 => context.Reader.ReadInt8(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.Int16 => context.Reader.ReadInt16(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.Int32 => context.Reader.ReadInt32(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.VarInt32 => context.Reader.ReadVarInt32(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.Int64 => context.Reader.ReadInt64(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.VarInt64 => context.Reader.ReadVarInt64(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.TaggedInt64 => context.Reader.ReadTaggedInt64(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.UInt8 => context.Reader.ReadUInt8(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.UInt16 => context.Reader.ReadUInt16(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.UInt32 => context.Reader.ReadUInt32(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.VarUInt32 => context.Reader.ReadVarUInt32(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.UInt64 => context.Reader.ReadUInt64(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.VarUInt64 => context.Reader.ReadVarUInt64(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.TaggedUInt64 => context.Reader.ReadTaggedUInt64(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.Float16 => global::System.BitConverter.UInt16BitsToHalf(context.Reader.ReadUInt16()),"); - sb.AppendLine(" global::Apache.Fory.TypeId.BFloat16 => global::Apache.Fory.BFloat16.FromBits(context.Reader.ReadUInt16()),"); - sb.AppendLine(" global::Apache.Fory.TypeId.Float32 => context.Reader.ReadFloat32(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.Float64 => context.Reader.ReadFloat64(),"); - sb.AppendLine(" global::Apache.Fory.TypeId.String => global::Apache.Fory.StringSerializer.ReadString(context),"); - sb.AppendLine(" _ => throw new global::Apache.Fory.InvalidDataException($\"unsupported compatible primitive type id {typeId}\"),"); - sb.AppendLine(" };"); - sb.AppendLine(" }"); - sb.AppendLine(); - sb.AppendLine(" private static T __ForyReadCompatibleField("); - sb.AppendLine(" global::Apache.Fory.ReadContext context,"); - sb.AppendLine(" global::Apache.Fory.TypeMetaFieldType fieldType,"); - sb.AppendLine(" global::Apache.Fory.TypeId localTypeId,"); - sb.AppendLine(" string fieldName,"); - sb.AppendLine(" global::Apache.Fory.RefMode refMode,"); - sb.AppendLine(" global::Apache.Fory.RefMode localRefMode,"); - sb.AppendLine(" bool readTypeInfo)"); - sb.AppendLine(" {"); - sb.AppendLine(" global::Apache.Fory.TypeId typeId = (global::Apache.Fory.TypeId)fieldType.TypeId;"); - sb.AppendLine(" bool scalarPair = global::Apache.Fory.CompatibleScalarConverter.IsScalarType(fieldType.TypeId) &&"); - sb.AppendLine(" global::Apache.Fory.CompatibleScalarConverter.IsScalarType((uint)localTypeId);"); - // Compatible scalar reads must use remote field metadata even for same CLR scalar types; - // C# serializers default to one wire form while IDL fields may use fixed, varint, or tagged forms. - sb.AppendLine(" bool compatibleScalarRead = fieldType.TypeId == (uint)localTypeId ||"); - sb.AppendLine(" global::Apache.Fory.CompatibleScalarConverter.RequiresScalarRead(fieldType.TypeId, (uint)localTypeId);"); - sb.AppendLine(" if (!readTypeInfo && refMode != global::Apache.Fory.RefMode.Tracking && localRefMode != global::Apache.Fory.RefMode.Tracking && scalarPair && compatibleScalarRead)"); - sb.AppendLine(" {"); - sb.AppendLine(" return global::Apache.Fory.CompatibleScalarConverter.ReadField(context, typeId, localTypeId, fieldName, refMode);"); - sb.AppendLine(" }"); - sb.AppendLine(); - sb.AppendLine(" return context.TypeResolver.GetSerializer().Read(context, refMode, readTypeInfo);"); - sb.AppendLine(" }"); - sb.AppendLine(); sb.AppendLine( " private static global::System.Collections.Generic.IReadOnlyList __ForyBuildTypeMetaFields(bool trackRef)"); sb.AppendLine(" {"); @@ -578,18 +528,44 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" for (int i = 0; i < typeMeta.Fields.Count; i++)"); sb.AppendLine(" {"); sb.AppendLine(" global::Apache.Fory.TypeMetaFieldInfo remoteField = typeMeta.Fields[i];"); - sb.AppendLine(" global::Apache.Fory.RefMode remoteRefMode = __ForyRefMode(remoteField.FieldType.Nullable, remoteField.FieldType.TrackRef);"); sb.AppendLine(" switch (remoteField.AssignedFieldId)"); sb.AppendLine(" {"); + sb.AppendLine(" case -1:"); + sb.AppendLine(" global::Apache.Fory.FieldSkipper.SkipFieldValue(context, remoteField.FieldType);"); + sb.AppendLine(" break;"); for (int idx = 0; idx < model.SortedMembers.Length; idx++) { MemberModel member = model.SortedMembers[idx]; - sb.AppendLine($" case {idx}:"); + sb.AppendLine($" case {idx * 2}:"); + sb.AppendLine(" {"); + EmitReadMemberAssignment( + sb, + member, + BuildWriteRefModeExpression(member), + BuildFieldTypeInfoLiteral(member), + "value", + "CompatDirect", + 7, + true); + sb.AppendLine(" break;"); + sb.AppendLine(" }"); + sb.AppendLine($" case {idx * 2 + 1}:"); sb.AppendLine(" {"); + string compatRefModeExpr; + if (CompatibleCaseNeedsRemoteRefMode(member)) + { + sb.AppendLine(" global::Apache.Fory.RefMode remoteRefMode = __ForyRefMode(remoteField.FieldType.Nullable, remoteField.FieldType.TrackRef);"); + compatRefModeExpr = "remoteRefMode"; + } + else + { + compatRefModeExpr = "default"; + } + EmitReadMemberAssignment( sb, member, - "remoteRefMode", + compatRefModeExpr, BuildFieldTypeInfoLiteral(member), "value", "Compat", @@ -600,8 +576,7 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) } sb.AppendLine(" default:"); - sb.AppendLine(" global::Apache.Fory.FieldSkipper.SkipFieldValue(context, remoteField.FieldType);"); - sb.AppendLine(" break;"); + sb.AppendLine(" throw new global::Apache.Fory.InvalidDataException($\"invalid compatible matched id {remoteField.AssignedFieldId}\");"); sb.AppendLine(" }"); sb.AppendLine(" }"); sb.AppendLine(" return value;"); @@ -960,7 +935,7 @@ private static void EmitCompatibleFieldCodecMethods(StringBuilder sb, TypeModel foreach (MemberModel member in model.SortedMembers) { if (member.FieldCodec is not null && - TryBuildCompatibleListArrayReadCodec(member.FieldCodec, out _)) + CanReadCompatibleField(member.FieldCodec)) { hasCompatibleField = true; break; @@ -977,9 +952,9 @@ private static void EmitCompatibleFieldCodecMethods(StringBuilder sb, TypeModel foreach (MemberModel member in model.SortedMembers) { if (member.FieldCodec is not null && - TryBuildCompatibleListArrayReadCodec(member.FieldCodec, out FieldCodecModel? alternateCodec)) + CanReadCompatibleField(member.FieldCodec)) { - EmitCompatibleFieldCodecMethod(sb, member, member.FieldCodec, alternateCodec); + EmitCompatibleFieldCodecMethod(sb, member, member.FieldCodec); } } @@ -990,52 +965,115 @@ private static void EmitCompatibleFieldCodecMethods(StringBuilder sb, TypeModel private static void EmitCompatibleFieldCodecMethod( StringBuilder sb, MemberModel member, - FieldCodecModel codec, - FieldCodecModel alternateCodec) + FieldCodecModel codec) { string memberId = Sanitize(member.Name); sb.AppendLine(" [global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.NoInlining)]"); sb.AppendLine( - $" internal static {member.TypeName} Read{memberId}ListArrayBridge(global::Apache.Fory.ReadContext context, global::Apache.Fory.TypeMetaFieldType remoteFieldType, global::Apache.Fory.RefMode refMode)"); + $" internal static {member.TypeName} Read{memberId}FieldBridge(global::Apache.Fory.ReadContext context, global::Apache.Fory.TypeMetaFieldType remoteFieldType, global::Apache.Fory.RefMode refMode)"); sb.AppendLine(" {"); - sb.AppendLine(" if (remoteFieldType.TypeId == " + alternateCodec.TypeId + ")"); + sb.AppendLine(" if (remoteFieldType.TypeId == " + codec.TypeId + ")"); sb.AppendLine(" {"); - if (codec.Kind == FieldCodecKind.PackedArray) + sb.AppendLine($" return __ForyRead{memberId}Field(context, refMode);"); + sb.AppendLine(" }"); + sb.AppendLine(); + if (TryBuildCompatibleListArrayReadCodec(codec, out FieldCodecModel? alternateCodec)) { - sb.AppendLine(" if (remoteFieldType.Generics.Count != 1)"); - sb.AppendLine(" {"); - sb.AppendLine(" throw new global::Apache.Fory.InvalidDataException(\"compatible list to array field requires one element schema\");"); - sb.AppendLine(" }"); + sb.AppendLine(" if (remoteFieldType.TypeId == " + alternateCodec.TypeId + ")"); + sb.AppendLine(" {"); + if (codec.Kind == FieldCodecKind.PackedArray) + { + sb.AppendLine(" if (remoteFieldType.Generics.Count != 1)"); + sb.AppendLine(" {"); + sb.AppendLine(" throw new global::Apache.Fory.InvalidDataException(\"compatible list to array field requires one element schema\");"); + sb.AppendLine(" }"); + } + + EmitReadNullOnlyPrefix(sb, member, 4); + int id = 0; + string compatibleResultVar = $"__{memberId}CompatibleValue"; + if (codec.Kind == FieldCodecKind.PackedArray && alternateCodec.Kind == FieldCodecKind.List) + { + EmitReadCompatibleListArrayPayload(sb, codec, compatibleResultVar, 4, ref id); + } + else + { + EmitReadPayload(sb, alternateCodec, compatibleResultVar, 4, ref id); + } + + sb.AppendLine($" return {compatibleResultVar};"); + sb.AppendLine(" }"); } - sb.AppendLine(" if (refMode == global::Apache.Fory.RefMode.NullOnly)"); - sb.AppendLine(" {"); - sb.AppendLine(" sbyte refFlag = context.Reader.ReadInt8();"); - sb.AppendLine(" if (refFlag == (sbyte)global::Apache.Fory.RefFlag.Null)"); - sb.AppendLine(" {"); - sb.AppendLine($" return ({member.TypeName})default!;"); - sb.AppendLine(" }"); + if (CanReadCompatibleBinaryField(codec)) + { + sb.AppendLine(" if (remoteFieldType.TypeId == (uint)global::Apache.Fory.TypeId.Binary)"); + sb.AppendLine(" {"); + EmitReadNullOnlyPrefix(sb, member, 4); + EmitReadBinaryPayload(sb, codec, $"__{memberId}BinaryValue", 4); + sb.AppendLine($" return __{memberId}BinaryValue;"); + sb.AppendLine(" }"); + } + + sb.AppendLine(" throw new global::Apache.Fory.InvalidDataException($\"unsupported compatible field schema pair: local " + codec.TypeId + ", remote {remoteFieldType.TypeId}\");"); + sb.AppendLine(" }"); + } + + private static void EmitReadNullOnlyPrefix(StringBuilder sb, MemberModel member, int indentLevel) + { + string indent = new(' ', indentLevel * 4); + sb.AppendLine($"{indent}if (refMode == global::Apache.Fory.RefMode.NullOnly)"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine($"{indent} sbyte refFlag = context.Reader.ReadInt8();"); + sb.AppendLine($"{indent} if (refFlag == (sbyte)global::Apache.Fory.RefFlag.Null)"); + sb.AppendLine($"{indent} {{"); + sb.AppendLine($"{indent} return ({member.TypeName})default!;"); + sb.AppendLine($"{indent} }}"); sb.AppendLine(); - sb.AppendLine(" if (refFlag != (sbyte)global::Apache.Fory.RefFlag.NotNullValue)"); - sb.AppendLine(" {"); - sb.AppendLine(" throw new global::Apache.Fory.InvalidDataException($\"invalid nullOnly ref flag {refFlag}\");"); - sb.AppendLine(" }"); - sb.AppendLine(" }"); - int id = 0; - string compatibleResultVar = $"__{memberId}CompatibleValue"; - if (codec.Kind == FieldCodecKind.PackedArray && alternateCodec.Kind == FieldCodecKind.List) + sb.AppendLine($"{indent} if (refFlag != (sbyte)global::Apache.Fory.RefFlag.NotNullValue)"); + sb.AppendLine($"{indent} {{"); + sb.AppendLine($"{indent} throw new global::Apache.Fory.InvalidDataException($\"invalid nullOnly ref flag {{refFlag}}\");"); + sb.AppendLine($"{indent} }}"); + sb.AppendLine($"{indent}}}"); + } + + private static void EmitReadBinaryPayload( + StringBuilder sb, + FieldCodecModel codec, + string targetVar, + int indentLevel) + { + string indent = new(' ', indentLevel * 4); + sb.AppendLine($"{indent}int __foryLength = checked((int)context.Reader.ReadVarUInt32());"); + if (codec.CarrierKind == CarrierKind.Array) { - EmitReadCompatibleListArrayPayload(sb, codec, compatibleResultVar, 4, ref id); + sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = context.Reader.ReadBytes(__foryLength);"); + return; } - else + + if (codec.CarrierKind == CarrierKind.List) { - EmitReadPayload(sb, alternateCodec, compatibleResultVar, 4, ref id); + sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new(__foryLength);"); + sb.AppendLine($"{indent}for (int __foryIndex = 0; __foryIndex < __foryLength; __foryIndex++)"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine($"{indent} {targetVar}.Add(context.Reader.ReadUInt8());"); + sb.AppendLine($"{indent}}}"); + return; } - sb.AppendLine($" return {compatibleResultVar};"); - sb.AppendLine(" }"); - sb.AppendLine(" throw new global::Apache.Fory.InvalidDataException($\"unsupported compatible field schema pair: local " + codec.TypeId + ", remote {remoteFieldType.TypeId}\");"); - sb.AppendLine(" }"); + throw new InvalidOperationException($"unsupported binary compatible carrier {codec.TypeName}"); + } + + private static bool CanReadCompatibleField(FieldCodecModel codec) + { + return TryBuildCompatibleListArrayReadCodec(codec, out _) || CanReadCompatibleBinaryField(codec); + } + + private static bool CanReadCompatibleBinaryField(FieldCodecModel codec) + { + return codec.Kind == FieldCodecKind.PackedArray && + codec.TypeId == UInt8ArrayTypeId && + codec.CarrierKind is CarrierKind.Array or CarrierKind.List; } private static bool TryBuildCompatibleListArrayReadCodec(FieldCodecModel codec, out FieldCodecModel compatibleCodec) @@ -1116,8 +1154,8 @@ private static void EmitReadCompatibleListArrayPayload( sb.AppendLine($"{innerIndent} }}"); sb.AppendLine($"{innerIndent}}}"); sb.AppendLine($"{indent}}}"); - string indexVar = $"__foryIndex{id++}"; string elementTypeName = codec.CarrierKind == CarrierKind.Array ? ElementTypeName(codec.TypeName) : PackedArrayElementTypeName(codec.TypeId); + uint elementTypeId = PackedArrayElementTypeId(codec.TypeId); if (codec.CarrierKind == CarrierKind.Array) { sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new {ElementTypeName(codec.TypeName)}[{lengthVar}];"); @@ -1127,21 +1165,49 @@ private static void EmitReadCompatibleListArrayPayload( sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({lengthVar});"); } - sb.AppendLine($"{indent}for (int {indexVar} = 0; {indexVar} < {lengthVar}; {indexVar}++)"); + string indexVar = $"__foryIndex{id++}"; + sb.AppendLine($"{indent}switch (remoteFieldType.Generics[0].TypeId)"); sb.AppendLine($"{indent}{{"); - sb.AppendLine($"{innerIndent}object __foryItem = __ForyReadCompatiblePrimitivePayload((global::Apache.Fory.TypeId)remoteFieldType.Generics[0].TypeId, context);"); - if (codec.CarrierKind == CarrierKind.Array) + foreach (uint remoteElementTypeId in CompatibleElementReadTypeIds(elementTypeId)) { - sb.AppendLine($"{innerIndent}{targetVar}[{indexVar}] = ({elementTypeName})__foryItem;"); - } - else - { - sb.AppendLine($"{innerIndent}{targetVar}.Add(({elementTypeName})__foryItem);"); - } + if (!TryBuildDirectPayloadRead(remoteElementTypeId, out string? itemReadExpr)) + { + throw new InvalidOperationException($"unsupported compatible list element type id {remoteElementTypeId}"); + } + sb.AppendLine($"{indent} case {remoteElementTypeId}:"); + sb.AppendLine($"{indent} for (int {indexVar} = 0; {indexVar} < {lengthVar}; {indexVar}++)"); + sb.AppendLine($"{indent} {{"); + sb.AppendLine($"{indent} {elementTypeName} __foryItem = {itemReadExpr};"); + if (codec.CarrierKind == CarrierKind.Array) + { + sb.AppendLine($"{indent} {targetVar}[{indexVar}] = __foryItem;"); + } + else + { + sb.AppendLine($"{indent} {targetVar}.Add(__foryItem);"); + } + + sb.AppendLine($"{indent} }}"); + sb.AppendLine($"{indent} break;"); + } + sb.AppendLine($"{indent} default:"); + sb.AppendLine($"{indent} throw new global::Apache.Fory.InvalidDataException($\"unsupported compatible list element type {{remoteFieldType.Generics[0].TypeId}}\");"); sb.AppendLine($"{indent}}}"); } + private static uint[] CompatibleElementReadTypeIds(uint elementTypeId) + { + return elementTypeId switch + { + 4 or 5 => [4, 5], + 6 or 7 or 8 => [6, 7, 8], + 11 or 12 => [11, 12], + 13 or 14 or 15 => [13, 14, 15], + _ => [elementTypeId], + }; + } + private static void EmitWritePayload( StringBuilder sb, FieldCodecModel codec, @@ -1927,21 +1993,20 @@ private static void EmitReadMemberAssignment( throw new InvalidOperationException($"unsupported dynamic any kind {member.DynamicAnyKind}"); } + if (variableSuffix == "Compat" && + TryBuildCompatibleScalarReadExpression(member, out string? compatibleScalarReadExpr)) + { + sb.AppendLine($"{indent}{assignmentTarget} = {compatibleScalarReadExpr};"); + return; + } + if (member.FieldCodec is not null) { if (variableSuffix == "Compat" && - TryBuildCompatibleListArrayReadCodec(member.FieldCodec, out _)) + CanReadCompatibleField(member.FieldCodec)) { - sb.AppendLine($"{indent}if (remoteField.FieldType.TypeId == {member.FieldCodec.TypeId})"); - sb.AppendLine($"{indent}{{"); sb.AppendLine( - $"{indent} {assignmentTarget} = __ForyRead{Sanitize(member.Name)}Field(context, {refModeExpr});"); - sb.AppendLine($"{indent}}}"); - sb.AppendLine($"{indent}else"); - sb.AppendLine($"{indent}{{"); - sb.AppendLine( - $"{indent} {assignmentTarget} = __ForyCompatibleFieldReaders.Read{Sanitize(member.Name)}ListArrayBridge(context, remoteField.FieldType, {refModeExpr});"); - sb.AppendLine($"{indent}}}"); + $"{indent}{assignmentTarget} = __ForyCompatibleFieldReaders.Read{Sanitize(member.Name)}FieldBridge(context, remoteField.FieldType, {refModeExpr});"); } else { @@ -1967,7 +2032,7 @@ private static void EmitReadMemberAssignment( if (variableSuffix == "Compat") { sb.AppendLine( - $"{indent}{assignmentTarget} = __ForyReadCompatibleField<{member.TypeName}>(context, remoteField.FieldType, (global::Apache.Fory.TypeId){member.TypeMeta.TypeIdExpr}, \"{EscapeString(member.FieldIdentifier)}\", {refModeExpr}, {BuildWriteRefModeExpression(member)}, {readTypeInfoExpr});"); + $"{indent}{assignmentTarget} = context.TypeResolver.GetSerializer<{member.TypeName}>().Read(context, {refModeExpr}, {readTypeInfoExpr});"); return; } @@ -1975,6 +2040,69 @@ private static void EmitReadMemberAssignment( $"{indent}{assignmentTarget} = context.TypeResolver.GetSerializer<{member.TypeName}>().Read(context, {refModeExpr}, {readTypeInfoExpr});"); } + private static bool CompatibleCaseNeedsRemoteRefMode(MemberModel member) + { + return !IsCompatibleScalarMember(member); + } + + private static bool IsCompatibleScalarMember(MemberModel member) + { + return TryResolveCompatibleScalarTarget(member, out _); + } + + private static bool TryBuildCompatibleScalarReadExpression(MemberModel member, out string? readExpr) + { + readExpr = null; + if (!TryResolveCompatibleScalarTarget(member, out string? methodTarget)) + { + return false; + } + + string methodName = member.IsNullable ? $"ReadNullable{methodTarget}Field" : $"Read{methodTarget}Field"; + readExpr = + $"global::Apache.Fory.CompatibleScalarConverter.{methodName}(context, remoteField)"; + return true; + } + + private static bool TryResolveCompatibleScalarTarget(MemberModel member, out string? methodTarget) + { + methodTarget = null; + if (member.DynamicAnyKind != DynamicAnyKind.None || + !IsCompatibleScalarTypeId(member.Classification.TypeId)) + { + return false; + } + + string targetName = StripNullableForTypeOf(member.TypeName); + methodTarget = targetName switch + { + "bool" or "global::System.Boolean" => "Bool", + "sbyte" or "global::System.SByte" => "SByte", + "short" or "global::System.Int16" => "Int16", + "int" or "global::System.Int32" => "Int32", + "long" or "global::System.Int64" => "Int64", + "byte" or "global::System.Byte" => "Byte", + "ushort" or "global::System.UInt16" => "UInt16", + "uint" or "global::System.UInt32" => "UInt32", + "ulong" or "global::System.UInt64" => "UInt64", + "global::System.Half" => "Half", + "global::Apache.Fory.BFloat16" => "BFloat16", + "float" or "global::System.Single" => "Float", + "double" or "global::System.Double" => "Double", + "string" or "global::System.String" => "String", + "decimal" or "global::System.Decimal" => "Decimal", + "global::Apache.Fory.ForyDecimal" => "ForyDecimal", + _ => null, + }; + + return methodTarget is not null; + } + + private static bool IsCompatibleScalarTypeId(uint typeId) + { + return typeId is >= 1 and <= 15 or >= 17 and <= 21 or 40; + } + private static string StripNullableForTypeOf(string typeName) { return typeName.Replace("?", string.Empty); diff --git a/csharp/src/Fory/CompatibleScalarConverter.cs b/csharp/src/Fory/CompatibleScalarConverter.cs index 7da3e2013c..54fc5891e8 100644 --- a/csharp/src/Fory/CompatibleScalarConverter.cs +++ b/csharp/src/Fory/CompatibleScalarConverter.cs @@ -21,6 +21,27 @@ namespace Apache.Fory; +internal readonly struct CompatibleScalarRead( + TypeId rawRemoteTypeId, + TypeId remoteTypeId, + TypeId localTypeId, + bool remoteNullable, + bool int64FastSource, + string fieldName) +{ + public TypeId RawRemoteTypeId { get; } = rawRemoteTypeId; + + public TypeId RemoteTypeId { get; } = remoteTypeId; + + public TypeId LocalTypeId { get; } = localTypeId; + + public bool RemoteNullable { get; } = remoteNullable; + + public bool Int64FastSource { get; } = int64FastSource; + + public string FieldName { get; } = fieldName; +} + /// /// Provides compatible scalar field conversion for generated serializers. /// @@ -96,103 +117,419 @@ public static bool CanConvert(uint remoteTypeId, uint localTypeId) return IsNumeric(remote) && IsNumeric(local); } - /// - /// Returns whether the generated compatible reader should read the field through this scalar reader. - /// - public static bool RequiresScalarRead(uint remoteTypeId, uint localTypeId) + internal static CompatibleScalarRead? TryBuildRead( + TypeMetaFieldInfo remoteField, + TypeMetaFieldInfo localField) { - TypeId remote = NormalizeScalarTypeId(remoteTypeId); - TypeId local = NormalizeScalarTypeId(localTypeId); - return IsScalar(remote) && IsScalar(local) && remoteTypeId != localTypeId && - (remote == local || CanConvert(remoteTypeId, localTypeId)); + TypeMetaFieldType remoteFieldType = remoteField.FieldType; + TypeMetaFieldType localFieldType = localField.FieldType; + if (remoteFieldType.TrackRef || localFieldType.TrackRef) + { + return null; + } + + TypeId rawRemoteTypeId = (TypeId)remoteFieldType.TypeId; + TypeId remoteTypeId = NormalizeScalarTypeId(remoteFieldType.TypeId); + TypeId localTypeId = NormalizeScalarTypeId(localFieldType.TypeId); + if (!IsScalar(remoteTypeId) || !IsScalar(localTypeId)) + { + return null; + } + + bool sameWireType = remoteFieldType.TypeId == localFieldType.TypeId; + bool sameNullable = remoteFieldType.Nullable == localFieldType.Nullable; + if (sameWireType && sameNullable) + { + return null; + } + + if (!sameWireType && !CanConvert(remoteFieldType.TypeId, localFieldType.TypeId)) + { + return null; + } + + return new CompatibleScalarRead( + rawRemoteTypeId, + remoteTypeId, + localTypeId, + remoteFieldType.Nullable, + IsInt64FastSource(rawRemoteTypeId), + localField.FieldName); } - /// - /// Reads a remote scalar payload and converts it to the local field type. - /// [MethodImpl(MethodImplOptions.NoInlining)] - public static T ReadField( - ReadContext context, - TypeId remoteTypeId, - TypeId localTypeId, - string fieldName) + public static bool ReadBoolField(ReadContext context, TypeMetaFieldInfo remoteField) { - return ReadPayloadAs(context, remoteTypeId, localTypeId, fieldName); + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToBool(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : false; } - /// - /// Reads remote scalar null framing and converts the remote scalar payload to the local field type. - /// [MethodImpl(MethodImplOptions.NoInlining)] - public static T ReadField( - ReadContext context, - TypeId remoteTypeId, - TypeId localTypeId, - string fieldName, - RefMode refMode) + public static bool? ReadNullableBoolField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToBool(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static sbyte ReadSByteField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToSByte(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static sbyte? ReadNullableSByteField(ReadContext context, TypeMetaFieldInfo remoteField) { - switch (refMode) + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToSByte(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static short ReadInt16Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToInt16(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static short? ReadNullableInt16Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToInt16(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static int ReadInt32Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToInt32(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static int? ReadNullableInt32Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToInt32(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static long ReadInt64Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + if (TryReadIntegralToInt64Field(context, remoteField, out long fastValue, out bool present)) { - case RefMode.None: - return ReadPayloadAs(context, remoteTypeId, localTypeId, fieldName); - case RefMode.NullOnly: - { - RefFlag flag = context.RefReader.ReadRefFlag(context.Reader); - return flag switch - { - RefFlag.Null => default!, - RefFlag.NotNullValue => ReadPayloadAs(context, remoteTypeId, localTypeId, fieldName), - _ => throw Fail( - remoteTypeId, - localTypeId, - fieldName, - $"invalid compatible nullOnly ref flag {(sbyte)flag}"), - }; - } - default: - throw Fail(remoteTypeId, localTypeId, fieldName, $"unsupported compatible ref mode {refMode}"); + return present ? fastValue : default; + } + + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToInt64(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static long? ReadNullableInt64Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + if (TryReadIntegralToInt64Field(context, remoteField, out long fastValue, out bool present)) + { + return present ? fastValue : null; } + + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToInt64(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static byte ReadByteField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToByte(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static byte? ReadNullableByteField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToByte(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static ushort ReadUInt16Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToUInt16(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static ushort? ReadNullableUInt16Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToUInt16(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; } - private static T ReadPayloadAs( + [MethodImpl(MethodImplOptions.NoInlining)] + public static uint ReadUInt32Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToUInt32(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static uint? ReadNullableUInt32Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToUInt32(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static ulong ReadUInt64Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToUInt64(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static ulong? ReadNullableUInt64Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToUInt64(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static Half ReadHalfField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToHalfTarget(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static Half? ReadNullableHalfField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToHalfTarget(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static BFloat16 ReadBFloat16Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToBFloat16Target(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static BFloat16? ReadNullableBFloat16Field(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToBFloat16Target(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static float ReadFloatField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToSingleTarget(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static float? ReadNullableFloatField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToSingleTarget(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static double ReadDoubleField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToDoubleTarget(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static double? ReadNullableDoubleField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToDoubleTarget(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static string ReadStringField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToStringValue(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default!; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static string? ReadNullableStringField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToStringValue(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static ForyDecimal ReadForyDecimalField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToForyDecimalTarget(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static ForyDecimal? ReadNullableForyDecimalField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToForyDecimalTarget(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static decimal ReadDecimalField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToSystemDecimalTarget(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static decimal? ReadNullableDecimalField(ReadContext context, TypeMetaFieldInfo remoteField) + { + return TryReadScalarValue(context, remoteField, out CompatibleScalarRead read, out ScalarValue value) + ? ToSystemDecimalTarget(value, read.RemoteTypeId, read.LocalTypeId, read.FieldName) + : null; + } + + private static bool TryReadScalarValue( ReadContext context, - TypeId remoteTypeId, - TypeId localTypeId, - string fieldName) + TypeMetaFieldInfo remoteField, + out CompatibleScalarRead scalarRead, + out ScalarValue value) { - object value = ReadPayload(context, remoteTypeId, localTypeId, fieldName); - return ConvertReadValue(value, remoteTypeId, localTypeId, fieldName); + scalarRead = RequireRead(remoteField); + value = default; + + if (!scalarRead.RemoteNullable) + { + value = ReadScalarPayload( + context, + scalarRead.RawRemoteTypeId, + scalarRead.LocalTypeId, + scalarRead.FieldName); + return true; + } + + RefFlag flag = context.RefReader.ReadRefFlag(context.Reader); + switch (flag) + { + case RefFlag.Null: + return false; + case RefFlag.NotNullValue: + value = ReadScalarPayload( + context, + scalarRead.RawRemoteTypeId, + scalarRead.LocalTypeId, + scalarRead.FieldName); + return true; + default: + throw Fail( + scalarRead.RemoteTypeId, + scalarRead.LocalTypeId, + scalarRead.FieldName, + $"invalid compatible nullOnly ref flag {(sbyte)flag}"); + } } - private static T ConvertReadValue( - object? value, - TypeId remoteTypeId, - TypeId localTypeId, - string fieldName) + private static bool TryReadIntegralToInt64Field( + ReadContext context, + TypeMetaFieldInfo remoteField, + out long value, + out bool present) { - if (value is null) + CompatibleScalarRead scalarRead = RequireRead(remoteField); + value = default; + present = false; + if (!scalarRead.Int64FastSource) { - return default!; + return false; } - if (value is T typedValue) + if (!scalarRead.RemoteNullable) { - return typedValue; + value = ReadInt64FastPayload( + context, + scalarRead.RawRemoteTypeId, + scalarRead.LocalTypeId, + scalarRead.FieldName); + present = true; + return true; } - object converted = ConvertValue(value, remoteTypeId, localTypeId, typeof(T), fieldName); - return (T)converted; + RefFlag flag = context.RefReader.ReadRefFlag(context.Reader); + switch (flag) + { + case RefFlag.Null: + return true; + case RefFlag.NotNullValue: + value = ReadInt64FastPayload( + context, + scalarRead.RawRemoteTypeId, + scalarRead.LocalTypeId, + scalarRead.FieldName); + present = true; + return true; + default: + throw Fail( + scalarRead.RemoteTypeId, + scalarRead.LocalTypeId, + scalarRead.FieldName, + $"invalid compatible nullOnly ref flag {(sbyte)flag}"); + } } - private static object ReadPayload( - ReadContext context, - TypeId remoteTypeId, - TypeId localTypeId, - string fieldName) + private static CompatibleScalarRead RequireRead(TypeMetaFieldInfo remoteField) + { + return remoteField.CompatibleScalarRead + ?? throw new InvalidDataException( + $"compatible scalar field {remoteField.FieldName} was not classified for scalar conversion"); + } + + private static bool IsInt64FastSource(TypeId remoteTypeId) + { + return remoteTypeId is TypeId.Bool or TypeId.Int8 or TypeId.Int16 or TypeId.Int32 or + TypeId.VarInt32 or TypeId.Int64 or TypeId.VarInt64 or TypeId.TaggedInt64 or + TypeId.UInt8 or TypeId.UInt16 or TypeId.UInt32 or TypeId.VarUInt32 or + TypeId.UInt64 or TypeId.VarUInt64 or TypeId.TaggedUInt64; + } + + private static long ReadInt64FastPayload(ReadContext context, TypeId remoteTypeId, TypeId localTypeId, string fieldName) { return remoteTypeId switch { - TypeId.Bool => ReadBool(context, localTypeId, fieldName), + TypeId.Bool => ReadBool(context, localTypeId, fieldName) ? 1 : 0, TypeId.Int8 => context.Reader.ReadInt8(), TypeId.Int16 => context.Reader.ReadInt16(), TypeId.Int32 => context.Reader.ReadInt32(), @@ -204,15 +541,52 @@ private static object ReadPayload( TypeId.UInt16 => context.Reader.ReadUInt16(), TypeId.UInt32 => context.Reader.ReadUInt32(), TypeId.VarUInt32 => context.Reader.ReadVarUInt32(), - TypeId.UInt64 => context.Reader.ReadUInt64(), - TypeId.VarUInt64 => context.Reader.ReadVarUInt64(), - TypeId.TaggedUInt64 => context.Reader.ReadTaggedUInt64(), - TypeId.Float16 => BitConverter.UInt16BitsToHalf(context.Reader.ReadUInt16()), - TypeId.BFloat16 => BFloat16.FromBits(context.Reader.ReadUInt16()), - TypeId.Float32 => context.Reader.ReadFloat32(), - TypeId.Float64 => context.Reader.ReadFloat64(), - TypeId.Decimal => ReadDecimal(context), - TypeId.String => StringSerializer.ReadString(context), + TypeId.UInt64 => ToInt64(context.Reader.ReadUInt64(), remoteTypeId, localTypeId, fieldName), + TypeId.VarUInt64 => ToInt64(context.Reader.ReadVarUInt64(), remoteTypeId, localTypeId, fieldName), + TypeId.TaggedUInt64 => ToInt64(context.Reader.ReadTaggedUInt64(), remoteTypeId, localTypeId, fieldName), + _ => throw Fail(remoteTypeId, localTypeId, fieldName, $"unsupported compatible scalar type id {remoteTypeId}"), + }; + } + + private static long ToInt64(ulong value, TypeId remote, TypeId local, string fieldName) + { + if (value <= long.MaxValue) + { + return (long)value; + } + + throw Fail(remote, local, fieldName, $"integer value {value} is outside {local} range"); + } + + private static ScalarValue ReadScalarPayload( + ReadContext context, + TypeId remoteTypeId, + TypeId localTypeId, + string fieldName) + { + return remoteTypeId switch + { + TypeId.Bool => ScalarValue.ForBool(ReadBool(context, localTypeId, fieldName)), + TypeId.Int8 => ScalarValue.ForSigned(context.Reader.ReadInt8()), + TypeId.Int16 => ScalarValue.ForSigned(context.Reader.ReadInt16()), + TypeId.Int32 => ScalarValue.ForSigned(context.Reader.ReadInt32()), + TypeId.VarInt32 => ScalarValue.ForSigned(context.Reader.ReadVarInt32()), + TypeId.Int64 => ScalarValue.ForSigned(context.Reader.ReadInt64()), + TypeId.VarInt64 => ScalarValue.ForSigned(context.Reader.ReadVarInt64()), + TypeId.TaggedInt64 => ScalarValue.ForSigned(context.Reader.ReadTaggedInt64()), + TypeId.UInt8 => ScalarValue.ForUnsigned(context.Reader.ReadUInt8()), + TypeId.UInt16 => ScalarValue.ForUnsigned(context.Reader.ReadUInt16()), + TypeId.UInt32 => ScalarValue.ForUnsigned(context.Reader.ReadUInt32()), + TypeId.VarUInt32 => ScalarValue.ForUnsigned(context.Reader.ReadVarUInt32()), + TypeId.UInt64 => ScalarValue.ForUnsigned(context.Reader.ReadUInt64()), + TypeId.VarUInt64 => ScalarValue.ForUnsigned(context.Reader.ReadVarUInt64()), + TypeId.TaggedUInt64 => ScalarValue.ForUnsigned(context.Reader.ReadTaggedUInt64()), + TypeId.Float16 => ScalarValue.ForHalf(BitConverter.UInt16BitsToHalf(context.Reader.ReadUInt16())), + TypeId.BFloat16 => ScalarValue.ForBFloat16(BFloat16.FromBits(context.Reader.ReadUInt16())), + TypeId.Float32 => ScalarValue.ForSingle(context.Reader.ReadFloat32()), + TypeId.Float64 => ScalarValue.ForDouble(context.Reader.ReadFloat64()), + TypeId.Decimal => ScalarValue.ForDecimal(ReadDecimal(context)), + TypeId.String => ScalarValue.ForString(StringSerializer.ReadString(context)), _ => throw Fail( remoteTypeId, localTypeId, @@ -223,12 +597,12 @@ private static object ReadPayload( private static bool ReadBool(ReadContext context, TypeId localTypeId, string fieldName) { - byte value = context.Reader.ReadUInt8(); - return value switch + byte raw = context.Reader.ReadUInt8(); + return raw switch { 0 => false, 1 => true, - _ => throw Fail(TypeId.Bool, localTypeId, fieldName, $"invalid bool payload {value}"), + _ => throw Fail(TypeId.Bool, localTypeId, fieldName, $"invalid bool payload {raw}"), }; } @@ -238,204 +612,173 @@ private static ForyDecimal ReadDecimal(ReadContext context) return new ForyDecimal(unscaled, scale); } - private static object ConvertValue( - object value, - TypeId remoteTypeId, - TypeId localTypeId, - Type targetType, - string fieldName) + private static bool ToBool(ScalarValue value, TypeId remote, TypeId local, string fieldName) { - TypeId remote = NormalizeScalarTypeId((uint)remoteTypeId); - TypeId local = NormalizeScalarTypeId((uint)localTypeId); - Type unwrappedTarget = Nullable.GetUnderlyingType(targetType) ?? targetType; - if (local == TypeId.Bool) + switch (value.Kind) { - return ToBool(value, remote, local, fieldName); - } + case ScalarValueKind.Bool: + return value.BoolValue; + case ScalarValueKind.String: + return value.StringValue switch + { + "0" or "false" => false, + "1" or "true" => true, + string s => throw Fail(remote, local, fieldName, $"cannot convert string '{s}' to bool"), + _ => throw Fail(remote, local, fieldName, "remote value is not a string"), + }; + default: + DecimalValue numeric = Normalize(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName); + if (numeric.Unscaled.IsZero) + { + return false; + } - if (local == TypeId.String) - { - return ToStringValue(value, remote, local, fieldName); - } + if (numeric.Scale == 0 && numeric.Unscaled.IsOne) + { + return true; + } - if (IsNumeric(local)) - { - return ToNumeric(value, remote, local, unwrappedTarget, fieldName); + throw Fail(remote, local, fieldName, "numeric value is not exactly 0 or 1"); } - - throw Fail(remote, local, fieldName, "unsupported compatible scalar target"); } - private static bool ToBool(object value, TypeId remote, TypeId local, string fieldName) + private static string ToStringValue(ScalarValue value, TypeId remote, TypeId local, string fieldName) { - if (remote == TypeId.String) - { - return value switch - { - "0" or "false" => false, - "1" or "true" => true, - string s => throw Fail(remote, local, fieldName, $"cannot convert string '{s}' to bool"), - _ => throw Fail(remote, local, fieldName, "remote value is not a string"), - }; - } + switch (value.Kind) + { + case ScalarValueKind.Bool: + return value.BoolValue ? "true" : "false"; + case ScalarValueKind.String: + return value.StringValue!; + case ScalarValueKind.Float16: + case ScalarValueKind.BFloat16: + case ScalarValueKind.Float32: + case ScalarValueKind.Float64: + { + DecimalValue numeric = ToDecimalValue(value, remote, local, fieldName); + if (numeric.Unscaled.IsZero) + { + return numeric.NegativeZero ? "-0.0" : "0.0"; + } - DecimalValue numeric = ToDecimalValue(value, remote, local, fieldName); - numeric = Normalize(numeric, remote, local, fieldName); - if (numeric.Unscaled.IsZero) - { - return false; + return FormatFloating(Normalize(numeric, remote, local, fieldName)); + } + case ScalarValueKind.Decimal: + return FormatDecimal(Normalize(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName)); + default: + return FormatInteger(ToInteger(value, remote, local, fieldName)); } + } - if (numeric.Scale == 0 && numeric.Unscaled.IsOne) - { - return true; - } + private static sbyte ToSByte(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + return (sbyte)CheckedInteger(ToSignedInteger(value, remote, local, fieldName), SByteMin, SByteMax, remote, local, fieldName); + } - throw Fail(remote, local, fieldName, "numeric value is not exactly 0 or 1"); + private static short ToInt16(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + return (short)CheckedInteger(ToSignedInteger(value, remote, local, fieldName), Int16Min, Int16Max, remote, local, fieldName); } - private static string ToStringValue(object value, TypeId remote, TypeId local, string fieldName) + private static int ToInt32(ScalarValue value, TypeId remote, TypeId local, string fieldName) { - if (remote == TypeId.Bool) - { - return (bool)value ? "true" : "false"; - } + return (int)CheckedInteger(ToSignedInteger(value, remote, local, fieldName), Int32Min, Int32Max, remote, local, fieldName); + } - if (!IsNumeric(remote)) - { - throw Fail(remote, local, fieldName, "remote value is not numeric"); - } + private static long ToInt64(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + return (long)CheckedInteger(ToSignedInteger(value, remote, local, fieldName), Int64Min, Int64Max, remote, local, fieldName); + } - if (IsFloating(remote)) - { - DecimalValue numeric = ToDecimalValue(value, remote, local, fieldName); - if (numeric.Unscaled.IsZero) - { - return numeric.NegativeZero ? "-0.0" : "0.0"; - } + private static byte ToByte(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + return (byte)CheckedInteger(ToUnsignedInteger(value, remote, local, fieldName), BigInteger.Zero, ByteMax, remote, local, fieldName); + } - return FormatFloating(Normalize(numeric, remote, local, fieldName)); - } + private static ushort ToUInt16(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + return (ushort)CheckedInteger(ToUnsignedInteger(value, remote, local, fieldName), BigInteger.Zero, UInt16Max, remote, local, fieldName); + } - if (remote == TypeId.Decimal) - { - return FormatDecimal(Normalize(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName)); - } + private static uint ToUInt32(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + return (uint)CheckedInteger(ToUnsignedInteger(value, remote, local, fieldName), BigInteger.Zero, UInt32Max, remote, local, fieldName); + } - return FormatInteger(ToInteger(value, remote, local, fieldName)); + private static ulong ToUInt64(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + return (ulong)CheckedInteger(ToUnsignedInteger(value, remote, local, fieldName), BigInteger.Zero, UInt64Max, remote, local, fieldName); } - private static object ToNumeric( - object value, - TypeId remote, - TypeId local, - Type targetType, - string fieldName) + private static Half ToHalfTarget(ScalarValue value, TypeId remote, TypeId local, string fieldName) { - if (remote == TypeId.Bool) + if (TryNonFinite(value, out int sign)) { - return FromInteger((bool)value ? BigInteger.One : BigInteger.Zero, remote, local, targetType, fieldName); + return sign switch + { + < 0 => Half.NegativeInfinity, + > 0 => Half.PositiveInfinity, + _ => throw Fail(remote, local, fieldName, "NaN cannot convert across floating scalar types"), + }; } - if (remote == TypeId.String) + return ToHalf(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName); + } + + private static BFloat16 ToBFloat16Target(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + if (TryNonFinite(value, out int sign)) { - if (!TryParseNumber((string)value, out DecimalValue parsed)) + return sign switch { - throw Fail(remote, local, fieldName, "string is not a compatible numeric literal"); - } - - return FromDecimal(parsed, local, targetType, remote, fieldName); + < 0 => BFloat16.NegativeInfinity, + > 0 => BFloat16.PositiveInfinity, + _ => throw Fail(remote, local, fieldName, "NaN cannot convert across floating scalar types"), + }; } - if (IsInteger(remote)) - { - return FromInteger(ToInteger(value, remote, local, fieldName), remote, local, targetType, fieldName); - } + return ToBFloat16(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName); + } - if (IsFloating(remote) && IsFloating(local) && TryNonFinite(value, out int sign)) + private static float ToSingleTarget(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + if (TryNonFinite(value, out int sign)) { - if (sign == 0) + return sign switch { - throw Fail(remote, local, fieldName, "NaN cannot convert across floating scalar types"); - } - - return NonFiniteFloat(sign, local, fieldName); + < 0 => float.NegativeInfinity, + > 0 => float.PositiveInfinity, + _ => throw Fail(remote, local, fieldName, "NaN cannot convert across floating scalar types"), + }; } - DecimalValue numeric = ToDecimalValue(value, remote, local, fieldName); - return FromDecimal(numeric, local, targetType, remote, fieldName); + return ToSingle(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName); } - private static object FromInteger( - BigInteger value, - TypeId remote, - TypeId local, - Type targetType, - string fieldName) + private static double ToDoubleTarget(ScalarValue value, TypeId remote, TypeId local, string fieldName) { - return local switch - { - TypeId.Int8 => CheckedSigned(value, SByteMin, SByteMax, remote, local, fieldName), - TypeId.Int16 => CheckedSigned(value, Int16Min, Int16Max, remote, local, fieldName), - TypeId.Int32 => CheckedSigned(value, Int32Min, Int32Max, remote, local, fieldName), - TypeId.Int64 => CheckedSigned(value, Int64Min, Int64Max, remote, local, fieldName), - TypeId.UInt8 => CheckedUnsigned(value, ByteMax, remote, local, fieldName), - TypeId.UInt16 => CheckedUnsigned(value, UInt16Max, remote, local, fieldName), - TypeId.UInt32 => CheckedUnsigned(value, UInt32Max, remote, local, fieldName), - TypeId.UInt64 => CheckedUnsigned(value, UInt64Max, remote, local, fieldName), - TypeId.Float16 or TypeId.BFloat16 or TypeId.Float32 or TypeId.Float64 => - FromDecimal(new DecimalValue(value, 0, false), local, targetType, remote, fieldName), - TypeId.Decimal => FromDecimal(new DecimalValue(value, 0, false), local, targetType, remote, fieldName), - _ => throw Fail(remote, local, fieldName, "unsupported numeric target"), - }; + if (TryNonFinite(value, out int sign)) + { + return sign switch + { + < 0 => double.NegativeInfinity, + > 0 => double.PositiveInfinity, + _ => throw Fail(remote, local, fieldName, "NaN cannot convert across floating scalar types"), + }; + } + + return ToDouble(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName); } - private static object FromDecimal( - DecimalValue value, - TypeId local, - Type targetType, - TypeId remote, - string fieldName) + private static ForyDecimal ToForyDecimalTarget(ScalarValue value, TypeId remote, TypeId local, string fieldName) { - value = Normalize(value, remote, local, fieldName); - return local switch - { - TypeId.Int8 => CheckedSigned(Integral(value, remote, local, fieldName), SByteMin, SByteMax, remote, local, fieldName), - TypeId.Int16 => CheckedSigned(Integral(value, remote, local, fieldName), Int16Min, Int16Max, remote, local, fieldName), - TypeId.Int32 => CheckedSigned(Integral(value, remote, local, fieldName), Int32Min, Int32Max, remote, local, fieldName), - TypeId.Int64 => CheckedSigned(Integral(value, remote, local, fieldName), Int64Min, Int64Max, remote, local, fieldName), - TypeId.UInt8 => CheckedUnsigned(UnsignedIntegral(value, remote, local, fieldName), ByteMax, remote, local, fieldName), - TypeId.UInt16 => CheckedUnsigned(UnsignedIntegral(value, remote, local, fieldName), UInt16Max, remote, local, fieldName), - TypeId.UInt32 => CheckedUnsigned(UnsignedIntegral(value, remote, local, fieldName), UInt32Max, remote, local, fieldName), - TypeId.UInt64 => CheckedUnsigned(UnsignedIntegral(value, remote, local, fieldName), UInt64Max, remote, local, fieldName), - TypeId.Float16 => ToHalf(value, remote, local, fieldName), - TypeId.BFloat16 => ToBFloat16(value, remote, local, fieldName), - TypeId.Float32 => ToSingle(value, remote, local, fieldName), - TypeId.Float64 => ToDouble(value, remote, local, fieldName), - TypeId.Decimal => ToDecimalCarrier(value, targetType, remote, local, fieldName), - _ => throw Fail(remote, local, fieldName, "unsupported numeric target"), - }; + DecimalValue decimalValue = Normalize(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName); + return new ForyDecimal(decimalValue.Unscaled, decimalValue.Scale); } - private static object ToDecimalCarrier( - DecimalValue value, - Type targetType, - TypeId remote, - TypeId local, - string fieldName) + private static decimal ToSystemDecimalTarget(ScalarValue value, TypeId remote, TypeId local, string fieldName) { - value = Normalize(value, remote, local, fieldName); - if (targetType == typeof(ForyDecimal)) - { - return new ForyDecimal(value.Unscaled, value.Scale); - } - - if (targetType == typeof(decimal)) - { - return ToSystemDecimal(value, remote, local, fieldName); - } - - throw Fail(remote, local, fieldName, $"unsupported decimal carrier {targetType}"); + return ToSystemDecimal(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName); } private static decimal ToSystemDecimal(DecimalValue value, TypeId remote, TypeId local, string fieldName) @@ -572,7 +915,7 @@ private static BigInteger UnsignedIntegral(DecimalValue value, TypeId remote, Ty throw Fail(remote, local, fieldName, "negative value cannot convert to unsigned target"); } - private static object CheckedSigned( + private static BigInteger CheckedInteger( BigInteger value, BigInteger min, BigInteger max, @@ -585,135 +928,104 @@ private static object CheckedSigned( throw Fail(remote, local, fieldName, $"integer value {value} is outside {local} range"); } - if (typeof(T) == typeof(sbyte)) - { - return (sbyte)value; - } + return value; + } - if (typeof(T) == typeof(short)) + private static BigInteger ToSignedInteger(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + if (value.Kind == ScalarValueKind.Bool) { - return (short)value; + return value.BoolValue ? BigInteger.One : BigInteger.Zero; } - if (typeof(T) == typeof(int)) + if (value.Kind is ScalarValueKind.Signed or ScalarValueKind.Unsigned) { - return (int)value; + return ToInteger(value, remote, local, fieldName); } - return (long)value; + return Integral(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName); } - private static object CheckedUnsigned( - BigInteger value, - BigInteger max, - TypeId remote, - TypeId local, - string fieldName) + private static BigInteger ToUnsignedInteger(ScalarValue value, TypeId remote, TypeId local, string fieldName) { - if (value.Sign < 0 || value > max) - { - throw Fail(remote, local, fieldName, $"integer value {value} is outside {local} range"); - } - - if (typeof(T) == typeof(byte)) + if (value.Kind == ScalarValueKind.Bool) { - return (byte)value; + return value.BoolValue ? BigInteger.One : BigInteger.Zero; } - if (typeof(T) == typeof(ushort)) + if (value.Kind is ScalarValueKind.Signed or ScalarValueKind.Unsigned) { - return (ushort)value; - } + BigInteger integer = ToInteger(value, remote, local, fieldName); + if (integer.Sign >= 0) + { + return integer; + } - if (typeof(T) == typeof(uint)) - { - return (uint)value; + throw Fail(remote, local, fieldName, "negative value cannot convert to unsigned target"); } - return (ulong)value; + return UnsignedIntegral(ToDecimalValue(value, remote, local, fieldName), remote, local, fieldName); } - private static BigInteger ToInteger(object value, TypeId remote, TypeId local, string fieldName) + private static BigInteger ToInteger(ScalarValue value, TypeId remote, TypeId local, string fieldName) { - return value switch + return value.Kind switch { - sbyte v => v, - short v => v, - int v => v, - long v => v, - byte v => v, - ushort v => v, - uint v => v, - ulong v => v, - _ => throw Fail(remote, local, fieldName, $"remote value type {value.GetType()} is not an integer"), + ScalarValueKind.Signed => value.SignedValue, + ScalarValueKind.Unsigned => value.UnsignedValue, + _ => throw Fail(remote, local, fieldName, $"remote value kind {value.Kind} is not an integer"), }; } - private static DecimalValue ToDecimalValue(object value, TypeId remote, TypeId local, string fieldName) - { - return value switch - { - sbyte v => new DecimalValue(v, 0, false), - short v => new DecimalValue(v, 0, false), - int v => new DecimalValue(v, 0, false), - long v => new DecimalValue(v, 0, false), - byte v => new DecimalValue(v, 0, false), - ushort v => new DecimalValue(v, 0, false), - uint v => new DecimalValue(v, 0, false), - ulong v => new DecimalValue(v, 0, false), - Half v => FromHalfChecked(v, remote, local, fieldName), - BFloat16 v => FromBFloat16Checked(v, remote, local, fieldName), - float v => FromSingleChecked(v, remote, local, fieldName), - double v => FromDoubleChecked(v, remote, local, fieldName), - ForyDecimal v => new DecimalValue(v.UnscaledValue, v.Scale, false), - decimal v => FromSystemDecimal(v), - _ => throw Fail(remote, local, fieldName, $"remote value type {value.GetType()} is not numeric"), + private static DecimalValue ToDecimalValue(ScalarValue value, TypeId remote, TypeId local, string fieldName) + { + return value.Kind switch + { + ScalarValueKind.Bool => new DecimalValue(value.BoolValue ? BigInteger.One : BigInteger.Zero, 0, false), + ScalarValueKind.Signed => new DecimalValue(value.SignedValue, 0, false), + ScalarValueKind.Unsigned => new DecimalValue(value.UnsignedValue, 0, false), + ScalarValueKind.Float16 => FromHalfChecked(value.HalfValue, remote, local, fieldName), + ScalarValueKind.BFloat16 => FromBFloat16Checked(value.BFloat16Value, remote, local, fieldName), + ScalarValueKind.Float32 => FromSingleChecked(value.SingleValue, remote, local, fieldName), + ScalarValueKind.Float64 => FromDoubleChecked(value.DoubleValue, remote, local, fieldName), + ScalarValueKind.Decimal => new DecimalValue(value.DecimalValue.UnscaledValue, value.DecimalValue.Scale, false), + ScalarValueKind.String => TryParseNumber(value.StringValue!, out DecimalValue parsed) + ? parsed + : throw Fail(remote, local, fieldName, "string is not a compatible numeric literal"), + _ => throw Fail(remote, local, fieldName, $"remote value kind {value.Kind} is not numeric"), }; } - private static bool TryNonFinite(object value, out int sign) + private static bool TryNonFinite(ScalarValue value, out int sign) { sign = 0; - switch (value) + switch (value.Kind) { - case Half v when Half.IsNaN(v): + case ScalarValueKind.Float16 when Half.IsNaN(value.HalfValue): return true; - case Half v when Half.IsInfinity(v): - sign = Half.IsNegative(v) ? -1 : 1; + case ScalarValueKind.Float16 when Half.IsInfinity(value.HalfValue): + sign = Half.IsNegative(value.HalfValue) ? -1 : 1; return true; - case BFloat16 v when v.IsNaN: + case ScalarValueKind.BFloat16 when value.BFloat16Value.IsNaN: return true; - case BFloat16 v when v.IsInfinity: - sign = v.SignBit ? -1 : 1; + case ScalarValueKind.BFloat16 when value.BFloat16Value.IsInfinity: + sign = value.BFloat16Value.SignBit ? -1 : 1; return true; - case float v when float.IsNaN(v): + case ScalarValueKind.Float32 when float.IsNaN(value.SingleValue): return true; - case float v when float.IsInfinity(v): - sign = float.IsNegative(v) ? -1 : 1; + case ScalarValueKind.Float32 when float.IsInfinity(value.SingleValue): + sign = float.IsNegative(value.SingleValue) ? -1 : 1; return true; - case double v when double.IsNaN(v): + case ScalarValueKind.Float64 when double.IsNaN(value.DoubleValue): return true; - case double v when double.IsInfinity(v): - sign = double.IsNegative(v) ? -1 : 1; + case ScalarValueKind.Float64 when double.IsInfinity(value.DoubleValue): + sign = double.IsNegative(value.DoubleValue) ? -1 : 1; return true; default: return false; } } - private static object NonFiniteFloat(int sign, TypeId local, string fieldName) - { - bool negative = sign < 0; - return local switch - { - TypeId.Float16 => negative ? Half.NegativeInfinity : Half.PositiveInfinity, - TypeId.BFloat16 => negative ? BFloat16.NegativeInfinity : BFloat16.PositiveInfinity, - TypeId.Float32 => negative ? float.NegativeInfinity : float.PositiveInfinity, - TypeId.Float64 => negative ? double.NegativeInfinity : double.PositiveInfinity, - _ => throw Fail(local, local, fieldName, "non-finite value requires a floating target"), - }; - } - private static DecimalValue FromHalfChecked(Half value, TypeId remote, TypeId local, string fieldName) { if (!Half.IsFinite(value)) @@ -1150,5 +1462,110 @@ private static InvalidDataException Fail(TypeId remote, TypeId local, string fie $"compatible scalar conversion failed for field '{fieldName}' from {remote} to {local}: {reason}"); } + private enum ScalarValueKind + { + Bool, + String, + Signed, + Unsigned, + Float16, + BFloat16, + Float32, + Float64, + Decimal, + } + + private readonly struct ScalarValue + { + private ScalarValue( + ScalarValueKind kind, + bool boolValue = false, + string? stringValue = null, + long signedValue = 0, + ulong unsignedValue = 0, + Half halfValue = default, + BFloat16 bfloat16Value = default, + float singleValue = default, + double doubleValue = default, + ForyDecimal decimalValue = default) + { + Kind = kind; + BoolValue = boolValue; + StringValue = stringValue; + SignedValue = signedValue; + UnsignedValue = unsignedValue; + HalfValue = halfValue; + BFloat16Value = bfloat16Value; + SingleValue = singleValue; + DoubleValue = doubleValue; + DecimalValue = decimalValue; + } + + public ScalarValueKind Kind { get; } + + public bool BoolValue { get; } + + public string? StringValue { get; } + + public long SignedValue { get; } + + public ulong UnsignedValue { get; } + + public Half HalfValue { get; } + + public BFloat16 BFloat16Value { get; } + + public float SingleValue { get; } + + public double DoubleValue { get; } + + public ForyDecimal DecimalValue { get; } + + public static ScalarValue ForBool(bool value) + { + return new ScalarValue(ScalarValueKind.Bool, boolValue: value); + } + + public static ScalarValue ForString(string value) + { + return new ScalarValue(ScalarValueKind.String, stringValue: value); + } + + public static ScalarValue ForSigned(long value) + { + return new ScalarValue(ScalarValueKind.Signed, signedValue: value); + } + + public static ScalarValue ForUnsigned(ulong value) + { + return new ScalarValue(ScalarValueKind.Unsigned, unsignedValue: value); + } + + public static ScalarValue ForHalf(Half value) + { + return new ScalarValue(ScalarValueKind.Float16, halfValue: value); + } + + public static ScalarValue ForBFloat16(BFloat16 value) + { + return new ScalarValue(ScalarValueKind.BFloat16, bfloat16Value: value); + } + + public static ScalarValue ForSingle(float value) + { + return new ScalarValue(ScalarValueKind.Float32, singleValue: value); + } + + public static ScalarValue ForDouble(double value) + { + return new ScalarValue(ScalarValueKind.Float64, doubleValue: value); + } + + public static ScalarValue ForDecimal(ForyDecimal value) + { + return new ScalarValue(ScalarValueKind.Decimal, decimalValue: value); + } + } + private readonly record struct DecimalValue(BigInteger Unscaled, int Scale, bool NegativeZero); } diff --git a/csharp/src/Fory/TypeMeta.cs b/csharp/src/Fory/TypeMeta.cs index c14d550d80..97d3de1177 100644 --- a/csharp/src/Fory/TypeMeta.cs +++ b/csharp/src/Fory/TypeMeta.cs @@ -268,6 +268,8 @@ public TypeMetaFieldInfo(short? fieldId, string fieldName, TypeMetaFieldType fie public int AssignedFieldId { get; internal set; } + internal CompatibleScalarRead? CompatibleScalarRead { get; set; } + internal void Write(ByteWriter writer) { byte header = 0; @@ -706,10 +708,11 @@ private static bool IsNamedKind(uint typeId) } /// - /// Assigns local sorted field indexes for a remote compatible type meta. + /// Assigns doubled compatible dispatch ids for a remote compatible type meta. /// The result is written to each remote field's : - /// - local sorted field index when a compatible local field is found - /// - -1 when no compatible local field is found and the field should be skipped + /// - local sorted field index * 2 when the matched field schema is exact + /// - local sorted field index * 2 + 1 when the matched field schema needs compatible conversion + /// - -1 when no local field identity is found and the field should be skipped /// public static void AssignFieldIds( TypeMeta remoteTypeMeta, @@ -723,11 +726,6 @@ public static void AssignFieldIds( for (int i = 0; i < localFieldInfos.Count; i++) { TypeMetaFieldInfo localField = localFieldInfos[i]; - if (!string.IsNullOrEmpty(localField.FieldName)) - { - localByName.TryAdd(localField.FieldName, (i, localField)); - } - if (localField.FieldId.HasValue && localField.FieldId.Value >= 0) { short fieldId = localField.FieldId.Value; @@ -737,12 +735,18 @@ public static void AssignFieldIds( $"duplicate local field id {fieldId} in compatible type metadata"); } } + else if (!string.IsNullOrEmpty(localField.FieldName)) + { + localByName.TryAdd(localField.FieldName, (i, localField)); + } } HashSet? remoteFieldIds = null; + HashSet usedLocalFields = []; for (int i = 0; i < remoteTypeMeta.Fields.Count; i++) { TypeMetaFieldInfo remoteField = remoteTypeMeta.Fields[i]; + remoteField.CompatibleScalarRead = null; if (remoteField.FieldId.HasValue && remoteField.FieldId.Value >= 0) { short fieldId = remoteField.FieldId.Value; @@ -758,12 +762,13 @@ public static void AssignFieldIds( TypeMetaFieldInfo? localMatch = null; if (remoteField.FieldId.HasValue && + remoteField.FieldId.Value >= 0 && localById.TryGetValue(remoteField.FieldId.Value, out (int Index, TypeMetaFieldInfo Field) byId)) { localIndex = byId.Index; localMatch = byId.Field; } - else + else if (!remoteField.FieldId.HasValue) { if (localByName.TryGetValue(remoteField.FieldName, out (int Index, TypeMetaFieldInfo Field) byName)) { @@ -782,11 +787,31 @@ public static void AssignFieldIds( } } - if (localIndex >= 0 && - localMatch is not null && - IsCompatibleFieldType(remoteField.FieldType, localMatch.FieldType, topLevel: true)) + if (localIndex >= 0 && localMatch is not null) { - remoteField.AssignedFieldId = localIndex; + if (!usedLocalFields.Add(localIndex)) + { + throw new InvalidDataException( + $"compatible field {remoteField.FieldName} duplicates local field {localMatch.FieldName}"); + } + if (IsDirectFieldType(remoteField.FieldType, localMatch.FieldType)) + { + remoteField.AssignedFieldId = localIndex * 2; + } + else if (CompatibleScalarConverter.TryBuildRead(remoteField, localMatch) is { } scalarRead) + { + remoteField.AssignedFieldId = localIndex * 2 + 1; + remoteField.CompatibleScalarRead = scalarRead; + } + else if (IsCompatibleFieldType(remoteField.FieldType, localMatch.FieldType, topLevel: true)) + { + remoteField.AssignedFieldId = localIndex * 2 + 1; + } + else + { + throw new InvalidDataException( + $"compatible field {remoteField.FieldName} cannot be read as local field {localMatch.FieldName}: remote and local field schemas are not compatible"); + } } else { @@ -797,29 +822,59 @@ localMatch is not null && private static bool IsCompatibleFieldType(TypeMetaFieldType remote, TypeMetaFieldType local, bool topLevel) { + if (topLevel && IsCompatibleByteSequenceFieldPair(remote, local)) + { + return true; + } + if (topLevel && IsCompatibleListArrayFieldPair(remote, local)) { return true; } - if (topLevel && - (remote.TrackRef || local.TrackRef) && - CompatibleScalarConverter.IsScalarType(remote.TypeId) && - CompatibleScalarConverter.IsScalarType(local.TypeId)) + bool remoteScalar = CompatibleScalarConverter.IsScalarType(remote.TypeId); + bool localScalar = CompatibleScalarConverter.IsScalarType(local.TypeId); + if (remoteScalar || localScalar) { - return remote.TrackRef == local.TrackRef && + return !remote.TrackRef && + !local.TrackRef && remote.TypeId == local.TypeId && - remote.Nullable == local.Nullable; + remote.Generics.Count == local.Generics.Count; + } + + if (NormalizeTypeIdForMatch(remote.TypeId) != NormalizeTypeIdForMatch(local.TypeId)) + { + return false; + } + + if (remote.Generics.Count != local.Generics.Count) + { + return false; + } + + for (int i = 0; i < remote.Generics.Count; i++) + { + if (!IsCompatibleFieldType(remote.Generics[i], local.Generics[i], topLevel: false)) + { + return false; + } } - if (topLevel && - !remote.TrackRef && - !local.TrackRef && - CompatibleScalarConverter.CanConvert(remote.TypeId, local.TypeId)) + return true; + } + + private static bool IsDirectFieldType(TypeMetaFieldType remote, TypeMetaFieldType local) + { + if (remote.Equals(local)) { return true; } + if (remote.Nullable != local.Nullable || remote.TrackRef != local.TrackRef) + { + return false; + } + if (NormalizeTypeIdForMatch(remote.TypeId) != NormalizeTypeIdForMatch(local.TypeId)) { return false; @@ -832,7 +887,7 @@ private static bool IsCompatibleFieldType(TypeMetaFieldType remote, TypeMetaFiel for (int i = 0; i < remote.Generics.Count; i++) { - if (!IsCompatibleFieldType(remote.Generics[i], local.Generics[i], topLevel: false)) + if (!IsDirectFieldType(remote.Generics[i], local.Generics[i])) { return false; } @@ -841,13 +896,35 @@ private static bool IsCompatibleFieldType(TypeMetaFieldType remote, TypeMetaFiel return true; } + private static bool IsCompatibleByteSequenceFieldPair(TypeMetaFieldType remote, TypeMetaFieldType local) + { + if (remote.TrackRef || local.TrackRef || remote.Nullable != local.Nullable) + { + return false; + } + + return (remote.TypeId == (uint)global::Apache.Fory.TypeId.Binary && + local.TypeId == (uint)global::Apache.Fory.TypeId.UInt8Array) || + (remote.TypeId == (uint)global::Apache.Fory.TypeId.UInt8Array && + local.TypeId == (uint)global::Apache.Fory.TypeId.Binary); + } + private static bool IsCompatibleListArrayFieldPair(TypeMetaFieldType remote, TypeMetaFieldType local) { + if (remote.Nullable || local.Nullable || remote.TrackRef || local.TrackRef) + { + return false; + } + uint? localArrayElementTypeId = TryPackedArrayElementTypeId(local.TypeId); uint? remoteArrayElementTypeId = TryPackedArrayElementTypeId(remote.TypeId); + // Nullable element schema is allowed for list -> array; actual null payload + // elements fail in the dense-array reader. Ref-tracked element framing is rejected here + // because this path stays primitive-only. bool remoteListLocalArray = remote.TypeId == (uint)global::Apache.Fory.TypeId.List && localArrayElementTypeId.HasValue && remote.Generics.Count == 1 && + !remote.Generics[0].TrackRef && NormalizeScalarTypeIdForMatch(localArrayElementTypeId.Value) == NormalizeScalarTypeIdForMatch(remote.Generics[0].TypeId); if (remoteListLocalArray) @@ -917,9 +994,6 @@ private static uint NormalizeTypeIdForMatch(uint typeId) (uint)global::Apache.Fory.TypeId.Union or (uint)global::Apache.Fory.TypeId.NamedUnion or (uint)global::Apache.Fory.TypeId.TypedUnion => (uint)global::Apache.Fory.TypeId.Union, - (uint)global::Apache.Fory.TypeId.Binary or - (uint)global::Apache.Fory.TypeId.Int8Array or - (uint)global::Apache.Fory.TypeId.UInt8Array => (uint)global::Apache.Fory.TypeId.Binary, _ => typeId, }; } diff --git a/csharp/tests/Fory.Tests/ForyGeneratorTests.cs b/csharp/tests/Fory.Tests/ForyGeneratorTests.cs index f546fa13b4..5b6bb4a665 100644 --- a/csharp/tests/Fory.Tests/ForyGeneratorTests.cs +++ b/csharp/tests/Fory.Tests/ForyGeneratorTests.cs @@ -128,6 +128,66 @@ public sealed partial record Unknown(UnknownCase Value) : OnlyUnknown; Assert.DoesNotContain(output.GetDiagnostics(), diagnostic => diagnostic.Severity == DiagnosticSeverity.Error && diagnostic.Id != "FORY006"); } + [Fact] + public void CompatibleReadSourceUsesTypedCases() + { + const string source = """ + using System.Collections.Generic; + using Apache.Fory; + using S = Apache.Fory.Schema.Types; + + namespace GeneratedDiagnostics; + + [ForyStruct] + public sealed class Shape + { + [ForyField(1, Type = typeof(S.Bool))] + public bool Flag { get; set; } + + [ForyField(2, Type = typeof(S.Int32))] + public int? Count { get; set; } + + [ForyField(3, Type = typeof(S.String))] + public string? Name { get; set; } + + [ForyField(4, Type = typeof(S.Array))] + public int[] Values { get; set; } = []; + } + """; + + string generated = GenerateSource(source); + + Assert.Contains("case 0:", generated, StringComparison.Ordinal); + Assert.Contains("case 1:", generated, StringComparison.Ordinal); + Assert.Contains("case 2:", generated, StringComparison.Ordinal); + Assert.Contains("case 3:", generated, StringComparison.Ordinal); + Assert.DoesNotContain("__ForyLocalFields", generated, StringComparison.Ordinal); + Assert.Contains("ReadBoolField(context, remoteField)", generated, StringComparison.Ordinal); + Assert.Contains("ReadNullableStringField(context, remoteField)", generated, StringComparison.Ordinal); + Assert.Contains("ReadNullableInt32Field(context, remoteField)", generated, StringComparison.Ordinal); + Assert.Contains("ReadValuesFieldBridge(context, remoteField.FieldType", generated, StringComparison.Ordinal); + Assert.DoesNotContain("__ForyReadCompatibleField<", generated, StringComparison.Ordinal); + Assert.DoesNotContain("RequiresScalarRead", generated, StringComparison.Ordinal); + Assert.DoesNotContain("CompatibleScalarConverter.ReadBoolField(context, remoteField.FieldType", generated, StringComparison.Ordinal); + Assert.DoesNotContain("if (remoteField.FieldType.TypeId ==", generated, StringComparison.Ordinal); + } + + private static string GenerateSource(string source) + { + CSharpCompilation compilation = CreateCompilation(source); + GeneratorDriver driver = CSharpGeneratorDriver.Create(new ForyModelGenerator()); + driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out Compilation output, out ImmutableArray diagnostics); + + Assert.DoesNotContain( + diagnostics.Concat(output.GetDiagnostics()), + diagnostic => diagnostic.Severity == DiagnosticSeverity.Error); + + return string.Join( + "\n", + driver.GetRunResult().Results.SelectMany(result => result.GeneratedSources) + .Select(sourceResult => sourceResult.SourceText.ToString())); + } + private static CSharpCompilation CreateCompilation(string source) { IEnumerable platformReferences = diff --git a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs index 45a39017a3..b8c06f695d 100644 --- a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs +++ b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs @@ -173,6 +173,27 @@ public sealed class CompatibleUInt32ArrayListCarrierSchema public List Values { get; set; } = []; } +[ForyStruct] +public sealed class CompatibleBinarySchema +{ + [ForyField(1, Type = typeof(S.Binary))] + public byte[] Value { get; set; } = []; +} + +[ForyStruct] +public sealed class CompatibleUInt8ArrayBytesSchema +{ + [ForyField(1, Type = typeof(S.Array))] + public byte[] Value { get; set; } = []; +} + +[ForyStruct] +public sealed class CompatibleUInt8ArrayListSchema +{ + [ForyField(1, Type = typeof(S.Array))] + public List Value { get; set; } = []; +} + [ForyStruct] public sealed class CompatibleNestedListSchema { @@ -235,6 +256,13 @@ public sealed class ScalarInt32Field public int Value { get; set; } } +[ForyStruct] +public sealed class ScalarVarInt32Field +{ + [ForyField(1)] + public int Value { get; set; } +} + [ForyStruct] public sealed class ScalarNullableInt32Field { @@ -242,6 +270,20 @@ public sealed class ScalarNullableInt32Field public int? Value { get; set; } } +[ForyStruct] +public sealed class ScalarInt64Field +{ + [ForyField(1, Type = typeof(S.Int64))] + public long Value { get; set; } +} + +[ForyStruct] +public sealed class ScalarNullableInt64Field +{ + [ForyField(1, Type = typeof(S.Int64))] + public long? Value { get; set; } +} + [ForyStruct] public sealed class ScalarUInt32Field { @@ -249,6 +291,13 @@ public sealed class ScalarUInt32Field public uint Value { get; set; } } +[ForyStruct] +public sealed class ScalarUInt64Field +{ + [ForyField(1, Type = typeof(S.UInt64))] + public ulong Value { get; set; } +} + [ForyStruct] public sealed class ScalarFixedUInt32Field { @@ -1243,25 +1292,57 @@ public void CompatibleReadSupportsUInt32ListArrayFieldPairs() } [Fact] - public void CompatibleReadRejectsNullableListElementsIntoArrayCarrier() + public void CompatibleReadSupportsBinaryUint8ArrayPairs() + { + ForyRuntime binaryWriter = ForyRuntime.Builder().Compatible(true).Build(); + binaryWriter.Register(311); + ForyRuntime uint8ArrayReader = ForyRuntime.Builder().Compatible(true).Build(); + uint8ArrayReader.Register(311); + + byte[] binaryPayload = [0, 1, 2, 250, 255]; + CompatibleUInt8ArrayBytesSchema decodedArray = + uint8ArrayReader.Deserialize( + binaryWriter.Serialize(new CompatibleBinarySchema { Value = binaryPayload })); + Assert.Equal(binaryPayload, decodedArray.Value); + + ForyRuntime listReader = ForyRuntime.Builder().Compatible(true).Build(); + listReader.Register(311); + CompatibleUInt8ArrayListSchema decodedList = + listReader.Deserialize( + binaryWriter.Serialize(new CompatibleBinarySchema { Value = binaryPayload })); + Assert.Equal(binaryPayload, decodedList.Value); + + ForyRuntime uint8ArrayWriter = ForyRuntime.Builder().Compatible(true).Build(); + uint8ArrayWriter.Register(312); + ForyRuntime binaryReader = ForyRuntime.Builder().Compatible(true).Build(); + binaryReader.Register(312); + + byte[] arrayPayload = [3, 4, 5, 253]; + CompatibleBinarySchema decodedBinary = binaryReader.Deserialize( + uint8ArrayWriter.Serialize(new CompatibleUInt8ArrayBytesSchema { Value = arrayPayload })); + Assert.Equal(arrayPayload, decodedBinary.Value); + } + + [Fact] + public void CompatibleReadAllowsNullableListSchemaIntoArrayCarrier() { ForyRuntime writer = ForyRuntime.Builder().Compatible(true).Build(); writer.Register(308); ForyRuntime reader = ForyRuntime.Builder().Compatible(true).Build(); reader.Register(308); - byte[] nonNullPayload = writer.Serialize(new CompatibleNullableListSchema { Values = [1, 2] }); - CompatibleArraySchema decoded = reader.Deserialize(nonNullPayload); + byte[] payload = writer.Serialize(new CompatibleNullableListSchema { Values = [1, 2] }); + CompatibleArraySchema decoded = reader.Deserialize(payload); Assert.Equal([1, 2], decoded.Values); - byte[] payload = writer.Serialize(new CompatibleNullableListSchema { Values = [1, null] }); + byte[] nullPayload = writer.Serialize(new CompatibleNullableListSchema { Values = [1, null, 3] }); InvalidDataException exception = - Assert.Throws(() => reader.Deserialize(payload)); - Assert.Contains("compatible list to array field requires non-null elements", exception.Message); + Assert.Throws(() => reader.Deserialize(nullPayload)); + Assert.Contains("non-null elements", exception.Message); } [Fact] - public void CompatibleReadDoesNotMatchNestedListArraySchemaPairs() + public void RejectsNestedListArrayPairs() { List localFields = [ @@ -1301,9 +1382,150 @@ [new TypeMetaFieldType((uint)TypeId.VarInt32, false)]), ])), ]); - TypeMeta.AssignFieldIds(remoteTypeMeta, localFields); + InvalidDataException exception = + Assert.Throws(() => TypeMeta.AssignFieldIds(remoteTypeMeta, localFields)); + Assert.Contains("remote and local field schemas are not compatible", exception.Message); + } - Assert.Equal(-1, remoteTypeMeta.Fields[0].AssignedFieldId); + [Fact] + public void AllowsNestedScalarNullableChanges() + { + List localFields = + [ + new TypeMetaFieldInfo( + null, + "values", + new TypeMetaFieldType( + (uint)TypeId.List, + false, + false, + [new TypeMetaFieldType((uint)TypeId.String, false)])), + ]; + + TypeMeta remoteNullableTypeMeta = new( + (uint)TypeId.CompatibleStruct, + 0, + MetaString.Empty('_', '_'), + new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), + false, + [ + new TypeMetaFieldInfo( + null, + "values", + new TypeMetaFieldType( + (uint)TypeId.List, + false, + false, + [new TypeMetaFieldType((uint)TypeId.String, true)])), + ]); + + TypeMeta.AssignFieldIds(remoteNullableTypeMeta, localFields); + Assert.Equal(1, remoteNullableTypeMeta.Fields[0].AssignedFieldId); + + TypeMeta remoteTrackedTypeMeta = new( + (uint)TypeId.CompatibleStruct, + 0, + MetaString.Empty('_', '_'), + new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), + false, + [ + new TypeMetaFieldInfo( + null, + "values", + new TypeMetaFieldType( + (uint)TypeId.List, + false, + false, + [new TypeMetaFieldType((uint)TypeId.String, false, true)])), + ]); + + Assert.Throws(() => TypeMeta.AssignFieldIds(remoteTrackedTypeMeta, localFields)); + } + + [Fact] + public void RestrictsByteSequenceCompatibility() + { + List localBytes = + [ + new TypeMetaFieldInfo(1, "value", new TypeMetaFieldType((uint)TypeId.Binary, false)), + ]; + TypeMeta remoteUInt8Array = new( + (uint)TypeId.CompatibleStruct, + 0, + MetaString.Empty('_', '_'), + new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), + false, + [new TypeMetaFieldInfo(1, "$tag1", new TypeMetaFieldType((uint)TypeId.UInt8Array, false))]); + + TypeMeta.AssignFieldIds(remoteUInt8Array, localBytes); + Assert.Equal(1, remoteUInt8Array.Fields[0].AssignedFieldId); + + TypeMeta remoteInt8Array = new( + (uint)TypeId.CompatibleStruct, + 0, + MetaString.Empty('_', '_'), + new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), + false, + [new TypeMetaFieldInfo(1, "$tag1", new TypeMetaFieldType((uint)TypeId.Int8Array, false))]); + + Assert.Throws(() => TypeMeta.AssignFieldIds(remoteInt8Array, localBytes)); + + List localNested = + [ + new TypeMetaFieldInfo( + null, + "value", + new TypeMetaFieldType( + (uint)TypeId.List, + false, + false, + [new TypeMetaFieldType((uint)TypeId.Binary, false)])), + ]; + TypeMeta remoteNested = new( + (uint)TypeId.CompatibleStruct, + 0, + MetaString.Empty('_', '_'), + new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), + false, + [ + new TypeMetaFieldInfo( + null, + "value", + new TypeMetaFieldType( + (uint)TypeId.List, + false, + false, + [new TypeMetaFieldType((uint)TypeId.UInt8Array, false)])), + ]); + + Assert.Throws(() => TypeMeta.AssignFieldIds(remoteNested, localNested)); + } + + [Fact] + public void RejectsFramedListArrayPairs() + { + List localFields = + [ + new TypeMetaFieldInfo(1, "values", new TypeMetaFieldType((uint)TypeId.Int32Array, false)), + ]; + TypeMeta remoteTypeMeta = new( + (uint)TypeId.CompatibleStruct, + 0, + MetaString.Empty('_', '_'), + new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), + false, + [ + new TypeMetaFieldInfo( + 1, + "$tag1", + new TypeMetaFieldType( + (uint)TypeId.List, + true, + false, + [new TypeMetaFieldType((uint)TypeId.VarInt32, false)])), + ]); + + Assert.Throws(() => TypeMeta.AssignFieldIds(remoteTypeMeta, localFields)); } [Fact] @@ -1352,6 +1574,10 @@ public void CompatibleScalarNumberBounds() { Assert.Equal((short)123, CompatibleRead( new ScalarInt32Field { Value = 123 }).Value); + Assert.Equal(123L, CompatibleRead( + new ScalarInt32Field { Value = 123 }).Value); + Assert.Equal((long)int.MaxValue, CompatibleRead( + new ScalarVarInt32Field { Value = int.MaxValue }).Value); Assert.Equal(1, CompatibleRead( new ScalarFloat64Field { Value = 1.0d }).Value); Assert.Equal(double.PositiveInfinity, CompatibleRead( @@ -1365,6 +1591,8 @@ public void CompatibleScalarNumberBounds() new ScalarInt32Field { Value = 16_777_217 })); Assert.Throws(() => CompatibleRead( new ScalarFloat32Field { Value = float.NaN })); + Assert.Throws(() => CompatibleRead( + new ScalarUInt64Field { Value = ulong.MaxValue })); } [Fact] @@ -1449,6 +1677,10 @@ public void CompatibleScalarNullable() new ScalarStringField { Value = null }).Value); Assert.Equal(1, CompatibleRead( new ScalarBoolField { Value = true }).Value); + Assert.Equal(1L, CompatibleRead( + new ScalarBoolField { Value = true }).Value); + Assert.Null(CompatibleRead( + new ScalarNullableInt32Field { Value = null }).Value); } [Fact] @@ -1488,8 +1720,8 @@ public void CompatibleScalarTrackingRefRules() new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), false, [new TypeMetaFieldInfo(1, "$tag1", new TypeMetaFieldType((uint)TypeId.String, false, true))]); - TypeMeta.AssignFieldIds(remoteTrackingTypeMeta, localFields); - Assert.Equal(-1, remoteTrackingTypeMeta.Fields[0].AssignedFieldId); + Assert.Throws( + () => TypeMeta.AssignFieldIds(remoteTrackingTypeMeta, localFields)); List localTrackingFields = [ @@ -1502,8 +1734,8 @@ public void CompatibleScalarTrackingRefRules() new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), false, [new TypeMetaFieldInfo(1, "$tag1", new TypeMetaFieldType((uint)TypeId.String, false))]); - TypeMeta.AssignFieldIds(remoteTypeMeta, localTrackingFields); - Assert.Equal(-1, remoteTypeMeta.Fields[0].AssignedFieldId); + Assert.Throws( + () => TypeMeta.AssignFieldIds(remoteTypeMeta, localTrackingFields)); TypeMeta remoteBoolTrackingTypeMeta = new( (uint)TypeId.CompatibleStruct, @@ -1512,8 +1744,8 @@ public void CompatibleScalarTrackingRefRules() new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), false, [new TypeMetaFieldInfo(1, "$tag1", new TypeMetaFieldType((uint)TypeId.Bool, false, true))]); - TypeMeta.AssignFieldIds(remoteBoolTrackingTypeMeta, localFields); - Assert.Equal(-1, remoteBoolTrackingTypeMeta.Fields[0].AssignedFieldId); + Assert.Throws( + () => TypeMeta.AssignFieldIds(remoteBoolTrackingTypeMeta, localFields)); TypeMeta remoteBoolTypeMeta = new( (uint)TypeId.CompatibleStruct, @@ -1522,8 +1754,8 @@ public void CompatibleScalarTrackingRefRules() new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), false, [new TypeMetaFieldInfo(1, "$tag1", new TypeMetaFieldType((uint)TypeId.Bool, false))]); - TypeMeta.AssignFieldIds(remoteBoolTypeMeta, localTrackingFields); - Assert.Equal(-1, remoteBoolTypeMeta.Fields[0].AssignedFieldId); + Assert.Throws( + () => TypeMeta.AssignFieldIds(remoteBoolTypeMeta, localTrackingFields)); TypeMeta remoteBoolBothTrackingTypeMeta = new( (uint)TypeId.CompatibleStruct, @@ -1556,8 +1788,8 @@ public void CompatibleScalarTrackingRefRules() new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), false, [new TypeMetaFieldInfo(1, "$tag1", new TypeMetaFieldType((uint)TypeId.Bool, true, true))]); - TypeMeta.AssignFieldIds(remoteBoolTrackingNullableTypeMeta, localTrackingFields); - Assert.Equal(-1, remoteBoolTrackingNullableTypeMeta.Fields[0].AssignedFieldId); + Assert.Throws( + () => TypeMeta.AssignFieldIds(remoteBoolTrackingNullableTypeMeta, localTrackingFields)); TypeMeta remoteBoolTrackingRequiredTypeMeta = new( (uint)TypeId.CompatibleStruct, @@ -1566,8 +1798,8 @@ public void CompatibleScalarTrackingRefRules() new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), false, [new TypeMetaFieldInfo(1, "$tag1", new TypeMetaFieldType((uint)TypeId.Bool, false, true))]); - TypeMeta.AssignFieldIds(remoteBoolTrackingRequiredTypeMeta, localNullableTrackingFields); - Assert.Equal(-1, remoteBoolTrackingRequiredTypeMeta.Fields[0].AssignedFieldId); + Assert.Throws( + () => TypeMeta.AssignFieldIds(remoteBoolTrackingRequiredTypeMeta, localNullableTrackingFields)); List localTrackingUIntFields = [ @@ -1580,17 +1812,16 @@ public void CompatibleScalarTrackingRefRules() new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), false, [new TypeMetaFieldInfo(1, "$tag1", new TypeMetaFieldType((uint)TypeId.UInt32, false, true))]); - TypeMeta.AssignFieldIds(remoteFixedUIntTrackingTypeMeta, localTrackingUIntFields); - Assert.Equal(-1, remoteFixedUIntTrackingTypeMeta.Fields[0].AssignedFieldId); + Assert.Throws( + () => TypeMeta.AssignFieldIds(remoteFixedUIntTrackingTypeMeta, localTrackingUIntFields)); } [Fact] - public void NestedScalarConversionNotApplied() + public void NestedScalarConversionRejected() { - ScalarInt32ListField decoded = CompatibleRead( - new ScalarStringListField { Values = ["1"] }); - - Assert.Empty(decoded.Values); + Assert.Throws( + () => CompatibleRead( + new ScalarStringListField { Values = ["1"] })); } [Fact] @@ -2341,7 +2572,7 @@ public void TypeMetaAssignFieldIdsPrefersIdAndFallsBackToName() { List localFields = [ - new TypeMetaFieldInfo(1, "int_value", new TypeMetaFieldType((uint)TypeId.VarInt32, false)), + new TypeMetaFieldInfo(null, "int_value", new TypeMetaFieldType((uint)TypeId.VarInt32, false)), new TypeMetaFieldInfo(2, "name", new TypeMetaFieldType((uint)TypeId.String, true)), ]; List remoteFields = @@ -2359,13 +2590,61 @@ public void TypeMetaAssignFieldIdsPrefersIdAndFallsBackToName() remoteFields); TypeMeta.AssignFieldIds(remoteTypeMeta, localFields); - Assert.Equal(1, remoteTypeMeta.Fields[0].AssignedFieldId); + Assert.Equal(2, remoteTypeMeta.Fields[0].AssignedFieldId); Assert.Equal(0, remoteTypeMeta.Fields[1].AssignedFieldId); Assert.Equal(-1, remoteTypeMeta.Fields[2].AssignedFieldId); } [Fact] - public void TypeMetaAssignFieldIdsSkipsTypeMismatchedField() + public void TypeMetaAssignFieldIdsDoesNotNameMatchTaggedLocalField() + { + List localFields = + [ + new TypeMetaFieldInfo(1, "value", new TypeMetaFieldType((uint)TypeId.String, false)), + ]; + List remoteFields = + [ + new TypeMetaFieldInfo(null, "value", new TypeMetaFieldType((uint)TypeId.String, false)), + ]; + TypeMeta remoteTypeMeta = new( + (uint)TypeId.CompatibleStruct, + 505, + MetaString.Empty('.', '_'), + MetaString.Empty('$', '_'), + registerByName: false, + remoteFields); + + TypeMeta.AssignFieldIds(remoteTypeMeta, localFields); + Assert.Equal(-1, remoteTypeMeta.Fields[0].AssignedFieldId); + } + + [Fact] + public void TypeMetaAssignFieldIdsThrowsOnDuplicateRemoteNameBinding() + { + List localFields = + [ + new TypeMetaFieldInfo(null, "value", new TypeMetaFieldType((uint)TypeId.String, false)), + ]; + List remoteFields = + [ + new TypeMetaFieldInfo(null, "value", new TypeMetaFieldType((uint)TypeId.String, false)), + new TypeMetaFieldInfo(null, "value", new TypeMetaFieldType((uint)TypeId.String, false)), + ]; + TypeMeta remoteTypeMeta = new( + (uint)TypeId.CompatibleStruct, + 506, + MetaString.Empty('.', '_'), + MetaString.Empty('$', '_'), + registerByName: false, + remoteFields); + + InvalidDataException exception = Assert.Throws( + () => TypeMeta.AssignFieldIds(remoteTypeMeta, localFields)); + Assert.Contains("duplicates local field value", exception.Message, StringComparison.Ordinal); + } + + [Fact] + public void RejectsTypeMismatchedField() { List localFields = [ @@ -2383,8 +2662,9 @@ public void TypeMetaAssignFieldIdsSkipsTypeMismatchedField() registerByName: false, remoteFields); - TypeMeta.AssignFieldIds(remoteTypeMeta, localFields); - Assert.Equal(-1, remoteTypeMeta.Fields[0].AssignedFieldId); + InvalidDataException exception = + Assert.Throws(() => TypeMeta.AssignFieldIds(remoteTypeMeta, localFields)); + Assert.Contains("remote and local field schemas are not compatible", exception.Message); } [Fact] @@ -2407,7 +2687,7 @@ public void TypeMetaAssignScalarConversion() remoteFields); TypeMeta.AssignFieldIds(remoteTypeMeta, localFields); - Assert.Equal(0, remoteTypeMeta.Fields[0].AssignedFieldId); + Assert.Equal(1, remoteTypeMeta.Fields[0].AssignedFieldId); } [Fact] diff --git a/dart/packages/fory-test/lib/entity/xlang_test_manual.dart b/dart/packages/fory-test/lib/entity/xlang_test_manual.dart index 21af02627d..f9a551f640 100644 --- a/dart/packages/fory-test/lib/entity/xlang_test_manual.dart +++ b/dart/packages/fory-test/lib/entity/xlang_test_manual.dart @@ -168,19 +168,19 @@ _refOverrideContainerForySchema = GeneratedStructSchema( final class _RefOverrideContainerForySerializer extends Serializer implements GeneratedStructSerializer { - List? _generatedFields; + List? _fieldDescriptors; _RefOverrideContainerForySerializer(); - List _writeFields(WriteContext context) { - return _generatedFields ??= buildGeneratedStructFieldInfos( + List _writeFields(WriteContext context) { + return _fieldDescriptors ??= buildGeneratedStructFieldDescriptors( context.typeResolver, _refOverrideContainerForySchema, ); } - List _readFields(ReadContext context) { - return _generatedFields ??= buildGeneratedStructFieldInfos( + List _readFields(ReadContext context) { + return _fieldDescriptors ??= buildGeneratedStructFieldDescriptors( context.typeResolver, _refOverrideContainerForySchema, ); @@ -189,9 +189,36 @@ final class _RefOverrideContainerForySerializer @override void write(WriteContext context, RefOverrideContainer value) { final fields = _writeFields(context); - writeGeneratedStructFieldInfoValue(context, fields[0], value.listField); - writeGeneratedStructFieldInfoValue(context, fields[1], value.mapField); - writeGeneratedStructFieldInfoValue(context, fields[2], value.setField); + final field0 = fields[0]; + final field0Value = value.listField; + final field0Declared = field0.declaredTypeInfo; + if (field0Declared != null && field0.usesDeclaredType) { + context.writeResolvedValue(field0Declared, field0Value, field0.fieldType); + } else { + final actualResolved = context.typeResolver.resolveValue(field0Value); + context.writeTypeMetaValue(actualResolved, field0Value); + context.writeResolvedValue(actualResolved, field0Value, field0.fieldType); + } + final field1 = fields[1]; + final field1Value = value.mapField; + final field1Declared = field1.declaredTypeInfo; + if (field1Declared != null && field1.usesDeclaredType) { + context.writeResolvedValue(field1Declared, field1Value, field1.fieldType); + } else { + final actualResolved = context.typeResolver.resolveValue(field1Value); + context.writeTypeMetaValue(actualResolved, field1Value); + context.writeResolvedValue(actualResolved, field1Value, field1.fieldType); + } + final field2 = fields[2]; + final field2Value = value.setField; + final field2Declared = field2.declaredTypeInfo; + if (field2Declared != null && field2.usesDeclaredType) { + context.writeResolvedValue(field2Declared, field2Value, field2.fieldType); + } else { + final actualResolved = context.typeResolver.resolveValue(field2Value); + context.writeTypeMetaValue(actualResolved, field2Value); + context.writeResolvedValue(actualResolved, field2Value, field2.fieldType); + } } @override @@ -199,15 +226,15 @@ final class _RefOverrideContainerForySerializer final value = RefOverrideContainer(); final fields = _readFields(context); value.listField = _readRefOverrideContainerListField( - readGeneratedStructFieldInfoValue(context, fields[0], value.listField), + readGeneratedStructDescriptorValue(context, fields[0], value.listField), value.listField, ); value.mapField = _readRefOverrideContainerMapField( - readGeneratedStructFieldInfoValue(context, fields[1], value.mapField), + readGeneratedStructDescriptorValue(context, fields[1], value.mapField), value.mapField, ); value.setField = _readRefOverrideContainerSetField( - readGeneratedStructFieldInfoValue(context, fields[2], value.setField), + readGeneratedStructDescriptorValue(context, fields[2], value.setField), value.setField, ); return value; @@ -219,34 +246,64 @@ final class _RefOverrideContainerForySerializer CompatibleStructReadLayout layout, ) { final value = RefOverrideContainer(); + final fields = _readFields(context); for (var index = 0; index < layout.fieldCount; index += 1) { - final field = layout.localFieldAt(index); - if (field == null) { - skipGeneratedCompatibleStructField(context, layout, index); - continue; - } - switch (field.index) { + final field = layout.fieldAt(index); + switch (field.matchedId) { + case -1: + skipGeneratedCompatibleStructField(context, field); + break; case 0: value.listField = _readRefOverrideContainerListField( - readGeneratedCompatibleStructField(context, layout, index), + readGeneratedStructDescriptorValue( + context, + fields[0], + value.listField, + ), value.listField, ); break; case 1: + value.listField = _readRefOverrideContainerListField( + readGeneratedCompatibleStructField(context, field), + value.listField, + ); + break; + case 2: value.mapField = _readRefOverrideContainerMapField( - readGeneratedCompatibleStructField(context, layout, index), + readGeneratedStructDescriptorValue( + context, + fields[1], + value.mapField, + ), value.mapField, ); break; - case 2: + case 3: + value.mapField = _readRefOverrideContainerMapField( + readGeneratedCompatibleStructField(context, field), + value.mapField, + ); + break; + case 4: + value.setField = _readRefOverrideContainerSetField( + readGeneratedStructDescriptorValue( + context, + fields[2], + value.setField, + ), + value.setField, + ); + break; + case 5: value.setField = _readRefOverrideContainerSetField( - readGeneratedCompatibleStructField(context, layout, index), + readGeneratedCompatibleStructField(context, field), value.setField, ); break; default: throw StateError( - 'Compatible field index is out of range for RefOverrideContainer.', + 'Compatible matched id is out of range for RefOverrideContainer.', ); } } diff --git a/dart/packages/fory/lib/src/codegen/fory_generator.dart b/dart/packages/fory/lib/src/codegen/fory_generator.dart index 97c133ed48..79e33a278d 100644 --- a/dart/packages/fory/lib/src/codegen/fory_generator.dart +++ b/dart/packages/fory/lib/src/codegen/fory_generator.dart @@ -158,12 +158,12 @@ final class ForyGenerator extends Generator { .toList(growable: false); final sortedFields = _sortFields(fields); - final constructorPlan = _buildConstructorPlan(element, sortedFields); + final constructionModel = _buildConstructionModel(element, sortedFields); return _GeneratedStructSpec( name: element.displayName, evolving: evolving, fields: sortedFields, - constructorPlan: constructorPlan, + constructionModel: constructionModel, ); } @@ -372,7 +372,7 @@ final class ForyGenerator extends Generator { ); } - _ConstructorPlan _buildConstructorPlan( + _ConstructionModel _buildConstructionModel( ClassElement element, List<_GeneratedFieldSpec> fields, ) { @@ -384,7 +384,7 @@ final class ForyGenerator extends Generator { (parameter) => parameter.isOptional, ); if (hasZeroArgConstructor && fields.every((field) => field.writable)) { - return const _ConstructorPlan.mutable(); + return const _ConstructionModel.mutable(); } if (unnamedConstructor == null || unnamedConstructor.isFactory) { @@ -449,7 +449,7 @@ final class ForyGenerator extends Generator { .map((field) => field.name) .toList(growable: false); - return _ConstructorPlan.constructor( + return _ConstructionModel.constructor( arguments: arguments, postConstructionFieldNames: postConstructionFields, ); @@ -490,6 +490,9 @@ final class ForyGenerator extends Generator { final hasRuntimeFastPath = structSpec.fields.any( (field) => !_usesDirectGeneratedBasicFastPath(field), ); + final writeNeedsFieldDescriptors = structSpec.fields.any( + _writeUsesFieldDescriptor, + ); final writeUsesBuffer = structSpec.fields.any( _directGeneratedBasicWriteNeedsBuffer, ); @@ -539,15 +542,15 @@ final class ForyGenerator extends Generator { ..writeln( 'final class $serializerClassName extends Serializer<${structSpec.name}> implements GeneratedStructSerializer<${structSpec.name}> {', ) - ..writeln(' List? _generatedFields;') + ..writeln(' List? _fieldDescriptors;') ..writeln() ..writeln(' $serializerClassName();') ..writeln() ..writeln( - ' List _writeFields(WriteContext context) {', + ' List _writeFields(WriteContext context) {', ) ..writeln( - ' return _generatedFields ??= buildGeneratedStructFieldInfos(', + ' return _fieldDescriptors ??= buildGeneratedStructFieldDescriptors(', ) ..writeln(' context.typeResolver,') ..writeln(' $schemaName,') @@ -555,10 +558,10 @@ final class ForyGenerator extends Generator { ..writeln(' }') ..writeln() ..writeln( - ' List _readFields(ReadContext context) {', + ' List _readFields(ReadContext context) {', ) ..writeln( - ' return _generatedFields ??= buildGeneratedStructFieldInfos(', + ' return _fieldDescriptors ??= buildGeneratedStructFieldDescriptors(', ) ..writeln(' context.typeResolver,') ..writeln(' $schemaName,') @@ -571,7 +574,7 @@ final class ForyGenerator extends Generator { if (writeUsesBuffer) { output.writeln(' final buffer = context.buffer;'); } - if (hasRuntimeFastPath) { + if (writeNeedsFieldDescriptors) { output.writeln(' final fields = _writeFields(context);'); } for (var index = 0; index < structSpec.fields.length; index += 1) { @@ -603,12 +606,12 @@ final class ForyGenerator extends Generator { ' ${_directGeneratedTypedContainerWriteStatement(field, index, 'value.${field.name}')};', ); } else { - final fieldValue = _generatedFieldInfoWriteValueExpression( + _writeGeneratedDescriptorValue( + output, field, + index, 'value.${field.name}', - ); - output.writeln( - ' writeGeneratedStructFieldInfoValue(context, fields[$index], $fieldValue);', + ' ', ); } final directPrimitiveEndRun = directPrimitiveRunByEnd[index]; @@ -626,7 +629,7 @@ final class ForyGenerator extends Generator { ..writeln(' @override') ..writeln(' ${structSpec.name} read(ReadContext context) {'); - switch (structSpec.constructorPlan.mode) { + switch (structSpec.constructionModel.mode) { case _ConstructorMode.mutable: output.writeln(' final value = ${structSpec.name}();'); if (_structNeedsEarlyReadReference(structSpec)) { @@ -680,7 +683,7 @@ final class ForyGenerator extends Generator { ); } else { output.writeln( - ' value.${field.name} = $readerFunctionName(readGeneratedStructFieldInfoValue(context, fields[$index], value.${field.name}), value.${field.name});', + ' value.${field.name} = $readerFunctionName(readGeneratedStructDescriptorValue(context, fields[$index], value.${field.name}), value.${field.name});', ); } final directPrimitiveEndRun = directPrimitiveRunByEnd[index]; @@ -739,7 +742,7 @@ final class ForyGenerator extends Generator { ); } else { output.writeln( - ' final ${field.displayType} ${field.localName} = $readerFunctionName(readGeneratedStructFieldInfoValue(context, fields[$index]));', + ' final ${field.displayType} ${field.localName} = $readerFunctionName(readGeneratedStructDescriptorValue(context, fields[$index]));', ); } final directPrimitiveEndRun = directPrimitiveRunByEnd[index]; @@ -754,7 +757,7 @@ final class ForyGenerator extends Generator { final constructorInvocation = _constructorInvocation(structSpec); output.writeln(' final value = $constructorInvocation;'); for (final fieldName - in structSpec.constructorPlan.postConstructionFieldNames) { + in structSpec.constructionModel.postConstructionFieldNames) { final field = structSpec.fields.firstWhere( (item) => item.name == fieldName, ); @@ -764,7 +767,12 @@ final class ForyGenerator extends Generator { } output.writeln(' }'); - _writeCompatibleStructReadMethod(output, structSpec); + _writeCompatibleStructReadMethod( + output, + structSpec, + readUsesBuffer: readUsesBuffer, + hasRuntimeFastPath: hasRuntimeFastPath, + ); output ..writeln('}') ..writeln(); @@ -788,15 +796,18 @@ final class ForyGenerator extends Generator { void _writeCompatibleStructReadMethod( StringBuffer output, - _GeneratedStructSpec structSpec, - ) { + _GeneratedStructSpec structSpec, { + required bool readUsesBuffer, + required bool hasRuntimeFastPath, + }) { + final splitFallback = structSpec.fields.length >= 16; output ..writeln() ..writeln(' @override') ..writeln( ' ${structSpec.name} readCompatibleStruct(ReadContext context, CompatibleStructReadLayout layout) {', ); - switch (structSpec.constructorPlan.mode) { + switch (structSpec.constructionModel.mode) { case _ConstructorMode.mutable: output.writeln(' final value = ${structSpec.name}();'); if (_structNeedsEarlyReadReference(structSpec)) { @@ -805,93 +816,317 @@ final class ForyGenerator extends Generator { ..writeln(' context.reference(value);') ..writeln(' }'); } - output - ..writeln( - ' for (var index = 0; index < layout.fieldCount; index += 1) {', - ) - ..writeln(' final field = layout.localFieldAt(index);') - ..writeln(' if (field == null) {') - ..writeln( - ' skipGeneratedCompatibleStructField(context, layout, index);', - ) - ..writeln(' continue;') - ..writeln(' }') - ..writeln(' switch (field.index) {'); - for (var index = 0; index < structSpec.fields.length; index += 1) { - final field = structSpec.fields[index]; - final readerFunctionName = field.readerFunctionName(structSpec.name); - output - ..writeln(' case $index:') - ..writeln( - ' value.${field.name} = $readerFunctionName(readGeneratedCompatibleStructField(context, layout, index), value.${field.name});', - ) - ..writeln(' break;'); + if (!splitFallback && readUsesBuffer) { + output.writeln(' final buffer = context.buffer;'); } - output - ..writeln(' default:') - ..writeln( - " throw StateError('Compatible field index is out of range for ${structSpec.name}.');", - ) - ..writeln(' }') - ..writeln(' }') - ..writeln(' return value;'); - case _ConstructorMode.constructor: - for (var index = 0; index < structSpec.fields.length; index += 1) { - final field = structSpec.fields[index]; - output - ..writeln(' late final ${field.displayType} ${field.localName};') - ..writeln(' var hasField$index = false;'); + if (!splitFallback && hasRuntimeFastPath) { + output.writeln(' final fields = _readFields(context);'); } - output - ..writeln( - ' for (var index = 0; index < layout.fieldCount; index += 1) {', - ) - ..writeln(' final field = layout.localFieldAt(index);') - ..writeln(' if (field == null) {') - ..writeln( - ' skipGeneratedCompatibleStructField(context, layout, index);', - ) - ..writeln(' continue;') - ..writeln(' }') - ..writeln(' switch (field.index) {'); - for (var index = 0; index < structSpec.fields.length; index += 1) { - final field = structSpec.fields[index]; - final readerFunctionName = field.readerFunctionName(structSpec.name); - output - ..writeln(' case $index:') - ..writeln( - ' ${field.localName} = $readerFunctionName(readGeneratedCompatibleStructField(context, layout, index));', - ) - ..writeln(' hasField$index = true;') - ..writeln(' break;'); + if (splitFallback) { + output.writeln( + ' return _readCompatibleStructFallback(context, layout, value);', + ); + } else { + _writeMutableCompatibleSwitchFallback(output, structSpec); } - output - ..writeln(' default:') - ..writeln( - " throw StateError('Compatible field index is out of range for ${structSpec.name}.');", - ) - ..writeln(' }') - ..writeln(' }'); - for (var index = 0; index < structSpec.fields.length; index += 1) { - final field = structSpec.fields[index]; - final readerFunctionName = field.readerFunctionName(structSpec.name); - output - ..writeln(' if (!hasField$index) {') - ..writeln(' ${field.localName} = $readerFunctionName(null);') - ..writeln(' }'); + case _ConstructorMode.constructor: + if (!splitFallback && readUsesBuffer) { + output.writeln(' final buffer = context.buffer;'); } - final constructorInvocation = _constructorInvocation(structSpec); - output.writeln(' final value = $constructorInvocation;'); - for (final fieldName - in structSpec.constructorPlan.postConstructionFieldNames) { - final field = structSpec.fields.firstWhere( - (item) => item.name == fieldName, + if (!splitFallback && hasRuntimeFastPath) { + output.writeln(' final fields = _readFields(context);'); + } + if (splitFallback) { + output.writeln( + ' return _readCompatibleStructFallback(context, layout);', ); - output.writeln(' value.${field.name} = ${field.localName};'); + } else { + _writeCtorCompatSwitch(output, structSpec); } - output.writeln(' return value;'); } output.writeln(' }'); + if (splitFallback) { + _writeCompatibleStructFallbackMethod( + output, + structSpec, + readUsesBuffer: readUsesBuffer, + hasRuntimeFastPath: hasRuntimeFastPath, + ); + } + } + + void _writeCompatibleStructFallbackMethod( + StringBuffer output, + _GeneratedStructSpec structSpec, { + required bool readUsesBuffer, + required bool hasRuntimeFastPath, + }) { + final mutable = + structSpec.constructionModel.mode == _ConstructorMode.mutable; + final valueParameter = mutable ? ', ${structSpec.name} value' : ''; + output + ..writeln() + ..writeln(" @pragma('vm:never-inline')") + ..writeln( + ' ${structSpec.name} _readCompatibleStructFallback(ReadContext context, CompatibleStructReadLayout layout$valueParameter) {', + ); + if (readUsesBuffer) { + output.writeln(' final buffer = context.buffer;'); + } + if (hasRuntimeFastPath) { + output.writeln(' final fields = _readFields(context);'); + } + if (mutable) { + _writeMutableCompatibleSwitchFallback(output, structSpec); + } else { + _writeCtorCompatSwitch(output, structSpec); + } + output.writeln(' }'); + } + + void _writeMutableCompatibleSwitchFallback( + StringBuffer output, + _GeneratedStructSpec structSpec, + ) { + output + ..writeln( + ' for (var index = 0; index < layout.fieldCount; index += 1) {', + ) + ..writeln(' final field = layout.fieldAt(index);') + ..writeln(' switch (field.matchedId) {') + ..writeln(' case -1:') + ..writeln(' skipGeneratedCompatibleStructField(context, field);') + ..writeln(' break;'); + for (var index = 0; index < structSpec.fields.length; index += 1) { + final field = structSpec.fields[index]; + output.writeln(' case ${index * 2}:'); + _writeExactCompatRead( + output, + structSpec, + field, + index, + 'value.${field.name}', + 'value.${field.name}', + ' ', + ); + output + ..writeln(' break;') + ..writeln(' case ${index * 2 + 1}:'); + _writeCompatConversionRead( + output, + structSpec, + field, + index, + 'value.${field.name}', + 'value.${field.name}', + 'field', + ' ', + ); + output.writeln(' break;'); + } + output + ..writeln(' default:') + ..writeln( + " throw StateError('Compatible matched id is out of range for ${structSpec.name}.');", + ) + ..writeln(' }') + ..writeln(' }') + ..writeln(' return value;'); + } + + void _writeCtorCompatSwitch( + StringBuffer output, + _GeneratedStructSpec structSpec, + ) { + for (var index = 0; index < structSpec.fields.length; index += 1) { + final field = structSpec.fields[index]; + output + ..writeln(' late final ${field.displayType} ${field.localName};') + ..writeln(' var hasField$index = false;'); + } + output + ..writeln( + ' for (var index = 0; index < layout.fieldCount; index += 1) {', + ) + ..writeln(' final field = layout.fieldAt(index);') + ..writeln(' switch (field.matchedId) {') + ..writeln(' case -1:') + ..writeln(' skipGeneratedCompatibleStructField(context, field);') + ..writeln(' break;'); + for (var index = 0; index < structSpec.fields.length; index += 1) { + final field = structSpec.fields[index]; + output.writeln(' case ${index * 2}:'); + _writeExactCompatRead( + output, + structSpec, + field, + index, + field.localName, + null, + ' ', + ); + output + ..writeln(' hasField$index = true;') + ..writeln(' break;') + ..writeln(' case ${index * 2 + 1}:'); + _writeCompatConversionRead( + output, + structSpec, + field, + index, + field.localName, + null, + 'field', + ' ', + ); + output + ..writeln(' hasField$index = true;') + ..writeln(' break;'); + } + output + ..writeln(' default:') + ..writeln( + " throw StateError('Compatible matched id is out of range for ${structSpec.name}.');", + ) + ..writeln(' }') + ..writeln(' }'); + for (var index = 0; index < structSpec.fields.length; index += 1) { + final field = structSpec.fields[index]; + final readerFunctionName = field.readerFunctionName(structSpec.name); + output + ..writeln(' if (!hasField$index) {') + ..writeln(' ${field.localName} = $readerFunctionName(null);') + ..writeln(' }'); + } + final constructorInvocation = _constructorInvocation(structSpec); + output.writeln(' final value = $constructorInvocation;'); + for (final fieldName + in structSpec.constructionModel.postConstructionFieldNames) { + final field = structSpec.fields.firstWhere( + (item) => item.name == fieldName, + ); + output.writeln(' value.${field.name} = ${field.localName};'); + } + output.writeln(' return value;'); + } + + void _writeCompatConversionRead( + StringBuffer output, + _GeneratedStructSpec structSpec, + _GeneratedFieldSpec field, + int index, + String target, + String? fallback, + String readField, + String indent, + ) { + final directScalarRead = _directCompatibleScalarRead(field); + if (directScalarRead != null) { + final scalarRead = 'scalarRead$index'; + final fallbackArg = fallback == null ? '' : ', $fallback'; + output.writeln('${indent}final $scalarRead = $readField.scalarRead!;'); + output.writeln( + '$indent$target = ${directScalarRead.method}(context, $scalarRead$fallbackArg);', + ); + return; + } + final readerFunctionName = field.readerFunctionName(structSpec.name); + output.writeln( + '$indent$target = ${_readerCall(readerFunctionName, 'readGeneratedCompatibleStructField(context, $readField)', fallback)};', + ); + } + + void _writeExactCompatRead( + StringBuffer output, + _GeneratedStructSpec structSpec, + _GeneratedFieldSpec field, + int index, + String target, + String? fallback, + String indent, + ) { + final readerFunctionName = field.readerFunctionName(structSpec.name); + if (_usesDirectGeneratedBasicFastPath(field)) { + output.writeln( + '$indent$target = ${_directGeneratedReadExpression(field)};', + ); + return; + } + if (_usesDirectGeneratedTypedContainerReadFastPath(field)) { + output.writeln( + '$indent$target = ${_directGeneratedTypedContainerReadExpression(structSpec.name, field, 'fields[$index]')};', + ); + return; + } + if (_usesDirectGeneratedStructFieldFastPath(field)) { + output.writeln( + '$indent$target = ${_readerCall(readerFunctionName, 'readGeneratedStructDirectValue(context, fields[$index])', fallback)};', + ); + return; + } + if (_usesDirectGeneratedDeclaredReadFastPath(field)) { + output.writeln( + '$indent$target = ${_readerCall(readerFunctionName, 'readGeneratedStructDeclaredValue(context, fields[$index])', fallback)};', + ); + return; + } + final valueExpression = + fallback == null + ? 'readGeneratedStructDescriptorValue(context, fields[$index])' + : 'readGeneratedStructDescriptorValue(context, fields[$index], $fallback)'; + output.writeln('$indent$target = $readerFunctionName($valueExpression);'); + } + + _DirectCompatibleScalarRead? _directCompatibleScalarRead( + _GeneratedFieldSpec field, + ) { + if (field.fieldType.nullable || + field.fieldType.ref || + _isGeneratedDynamicField(field)) { + return null; + } + final typeLiteral = _typeLiteral(_withoutNullability(field.type)); + if (typeLiteral == 'int' && + (field.fieldType.typeId == TypeIds.int8 || + field.fieldType.typeId == TypeIds.int16 || + field.fieldType.typeId == TypeIds.int32 || + field.fieldType.typeId == TypeIds.varInt32 || + field.fieldType.typeId == TypeIds.int64 || + field.fieldType.typeId == TypeIds.varInt64 || + field.fieldType.typeId == TypeIds.taggedInt64 || + field.fieldType.typeId == TypeIds.uint8 || + field.fieldType.typeId == TypeIds.uint16 || + field.fieldType.typeId == TypeIds.uint32 || + field.fieldType.typeId == TypeIds.varUint32)) { + return const _DirectCompatibleScalarRead('readGenCompatScalarAsInt'); + } + if (typeLiteral == 'double' && + (field.fieldType.typeId == TypeIds.float32 || + field.fieldType.typeId == TypeIds.float64)) { + return const _DirectCompatibleScalarRead('readGenCompatScalarAsDouble'); + } + if (typeLiteral == 'Int64' && _isSigned64TypeId(field.fieldType.typeId)) { + return const _DirectCompatibleScalarRead('readGenCompatScalarAsInt64'); + } + if (typeLiteral == 'Uint64' && + _isUnsigned64TypeId(field.fieldType.typeId)) { + return const _DirectCompatibleScalarRead('readGenCompatScalarAsUint64'); + } + if (typeLiteral == 'Float32' && field.fieldType.typeId == TypeIds.float32) { + return const _DirectCompatibleScalarRead('readGenCompatScalarAsFloat32'); + } + return null; + } + + String _readerCall( + String functionName, + String valueExpression, + String? fallback, + ) { + if (fallback == null) { + return '$functionName($valueExpression)'; + } + return '$functionName($valueExpression, $fallback)'; } void _writeGeneratedSupport( @@ -966,7 +1201,7 @@ final class ForyGenerator extends Generator { String _constructorInvocation(_GeneratedStructSpec structSpec) { final positionalArguments = []; final namedArguments = []; - for (final argument in structSpec.constructorPlan.arguments) { + for (final argument in structSpec.constructionModel.arguments) { final field = structSpec.fields.firstWhere( (item) => item.name == argument.fieldName, ); @@ -1081,7 +1316,7 @@ GeneratedFieldType( if (_isNullable(type)) { return valueExpression; } - return '$valueExpression == null ? $nullExpression : $valueExpression'; + return '$valueExpression ?? $nullExpression'; } if (_isNullable(type)) { final nonNullableType = _withoutNullability(type); @@ -1222,7 +1457,7 @@ GeneratedFieldType( bool _usesDirectGeneratedBasicFastPath(_GeneratedFieldSpec field) { if (field.fieldType.nullable || field.fieldType.ref || - field.fieldType.dynamic == true) { + _isGeneratedDynamicField(field)) { return false; } return _isPrimitiveTypeId(field.fieldType.typeId) || @@ -1294,7 +1529,7 @@ GeneratedFieldType( bool _usesDirectGeneratedDeclaredReadFastPath(_GeneratedFieldSpec field) { if (field.fieldType.nullable || field.fieldType.ref || - field.fieldType.dynamic == true) { + _isGeneratedDynamicField(field)) { return false; } final typeId = field.fieldType.typeId; @@ -1304,7 +1539,10 @@ GeneratedFieldType( bool _usesDirectGeneratedStructFieldFastPath(_GeneratedFieldSpec field) { if (field.fieldType.nullable || field.fieldType.ref || - field.fieldType.dynamic == true) { + _isGeneratedDynamicField(field)) { + return false; + } + if (!_isGeneratedStructType(field.type)) { return false; } final typeId = field.fieldType.typeId; @@ -1319,7 +1557,7 @@ GeneratedFieldType( ) { if (field.fieldType.nullable || field.fieldType.ref || - field.fieldType.dynamic == true) { + _isGeneratedDynamicField(field)) { return false; } if (_isBoolList(field.type)) { @@ -1330,12 +1568,18 @@ GeneratedFieldType( field.fieldType.typeId == TypeIds.map; } + bool _isGeneratedStructType(DartType type) { + final element = _withoutNullability(type).element; + return element is ClassElement && + _foryStructChecker.hasAnnotationOf(element); + } + bool _usesDirectGeneratedTypedContainerWriteFastPath( _GeneratedFieldSpec field, ) { if (field.fieldType.nullable || field.fieldType.ref || - field.fieldType.dynamic == true) { + _isGeneratedDynamicField(field)) { return false; } if (_isBoolList(field.type)) { @@ -1346,13 +1590,29 @@ GeneratedFieldType( return false; } final elementFieldType = field.fieldType.arguments.single; - if (elementFieldType.ref || elementFieldType.dynamic == true) { + if (elementFieldType.ref || _isGeneratedDynamicType(elementFieldType)) { return false; } final elementType = (field.type as InterfaceType).typeArguments.single; return !_isNullable(elementType); } + bool _writeUsesFieldDescriptor(_GeneratedFieldSpec field) { + if (_usesDirectGeneratedBasicFastPath(field)) { + return false; + } + if (_usesDirectGeneratedTypedContainerWriteFastPath(field)) { + return true; + } + return !_isGeneratedDynamicField(field); + } + + bool _isGeneratedDynamicField(_GeneratedFieldSpec field) => + _isGeneratedDynamicType(field.fieldType); + + bool _isGeneratedDynamicType(_GeneratedFieldTypeSpec fieldType) => + fieldType.dynamic == true || fieldType.typeId == TypeIds.unknown; + String _directGeneratedTypedContainerWriteStatement( _GeneratedFieldSpec field, int fieldIndex, @@ -1372,14 +1632,16 @@ GeneratedFieldType( } List<_DirectGeneratedPrimitiveRun> _directGeneratedPrimitiveRuns( - List<_GeneratedFieldSpec> fields, - ) { + List<_GeneratedFieldSpec> fields, { + bool includeCoreInt64Varints = false, + }) { final runs = <_DirectGeneratedPrimitiveRun>[]; int? start; var bytes = 0; for (var index = 0; index < fields.length; index += 1) { final fieldBytes = _directGeneratedPrimitiveReservationBytes( fields[index], + includeCoreInt64Varints: includeCoreInt64Varints, ); if (fieldBytes == null) { if (start != null) { @@ -1398,11 +1660,21 @@ GeneratedFieldType( return runs; } - bool _usesReservedGeneratedFastPath(_GeneratedFieldSpec field) { - return _directGeneratedPrimitiveReservationBytes(field) != null; + bool _usesReservedGeneratedFastPath( + _GeneratedFieldSpec field, { + bool includeCoreInt64Varints = false, + }) { + return _directGeneratedPrimitiveReservationBytes( + field, + includeCoreInt64Varints: includeCoreInt64Varints, + ) != + null; } - int? _directGeneratedPrimitiveReservationBytes(_GeneratedFieldSpec field) { + int? _directGeneratedPrimitiveReservationBytes( + _GeneratedFieldSpec field, { + bool includeCoreInt64Varints = false, + }) { if (!_usesDirectGeneratedBasicFastPath(field)) { return null; } @@ -1425,6 +1697,9 @@ GeneratedFieldType( case TypeIds.varInt32: case TypeIds.varUint32: return 5; + case TypeIds.varInt64: + case TypeIds.varUint64: + return includeCoreInt64Varints && field.type.isDartCoreInt ? 9 : null; default: return null; } @@ -1438,7 +1713,9 @@ GeneratedFieldType( switch (fields[index].fieldType.typeId) { case TypeIds.boolType: case TypeIds.varInt32: + case TypeIds.varInt64: case TypeIds.varUint32: + case TypeIds.varUint64: return true; } } @@ -1453,7 +1730,9 @@ GeneratedFieldType( switch (fields[index].fieldType.typeId) { case TypeIds.boolType: case TypeIds.varInt32: + case TypeIds.varInt64: case TypeIds.varUint32: + case TypeIds.varUint64: break; default: return true; @@ -1713,6 +1992,36 @@ GeneratedFieldType( output.writeln( '$indent$target = (($result >>> 1) ^ -($result & 1)).toSigned(32);', ); + case TypeIds.varInt64: + if (!field.type.isDartCoreInt) { + throw StateError( + 'Generated varint64 cursor read is only supported for Dart int fields.', + ); + } + final assignTarget = _writeGeneratedReadAssignmentTarget( + output, + target, + indent, + ); + final result = 'result$fieldIndex'; + output + ..writeln('$indent if (generatedIsWeb) {') + ..writeln('$indent bufferSetReaderIndex(buffer, $offset);') + ..writeln('$indent $assignTarget = buffer.readVarInt64AsInt();') + ..writeln('$indent $offset = bufferReaderIndex(buffer);') + ..writeln('$indent } else {'); + _writeDirectGeneratedVarUint64Read( + output, + result, + bytes, + offset, + '$indent ', + ); + output + ..writeln( + '$indent $assignTarget = ($result >>> 1) ^ -($result & 1);', + ) + ..writeln('$indent }'); case TypeIds.varUint32: final result = 'result$fieldIndex'; _writeDirectGeneratedVarUint32Read( @@ -1723,6 +2032,34 @@ GeneratedFieldType( indent, ); output.writeln('$indent$target = $result;'); + case TypeIds.varUint64: + if (!field.type.isDartCoreInt) { + throw StateError( + 'Generated varuint64 cursor read is only supported for Dart int fields.', + ); + } + final assignTarget = _writeGeneratedReadAssignmentTarget( + output, + target, + indent, + ); + final result = 'result$fieldIndex'; + output + ..writeln('$indent if (generatedIsWeb) {') + ..writeln('$indent bufferSetReaderIndex(buffer, $offset);') + ..writeln('$indent $assignTarget = buffer.readVarUint64().toInt();') + ..writeln('$indent $offset = bufferReaderIndex(buffer);') + ..writeln('$indent } else {'); + _writeDirectGeneratedVarUint64Read( + output, + result, + bytes, + offset, + '$indent ', + ); + output + ..writeln('$indent $assignTarget = $result;') + ..writeln('$indent }'); default: throw StateError( 'Unsupported generated direct buffer read fast path for ${field.name}.', @@ -1730,6 +2067,23 @@ GeneratedFieldType( } } + String _writeGeneratedReadAssignmentTarget( + StringBuffer output, + String target, + String indent, + ) { + if (!target.startsWith('final ')) { + return target; + } + final nameStart = target.lastIndexOf(' ') + 1; + final name = target.substring(nameStart); + final typePrefix = target + .substring(0, nameStart) + .replaceFirst('final ', 'late final '); + output.writeln('$indent$typePrefix$name;'); + return name; + } + void _writeDirectGeneratedVarUint32Read( StringBuffer output, String result, @@ -1753,6 +2107,34 @@ GeneratedFieldType( ..writeln('$indent }'); } + void _writeDirectGeneratedVarUint64Read( + StringBuffer output, + String result, + String bytes, + String offset, + String indent, + ) { + final shift = '${result}Shift'; + final byte = '${result}Byte'; + output + ..writeln('$indent var $shift = 0;') + ..writeln('$indent var $result = 0;') + ..writeln('$indent while ($shift < 56) {') + ..writeln('$indent final $byte = $bytes[$offset];') + ..writeln('$indent $offset += 1;') + ..writeln('$indent $result |= ($byte & 0x7f) << $shift;') + ..writeln('$indent if (($byte & 0x80) == 0) {') + ..writeln('$indent break;') + ..writeln('$indent }') + ..writeln('$indent $shift += 7;') + ..writeln('$indent }') + ..writeln('$indent if ($shift == 56) {') + ..writeln('$indent final $byte = $bytes[$offset];') + ..writeln('$indent $offset += 1;') + ..writeln('$indent $result |= $byte << 56;') + ..writeln('$indent }'); + } + String _directGeneratedWriteStatement( _GeneratedFieldSpec field, String valueExpression, @@ -2009,10 +2391,134 @@ GeneratedFieldType( } } - String _generatedFieldInfoWriteValueExpression( + void _writeGeneratedDescriptorValue( + StringBuffer output, + _GeneratedFieldSpec field, + int index, + String valueExpression, + String indent, + ) { + final value = 'field${index}Value'; + output.writeln('${indent}final $value = $valueExpression;'); + if (_isGeneratedDynamicField(field)) { + _writeGeneratedDynamicValue(output, field, value, indent); + return; + } + final descriptor = 'field$index'; + final fieldType = 'field${index}Type'; + output + ..writeln('${indent}final $descriptor = fields[$index];') + ..writeln('${indent}final $fieldType = $descriptor.fieldType;') + ..writeln( + '${indent}final field${index}Declared = $descriptor.declaredTypeInfo;', + ) + ..writeln( + '${indent}if (field${index}Declared != null && $descriptor.usesDeclaredType) {', + ); + _writeGeneratedDeclaredValue( + output, + field, + resolved: 'field${index}Declared', + fieldType: fieldType, + value: value, + indent: '$indent ', + ); + output.writeln('$indent} else {'); + _writeGeneratedUndeclaredValue( + output, + field, + fieldType: fieldType, + value: value, + indent: '$indent ', + ); + output.writeln('$indent}'); + } + + void _writeGeneratedDynamicValue( + StringBuffer output, _GeneratedFieldSpec field, String valueExpression, - ) => valueExpression; + String indent, + ) { + if (field.fieldType.ref) { + output.writeln('${indent}context.writeRef($valueExpression);'); + return; + } + output + ..writeln( + '${indent}if (!context.refWriter.writeRefOrNull(context.buffer, $valueExpression, trackRef: false)) {', + ) + ..writeln( + '$indent context.writeNonRef(${_nonNullObjectExpression(field, valueExpression)});', + ) + ..writeln('$indent}'); + } + + void _writeGeneratedDeclaredValue( + StringBuffer output, + _GeneratedFieldSpec field, { + required String resolved, + required String fieldType, + required String value, + required String indent, + }) { + if (field.fieldType.nullable || field.fieldType.ref) { + final trackRef = field.fieldType.ref ? '$resolved.supportsRef' : 'false'; + output + ..writeln( + '${indent}if (!context.refWriter.writeRefOrNull(context.buffer, $value, trackRef: $trackRef)) {', + ) + ..writeln( + '$indent context.writeResolvedValue($resolved, $value as Object, $fieldType);', + ) + ..writeln('$indent}'); + return; + } + output + ..writeln('${indent}if ($value == null) {') + ..writeln( + "$indent throw StateError('Field ${field.name} is not nullable.');", + ) + ..writeln('$indent}') + ..writeln( + '${indent}context.writeResolvedValue($resolved, $value as Object, $fieldType);', + ); + } + + void _writeGeneratedUndeclaredValue( + StringBuffer output, + _GeneratedFieldSpec field, { + required String fieldType, + required String value, + required String indent, + }) { + if (field.fieldType.ref) { + output.writeln('${indent}context.writeRef($value);'); + return; + } + if (field.fieldType.nullable) { + output + ..writeln( + '${indent}if (!context.refWriter.writeRefOrNull(context.buffer, $value, trackRef: false)) {', + ) + ..writeln('$indent context.writeNonRef($value as Object);') + ..writeln('$indent}'); + return; + } + output + ..writeln('${indent}if ($value == null) {') + ..writeln( + "$indent throw StateError('Field ${field.name} is not nullable.');", + ) + ..writeln('$indent}') + ..writeln( + '${indent}final actualResolved = context.typeResolver.resolveValue($value as Object);', + ) + ..writeln('${indent}context.writeTypeMetaValue(actualResolved, $value);') + ..writeln( + '${indent}context.writeResolvedValue(actualResolved, $value, $fieldType);', + ); + } String _nullExpression( DartType type, { @@ -2024,11 +2530,25 @@ GeneratedFieldType( return 'null as $displayType'; } if (fallbackExpression != null) { + if (_withoutNullability(type).isDartCoreObject) { + return '($fallbackExpression ?? (throw StateError(\'Received null for non-nullable $errorTarget.\')))'; + } return '($fallbackExpression != null ? $fallbackExpression as $displayType : (throw StateError(\'Received null for non-nullable $errorTarget.\')))'; } return '(throw StateError(\'Received null for non-nullable $errorTarget.\'))'; } + String _nonNullObjectExpression( + _GeneratedFieldSpec field, + String valueExpression, + ) { + if (_withoutNullability(field.type).isDartCoreObject && + !_isNullable(field.type)) { + return valueExpression; + } + return '$valueExpression as Object'; + } + _GeneratedFieldTypeSpec _nonNullableFieldType( _GeneratedFieldTypeSpec fieldType, ) { @@ -3226,7 +3746,8 @@ GeneratedFieldType( } bool _fieldTypeUsesNestedTypeDefinitions(_GeneratedFieldTypeSpec fieldType) { - if (fieldType.dynamic == true || TypeIds.isUserType(fieldType.typeId)) { + if (_isGeneratedDynamicType(fieldType) || + TypeIds.isUserType(fieldType.typeId)) { return true; } for (final argument in fieldType.arguments) { @@ -3249,16 +3770,22 @@ final class _GeneratedStructSpec { final String name; final bool evolving; final List<_GeneratedFieldSpec> fields; - final _ConstructorPlan constructorPlan; + final _ConstructionModel constructionModel; const _GeneratedStructSpec({ required this.name, required this.evolving, required this.fields, - required this.constructorPlan, + required this.constructionModel, }); } +final class _DirectCompatibleScalarRead { + final String method; + + const _DirectCompatibleScalarRead(this.method); +} + final class _GeneratedFieldSpec { final String name; final DartType type; @@ -3321,17 +3848,17 @@ final class _GeneratedFieldTypeSpec { enum _ConstructorMode { mutable, constructor } -final class _ConstructorPlan { +final class _ConstructionModel { final _ConstructorMode mode; final List<_ConstructorArgumentSpec> arguments; final List postConstructionFieldNames; - const _ConstructorPlan.mutable() + const _ConstructionModel.mutable() : mode = _ConstructorMode.mutable, arguments = const <_ConstructorArgumentSpec>[], postConstructionFieldNames = const []; - const _ConstructorPlan.constructor({ + const _ConstructionModel.constructor({ required this.arguments, required this.postConstructionFieldNames, }) : mode = _ConstructorMode.constructor; diff --git a/dart/packages/fory/lib/src/codegen/generated_support.dart b/dart/packages/fory/lib/src/codegen/generated_support.dart index 58c8167945..709f9cd3e1 100644 --- a/dart/packages/fory/lib/src/codegen/generated_support.dart +++ b/dart/packages/fory/lib/src/codegen/generated_support.dart @@ -122,6 +122,23 @@ final class GeneratedEnumSchema { @internal typedef GeneratedStructFieldInfo = SerializationFieldInfo; +@internal +final class GeneratedStructFieldDescriptor { + final GeneratedStructFieldInfo field; + final resolver.TypeInfo? declaredTypeInfo; + final bool usesDeclaredType; + + const GeneratedStructFieldDescriptor({ + required this.field, + required this.declaredTypeInfo, + required this.usesDeclaredType, + }); + + String get name => field.name; + + meta_types.FieldType get fieldType => field.fieldType; +} + @internal final class GeneratedStructSchema { final Type type; @@ -222,7 +239,9 @@ int generatedCheckedUint16(int value) => checkedUint16(value); int generatedCheckedUint32(int value) => checkedUint32(value); const int _generatedJsSafeUint64IntMax = 9007199254740991; -const bool _generatedIsWeb = bool.fromEnvironment('dart.library.js_interop'); +@internal +const bool generatedIsWeb = bool.fromEnvironment('dart.library.js_interop'); +const bool _generatedIsWeb = generatedIsWeb; @internal @pragma('vm:prefer-inline') @@ -366,6 +385,22 @@ List buildGeneratedStructFieldInfos( .localFields; } +@internal +List buildGeneratedStructFieldDescriptors( + resolver.TypeResolver typeResolver, + GeneratedStructSchema schema, +) { + final fields = buildGeneratedStructFieldInfos(typeResolver, schema); + return List.generate(fields.length, (index) { + final field = fields[index]; + return GeneratedStructFieldDescriptor( + field: field, + declaredTypeInfo: fieldDeclaredTypeInfo(typeResolver, field), + usesDeclaredType: fieldUsesDeclaredType(typeResolver, field), + ); + }, growable: false); +} + @internal List buildGeneratedUnionCaseFieldInfos( List fields, @@ -441,36 +476,11 @@ Object? readGeneratedUnionCaseValue( return value; } -@internal -void writeGeneratedStructFieldInfoValue( - WriteContext context, - GeneratedStructFieldInfo field, - Object? value, -) { - final fieldType = field.fieldType; - if (!fieldType.isDynamic && !fieldType.ref && !fieldType.nullable) { - if (fieldType.isPrimitive) { - context.writePrimitiveValue(fieldType.typeId, value as Object); - return; - } - final resolved = fieldDeclaredTypeInfo(context.typeResolver, field)!; - if (fieldUsesDeclaredType(context.typeResolver, field)) { - context.writeResolvedValue(resolved, value as Object, fieldType); - return; - } - final actualResolved = context.typeResolver.resolveValue(value as Object); - context.writeTypeMetaValue(actualResolved, value); - context.writeResolvedValue(actualResolved, value, fieldType); - return; - } - writeFieldValue(context, field, value); -} - @internal @pragma('vm:prefer-inline') -Object? readGeneratedStructFieldInfoValue( +Object? readGeneratedStructDescriptorValue( ReadContext context, - GeneratedStructFieldInfo field, [ + GeneratedStructFieldDescriptor field, [ Object? fallback, ]) { final fieldType = field.fieldType; @@ -484,24 +494,24 @@ Object? readGeneratedStructFieldInfoValue( fieldType, ); } - final resolved = fieldDeclaredTypeInfo(context.typeResolver, field)!; - if (fieldUsesDeclaredType(context.typeResolver, field)) { + final resolved = field.declaredTypeInfo!; + if (field.usesDeclaredType) { return context.readResolvedValue(resolved, fieldType); } final actualResolved = context.readTypeMetaValue(resolved); return context.readResolvedValue(actualResolved, fieldType); } - return readFieldValue(context, field, fallback); + return readFieldValue(context, field.field, fallback); } @internal @pragma('vm:prefer-inline') Object? readGeneratedStructDeclaredValue( ReadContext context, - GeneratedStructFieldInfo field, + GeneratedStructFieldDescriptor field, ) { - final resolved = fieldDeclaredTypeInfo(context.typeResolver, field)!; - if (fieldUsesDeclaredType(context.typeResolver, field)) { + final resolved = field.declaredTypeInfo!; + if (field.usesDeclaredType) { return context.readResolvedValue(resolved, field.fieldType); } final actualResolved = context.readTypeMetaValue(resolved); @@ -512,17 +522,21 @@ Object? readGeneratedStructDeclaredValue( @pragma('vm:prefer-inline') Object readGeneratedStructDirectValue( ReadContext context, - GeneratedStructFieldInfo field, + GeneratedStructFieldDescriptor field, ) { - final declared = fieldDeclaredTypeInfo(context.typeResolver, field)!; + final declared = field.declaredTypeInfo!; final resolver.TypeInfo resolved; - if (fieldUsesDeclaredType(context.typeResolver, field)) { + if (field.usesDeclaredType) { resolved = declared; } else { resolved = context.readTypeMetaValue(declared); } context.increaseDepth(); - final value = resolved.structSerializer!.readValue(context, resolved); + final structSerializer = resolved.structSerializer!; + final value = + resolved.remoteTypeDef == null + ? structSerializer.readValue(context, resolved) + : structSerializer.readGeneratedCompatibleValue(context, resolved); context.decreaseDepth(); return value; } @@ -530,100 +544,61 @@ Object readGeneratedStructDirectValue( @internal void writeGeneratedDirectListValue( WriteContext context, - GeneratedStructFieldInfo field, + GeneratedStructFieldDescriptor field, List value, ) { - final fieldType = field.fieldType; - if (fieldType.typeId != TypeIds.list || - fieldType.nullable || - fieldType.ref || - fieldType.isDynamic) { - throw StateError('Field ${field.name} is not a direct list path.'); - } - final elementFieldType = fieldType.arguments.single; - if (elementFieldType.ref || elementFieldType.isDynamic) { - throw StateError( - 'Field ${field.name} element type is not a direct list path.', - ); - } - writeTypedListPayload(context, value, elementFieldType); + writeTypedListPayload(context, value, field.fieldType.arguments.single); } @internal void writeGeneratedDirectSetValue( WriteContext context, - GeneratedStructFieldInfo field, + GeneratedStructFieldDescriptor field, Set value, ) { - final fieldType = field.fieldType; - if (fieldType.typeId != TypeIds.set || - fieldType.nullable || - fieldType.ref || - fieldType.isDynamic) { - throw StateError('Field ${field.name} is not a direct set path.'); - } - final elementFieldType = fieldType.arguments.single; - if (elementFieldType.ref || elementFieldType.isDynamic) { - throw StateError( - 'Field ${field.name} element type is not a direct set path.', - ); - } - writeTypedSetPayload(context, value, elementFieldType); + writeTypedSetPayload(context, value, field.fieldType.arguments.single); } @internal @pragma('vm:prefer-inline') List readGeneratedDirectListValue( ReadContext context, - GeneratedStructFieldInfo field, + GeneratedStructFieldDescriptor field, T Function(Object? value) convert, ) { - final fieldType = field.fieldType; - if (fieldType.typeId != TypeIds.list || - fieldType.nullable || - fieldType.ref || - fieldType.isDynamic) { - throw StateError('Field ${field.name} is not a direct list path.'); - } - return readTypedListPayload(context, fieldType.arguments.single, convert); + return readTypedListPayload( + context, + field.fieldType.arguments.single, + convert, + ); } @internal @pragma('vm:prefer-inline') Set readGeneratedDirectSetValue( ReadContext context, - GeneratedStructFieldInfo field, + GeneratedStructFieldDescriptor field, T Function(Object? value) convert, ) { - final fieldType = field.fieldType; - if (fieldType.typeId != TypeIds.set || - fieldType.nullable || - fieldType.ref || - fieldType.isDynamic) { - throw StateError('Field ${field.name} is not a direct set path.'); - } - return readTypedSetPayload(context, fieldType.arguments.single, convert); + return readTypedSetPayload( + context, + field.fieldType.arguments.single, + convert, + ); } @internal @pragma('vm:prefer-inline') Map readGeneratedDirectMapValue( ReadContext context, - GeneratedStructFieldInfo field, + GeneratedStructFieldDescriptor field, K Function(Object? value) convertKey, V Function(Object? value) convertValue, ) { - final fieldType = field.fieldType; - if (fieldType.typeId != TypeIds.map || - fieldType.nullable || - fieldType.ref || - fieldType.isDynamic) { - throw StateError('Field ${field.name} is not a direct map path.'); - } return readTypedMapPayload( context, - fieldType.arguments[0], - fieldType.arguments[1], + field.fieldType.arguments[0], + field.fieldType.arguments[1], convertKey, convertValue, ); diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index 1ed9dbed9e..57b3d0cf28 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -87,16 +87,16 @@ final class ReadContext { @internal @pragma('vm:prefer-inline') - TypeInfo readTypeMetaValue([ - TypeInfo? expectedNamedType, - ]) => + TypeInfo readTypeMetaValue([TypeInfo? expectedNamedType]) => _readTypeMeta(expectedNamedType); @internal @pragma('vm:prefer-inline') Object? readSerializerPayload( - Serializer serializer, TypeInfo resolved, - {required bool hasCurrentPreservedRef}) { + Serializer serializer, + TypeInfo resolved, { + required bool hasCurrentPreservedRef, + }) { return serializer.read(this); } @@ -217,38 +217,43 @@ final class ReadContext { return _refReader.getReadRef(); } final expectedRootType = _typeResolver.expectedRootType(); - final typeMetaResolved = expectedRootType == null - ? _readTypeMeta() - : _typeResolver.readExpectedInitialTypeDefMeta( - _buffer, - expectedRootType, - sharedTypes: _sharedTypes, - ) ?? - _readTypeMeta(expectedRootType); - final resolved = - _typeResolver.resolveExpectedRootWireType(typeMetaResolved); - final rootPreservedRefId = preservedRefId == null && - flag == RefWriter.notNullValueFlag && - _depth == 0 && - resolved.needsRootRef - ? _refReader.preserveRefId() - : null; + final typeMetaResolved = + expectedRootType == null + ? _readTypeMeta() + : _typeResolver.readExpectedInitialTypeDefMeta( + _buffer, + expectedRootType, + sharedTypes: _sharedTypes, + ) ?? + _readTypeMeta(expectedRootType); + final resolved = _typeResolver.resolveExpectedRootWireType( + typeMetaResolved, + ); + final rootPreservedRefId = + preservedRefId == null && + flag == RefWriter.notNullValueFlag && + _depth == 0 && + resolved.needsRootRef + ? _refReader.preserveRefId() + : null; if (preservedRefId == null && rootPreservedRefId == null && expectedRootType != null && identical(resolved, expectedRootType) && resolved.kind == RegistrationKind.struct && resolved.remoteTypeDef == null) { - _depth += 1; - if (_depth > config.maxDepth) { - _throwMaxDepthExceeded(); - } - final value = resolved.structSerializer!.readSameTypeValue( - this, - resolved, - ); - _depth -= 1; - return value; + return _readRootStructValue(resolved, compatible: false); + } + if (preservedRefId == null && + rootPreservedRefId == null && + expectedRootType != null && + resolved.kind == RegistrationKind.struct && + resolved.remoteTypeDef != null && + identical( + resolved.structSerializer, + expectedRootType.structSerializer, + )) { + return _readRootStructValue(resolved, compatible: true); } final value = readResolvedValue( resolved, @@ -267,6 +272,19 @@ final class ReadContext { return value; } + Object _readRootStructValue(TypeInfo resolved, {required bool compatible}) { + _depth += 1; + if (_depth > config.maxDepth) { + _throwMaxDepthExceeded(); + } + final value = + compatible + ? resolved.structSerializer!.readValue(this, resolved) + : resolved.structSerializer!.readSameTypeValue(this, resolved); + _depth -= 1; + return value; + } + Object? _readRefWithResolved(TypeInfo Function(TypeInfo) resolveRootType) { final flag = _refReader.tryPreserveRefId(_buffer); final preservedRefId = flag >= RefWriter.refValueFlag ? flag : null; @@ -277,12 +295,13 @@ final class ReadContext { return _refReader.getReadRef(); } final resolved = resolveRootType(_readTypeMeta()); - final rootPreservedRefId = preservedRefId == null && - flag == RefWriter.notNullValueFlag && - _depth == 0 && - resolved.needsRootRef - ? _refReader.preserveRefId() - : null; + final rootPreservedRefId = + preservedRefId == null && + flag == RefWriter.notNullValueFlag && + _depth == 0 && + resolved.needsRootRef + ? _refReader.preserveRefId() + : null; final value = readResolvedValue( resolved, null, @@ -324,8 +343,11 @@ final class ReadContext { @internal @pragma('vm:prefer-inline') - Object? readResolvedValue(TypeInfo resolved, FieldType? declaredFieldType, - {bool hasPreservedRef = false}) { + Object? readResolvedValue( + TypeInfo resolved, + FieldType? declaredFieldType, { + bool hasPreservedRef = false, + }) { if (!_tracksDepth(resolved)) { return _readPayloadValue( resolved, @@ -442,9 +464,7 @@ final class ReadContext { } @pragma('vm:prefer-inline') - TypeInfo _readTypeMeta([ - TypeInfo? expectedNamedType, - ]) { + TypeInfo _readTypeMeta([TypeInfo? expectedNamedType]) { return _typeResolver.readTypeMeta( _buffer, expectedNamedType: expectedNamedType, diff --git a/dart/packages/fory/lib/src/resolver/type_resolver.dart b/dart/packages/fory/lib/src/resolver/type_resolver.dart index 12209a5f95..4ec28c03cd 100644 --- a/dart/packages/fory/lib/src/resolver/type_resolver.dart +++ b/dart/packages/fory/lib/src/resolver/type_resolver.dart @@ -787,15 +787,22 @@ final class TypeResolver { return null; } final marker = buffer.readVarUint32Small14(); - if (marker != 0) { - bufferSetReaderIndex(buffer, start); - return null; + if ((marker & 1) == 1) { + return sharedTypes[marker >>> 1]; } final header = TypeHeader(buffer.readInt64()); final expectedTypeDef = expected.typeDef; if (expectedTypeDef == null || expectedTypeDef.header != header.value) { - bufferSetReaderIndex(buffer, start); - return null; + final cached = _parsedTypeMetaCache.lookup(header); + if (cached != null) { + header.skipRemaining(buffer); + sharedTypes.add(cached); + return cached; + } + final resolved = _readTypeDefWithHeader(buffer, header); + _parsedTypeMetaCache.remember(header, resolved); + sharedTypes.add(resolved); + return resolved; } header.skipRemaining(buffer); sharedTypes.add(expected); @@ -1293,7 +1300,7 @@ final class TypeResolver { typeId: typeId, nullable: nullable, ref: ref, - dynamic: typeId == TypeIds.unknown ? true : false, + dynamic: typeId == TypeIds.unknown && !ref ? true : null, arguments: arguments, ); } diff --git a/dart/packages/fory/lib/src/serializer/collection_serializers.dart b/dart/packages/fory/lib/src/serializer/collection_serializers.dart index 125292cc7f..98d33ffc23 100644 --- a/dart/packages/fory/lib/src/serializer/collection_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/collection_serializers.dart @@ -121,10 +121,7 @@ bool tracksNestedPayloadDepth(TypeInfo typeInfo) { } } -bool sameTypeInfo( - TypeInfo left, - TypeInfo right, -) { +bool sameTypeInfo(TypeInfo left, TypeInfo right) { if (left.kind != right.kind || left.typeId != right.typeId) { return false; } @@ -208,9 +205,10 @@ T readFieldTypeValue( } if (fieldType.isPrimitive && !fieldType.nullable) { return convertPrimitiveFieldValue( - context.readPrimitiveValue(fieldType.typeId), - fieldType, - ) as T; + context.readPrimitiveValue(fieldType.typeId), + fieldType, + ) + as T; } if (!usesDeclaredType) { if (fieldType.ref) { @@ -251,12 +249,7 @@ final class ListSerializer extends Serializer { @override void write(WriteContext context, List value) { - writePayload( - context, - value, - null, - trackRef: context.rootTrackRef, - ); + writePayload(context, value, null, trackRef: context.rootTrackRef); } @override @@ -280,12 +273,14 @@ final class ListSerializer extends Serializer { if (size == 0) { return; } - final declaredTypeInfo = elementFieldType == null || - elementFieldType.isDynamic || - elementFieldType.typeId == TypeIds.unknown - ? null - : context.typeResolver.resolveFieldType(elementFieldType); - final usesDeclaredType = declaredTypeInfo != null && + final declaredTypeInfo = + elementFieldType == null || + elementFieldType.isDynamic || + elementFieldType.typeId == TypeIds.unknown + ? null + : context.typeResolver.resolveFieldType(elementFieldType); + final usesDeclaredType = + declaredTypeInfo != null && usesDeclaredTypeInfo( context.config.compatible, elementFieldType!, @@ -296,7 +291,8 @@ final class ListSerializer extends Serializer { values, usesDeclaredType: usesDeclaredType, ); - final elementTrackRef = (elementFieldType?.ref ?? false) || + final elementTrackRef = + (elementFieldType?.ref ?? false) || (elementFieldType == null && trackRef); final header = _buildCollectionHeader( trackRef: elementTrackRef, @@ -310,10 +306,7 @@ final class ListSerializer extends Serializer { if (!usesDeclaredType && sameTypeInfo != null && analysis.firstNonNull != null) { - context.writeTypeMetaValue( - sameTypeInfo, - analysis.firstNonNull!, - ); + context.writeTypeMetaValue(sameTypeInfo, analysis.firstNonNull!); } if (declaredTypeInfo != null) { _writeSameTypeElements( @@ -337,12 +330,7 @@ final class ListSerializer extends Serializer { ); return; } - _writeDifferentTypeElements( - context, - values, - analysis, - elementTrackRef, - ); + _writeDifferentTypeElements(context, values, analysis, elementTrackRef); } static List readPayload( @@ -423,7 +411,8 @@ Object? readCompatibleMatchedCollectionArrayField( _arrayElementTypeId(localType.typeId) != _compatibleArrayElementTypeId(elementType.typeId)) { throw StateError( - 'Compatible list-to-array field ${localField.name} is unsupported.'); + 'Compatible list-to-array field ${localField.name} is unsupported.', + ); } return _readCompatibleListAsArrayField( context, @@ -440,7 +429,8 @@ Object? readCompatibleMatchedCollectionArrayField( _arrayElementTypeId(remoteType.typeId) != _compatibleArrayElementTypeId(localElementType.typeId)) { throw StateError( - 'Compatible array-to-list field ${localField.name} is unsupported.'); + 'Compatible array-to-list field ${localField.name} is unsupported.', + ); } final raw = readCompatibleField(context, remoteField); return _arrayToListValue(raw); @@ -464,13 +454,27 @@ bool isCompatibleCollectionArrayTypePair( FieldType localType, FieldType remoteType, ) { + if (localType.nullable || + remoteType.nullable || + localType.ref || + remoteType.ref) { + return false; + } if (isCompatibleArrayType(localType.typeId) && remoteType.typeId == TypeIds.list) { - return _listElementMatchesArray(remoteType, localType.typeId); + return _listElementMatchesArray( + remoteType, + localType.typeId, + requireUnframedElement: true, + ); } if (localType.typeId == TypeIds.list && isCompatibleArrayType(remoteType.typeId)) { - return _listElementMatchesArray(localType, remoteType.typeId); + return _listElementMatchesArray( + localType, + remoteType.typeId, + requireUnframedElement: false, + ); } return false; } @@ -485,10 +489,18 @@ bool isCompatibleCollectionArrayRootTypePair( (isCompatibleArrayType(localTypeId) && remoteTypeId == TypeIds.list); } -bool _listElementMatchesArray(FieldType listType, int arrayTypeId) { +bool _listElementMatchesArray( + FieldType listType, + int arrayTypeId, { + required bool requireUnframedElement, +}) { final elementType = listType.arguments.isEmpty ? null : listType.arguments.single; + // Nullable element schema is allowed for list -> array; actual + // null payload elements fail in the dense-array reader. Ref-tracked + // element framing is rejected here because this path stays primitive-only. return elementType != null && + (!requireUnframedElement || !elementType.ref) && _arrayElementTypeId(arrayTypeId) == _compatibleArrayElementTypeId(elementType.typeId); } @@ -651,11 +663,23 @@ List readTypedListPayload( if (directTypeInfo.type == T && directTypeInfo.kind == RegistrationKind.struct) { final structSerializer = directTypeInfo.structSerializer!; - final result = List.generate( - state.size, - (_) => structSerializer.readValue(context, directTypeInfo) as T, - growable: false, - ); + final result = + directTypeInfo.remoteTypeDef == null + ? List.generate( + state.size, + (_) => structSerializer.readValue(context, directTypeInfo) as T, + growable: false, + ) + : List.generate( + state.size, + (_) => + structSerializer.readGeneratedCompatibleValue( + context, + directTypeInfo, + ) + as T, + growable: false, + ); if (state.tracksDepth) { context.decreaseDepth(); } @@ -715,8 +739,9 @@ void writeTypedListPayload( } context.buffer.writeVarUint32(size); if (size == 0) return; - final declaredTypeInfo = - context.typeResolver.resolveFieldType(elementFieldType); + final declaredTypeInfo = context.typeResolver.resolveFieldType( + elementFieldType, + ); final usesDeclaredType = usesDeclaredTypeInfo( context.config.compatible, elementFieldType, @@ -749,7 +774,10 @@ void writeTypedSetPayload( FieldType elementFieldType, ) { writeTypedListPayload( - context, values.toList(growable: false), elementFieldType); + context, + values.toList(growable: false), + elementFieldType, + ); } int _buildCollectionHeader({ @@ -799,19 +827,9 @@ void _writeSameTypeElements( ); } else if (hasNull) { context.buffer.writeByte(RefWriter.notNullValueFlag); - _writeDirectTypeInfoValue( - context, - typeInfo, - fieldType, - value as Object, - ); + _writeDirectTypeInfoValue(context, typeInfo, fieldType, value as Object); } else { - _writeDirectTypeInfoValue( - context, - typeInfo, - fieldType, - value as Object, - ); + _writeDirectTypeInfoValue(context, typeInfo, fieldType, value as Object); } } if (tracksDepth) { @@ -928,17 +946,21 @@ _PreparedListRead _prepareListRead( final usesDeclaredType = (header & CollectionFlags.isDeclaredElementType) != 0; final sameType = (header & CollectionFlags.isSameType) != 0; - final needsExpectedElementType = elementFieldType != null && + final needsExpectedElementType = + elementFieldType != null && (usesDeclaredType || (sameType && TypeIds.isUserType(elementFieldType.typeId))); - final expectedElementTypeInfo = needsExpectedElementType - ? context.typeResolver.tryResolveFieldType(elementFieldType) - : null; + final expectedElementTypeInfo = + needsExpectedElementType + ? context.typeResolver.tryResolveFieldType(elementFieldType) + : null; final declaredTypeInfo = usesDeclaredType ? expectedElementTypeInfo : null; - final sameTypeInfo = (!usesDeclaredType && sameType) - ? context.readTypeMetaValue(expectedElementTypeInfo) - : null; - final tracksDepth = (declaredTypeInfo != null && + final sameTypeInfo = + (!usesDeclaredType && sameType) + ? context.readTypeMetaValue(expectedElementTypeInfo) + : null; + final tracksDepth = + (declaredTypeInfo != null && tracksNestedPayloadDepth(declaredTypeInfo)) || (sameTypeInfo != null && tracksNestedPayloadDepth(sameTypeInfo)); return _PreparedListRead( @@ -954,10 +976,7 @@ _PreparedListRead _prepareListRead( } @pragma('vm:prefer-inline') -Object? _readPreparedListItem( - ReadContext context, - _PreparedListRead state, -) { +Object? _readPreparedListItem(ReadContext context, _PreparedListRead state) { if (state.declaredTypeInfo != null) { return _readSameTypeElement( context, @@ -1008,11 +1027,7 @@ Object? _readPreparedListItem( } return context.readResolvedValue(state.sameTypeInfo!, null); } - return _readDifferentTypeElement( - context, - state.trackRef, - state.hasNull, - ); + return _readDifferentTypeElement(context, state.trackRef, state.hasNull); } @pragma('vm:prefer-inline') diff --git a/dart/packages/fory/lib/src/serializer/generated_struct_serializer.dart b/dart/packages/fory/lib/src/serializer/generated_struct_serializer.dart index 3a9acf5e3c..4633b0d75d 100644 --- a/dart/packages/fory/lib/src/serializer/generated_struct_serializer.dart +++ b/dart/packages/fory/lib/src/serializer/generated_struct_serializer.dart @@ -26,6 +26,9 @@ import 'package:fory/src/serializer/scalar_conversion.dart'; import 'package:fory/src/serializer/serialization_field_info.dart'; import 'package:fory/src/serializer/serializer.dart'; import 'package:fory/src/serializer/serializer_support.dart'; +import 'package:fory/src/types/float32.dart'; +import 'package:fory/src/types/int64.dart'; +import 'package:fory/src/types/uint64.dart'; @internal abstract interface class GeneratedStructSerializer implements Serializer { @@ -37,59 +40,167 @@ abstract interface class GeneratedStructSerializer implements Serializer { @internal final class CompatibleStructReadLayout { - final List remoteFields; - final List fields; - final List? _scalarConversions; - final List? _topLevelListArrayPairs; + final List fields; - const CompatibleStructReadLayout( - this.remoteFields, - this.fields, [ - this._scalarConversions, - this._topLevelListArrayPairs, - ]); + const CompatibleStructReadLayout(this.fields); - int get fieldCount => remoteFields.length; + int get fieldCount => fields.length; - FieldInfo remoteFieldAt(int index) => remoteFields[index]; + CompatibleStructReadField fieldAt(int index) => fields[index]; +} - SerializationFieldInfo? localFieldAt(int index) => fields[index]; +@internal +final class CompatibleStructReadField { + final FieldInfo remoteField; + final int matchedId; + final SerializationFieldInfo? localField; + final CompatibleScalarReadDescriptor? scalarRead; + final bool topLevelListArrayPair; - CompatibleScalarConversion? scalarConversionAt(int index) => - _scalarConversions?[index]; + const CompatibleStructReadField({ + required this.remoteField, + required this.matchedId, + required this.localField, + required this.scalarRead, + required this.topLevelListArrayPair, + }); +} - bool topLevelListArrayPairAt(int index) => - _topLevelListArrayPairs?[index] ?? false; +@internal +@pragma('vm:never-inline') +Object? readGeneratedCompatibleScalarField( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, +) { + return readCompatibleScalarField(context, scalarRead.conversion); } @internal @pragma('vm:never-inline') Object? readGeneratedCompatibleStructField( ReadContext context, - CompatibleStructReadLayout layout, - int index, + CompatibleStructReadField field, ) { - final localField = layout.localFieldAt(index)!; - final scalarConversion = layout.scalarConversionAt(index); - if (scalarConversion != null) { - return readCompatibleScalarField(context, scalarConversion); + final localField = field.localField!; + final scalarRead = field.scalarRead; + if (scalarRead != null) { + return readGeneratedCompatibleScalarField(context, scalarRead); } - if (layout.topLevelListArrayPairAt(index)) { + if (field.topLevelListArrayPair) { return readCompatibleMatchedCollectionArrayField( context, localField, - layout.remoteFieldAt(index), + field.remoteField, ); } return readFieldValue(context, localField); } +@internal +@pragma('vm:prefer-inline') +int readGenCompatScalarAsInt( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, [ + int? fallback, +]) { + if (scalarRead.readsDirectIntAsInt) { + return readCompatScalarPayloadAsInt(context, scalarRead, fallback); + } + final value = readGeneratedCompatibleScalarField(context, scalarRead); + if (value == null) { + if (fallback != null) { + return fallback; + } + throw StateError('Received null for non-nullable field value.'); + } + return value as int; +} + +@internal +@pragma('vm:prefer-inline') +double readGenCompatScalarAsDouble( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, [ + double? fallback, +]) { + if (scalarRead.readsDirectDoubleAsDouble) { + return readCompatScalarPayloadAsDouble(context, scalarRead, fallback); + } + final value = readGeneratedCompatibleScalarField(context, scalarRead); + if (value == null) { + if (fallback != null) { + return fallback; + } + throw StateError('Received null for non-nullable field value.'); + } + return value as double; +} + +@internal +@pragma('vm:prefer-inline') +Int64 readGenCompatScalarAsInt64( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, [ + Int64? fallback, +]) { + if (scalarRead.readsDirectInt64) { + return readCompatScalarPayloadAsInt64(context, scalarRead, fallback); + } + final value = readGeneratedCompatibleScalarField(context, scalarRead); + if (value == null) { + if (fallback != null) { + return fallback; + } + throw StateError('Received null for non-nullable field value.'); + } + return value as Int64; +} + +@internal +@pragma('vm:prefer-inline') +Uint64 readGenCompatScalarAsUint64( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, [ + Uint64? fallback, +]) { + if (scalarRead.readsDirectUint64) { + return readCompatScalarPayloadAsUint64(context, scalarRead, fallback); + } + final value = readGeneratedCompatibleScalarField(context, scalarRead); + if (value == null) { + if (fallback != null) { + return fallback; + } + throw StateError('Received null for non-nullable field value.'); + } + return value as Uint64; +} + +@internal +@pragma('vm:prefer-inline') +Float32 readGenCompatScalarAsFloat32( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, [ + Float32? fallback, +]) { + if (scalarRead.readsDirectFloat32) { + return readCompatScalarPayloadAsFloat32(context, scalarRead, fallback); + } + final value = readGeneratedCompatibleScalarField(context, scalarRead); + if (value == null) { + if (fallback != null) { + return fallback; + } + throw StateError('Received null for non-nullable field value.'); + } + return value as Float32; +} + @internal @pragma('vm:never-inline') void skipGeneratedCompatibleStructField( ReadContext context, - CompatibleStructReadLayout layout, - int index, + CompatibleStructReadField field, ) { - readCompatibleField(context, layout.remoteFieldAt(index)); + readCompatibleField(context, field.remoteField); } diff --git a/dart/packages/fory/lib/src/serializer/scalar_conversion.dart b/dart/packages/fory/lib/src/serializer/scalar_conversion.dart index f8204c6bab..d4caca3b2d 100644 --- a/dart/packages/fory/lib/src/serializer/scalar_conversion.dart +++ b/dart/packages/fory/lib/src/serializer/scalar_conversion.dart @@ -50,12 +50,19 @@ final BigInt _uint64Max = (BigInt.one << 64) - BigInt.one; final BigInt _doubleSignMask = BigInt.one << 63; final BigInt _doubleFractionMask = (BigInt.one << 52) - BigInt.one; final BigInt _ten = BigInt.from(10); +const int _int64SignHigh32 = 0x80000000; const int _maxCompatibleDecimalDigits = 256; const int _maxCompatibleNumericTextLength = 320; final BigInt _maxCompatibleDecimalMagnitude = BigInt.from( 10, ).pow(_maxCompatibleDecimalDigits); final ByteData _floatData = ByteData(8); +const int _directTargetNone = 0; +const int _directTargetInt = 1; +const int _directTargetDouble = 2; +const int _directTargetInt64 = 3; +const int _directTargetUint64 = 4; +const int _directTargetFloat32 = 5; final class CompatibleScalarConversion { final FieldInfo remoteField; @@ -64,6 +71,32 @@ final class CompatibleScalarConversion { const CompatibleScalarConversion(this.remoteField, this.localField); } +final class CompatibleScalarReadDescriptor { + final CompatibleScalarConversion conversion; + final int directTargetKind; + final int directSourceTypeId; + final int directTargetTypeId; + final bool sourceNullable; + + const CompatibleScalarReadDescriptor( + this.conversion, + this.directTargetKind, + this.directSourceTypeId, + this.directTargetTypeId, + this.sourceNullable, + ); + + bool get readsDirectIntAsInt => directTargetKind == _directTargetInt; + + bool get readsDirectDoubleAsDouble => directTargetKind == _directTargetDouble; + + bool get readsDirectInt64 => directTargetKind == _directTargetInt64; + + bool get readsDirectUint64 => directTargetKind == _directTargetUint64; + + bool get readsDirectFloat32 => directTargetKind == _directTargetFloat32; +} + CompatibleScalarConversion? compatibleScalarConversion( FieldInfo remoteField, FieldInfo localField, @@ -87,6 +120,29 @@ CompatibleScalarConversion? compatibleScalarConversion( return CompatibleScalarConversion(remoteField, localField); } +CompatibleScalarReadDescriptor compatibleScalarReadDescriptor( + CompatibleScalarConversion conversion, +) { + final localType = conversion.localField.fieldType; + final remoteType = conversion.remoteField.fieldType; + final directTargetKind = _directTargetKind(localType); + final directSourceTypeId = switch (directTargetKind) { + _directTargetInt || + _directTargetInt64 || + _directTargetUint64 => _directIntegerSourceTypeId(remoteType.typeId), + _directTargetDouble || + _directTargetFloat32 => _directDoubleSourceTypeId(remoteType.typeId), + _ => -1, + }; + return CompatibleScalarReadDescriptor( + conversion, + directSourceTypeId >= 0 ? directTargetKind : _directTargetNone, + directSourceTypeId, + directSourceTypeId >= 0 ? localType.typeId : -1, + remoteType.nullable, + ); +} + bool _supportsCompatibleScalarConversion(int remoteTypeId, int localTypeId) { if (remoteTypeId == localTypeId) { return _isScalarConversionType(remoteTypeId); @@ -144,6 +200,98 @@ bool _isFloatingType(int typeId) => typeId == TypeIds.float32 || typeId == TypeIds.float64; +int _directTargetKind(FieldType localType) { + if (localType.nullable || localType.ref || localType.dynamic == true) { + return _directTargetNone; + } + if (localType.type == int && _isDirectIntTargetType(localType.typeId)) { + return _directTargetInt; + } + if (localType.type == double && _isDirectDoubleTargetType(localType.typeId)) { + return _directTargetDouble; + } + if (localType.type == Int64 && _isSigned64TargetType(localType.typeId)) { + return _directTargetInt64; + } + if (localType.type == Uint64 && _isUnsigned64TargetType(localType.typeId)) { + return _directTargetUint64; + } + if (localType.type == Float32 && localType.typeId == TypeIds.float32) { + return _directTargetFloat32; + } + return _directTargetNone; +} + +bool _isDirectIntTargetType(int typeId) => + typeId == TypeIds.int8 || + typeId == TypeIds.int16 || + typeId == TypeIds.int32 || + typeId == TypeIds.varInt32 || + typeId == TypeIds.int64 || + typeId == TypeIds.varInt64 || + typeId == TypeIds.taggedInt64 || + typeId == TypeIds.uint8 || + typeId == TypeIds.uint16 || + typeId == TypeIds.uint32 || + typeId == TypeIds.varUint32; + +bool _isDirectDoubleTargetType(int typeId) => + typeId == TypeIds.float32 || typeId == TypeIds.float64; + +bool _isSigned64TargetType(int typeId) => + typeId == TypeIds.int64 || + typeId == TypeIds.varInt64 || + typeId == TypeIds.taggedInt64; + +bool _isUnsigned64TargetType(int typeId) => + typeId == TypeIds.uint64 || + typeId == TypeIds.varUint64 || + typeId == TypeIds.taggedUint64; + +int _directIntegerSourceTypeId(int remoteTypeId) { + switch (remoteTypeId) { + case TypeIds.boolType: + case TypeIds.int8: + case TypeIds.int16: + case TypeIds.int32: + case TypeIds.varInt32: + case TypeIds.int64: + case TypeIds.varInt64: + case TypeIds.taggedInt64: + case TypeIds.uint8: + case TypeIds.uint16: + case TypeIds.uint32: + case TypeIds.varUint32: + case TypeIds.uint64: + case TypeIds.varUint64: + case TypeIds.taggedUint64: + return remoteTypeId; + default: + return -1; + } +} + +int _directDoubleSourceTypeId(int remoteTypeId) { + switch (remoteTypeId) { + case TypeIds.boolType: + case TypeIds.int8: + case TypeIds.int16: + case TypeIds.int32: + case TypeIds.varInt32: + case TypeIds.uint8: + case TypeIds.uint16: + case TypeIds.uint32: + case TypeIds.varUint32: + case TypeIds.float16: + case TypeIds.bfloat16: + case TypeIds.float32: + case TypeIds.float64: + return remoteTypeId; + default: + return -1; + } +} + @pragma('vm:never-inline') Object? readCompatibleScalarField( ReadContext context, @@ -162,6 +310,344 @@ Object? readCompatibleScalarField( } } +@pragma('vm:prefer-inline') +int readCompatScalarPayloadAsInt( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, [ + int? fallback, +]) { + if (!_readCompatibleScalarHeader(context, scalarRead.sourceNullable)) { + if (fallback != null) { + return fallback; + } + throw InvalidDataException( + 'Received null for non-nullable compatible scalar field.', + ); + } + final buffer = context.buffer; + final value = switch (scalarRead.directSourceTypeId) { + TypeIds.boolType => _readBoolIntPayload(buffer.readUint8()), + TypeIds.int8 => buffer.readByte(), + TypeIds.int16 => buffer.readInt16(), + TypeIds.int32 => buffer.readInt32(), + TypeIds.varInt32 => buffer.readVarInt32(), + TypeIds.int64 => buffer.readInt64AsInt(), + TypeIds.varInt64 => buffer.readVarInt64AsInt(), + TypeIds.taggedInt64 => buffer.readTaggedInt64AsInt(), + TypeIds.uint8 => buffer.readUint8(), + TypeIds.uint16 => buffer.readUint16(), + TypeIds.uint32 => buffer.readUint32(), + TypeIds.varUint32 => buffer.readVarUint32(), + TypeIds.uint64 => _uint64ToSignedInt64Target(buffer.readUint64()), + TypeIds.varUint64 => _uint64ToSignedInt64Target(buffer.readVarUint64()), + TypeIds.taggedUint64 => _uint64ToSignedInt64Target( + buffer.readTaggedUint64(), + ), + _ => + throw StateError( + 'Unsupported int-compatible payload type ${scalarRead.directSourceTypeId}.', + ), + }; + _checkIntTargetRange(value, scalarRead.directTargetTypeId); + return value; +} + +@pragma('vm:prefer-inline') +double readCompatScalarPayloadAsDouble( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, [ + double? fallback, +]) { + if (!_readCompatibleScalarHeader(context, scalarRead.sourceNullable)) { + if (fallback != null) { + return fallback; + } + throw InvalidDataException( + 'Received null for non-nullable compatible scalar field.', + ); + } + final buffer = context.buffer; + final value = switch (scalarRead.directSourceTypeId) { + TypeIds.boolType => _readBoolIntPayload(buffer.readUint8()).toDouble(), + TypeIds.int8 => buffer.readByte().toDouble(), + TypeIds.int16 => buffer.readInt16().toDouble(), + TypeIds.int32 => buffer.readInt32().toDouble(), + TypeIds.varInt32 => buffer.readVarInt32().toDouble(), + TypeIds.uint8 => buffer.readUint8().toDouble(), + TypeIds.uint16 => buffer.readUint16().toDouble(), + TypeIds.uint32 => buffer.readUint32().toDouble(), + TypeIds.varUint32 => buffer.readVarUint32().toDouble(), + TypeIds.float16 => buffer.readFloat16(), + TypeIds.bfloat16 => buffer.readBfloat16(), + TypeIds.float32 => buffer.readFloat32(), + TypeIds.float64 => buffer.readFloat64(), + _ => + throw StateError( + 'Unsupported double-compatible payload type ${scalarRead.directSourceTypeId}.', + ), + }; + return _checkedDoubleTarget(value, scalarRead.directTargetTypeId); +} + +@pragma('vm:prefer-inline') +Int64 readCompatScalarPayloadAsInt64( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, [ + Int64? fallback, +]) { + if (!_readCompatibleScalarHeader(context, scalarRead.sourceNullable)) { + if (fallback != null) { + return fallback; + } + throw InvalidDataException( + 'Received null for non-nullable compatible scalar field.', + ); + } + final buffer = context.buffer; + return switch (scalarRead.directSourceTypeId) { + TypeIds.boolType => Int64(_readBoolIntPayload(buffer.readUint8())), + TypeIds.int8 => Int64(buffer.readByte()), + TypeIds.int16 => Int64(buffer.readInt16()), + TypeIds.int32 => Int64(buffer.readInt32()), + TypeIds.varInt32 => Int64(buffer.readVarInt32()), + TypeIds.int64 => buffer.readInt64(), + TypeIds.varInt64 => buffer.readVarInt64(), + TypeIds.taggedInt64 => buffer.readTaggedInt64(), + TypeIds.uint8 => Int64(buffer.readUint8()), + TypeIds.uint16 => Int64(buffer.readUint16()), + TypeIds.uint32 => Int64(buffer.readUint32()), + TypeIds.varUint32 => Int64(buffer.readVarUint32()), + TypeIds.uint64 => _uint64ToInt64Target(buffer.readUint64()), + TypeIds.varUint64 => _uint64ToInt64Target(buffer.readVarUint64()), + TypeIds.taggedUint64 => _uint64ToInt64Target(buffer.readTaggedUint64()), + _ => + throw StateError( + 'Unsupported Int64-compatible payload type ${scalarRead.directSourceTypeId}.', + ), + }; +} + +@pragma('vm:prefer-inline') +Uint64 readCompatScalarPayloadAsUint64( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, [ + Uint64? fallback, +]) { + if (!_readCompatibleScalarHeader(context, scalarRead.sourceNullable)) { + if (fallback != null) { + return fallback; + } + throw InvalidDataException( + 'Received null for non-nullable compatible scalar field.', + ); + } + final buffer = context.buffer; + return switch (scalarRead.directSourceTypeId) { + TypeIds.boolType => Uint64(_readBoolIntPayload(buffer.readUint8())), + TypeIds.int8 => _signedIntToUint64Target(buffer.readByte()), + TypeIds.int16 => _signedIntToUint64Target(buffer.readInt16()), + TypeIds.int32 => _signedIntToUint64Target(buffer.readInt32()), + TypeIds.varInt32 => _signedIntToUint64Target(buffer.readVarInt32()), + TypeIds.int64 => _int64ToUint64Target(buffer.readInt64()), + TypeIds.varInt64 => _int64ToUint64Target(buffer.readVarInt64()), + TypeIds.taggedInt64 => _int64ToUint64Target(buffer.readTaggedInt64()), + TypeIds.uint8 => Uint64(buffer.readUint8()), + TypeIds.uint16 => Uint64(buffer.readUint16()), + TypeIds.uint32 => Uint64(buffer.readUint32()), + TypeIds.varUint32 => Uint64(buffer.readVarUint32()), + TypeIds.uint64 => buffer.readUint64(), + TypeIds.varUint64 => buffer.readVarUint64(), + TypeIds.taggedUint64 => buffer.readTaggedUint64(), + _ => + throw StateError( + 'Unsupported Uint64-compatible payload type ${scalarRead.directSourceTypeId}.', + ), + }; +} + +@pragma('vm:prefer-inline') +Float32 readCompatScalarPayloadAsFloat32( + ReadContext context, + CompatibleScalarReadDescriptor scalarRead, [ + Float32? fallback, +]) { + if (!_readCompatibleScalarHeader(context, scalarRead.sourceNullable)) { + if (fallback != null) { + return fallback; + } + throw InvalidDataException( + 'Received null for non-nullable compatible scalar field.', + ); + } + final buffer = context.buffer; + final value = switch (scalarRead.directSourceTypeId) { + TypeIds.boolType => _readBoolIntPayload(buffer.readUint8()).toDouble(), + TypeIds.int8 => buffer.readByte().toDouble(), + TypeIds.int16 => buffer.readInt16().toDouble(), + TypeIds.int32 => buffer.readInt32().toDouble(), + TypeIds.varInt32 => buffer.readVarInt32().toDouble(), + TypeIds.uint8 => buffer.readUint8().toDouble(), + TypeIds.uint16 => buffer.readUint16().toDouble(), + TypeIds.uint32 => buffer.readUint32().toDouble(), + TypeIds.varUint32 => buffer.readVarUint32().toDouble(), + TypeIds.float16 => buffer.readFloat16(), + TypeIds.bfloat16 => buffer.readBfloat16(), + TypeIds.float32 => buffer.readFloat32(), + TypeIds.float64 => buffer.readFloat64(), + _ => + throw StateError( + 'Unsupported Float32-compatible payload type ${scalarRead.directSourceTypeId}.', + ), + }; + return Float32(_checkedDoubleTarget(value, scalarRead.directTargetTypeId)); +} + +bool _readCompatibleScalarHeader(ReadContext context, bool nullable) { + if (!nullable) { + return true; + } + final flag = context.buffer.readByte(); + if (flag == RefWriter.nullFlag) { + return false; + } + if (flag != RefWriter.notNullValueFlag) { + throw InvalidDataException( + 'Invalid nullable compatible scalar field flag $flag.', + ); + } + return true; +} + +int _readBoolIntPayload(int raw) { + if (raw == 0) { + return 0; + } + if (raw == 1) { + return 1; + } + throw InvalidDataException('Bool payload must be encoded as 0 or 1.'); +} + +void _checkIntTargetRange(int value, int targetTypeId) { + switch (targetTypeId) { + case TypeIds.int8: + if (value < -128 || value > 127) { + throw InvalidDataException( + 'Integer value $value is outside int8 range.', + ); + } + return; + case TypeIds.int16: + if (value < -32768 || value > 32767) { + throw InvalidDataException( + 'Integer value $value is outside int16 range.', + ); + } + return; + case TypeIds.int32: + case TypeIds.varInt32: + if (value < -2147483648 || value > 2147483647) { + throw InvalidDataException( + 'Integer value $value is outside int32 range.', + ); + } + return; + case TypeIds.uint8: + if (value < 0 || value > 255) { + throw InvalidDataException( + 'Integer value $value is outside uint8 range.', + ); + } + return; + case TypeIds.uint16: + if (value < 0 || value > 65535) { + throw InvalidDataException( + 'Integer value $value is outside uint16 range.', + ); + } + return; + case TypeIds.uint32: + case TypeIds.varUint32: + if (value < 0 || value > 4294967295) { + throw InvalidDataException( + 'Integer value $value is outside uint32 range.', + ); + } + return; + case TypeIds.int64: + case TypeIds.varInt64: + case TypeIds.taggedInt64: + return; + default: + throw StateError('Unsupported int-compatible target type $targetTypeId.'); + } +} + +double _checkedDoubleTarget(double value, int targetTypeId) { + if (value.isNaN) { + throw const InvalidDataException('NaN is not convertible.'); + } + switch (targetTypeId) { + case TypeIds.float64: + return value; + case TypeIds.float32: + final rounded = Float32(value).value; + if (_sameFloatValue(value, rounded)) { + return rounded; + } + throw const InvalidDataException( + 'Numeric value is not exactly representable by float32.', + ); + default: + throw StateError( + 'Unsupported double-compatible target type $targetTypeId.', + ); + } +} + +int _uint64ToSignedInt64Target(Uint64 value) { + if (value.high32Unsigned >= _int64SignHigh32) { + throw InvalidDataException( + 'Unsigned 64-bit compatible scalar value exceeds signed int64 range.', + ); + } + try { + return value.toInt(); + } on StateError catch (error) { + throw InvalidDataException( + 'Unsigned 64-bit compatible scalar value is not representable as a Dart int.', + error, + ); + } +} + +Int64 _uint64ToInt64Target(Uint64 value) { + if (value.high32Unsigned >= _int64SignHigh32) { + throw InvalidDataException( + 'Unsigned 64-bit compatible scalar value exceeds signed int64 range.', + ); + } + return Int64.fromWords(value.low32, value.high32Unsigned); +} + +Uint64 _signedIntToUint64Target(int value) { + if (value < 0) { + throw InvalidDataException( + 'Signed compatible scalar value $value is outside uint64 range.', + ); + } + return Uint64(value); +} + +Uint64 _int64ToUint64Target(Int64 value) { + if (value.isNegative) { + throw InvalidDataException( + 'Signed compatible scalar value ${value.toBigInt()} is outside uint64 range.', + ); + } + return Uint64.fromWords(value.low32, value.high32Unsigned); +} + Object? _readCompatibleScalarPayload( ReadContext context, CompatibleScalarConversion conversion, @@ -221,26 +707,26 @@ Object convertCompatibleScalarValue( return convertPrimitiveFieldValue(value, localType); } if (localTypeId == TypeIds.boolType) { - return _toBool(value); + return _toBool(value, remoteTypeId); } if (localTypeId == TypeIds.string) { - return _toString(value); + return _toString(value, remoteTypeId); } if (_isIntegerType(localTypeId)) { - return _toIntegerTarget(value, localType); + return _toIntegerTarget(value, localType, remoteTypeId); } if (_isFloatingType(localTypeId)) { - return _toFloatingTarget(value, localType); + return _toFloatingTarget(value, localType, remoteTypeId); } if (localTypeId == TypeIds.decimal) { - return _toDecimal(value); + return _toDecimal(value, remoteTypeId); } throw InvalidDataException( 'Unsupported compatible scalar target $localTypeId.', ); } -bool _toBool(Object value) { +bool _toBool(Object value, int remoteTypeId) { if (value is bool) { return value; } @@ -251,7 +737,7 @@ bool _toBool(Object value) { _ => throw const FormatException('String is not a bool literal.'), }; } - final decimal = _numericDecimal(value); + final decimal = _numericDecimal(value, remoteTypeId); if (decimal.unscaled == BigInt.zero) { return false; } @@ -261,7 +747,7 @@ bool _toBool(Object value) { throw const FormatException('Numeric bool value must be 0 or 1.'); } -String _toString(Object value) { +String _toString(Object value, int remoteTypeId) { if (value is bool) { return value ? 'true' : 'false'; } @@ -274,11 +760,11 @@ String _toString(Object value) { if (value is double) { return _exactFloatText(value); } - return _integerValue(value).toString(); + return _integerValue(value, remoteTypeId).toString(); } -Object _toIntegerTarget(Object value, FieldType localType) { - final decimal = _numericDecimal(value); +Object _toIntegerTarget(Object value, FieldType localType, int remoteTypeId) { + final decimal = _numericDecimal(value, remoteTypeId); if (decimal.scale != 0) { throw const FormatException('Numeric value is not integral.'); } @@ -287,7 +773,7 @@ Object _toIntegerTarget(Object value, FieldType localType) { return _integerCarrier(integer, localType); } -Object _toFloatingTarget(Object value, FieldType localType) { +Object _toFloatingTarget(Object value, FieldType localType, int remoteTypeId) { final localTypeId = localType.typeId; if (value is String && _isNegativeZeroLiteral(value)) { _parseDecimalLiteral(value); @@ -304,7 +790,7 @@ Object _toFloatingTarget(Object value, FieldType localType) { } return _floatCarrier(rounded, localType); } - final decimal = _numericDecimal(value); + final decimal = _numericDecimal(value, remoteTypeId); if (decimal.unscaled == BigInt.zero) { return _floatCarrier(0.0, localType); } @@ -316,7 +802,7 @@ Object _toFloatingTarget(Object value, FieldType localType) { return _floatCarrier(rounded, localType); } -Decimal _toDecimal(Object value) { +Decimal _toDecimal(Object value, int remoteTypeId) { if (value is bool) { return Decimal(value ? BigInt.one : BigInt.zero, 0); } @@ -336,10 +822,10 @@ Decimal _toDecimal(Object value) { } return _canonicalDecimalFromFiniteFloat(source).toDecimal(); } - return Decimal(_integerValue(value), 0); + return Decimal(_integerValue(value, remoteTypeId), 0); } -_DecimalValue _numericDecimal(Object value) { +_DecimalValue _numericDecimal(Object value, int remoteTypeId) { if (value is bool) { return _DecimalValue(value ? BigInt.one : BigInt.zero, 0); } @@ -359,18 +845,43 @@ _DecimalValue _numericDecimal(Object value) { } return _canonicalDecimalFromFiniteFloat(source); } - return _DecimalValue(_integerValue(value), 0); + return _DecimalValue(_integerValue(value, remoteTypeId), 0); } -BigInt _integerValue(Object value) { - if (value is Uint64) { - return value.toBigInt(); - } - if (value is Int64) { - return value.toBigInt(); - } - if (value is int) { - return BigInt.from(value); +BigInt _integerValue(Object value, int remoteTypeId) { + switch (remoteTypeId) { + case TypeIds.int8: + case TypeIds.int16: + case TypeIds.int32: + case TypeIds.varInt32: + case TypeIds.uint8: + case TypeIds.uint16: + case TypeIds.uint32: + case TypeIds.varUint32: + if (value is int) { + return BigInt.from(value); + } + break; + case TypeIds.int64: + case TypeIds.varInt64: + case TypeIds.taggedInt64: + if (value is Int64) { + return value.toBigInt(); + } + if (value is int) { + return Int64(value).toBigInt(); + } + break; + case TypeIds.uint64: + case TypeIds.varUint64: + case TypeIds.taggedUint64: + if (value is Uint64) { + return value.toBigInt(); + } + if (value is int) { + return Uint64(value).toBigInt(); + } + break; } throw StateError( 'Expected integer-compatible value, got ${value.runtimeType}.', diff --git a/dart/packages/fory/lib/src/serializer/serializer_support.dart b/dart/packages/fory/lib/src/serializer/serializer_support.dart index 23b81025c6..1d9a486286 100644 --- a/dart/packages/fory/lib/src/serializer/serializer_support.dart +++ b/dart/packages/fory/lib/src/serializer/serializer_support.dart @@ -335,7 +335,7 @@ FieldInfo mergeCompatibleReadField( typeId: remote.typeId, nullable: remote.nullable, ref: remote.ref, - dynamic: local.dynamic ?? remote.dynamic, + dynamic: remote.dynamic ?? local.dynamic, arguments: mergedArguments, ); } diff --git a/dart/packages/fory/lib/src/serializer/struct_serializer.dart b/dart/packages/fory/lib/src/serializer/struct_serializer.dart index 5accd18756..beb0d2c1f9 100644 --- a/dart/packages/fory/lib/src/serializer/struct_serializer.dart +++ b/dart/packages/fory/lib/src/serializer/struct_serializer.dart @@ -22,6 +22,7 @@ import 'package:fory/src/context/write_context.dart'; import 'package:fory/src/meta/field_info.dart'; import 'package:fory/src/meta/field_type.dart'; import 'package:fory/src/meta/type_def.dart'; +import 'package:fory/src/meta/type_ids.dart'; import 'package:fory/src/resolver/type_resolver.dart'; import 'package:fory/src/serializer/collection_serializers.dart'; import 'package:fory/src/serializer/generated_struct_serializer.dart'; @@ -49,8 +50,10 @@ final class StructSerializer extends Serializer { { for (final field in _localFields) field.identifier: field, }; - final Expando _compatibleReadLayouts = - Expando('fory_compatible_read_layout'); + final Map _compatibleReadLayouts = + Map.identity(); + TypeDef? _lastCompatibleRemoteTypeDef; + CompatibleStructReadLayout? _lastCompatibleReadLayout; StructSerializer(this._payloadSerializer, this._typeDef, this._typeResolver); @@ -114,6 +117,11 @@ final class StructSerializer extends Serializer { @pragma('vm:never-inline') Object _readCompatible(ReadContext context, TypeInfo resolved) { + return readGeneratedCompatibleValue(context, resolved); + } + + @pragma('vm:prefer-inline') + Object readGeneratedCompatibleValue(ReadContext context, TypeInfo resolved) { final payloadSerializer = _payloadSerializer; if (payloadSerializer is! GeneratedStructSerializer) { throw StateError( @@ -136,27 +144,53 @@ final class StructSerializer extends Serializer { ) { final remoteTypeDef = resolved.remoteTypeDef; if (remoteTypeDef == null) { - return CompatibleStructReadLayout(_typeDef.fields, _localFields); + return CompatibleStructReadLayout( + List.unmodifiable( + List.generate( + _localFields.length, + (index) => CompatibleStructReadField( + remoteField: _typeDef.fields[index], + matchedId: index * 2, + localField: _localFields[index], + scalarRead: null, + topLevelListArrayPair: false, + ), + ), + ), + ); + } + final lastRemoteTypeDef = _lastCompatibleRemoteTypeDef; + if (identical(lastRemoteTypeDef, remoteTypeDef)) { + return _lastCompatibleReadLayout!; } final cached = _compatibleReadLayouts[remoteTypeDef]; if (cached != null) { + _lastCompatibleRemoteTypeDef = remoteTypeDef; + _lastCompatibleReadLayout = cached; return cached; } return _buildCompatibleReadLayout(remoteTypeDef); } CompatibleStructReadLayout _buildCompatibleReadLayout(TypeDef remoteTypeDef) { - final fields = []; - List? scalarConversions; - var hasScalarConversions = false; - List? topLevelListArrayPairs; - var hasTopLevelListArrayPairs = false; - for (final remoteField in remoteTypeDef.fields) { + final fields = []; + for ( + var remoteIndex = 0; + remoteIndex < remoteTypeDef.fields.length; + remoteIndex += 1 + ) { + final remoteField = remoteTypeDef.fields[remoteIndex]; final localField = _localFieldsByIdentifier[remoteField.identifier]; if (localField == null) { - fields.add(null); - scalarConversions?.add(null); - topLevelListArrayPairs?.add(false); + fields.add( + CompatibleStructReadField( + remoteField: remoteField, + matchedId: -1, + localField: null, + scalarRead: null, + topLevelListArrayPair: false, + ), + ); continue; } final topLevelListArrayPair = _topLevelListArrayPair( @@ -172,65 +206,55 @@ final class StructSerializer extends Serializer { 'Compatible field ${localField.name} has unsupported list/array schema mismatch.', ); } - if (topLevelListArrayPair) { - topLevelListArrayPairs ??= List.filled( - fields.length, - false, - growable: true, - ); - hasTopLevelListArrayPairs = true; - } - final scalarPair = - isCompatibleScalarType(remoteField.fieldType.typeId) && - isCompatibleScalarType(localField.field.fieldType.typeId); - final refTrackedScalarSchemaMismatch = - scalarPair && - (remoteField.fieldType.ref != localField.field.fieldType.ref || - ((remoteField.fieldType.ref || localField.field.fieldType.ref) && - (remoteField.fieldType.typeId != - localField.field.fieldType.typeId || - remoteField.fieldType.nullable != - localField.field.fieldType.nullable))); - if (refTrackedScalarSchemaMismatch) { - fields.add(null); - scalarConversions?.add(null); - topLevelListArrayPairs?.add(false); - continue; - } final scalarConversion = topLevelListArrayPair ? null : compatibleScalarConversion(remoteField, localField.field); - if (scalarConversion != null) { - scalarConversions ??= List.filled( - fields.length, - null, - growable: true, + final scalarRead = + scalarConversion == null + ? null + : compatibleScalarReadDescriptor(scalarConversion); + final exactField = _sameFieldType( + localField.field.fieldType, + remoteField.fieldType, + ); + if (!exactField && + !topLevelListArrayPair && + scalarConversion == null && + !_compatibleFieldType( + localField.field.fieldType, + remoteField.fieldType, + topLevel: true, + )) { + throw StateError( + 'Compatible field ${localField.name} has incompatible local and remote schemas.', ); - hasScalarConversions = true; } + final matchedId = + exactField ? localField.index * 2 : localField.index * 2 + 1; final mergedField = - topLevelListArrayPair || scalarConversion != null + exactField || topLevelListArrayPair || scalarRead != null ? localField : _typeResolver.serializationFieldInfo( mergeCompatibleReadField(localField.field, remoteField), index: localField.index, ); - fields.add(mergedField); - scalarConversions?.add(scalarConversion); - topLevelListArrayPairs?.add(topLevelListArrayPair); + fields.add( + CompatibleStructReadField( + remoteField: remoteField, + matchedId: matchedId, + localField: mergedField, + scalarRead: scalarRead, + topLevelListArrayPair: topLevelListArrayPair, + ), + ); } final layout = CompatibleStructReadLayout( - remoteTypeDef.fields, - List.unmodifiable(fields), - hasScalarConversions - ? List.unmodifiable(scalarConversions!) - : null, - hasTopLevelListArrayPairs - ? List.unmodifiable(topLevelListArrayPairs!) - : null, + List.unmodifiable(fields), ); _compatibleReadLayouts[remoteTypeDef] = layout; + _lastCompatibleRemoteTypeDef = remoteTypeDef; + _lastCompatibleReadLayout = layout; return layout; } } @@ -239,6 +263,91 @@ bool _topLevelListArrayPair(FieldInfo localField, FieldInfo remoteField) { return isCompatibleCollectionArrayFieldPair(localField, remoteField); } +bool _sameFieldType(FieldType localType, FieldType remoteType) { + if (localType.typeId != remoteType.typeId || + localType.nullable != remoteType.nullable || + localType.ref != remoteType.ref || + _explicitDynamic(localType) != _explicitDynamic(remoteType) || + localType.arguments.length != remoteType.arguments.length) { + return false; + } + for (var index = 0; index < localType.arguments.length; index += 1) { + if (!_sameFieldType( + localType.arguments[index], + remoteType.arguments[index], + )) { + return false; + } + } + return true; +} + +bool _compatibleFieldType( + FieldType localType, + FieldType remoteType, { + required bool topLevel, +}) { + if (_sameFieldType(localType, remoteType)) { + return true; + } + if (_explicitDynamic(localType) != _explicitDynamic(remoteType)) { + return false; + } + final scalarPair = + isCompatibleScalarType(remoteType.typeId) && + isCompatibleScalarType(localType.typeId); + if (scalarPair) { + // Scalar conversion is an immediate field adaptation; nested container + // element/key/value scalar types must stay schema-compatible as written. + if (!topLevel) { + return localType.typeId == remoteType.typeId && + localType.arguments.length == remoteType.arguments.length; + } + if (remoteType.ref || localType.ref) { + return false; + } + if (localType.typeId == remoteType.typeId && + localType.arguments.length == remoteType.arguments.length) { + return true; + } + return compatibleScalarConversion( + FieldInfo(name: '', identifier: '', id: null, fieldType: remoteType), + FieldInfo(name: '', identifier: '', id: null, fieldType: localType), + ) != + null; + } + if (_isStructWireType(localType.typeId) && + _isStructWireType(remoteType.typeId) && + localType.arguments.length == remoteType.arguments.length) { + return true; + } + if (topLevel && isCompatibleCollectionArrayTypePair(localType, remoteType)) { + return true; + } + final sameWireFamily = + localType.typeId == remoteType.typeId || + _compatibleUnknownUserType(localType, remoteType) || + _compatibleGeneratedManualUserType(localType, remoteType) || + (_isStructWireType(localType.typeId) && + _isStructWireType(remoteType.typeId)); + if (!sameWireFamily || + localType.arguments.length != remoteType.arguments.length) { + return false; + } + for (var index = 0; index < localType.arguments.length; index += 1) { + if (!_compatibleFieldType( + localType.arguments[index], + remoteType.arguments[index], + topLevel: false, + )) { + return false; + } + } + return true; +} + +bool _explicitDynamic(FieldType fieldType) => fieldType.dynamic == true; + bool _hasUnsupportedListArrayMismatch( FieldType localType, FieldType remoteType, { @@ -263,3 +372,40 @@ bool _hasUnsupportedListArrayMismatch( } return false; } + +bool _isStructWireType(int typeId) => + typeId == TypeIds.struct || + typeId == TypeIds.compatibleStruct || + typeId == TypeIds.namedStruct || + typeId == TypeIds.namedCompatibleStruct; + +bool _isManualUserWireType(int typeId) => + typeId == TypeIds.ext || + typeId == TypeIds.namedExt || + typeId == TypeIds.union || + typeId == TypeIds.typedUnion || + typeId == TypeIds.namedUnion; + +bool _compatibleGeneratedManualUserType( + FieldType localType, + FieldType remoteType, +) { + if (localType.arguments.isNotEmpty || remoteType.arguments.isNotEmpty) { + return false; + } + return (_isStructWireType(localType.typeId) && + _isManualUserWireType(remoteType.typeId)) || + (_isManualUserWireType(localType.typeId) && + _isStructWireType(remoteType.typeId)); +} + +bool _compatibleUnknownUserType(FieldType localType, FieldType remoteType) { + if (localType.typeId == TypeIds.unknown) { + return remoteType.typeId == TypeIds.unknown || + TypeIds.isUserType(remoteType.typeId); + } + if (remoteType.typeId == TypeIds.unknown) { + return TypeIds.isUserType(localType.typeId); + } + return false; +} diff --git a/dart/packages/fory/test/nested_type_spec_test.dart b/dart/packages/fory/test/nested_type_spec_test.dart index c66739fbda..68972eaa66 100644 --- a/dart/packages/fory/test/nested_type_spec_test.dart +++ b/dart/packages/fory/test/nested_type_spec_test.dart @@ -135,22 +135,29 @@ void _registerNullableReader(Fory fory) { void main() { group('nested type specs', () { - test('compatible mode reads nested overridden fixed int32 list values', () { + test('compatible mode rejects nested scalar encoding drift', () { final writer = Fory(compatible: true); final reader = Fory(compatible: true); _registerFixedCompatible(writer); _registerVarintConsistent(reader); - final result = reader.deserialize( - writer.serialize( - NestedFixedContainer() - ..nested = >{ - 'a': [1, null, -7], - }, - ), + final bytes = writer.serialize( + NestedFixedContainer() + ..nested = >{ + 'a': [1, null, -7], + }, ); - expect(result.nested['a'], orderedEquals([1, null, -7])); + expect( + () => reader.deserialize(bytes), + throwsA( + isA().having( + (error) => error.toString(), + 'message', + contains('incompatible local and remote schemas'), + ), + ), + ); }); test('schema-consistent mode rejects nested encoding hash mismatches', () { diff --git a/dart/packages/fory/test/object_and_compatible_serializer_test.dart b/dart/packages/fory/test/object_and_compatible_serializer_test.dart index ea5e6c7b1c..dcd53c3e0e 100644 --- a/dart/packages/fory/test/object_and_compatible_serializer_test.dart +++ b/dart/packages/fory/test/object_and_compatible_serializer_test.dart @@ -109,6 +109,22 @@ class CompatibleEnvelopeV2 { Object? payload; } +@ForyStruct() +class DynamicPayloadEnvelopeV1 { + DynamicPayloadEnvelopeV1(); + + @ForyField(id: 1, dynamic: true) + Object? payload; +} + +@ForyStruct() +class StaticPayloadEnvelopeV2 { + StaticPayloadEnvelopeV2(); + + @ForyField(id: 1) + SharedLeaf? payload; +} + void _registerCommonTypes(Fory fory) { ObjectAndCompatibleSerializerTestForyModule.register( fory, @@ -281,5 +297,38 @@ void main() { orderedEquals(reader.serialize(local)), ); }); + + test('rejects compatible dynamic field framing drift', () { + final writer = Fory(compatible: true); + final reader = Fory(compatible: true); + _registerCommonTypes(writer); + _registerCommonTypes(reader); + ObjectAndCompatibleSerializerTestForyModule.register( + writer, + DynamicPayloadEnvelopeV1, + name: 'compat.DynamicPayloadEnvelope', + ); + ObjectAndCompatibleSerializerTestForyModule.register( + reader, + StaticPayloadEnvelopeV2, + name: 'compat.DynamicPayloadEnvelope', + ); + + final leaf = SharedLeaf()..label = 'leaf'; + final bytes = writer.serialize( + DynamicPayloadEnvelopeV1()..payload = leaf, + ); + + expect( + () => reader.deserialize(bytes), + throwsA( + isA().having( + (error) => error.toString(), + 'message', + contains('incompatible local and remote schemas'), + ), + ), + ); + }); }); } diff --git a/dart/packages/fory/test/runtime_validation_test.dart b/dart/packages/fory/test/runtime_validation_test.dart index d72fb9971f..5fa7c64a06 100644 --- a/dart/packages/fory/test/runtime_validation_test.dart +++ b/dart/packages/fory/test/runtime_validation_test.dart @@ -81,6 +81,14 @@ class DynamicAnimalEnvelope { DynamicAnimal? animal; } +@ForyStruct() +class ExplicitUnknownEnvelope { + ExplicitUnknownEnvelope(); + + @ForyField(dynamic: false) + Object value = 'unset'; +} + @ForyStruct() class SkipEnvelope { SkipEnvelope(); @@ -155,6 +163,11 @@ void _registerValidationTypes(Fory fory) { DynamicAnimalEnvelope, name: 'validation.DynamicAnimalEnvelope', ); + RuntimeValidationTestForyModule.register( + fory, + ExplicitUnknownEnvelope, + name: 'validation.ExplicitUnknownEnvelope', + ); RuntimeValidationTestForyModule.register( fory, SkipEnvelope, @@ -265,6 +278,17 @@ void main() { expect((catEnvelope.animal as DynamicCat).lives, equals(7)); }); + test('unknown object fields use dynamic generated payloads', () { + final fory = Fory(); + _registerValidationTypes(fory); + + final roundTrip = fory.deserialize( + fory.serialize(ExplicitUnknownEnvelope()..value = 'dynamic-payload'), + ); + + expect(roundTrip.value, equals('dynamic-payload')); + }); + test('compatible mode ignores skipped fields from older writers', () { final writer = Fory(compatible: true); final reader = Fory(compatible: true); diff --git a/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart b/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart index 285027dc90..a1ad837677 100644 --- a/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart +++ b/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart @@ -135,6 +135,19 @@ class CompatibleNullableListEnvelope { List values = []; } +@ForyStruct() +class CompatibleRootNullableListEnvelope { + CompatibleRootNullableListEnvelope(); + + @ForyField( + type: ListType( + element: Int32Type(encoding: Encoding.fixed), + nullable: true, + ), + ) + List? values = []; +} + @ForyStruct() class CompatibleStringListEnvelope { CompatibleStringListEnvelope(); @@ -159,6 +172,53 @@ class CompatibleNestedListEnvelope { List> values = >[]; } +@ForyStruct() +class CompatibleNestedStringListEnvelope { + CompatibleNestedStringListEnvelope(); + + @ListField(element: ListType(element: StringType())) + List> values = >[]; +} + +@ForyStruct() +class CompatibleNestedNullableListEnvelope { + CompatibleNestedNullableListEnvelope(); + + @ListField( + element: ListType( + element: Int32Type(nullable: true, encoding: Encoding.fixed), + ), + ) + List> values = >[]; +} + +@ForyStruct() +class CompatibleNestedIntMapEnvelope { + CompatibleNestedIntMapEnvelope(); + + @MapField(key: StringType(), value: Int32Type(encoding: Encoding.fixed)) + Map values = {}; +} + +@ForyStruct() +class CompatibleNestedNullableMapEnvelope { + CompatibleNestedNullableMapEnvelope(); + + @MapField( + key: StringType(), + value: Int32Type(nullable: true, encoding: Encoding.fixed), + ) + Map values = {}; +} + +@ForyStruct() +class CompatibleNestedStringMapEnvelope { + CompatibleNestedStringMapEnvelope(); + + @MapField(key: StringType(), value: StringType()) + Map values = {}; +} + @ForyStruct() class CompatibleScalarStringEnvelope { CompatibleScalarStringEnvelope(); @@ -239,6 +299,14 @@ class CompatibleScalarUint64Envelope { Uint64 value = Uint64(0); } +@ForyStruct() +class CompatibleScalarWrappedInt64Envelope { + CompatibleScalarWrappedInt64Envelope(); + + @ForyField(id: 1, type: Int64Type(encoding: Encoding.varint)) + Int64 value = Int64(0); +} + @ForyStruct() class CompatibleScalarFloat64Envelope { CompatibleScalarFloat64Envelope(); @@ -255,6 +323,14 @@ class CompatibleScalarFloat32Envelope { double value = 0.0; } +@ForyStruct() +class CompatibleScalarWrappedFloat32Envelope { + CompatibleScalarWrappedFloat32Envelope(); + + @ForyField(id: 1, type: Float32Type()) + Float32 value = Float32(0); +} + @ForyStruct() class CompatibleScalarDecimalEnvelope { CompatibleScalarDecimalEnvelope(); @@ -803,40 +879,74 @@ void main() { expect(decoded.values, orderedEquals([1, 2, 3])); }); - test( - 'rejects compatible list payload with nullable elements for dense array fields', - () { - final writer = Fory(); - final reader = Fory(); - ScalarAndTypedArraySerializerTestForyModule.register( - writer, - CompatibleNullableListEnvelope, - name: 'test.CompatibleNullableListArrayEnvelope', - ); - ScalarAndTypedArraySerializerTestForyModule.register( - reader, - CompatibleArrayEnvelope, - name: 'test.CompatibleNullableListArrayEnvelope', - ); + test('adapts compatible nullable list schema into dense array fields', () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestForyModule.register( + writer, + CompatibleNullableListEnvelope, + name: 'test.CompatibleNullableListArrayEnvelope', + ); + ScalarAndTypedArraySerializerTestForyModule.register( + reader, + CompatibleArrayEnvelope, + name: 'test.CompatibleNullableListArrayEnvelope', + ); - final nonNullBytes = writer.serialize( - CompatibleNullableListEnvelope()..values = [1, 2, 3], - ); - final decoded = reader.deserialize( - nonNullBytes, - ); - expect(decoded.values, orderedEquals([1, 2, 3])); + final bytes = writer.serialize( + CompatibleNullableListEnvelope()..values = [1, 2, 3], + ); - final nullableBytes = writer.serialize( - CompatibleNullableListEnvelope()..values = [1, null, 3], - ); + final decoded = reader.deserialize(bytes); + _expectInt32ListEquals( + decoded.values, + Int32List.fromList([1, 2, 3]), + ); - expect( - () => reader.deserialize(nullableBytes), - throwsStateError, - ); - }, - ); + final nullBytes = writer.serialize( + CompatibleNullableListEnvelope()..values = [1, null, 3], + ); + expect( + () => reader.deserialize(nullBytes), + throwsA( + isA().having( + (error) => error.toString(), + 'message', + contains('cannot read nullable'), + ), + ), + ); + }); + + test('rejects compatible nullable list field into dense array field', () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestForyModule.register( + writer, + CompatibleRootNullableListEnvelope, + name: 'test.CompatibleRootNullableListArrayEnvelope', + ); + ScalarAndTypedArraySerializerTestForyModule.register( + reader, + CompatibleArrayEnvelope, + name: 'test.CompatibleRootNullableListArrayEnvelope', + ); + + final bytes = writer.serialize( + CompatibleRootNullableListEnvelope()..values = [1, 2, 3], + ); + + expect( + () => reader.deserialize(bytes), + throwsA( + isA().having( + (error) => error.toString(), + 'message', + contains('unsupported list/array schema mismatch'), + ), + ), + ); + }); test( 'rejects incompatible compatible list and dense array element fields', @@ -891,6 +1001,119 @@ void main() { ); }); + test('rejects nested list scalar changes', () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestForyModule.register( + writer, + CompatibleNestedStringListEnvelope, + name: 'test.CompatibleNestedScalarListEnvelope', + ); + ScalarAndTypedArraySerializerTestForyModule.register( + reader, + CompatibleNestedListEnvelope, + name: 'test.CompatibleNestedScalarListEnvelope', + ); + + final bytes = writer.serialize( + CompatibleNestedStringListEnvelope() + ..values = >[ + ['1'], + ], + ); + + expect( + () => reader.deserialize(bytes), + throwsA( + isA().having( + (error) => error.toString(), + 'message', + contains('incompatible local and remote schemas'), + ), + ), + ); + }); + + test('reads nested list scalar nullable changes', () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestForyModule.register( + writer, + CompatibleNestedNullableListEnvelope, + name: 'test.CompatibleNestedScalarListNullableEnvelope', + ); + ScalarAndTypedArraySerializerTestForyModule.register( + reader, + CompatibleNestedListEnvelope, + name: 'test.CompatibleNestedScalarListNullableEnvelope', + ); + + final bytes = writer.serialize( + CompatibleNestedNullableListEnvelope() + ..values = >[ + [7], + ], + ); + + final decoded = reader.deserialize(bytes); + expect(decoded.values, >[ + [7], + ]); + }); + + test('rejects nested map scalar changes', () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestForyModule.register( + writer, + CompatibleNestedStringMapEnvelope, + name: 'test.CompatibleNestedScalarMapEnvelope', + ); + ScalarAndTypedArraySerializerTestForyModule.register( + reader, + CompatibleNestedIntMapEnvelope, + name: 'test.CompatibleNestedScalarMapEnvelope', + ); + + final bytes = writer.serialize( + CompatibleNestedStringMapEnvelope() + ..values = {'x': '1'}, + ); + + expect( + () => reader.deserialize(bytes), + throwsA( + isA().having( + (error) => error.toString(), + 'message', + contains('incompatible local and remote schemas'), + ), + ), + ); + }); + + test('reads nested map scalar nullable changes', () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestForyModule.register( + writer, + CompatibleNestedNullableMapEnvelope, + name: 'test.CompatibleNestedScalarMapNullableEnvelope', + ); + ScalarAndTypedArraySerializerTestForyModule.register( + reader, + CompatibleNestedIntMapEnvelope, + name: 'test.CompatibleNestedScalarMapNullableEnvelope', + ); + + final bytes = writer.serialize( + CompatibleNestedNullableMapEnvelope()..values = {'x': 9}, + ); + + final decoded = reader.deserialize(bytes); + expect(decoded.values, {'x': 9}); + }); + test('converts compatible scalar fields losslessly', () { expect( _compatibleScalarRoundTrip( @@ -949,6 +1172,22 @@ void main() { ).value, equals('18446744073709551615'), ); + expect( + _compatibleScalarRoundTrip( + CompatibleScalarInt64Envelope, + CompatibleScalarUint64Envelope, + CompatibleScalarInt64Envelope()..value = 42, + ).value, + equals(Uint64(42)), + ); + expect( + _compatibleScalarRoundTrip( + CompatibleScalarUint64Envelope, + CompatibleScalarWrappedInt64Envelope, + CompatibleScalarUint64Envelope()..value = Uint64(42), + ).value, + equals(Int64(42)), + ); expect( _compatibleScalarRoundTrip( CompatibleScalarStringEnvelope, @@ -957,6 +1196,38 @@ void main() { ).value, equals(0.5), ); + expect( + _compatibleScalarRoundTrip( + CompatibleScalarFloat64Envelope, + CompatibleScalarFloat32Envelope, + CompatibleScalarFloat64Envelope()..value = 0.5, + ).value, + equals(0.5), + ); + expect( + _compatibleScalarRoundTrip( + CompatibleScalarFloat32Envelope, + CompatibleScalarFloat64Envelope, + CompatibleScalarFloat32Envelope()..value = double.infinity, + ).value, + equals(double.infinity), + ); + expect( + _compatibleScalarRoundTrip( + CompatibleScalarFloat64Envelope, + CompatibleScalarFloat32Envelope, + CompatibleScalarFloat64Envelope()..value = double.negativeInfinity, + ).value, + equals(double.negativeInfinity), + ); + expect( + _compatibleScalarRoundTrip( + CompatibleScalarFloat64Envelope, + CompatibleScalarWrappedFloat32Envelope, + CompatibleScalarFloat64Envelope()..value = 0.5, + ).value, + equals(Float32(0.5)), + ); final negativeZero = _compatibleScalarRoundTrip( CompatibleScalarStringEnvelope, @@ -1066,6 +1337,11 @@ void main() { CompatibleScalarInt64Envelope, CompatibleScalarDecimalEnvelope()..value = Decimal(BigInt.one << 63, 0), ); + _expectCompatibleScalarError( + CompatibleScalarUint64Envelope, + CompatibleScalarInt64Envelope, + CompatibleScalarUint64Envelope()..value = _uint64HighBit, + ); _expectCompatibleScalarError( CompatibleScalarDecimalEnvelope, CompatibleScalarStringEnvelope, @@ -1205,24 +1481,25 @@ void main() { expect(buffer.readableBytes, equals(0)); }); - test('rejects tracked scalar nullable framing mismatch', () { + test('rejects tracked scalar nullable framing mismatch during layout', () { expect( - _compatibleScalarRoundTrip( - CompatibleScalarOptionalTrackingRefBoolEnvelope, - CompatibleScalarTrackingRefBoolEnvelope, - CompatibleScalarOptionalTrackingRefBoolEnvelope()..value = true, - ).value, - isFalse, + () => + _compatibleScalarRoundTrip( + CompatibleScalarOptionalTrackingRefBoolEnvelope, + CompatibleScalarTrackingRefBoolEnvelope, + CompatibleScalarOptionalTrackingRefBoolEnvelope()..value = true, + ), + throwsA(isA()), ); expect( - _compatibleScalarRoundTrip< + () => _compatibleScalarRoundTrip< CompatibleScalarOptionalTrackingRefBoolEnvelope >( CompatibleScalarTrackingRefBoolEnvelope, CompatibleScalarOptionalTrackingRefBoolEnvelope, CompatibleScalarTrackingRefBoolEnvelope()..value = true, - ).value, - isNull, + ), + throwsA(isA()), ); expect( _compatibleScalarRoundTrip< diff --git a/docs/guide/java/schema-evolution.md b/docs/guide/java/schema-evolution.md index 18e2613f5c..6f056f4525 100644 --- a/docs/guide/java/schema-evolution.md +++ b/docs/guide/java/schema-evolution.md @@ -48,6 +48,9 @@ is lost. Numeric strings use finite ASCII decimal syntax. Nullable and boxed fie these conversions, but reference-tracked scalar type changes are incompatible. Invalid strings and lossy conversions fail during deserialization. +Extra writer fields with no matching local field are skipped. A field that matches by tag ID or name +but has an incompatible schema is not treated as missing; deserialization fails instead. + ```java Fory fory = Fory.builder().withXlang(false) .build(); diff --git a/docs/guide/python/schema-metadata.md b/docs/guide/python/schema-metadata.md index 2af95f9ccb..50e4456a6e 100644 --- a/docs/guide/python/schema-metadata.md +++ b/docs/guide/python/schema-metadata.md @@ -320,9 +320,10 @@ written as tagged int64. Runtime type inference is used only for dynamic or unkn schemas. In compatible mode, readers consume field bytes using the remote schema metadata. Python assigns the -decoded value only when it can safely satisfy the local declared schema. Different integer encodings -in the same signedness and width domain are compatible, and same-signedness narrowing is assigned -only after range validation. +decoded value only when it can safely satisfy the local declared schema. Scalar conversion and +integer encoding adaptation apply only to the immediate matched field schema. Nested collection +elements, map keys, and map values must keep exact nullability, reference-tracking, and type shape +metadata, except for user-type family normalization such as named and unnamed struct metadata. ## Complete Example diff --git a/docs/guide/xlang/serialization.md b/docs/guide/xlang/serialization.md index 409f91cf53..338f58fcb6 100644 --- a/docs/guide/xlang/serialization.md +++ b/docs/guide/xlang/serialization.md @@ -34,7 +34,7 @@ Reduced-precision floating-point values are also part of the built-in xlang type Use the language-specific carrier types documented in the type mapping reference. Python uses `pyfory.Float16` and `pyfory.BFloat16` as annotation markers only; scalar values are native Python `float`, and dense reduced-precision arrays use `pyfory.Float16Array` and `pyfory.BFloat16Array`. Go uses the `float16` and `bfloat16` packages for scalar, slice, and array carriers; JavaScript uses `number` for scalar `float16` and `bfloat16`, and dense array carriers `BoolArray`, `Float16Array`, and `BFloat16Array` for the corresponding `array` schemas. Dart uses `double` plus `Float16Type` or `Bfloat16Type` metadata for scalar fields, and `Float16List` / `Bfloat16List` for dense arrays. Java uses `@ArrayType` on supported reduced-precision carriers for `array` / `array` schema, while general object arrays stay on the `list` path; C++, Rust, and C# provide their own dedicated scalar and array carriers. -When `compatible=true`, a direct struct/class field can evolve between `list` and `array` for dense bool/numeric `T`. Integer list element encodings in the same signedness and width domain match the corresponding dense array element domain. This applies only to the immediate matched field schema. It does not apply to nested collection, map, array, union, or generic positions. If a peer `list` payload declares nullable or ref-tracked elements, reading it into a local `array` field raises a compatible-read error. +When `compatible=true`, a direct struct/class field can evolve between `list` and `array` for dense bool/numeric `T`. Integer list element encodings in the same signedness and width domain match the corresponding dense array element domain. This applies only to the immediate matched field schema. It does not apply to nested collection, map, array, union, or generic positions. A peer `list` schema can be read into a local `array` field when the actual payload has no null elements. If the payload carries a null element or ref-tracked element encoding, reading it into a local `array` field raises a compatible-read error. ### Java diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index 60ed45ef63..9ad0ec1189 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -306,8 +306,8 @@ In Dart that internal owner is `StructSerializer`. - caching compatible read layouts - skipping unknown compatible fields - passing compatible read layouts explicitly to generated serializers -- classifying matched compatible fields that require top-level scalar - conversion and routing those fields through cold conversion helpers +- classifying matched compatible fields as exact direct reads, compatible + conversions, or remote-only skips before generated dispatch When `Config.compatible` is enabled and the struct is marked evolving: @@ -316,9 +316,20 @@ When `Config.compatible` is enabled and the struct is marked evolving: - reads map incoming fields by identifier and skip unknown fields - generated serializers apply matched fields directly while preserving their own object construction and default-value rules +- exact matched field schemas use the same direct read shape as same-schema + reads and must not receive remote compatible metadata - matched scalar fields may use compatible scalar conversion only when the layout has classified a remote/local top-level scalar pair as lossless convertible and both field schemas have `trackingRef = false` +- compatible scalar conversion applies only to the immediate matched field. + Nested collection, array, map key, and map value schemas must not be accepted + by recursively applying scalar conversion to child schemas. +- direct top-level `list` to dense `array` matched fields must be + classified as compatible when element domains match; the nullable element + schema bit alone is not a schema-pair rejection. Actual null element payloads + fail in the dense-array reader. Ref-tracked list-element framing is separate + and may remain rejected when the runtime cannot materialize it without + generic/reference paths. When `compatible` is disabled and `checkStructVersion` is enabled: @@ -334,8 +345,25 @@ decision and value adaptation stay with the serializer-owned compatible field layout. Layout classification must reject top-level scalar conversions when either matched schema has `trackingRef = true` and must reject same scalar type pairs whose top-level `trackingRef` framing differs; converters must not add a -reference-table path for scalar mismatches. Same-schema readers with matching -reference and null/optional framing must keep direct scalar read paths without conversion branches or per-field conversion +reference-table path for scalar mismatches. Recursive schema comparison inside +containers must reject scalar mismatches instead of reusing the top-level scalar +conversion matrix. Generated serializers should consume the classified layout +decision directly: + +- source-generated serializers use the layout's matched-field dispatch key to + select exact direct field code, compatible conversion code, or skip code +- regenerated serializers may instead compile a remote-schema-specific + straight-line reader after classification, without a second outer matched-id + switch, when the generated source still has pure direct, pure conversion, and + explicit skip operations +- compatible scalar conversion cases must read the concrete remote wire scalar + selected by classification and compose only the required lossless conversion; + they must not call a generic runtime converter that redispatches by remote and + local scalar type IDs, field descriptors, field names, or schema eligibility + helpers + +Same-schema readers with matching reference and null/optional framing must keep +direct scalar read paths without conversion branches or per-field conversion objects. Same raw scalar types with different null/optional framing may still use the compatible nullable/optional composition path when both fields are not reference-tracked. diff --git a/docs/specification/xlang_serialization_spec.md b/docs/specification/xlang_serialization_spec.md index 27b8e19ee4..cf46904344 100644 --- a/docs/specification/xlang_serialization_spec.md +++ b/docs/specification/xlang_serialization_spec.md @@ -196,14 +196,31 @@ and `array` as distinct kinds. The adaptation is limited to the immediate schema of the matched compatible field. It does not apply when `list` or `array` appears inside another field type, including collection elements, map keys or values, array elements, -union alternatives, or other generic/container positions. A peer `list` -TypeDef element may be declared nullable or ref-tracked; that declaration alone -does not prevent a local matched `array` field from reading the value. The -reader must decide from the collection payload. If the payload actually carries -a null element or reference-tracked element encoding that cannot be represented -as a dense array element value, the local `array` field must raise a +union alternatives, or other generic/container positions. A peer `list` +TypeDef element schema is not immediate schema incompatibility for a local +matched `array` field. Classification must accept the matched field when the +element domains match and the only element-schema difference is nullable +metadata. The reader must decide from the collection payload: if the payload +actually carries a null element, the local `array` field must raise a compatible-read error. Null list elements must not be coerced to dense-array -default values. +default values. Reference-tracked list-element framing is separate from +nullable element schema. A runtime that cannot materialize ref-tracked list +elements into a dense array without generic/reference paths may reject that +field during compatible classification; if it accepts the field, reference +payloads that cannot be represented as dense array element values must fail +during read. + +The dense-array error rule applies to dense-array targets. A matched +`list` field read into a local `list` target must keep using list +semantics and preserve actual null elements; implementations must not route that +payload through a dense primitive-array materialization path that rejects nulls. + +In schema-compatible mode only, a matched struct/class field may read between +direct top-level `binary` and direct top-level `array` schemas. This is a +byte-sequence adaptation only: it does not merge TypeDef/ClassDef type IDs, +schema fingerprints, dynamic root serialization, same-schema mode, or nested +collection/map/array/union/generic positions. `array` is not part of this +adapter. In schema-compatible mode only, a matched struct/class field may also read between direct top-level scalar schemas when the remote value can be represented @@ -329,6 +346,12 @@ action, an invalid payload value MUST be reported through the implementation's data-error path with enough context to identify the remote type, local type, and field when that path has the information. +Unknown-field skipping applies only when the remote field has no matching local +field identity. If a local field matches by tag ID or name but its schema is +outside the exact-read and compatible-adaptation rules, the reader MUST reject +the compatible layout instead of treating the field as missing, remote-only, or +skippable. + Users can also provide meta hints for fields of a type, or the type whole. Here is an example in java which use annotation to provide such information. @@ -1673,6 +1696,9 @@ MurmurHash3 x64_128 of the struct fingerprint string: - `LIST` / `SET`: `,,[]` - `MAP`: `,,[|]` - Nested container element/key/value fingerprints include nested type ID, container shape, and effective integer encoding, but nested `nullable` and `ref` policy are always hashed as `0`. Only the root field `nullable` and `ref` bits participate in schema hash, because nested reads honor the wire null/ref flags directly. +- This schema-hash rule is only for same-schema mode without TypeDef metadata. It + does not permit compatible-mode matched-field classification to accept nested + nullability or reference-tracking mismatches. Field values are serialized in Fory order. Primitive fields are written as raw values (nullable primitives include a null flag). Non-primitive fields write ref/null flags as needed and then the diff --git a/docs/specification/xlang_type_mapping.md b/docs/specification/xlang_type_mapping.md index bb8d49785d..1e7976fa06 100644 --- a/docs/specification/xlang_type_mapping.md +++ b/docs/specification/xlang_type_mapping.md @@ -141,9 +141,16 @@ Notes: field, and a direct top-level `array` field may be read as a direct top-level `list` field, when `T` is one of the dense bool/numeric array domains. Integer list element encodings in the same signedness and width domain match the corresponding dense array element domain. The rule does - not apply inside nested collection, map, array, union, or generic positions. A peer `list` - payload that declares nullable or ref-tracked elements must raise a compatible-read error when the - local matched field is `array`. + not apply inside nested collection, map, array, union, or generic positions. A peer `list` + element schema is still a compatible schema match for a local `array` field; if the actual + payload contains a null element, the dense-array reader raises a compatible-read error instead of + coercing the value. Reference-tracked list-element framing is separate from nullable element + schema and may be rejected during compatible field classification when the local matched field is + `array` and the runtime cannot materialize it without generic/reference paths. +- `binary` and `array` remain distinct schema kinds. In schema-compatible struct/class field + matching only, a direct top-level `binary` field may be read as a direct top-level `array` + field and the reverse may be read as the same byte sequence. This rule does not apply inside + nested collection, map, array, union, or generic positions, and it does not include `array`. - The table above remains the canonical xlang schema mapping. Compatible readers may apply the scalar field adaptation rules defined by `xlang_serialization_spec.md` during schema-compatible struct/class field matching. Those rules do not change TypeDef metadata, dynamic root type diff --git a/go/fory/compatible_scalar.go b/go/fory/compatible_scalar.go index c09349d402..542899ac65 100644 --- a/go/fory/compatible_scalar.go +++ b/go/fory/compatible_scalar.go @@ -174,6 +174,12 @@ func readCompatibleScalarField(ctx *ReadContext, field *FieldInfo, fieldPtr unsa return } } + if readI32ToI64Scalar(ctx, field, fieldPtr) { + return + } + if readDirectIntegerScalar(ctx, field, fieldPtr) { + return + } value := readCompatibleScalarValue(ctx, field.Meta.CompatibleScalar.remoteTypeID) if ctx.HasError() { return @@ -181,6 +187,130 @@ func readCompatibleScalarField(ctx *ReadContext, field *FieldInfo, fieldPtr unsa storeCompatibleScalarValue(ctx, field, fieldPtr, value) } +func readI32ToI64Scalar(ctx *ReadContext, field *FieldInfo, fieldPtr unsafe.Pointer) bool { + scalar := field.Meta.CompatibleScalar + if field.Kind != FieldKindValue || scalar.localType.Kind() != reflect.Int64 { + return false + } + err := ctx.Err() + switch scalar.remoteTypeID { + case INT32: + *(*int64)(fieldPtr) = int64(ctx.buffer.ReadInt32(err)) + case VARINT32: + *(*int64)(fieldPtr) = int64(ctx.buffer.ReadVarint32(err)) + default: + return false + } + return true +} + +func readDirectIntegerScalar(ctx *ReadContext, field *FieldInfo, fieldPtr unsafe.Pointer) bool { + scalar := field.Meta.CompatibleScalar + targetSigned := isSignedTypeID(scalar.localTypeID) + targetUnsigned := isUnsignedTypeID(scalar.localTypeID) + if !targetSigned && !targetUnsigned { + return false + } + var signed int64 + var unsigned uint64 + sourceSigned := false + sourceUnsigned := false + err := ctx.Err() + buf := ctx.buffer + switch scalar.remoteTypeID { + case INT8: + signed = int64(buf.ReadInt8(err)) + sourceSigned = true + case INT16: + signed = int64(buf.ReadInt16(err)) + sourceSigned = true + case INT32: + signed = int64(buf.ReadInt32(err)) + sourceSigned = true + case VARINT32: + signed = int64(buf.ReadVarint32(err)) + sourceSigned = true + case INT64: + signed = buf.ReadInt64(err) + sourceSigned = true + case VARINT64: + signed = buf.ReadVarint64(err) + sourceSigned = true + case TAGGED_INT64: + signed = buf.ReadTaggedInt64(err) + sourceSigned = true + case UINT8: + unsigned = uint64(buf.ReadUint8(err)) + sourceUnsigned = true + case UINT16: + unsigned = uint64(buf.ReadUint16(err)) + sourceUnsigned = true + case UINT32: + unsigned = uint64(buf.ReadUint32(err)) + sourceUnsigned = true + case VAR_UINT32: + unsigned = uint64(buf.ReadVarUint32(err)) + sourceUnsigned = true + case UINT64: + unsigned = buf.ReadUint64(err) + sourceUnsigned = true + case VAR_UINT64: + unsigned = buf.ReadVarUint64(err) + sourceUnsigned = true + case TAGGED_UINT64: + unsigned = buf.ReadTaggedUint64(err) + sourceUnsigned = true + default: + return false + } + if ctx.HasError() { + return true + } + optInfo := optionalInfo{} + if field.Kind == FieldKindOptional { + optInfo = field.Meta.OptionalInfo + } + if targetSigned { + min, max := signedRange(scalar.localTypeID, scalar.localType.Kind()) + if sourceSigned { + if signed < min || signed > max { + compatibleScalarFail(ctx, field.Meta.Name, scalar.remoteTypeID, scalar.localTypeID, "value is not an in-range integral target value") + return true + } + storeCompatibleInt(field.Kind, fieldPtr, optInfo, scalar.localType.Kind(), signed) + return true + } + if sourceUnsigned { + if unsigned > uint64(max) { + compatibleScalarFail(ctx, field.Meta.Name, scalar.remoteTypeID, scalar.localTypeID, "value is not an in-range integral target value") + return true + } + storeCompatibleInt(field.Kind, fieldPtr, optInfo, scalar.localType.Kind(), int64(unsigned)) + return true + } + } + if targetUnsigned { + max := unsignedMax(scalar.localTypeID, scalar.localType.Kind()) + if sourceSigned { + if signed < 0 || uint64(signed) > max { + compatibleScalarFail(ctx, field.Meta.Name, scalar.remoteTypeID, scalar.localTypeID, "value is not an in-range unsigned integral target value") + return true + } + storeCompatibleUint(field.Kind, fieldPtr, optInfo, scalar.localType.Kind(), uint64(signed)) + return true + } + if sourceUnsigned { + if unsigned > max { + compatibleScalarFail(ctx, field.Meta.Name, scalar.remoteTypeID, scalar.localTypeID, "value is not an in-range unsigned integral target value") + return true + } + storeCompatibleUint(field.Kind, fieldPtr, optInfo, scalar.localType.Kind(), unsigned) + return true + } + } + return false +} + func readCompatibleScalarValue(ctx *ReadContext, typeID TypeId) compatibleScalarValue { err := ctx.Err() buf := ctx.buffer @@ -440,6 +570,19 @@ func compatibleValueToInt(value compatibleScalarValue, target TypeId, kind refle } return 0, true } + min, max := signedRange(target, kind) + switch value.typeID { + case INT8, INT16, INT32, VARINT32, INT64, VARINT64, TAGGED_INT64: + if value.signed < min || value.signed > max { + return 0, false + } + return value.signed, true + case UINT8, UINT16, UINT32, VAR_UINT32, UINT64, VAR_UINT64, TAGGED_UINT64: + if value.unsigned > uint64(max) { + return 0, false + } + return int64(value.unsigned), true + } rat, ok := compatibleFiniteRat(value) if !ok { return 0, false @@ -448,7 +591,6 @@ func compatibleValueToInt(value compatibleScalarValue, target TypeId, kind refle return 0, false } i := rat.Num() - min, max := signedRange(target, kind) if i.Cmp(big.NewInt(min)) < 0 || i.Cmp(big.NewInt(max)) > 0 { return 0, false } @@ -462,6 +604,19 @@ func compatibleValueToUint(value compatibleScalarValue, target TypeId, kind refl } return 0, true } + max := unsignedMax(target, kind) + switch value.typeID { + case INT8, INT16, INT32, VARINT32, INT64, VARINT64, TAGGED_INT64: + if value.signed < 0 || uint64(value.signed) > max { + return 0, false + } + return uint64(value.signed), true + case UINT8, UINT16, UINT32, VAR_UINT32, UINT64, VAR_UINT64, TAGGED_UINT64: + if value.unsigned > max { + return 0, false + } + return value.unsigned, true + } rat, ok := compatibleFiniteRat(value) if !ok || !rat.IsInt() { return 0, false @@ -470,7 +625,6 @@ func compatibleValueToUint(value compatibleScalarValue, target TypeId, kind refl if i.Sign() < 0 { return 0, false } - max := unsignedMax(target, kind) if i.Cmp(new(big.Int).SetUint64(max)) > 0 { return 0, false } @@ -1039,6 +1193,15 @@ func isSignedTypeID(typeID TypeId) bool { } } +func isUnsignedTypeID(typeID TypeId) bool { + switch typeID { + case UINT8, UINT16, UINT32, VAR_UINT32, UINT64, VAR_UINT64, TAGGED_UINT64: + return true + default: + return false + } +} + func storeCompatibleInt(kind FieldKind, fieldPtr unsafe.Pointer, optInfo optionalInfo, targetKind reflect.Kind, value int64) { switch targetKind { case reflect.Int8: diff --git a/go/fory/compatible_scalar_test.go b/go/fory/compatible_scalar_test.go index be6cb7ff88..4a043b1099 100644 --- a/go/fory/compatible_scalar_test.go +++ b/go/fory/compatible_scalar_test.go @@ -456,11 +456,9 @@ func TestCompatibleScalarTrackingRefMismatch(t *testing.T) { tagID: -1, } ser := newStructSerializerFromTypeDef(reflect.TypeOf(scalarBool{}), "ScalarTrackingRefMismatch", []FieldDef{remoteDef}) - require.NoError(t, ser.initialize(f.typeResolver)) - require.Len(t, ser.fields, 1) - assert.Equal(t, -1, ser.fields[0].Meta.FieldIndex) - assert.Nil(t, ser.fields[0].Meta.CompatibleScalar) - assert.True(t, ser.typeDefDiffers) + err := ser.initialize(f.typeResolver) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot be read as local field") remoteDef.trackRef = false ser = newStructSerializerFromTypeDef(reflect.TypeOf(scalarBool{}), "ScalarTrackingRefMatch", []FieldDef{remoteDef}) @@ -473,19 +471,15 @@ func TestCompatibleScalarTrackingRefMismatch(t *testing.T) { remoteDef.trackRef = true remoteDef.nullable = true ser = newStructSerializerFromTypeDef(reflect.TypeOf(scalarBool{}), "ScalarTrackingRefNullableRemote", []FieldDef{remoteDef}) - require.NoError(t, ser.initialize(f.typeResolver)) - require.Len(t, ser.fields, 1) - assert.Equal(t, -1, ser.fields[0].Meta.FieldIndex) - assert.Nil(t, ser.fields[0].Meta.CompatibleScalar) - assert.True(t, ser.typeDefDiffers) + err = ser.initialize(f.typeResolver) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot be read as local field") remoteDef.trackRef = true remoteDef.nullable = false remoteDef.typeSpec = NewSimpleTypeSpec(INT32) ser = newStructSerializerFromTypeDef(reflect.TypeOf(scalarTrackingRefInt32{}), "ScalarTrackingRefTypeChange", []FieldDef{remoteDef}) - require.NoError(t, ser.initialize(f.typeResolver)) - require.Len(t, ser.fields, 1) - assert.Equal(t, -1, ser.fields[0].Meta.FieldIndex) - assert.Nil(t, ser.fields[0].Meta.CompatibleScalar) - assert.True(t, ser.typeDefDiffers) + err = ser.initialize(f.typeResolver) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot be read as local field") } diff --git a/go/fory/field_info.go b/go/fory/field_info.go index 5b00a677b8..97941bb1a7 100644 --- a/go/fory/field_info.go +++ b/go/fory/field_info.go @@ -44,6 +44,20 @@ type PrimitiveFieldInfo struct { Meta *FieldMeta } +type remoteFieldReadAction uint8 + +const ( + remoteFieldReadSkip remoteFieldReadAction = iota + remoteFieldReadCompatibleScalar + remoteFieldReadExactFixed + remoteFieldReadExactVarint + remoteFieldReadExactRemaining + remoteFieldReadExactNullableFixed + remoteFieldReadExactNullableVarint + remoteFieldReadExactEnum + remoteFieldReadSerializer +) + // FieldMeta contains cold/rarely-accessed field metadata. // Accessed via pointer from FieldInfo to keep FieldInfo small for cache efficiency. type FieldMeta struct { @@ -60,6 +74,11 @@ type FieldMeta struct { // top-level scalar schema differs from the local scalar schema. CompatibleScalar *compatibleScalarConversion + // ExactSchema is true when the remote field schema exactly matches the + // matched local field schema. Changed-schema reads can use direct field + // readers for these fields without consulting remote compatible metadata. + ExactSchema bool + // Optional fields (fory/optional.Optional[T]) - only valid when FieldKindOptional OptionalInfo optionalInfo @@ -78,8 +97,9 @@ type FieldInfo struct { // Hot fields - accessed frequently during serialization Offset uintptr // Field offset for unsafe access DispatchId DispatchId // Type dispatch ID - WriteOffset int // Offset within fixed-fields buffer region (sum of preceding field sizes) - RefMode RefMode // ref mode for serializer.Write/Read + ReadAction remoteFieldReadAction + WriteOffset int // Offset within fixed-fields buffer region (sum of preceding field sizes) + RefMode RefMode // ref mode for serializer.Write/Read Kind FieldKind Serializer Serializer // Serializer for this field diff --git a/go/fory/fory_compatible_test.go b/go/fory/fory_compatible_test.go index 459ebe1090..3b85fdec94 100644 --- a/go/fory/fory_compatible_test.go +++ b/go/fory/fory_compatible_test.go @@ -358,13 +358,7 @@ func TestCompatibleSerializationScenarios(t *testing.T) { Items: []string{"item1", "item2"}, Nums: []int32{1, 2, 3}, }, - assertFunc: func(t *testing.T, input any, output any) { - in := input.(SliceDataClass) - out := output.(InconsistentSliceDataClass) - assert.Equal(t, in.Name, out.Name) - assert.Nil(t, out.Items) - assert.Equal(t, in.Nums, out.Nums) - }, + unmarshalErrContains: "cannot be read as local field", }, { name: "MapFields", @@ -460,13 +454,7 @@ func TestCompatibleSerializationScenarios(t *testing.T) { "c2": 20, }, }, - assertFunc: func(t *testing.T, input any, output any) { - in := input.(MapDataClass) - out := output.(InconsistentMapDataClass) - assert.Equal(t, in.Name, out.Name) - assert.Nil(t, out.Metadata) - assert.Nil(t, out.Counters) - }, + unmarshalErrContains: "cannot be read as local field", }, { name: "NestedStruct", @@ -541,10 +529,7 @@ func TestCompatibleSerializationScenarios(t *testing.T) { input: ByteFamilyInt8ArrayDataClass{ Payload: []int8{-1, 0, 1}, }, - assertFunc: func(t *testing.T, input any, output any) { - out := output.(ByteFamilyBinaryDataClass) - assert.Nil(t, out.Payload) - }, + unmarshalErrContains: "cannot be read as local field", }, { name: "Int32ListToArray", @@ -614,7 +599,7 @@ func TestCompatibleSerializationScenarios(t *testing.T) { input: NullableInt32ListPayloadDataClass{ Payload: []*int32{ptr(int32(1)), nil, ptr(int32(3))}, }, - unmarshalErrContains: "compatible list to array field requires non-null elements", + unmarshalErrContains: "requires non-null elements", }, { name: "NestedListArrayMismatch", @@ -624,10 +609,7 @@ func TestCompatibleSerializationScenarios(t *testing.T) { input: NestedInt32ListPayloadDataClass{ Payload: [][]int32{{1, 2}, {3, 4}}, }, - assertFunc: func(t *testing.T, input any, output any) { - out := output.(NestedInt32ArrayPayloadDataClass) - assert.Nil(t, out.Payload) - }, + unmarshalErrContains: "cannot be read as local field", }, } diff --git a/go/fory/set.go b/go/fory/set.go index f466841727..ed3f9bac1b 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -146,11 +146,12 @@ func (s setSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, keys []re firstElem := UnwrapReflectValue(keys[0]) if isNull(firstElem) { hasNull = true - } else if declaredGenerics { - elemTypeInfo = &TypeInfo{Type: firstElem.Type(), Serializer: s.elemSerializer} } else { // Get type info for first element to use as reference elemTypeInfo, _ = ctx.TypeResolver().getTypeInfo(firstElem, true) + if declaredGenerics && elemTypeInfo != nil && !needsElemTypeInfo(TypeId(elemTypeInfo.TypeID)) { + elemTypeInfo = &TypeInfo{Type: firstElem.Type(), Serializer: s.elemSerializer} + } } } @@ -188,6 +189,9 @@ func (s setSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, keys []re } // When hasGenerics is true, element type is declared from schema (known at compile time) // so we don't need to write the element type ID + if declaredGenerics && elemTypeInfo != nil && needsElemTypeInfo(TypeId(elemTypeInfo.TypeID)) { + declaredGenerics = false + } if declaredGenerics { collectFlag |= CollectionIsDeclElementType collectFlag |= CollectionIsSameType @@ -322,10 +326,19 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // If all elements are same type, get element type info if (collectFlag & CollectionIsSameType) != 0 { - if (collectFlag&CollectionIsDeclElementType) != 0 && s.elemSerializer != nil { + if (collectFlag & CollectionIsDeclElementType) != 0 { // Element type is declared in schema, derive from Go type's key type keyType := type_.Key() - elemTypeInfo = &TypeInfo{Type: keyType, Serializer: s.elemSerializer} + elemSerializer := s.elemSerializer + if elemSerializer == nil { + var err error + elemSerializer, err = ctx.TypeResolver().getSerializerByType(keyType, false) + if err != nil { + ctx.SetError(FromError(err)) + return + } + } + elemTypeInfo = &TypeInfo{Type: keyType, Serializer: elemSerializer} } else { // Element type is not declared, read from buffer elemTypeInfo = ctx.TypeResolver().ReadTypeInfo(buf, err) @@ -352,31 +365,52 @@ func (s setSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, value ref // Determine if reference tracking is enabled trackRefs := (flag & CollectionTrackingRef) != 0 declaredGenerics := (flag & CollectionIsDeclElementType) != 0 - serializer := typeInfo.Serializer + hasNull := (flag & CollectionHasNull) != 0 + serializer := s.elemSerializer keyType := value.Type().Key() + elemType := keyType + if !declaredGenerics && typeInfo != nil && typeInfo.Serializer != nil { + serializer = typeInfo.Serializer + if typeInfo.Type != nil { + elemType, serializer = wrapMapSerializerIfNeeded(keyType, typeInfo.Type, serializer) + } + } + if keyType.Kind() != reflect.Ptr { + if ptrSer, ok := serializer.(*ptrToValueSerializer); ok { + serializer = ptrSer.valueSerializer + } + } + declaredGenericDispatch := declaredGenerics && serializerNeedsGenericDispatch(serializer) + + elemRefMode := RefModeNone + if trackRefs { + elemRefMode = RefModeTracking + } for i := 0; i < length; i++ { - var refID int32 + elem := reflect.New(elemType).Elem() if trackRefs { - // Handle reference tracking if enabled - refID, _ = ctx.RefResolver().TryPreserveRefId(buf) - if refID < int32(NotNullValueFlag) { - // Use existing reference if available - elem := ctx.RefResolver().GetReadObject(refID) - setMapKey(value, elem, keyType) + serializer.Read(ctx, elemRefMode, false, declaredGenericDispatch, elem) + if ctx.HasError() { + return + } + if isNull(elem) { continue } - } - - // Create new element and deserialize from buffer - elem := reflect.New(typeInfo.Type).Elem() - readSerializerData(ctx, serializer, declaredGenerics, elem) - if ctx.HasError() { - return - } - // Register new reference if tracking - if trackRefs { - ctx.RefResolver().SetReadObject(refID, elem) + } else if hasNull { + refFlag := buf.ReadInt8(ctx.Err()) + if refFlag == NullFlag { + continue + } + readSerializerData(ctx, serializer, declaredGenericDispatch, elem) + if ctx.HasError() { + return + } + } else { + readSerializerData(ctx, serializer, declaredGenericDispatch, elem) + if ctx.HasError() { + return + } } // Add element to set setMapKey(value, elem, keyType) diff --git a/go/fory/slice.go b/go/fory/slice.go index 8edc005e39..7fd787b814 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -31,6 +31,15 @@ const ( CollectionDeclSameType = CollectionIsSameType | CollectionIsDeclElementType ) +func needsElemTypeInfo(typeID TypeId) bool { + switch typeID { + case STRUCT, COMPATIBLE_STRUCT, NAMED_STRUCT, NAMED_COMPATIBLE_STRUCT, EXT, NAMED_EXT: + return true + default: + return false + } +} + // writeSliceRefAndType handles reference and type writing for slice serializers. // Returns true if the value was already written (nil or ref), false if data should be written. func writeSliceRefAndType(ctx *WriteContext, refMode RefMode, writeType bool, value reflect.Value, typeId TypeId) bool { @@ -156,17 +165,20 @@ func (s *sliceSerializer) writeDataWithGenerics(ctx *WriteContext, value reflect return } - // Determine collection flags - // When hasGenerics is true, element type is known from TypeDef, use CollectionDeclSameType - // When hasGenerics is false, write element type info, use CollectionIsSameType - var collectFlag int - if hasGenerics { - collectFlag = CollectionDeclSameType - } else { - collectFlag = CollectionIsSameType + elemType := s.type_.Elem() + elemTypeInfo, _ := ctx.TypeResolver().getTypeInfo(reflect.New(elemType).Elem(), false) + elemDeclared := hasGenerics + if elemDeclared && elemTypeInfo != nil && needsElemTypeInfo(TypeId(elemTypeInfo.TypeID)) { + elemDeclared = false + } + + // Determine collection flags. User/ext elements still write TypeInfo so + // nested compatible struct readers can classify the element schema. + collectFlag := CollectionIsSameType + if elemDeclared { + collectFlag |= CollectionIsDeclElementType } hasNull := false - elemType := s.type_.Elem() isPointerElem := elemType.Kind() == reflect.Ptr // Check for null values first (only applicable for pointer element types) @@ -189,10 +201,8 @@ func (s *sliceSerializer) writeDataWithGenerics(ctx *WriteContext, value reflect } buf.WriteInt8(int8(collectFlag)) - // Write element type info for deserialization only when hasGenerics is false - // When hasGenerics is true, the element type is known from the struct's TypeDef - if !hasGenerics { - elemTypeInfo, _ := ctx.TypeResolver().getTypeInfo(reflect.New(elemType).Elem(), false) + // Write element type info unless the element schema fully declares it. + if !elemDeclared { ctx.TypeResolver().WriteTypeInfo(buf, elemTypeInfo, ctx.Err()) } @@ -305,11 +315,24 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // ReadData collection flags collectFlag := buf.ReadInt8(ctxErr) - // ReadData element type info if present in buffer - // We must consume these bytes for protocol compliance + elemSerializer := s.elemSerializer + + // ReadData element type info if present in buffer. if (collectFlag & CollectionIsSameType) != 0 { if (collectFlag & CollectionIsDeclElementType) == 0 { - ctx.TypeResolver().ReadTypeInfo(buf, ctxErr) + elemTypeInfo := ctx.TypeResolver().ReadTypeInfo(buf, ctxErr) + if elemTypeInfo != nil && elemTypeInfo.Serializer != nil { + elemSerializer = elemTypeInfo.Serializer + elemType := value.Type().Elem() + if elemTypeInfo.Type != nil { + _, elemSerializer = wrapMapSerializerIfNeeded(elemType, elemTypeInfo.Type, elemSerializer) + } + if elemType.Kind() != reflect.Ptr { + if ptrSer, ok := elemSerializer.(*ptrToValueSerializer); ok { + elemSerializer = ptrSer.valueSerializer + } + } + } } } @@ -320,7 +343,7 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // comment during cleanup or refactors. trackRefs := (collectFlag & CollectionTrackingRef) != 0 hasNull := (collectFlag & CollectionHasNull) != 0 - declaredGenericDispatch := (collectFlag&CollectionIsDeclElementType) != 0 && serializerNeedsGenericDispatch(s.elemSerializer) + declaredGenericDispatch := (collectFlag&CollectionIsDeclElementType) != 0 && serializerNeedsGenericDispatch(elemSerializer) // Handle slice vs array allocation if isArrayType { @@ -347,14 +370,14 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { if !trackRefs && !hasNull { if declaredGenericDispatch { for i := 0; i < length; i++ { - s.elemSerializer.Read(ctx, RefModeNone, false, true, value.Index(i)) + elemSerializer.Read(ctx, RefModeNone, false, true, value.Index(i)) if ctx.HasError() { return } } } else { for i := 0; i < length; i++ { - s.elemSerializer.ReadData(ctx, value.Index(i)) + elemSerializer.ReadData(ctx, value.Index(i)) if ctx.HasError() { return } @@ -370,7 +393,7 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { if trackRefs { // When trackRefs is true, elemSerializer will read the ref flag via TryPreserveRefId // For pointer types, elemSerializer will handle allocation and reference tracking - s.elemSerializer.Read(ctx, elemRefMode, false, declaredGenericDispatch, elem) + elemSerializer.Read(ctx, elemRefMode, false, declaredGenericDispatch, elem) } else if hasNull { // When hasNull is set, read a flag byte for each element: // - NullFlag (-3) for null elements @@ -382,16 +405,16 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } // refFlag should be NotNullValueFlag, now read the actual data if declaredGenericDispatch { - s.elemSerializer.Read(ctx, RefModeNone, false, true, elem) + elemSerializer.Read(ctx, RefModeNone, false, true, elem) } else { - s.elemSerializer.ReadData(ctx, elem) + elemSerializer.ReadData(ctx, elem) } } else { // No ref tracking and no nulls: directly read data if declaredGenericDispatch { - s.elemSerializer.Read(ctx, RefModeNone, false, true, elem) + elemSerializer.Read(ctx, RefModeNone, false, true, elem) } else { - s.elemSerializer.ReadData(ctx, elem) + elemSerializer.ReadData(ctx, elem) } } if ctx.HasError() { diff --git a/go/fory/struct.go b/go/fory/struct.go index 0f7bf2334d..890ca21082 100644 --- a/go/fory/struct.go +++ b/go/fory/struct.go @@ -2314,78 +2314,30 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val buf := ctx.Buffer() ptr := unsafe.Pointer(value.UnsafeAddr()) err := ctx.Err() - readField := func(field *FieldInfo) { - if field.Meta.FieldIndex < 0 { + for i := 0; i < len(s.fields); i++ { + field := &s.fields[i] + switch field.ReadAction { + case remoteFieldReadSkip: s.skipField(ctx, field) - return - } - if field.Meta.CompatibleScalar != nil { - readCompatibleScalarField(ctx, field, unsafe.Add(ptr, field.Offset)) - return - } - - // Fast path for fixed-size primitive types (no ref flag from remote schema) - if isFixedSizePrimitive(field.DispatchId) { - fieldPtr := unsafe.Add(ptr, field.Offset) - optInfo := optionalInfo{} - if field.Kind == FieldKindOptional { - optInfo = field.Meta.OptionalInfo - } - switch field.DispatchId { - case PrimitiveBoolDispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadBool(err)) - case PrimitiveInt8DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt8(err)) - case PrimitiveUint8DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, uint8(buf.ReadInt8(err))) - case PrimitiveInt16DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt16(err)) - case PrimitiveUint16DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint16(err)) - case PrimitiveInt32DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt32(err)) - case PrimitiveUint32DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint32(err)) - case PrimitiveInt64DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadInt64(err)) - case PrimitiveUint64DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint64(err)) - case PrimitiveFloat32DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadFloat32(err)) - case PrimitiveFloat64DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadFloat64(err)) - case PrimitiveFloat16DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint16(err)) - } - return - } - - // Fast path for varint primitive types (no ref flag from remote schema) - if isVarintPrimitive(field.DispatchId) && !fieldHasNonPrimitiveSerializer(field) { - fieldPtr := unsafe.Add(ptr, field.Offset) - optInfo := optionalInfo{} - if field.Kind == FieldKindOptional { - optInfo = field.Meta.OptionalInfo + if ctx.HasError() { + return } - switch field.DispatchId { - case PrimitiveVarint32DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarint32(err)) - case PrimitiveVarint64DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarint64(err)) - case PrimitiveVarUint32DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarUint32(err)) - case PrimitiveVarUint64DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarUint64(err)) - case PrimitiveTaggedInt64DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadTaggedInt64(err)) - case PrimitiveTaggedUint64DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadTaggedUint64(err)) - case PrimitiveIntDispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, int(buf.ReadVarint64(err))) - case PrimitiveUintDispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, uint(buf.ReadVarUint64(err))) + continue + case remoteFieldReadExactFixed: + i = readExactFixedPrimitiveRun(ctx, s.fields, i, ptr) - 1 + continue + case remoteFieldReadExactVarint: + i = readExactVarintPrimitiveRun(ctx, s.fields, i, ptr) - 1 + continue + case remoteFieldReadCompatibleScalar: + readCompatibleScalarField(ctx, field, unsafe.Add(ptr, field.Offset)) + if ctx.HasError() { + return } - return + continue + case remoteFieldReadExactRemaining: + s.readRemainingField(ctx, ptr, field, value) + continue } fieldPtr := unsafe.Add(ptr, field.Offset) @@ -2394,13 +2346,12 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val optInfo = field.Meta.OptionalInfo } - // Handle nullable fixed-size primitives (read ref flag + fixed bytes) - // These have Nullable=true but use fixed encoding, not varint - if isNullableFixedSizePrimitive(field.DispatchId) { + switch field.ReadAction { + case remoteFieldReadExactNullableFixed: refFlag := buf.ReadInt8(err) if refFlag == NullFlag { clearFieldValue(field.Kind, fieldPtr, optInfo) - return + continue } // Read fixed-size value based on dispatch ID switch field.DispatchId { @@ -2429,15 +2380,12 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val case NullableFloat16DispatchId: storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadUint16(err)) } - return - } - - // Handle nullable varint primitives (read ref flag + varint) - if isNullableVarintPrimitive(field.DispatchId) { + continue + case remoteFieldReadExactNullableVarint: refFlag := buf.ReadInt8(err) if refFlag == NullFlag { clearFieldValue(field.Kind, fieldPtr, optInfo) - return + continue } // Read varint value based on dispatch ID switch field.DispatchId { @@ -2458,14 +2406,12 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val case NullableUintDispatchId: storeFieldValue(field.Kind, fieldPtr, optInfo, uint(buf.ReadVarUint64(err))) } - return - } - if isEnumField(field) { + continue + case remoteFieldReadExactEnum: readEnumFieldUnsafe(ctx, field, fieldPtr) - return + continue } - // Slow path for non-primitives (all need ref flag per xlang spec) fieldValue := value.Field(field.Meta.FieldIndex) if field.Serializer != nil { // Use pre-computed RefMode and WriteType from field initialization @@ -2473,15 +2419,260 @@ func (s *structSerializer) readFieldsInOrder(ctx *ReadContext, value reflect.Val } else { ctx.ReadValue(fieldValue, RefModeTracking, true) } - } - - for i := range s.fields { - field := &s.fields[i] - readField(field) if ctx.HasError() { return } } + if ctx.HasError() { + return + } +} + +func computeRemoteFieldReadAction(field *FieldInfo) remoteFieldReadAction { + if field.Meta.FieldIndex < 0 { + return remoteFieldReadSkip + } + if field.Meta.CompatibleScalar != nil { + return remoteFieldReadCompatibleScalar + } + if field.Meta.CompatibleScalar == nil && + field.Meta.FieldIndex >= 0 && + isFixedSizePrimitive(field.DispatchId) { + return remoteFieldReadExactFixed + } + if field.Meta.CompatibleScalar == nil && + field.Meta.FieldIndex >= 0 && + isVarintPrimitive(field.DispatchId) && + !fieldHasNonPrimitiveSerializer(field) { + return remoteFieldReadExactVarint + } + if field.Meta.ExactSchema && !isPrimitiveFieldGroupType(field.Meta.TypeId) { + return remoteFieldReadExactRemaining + } + if isNullableFixedSizePrimitive(field.DispatchId) { + return remoteFieldReadExactNullableFixed + } + if isNullableVarintPrimitive(field.DispatchId) { + return remoteFieldReadExactNullableVarint + } + if isEnumField(field) { + return remoteFieldReadExactEnum + } + return remoteFieldReadSerializer +} + +func canReadFixedRun(field *FieldInfo) bool { + return field.ReadAction == remoteFieldReadExactFixed +} + +func readExactFixedPrimitiveRun(ctx *ReadContext, fields []FieldInfo, start int, ptr unsafe.Pointer) int { + end := start + size := 0 + for end < len(fields) { + field := &fields[end] + if !canReadFixedRun(field) { + break + } + size += getFixedSizeByDispatchId(field.DispatchId) + end++ + } + buf := ctx.Buffer() + var errOut Error + if !buf.CheckReadable(size, &errOut) { + ctx.SetError(errOut) + return end + } + baseOffset := buf.ReaderIndex() + data := buf.GetData() + bufOffset := baseOffset + for i := start; i < end; i++ { + field := &fields[i] + fieldPtr := unsafe.Add(ptr, field.Offset) + optInfo := optionalInfo{} + if field.Kind == FieldKindOptional { + optInfo = field.Meta.OptionalInfo + } + switch field.DispatchId { + case PrimitiveBoolDispatchId: + storeFieldValue(field.Kind, fieldPtr, optInfo, data[bufOffset] != 0) + bufOffset++ + case PrimitiveInt8DispatchId: + storeFieldValue(field.Kind, fieldPtr, optInfo, int8(data[bufOffset])) + bufOffset++ + case PrimitiveUint8DispatchId: + storeFieldValue(field.Kind, fieldPtr, optInfo, data[bufOffset]) + bufOffset++ + case PrimitiveInt16DispatchId: + var v int16 + if isLittleEndian { + v = *(*int16)(unsafe.Pointer(&data[bufOffset])) + } else { + v = int16(binary.LittleEndian.Uint16(data[bufOffset:])) + } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) + bufOffset += 2 + case PrimitiveUint16DispatchId: + var v uint16 + if isLittleEndian { + v = *(*uint16)(unsafe.Pointer(&data[bufOffset])) + } else { + v = binary.LittleEndian.Uint16(data[bufOffset:]) + } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) + bufOffset += 2 + case PrimitiveInt32DispatchId: + var v int32 + if isLittleEndian { + v = *(*int32)(unsafe.Pointer(&data[bufOffset])) + } else { + v = int32(binary.LittleEndian.Uint32(data[bufOffset:])) + } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) + bufOffset += 4 + case PrimitiveUint32DispatchId: + var v uint32 + if isLittleEndian { + v = *(*uint32)(unsafe.Pointer(&data[bufOffset])) + } else { + v = binary.LittleEndian.Uint32(data[bufOffset:]) + } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) + bufOffset += 4 + case PrimitiveInt64DispatchId: + var v int64 + if isLittleEndian { + v = *(*int64)(unsafe.Pointer(&data[bufOffset])) + } else { + v = int64(binary.LittleEndian.Uint64(data[bufOffset:])) + } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) + bufOffset += 8 + case PrimitiveUint64DispatchId: + var v uint64 + if isLittleEndian { + v = *(*uint64)(unsafe.Pointer(&data[bufOffset])) + } else { + v = binary.LittleEndian.Uint64(data[bufOffset:]) + } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) + bufOffset += 8 + case PrimitiveFloat32DispatchId: + var v float32 + if isLittleEndian { + v = *(*float32)(unsafe.Pointer(&data[bufOffset])) + } else { + v = math.Float32frombits(binary.LittleEndian.Uint32(data[bufOffset:])) + } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) + bufOffset += 4 + case PrimitiveFloat64DispatchId: + var v float64 + if isLittleEndian { + v = *(*float64)(unsafe.Pointer(&data[bufOffset])) + } else { + v = math.Float64frombits(binary.LittleEndian.Uint64(data[bufOffset:])) + } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) + bufOffset += 8 + case PrimitiveFloat16DispatchId: + var v uint16 + if isLittleEndian { + v = *(*uint16)(unsafe.Pointer(&data[bufOffset])) + } else { + v = binary.LittleEndian.Uint16(data[bufOffset:]) + } + storeFieldValue(field.Kind, fieldPtr, optInfo, v) + bufOffset += 2 + } + } + buf.SetReaderIndex(baseOffset + size) + return end +} + +func canReadVarintRun(field *FieldInfo) bool { + return field.ReadAction == remoteFieldReadExactVarint +} + +func readExactVarintPrimitiveRun(ctx *ReadContext, fields []FieldInfo, start int, ptr unsafe.Pointer) int { + end := start + maxSize := 0 + plainVarint32Values := true + for end < len(fields) { + field := &fields[end] + if !canReadVarintRun(field) { + break + } + maxSize += getVarintMaxSizeByDispatchId(field.DispatchId) + if field.DispatchId != PrimitiveVarint32DispatchId || field.Kind != FieldKindValue { + plainVarint32Values = false + } + end++ + } + buf := ctx.Buffer() + err := ctx.Err() + useUnsafe := buf.remaining() >= maxSize+8 + if plainVarint32Values { + for i := start; i < end; i++ { + field := &fields[i] + if useUnsafe { + *(*int32)(unsafe.Add(ptr, field.Offset)) = buf.UnsafeReadVarint32(err) + } else { + *(*int32)(unsafe.Add(ptr, field.Offset)) = buf.ReadVarint32(err) + } + } + return end + } + for i := start; i < end; i++ { + field := &fields[i] + fieldPtr := unsafe.Add(ptr, field.Offset) + optInfo := optionalInfo{} + if field.Kind == FieldKindOptional { + optInfo = field.Meta.OptionalInfo + } + switch field.DispatchId { + case PrimitiveVarint32DispatchId: + if useUnsafe { + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarint32(err)) + } else { + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarint32(err)) + } + case PrimitiveVarint64DispatchId: + if useUnsafe { + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarint64()) + } else { + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarint64(err)) + } + case PrimitiveVarUint32DispatchId: + if useUnsafe { + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarUint32(err)) + } else { + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarUint32(err)) + } + case PrimitiveVarUint64DispatchId: + if useUnsafe { + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarUint64()) + } else { + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadVarUint64(err)) + } + case PrimitiveTaggedInt64DispatchId: + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadTaggedInt64(err)) + case PrimitiveTaggedUint64DispatchId: + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.ReadTaggedUint64(err)) + case PrimitiveIntDispatchId: + if useUnsafe { + storeFieldValue(field.Kind, fieldPtr, optInfo, int(buf.UnsafeReadVarint64())) + } else { + storeFieldValue(field.Kind, fieldPtr, optInfo, int(buf.ReadVarint64(err))) + } + case PrimitiveUintDispatchId: + if useUnsafe { + storeFieldValue(field.Kind, fieldPtr, optInfo, uint(buf.UnsafeReadVarUint64())) + } else { + storeFieldValue(field.Kind, fieldPtr, optInfo, uint(buf.ReadVarUint64(err))) + } + } + } + return end } // skipField skips a field that doesn't exist or is incompatible diff --git a/go/fory/struct_init.go b/go/fory/struct_init.go index 0cb2f8398b..716ac2698f 100644 --- a/go/fory/struct_init.go +++ b/go/fory/struct_init.go @@ -313,6 +313,7 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err fieldInfo := FieldInfo{ Offset: 0, DispatchId: dispatchId, + ReadAction: remoteFieldReadSkip, RefMode: refMode, Kind: FieldKindValue, Serializer: fieldSerializer, @@ -406,6 +407,7 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err var localFieldSpec *FieldSpec var exists bool var scalarConversion *compatibleScalarConversion + exactSchema := false if def.tagID >= 0 { if binding, ok := fieldTagIDToBinding[def.tagID]; ok { @@ -445,6 +447,14 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err scalarPair := false scalarExactSchema := false if localFieldSpec != nil { + exactSchema = fieldSpecEqualForDiff( + def.typeSpec, + def.nullable, + def.trackRef, + localFieldSpec.Type, + localNullableByIndex[fieldIndex], + localTrackRefByIndex[fieldIndex], + ) remoteScalar := compatibleScalarType(def.typeSpec.TypeId()) localScalar := compatibleScalarType(localFieldSpec.Type.TypeId()) if remoteScalar && localScalar { @@ -553,7 +563,15 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err ) { shouldRead = true fieldType = localType - } else if defTypeId == LIST && localFieldSpec != nil && canReadCompatibleListAsLocalArray(def.typeSpec, localFieldSpec.Type, localType) { + } else if defTypeId == LIST && localFieldSpec != nil && canReadCompatibleListAsLocalArray( + def.typeSpec, + def.nullable, + def.trackRef, + localFieldSpec.Type, + localNullableByIndex[fieldIndex], + localTrackRefByIndex[fieldIndex], + localType, + ) { shouldRead = true usesCompatibleCollectionArrayReader = true fieldType = localType @@ -564,7 +582,8 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err listReader: listReader.(primitiveListSerializer), } } - } else if defTypeId == LIST && localType.Kind() == reflect.Array { + } else if defTypeId == LIST && localFieldSpec != nil && + isPrimitiveArrayType(localFieldSpec.Type.TypeID) { shouldRead = false } else if !refTrackedScalarSchemaMismatch && !typeLookupFailed && typesCompatible(localType, remoteType) && (!scalarPair || scalarExactSchema) { shouldRead = true @@ -640,9 +659,11 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err } } } else { - fieldType = remoteType - fieldIndex = -1 - offset = 0 + return fmt.Errorf( + "compatible field %s cannot be read as local field %s", + def.name, + localFieldName, + ) } } else { fieldType = remoteType @@ -762,8 +783,10 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err Spec: metaSpec, TypeSpec: def.typeSpec, CompatibleScalar: scalarConversion, + ExactSchema: exactSchema, }, } + fieldInfo.ReadAction = computeRemoteFieldReadAction(&fieldInfo) fields = append(fields, fieldInfo) } @@ -840,7 +863,10 @@ func sameListSchemaCanReadLocalArray(remoteSpec *TypeSpec, remoteNullable bool, return fieldSpecEqualForDiff(remoteSpec, remoteNullable, remoteTrackRef, localSpec, localNullable, localTrackRef) } -func canReadCompatibleListAsLocalArray(remoteSpec *TypeSpec, localSpec *TypeSpec, localType reflect.Type) bool { +func canReadCompatibleListAsLocalArray(remoteSpec *TypeSpec, remoteNullable bool, remoteTrackRef bool, localSpec *TypeSpec, localNullable bool, localTrackRef bool, localType reflect.Type) bool { + if remoteNullable || remoteTrackRef || localNullable || localTrackRef { + return false + } if remoteSpec == nil || localSpec == nil || localType == nil { return false } @@ -850,6 +876,12 @@ func canReadCompatibleListAsLocalArray(remoteSpec *TypeSpec, localSpec *TypeSpec if remoteSpec.TypeID != LIST || remoteSpec.Element == nil { return false } + // Nullable element schema is allowed for list -> array; actual + // null payload elements fail in the dense-array reader. Ref-tracked + // element framing is rejected here because this path stays primitive-only. + if remoteSpec.Element.TrackRef { + return false + } if !isPrimitiveArrayType(localSpec.TypeID) { return false } diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java index 33015c2959..d0ccf91f16 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java @@ -376,13 +376,6 @@ private SourceField buildField( boolean serialized, SerializerMode mode) { Set modifiers = field.getModifiers(); - if (!record && modifiers.contains(Modifier.FINAL)) { - throw new InvalidStructException( - "Static serializers cannot assign final field " - + field.getSimpleName() - + "; use a record component or mark the field @Ignore/transient", - field); - } ForyFieldMeta foryField = foryField(field); Object fieldTypeTree = typeTree(field); boolean nullable = fieldNullable(field.asType(), fieldTypeTree, mode); @@ -402,7 +395,8 @@ private SourceField buildField( writeKind = SourceField.AccessKind.METHOD; readAccess = field.getSimpleName().toString(); writeAccess = null; - } else if (isAccessibleFromGenerated(field, generatedPackage)) { + } else if (isAccessibleFromGenerated(field, generatedPackage) + && !modifiers.contains(Modifier.FINAL)) { readKind = SourceField.AccessKind.FIELD; writeKind = SourceField.AccessKind.FIELD; readAccess = field.getSimpleName().toString(); @@ -410,18 +404,27 @@ private SourceField buildField( } else { ExecutableElement getter = findGetter(owner, field, generatedPackage); ExecutableElement setter = findSetter(owner, field, generatedPackage); - if (getter == null || setter == null) { - throw new InvalidStructException( - "Field " - + field.getSimpleName() - + " is not directly accessible from the generated serializer. Add accessible " - + "non-private getter/setter methods or mark it @Ignore/transient.", - field); + if (getter != null) { + readKind = SourceField.AccessKind.METHOD; + readAccess = getter.getSimpleName().toString(); + } else if (isAccessibleFromGenerated(field, generatedPackage)) { + readKind = SourceField.AccessKind.FIELD; + readAccess = field.getSimpleName().toString(); + } else { + readKind = SourceField.AccessKind.ACCESSOR; + readAccess = null; + } + if (!modifiers.contains(Modifier.FINAL) && setter != null) { + writeKind = SourceField.AccessKind.METHOD; + writeAccess = setter.getSimpleName().toString(); + } else if (!modifiers.contains(Modifier.FINAL) + && isAccessibleFromGenerated(field, generatedPackage)) { + writeKind = SourceField.AccessKind.FIELD; + writeAccess = field.getSimpleName().toString(); + } else { + writeKind = SourceField.AccessKind.ACCESSOR; + writeAccess = null; } - readKind = SourceField.AccessKind.METHOD; - writeKind = SourceField.AccessKind.METHOD; - readAccess = getter.getSimpleName().toString(); - writeAccess = setter.getSimpleName().toString(); } return new SourceField( id, diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceField.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceField.java index 3d33d37422..a8581a7f69 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceField.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceField.java @@ -21,6 +21,7 @@ final class SourceField { enum AccessKind { + ACCESSOR, FIELD, METHOD } @@ -84,6 +85,9 @@ enum AccessKind { } String readExpression(String target) { + if (readAccessKind == AccessKind.ACCESSOR) { + return fieldAccessorName() + ".get(" + target + ")"; + } if (readAccessKind == AccessKind.METHOD) { return target + "." + readAccess + "()"; } @@ -91,12 +95,23 @@ String readExpression(String target) { } String writeStatement(String target, String valueExpression) { + if (writeAccessKind == AccessKind.ACCESSOR) { + return fieldAccessorName() + ".set(" + target + ", " + valueExpression + ");"; + } if (writeAccessKind == AccessKind.METHOD) { return target + "." + writeAccess + "(" + valueExpression + ");"; } return target + "." + writeAccess + " = " + valueExpression + ";"; } + boolean usesFieldAccessor() { + return readAccessKind == AccessKind.ACCESSOR || writeAccessKind == AccessKind.ACCESSOR; + } + + String fieldAccessorName() { + return "fieldAccessor" + id; + } + String defaultValue() { switch (erasedType) { case "boolean": diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java index 46a772a5bc..4d29797710 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java @@ -60,10 +60,15 @@ private void writeHeader() { builder.append("import org.apache.fory.memory.MemoryBuffer;\n"); builder.append("import org.apache.fory.meta.TypeDef;\n"); builder.append("import org.apache.fory.meta.TypeExtMeta;\n"); + builder.append("import org.apache.fory.reflect.FieldAccessor;\n"); + if (!struct.record) { + builder.append("import org.apache.fory.reflect.ObjectInstantiator;\n"); + } builder.append("import org.apache.fory.resolver.TypeResolver;\n"); builder.append("import org.apache.fory.serializer.FieldGroups;\n"); builder.append("import org.apache.fory.serializer.FieldGroups.SerializationFieldInfo;\n"); builder.append("import org.apache.fory.serializer.StaticGeneratedStructSerializer;\n"); + builder.append("import org.apache.fory.serializer.converter.FieldConverters;\n"); builder.append("import org.apache.fory.type.Descriptor;\n"); builder.append("import org.apache.fory.type.DispatchId;\n"); builder.append("import org.apache.fory.type.Types;\n\n"); @@ -88,6 +93,17 @@ private void writeClassStart() { builder.append(" private final SerializationFieldInfo[] allFields;\n"); builder.append(" private final int[] allFieldIds;\n"); builder.append(" private final SerializationFieldInfo[] fieldsById;\n"); + if (!struct.record) { + builder.append(" private final ObjectInstantiator generatedObjectInstantiator;\n"); + } + for (SourceField field : struct.fields) { + if (field.usesFieldAccessor()) { + builder + .append(" private final FieldAccessor ") + .append(field.fieldAccessorName()) + .append(";\n"); + } + } builder.append(" private final int classVersionHash;\n"); builder.append(" private final boolean sameSchemaCompatible;\n\n"); } @@ -140,6 +156,14 @@ private void writeConstructors() { builder.append(" this.allFields = null;\n"); builder.append(" this.allFieldIds = null;\n"); builder.append(" this.fieldsById = null;\n"); + if (!struct.record) { + builder.append(" this.generatedObjectInstantiator = null;\n"); + } + for (SourceField field : struct.fields) { + if (field.usesFieldAccessor()) { + builder.append(" this.").append(field.fieldAccessorName()).append(" = null;\n"); + } + } builder.append(" this.classVersionHash = 0;\n"); builder.append(" this.sameSchemaCompatible = false;\n"); builder.append(" }\n\n"); @@ -166,11 +190,27 @@ private void writeConstructorBody(String fieldGroupsExpression, String sameSchem builder.append(" this.allFields = fieldGroups.allFields;\n"); builder.append(" this.allFieldIds = localFieldIds(allFields, DESCRIPTORS);\n"); builder.append(" this.fieldsById = new SerializationFieldInfo[DESCRIPTORS.size()];\n"); + if (!struct.record) { + builder.append( + " this.generatedObjectInstantiator = typeResolver.getObjectInstantiator(type);\n"); + } builder.append(" SerializationFieldInfo[] allFields = fieldGroups.allFields;\n"); builder.append(" int[] allFieldIds = localFieldIds(allFields, DESCRIPTORS);\n"); builder.append(" for (int i = 0; i < allFields.length; i++) {\n"); builder.append(" this.fieldsById[allFieldIds[i]] = allFields[i];\n"); builder.append(" }\n"); + for (SourceField field : struct.fields) { + if (field.usesFieldAccessor()) { + builder + .append(" this.") + .append(field.fieldAccessorName()) + .append(" = FieldAccessor.createAccessor(declaredField(type, \"") + .append(escape(field.declaringClass)) + .append("\", \"") + .append(escape(field.name)) + .append("\"));\n"); + } + } builder.append( " this.classVersionHash = typeResolver.checkClassVersion() ? computeClassVersionHash(DESCRIPTORS) : 0;\n"); builder.append(" this.sameSchemaCompatible = ").append(sameSchemaExpression).append(";\n"); @@ -221,21 +261,16 @@ private void writeSchemaConsistentRead() { .append(field.defaultValue()) .append(";\n"); } - builder.append(" Object[] values = new Object[DESCRIPTORS.size()];\n"); - builder.append(" readRecordFields(readContext, values);\n"); - for (SourceField field : struct.fields) { - builder - .append(" field") - .append(field.id) - .append(" = ") - .append(field.castExpression("values[" + field.id + "]")) - .append(";\n"); - } - builder.append(" return new ").append(struct.typeName).append("("); - appendRecordConstructorArguments("field"); - builder.append(");\n"); + appendSchemaConsistentRecordLoop(); + appendRecordConstruction("record", "field", 4); + builder.append(" return record;\n"); } else { - builder.append(" ").append(struct.typeName).append(" value = newBean();\n"); + builder + .append(" ") + .append(struct.typeName) + .append(" value = ") + .append(newGeneratedBeanExpression()) + .append(";\n"); builder.append(" readContext.reference(value);\n"); builder.append(" readFields(readContext, value);\n"); builder.append(" return value;\n"); @@ -247,6 +282,37 @@ private void writeWriteGroups() { writeWriteGroup("", "allFields", "allFieldIds", "writeFieldValue"); } + private void appendSchemaConsistentRecordLoop() { + builder.append(" for (int i = 0; i < allFields.length; i++) {\n"); + builder.append(" SerializationFieldInfo fieldInfo = allFields[i];\n"); + builder.append(" switch (allFieldIds[i]) {\n"); + for (SourceField field : struct.fields) { + builder.append(" case ").append(field.id).append(":\n"); + appendDebugRead("before", "fieldInfo", 10); + if (canEmitDirectReadField(field)) { + appendDirectReadLocal(field, "field" + field.id, "fieldInfo", 10); + } else { + builder + .append(" Object fieldValue") + .append(field.id) + .append(" = readFieldValue(readContext, fieldInfo);\n"); + builder + .append(" field") + .append(field.id) + .append(" = ") + .append(field.castExpression("fieldValue" + field.id)) + .append(";\n"); + } + appendDebugRead("after", "fieldInfo", 10); + builder.append(" break;\n"); + } + builder.append(" default:\n"); + builder.append( + " throw new IllegalStateException(\"Unknown generated field id \" + allFieldIds[i]);\n"); + builder.append(" }\n"); + builder.append(" }\n"); + } + private void writeWriteGroup( String groupName, String fieldsName, String idsName, String helperName) { builder @@ -288,9 +354,7 @@ private void writeWriteGroup( } private void writeReadGroups() { - if (struct.record) { - writeReadRecordGroup("", "allFields", "allFieldIds", "readFieldValue"); - } else { + if (!struct.record) { writeReadBeanGroup("", "allFields", "allFieldIds", "readFieldValue"); } } @@ -369,6 +433,9 @@ private boolean hasDirectReadField() { } private boolean canEmitDirectWriteField(SourceField field) { + if (field.readAccessKind == SourceField.AccessKind.ACCESSOR) { + return false; + } if (canEmitDirectStringField(field)) { return true; } @@ -443,52 +510,77 @@ private void appendDirectWrite(SourceField field) { } private void appendDirectRead(SourceField field) { + appendDirectReadTarget(field, "value", "fieldInfo", 10, false); + } + + private void appendDirectReadLocal( + SourceField field, String localName, String fieldInfoName, int indent) { + appendDirectReadTarget(field, localName, fieldInfoName, indent, true); + } + + private void appendDirectReadTarget( + SourceField field, String target, String fieldInfoName, int indent, boolean localTarget) { if (canEmitDirectStringField(field)) { - builder - .append(" ") - .append(field.writeStatement("value", "readContext.readString()")) - .append("\n"); + appendIndented(indent, assignment(field, target, "readContext.readString()", localTarget)); return; } if (canEmitDirectArrayField(field)) { - builder.append(" readContext.preserveRefId(-1);\n"); - builder - .append(" ") - .append( - field.writeStatement( - "value", "(" + field.erasedType + ") readContext.readNonRef(fieldInfo.typeInfo)")) - .append("\n"); + appendIndented(indent, "readContext.preserveRefId(-1);"); + appendIndented( + indent, + assignment( + field, + target, + "(" + field.erasedType + ") readContext.readNonRef(" + fieldInfoName + ".typeInfo)", + localTarget)); return; } String exactRead = exactPrimitiveReadExpression(field); if (exactRead == null) { - appendPrimitiveReadSwitch(field); + appendPrimitiveReadSwitch(field, target, fieldInfoName, indent, localTarget); return; } - builder.append(" ").append(field.writeStatement("value", exactRead)).append("\n"); + appendIndented(indent, assignment(field, target, exactRead, localTarget)); } private void appendPrimitiveReadSwitch(SourceField field) { - builder.append(" switch (fieldInfo.dispatchId) {\n"); + appendPrimitiveReadSwitch(field, "value", "fieldInfo", 10, false); + } + + private void appendPrimitiveReadSwitch( + SourceField field, String target, String fieldInfoName, int indent, boolean localTarget) { + appendIndented(indent, "switch (" + fieldInfoName + ".dispatchId) {"); String[][] cases = primitiveReadCases(field); for (String[] readCase : cases) { - builder.append(" case DispatchId.").append(readCase[0]).append(":\n"); - builder - .append(" ") - .append(field.writeStatement("value", readCase[1])) - .append("\n"); - builder.append(" break;\n"); + appendIndented(indent + 2, "case DispatchId." + readCase[0] + ":"); + appendIndented(indent + 4, assignment(field, target, readCase[1], localTarget)); + appendIndented(indent + 4, "break;"); } - builder.append(" default:\n"); - builder - .append(" Object fieldValue") - .append(field.id) - .append(" = readBuildInFieldValue(readContext, fieldInfo);\n"); - builder - .append(" ") - .append(field.writeStatement("value", field.castExpression("fieldValue" + field.id))) - .append("\n"); - builder.append(" }\n"); + appendIndented(indent + 2, "default:"); + appendIndented( + indent + 4, + "Object fieldValue" + + field.id + + " = readBuildInFieldValue(readContext, " + + fieldInfoName + + ");"); + appendIndented( + indent + 4, + assignment(field, target, field.castExpression("fieldValue" + field.id), localTarget)); + appendIndented(indent, "}"); + } + + private String assignment( + SourceField field, String target, String valueExpression, boolean localTarget) { + if (localTarget) { + return target + " = " + valueExpression + ";"; + } + return field.writeStatement(target, valueExpression); + } + + private void appendIndented(int spaces, String code) { + appendIndent(spaces); + builder.append(code).append("\n"); } private String[][] primitiveWriteCases(SourceField field) { @@ -702,26 +794,6 @@ private String exactPrimitiveTypeId(SourceField field) { return meta.substring(start + prefix.length(), end); } - private void writeReadRecordGroup( - String groupName, String fieldsName, String idsName, String helperName) { - builder - .append(" private void read") - .append(groupName) - .append("RecordFields(ReadContext readContext, Object[] values) {\n"); - builder.append(" for (int i = 0; i < ").append(fieldsName).append(".length; i++) {\n"); - builder.append(" SerializationFieldInfo fieldInfo = ").append(fieldsName).append("[i];\n"); - appendDebugRead("before", "fieldInfo", 6); - builder - .append(" values[") - .append(idsName) - .append("[i]] = ") - .append(helperName) - .append("(readContext, fieldInfo);\n"); - appendDebugRead("after", "fieldInfo", 6); - builder.append(" }\n"); - builder.append(" }\n\n"); - } - private void writeCompatibleRead() { builder.append(" @Override\n"); builder @@ -732,19 +804,6 @@ private void writeCompatibleRead() { builder.append(" return readSchemaConsistent(readContext);\n"); builder.append(" }\n"); if (struct.record) { - builder.append(" Object[] values = new Object[DESCRIPTORS.size()];\n"); - for (SourceField field : struct.fields) { - builder - .append(" values[") - .append(field.id) - .append("] = ") - .append(field.defaultValue()) - .append(";\n"); - } - builder.append(" for (int i = 0; i < remoteFields.size(); i++) {\n"); - builder.append(" RemoteFieldInfo remoteField = remoteFields.get(i);\n"); - builder.append(" readCompatibleRecordField(readContext, values, remoteField);\n"); - builder.append(" }\n"); for (SourceField field : struct.fields) { builder .append(" ") @@ -752,14 +811,28 @@ private void writeCompatibleRead() { .append(" field") .append(field.id) .append(" = ") - .append(field.castExpression("values[" + field.id + "]")) + .append(field.defaultValue()) .append(";\n"); } - builder.append(" return new ").append(struct.typeName).append("("); - appendRecordConstructorArguments("field"); - builder.append(");\n"); + builder.append(" for (int i = 0; i < remoteFields.size(); i++) {\n"); + builder.append(" RemoteFieldInfo remoteField = remoteFields.get(i);\n"); + builder.append(" if (remoteField.matchedId == -1) {\n"); + appendDebugRemoteRead("before skip", "remoteField", 8); + builder.append(" skipField(readContext, remoteField);\n"); + appendDebugRemoteRead("after skip", "remoteField", 8); + builder.append(" continue;\n"); + builder.append(" }\n"); + appendCompatibleRecordSwitch(); + builder.append(" }\n"); + appendRecordConstruction("record", "field", 4); + builder.append(" return record;\n"); } else { - builder.append(" ").append(struct.typeName).append(" value = newBean();\n"); + builder + .append(" ") + .append(struct.typeName) + .append(" value = ") + .append(newGeneratedBeanExpression()) + .append(";\n"); builder.append(" readContext.reference(value);\n"); builder.append(" for (int i = 0; i < remoteFields.size(); i++) {\n"); builder.append(" RemoteFieldInfo remoteField = remoteFields.get(i);\n"); @@ -771,27 +844,105 @@ private void writeCompatibleRead() { writeCompatibleDispatchMethods(); } + private void appendCompatibleRecordSwitch() { + builder.append(" switch (remoteField.matchedId) {\n"); + for (int matchedId = 0; matchedId < struct.fields.size() * 2; matchedId++) { + SourceField field = struct.fields.get(matchedId >> 1); + builder.append(" case ").append(matchedId).append(": {\n"); + if ((matchedId & 1) == 0) { + String fieldInfoName = "fieldInfo" + matchedId; + builder + .append(" SerializationFieldInfo ") + .append(fieldInfoName) + .append(" = fieldsById[") + .append(field.id) + .append("];\n"); + appendDebugRead("before", fieldInfoName, 10); + if (canEmitDirectReadField(field)) { + builder.append(" MemoryBuffer buffer = readContext.getBuffer();\n"); + appendDirectReadLocal(field, "field" + field.id, fieldInfoName, 10); + } else { + String fieldValueName = "fieldValue" + matchedId; + builder + .append(" Object ") + .append(fieldValueName) + .append(" = readFieldValue(readContext, ") + .append(fieldInfoName) + .append(");\n"); + builder + .append(" field") + .append(field.id) + .append(" = ") + .append(field.castExpression(fieldValueName)) + .append(";\n"); + } + appendDebugRead("after", fieldInfoName, 10); + builder.append(" break;\n"); + builder.append(" }\n"); + continue; + } + appendDebugRemoteRead("before read", "remoteField", 10); + String scalarRead = + compatibleScalarReadExpression( + field, "remoteField.serializationFieldInfo", "fieldsById[" + field.id + "]"); + if (scalarRead != null) { + builder + .append(" field") + .append(field.id) + .append(" = ") + .append(scalarRead) + .append(";\n"); + appendDebugRemoteRead("after read", "remoteField", 10); + } else { + String fieldValueName = "fieldValue" + matchedId; + builder + .append(" Object ") + .append(fieldValueName) + .append(" = readCompatibleFieldValue(readContext, remoteField, fieldsById[") + .append(field.id) + .append("]);\n"); + appendDebugRemoteRead("after read", "remoteField", 10); + builder + .append(" field") + .append(field.id) + .append(" = ") + .append(field.castExpression(fieldValueName)) + .append(";\n"); + } + builder.append(" break;\n"); + builder.append(" }\n"); + } + builder.append(" default:\n"); + builder.append( + " throw new IllegalStateException(\"Invalid compatible matched id \" + remoteField.matchedId);\n"); + builder.append(" }\n"); + } + private void writeCompatibleDispatchMethods() { - int groupCount = (struct.fields.size() + DISPATCH_GROUP_SIZE - 1) / DISPATCH_GROUP_SIZE; + int caseCount = struct.fields.size() * 2; + int groupCount = (caseCount + DISPATCH_GROUP_SIZE - 1) / DISPATCH_GROUP_SIZE; if (struct.record) { - writeCompatibleDispatchRouter("readCompatibleRecordField", true, groupCount); - for (int group = 0; group < groupCount; group++) { - writeCompatibleRecordDispatchGroup(group); - } + return; } else { - writeCompatibleDispatchRouter("readCompatibleField", false, groupCount); + writeCompatibleDispatchRouter("readCompatibleField", groupCount); for (int group = 0; group < groupCount; group++) { writeCompatibleBeanDispatchGroup(group); } } } - private void writeCompatibleDispatchRouter(String methodName, boolean record, int groupCount) { + private void writeCompatibleDispatchRouter(String methodName, int groupCount) { builder.append(" private void ").append(methodName).append("("); - appendCompatibleDispatchParameters(record); + appendCompatibleDispatchParameters(); builder.append(") {\n"); + builder.append(" if (remoteField.matchedId == -1) {\n"); + appendDebugRemoteRead("before skip", "remoteField", 6); + builder.append(" skipField(readContext, remoteField);\n"); + appendDebugRemoteRead("after skip", "remoteField", 6); + builder.append(" return;\n"); + builder.append(" }\n"); for (int group = 0; group < groupCount; group++) { - int upperBound = Math.min(struct.fields.size(), (group + 1) * DISPATCH_GROUP_SIZE); + int upperBound = Math.min(struct.fields.size() * 2, (group + 1) * DISPATCH_GROUP_SIZE); if (group == 0) { builder .append(" if (remoteField.matchedId >= 0 && remoteField.matchedId < ") @@ -801,114 +952,181 @@ private void writeCompatibleDispatchRouter(String methodName, boolean record, in builder.append(" if (remoteField.matchedId < ").append(upperBound).append(") {\n"); } builder.append(" ").append(methodName).append(group).append("("); - appendCompatibleDispatchArguments(record); + appendCompatibleDispatchArguments(); builder.append(");\n"); builder.append(" return;\n"); builder.append(" }\n"); } - appendDebugRemoteRead("before skip", "remoteField", 4); - builder.append(" skipField(readContext, remoteField);\n"); - appendDebugRemoteRead("after skip", "remoteField", 4); + builder.append( + " throw new IllegalStateException(\"Invalid compatible matched id \" + remoteField.matchedId);\n"); builder.append(" }\n\n"); } private void writeCompatibleBeanDispatchGroup(int group) { int start = group * DISPATCH_GROUP_SIZE; - int end = Math.min(struct.fields.size(), start + DISPATCH_GROUP_SIZE); + int end = Math.min(struct.fields.size() * 2, start + DISPATCH_GROUP_SIZE); builder.append(" private void readCompatibleField").append(group).append("("); - appendCompatibleDispatchParameters(false); + appendCompatibleDispatchParameters(); builder.append(") {\n"); builder.append(" switch (remoteField.matchedId) {\n"); - for (int i = start; i < end; i++) { - SourceField field = struct.fields.get(i); - builder.append(" case ").append(field.id).append(":\n"); - appendDebugRemoteRead("before read", "remoteField", 8); - builder - .append(" if (canReadGeneratedField(remoteField, fieldsById[") - .append(field.id) - .append("])) {\n"); - builder - .append( - " Object fieldValue = readCompatibleFieldValue(readContext, remoteField, fieldsById[") - .append(field.id) - .append("]);\n"); - appendDebugRemoteRead("after read", "remoteField", 10); - builder - .append(" ") - .append(field.writeStatement("value", field.castExpression("fieldValue"))) - .append("\n"); - builder.append(" } else {\n"); - appendDebugRemoteRead("before skip", "remoteField", 10); - builder.append(" skipField(readContext, remoteField);\n"); - appendDebugRemoteRead("after skip", "remoteField", 10); - builder.append(" }\n"); - builder.append(" return;\n"); - } - builder.append(" default:\n"); - appendDebugRemoteRead("before skip", "remoteField", 8); - builder.append(" skipField(readContext, remoteField);\n"); - appendDebugRemoteRead("after skip", "remoteField", 8); - builder.append(" }\n"); - builder.append(" }\n\n"); - } - - private void writeCompatibleRecordDispatchGroup(int group) { - int start = group * DISPATCH_GROUP_SIZE; - int end = Math.min(struct.fields.size(), start + DISPATCH_GROUP_SIZE); - builder.append(" private void readCompatibleRecordField").append(group).append("("); - appendCompatibleDispatchParameters(true); - builder.append(") {\n"); - builder.append(" switch (remoteField.matchedId) {\n"); - for (int i = start; i < end; i++) { - SourceField field = struct.fields.get(i); - builder.append(" case ").append(field.id).append(":\n"); + for (int matchedId = start; matchedId < end; matchedId++) { + SourceField field = struct.fields.get(matchedId >> 1); + builder.append(" case ").append(matchedId).append(": {\n"); + if ((matchedId & 1) == 0) { + String fieldInfoName = "fieldInfo"; + builder + .append(" SerializationFieldInfo ") + .append(fieldInfoName) + .append(" = fieldsById[") + .append(field.id) + .append("];\n"); + appendDebugRead("before", fieldInfoName, 8); + if (canEmitDirectReadField(field)) { + builder.append(" MemoryBuffer buffer = readContext.getBuffer();\n"); + appendDirectRead(field); + } else { + String fieldValueName = "fieldValue" + matchedId; + builder + .append(" Object ") + .append(fieldValueName) + .append(" = readFieldValue(readContext, ") + .append(fieldInfoName) + .append(");\n"); + builder + .append(" ") + .append(field.writeStatement("value", field.castExpression(fieldValueName))) + .append("\n"); + } + appendDebugRead("after", fieldInfoName, 8); + builder.append(" return;\n"); + builder.append(" }\n"); + continue; + } appendDebugRemoteRead("before read", "remoteField", 8); - builder - .append(" if (canReadGeneratedField(remoteField, fieldsById[") - .append(field.id) - .append("])) {\n"); - builder - .append(" values[") - .append(field.id) - .append("] = readCompatibleFieldValue(readContext, remoteField, fieldsById[") - .append(field.id) - .append("]);\n"); - appendDebugRemoteRead("after read", "remoteField", 10); - builder.append(" } else {\n"); - appendDebugRemoteRead("before skip", "remoteField", 10); - builder.append(" skipField(readContext, remoteField);\n"); - appendDebugRemoteRead("after skip", "remoteField", 10); - builder.append(" }\n"); + String scalarRead = + compatibleScalarReadExpression( + field, "remoteField.serializationFieldInfo", "fieldsById[" + field.id + "]"); + if (scalarRead != null) { + builder.append(" ").append(field.writeStatement("value", scalarRead)).append("\n"); + appendDebugRemoteRead("after read", "remoteField", 8); + } else { + String fieldValueName = "fieldValue" + matchedId; + builder + .append(" Object ") + .append(fieldValueName) + .append(" = readCompatibleFieldValue(readContext, remoteField, fieldsById[") + .append(field.id) + .append("]);\n"); + appendDebugRemoteRead("after read", "remoteField", 8); + builder + .append(" ") + .append(field.writeStatement("value", field.castExpression(fieldValueName))) + .append("\n"); + } builder.append(" return;\n"); + builder.append(" }\n"); } builder.append(" default:\n"); - appendDebugRemoteRead("before skip", "remoteField", 8); - builder.append(" skipField(readContext, remoteField);\n"); - appendDebugRemoteRead("after skip", "remoteField", 8); + builder.append( + " throw new IllegalStateException(\"Invalid compatible matched id \" + remoteField.matchedId);\n"); builder.append(" }\n"); builder.append(" }\n\n"); } - private void appendCompatibleDispatchParameters(boolean record) { + private void appendCompatibleDispatchParameters() { builder.append("ReadContext readContext, "); - if (record) { - builder.append("Object[] values, "); - } else { - builder.append(struct.typeName).append(" value, "); - } + builder.append(struct.typeName).append(" value, "); builder.append("RemoteFieldInfo remoteField"); } - private void appendCompatibleDispatchArguments(boolean record) { + private void appendCompatibleDispatchArguments() { builder.append("readContext, "); - if (record) { - builder.append("values, "); - } else { - builder.append("value, "); - } + builder.append("value, "); builder.append("remoteField"); } + private String compatibleScalarReadExpression( + SourceField field, String remoteFieldInfo, String localFieldInfo) { + String helper; + switch (field.erasedType) { + case "boolean": + helper = "readBooleanTarget"; + break; + case "java.lang.Boolean": + helper = "readBoxedBooleanTarget"; + break; + case "byte": + helper = "readByteTarget"; + break; + case "java.lang.Byte": + helper = "readBoxedByteTarget"; + break; + case "short": + helper = "readShortTarget"; + break; + case "java.lang.Short": + helper = "readBoxedShortTarget"; + break; + case "int": + helper = "readIntTarget"; + break; + case "java.lang.Integer": + helper = "readBoxedIntTarget"; + break; + case "long": + helper = "readLongTarget"; + break; + case "java.lang.Long": + helper = "readBoxedLongTarget"; + break; + case "float": + helper = "readFloatTarget"; + break; + case "java.lang.Float": + helper = "readBoxedFloatTarget"; + break; + case "double": + helper = "readDoubleTarget"; + break; + case "java.lang.Double": + helper = "readBoxedDoubleTarget"; + break; + case "java.lang.String": + helper = "readStringTarget"; + break; + case "java.math.BigDecimal": + helper = "readDecimalTarget"; + break; + case "org.apache.fory.type.unsigned.UInt8": + helper = "readUInt8Target"; + break; + case "org.apache.fory.type.unsigned.UInt16": + helper = "readUInt16Target"; + break; + case "org.apache.fory.type.unsigned.UInt32": + helper = "readUInt32Target"; + break; + case "org.apache.fory.type.unsigned.UInt64": + helper = "readUInt64Target"; + break; + case "org.apache.fory.type.Float16": + helper = "readFloat16Target"; + break; + case "org.apache.fory.type.BFloat16": + helper = "readBFloat16Target"; + break; + default: + return null; + } + return "FieldConverters." + + helper + + "(readContext, " + + remoteFieldInfo + + ", " + + localFieldInfo + + ")"; + } + private void appendDebugWrite(String stage, String fieldInfoName, int indent) { if (!struct.debug) { return; @@ -982,18 +1200,16 @@ private void writeCopy() { + "])")) .append(";\n"); } - builder - .append(" ") - .append(struct.typeName) - .append(" copied = new ") - .append(struct.typeName) - .append("("); - appendRecordConstructorArguments("field"); - builder.append(");\n"); + appendRecordConstruction("copied", "field", 4); builder.append(" copyContext.reference(value, copied);\n"); builder.append(" return copied;\n"); } else { - builder.append(" ").append(struct.typeName).append(" copied = newBean();\n"); + builder + .append(" ") + .append(struct.typeName) + .append(" copied = ") + .append(newGeneratedBeanExpression()) + .append(";\n"); builder.append(" copyContext.reference(value, copied);\n"); for (SourceField field : struct.fields) { builder @@ -1019,11 +1235,28 @@ private void writeDescriptorHelpers() { " private static TypeExtMeta meta(int typeId, boolean nullable, boolean trackingRef) {\n"); builder.append(" return TypeExtMeta.of(typeId, nullable, trackingRef);\n"); builder.append(" }\n"); + builder.append( + " private static java.lang.reflect.Field declaredField(Class type, String declaringClassName, String fieldName) {\n"); + builder.append(" try {\n"); + builder.append( + " return Class.forName(declaringClassName, false, type.getClassLoader()).getDeclaredField(fieldName);\n"); + builder.append(" } catch (ReflectiveOperationException e) {\n"); + builder.append(" throw new IllegalStateException(e);\n"); + builder.append(" }\n"); + builder.append(" }\n"); } - private void appendRecordConstructorArguments(String prefix) { + private void appendRecordConstruction(String variableName, String prefix, int indent) { + appendIndent(indent); + builder + .append(struct.typeName) + .append(" ") + .append(variableName) + .append(" = new ") + .append(struct.typeName) + .append("("); for (int i = 0; i < struct.recordConstructorFields.size(); i++) { - if (i > 0) { + if (i != 0) { builder.append(", "); } SourceField field = struct.recordConstructorFields.get(i); @@ -1033,6 +1266,11 @@ private void appendRecordConstructorArguments(String prefix) { builder.append(field.defaultValue()); } } + builder.append(");\n"); + } + + private String newGeneratedBeanExpression() { + return "(" + struct.typeName + ") generatedObjectInstantiator.newInstance()"; } private static String escape(String value) { diff --git a/java/fory-annotation-processor/src/test/java/org/apache/fory/annotation/processing/ForyStructProcessorTest.java b/java/fory-annotation-processor/src/test/java/org/apache/fory/annotation/processing/ForyStructProcessorTest.java index 2b4df3d7c5..3e915e9239 100644 --- a/java/fory-annotation-processor/src/test/java/org/apache/fory/annotation/processing/ForyStructProcessorTest.java +++ b/java/fory-annotation-processor/src/test/java/org/apache/fory/annotation/processing/ForyStructProcessorTest.java @@ -40,7 +40,7 @@ import org.apache.fory.ThreadSafeFory; import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.MetaWriteContext; -import org.apache.fory.exception.DeserializationException; +import org.apache.fory.exception.ForyException; import org.apache.fory.exception.SerializationException; import org.apache.fory.meta.FieldInfo; import org.apache.fory.meta.TypeDef; @@ -192,18 +192,77 @@ public void testPrivateFieldUsesAccessibleAccessors() throws Exception { } @Test - public void testPrivateFieldWithoutAccessorsFailsCompilation() throws Exception { + public void testPrivateFieldAccessor() throws Exception { CompilationResult result = compile( - "test.BadStruct", + "test.InaccessibleStruct", "package test;\n" + "import org.apache.fory.annotation.ForyStruct;\n" - + "@ForyStruct public class BadStruct {\n" + + "@ForyStruct public class InaccessibleStruct {\n" + " private int id;\n" - + " public BadStruct() {}\n" + + " public InaccessibleStruct() {}\n" + "}\n"); - Assert.assertFalse(result.success); - Assert.assertTrue(result.diagnostics().contains("getter/setter"), result.diagnostics()); + Assert.assertTrue(result.success, result.diagnostics()); + String generatedSource = + result.generatedSource("test/InaccessibleStruct_ForyNativeSerializer.java"); + Assert.assertTrue(generatedSource.contains("FieldAccessor fieldAccessor0"), generatedSource); + try (URLClassLoader loader = result.classLoader()) { + Class type = loader.loadClass("test.InaccessibleStruct"); + Object value = type.getConstructor().newInstance(); + setField(type, value, "id", 8); + Fory fory = + Fory.builder() + .withXlang(false) + .withClassLoader(loader) + .withCodegen(false) + .requireClassRegistration(false) + .withCompatible(false) + .build(); + Object roundTrip = fory.deserialize(fory.serialize(value)); + Assert.assertEquals(getField(type, roundTrip, "id"), 8); + } + } + + @Test + public void testNoArgFinalFieldsUseInstantiator() throws Exception { + CompilationResult result = + compile( + "test.FinalNoArgStruct", + "package test;\n" + + "import org.apache.fory.annotation.ForyField;\n" + + "import org.apache.fory.annotation.ForyStruct;\n" + + "@ForyStruct public class FinalNoArgStruct {\n" + + " @ForyField(id = 1) private final int id;\n" + + " @ForyField(id = 2) private final String name;\n" + + " public FinalNoArgStruct(int id, String name) {\n" + + " this.id = id;\n" + + " this.name = name;\n" + + " }\n" + + " public int id() { return id; }\n" + + " public String name() { return name; }\n" + + "}\n"); + Assert.assertTrue(result.success, result.diagnostics()); + String generatedSource = + result.generatedSource("test/FinalNoArgStruct_ForyNativeSerializer.java"); + Assert.assertTrue(generatedSource.contains("ObjectInstantiator generatedObjectInstantiator")); + Assert.assertTrue(generatedSource.contains("generatedObjectInstantiator.newInstance()")); + Assert.assertTrue(generatedSource.contains("FieldAccessor fieldAccessor0")); + Assert.assertTrue(generatedSource.contains("FieldAccessor fieldAccessor1")); + try (URLClassLoader loader = result.classLoader()) { + Class type = loader.loadClass("test.FinalNoArgStruct"); + Object value = type.getConstructor(int.class, String.class).newInstance(9, "final"); + Fory fory = + Fory.builder() + .withXlang(false) + .withClassLoader(loader) + .withCodegen(false) + .requireClassRegistration(false) + .withCompatible(false) + .build(); + Object roundTrip = fory.deserialize(fory.serialize(value)); + Assert.assertEquals(invoke(type, roundTrip, "id"), 9); + Assert.assertEquals(invoke(type, roundTrip, "name"), "final"); + } } @Test @@ -483,8 +542,16 @@ public void testRecordReadAndCopyUseCanonicalConstructor() throws Exception { + "}\n"); Assert.assertTrue(result.success, result.diagnostics()); String generatedSource = result.generatedSource("test/RecordStruct_ForyNativeSerializer.java"); - Assert.assertTrue(generatedSource.contains("private void readCompatibleRecordField0(")); Assert.assertTrue(generatedSource.contains("switch (remoteField.matchedId)")); + Assert.assertTrue(generatedSource.contains("case 0:")); + Assert.assertTrue(generatedSource.contains("case 1:")); + Assert.assertTrue(generatedSource.contains("remoteField.matchedId == -1")); + Assert.assertFalse(generatedSource.contains("ObjectInstantiator generatedObjectInstantiator")); + Assert.assertFalse(generatedSource.contains("private final Object[] recordArgs")); + Assert.assertFalse(generatedSource.contains("newInstanceWithArguments")); + Assert.assertTrue(generatedSource.contains("new test.RecordStruct(")); + Assert.assertFalse(generatedSource.contains("Object[] values")); + Assert.assertFalse(generatedSource.contains("readCompatibleRecordField")); try (URLClassLoader loader = result.classLoader()) { Class type = loader.loadClass("test.RecordStruct"); Object value = @@ -556,6 +623,14 @@ public void testCompatibleReadUsesGeneratedSerializer() throws Exception { generatedSource.contains("readCompatibleField(readContext, value, remoteField)")); Assert.assertTrue(generatedSource.contains("private void readCompatibleField0(")); Assert.assertTrue(generatedSource.contains("switch (remoteField.matchedId)")); + Assert.assertTrue(generatedSource.contains("case 0:")); + Assert.assertTrue(generatedSource.contains("case 1:")); + Assert.assertTrue(generatedSource.contains("remoteField.matchedId == -1")); + Assert.assertTrue(generatedSource.contains("ObjectInstantiator generatedObjectInstantiator")); + Assert.assertTrue(generatedSource.contains("generatedObjectInstantiator.newInstance()")); + Assert.assertFalse(generatedSource.contains("canReadGeneratedField")); + Assert.assertTrue(generatedSource.contains("FieldConverters.readIntTarget")); + Assert.assertTrue(generatedSource.contains("FieldConverters.readBooleanTarget")); try (URLClassLoader writerLoader = writerResult.classLoader(); URLClassLoader readerLoader = readerResult.classLoader()) { Class writerType = writerLoader.loadClass("test.EvolvingStruct"); @@ -911,7 +986,7 @@ public void testStaticArrayTypeListWritesDenseArrayPayload() throws Exception { } @Test - public void testStaticListArrayCompatibleReadPayloadValidation() throws Exception { + public void testStaticListArraySchemaValidation() throws Exception { CompilationResult nullableListWriter = compile( "test.ListArrayMismatchStruct", @@ -944,12 +1019,9 @@ public void testStaticListArrayCompatibleReadPayloadValidation() throws Exceptio Object result = reader.deserialize(writer.serialize(writerValue)); Assert.assertTrue( Arrays.equals((int[]) getField(readerType, result, "values"), new int[] {1, 2, 3})); - - Object nullElementWriterValue = writerType.getConstructor().newInstance(); - setField(writerType, nullElementWriterValue, "values", Arrays.asList(1, null, 3)); - byte[] nullElementPayload = writer.serialize(nullElementWriterValue); - Assert.expectThrows( - DeserializationException.class, () -> reader.deserialize(nullElementPayload)); + setField(writerType, writerValue, "values", Arrays.asList(1, null, 3)); + byte[] nullElementPayload = writer.serialize(writerValue); + Assert.expectThrows(ForyException.class, () -> reader.deserialize(nullElementPayload)); } CompilationResult nestedListWriter = @@ -986,8 +1058,7 @@ public void testStaticListArrayCompatibleReadPayloadValidation() throws Exceptio Object writerValue = writerType.getConstructor().newInstance(); setField(writerType, writerValue, "values", Arrays.asList(Arrays.asList(1, 2, 3))); byte[] payload = writer.serialize(writerValue); - Object result = reader.deserialize(payload); - Assert.assertNull(getField(readerType, result, "values")); + Assert.expectThrows(ForyException.class, () -> reader.deserialize(payload)); } } diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java index ce5bdae3db..c431f7ec7c 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java @@ -2327,6 +2327,12 @@ protected TypeRef compatibleLocalTypeRef(Descriptor descriptor) { : TypeRef.of(descriptor.getField().getGenericType()); } + protected TypeRef readValueTypeRef(Descriptor descriptor) { + return hasCompatibleCollectionArrayRead(descriptor) + ? compatibleLocalTypeRef(descriptor) + : descriptor.getTypeRef(); + } + private Expression deserializeForNotNullForField( Expression buffer, Descriptor descriptor, Expression serializer) { TypeRef typeRef = descriptor.getTypeRef(); diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/CompatibleCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/CompatibleCodecBuilder.java index 4cae8b3d99..a95f478d70 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/CompatibleCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/CompatibleCodecBuilder.java @@ -25,7 +25,11 @@ import java.lang.reflect.Field; import java.lang.reflect.Member; +import java.math.BigDecimal; import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.SortedMap; @@ -54,10 +58,16 @@ import org.apache.fory.serializer.Serializers; import org.apache.fory.serializer.converter.FieldConverter; import org.apache.fory.serializer.converter.FieldConverters; +import org.apache.fory.type.BFloat16; import org.apache.fory.type.Descriptor; import org.apache.fory.type.DescriptorBuilder; import org.apache.fory.type.DescriptorGrouper; +import org.apache.fory.type.Float16; import org.apache.fory.type.ScalaTypes; +import org.apache.fory.type.unsigned.UInt16; +import org.apache.fory.type.unsigned.UInt32; +import org.apache.fory.type.unsigned.UInt64; +import org.apache.fory.type.unsigned.UInt8; import org.apache.fory.util.DefaultValueUtils; import org.apache.fory.util.ExceptionUtils; import org.apache.fory.util.Preconditions; @@ -80,10 +90,16 @@ */ public class CompatibleCodecBuilder extends ObjectCodecBuilder { private static final Logger LOG = LoggerFactory.getLogger(CompatibleCodecBuilder.class); + private static final String REMOTE_FIELD_INFOS_NAME = "_f_remoteFieldInfos"; + private static final String LOCAL_FIELD_INFOS_NAME = "_f_localFieldInfosByRemoteOrder"; + private static final String CONSTRUCTOR_TYPE_DEF_NAME = "_f_typeDef"; private final TypeDef typeDef; private final String defaultValueLanguage; private final DefaultValueUtils.DefaultValueField[] defaultValueFields; + private final Map fieldInfoIds = new IdentityHashMap<>(); + private String remoteFieldInfosName; + private String localFieldInfosName; public CompatibleCodecBuilder(TypeRef beanType, Fory fory, TypeDef typeDef) { super(beanType, fory, GeneratedCompatibleSerializer.class); @@ -93,6 +109,9 @@ public CompatibleCodecBuilder(TypeRef beanType, Fory fory, TypeDef typeDef) { this.typeDef = typeDef; DescriptorGrouper grouper = typeResolver(r -> r.createDescriptorGrouper(typeDef, beanClass)); List sortedDescriptors = grouper.getSortedDescriptors(); + for (int i = 0; i < sortedDescriptors.size(); i++) { + fieldInfoIds.put(sortedDescriptors.get(i), i); + } if (org.apache.fory.util.Utils.DEBUG_OUTPUT_ENABLED) { LOG.info( "========== {} sorted descriptors for {} ==========", @@ -161,6 +180,7 @@ public String genCode() { ctx.extendsClasses(ctx.type(parentSerializerClass)); ctx.reserveName(POJO_CLASS_TYPE_NAME); ctx.reserveName(SERIALIZER_FIELD_NAME); + ctx.reserveName(CONSTRUCTOR_TYPE_DEF_NAME); String constructorCode = StringUtils.format( "" @@ -209,7 +229,9 @@ public String genCode() { TypeResolver.class, CONSTRUCTOR_TYPE_RESOLVER_NAME, Class.class, - POJO_CLASS_TYPE_NAME); + POJO_CLASS_TYPE_NAME, + TypeDef.class, + CONSTRUCTOR_TYPE_DEF_NAME); return ctx.genCode(); } @@ -292,13 +314,23 @@ protected List deserializePrimitives( } appendPrimitiveBulkReads(expressions, bean, buffer, bulkGroups); for (Descriptor descriptor : group) { - expressions.add( - deserializeField( - buffer, - descriptor, - value -> - setFieldValue( - bean, descriptor, tryInlineCast(value, descriptor.getTypeRef())))); + FieldConverter converter = descriptor.getFieldConverter(); + Expression targetValue = + converter == null ? null : fieldConverterTargetRead(descriptor, converter); + if (targetValue != null) { + expressions.add( + new Expression.ListExpression( + targetValue, + setFieldConverterTargetValue(bean, descriptor, converter, targetValue))); + } else { + expressions.add( + deserializeField( + buffer, + descriptor, + value -> + setFieldValue( + bean, descriptor, tryInlineCast(value, descriptor.getTypeRef())))); + } } } appendPrimitiveBulkReads(expressions, bean, buffer, bulkGroups); @@ -312,19 +344,11 @@ protected Expression deserializeField( if (converter == null) { return super.deserializeField(buffer, descriptor, callback); } - StaticInvoke sourceValue = - new StaticInvoke( - FieldConverters.class, - "readSourceScalar", - OBJECT_TYPE, - readContextRef(), - Literal.ofInt(FieldConverters.fromDispatchId(converter)), - Literal.ofClass(FieldConverters.fromType(converter)), - Literal.ofBoolean(descriptor.isNullable()), - Literal.ofBoolean( - new SerializationFieldInfo(typeResolver, descriptor).useDeclaredTypeInfo), - Literal.ofString(FieldConverters.fieldName(converter))); - return new Expression.ListExpression(sourceValue, callback.apply(sourceValue)); + Expression targetValue = fieldConverterTargetRead(descriptor, converter); + Preconditions.checkState( + targetValue != null, + "Unsupported compatible scalar converter target " + FieldConverters.toType(converter)); + return new Expression.ListExpression(targetValue, callback.apply(targetValue)); } @Override @@ -332,26 +356,7 @@ protected Expression setFieldValue(Expression bean, Descriptor descriptor, Expre if (descriptor.getField() == null) { FieldConverter converter = descriptor.getFieldConverter(); if (converter != null) { - Field field = converter.getField(); - StaticInvoke convertedValue = - new StaticInvoke( - FieldConverters.class, - "convertValue", - OBJECT_TYPE, - Literal.ofInt(FieldConverters.fromDispatchId(converter)), - Literal.ofClass(FieldConverters.fromType(converter)), - Literal.ofInt(FieldConverters.toDispatchId(converter)), - Literal.ofClass(FieldConverters.toType(converter)), - Literal.ofString(FieldConverters.fieldName(converter)), - value); - Expression converted = new Expression.Cast(convertedValue, TypeRef.of(field.getType())); - Descriptor newDesc = - new DescriptorBuilder(descriptor) - .field(field) - .type(field.getType()) - .typeRef(TypeRef.of(field.getType())) - .build(); - return super.setFieldValue(bean, newDesc, converted); + return setFieldConverterTargetValue(bean, descriptor, converter, value); } // Field doesn't exist in current class, skip set this field value. // Note that the field value shouldn't be an inlined value, otherwise field value read may @@ -362,6 +367,82 @@ protected Expression setFieldValue(Expression bean, Descriptor descriptor, Expre return super.setFieldValue(bean, descriptor, value); } + private Expression setFieldConverterTargetValue( + Expression bean, Descriptor descriptor, FieldConverter converter, Expression value) { + Field field = converter.getField(); + Descriptor targetDescriptor = + new DescriptorBuilder(descriptor) + .field(field) + .type(field.getType()) + .typeRef(TypeRef.of(field.getType())) + .build(); + return super.setFieldValue(bean, targetDescriptor, value); + } + + private Expression fieldConverterTargetRead(Descriptor descriptor, FieldConverter converter) { + Class targetType = converter.getField().getType(); + String helper = fieldConverterTargetReader(targetType); + if (helper == null) { + return null; + } + return new StaticInvoke( + FieldConverters.class, + helper, + TypeRef.of(targetType), + readContextRef(), + remoteFieldInfo(descriptor), + localFieldInfo(descriptor)); + } + + private static String fieldConverterTargetReader(Class targetType) { + if (targetType == boolean.class) { + return "readBooleanTarget"; + } else if (targetType == Boolean.class) { + return "readBoxedBooleanTarget"; + } else if (targetType == byte.class) { + return "readByteTarget"; + } else if (targetType == Byte.class) { + return "readBoxedByteTarget"; + } else if (targetType == short.class) { + return "readShortTarget"; + } else if (targetType == Short.class) { + return "readBoxedShortTarget"; + } else if (targetType == int.class) { + return "readIntTarget"; + } else if (targetType == Integer.class) { + return "readBoxedIntTarget"; + } else if (targetType == long.class) { + return "readLongTarget"; + } else if (targetType == Long.class) { + return "readBoxedLongTarget"; + } else if (targetType == float.class) { + return "readFloatTarget"; + } else if (targetType == Float.class) { + return "readBoxedFloatTarget"; + } else if (targetType == double.class) { + return "readDoubleTarget"; + } else if (targetType == Double.class) { + return "readBoxedDoubleTarget"; + } else if (targetType == String.class) { + return "readStringTarget"; + } else if (targetType == BigDecimal.class) { + return "readDecimalTarget"; + } else if (targetType == UInt8.class) { + return "readUInt8Target"; + } else if (targetType == UInt16.class) { + return "readUInt16Target"; + } else if (targetType == UInt32.class) { + return "readUInt32Target"; + } else if (targetType == UInt64.class) { + return "readUInt64Target"; + } else if (targetType == Float16.class) { + return "readFloat16Target"; + } else if (targetType == BFloat16.class) { + return "readBFloat16Target"; + } + return null; + } + private static boolean hasFieldConverter(List descriptors) { for (Descriptor descriptor : descriptors) { if (descriptor.getFieldConverter() != null) { @@ -371,6 +452,124 @@ private static boolean hasFieldConverter(List descriptors) { return false; } + @Override + protected TypeRef readValueTypeRef(Descriptor descriptor) { + FieldConverter converter = descriptor.getFieldConverter(); + if (converter != null) { + return TypeRef.of(converter.getField().getGenericType()); + } + return super.readValueTypeRef(descriptor); + } + + private Expression remoteFieldInfo(Descriptor descriptor) { + ensureCompatibleFieldInfos(); + return fieldInfo(remoteFieldInfosName, descriptor); + } + + private Expression localFieldInfo(Descriptor descriptor) { + ensureCompatibleFieldInfos(); + return fieldInfo(localFieldInfosName, descriptor); + } + + private Expression fieldInfo(String fieldInfosName, Descriptor descriptor) { + Integer fieldInfoId = fieldInfoIds.get(descriptor); + Preconditions.checkState( + fieldInfoId != null, "Unknown compatible field descriptor " + descriptor); + return new Expression.Reference( + fieldInfosName + "[" + fieldInfoId + "]", TypeRef.of(SerializationFieldInfo.class)); + } + + private void ensureCompatibleFieldInfos() { + if (remoteFieldInfosName != null) { + return; + } + remoteFieldInfosName = ctx.newName(REMOTE_FIELD_INFOS_NAME); + localFieldInfosName = ctx.newName(LOCAL_FIELD_INFOS_NAME); + Expression constructorResolver = + new Expression.Reference(CONSTRUCTOR_TYPE_RESOLVER_NAME, TypeRef.of(TypeResolver.class)); + Expression constructorClass = + new Expression.Reference(POJO_CLASS_TYPE_NAME, TypeRef.of(Class.class)); + Expression constructorTypeDef = + new Expression.Reference(CONSTRUCTOR_TYPE_DEF_NAME, TypeRef.of(TypeDef.class)); + TypeRef fieldInfoArrayType = + TypeRef.of(SerializationFieldInfo[].class); + ctx.addField( + false, + true, + ctx.type(fieldInfoArrayType), + remoteFieldInfosName, + new StaticInvoke( + CompatibleCodecBuilder.class, + "buildRemoteFieldInfos", + fieldInfoArrayType, + constructorResolver, + constructorClass, + constructorTypeDef)); + ctx.addField( + false, + true, + ctx.type(fieldInfoArrayType), + localFieldInfosName, + new StaticInvoke( + CompatibleCodecBuilder.class, + "buildLocalFieldInfosByRemoteOrder", + fieldInfoArrayType, + constructorResolver, + constructorClass, + constructorTypeDef)); + } + + public static SerializationFieldInfo[] buildRemoteFieldInfos( + TypeResolver typeResolver, Class cls, TypeDef typeDef) { + List descriptors = + typeResolver.createDescriptorGrouper(typeDef, cls).getSortedDescriptors(); + SerializationFieldInfo[] fieldInfos = new SerializationFieldInfo[descriptors.size()]; + for (int i = 0; i < descriptors.size(); i++) { + fieldInfos[i] = new SerializationFieldInfo(typeResolver, descriptors.get(i)); + } + return fieldInfos; + } + + public static SerializationFieldInfo[] buildLocalFieldInfosByRemoteOrder( + TypeResolver typeResolver, Class cls, TypeDef typeDef) { + List descriptors = + typeResolver.createDescriptorGrouper(typeDef, cls).getSortedDescriptors(); + Map localDescriptors = localDescriptorsByField(typeResolver, cls); + SerializationFieldInfo[] fieldInfos = new SerializationFieldInfo[descriptors.size()]; + for (int i = 0; i < descriptors.size(); i++) { + Descriptor localDescriptor = localDescriptor(descriptors.get(i), localDescriptors); + if (localDescriptor != null) { + fieldInfos[i] = new SerializationFieldInfo(typeResolver, localDescriptor); + } + } + return fieldInfos; + } + + private static Map localDescriptorsByField( + TypeResolver typeResolver, Class cls) { + Collection descriptors = typeResolver.getFieldDescriptors(cls, true); + Map descriptorsByField = new HashMap<>(); + for (Descriptor descriptor : descriptors) { + Field field = descriptor.getField(); + if (field != null) { + descriptorsByField.put(field, descriptor); + } + } + return descriptorsByField; + } + + private static Descriptor localDescriptor( + Descriptor remoteDescriptor, Map localDescriptors) { + FieldConverter converter = remoteDescriptor.getFieldConverter(); + if (converter != null) { + return localDescriptors.get(converter.getField()); + } + if (remoteDescriptor.getField() != null) { + return remoteDescriptor; + } + return null; + } + private void appendPrimitiveBulkReads( List expressions, Expression bean, diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java index edb7eccb0e..e22cd2b38c 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java @@ -905,8 +905,7 @@ protected Expression deserializeGroup( for (Descriptor d : group) { ExpressionVisitor.ExprHolder exprHolder = ExpressionVisitor.ExprHolder.of("bean", bean); walkPath.add(d.getDeclaringClass() + d.getName()); - TypeRef castTypeRef = - hasCompatibleCollectionArrayRead(d) ? compatibleLocalTypeRef(d) : d.getTypeRef(); + TypeRef castTypeRef = readValueTypeRef(d); Expression action = deserializeField( buffer, @@ -945,8 +944,7 @@ protected Expression deserializeGroupForRecord( ListExpression groupExpressions = new ListExpression(); // use Reference to cut-off expr dependency. for (Descriptor d : group) { - TypeRef castTypeRef = - hasCompatibleCollectionArrayRead(d) ? compatibleLocalTypeRef(d) : d.getTypeRef(); + TypeRef castTypeRef = readValueTypeRef(d); Expression value = deserializeField(buffer, d, expr -> expr); Expression action = setFieldValue(bean, d, tryInlineCast(value, castTypeRef)); groupExpressions.add(action); diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java index fc872e18d7..4a43269771 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java @@ -31,6 +31,7 @@ import org.apache.fory.codegen.CodeGenerator; import org.apache.fory.codegen.Expression; import org.apache.fory.codegen.Expression.Invoke; +import org.apache.fory.codegen.Expression.Literal; import org.apache.fory.codegen.Expression.Reference; import org.apache.fory.config.ForyBuilder; import org.apache.fory.context.ReadContext; @@ -65,13 +66,18 @@ public final class StaticCompatibleCodecBuilder extends ObjectCodecBuilder { private final List localDescriptors; private final boolean debug; + private Expression objectInstantiator; + private String recordArgsFieldName; public StaticCompatibleCodecBuilder(TypeRef beanType, Fory fory, TypeDef typeDef) { super(beanType, fory, GeneratedStaticCompatibleSerializer.class); Preconditions.checkArgument( !fory.getConfig().checkClassVersion(), "Class version check should be disabled when compatible mode is enabled."); - localDescriptors = Collections.unmodifiableList(Descriptor.getDescriptors(beanClass)); + localDescriptors = + Collections.unmodifiableList( + typeResolver.normalizeFieldDescriptors( + beanClass, true, Descriptor.getDescriptors(beanClass))); debug = beanClass.isAnnotationPresent(ForyStruct.class) && beanClass.isAnnotationPresent(ForyDebug.class); @@ -90,6 +96,16 @@ public String genCode() { ctx.extendsClasses(ctx.type(parentSerializerClass)); ctx.reserveName(POJO_CLASS_TYPE_NAME); ctx.addImports(List.class, TypeDef.class, Descriptor.class, SerializationFieldInfo.class); + if (!isRecord || !recordCtrAccessible) { + generatedObjectInstantiator(); + } + if (isRecord && !recordCtrAccessible) { + recordArgsFieldName(RecordUtils.getRecordComponents(beanClass).length); + } + String readCompatibleCode = isRecord ? genRecordCompatibleRead() : genObjectCompatibleRead(); + genDispatchMethods(); + // Read/dispatch generation can add serializer fields and instance-init code. Add the + // constructor after it because CodegenContext snapshots instance init into constructors. String constructorCode = StringUtils.format( "" @@ -115,12 +131,7 @@ public String genCode() { "_f_typeDef"); ctx.addMethod("getGeneratedDescriptors", "return Descriptor.getDescriptors(type);", List.class); ctx.overrideMethod( - "readCompatible", - isRecord ? genRecordCompatibleRead() : genObjectCompatibleRead(), - Object.class, - ReadContext.class, - READ_CONTEXT_NAME); - genDispatchMethods(); + "readCompatible", readCompatibleCode, Object.class, ReadContext.class, READ_CONTEXT_NAME); return ctx.genCode(); } @@ -142,8 +153,12 @@ public Expression buildDecodeExpression() { private String genObjectCompatibleRead() { ctx.clearExprState(); - Code.ExprCode beanCode = newBean().genCode(ctx); - String bean = beanCode.value().toString(); + Code.ExprCode beanCode = + new Invoke(generatedObjectInstantiator(), "newInstance", OBJECT_TYPE).genCode(ctx); + String bean = + ctx.sourcePublicAccessible(beanClass) + ? "((" + ctx.type(beanClass) + ") " + beanCode.value() + ")" + : beanCode.value().toString(); StringBuilder code = new StringBuilder(); if (StringUtils.isNotBlank(beanCode.code())) { code.append(beanCode.code()).append('\n'); @@ -174,94 +189,72 @@ private String genObjectCompatibleRead() { private String genRecordCompatibleRead() { RecordComponent[] components = RecordUtils.getRecordComponents(beanClass); StringBuilder code = new StringBuilder(); - String recordValues; - if (recordCtrAccessible) { - recordValues = "_f_recordValues"; - code.append("Object[] ") - .append(recordValues) - .append(" = new Object[") - .append(components.length) - .append("];\n"); - for (int i = 0; i < components.length; i++) { - String defaultValue = boxedDefaultValue(components[i].getType()); - if (defaultValue != null) { - code.append(recordValues) - .append("[") - .append(i) - .append("] = ") - .append(defaultValue) - .append(";\n"); - } - } - } else { - ctx.clearExprState(); - Code.ExprCode recordValuesCode = buildComponentsArray().genCode(ctx); - recordValues = recordValuesCode.value().toString(); - if (StringUtils.isNotBlank(recordValuesCode.code())) { - code.append(recordValuesCode.code()).append('\n'); - } + for (int i = 0; i < components.length; i++) { + Class componentType = components[i].getType(); + code.append(recordLocalType(componentType)) + .append(" _f_recordValue") + .append(i) + .append(" = ") + .append(defaultValue(componentType)) + .append(";\n"); } code.append("for (int _f_i = 0; _f_i < remoteFields.size(); _f_i++) {\n") .append(" RemoteFieldInfo _f_remoteField = (RemoteFieldInfo) remoteFields.get(_f_i);\n") - .append(" readMatchedRecordField(") - .append(READ_CONTEXT_NAME) - .append(", ") - .append(recordValues) - .append(", _f_remoteField);\n") + .append(" if (_f_remoteField.matchedId == -1) {\n"); + appendDebugRemoteRead(code, "before skip", "_f_remoteField", 4); + code.append(" skipField(").append(READ_CONTEXT_NAME).append(", _f_remoteField);\n"); + appendDebugRemoteRead(code, "after skip", "_f_remoteField", 4); + code.append(" continue;\n") + .append(" }\n") + .append(indent(genRecordDispatchSwitch(), 2)) + .append('\n') .append("}\n"); if (recordCtrAccessible) { - code.append("return new ") - .append(ctx.type(beanClass)) - .append("(") - .append(recordConstructorArgs(components, recordValues)) - .append(");"); - } else { - ctx.clearExprState(); - Reference values = new Reference(recordValues, objectArrayTypeRef, false); - Code.ExprCode newRecord = - new Invoke( - getObjectInstantiator(beanClass), "newInstanceWithArguments", OBJECT_TYPE, values) - .genCode(ctx); - if (StringUtils.isNotBlank(newRecord.code())) { - code.append(newRecord.code()).append('\n'); + code.append("return new ").append(ctx.type(beanClass)).append("("); + for (int i = 0; i < components.length; i++) { + if (i > 0) { + code.append(", "); + } + code.append("_f_recordValue").append(i); } - code.append("return ").append(newRecord.value()).append(';'); + code.append(");"); + return code.toString(); + } + String recordArgs = recordArgsFieldName(components.length); + code.append("Object[] _f_recordArgs = this.").append(recordArgs).append(";\n"); + for (int i = 0; i < components.length; i++) { + code.append("_f_recordArgs[") + .append(i) + .append("] = ") + .append("_f_recordValue") + .append(i) + .append(";\n"); } + ctx.clearExprState(); + Reference values = new Reference("_f_recordArgs", objectArrayTypeRef, false); + Code.ExprCode newRecord = + new Invoke(generatedObjectInstantiator(), "newInstanceWithArguments", OBJECT_TYPE, values) + .genCode(ctx); + if (StringUtils.isNotBlank(newRecord.code())) { + code.append(newRecord.code()).append('\n'); + } + code.append("Object _f_record = ").append(newRecord.value()).append(";\n"); + for (int i = 0; i < components.length; i++) { + code.append("_f_recordArgs[").append(i).append("] = null;\n"); + } + code.append("return _f_record;"); return code.toString(); } private void genDispatchMethods() { - String valueType = - ctx.sourcePublicAccessible(beanClass) ? ctx.type(beanClass) : ctx.type(Object.class); - TypeRef valueTypeRef = ctx.sourcePublicAccessible(beanClass) ? beanType : OBJECT_TYPE; - int groupCount = (localDescriptors.size() + DISPATCH_GROUP_SIZE - 1) / DISPATCH_GROUP_SIZE; if (isRecord) { - ctx.addMethod( - DISPATCH_METHOD_MODIFIERS, - "readMatchedRecordField", - genDispatchRouter("readMatchedRecordField", groupCount), - void.class, - ReadContext.class, - READ_CONTEXT_NAME, - Object[].class, - "_f_recordValues", - "RemoteFieldInfo", - "_f_remoteField"); - for (int group = 0; group < groupCount; group++) { - ctx.addMethod( - DISPATCH_METHOD_MODIFIERS, - "readMatchedRecordField" + group, - genRecordDispatchGroup(group), - void.class, - ReadContext.class, - READ_CONTEXT_NAME, - Object[].class, - "_f_recordValues", - "RemoteFieldInfo", - "_f_remoteField"); - } return; } + String valueType = + ctx.sourcePublicAccessible(beanClass) ? ctx.type(beanClass) : ctx.type(Object.class); + TypeRef valueTypeRef = ctx.sourcePublicAccessible(beanClass) ? beanType : OBJECT_TYPE; + int caseCount = localDescriptors.size() * 2; + int groupCount = (caseCount + DISPATCH_GROUP_SIZE - 1) / DISPATCH_GROUP_SIZE; ctx.addMethod( DISPATCH_METHOD_MODIFIERS, "readMatchedField", @@ -290,8 +283,14 @@ private void genDispatchMethods() { private String genDispatchRouter(String methodPrefix, int groupCount) { StringBuilder code = new StringBuilder(); + code.append("if (_f_remoteField.matchedId == -1) {\n"); + appendDebugRemoteRead(code, "before skip", "_f_remoteField", 2); + code.append(" skipField(").append(READ_CONTEXT_NAME).append(", _f_remoteField);\n"); + appendDebugRemoteRead(code, "after skip", "_f_remoteField", 2); + code.append(" return;\n"); + code.append("}\n"); for (int group = 0; group < groupCount; group++) { - int upperBound = Math.min(localDescriptors.size(), (group + 1) * DISPATCH_GROUP_SIZE); + int upperBound = Math.min(localDescriptors.size() * 2, (group + 1) * DISPATCH_GROUP_SIZE); if (group == 0) { code.append("if (_f_remoteField.matchedId >= 0 && _f_remoteField.matchedId < ") .append(upperBound) @@ -305,25 +304,50 @@ private String genDispatchRouter(String methodPrefix, int groupCount) { .append("(") .append(READ_CONTEXT_NAME) .append(", "); - code.append(isRecord ? "_f_recordValues" : "_f_value"); + code.append("_f_value"); code.append(", _f_remoteField);\n").append(" return;\n").append("}\n"); } - appendDebugRemoteRead(code, "before skip", "_f_remoteField", 0); - code.append("skipField(").append(READ_CONTEXT_NAME).append(", _f_remoteField);\n"); - appendDebugRemoteRead(code, "after skip", "_f_remoteField", 0); + code.append("throw new IllegalStateException(\"Invalid compatible matched id \"") + .append(" + _f_remoteField.matchedId);\n"); return code.toString(); } private String genObjectDispatchGroup(int group, TypeRef valueTypeRef) { int start = group * DISPATCH_GROUP_SIZE; - int end = Math.min(localDescriptors.size(), start + DISPATCH_GROUP_SIZE); + int end = Math.min(localDescriptors.size() * 2, start + DISPATCH_GROUP_SIZE); StringBuilder code = new StringBuilder("switch (_f_remoteField.matchedId) {\n"); - for (int i = start; i < end; i++) { - Descriptor descriptor = localDescriptors.get(i); - code.append(" case ") - .append(i) - .append(": {\n") - .append(debugRemoteReadCode("before read", "_f_remoteField", 4)) + for (int matchedId = start; matchedId < end; matchedId++) { + int localId = matchedId >> 1; + Descriptor descriptor = localDescriptors.get(localId); + code.append(" case ").append(matchedId).append(": {\n"); + if ((matchedId & 1) == 0) { + code.append(" MemoryBuffer ") + .append(BUFFER_NAME) + .append(" = ") + .append(READ_CONTEXT_NAME) + .append(".getBuffer();\n") + .append(indent(genDirectReadAndSetFieldCode(descriptor, valueTypeRef), 4)) + .append('\n') + .append(" return;\n") + .append(" }\n"); + continue; + } + String scalarRead = + genCompatibleScalarReadExpr( + descriptor, + "_f_remoteField.serializationFieldInfo", + "localFieldInfo(" + localId + ")"); + if (scalarRead != null) { + code.append(debugRemoteReadCode("before read", "_f_remoteField", 4)) + .append(" ") + .append(genSetFieldCode(descriptor, valueTypeRef, scalarRead)) + .append('\n') + .append(debugRemoteReadCode("after read", "_f_remoteField", 4)) + .append(" return;\n") + .append(" }\n"); + continue; + } + code.append(debugRemoteReadCode("before read", "_f_remoteField", 4)) .append(" if (_f_remoteField.serializationFieldInfo.fieldConverter != null) {\n") .append(" Object _f_fieldValue = readFieldConverterSource(") .append(READ_CONTEXT_NAME) @@ -332,16 +356,9 @@ private String genObjectDispatchGroup(int group, TypeRef valueTypeRef) { .append( " _f_remoteField.serializationFieldInfo.fieldConverter.set(_f_value, _f_fieldValue);\n") .append(" } else {\n") - .append( - " SerializationFieldInfo _f_localField = localFieldInfo(_f_remoteField.matchedId);\n") - .append(" if (!canReadGeneratedField(_f_remoteField, _f_localField)) {\n") - .append(debugRemoteReadCode("before skip", "_f_remoteField", 8)) - .append(" skipField(") - .append(READ_CONTEXT_NAME) - .append(", _f_remoteField);\n") - .append(debugRemoteReadCode("after skip", "_f_remoteField", 8)) - .append(" return;\n") - .append(" }\n") + .append(" SerializationFieldInfo _f_localField = localFieldInfo(") + .append(localId) + .append(");\n") .append(" Object _f_fieldValue = readCompatibleFieldValue(") .append(READ_CONTEXT_NAME) .append(", _f_remoteField, _f_localField);\n") @@ -353,54 +370,72 @@ private String genObjectDispatchGroup(int group, TypeRef valueTypeRef) { .append(" }\n"); } code.append(" default:\n") - .append(debugRemoteReadCode("before skip", "_f_remoteField", 4)) - .append(" skipField(") - .append(READ_CONTEXT_NAME) - .append(", _f_remoteField);\n") - .append(debugRemoteReadCode("after skip", "_f_remoteField", 4)) + .append(" throw new IllegalStateException(\"Invalid compatible matched id \"") + .append(" + _f_remoteField.matchedId);\n") .append("}\n"); return code.toString(); } - private String genRecordDispatchGroup(int group) { - int start = group * DISPATCH_GROUP_SIZE; - int end = Math.min(localDescriptors.size(), start + DISPATCH_GROUP_SIZE); + private String genRecordDispatchSwitch() { StringBuilder code = new StringBuilder("switch (_f_remoteField.matchedId) {\n"); - for (int i = start; i < end; i++) { - Descriptor descriptor = localDescriptors.get(i); + for (int matchedId = 0; matchedId < localDescriptors.size() * 2; matchedId++) { + int localId = matchedId >> 1; + Descriptor descriptor = localDescriptors.get(localId); Integer componentIndex = recordReversedMapping.get(descriptor.getName()); if (componentIndex == null) { continue; } - code.append(" case ") - .append(i) - .append(": {\n") - .append(debugRemoteReadCode("before read", "_f_remoteField", 4)) - .append( - " SerializationFieldInfo _f_localField = localFieldInfo(_f_remoteField.matchedId);\n") - .append(" if (canReadGeneratedField(_f_remoteField, _f_localField)) {\n") - .append(" _f_recordValues[") - .append(componentIndex) - .append("] = readCompatibleFieldValue(") + code.append(" case ").append(matchedId).append(": {\n"); + if ((matchedId & 1) == 0) { + code.append(" MemoryBuffer ") + .append(BUFFER_NAME) + .append(" = ") + .append(READ_CONTEXT_NAME) + .append(".getBuffer();\n") + .append( + indent( + genDirectReadRecordLocalCode(descriptor, "_f_recordValue" + componentIndex), 4)) + .append('\n') + .append(" break;\n") + .append(" }\n"); + continue; + } + String scalarRead = + genCompatibleScalarReadExpr( + descriptor, + "_f_remoteField.serializationFieldInfo", + "localFieldInfo(" + localId + ")"); + if (scalarRead != null) { + code.append(debugRemoteReadCode("before read", "_f_remoteField", 4)) + .append(" _f_recordValue") + .append(componentIndex) + .append(" = ") + .append(scalarRead) + .append(";\n") + .append(debugRemoteReadCode("after read", "_f_remoteField", 4)) + .append(" break;\n") + .append(" }\n"); + continue; + } + code.append(debugRemoteReadCode("before read", "_f_remoteField", 4)) + .append(" SerializationFieldInfo _f_localField = localFieldInfo(") + .append(localId) + .append(");\n") + .append(" Object _f_fieldValue = readCompatibleFieldValue(") .append(READ_CONTEXT_NAME) .append(", _f_remoteField, _f_localField);\n") .append(debugRemoteReadCode("after read", "_f_remoteField", 6)) - .append(" } else {\n") - .append(debugRemoteReadCode("before skip", "_f_remoteField", 6)) - .append(" skipField(") - .append(READ_CONTEXT_NAME) - .append(", _f_remoteField);\n") - .append(debugRemoteReadCode("after skip", "_f_remoteField", 6)) - .append(" }\n") - .append(" return;\n") + .append(" _f_recordValue") + .append(componentIndex) + .append(" = ") + .append(castRecordComponent("_f_fieldValue", descriptor.getRawType())) + .append(";\n") + .append(" break;\n") .append(" }\n"); } code.append(" default:\n") - .append(debugRemoteReadCode("before skip", "_f_remoteField", 4)) - .append(" skipField(") - .append(READ_CONTEXT_NAME) - .append(", _f_remoteField);\n") - .append(debugRemoteReadCode("after skip", "_f_remoteField", 4)) + .append(" throw new IllegalStateException(\"Invalid compatible matched id \"") + .append(" + _f_remoteField.matchedId);\n") .append("}\n"); return code.toString(); } @@ -431,6 +466,34 @@ private String debugRemoteReadCode(String stage, String remoteField, int indent) return code.toString(); } + private String genDirectReadAndSetFieldCode(Descriptor descriptor, TypeRef valueTypeRef) { + ctx.clearExprState(); + Reference value = new Reference("_f_value", valueTypeRef, false); + Reference buffer = new Reference(BUFFER_NAME, bufferTypeRef, false); + Expression readAndSet = + deserializeField( + buffer, + descriptor, + fieldValue -> + setFieldValue( + value, descriptor, tryInlineCast(fieldValue, descriptor.getTypeRef()))); + Code.ExprCode readCode = readAndSet.genCode(ctx); + String code = ctx.optimizeMethodCode(readCode.code()); + return code == null ? "" : code; + } + + private String genDirectReadRecordLocalCode(Descriptor descriptor, String recordValue) { + ctx.clearExprState(); + Reference target = new Reference(recordValue, descriptor.getTypeRef(), false); + Reference buffer = new Reference(BUFFER_NAME, bufferTypeRef, false); + Expression fieldValue = deserializeField(buffer, descriptor, value -> value); + Expression assign = + new Expression.Assign(target, tryInlineCast(fieldValue, descriptor.getTypeRef())); + Code.ExprCode readCode = assign.genCode(ctx); + String code = ctx.optimizeMethodCode(readCode.code()); + return code == null ? "" : code; + } + private String genSetFieldCode(Descriptor descriptor, TypeRef valueTypeRef) { ctx.clearExprState(); Reference value = new Reference("_f_value", valueTypeRef, false); @@ -442,15 +505,82 @@ private String genSetFieldCode(Descriptor descriptor, TypeRef valueTypeRef) { return code == null ? "" : code; } - private String recordConstructorArgs(RecordComponent[] components, String recordValues) { - StringBuilder args = new StringBuilder(); - for (int i = 0; i < components.length; i++) { - if (i > 0) { - args.append(", "); - } - args.append(castRecordComponent(recordValues + "[" + i + "]", components[i].getType())); + private String genSetFieldCode( + Descriptor descriptor, TypeRef valueTypeRef, String valueExpression) { + ctx.clearExprState(); + Reference value = new Reference("_f_value", valueTypeRef, false); + Expression setField = + setFieldValue( + value, + descriptor, + tryInlineCast( + new Reference(valueExpression, descriptor.getTypeRef(), false), + descriptor.getTypeRef())); + Code.ExprCode setCode = setField.genCode(ctx); + String code = ctx.optimizeMethodCode(setCode.code()); + return code == null ? "" : code; + } + + private String genCompatibleScalarReadExpr( + Descriptor descriptor, String remoteFieldInfo, String localFieldInfo) { + String helper; + Class rawType = descriptor.getRawType(); + if (rawType == boolean.class) { + helper = "readBooleanTarget"; + } else if (rawType == Boolean.class) { + helper = "readBoxedBooleanTarget"; + } else if (rawType == byte.class) { + helper = "readByteTarget"; + } else if (rawType == Byte.class) { + helper = "readBoxedByteTarget"; + } else if (rawType == short.class) { + helper = "readShortTarget"; + } else if (rawType == Short.class) { + helper = "readBoxedShortTarget"; + } else if (rawType == int.class) { + helper = "readIntTarget"; + } else if (rawType == Integer.class) { + helper = "readBoxedIntTarget"; + } else if (rawType == long.class) { + helper = "readLongTarget"; + } else if (rawType == Long.class) { + helper = "readBoxedLongTarget"; + } else if (rawType == float.class) { + helper = "readFloatTarget"; + } else if (rawType == Float.class) { + helper = "readBoxedFloatTarget"; + } else if (rawType == double.class) { + helper = "readDoubleTarget"; + } else if (rawType == Double.class) { + helper = "readBoxedDoubleTarget"; + } else if (rawType == String.class) { + helper = "readStringTarget"; + } else if (rawType == java.math.BigDecimal.class) { + helper = "readDecimalTarget"; + } else if (rawType == org.apache.fory.type.unsigned.UInt8.class) { + helper = "readUInt8Target"; + } else if (rawType == org.apache.fory.type.unsigned.UInt16.class) { + helper = "readUInt16Target"; + } else if (rawType == org.apache.fory.type.unsigned.UInt32.class) { + helper = "readUInt32Target"; + } else if (rawType == org.apache.fory.type.unsigned.UInt64.class) { + helper = "readUInt64Target"; + } else if (rawType == org.apache.fory.type.Float16.class) { + helper = "readFloat16Target"; + } else if (rawType == org.apache.fory.type.BFloat16.class) { + helper = "readBFloat16Target"; + } else { + return null; } - return args.toString(); + return "org.apache.fory.serializer.converter.FieldConverters." + + helper + + "(" + + READ_CONTEXT_NAME + + ", " + + remoteFieldInfo + + ", " + + localFieldInfo + + ")"; } private String castRecordComponent(String value, Class type) { @@ -460,9 +590,35 @@ private String castRecordComponent(String value, Class type) { return "(" + ctx.type(boxedType(type)) + ") " + value; } - private String boxedDefaultValue(Class type) { + private String recordLocalType(Class type) { + return sourcePublicAccessible(type) ? ctx.type(type) : ctx.type(Object.class); + } + + private String recordArgsFieldName(int componentCount) { + if (recordArgsFieldName != null) { + return recordArgsFieldName; + } + String fieldName = "_f_recordArgs"; + ctx.addField( + false, + false, + ctx.type(Object[].class), + fieldName, + new Expression.NewArray(Object.class, Literal.ofInt(componentCount))); + recordArgsFieldName = fieldName; + return fieldName; + } + + private Expression generatedObjectInstantiator() { + if (objectInstantiator == null) { + objectInstantiator = getObjectInstantiator(beanClass); + } + return objectInstantiator; + } + + private String defaultValue(Class type) { if (!type.isPrimitive()) { - return null; + return "null"; } if (type == boolean.class) { return "false"; diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java b/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java index f5ec2e263b..3f1dacb561 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java @@ -143,50 +143,34 @@ public Descriptor toDescriptor(TypeResolver resolver, Descriptor descriptor) { .build(); } if (peerArrayTypeId != Types.UNKNOWN - && peerArrayTypeId == arrayTypeId(localFieldType) - && TypeAnnotationUtils.isArrayType(descriptor)) { - return new DescriptorBuilder(descriptor) - .trackingRef(remoteTrackingRef) - .nullable(remoteNullable) - .build(); + && localFieldType != null + && peerArrayTypeId == arrayTypeId(localFieldType)) { + // Dense-array schema ids, not JVM carrier classes, define xlang compatibility. Keep the + // remote carrier so compatible generated arms can perform any required local adaptation. + return builder.build(); } if (localFieldType != null && isListArrayRootPair(fieldType, localFieldType)) { - throw new IllegalArgumentException( - StringUtils.format( - "Unsupported list/array compatible field mismatch for field ${definedClass}.${fieldName}: peer=${peer}, local=${local}", - "definedClass", - definedClass, - "fieldName", - fieldName, - "peer", - fieldType, - "local", - localFieldType)); + throw incompatibleField("unsupported list/array compatible field mismatch", localFieldType); } - if (localFieldType != null && hasNestedListArrayShapeMismatch(fieldType, localFieldType)) { - // List/array bridging is only defined for the matched field itself. If the shape differs - // deeper in a container, keep the remote descriptor for skipping but do not assign it to - // the local field. - TypeRef remoteTypeRef = fieldType.toTypeToken(resolver, null); - return new DescriptorBuilder(descriptor) - .typeName(fieldType.getTypeName(resolver, remoteTypeRef)) - .trackingRef(remoteTrackingRef) - .nullable(remoteNullable) - .typeRef(remoteTypeRef) - .type(remoteTypeRef.getRawType()) - .field(null) - .build(); + if (localFieldType != null && hasNestedFieldSchemaMismatch(fieldType, localFieldType)) { + throw incompatibleField("nested field schema mismatch", localFieldType); + } + if (localFieldType != null && hasIncompatibleRootArrayOrBinary(fieldType, localFieldType)) { + throw incompatibleField("primitive-array/binary field schema mismatch", localFieldType); } + boolean rootArrayOrBinaryBridge = + localFieldType != null && isBinaryUint8ArrayRootPair(fieldType, localFieldType); if (remoteNullable == descriptor.isNullable() && remoteTrackingRef == descriptor.isTrackingRef() - && typeRef.equals(descriptor.getTypeRef())) { + && typeRef.equals(descriptor.getTypeRef()) + && !rootArrayOrBinaryBridge) { if (typeName.equals(descriptor.getTypeName())) { return descriptor; } } Descriptor remoteDescriptor = builder.build(); if (localFieldType != null && isRefTrackedScalarSchemaMismatch(fieldType, localFieldType)) { - return new DescriptorBuilder(remoteDescriptor).field(null).build(); + throw incompatibleField("reference-tracked scalar schema mismatch", localFieldType); } FieldConverter converter = FieldConverters.getConverter(resolver, remoteDescriptor, descriptor); @@ -196,6 +180,9 @@ public Descriptor toDescriptor(TypeResolver resolver, Descriptor descriptor) { .fieldConverter(converter) .build(); } + if (FieldConverters.canConvert(resolver, remoteDescriptor, descriptor)) { + return remoteDescriptor; + } if (FieldTypes.useFieldType(rawType, descriptor)) { return remoteDescriptor; } @@ -209,9 +196,8 @@ public Descriptor toDescriptor(TypeResolver resolver, Descriptor descriptor) { Class remotePrimitive = typeRef.unwrap().getRawType(); boolean bothPrimitives = declaredPrimitive.isPrimitive() && remotePrimitive.isPrimitive(); boolean samePrimitiveType = bothPrimitives && declaredPrimitive.equals(remotePrimitive); - // Set field to null if types are incompatible (not the same primitive type) if (!samePrimitiveType) { - builder.field(null); + throw incompatibleField("field type mismatch", localFieldType); } } } @@ -241,7 +227,13 @@ private boolean isTopLevelListArrayCompatibleReadPair( return false; } FieldTypes.FieldType localFieldType = FieldTypes.buildFieldType(resolver, localDescriptor); - int peerListElementTypeId = listElementTypeId(fieldType); + if (fieldType.nullable() + || localFieldType.nullable() + || fieldType.trackingRef() + || localFieldType.trackingRef()) { + return false; + } + int peerListElementTypeId = untrackedListElementTypeId(fieldType); if (peerListElementTypeId != Types.UNKNOWN) { int localArrayTypeId = arrayTypeId(localFieldType); return localArrayTypeId != Types.UNKNOWN @@ -256,17 +248,11 @@ private boolean isTopLevelListArrayCompatibleReadPair( return false; } - private static boolean hasListArrayShapeMismatch( + private static boolean hasNestedFieldSchemaMismatch( FieldTypes.FieldType peerFieldType, FieldTypes.FieldType localFieldType) { - if (isListArrayRootPair(peerFieldType, localFieldType)) { - return true; - } - if (peerFieldType.getTypeId() != localFieldType.getTypeId()) { - return false; - } if (peerFieldType instanceof FieldTypes.CollectionFieldType && localFieldType instanceof FieldTypes.CollectionFieldType) { - return hasListArrayShapeMismatch( + return !sameNestedFieldSchema( ((FieldTypes.CollectionFieldType) peerFieldType).getElementType(), ((FieldTypes.CollectionFieldType) localFieldType).getElementType()); } @@ -274,43 +260,114 @@ private static boolean hasListArrayShapeMismatch( && localFieldType instanceof FieldTypes.MapFieldType) { FieldTypes.MapFieldType peerMap = (FieldTypes.MapFieldType) peerFieldType; FieldTypes.MapFieldType localMap = (FieldTypes.MapFieldType) localFieldType; - return hasListArrayShapeMismatch(peerMap.getKeyType(), localMap.getKeyType()) - || hasListArrayShapeMismatch(peerMap.getValueType(), localMap.getValueType()); + return !sameNestedFieldSchema(peerMap.getKeyType(), localMap.getKeyType()) + || !sameNestedFieldSchema(peerMap.getValueType(), localMap.getValueType()); } if (peerFieldType instanceof FieldTypes.ArrayFieldType && localFieldType instanceof FieldTypes.ArrayFieldType) { - return hasListArrayShapeMismatch( + return !sameNestedFieldSchema( ((FieldTypes.ArrayFieldType) peerFieldType).getComponentType(), ((FieldTypes.ArrayFieldType) localFieldType).getComponentType()); } return false; } - private static boolean hasNestedListArrayShapeMismatch( + private static boolean sameNestedFieldSchema( FieldTypes.FieldType peerFieldType, FieldTypes.FieldType localFieldType) { - if (peerFieldType.getTypeId() != localFieldType.getTypeId()) { + if (compatibleScalarType(peerFieldType.getTypeId()) + || compatibleScalarType(localFieldType.getTypeId())) { + return peerFieldType.getTypeId() == localFieldType.getTypeId() + && peerFieldType.trackingRef() == localFieldType.trackingRef() + && (!peerFieldType.trackingRef() + || peerFieldType.nullable() == localFieldType.nullable()); + } + // Preserve peer TypeDef type ids during decode. TYPED_UNION/NAMED_UNION remain peer metadata + // identities there; compatible schema comparison owns accepting the union family as one nested + // field domain. + if (Types.isUnionType(peerFieldType.getTypeId()) + && Types.isUnionType(localFieldType.getTypeId())) { + return true; + } + // Compatible reads use remote collection/map element metadata for nested user-defined object + // presence. Rejecting nullable/ref drift here would block valid xlang IDL payloads whose + // runtime element framing still carries the actual null/ref state. + if (isUserDefinedNestedType(peerFieldType) && isUserDefinedNestedType(localFieldType)) { + return sameUnresolvedOrNormalizedNestedTypeId( + peerFieldType.getTypeId(), localFieldType.getTypeId()); + } + if (peerFieldType.trackingRef() != localFieldType.trackingRef()) { return false; } if (peerFieldType instanceof FieldTypes.CollectionFieldType && localFieldType instanceof FieldTypes.CollectionFieldType) { - return hasListArrayShapeMismatch( - ((FieldTypes.CollectionFieldType) peerFieldType).getElementType(), - ((FieldTypes.CollectionFieldType) localFieldType).getElementType()); + return sameContainerType(peerFieldType, localFieldType) + && sameNestedFieldSchema( + ((FieldTypes.CollectionFieldType) peerFieldType).getElementType(), + ((FieldTypes.CollectionFieldType) localFieldType).getElementType()); } if (peerFieldType instanceof FieldTypes.MapFieldType && localFieldType instanceof FieldTypes.MapFieldType) { FieldTypes.MapFieldType peerMap = (FieldTypes.MapFieldType) peerFieldType; FieldTypes.MapFieldType localMap = (FieldTypes.MapFieldType) localFieldType; - return hasListArrayShapeMismatch(peerMap.getKeyType(), localMap.getKeyType()) - || hasListArrayShapeMismatch(peerMap.getValueType(), localMap.getValueType()); + return sameContainerType(peerFieldType, localFieldType) + && sameNestedFieldSchema(peerMap.getKeyType(), localMap.getKeyType()) + && sameNestedFieldSchema(peerMap.getValueType(), localMap.getValueType()); } if (peerFieldType instanceof FieldTypes.ArrayFieldType && localFieldType instanceof FieldTypes.ArrayFieldType) { - return hasListArrayShapeMismatch( - ((FieldTypes.ArrayFieldType) peerFieldType).getComponentType(), - ((FieldTypes.ArrayFieldType) localFieldType).getComponentType()); + FieldTypes.ArrayFieldType peerArray = (FieldTypes.ArrayFieldType) peerFieldType; + FieldTypes.ArrayFieldType localArray = (FieldTypes.ArrayFieldType) localFieldType; + return sameContainerType(peerFieldType, localFieldType) + && peerArray.getDimensions() == localArray.getDimensions() + && sameNestedFieldSchema(peerArray.getComponentType(), localArray.getComponentType()); } - return false; + if (peerFieldType instanceof FieldTypes.RegisteredFieldType + && localFieldType instanceof FieldTypes.RegisteredFieldType) { + return normalizedNestedTypeId(peerFieldType.getTypeId()) + == normalizedNestedTypeId(localFieldType.getTypeId()); + } + if (peerFieldType instanceof FieldTypes.EnumFieldType + && localFieldType instanceof FieldTypes.EnumFieldType) { + return sameUnresolvedOrNormalizedNestedTypeId( + peerFieldType.getTypeId(), localFieldType.getTypeId()); + } + return peerFieldType instanceof FieldTypes.UnionFieldType + && localFieldType instanceof FieldTypes.UnionFieldType; + } + + private static boolean sameContainerType( + FieldTypes.FieldType peerFieldType, FieldTypes.FieldType localFieldType) { + return sameUnresolvedOrNormalizedNestedTypeId( + peerFieldType.getTypeId(), localFieldType.getTypeId()); + } + + private static boolean sameUnresolvedOrNormalizedNestedTypeId(int peerTypeId, int localTypeId) { + if (peerTypeId <= Types.UNKNOWN || localTypeId <= Types.UNKNOWN) { + return true; + } + return normalizedNestedTypeId(peerTypeId) == normalizedNestedTypeId(localTypeId); + } + + private static int normalizedNestedTypeId(int typeId) { + if (typeId == Types.UNKNOWN || Types.isStructType(typeId)) { + return Types.STRUCT; + } + if (Types.isEnumType(typeId)) { + return Types.ENUM; + } + if (Types.isExtType(typeId)) { + return Types.EXT; + } + if (Types.isUnionType(typeId)) { + return Types.UNION; + } + return typeId; + } + + private static boolean isUserDefinedNestedType(FieldTypes.FieldType fieldType) { + int typeId = fieldType.getTypeId(); + return (fieldType instanceof FieldTypes.RegisteredFieldType && Types.isUserDefinedType(typeId)) + || fieldType instanceof FieldTypes.ObjectFieldType; } private static boolean isRefTrackedScalarSchemaMismatch( @@ -333,6 +390,49 @@ private static boolean compatibleScalarType(int typeId) { || typeId == Types.DECIMAL; } + private static boolean hasIncompatibleRootArrayOrBinary( + FieldTypes.FieldType remoteFieldType, FieldTypes.FieldType localFieldType) { + int remoteTypeId = rootArrayOrBinaryTypeId(remoteFieldType); + int localTypeId = rootArrayOrBinaryTypeId(localFieldType); + if (remoteTypeId == Types.UNKNOWN && localTypeId == Types.UNKNOWN) { + return false; + } + return remoteTypeId != localTypeId && !isBinaryUint8ArrayTypePair(remoteTypeId, localTypeId); + } + + private static boolean isBinaryUint8ArrayRootPair( + FieldTypes.FieldType remoteFieldType, FieldTypes.FieldType localFieldType) { + return isBinaryUint8ArrayTypePair( + rootArrayOrBinaryTypeId(remoteFieldType), rootArrayOrBinaryTypeId(localFieldType)); + } + + private static boolean isBinaryUint8ArrayTypePair(int remoteTypeId, int localTypeId) { + return (remoteTypeId == Types.BINARY && localTypeId == Types.UINT8_ARRAY) + || (remoteTypeId == Types.UINT8_ARRAY && localTypeId == Types.BINARY); + } + + private static int rootArrayOrBinaryTypeId(FieldTypes.FieldType fieldType) { + int typeId = fieldType.getTypeId(); + return typeId == Types.BINARY || Types.isPrimitiveArray(typeId) ? typeId : Types.UNKNOWN; + } + + private IllegalArgumentException incompatibleField( + String reason, FieldTypes.FieldType localFieldType) { + return new IllegalArgumentException( + StringUtils.format( + "Incompatible compatible field schema for field ${definedClass}.${fieldName}: ${reason}; peer=${peer}, local=${local}", + "definedClass", + definedClass, + "fieldName", + fieldName, + "reason", + reason, + "peer", + fieldType, + "local", + localFieldType)); + } + private static boolean isListArrayRootPair( FieldTypes.FieldType peerFieldType, FieldTypes.FieldType localFieldType) { boolean peerList = isListField(peerFieldType); @@ -370,6 +470,10 @@ private TypeRef primitiveListCarrierType() { } private static int listElementTypeId(FieldTypes.FieldType fieldType) { + return listElementTypeId(fieldType, false); + } + + private static int listElementTypeId(FieldTypes.FieldType fieldType, boolean requireUntracked) { if (!(fieldType instanceof FieldTypes.CollectionFieldType) || fieldType.getTypeId() != Types.LIST) { return Types.UNKNOWN; @@ -377,11 +481,20 @@ private static int listElementTypeId(FieldTypes.FieldType fieldType) { FieldTypes.FieldType elementType = ((FieldTypes.CollectionFieldType) fieldType).getElementType(); if (elementType instanceof FieldTypes.RegisteredFieldType) { + // Nullable element schema is allowed for list -> array compatibility; + // actual null payload elements are rejected by the dense-array reader. + if (requireUntracked && elementType.trackingRef()) { + return Types.UNKNOWN; + } return ((FieldTypes.RegisteredFieldType) elementType).getTypeId(); } return Types.UNKNOWN; } + private static int untrackedListElementTypeId(FieldTypes.FieldType fieldType) { + return listElementTypeId(fieldType, true); + } + private static int arrayTypeId(FieldTypes.FieldType fieldType) { if (fieldType instanceof FieldTypes.RegisteredFieldType) { int typeId = ((FieldTypes.RegisteredFieldType) fieldType).getTypeId(); diff --git a/java/fory-core/src/main/java/org/apache/fory/platform/internal/DefineClass.java b/java/fory-core/src/main/java/org/apache/fory/platform/internal/DefineClass.java index 4ea20ae703..adfd99a645 100644 --- a/java/fory-core/src/main/java/org/apache/fory/platform/internal/DefineClass.java +++ b/java/fory-core/src/main/java/org/apache/fory/platform/internal/DefineClass.java @@ -89,7 +89,9 @@ public static Class defineHiddenNestmate(Class neighbor, byte[] bytecodes) throw hiddenClassUnsupported(null); } try { - Lookup lookup = _Lookup.privateLookupIn(neighbor, MethodHandles.lookup()); + // Lookup#defineHiddenClass with NESTMATE requires full-privilege access. A private lookup + // can resolve members but is not sufficient on JDK25+ for classes from user loaders. + Lookup lookup = _JDKAccess._trustedLookup(neighbor); return hiddenClassDefiner().define(lookup, bytecodes); } catch (Throwable e) { throw hiddenClassFailure(neighbor, e); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java b/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java index ca516748d3..d24464b1ff 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/StaticGeneratedSerializerRegistry.java @@ -19,14 +19,16 @@ package org.apache.fory.resolver; -import java.lang.reflect.Constructor; +import java.lang.invoke.MethodHandle; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import org.apache.fory.annotation.Internal; import org.apache.fory.exception.ForyException; import org.apache.fory.meta.TypeDef; +import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.serializer.StaticGeneratedStructSerializer; import org.apache.fory.type.Descriptor; +import org.apache.fory.util.ExceptionUtils; /** Shared registry of build-time generated static serializer mappings. */ @Internal @@ -41,15 +43,15 @@ public enum Mode { private static final class Entry { private final Class serializerClass; - private final Constructor descriptorConstructor; - private final Constructor runtimeConstructor; - private final Constructor compatibleConstructor; + private final MethodHandle descriptorConstructor; + private final MethodHandle runtimeConstructor; + private final MethodHandle compatibleConstructor; private Entry( Class serializerClass, - Constructor descriptorConstructor, - Constructor runtimeConstructor, - Constructor compatibleConstructor) { + MethodHandle descriptorConstructor, + MethodHandle runtimeConstructor, + MethodHandle compatibleConstructor) { this.serializerClass = serializerClass; this.descriptorConstructor = descriptorConstructor; this.runtimeConstructor = runtimeConstructor; @@ -64,26 +66,20 @@ StaticGeneratedStructSerializer newSerializer( TypeResolver resolver, Class type, TypeDef typeDef) { try { return typeDef == null - ? runtimeConstructor.newInstance(resolver, type) - : compatibleConstructor.newInstance(resolver, type, typeDef); - } catch (ReflectiveOperationException e) { - throw new ForyException( - "Failed to create static generated serializer " - + serializerClass.getName() - + " for " - + type.getName(), - e); + ? (StaticGeneratedStructSerializer) runtimeConstructor.invoke(resolver, type) + : (StaticGeneratedStructSerializer) + compatibleConstructor.invoke(resolver, type, typeDef); + } catch (Throwable e) { + throw ExceptionUtils.throwException(e); } } List getGeneratedDescriptors() { try { - return descriptorConstructor.newInstance().getGeneratedDescriptors(); - } catch (ReflectiveOperationException e) { - throw new ForyException( - "Failed to create descriptor-only static generated serializer " - + serializerClass.getName(), - e); + return ((StaticGeneratedStructSerializer) descriptorConstructor.invoke()) + .getGeneratedDescriptors(); + } catch (Throwable e) { + throw ExceptionUtils.throwException(e); } } } @@ -159,29 +155,26 @@ private Entry loadEntry(Class targetType, Mode mode) { } private Entry newEntry(Class serializerClass) { - Constructor descriptorConstructor = - constructor(serializerClass); - Constructor runtimeConstructor = - constructor(serializerClass, TypeResolver.class, Class.class); - Constructor compatibleConstructor = + MethodHandle descriptorConstructor = constructor(serializerClass); + MethodHandle runtimeConstructor = constructor(serializerClass, TypeResolver.class, Class.class); + MethodHandle compatibleConstructor = constructor(serializerClass, TypeResolver.class, Class.class, TypeDef.class); return new Entry( serializerClass, descriptorConstructor, runtimeConstructor, compatibleConstructor); } - private Constructor constructor( + private MethodHandle constructor( Class serializerClass, Class... parameterTypes) { try { - Constructor constructor = - serializerClass.getDeclaredConstructor(parameterTypes); - constructor.setAccessible(true); - return constructor; - } catch (NoSuchMethodException e) { + // Static generated serializers are emitted into application packages, which may be in named + // modules not opened to Fory. Trusted constructor handles keep JPMS access in _JDKAccess. + return ReflectionUtils.getCtrHandle(serializerClass, parameterTypes); + } catch (RuntimeException e) { throw new ForyException( "Generated serializer " + serializerClass.getName() - + " is missing required constructor " + + " is missing or cannot access required constructor " + constructorSignature(parameterTypes), e); } diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 2dab906fbb..80c9f59982 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -21,7 +21,7 @@ import static org.apache.fory.type.Types.INVALID_USER_TYPE_ID; -import java.lang.reflect.Constructor; +import java.lang.invoke.MethodHandle; import java.lang.reflect.Field; import java.lang.reflect.Member; import java.lang.reflect.Type; @@ -100,6 +100,7 @@ import org.apache.fory.type.ScalaTypes; import org.apache.fory.type.TypeUtils; import org.apache.fory.type.Types; +import org.apache.fory.util.ExceptionUtils; import org.apache.fory.util.Preconditions; import org.apache.fory.util.function.Functions; @@ -1137,7 +1138,9 @@ private TypeInfo getMetaSharedTypeInfo(TypeDef typeDef, Class clz) { jitContext.registerSerializerJITCallback( () -> CompatibleSerializer.class, () -> CodecUtils.loadOrGenCompatibleCodecClass(this, cls, typeDef), - c -> typeInfo.setSerializer(this, Serializers.newSerializer(this, cls, c))); + c -> + typeInfo.setSerializer( + this, newGeneratedCompatibleSerializer(cls, c, typeDef))); } else if (sc == null) { sc = CompatibleSerializer.class; } @@ -1158,12 +1161,28 @@ private TypeInfo getMetaSharedTypeInfo(TypeDef typeDef, Class clz) { typeInfo.setSerializer(this, newStaticGeneratedStructSerializer(sc, cls, typeDef)); } else if (sc == CompatibleSerializer.class) { typeInfo.setSerializer(this, new CompatibleSerializer(this, cls, typeDef)); + } else if (GeneratedCompatibleSerializer.class.isAssignableFrom(sc)) { + typeInfo.setSerializer(this, newGeneratedCompatibleSerializer(cls, sc, typeDef)); } else { typeInfo.setSerializer(this, Serializers.newSerializer(this, cls, sc)); } return typeInfo; } + private Serializer newGeneratedCompatibleSerializer( + Class cls, Class serializerClass, TypeDef typeDef) { + try { + // Generated serializers can live in non-open application modules. Use Fory's trusted + // lookup owner instead of reflective setAccessible, which JPMS would reject there. + MethodHandle constructor = + ReflectionUtils.getCtrHandle( + serializerClass, TypeResolver.class, Class.class, TypeDef.class); + return (Serializer) constructor.invoke(this, cls, typeDef); + } catch (Throwable e) { + throw ExceptionUtils.throwException(e); + } + } + private Class loadGraalvmCompatibleDeserializerClass( Class cls, TypeDef typeDef) { if (typeDef.getId() == TypeDef.buildTypeDef(this, cls).getId()) { @@ -1746,17 +1765,14 @@ protected final StaticGeneratedStructSerializer newStaticGeneratedStructSeria private StaticGeneratedStructSerializer newRuntimeStaticCompatibleSerializer( Class serializerClass, Class cls, TypeDef typeDef) { try { - Constructor constructor = - serializerClass.getDeclaredConstructor(TypeResolver.class, Class.class, TypeDef.class); - constructor.setAccessible(true); - return (StaticGeneratedStructSerializer) constructor.newInstance(this, cls, typeDef); - } catch (ReflectiveOperationException e) { - throw new ForyException( - "Failed to create runtime static compatible serializer " - + serializerClass.getName() - + " for " - + cls.getName(), - e); + // Generated serializers can live in non-open application modules. Use Fory's trusted + // lookup owner instead of reflective setAccessible, which JPMS would reject there. + MethodHandle constructor = + ReflectionUtils.getCtrHandle( + serializerClass, TypeResolver.class, Class.class, TypeDef.class); + return (StaticGeneratedStructSerializer) constructor.invoke(this, cls, typeDef); + } catch (Throwable e) { + throw ExceptionUtils.throwException(e); } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java index 42b1436e90..7f4a8856e9 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java @@ -63,6 +63,7 @@ final class CompatibleCollectionArrayReader { static final int READ_LIST_TO_ARRAY = 1; static final int READ_ARRAY_TO_LIST = 2; static final int READ_LIST_TO_LIST = 3; + static final int READ_ARRAY_TO_ARRAY = 4; static final class ReadAction { final int mode; @@ -86,23 +87,23 @@ static ReadAction readAction(TypeResolver resolver, Descriptor descriptor) { return null; } FieldTypes.FieldType localFieldType = FieldTypes.buildFieldType(resolver, field); - int peerListElementTypeId = listElementTypeId(descriptor); + if (localFieldType.nullable() || localFieldType.trackingRef()) { + return null; + } + int peerListElementTypeId = untrackedListElementTypeId(descriptor); if (peerListElementTypeId != Types.UNKNOWN) { - // Element nullable/ref flags in TypeDef describe what the peer schema can encode, not what - // this payload actually contains. Dense-array compatibility is decided here by element type; - // readListPayloadAsPrimitiveArray rejects payloads that carry null/ref element markers. int localArrayTypeId = arrayTypeId(localFieldType); if (localArrayTypeId != Types.UNKNOWN && localArrayTypeId == denseArrayTypeId(peerListElementTypeId)) { return new ReadAction( READ_LIST_TO_ARRAY, localArrayTypeId, peerListElementTypeId, field.getType()); } - int nonNullablePeerListElementTypeId = nonNullableListElementTypeId(descriptor); - int localListElementTypeId = nonNullableListElementTypeId(localFieldType); + int untrackedPeerListElementTypeId = untrackedListElementTypeId(descriptor); + int localListElementTypeId = untrackedListElementTypeId(localFieldType); int peerArrayTypeId = denseArrayTypeId(peerListElementTypeId); - // List-to-array and list-to-list materialize through a dense primitive array, so they cannot - // preserve nullable or ref-tracked peer elements. - if (nonNullablePeerListElementTypeId != Types.UNKNOWN + // Actual null or ref-tracked payload elements are rejected by + // readListPayloadAsPrimitiveArray. + if (untrackedPeerListElementTypeId != Types.UNKNOWN && localListElementTypeId != Types.UNKNOWN && peerArrayTypeId != Types.UNKNOWN && peerArrayTypeId == denseArrayTypeId(localListElementTypeId) @@ -131,27 +132,32 @@ static ReadAction readAction( return null; } FieldTypes.FieldType remoteFieldType = remoteFieldInfo.getFieldType(); + FieldTypes.FieldType localFieldType = FieldTypes.buildFieldType(resolver, localDescriptor); + if (remoteFieldType.trackingRef() || localFieldType.trackingRef()) { + return null; + } TypeRef localType = localDescriptor.getTypeRef(); - int peerListElementTypeId = listElementTypeId(remoteFieldType); + boolean nullableArrayField = remoteFieldType.nullable() || localFieldType.nullable(); + int peerListElementTypeId = untrackedListElementTypeId(remoteFieldType); if (peerListElementTypeId != Types.UNKNOWN) { + if (nullableArrayField) { + return null; + } int localArrayTypeId = arrayTypeId(localDescriptor); if (localArrayTypeId != Types.UNKNOWN && localArrayTypeId == denseArrayTypeId(peerListElementTypeId)) { - // Remote element nullable/ref flags describe what the schema can encode. Actual payload - // null/ref markers are validated while reading so nullable list payloads without nulls - // remain compatible with local array fields. return new ReadAction( READ_LIST_TO_ARRAY, localArrayTypeId, peerListElementTypeId, localDescriptor.getRawType()); } - int nonNullablePeerListElementTypeId = nonNullableListElementTypeId(remoteFieldType); - int localListElementTypeId = nonNullableListElementTypeId(localType); + int untrackedPeerListElementTypeId = untrackedListElementTypeId(remoteFieldType); + int localListElementTypeId = listElementTypeId(localType); int peerArrayTypeId = denseArrayTypeId(peerListElementTypeId); - // List-to-array and list-to-list materialize through a dense primitive array, so they cannot - // preserve nullable or ref-tracked peer elements. - if (nonNullablePeerListElementTypeId != Types.UNKNOWN + // Actual null or ref-tracked payload elements are rejected by + // readListPayloadAsPrimitiveArray. + if (untrackedPeerListElementTypeId != Types.UNKNOWN && localListElementTypeId != Types.UNKNOWN && peerArrayTypeId != Types.UNKNOWN && peerArrayTypeId == denseArrayTypeId(localListElementTypeId) @@ -166,6 +172,17 @@ && canMaterializeListTarget(localDescriptor.getRawType(), peerArrayTypeId)) { } int peerArrayTypeId = arrayTypeId(remoteFieldType); if (peerArrayTypeId != Types.UNKNOWN) { + int localArrayTypeId = arrayTypeId(localFieldType); + if (localArrayTypeId == peerArrayTypeId && !remoteFieldType.equals(localFieldType)) { + return new ReadAction( + READ_ARRAY_TO_ARRAY, + peerArrayTypeId, + Types.UNKNOWN, + denseArrayTargetType(localDescriptor.getRawType(), peerArrayTypeId)); + } + if (nullableArrayField) { + return null; + } int localListElementTypeId = listElementTypeId(localType); if (localListElementTypeId != Types.UNKNOWN && peerArrayTypeId == denseArrayTypeId(localListElementTypeId) @@ -330,16 +347,16 @@ private static Object readNotNull( return materializeTarget(array, arrayTypeId, targetType); } if (readMode == READ_LIST_TO_LIST) { - Object array = readListPayloadAsPrimitiveArray(readContext, arrayTypeId, elementTypeId); - if (array == null) { - return null; - } - return materializeTarget(array, arrayTypeId, targetType); + return readListPayloadAsListTarget(readContext, arrayTypeId, elementTypeId, targetType); } if (readMode == READ_ARRAY_TO_LIST) { Object array = readDenseArrayPayload(readContext, arrayTypeId); return materializeTarget(array, arrayTypeId, targetType); } + if (readMode == READ_ARRAY_TO_ARRAY) { + Object array = readDenseArrayPayload(readContext, arrayTypeId); + return materializeTarget(array, arrayTypeId, targetType); + } throw new IllegalStateException("Unexpected compatible read mode " + readMode); } @@ -347,7 +364,7 @@ private static int listElementTypeId(FieldTypes.FieldType fieldType) { return listElementTypeId(fieldType, false); } - private static int listElementTypeId(FieldTypes.FieldType fieldType, boolean requireNonNullable) { + private static int listElementTypeId(FieldTypes.FieldType fieldType, boolean requireUntracked) { if (!(fieldType instanceof FieldTypes.CollectionFieldType) || fieldType.getTypeId() != Types.LIST) { return Types.UNKNOWN; @@ -355,7 +372,9 @@ private static int listElementTypeId(FieldTypes.FieldType fieldType, boolean req FieldTypes.FieldType elementType = ((FieldTypes.CollectionFieldType) fieldType).getElementType(); if (elementType instanceof FieldTypes.RegisteredFieldType) { - if (requireNonNullable && (elementType.nullable() || elementType.trackingRef())) { + // Nullable element schema is allowed for list -> array compatibility; + // actual null payload elements are rejected by the dense-array reader. + if (requireUntracked && elementType.trackingRef()) { return Types.UNKNOWN; } return ((FieldTypes.RegisteredFieldType) elementType).getTypeId(); @@ -367,7 +386,7 @@ private static int listElementTypeId(Descriptor descriptor) { return listElementTypeId(descriptor, false); } - private static int listElementTypeId(Descriptor descriptor, boolean requireNonNullable) { + private static int listElementTypeId(Descriptor descriptor, boolean requireUntracked) { Class rawType = descriptor.getRawType(); if (TypeUtils.isPrimitiveListClass(rawType) && TypeAnnotationUtils.isArrayType(descriptor)) { return Types.UNKNOWN; @@ -383,15 +402,16 @@ private static int listElementTypeId(Descriptor descriptor, boolean requireNonNu // wire shape here; otherwise array->list reads are misclassified as list->list reads. return Types.UNKNOWN; } - if (Types.isPrimitiveType(typeId) - && (!requireNonNullable || (!extMeta.nullable() && !extMeta.trackingRef()))) { + if (Types.isPrimitiveType(typeId) && (!requireUntracked || !extMeta.trackingRef())) { + // Nullable element metadata is not a schema-pair rejection. The + // dense-array read path fails only when the payload contains nulls. return typeId; } } TypeRef elementTypeRef = TypeAnnotationUtils.getPrimitiveListElementTypeRef(descriptor); if (elementTypeRef != null) { TypeExtMeta elementExtMeta = elementTypeRef.getTypeExtMeta(); - if (isPrimitiveElement(elementExtMeta, requireNonNullable)) { + if (isPrimitiveElement(elementExtMeta, requireUntracked)) { return elementExtMeta.typeId(); } } @@ -399,7 +419,7 @@ private static int listElementTypeId(Descriptor descriptor, boolean requireNonNu } if (extMeta != null && extMeta.typeId() == Types.LIST) { TypeExtMeta elementExtMeta = TypeUtils.getElementType(typeRef).getTypeExtMeta(); - return isPrimitiveElement(elementExtMeta, requireNonNullable) + return isPrimitiveElement(elementExtMeta, requireUntracked) ? elementExtMeta.typeId() : Types.UNKNOWN; } @@ -410,11 +430,11 @@ private static int listElementTypeId(TypeRef typeRef) { return listElementTypeId(typeRef, false); } - private static int listElementTypeId(TypeRef typeRef, boolean requireNonNullable) { + private static int listElementTypeId(TypeRef typeRef, boolean requireUntracked) { TypeExtMeta extMeta = typeRef.getTypeExtMeta(); if (extMeta != null && extMeta.typeId() == Types.LIST) { TypeExtMeta elementExtMeta = TypeUtils.getElementType(typeRef).getTypeExtMeta(); - return isPrimitiveElement(elementExtMeta, requireNonNullable) + return isPrimitiveElement(elementExtMeta, requireUntracked) ? elementExtMeta.typeId() : Types.UNKNOWN; } @@ -427,8 +447,9 @@ private static int listElementTypeId(TypeRef typeRef, boolean requireNonNulla // wire shape here; otherwise array->list reads are misclassified as list->list reads. return Types.UNKNOWN; } - if (Types.isPrimitiveType(typeId) - && (!requireNonNullable || (!extMeta.nullable() && !extMeta.trackingRef()))) { + if (Types.isPrimitiveType(typeId) && (!requireUntracked || !extMeta.trackingRef())) { + // Nullable element metadata is not a schema-pair rejection. The + // dense-array read path fails only when the payload contains nulls. return typeId; } } @@ -436,30 +457,30 @@ private static int listElementTypeId(TypeRef typeRef, boolean requireNonNulla } if (TypeUtils.isCollection(typeRef.getRawType())) { TypeExtMeta elementExtMeta = TypeUtils.getElementType(typeRef).getTypeExtMeta(); - return isPrimitiveElement(elementExtMeta, requireNonNullable) + return isPrimitiveElement(elementExtMeta, requireUntracked) ? elementExtMeta.typeId() : Types.UNKNOWN; } return Types.UNKNOWN; } - private static int nonNullableListElementTypeId(FieldTypes.FieldType fieldType) { + private static int untrackedListElementTypeId(FieldTypes.FieldType fieldType) { return listElementTypeId(fieldType, true); } - private static int nonNullableListElementTypeId(Descriptor descriptor) { + private static int untrackedListElementTypeId(Descriptor descriptor) { return listElementTypeId(descriptor, true); } - private static int nonNullableListElementTypeId(TypeRef typeRef) { + private static int untrackedListElementTypeId(TypeRef typeRef) { return listElementTypeId(typeRef, true); } - private static boolean isPrimitiveElement( - TypeExtMeta elementExtMeta, boolean requireNonNullable) { + private static boolean isPrimitiveElement(TypeExtMeta elementExtMeta, boolean requireUntracked) { + // Nullable element metadata is allowed; actual null payload elements fail while reading. return elementExtMeta != null && Types.isPrimitiveType(elementExtMeta.typeId()) - && (!requireNonNullable || (!elementExtMeta.nullable() && !elementExtMeta.trackingRef())); + && (!requireUntracked || !elementExtMeta.trackingRef()); } private static int arrayTypeId(Descriptor descriptor) { @@ -572,9 +593,9 @@ private static Object readListPayloadAsPrimitiveArray( boolean sameType = (flags & CollectionFlags.IS_SAME_TYPE) == CollectionFlags.IS_SAME_TYPE; boolean declared = (flags & CollectionFlags.IS_DECL_ELEMENT_TYPE) == CollectionFlags.IS_DECL_ELEMENT_TYPE; - if (hasNull || trackingRef) { + if (trackingRef) { throw new DeserializationException( - "Cannot read nullable or ref-tracked peer list payload into local array field"); + "Cannot read ref-tracked peer list payload into local array field"); } if (!sameType) { throw new DeserializationException( @@ -590,8 +611,58 @@ private static Object readListPayloadAsPrimitiveArray( + elementTypeId); } } + return readListPrimitiveElements(buffer, numElements, arrayTypeId, elementTypeId, hasNull); } - return readListPrimitiveElements(buffer, numElements, arrayTypeId, elementTypeId); + return readListPrimitiveElements(buffer, numElements, arrayTypeId, elementTypeId, false); + } + + private static Object readListPayloadAsListTarget( + ReadContext readContext, int arrayTypeId, int elementTypeId, Class targetType) { + MemoryBuffer buffer = readContext.getBuffer(); + int numElements = buffer.readVarUInt32Small7(); + validateElementCount(readContext.getConfig(), numElements); + validateElementStorageSize(readContext.getConfig(), numElements, elementSize(arrayTypeId)); + if (numElements == 0) { + Object array = readListPrimitiveElements(buffer, 0, arrayTypeId, elementTypeId, false); + return materializeTarget(array, arrayTypeId, targetType); + } + int flags = buffer.readByte(); + boolean hasNull = (flags & CollectionFlags.HAS_NULL) == CollectionFlags.HAS_NULL; + boolean trackingRef = (flags & CollectionFlags.TRACKING_REF) == CollectionFlags.TRACKING_REF; + boolean sameType = (flags & CollectionFlags.IS_SAME_TYPE) == CollectionFlags.IS_SAME_TYPE; + boolean declared = + (flags & CollectionFlags.IS_DECL_ELEMENT_TYPE) == CollectionFlags.IS_DECL_ELEMENT_TYPE; + if (trackingRef) { + throw new DeserializationException( + "Cannot read ref-tracked peer list payload into local list field"); + } + if (!sameType) { + throw new DeserializationException( + "Cannot read peer list payload into local list field"); + } + if (!declared) { + TypeInfo payloadElementTypeInfo = readContext.getTypeResolver().readTypeInfo(readContext); + if (payloadElementTypeInfo.getTypeId() != elementTypeId) { + throw new DeserializationException( + "Cannot read peer list element type id " + + payloadElementTypeInfo.getTypeId() + + " as local element type id " + + elementTypeId); + } + } + if (hasNull) { + // Nullable LIST element metadata is not a schema-pair rejection. Only boxed list targets can + // preserve actual null elements; dense primitive array/list targets fail while reading the + // nullable payload because they cannot represent null elements. + if (!targetType.isAssignableFrom(ArrayList.class)) { + throw new DeserializationException( + "Cannot read null peer list element into local list field"); + } + return readNullableListBoxedElements(buffer, numElements, arrayTypeId, elementTypeId); + } + Object array = + readListPrimitiveElements(buffer, numElements, arrayTypeId, elementTypeId, false); + return materializeTarget(array, arrayTypeId, targetType); } private static Object readDenseArrayPayload(ReadContext readContext, int arrayTypeId) { @@ -689,12 +760,15 @@ private static Object readPrimitiveElements( } private static Object readListPrimitiveElements( - MemoryBuffer buffer, int numElements, int arrayTypeId, int elementTypeId) { + MemoryBuffer buffer, int numElements, int arrayTypeId, int elementTypeId, boolean hasNull) { switch (elementTypeId) { case Types.BOOL: { boolean[] values = new boolean[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readBoolean(); } return values; @@ -703,7 +777,14 @@ private static Object readListPrimitiveElements( case Types.UINT8: { byte[] values = new byte[numElements]; - buffer.readBytes(values); + if (hasNull) { + for (int i = 0; i < numElements; i++) { + readNonNullListElement(buffer); + values[i] = buffer.readByte(); + } + } else { + buffer.readBytes(values); + } return values; } case Types.INT16: @@ -713,6 +794,9 @@ private static Object readListPrimitiveElements( { short[] values = new short[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readInt16(); } return values; @@ -722,6 +806,9 @@ private static Object readListPrimitiveElements( { int[] values = new int[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readInt32(); } return values; @@ -730,6 +817,9 @@ private static Object readListPrimitiveElements( { int[] values = new int[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readVarInt32(); } return values; @@ -738,6 +828,9 @@ private static Object readListPrimitiveElements( { int[] values = new int[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readVarUInt32(); } return values; @@ -747,6 +840,9 @@ private static Object readListPrimitiveElements( { long[] values = new long[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readInt64(); } return values; @@ -755,6 +851,9 @@ private static Object readListPrimitiveElements( { long[] values = new long[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readVarInt64(); } return values; @@ -763,6 +862,9 @@ private static Object readListPrimitiveElements( { long[] values = new long[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readTaggedInt64(); } return values; @@ -771,6 +873,9 @@ private static Object readListPrimitiveElements( { long[] values = new long[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readVarUInt64(); } return values; @@ -779,6 +884,9 @@ private static Object readListPrimitiveElements( { long[] values = new long[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readTaggedUInt64(); } return values; @@ -787,6 +895,9 @@ private static Object readListPrimitiveElements( { float[] values = new float[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readFloat32(); } return values; @@ -795,6 +906,9 @@ private static Object readListPrimitiveElements( { double[] values = new double[numElements]; for (int i = 0; i < numElements; i++) { + if (hasNull) { + readNonNullListElement(buffer); + } values[i] = buffer.readFloat64(); } return values; @@ -808,6 +922,85 @@ private static Object readListPrimitiveElements( } } + private static void readNonNullListElement(MemoryBuffer buffer) { + byte headFlag = buffer.readByte(); + if (headFlag == Fory.NULL_FLAG) { + throw new DeserializationException( + "Cannot read null peer list element into local array field"); + } + if (headFlag != Fory.NOT_NULL_VALUE_FLAG) { + throw new DeserializationException( + "Unexpected nullable peer list element flag " + headFlag); + } + } + + private static List readNullableListBoxedElements( + MemoryBuffer buffer, int numElements, int arrayTypeId, int elementTypeId) { + ArrayList values = new ArrayList<>(numElements); + for (int i = 0; i < numElements; i++) { + byte headFlag = buffer.readByte(); + if (headFlag == Fory.NULL_FLAG) { + values.add(null); + continue; + } + if (headFlag != Fory.NOT_NULL_VALUE_FLAG) { + throw new DeserializationException( + "Unexpected nullable peer list element flag " + headFlag); + } + values.add(readBoxedListElement(buffer, arrayTypeId, elementTypeId)); + } + return values; + } + + private static Object readBoxedListElement( + MemoryBuffer buffer, int arrayTypeId, int elementTypeId) { + switch (elementTypeId) { + case Types.BOOL: + return buffer.readBoolean(); + case Types.INT8: + return buffer.readByte(); + case Types.UINT8: + return buffer.readByte() & 0xFF; + case Types.INT16: + return buffer.readInt16(); + case Types.UINT16: + return buffer.readInt16() & 0xFFFF; + case Types.FLOAT16: + return Float16.fromBits(buffer.readInt16()); + case Types.BFLOAT16: + return BFloat16.fromBits(buffer.readInt16()); + case Types.INT32: + return buffer.readInt32(); + case Types.UINT32: + return Integer.toUnsignedLong(buffer.readInt32()); + case Types.VARINT32: + return buffer.readVarInt32(); + case Types.VAR_UINT32: + return Integer.toUnsignedLong(buffer.readVarUInt32()); + case Types.INT64: + case Types.UINT64: + return buffer.readInt64(); + case Types.VARINT64: + return buffer.readVarInt64(); + case Types.TAGGED_INT64: + return buffer.readTaggedInt64(); + case Types.VAR_UINT64: + return buffer.readVarUInt64(); + case Types.TAGGED_UINT64: + return buffer.readTaggedUInt64(); + case Types.FLOAT32: + return buffer.readFloat32(); + case Types.FLOAT64: + return buffer.readFloat64(); + default: + throw new DeserializationException( + "Unsupported peer list element type id " + + elementTypeId + + " for local array type id " + + arrayTypeId); + } + } + private static Object materializeTarget(Object array, int arrayTypeId, Class targetType) { if (targetType.isArray()) { return array; @@ -828,6 +1021,43 @@ private static Object materializeTarget(Object array, int arrayTypeId, Class throw new DeserializationException("Unsupported compatible list/array target " + targetType); } + private static Class denseArrayTargetType(Class targetType, int arrayTypeId) { + if (targetType.isArray() + || targetType == Float16Array.class + || targetType == BFloat16Array.class + || canMaterializePrimitiveListTarget(targetType, arrayTypeId)) { + return targetType; + } + return primitiveArrayClass(arrayTypeId); + } + + private static Class primitiveArrayClass(int arrayTypeId) { + switch (arrayTypeId) { + case Types.BOOL_ARRAY: + return boolean[].class; + case Types.INT8_ARRAY: + case Types.UINT8_ARRAY: + return byte[].class; + case Types.INT16_ARRAY: + case Types.UINT16_ARRAY: + case Types.FLOAT16_ARRAY: + case Types.BFLOAT16_ARRAY: + return short[].class; + case Types.INT32_ARRAY: + case Types.UINT32_ARRAY: + return int[].class; + case Types.INT64_ARRAY: + case Types.UINT64_ARRAY: + return long[].class; + case Types.FLOAT32_ARRAY: + return float[].class; + case Types.FLOAT64_ARRAY: + return double[].class; + default: + throw new IllegalArgumentException("Unsupported dense array type id " + arrayTypeId); + } + } + private static Object materializePrimitiveList( Object array, int arrayTypeId, Class targetType) { switch (arrayTypeId) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java index 78477f2480..81b6927434 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/StaticGeneratedStructSerializer.java @@ -19,7 +19,7 @@ package org.apache.fory.serializer; -import java.lang.reflect.Constructor; +import java.lang.invoke.MethodHandle; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -35,8 +35,10 @@ import org.apache.fory.exception.ForyException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.meta.FieldInfo; +import org.apache.fory.meta.FieldTypes; import org.apache.fory.meta.TypeDef; import org.apache.fory.meta.TypeExtMeta; +import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.reflect.TypeRef; import org.apache.fory.resolver.TypeInfo; import org.apache.fory.resolver.TypeResolver; @@ -44,6 +46,7 @@ import org.apache.fory.serializer.converter.FieldConverters; import org.apache.fory.type.Descriptor; import org.apache.fory.type.DescriptorGrouper; +import org.apache.fory.util.ExceptionUtils; import org.apache.fory.util.StringUtils; /** Base class used by javac-generated {@code @ForyStruct} serializers. */ @@ -123,20 +126,17 @@ private void setSerializerIfAbsent(TypeResolver typeResolver, Class type) { public StaticGeneratedStructSerializer copySerializer( TypeResolver typeResolver, Class type, TypeDef typeDef) { try { - Constructor constructor = - getClass() - .asSubclass(StaticGeneratedStructSerializer.class) - .getDeclaredConstructor(TypeResolver.class, Class.class, TypeDef.class); - constructor.setAccessible(true); - return (StaticGeneratedStructSerializer) - constructor.newInstance(typeResolver, type, typeDef); - } catch (ReflectiveOperationException e) { - throw new ForyException( - "Failed to copy static generated serializer " - + getClass().getName() - + " for " - + type.getName(), - e); + // The generated serializer class may be in a named application module that is not open to + // Fory. Reuse the trusted constructor owner instead of reflective setAccessible. + MethodHandle constructor = + ReflectionUtils.getCtrHandle( + getClass().asSubclass(StaticGeneratedStructSerializer.class), + TypeResolver.class, + Class.class, + TypeDef.class); + return (StaticGeneratedStructSerializer) constructor.invoke(typeResolver, type, typeDef); + } catch (Throwable e) { + throw ExceptionUtils.throwException(e); } } @@ -477,43 +477,8 @@ public final void skipField(ReadContext readContext, RemoteFieldInfo remoteField } } - protected final SerializationFieldInfo localFieldInfo(int matchedId) { - return localFieldsById[matchedId]; - } - - public final boolean canReadRemoteField(RemoteFieldInfo remoteField) { - if (remoteField.incompatibleCollectionArrayMatch) { - throw new DeserializationException( - "Cannot read remote field " - + remoteField.descriptor.getName() - + " as local field " - + localFieldsById[remoteField.matchedId].descriptor.getName() - + ": compatible list/array adaptation requires a matching non-null primitive element" - + " schema and does not apply recursively"); - } - if (remoteField.nestedCollectionArrayMatch) { - return false; - } - return remoteField.canRead; - } - - public final boolean canReadGeneratedField( - RemoteFieldInfo remoteField, SerializationFieldInfo localFieldInfo) { - if (remoteField.incompatibleCollectionArrayMatch) { - throw new DeserializationException( - "Cannot read remote field " - + remoteField.descriptor.getName() - + " as local field " - + localFieldInfo.descriptor.getName() - + ": compatible list/array adaptation requires a matching non-null primitive element" - + " schema and does not apply recursively"); - } - if (remoteField.nestedCollectionArrayMatch) { - return false; - } - return remoteField.canRead - || FieldConverters.canReadGeneratedField( - remoteField.serializationFieldInfo, localFieldInfo); + protected final SerializationFieldInfo localFieldInfo(int localFieldId) { + return localFieldsById[localFieldId]; } public final Object readCompatibleFieldValue( @@ -718,8 +683,12 @@ private void appendRemoteFields( matchedId == UNKNOWN_FIELD ? null : localDescriptors.get(matchedId); SerializationFieldInfo localFieldInfo = matchedId == UNKNOWN_FIELD ? null : localFieldsById[matchedId]; + boolean exactFieldSchema = false; if (localDescriptor != null) { Descriptor readDescriptor = fieldInfo.toDescriptor(typeResolver, localDescriptor); + exactFieldSchema = + readDescriptor == localDescriptor + || registeredFieldSchemaEquals(fieldInfo, localDescriptor); serializationFieldInfo = new SerializationFieldInfo( typeResolver, readDescriptor, serializationFieldInfo.codecCategory); @@ -732,10 +701,19 @@ private void appendRemoteFields( descriptor, serializationFieldInfo, localDescriptor, - localFieldInfo)); + localFieldInfo, + exactFieldSchema)); } } + private boolean registeredFieldSchemaEquals(FieldInfo fieldInfo, Descriptor localDescriptor) { + FieldTypes.FieldType remoteFieldType = fieldInfo.getFieldType(); + FieldTypes.FieldType localFieldType = FieldTypes.buildFieldType(typeResolver, localDescriptor); + return remoteFieldType instanceof FieldTypes.RegisteredFieldType + && localFieldType instanceof FieldTypes.RegisteredFieldType + && remoteFieldType.equals(localFieldType); + } + private SerializationFieldInfo[] buildLocalFieldsById(List descriptors) { FieldGroups fieldGroups = buildFieldGroups(descriptors); SerializationFieldInfo[] allFields = fieldGroups.allFields; @@ -795,7 +773,11 @@ private static String remoteFieldKey(Descriptor descriptor) { /** Remote field metadata consumed by generated compatible read methods. */ @Internal public static final class RemoteFieldInfo { + /** + * Doubled compatible-read dispatch id: local id * 2 for exact, local id * 2 + 1 for conversion. + */ public final int matchedId; + public final FieldInfo fieldInfo; public final Descriptor descriptor; public final SerializationFieldInfo serializationFieldInfo; @@ -812,8 +794,8 @@ private RemoteFieldInfo( Descriptor descriptor, SerializationFieldInfo serializationFieldInfo, Descriptor localDescriptor, - SerializationFieldInfo localFieldInfo) { - this.matchedId = matchedId; + SerializationFieldInfo localFieldInfo, + boolean exactFieldSchema) { this.fieldInfo = fieldInfo; this.descriptor = descriptor; this.serializationFieldInfo = serializationFieldInfo; @@ -825,20 +807,63 @@ private RemoteFieldInfo( this.nestedCollectionArrayMatch = CompatibleCollectionArrayReader.nestedCollectionArrayMatch( typeResolver, fieldInfo, localDescriptor); - if (localFieldInfo == null - || incompatibleCollectionArrayMatch - || nestedCollectionArrayMatch) { + if (localFieldInfo == null) { + this.matchedId = UNKNOWN_FIELD; this.canRead = false; this.compatibleScalarRead = false; } else if (compatibleCollectionArrayReadAction != null) { + this.matchedId = matchedId * 2 + 1; this.canRead = true; this.compatibleScalarRead = false; } else { - this.canRead = FieldConverters.canConvert(serializationFieldInfo, localFieldInfo); - this.compatibleScalarRead = - canRead - && FieldConverters.requiresSourceScalarRead(serializationFieldInfo, localFieldInfo); + boolean canGeneratedRead = + !incompatibleCollectionArrayMatch + && !nestedCollectionArrayMatch + && FieldConverters.canReadCompatibleField( + typeResolver, serializationFieldInfo, localFieldInfo); + if (exactFieldSchema) { + this.matchedId = matchedId * 2; + this.canRead = true; + this.compatibleScalarRead = false; + } else if (canGeneratedRead) { + this.matchedId = matchedId * 2 + 1; + this.canRead = true; + this.compatibleScalarRead = + FieldConverters.requiresSourceScalarRead(serializationFieldInfo, localFieldInfo); + } else { + throw incompatibleFieldError( + fieldInfo, + localFieldInfo, + incompatibleCollectionArrayMatch, + nestedCollectionArrayMatch); + } } } + + private static DeserializationException incompatibleFieldError( + FieldInfo fieldInfo, + SerializationFieldInfo localFieldInfo, + boolean incompatibleCollectionArrayMatch, + boolean nestedCollectionArrayMatch) { + String reason; + if (incompatibleCollectionArrayMatch || nestedCollectionArrayMatch) { + reason = + "compatible list/array adaptation requires a matching non-null primitive element" + + " schema and does not apply recursively"; + } else { + reason = "remote and local field schemas are not compatible"; + } + return new DeserializationException( + "Cannot read remote field " + + fieldInfo.getDefinedClass() + + "." + + fieldInfo.getFieldName() + + " as local field " + + localFieldInfo.descriptor.getDeclaringClass() + + "." + + localFieldInfo.descriptor.getName() + + ": " + + reason); + } } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/converter/CompatibleScalarConverter.java b/java/fory-core/src/main/java/org/apache/fory/serializer/converter/CompatibleScalarConverter.java index bf0aa6c106..02bb7fdcdd 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/converter/CompatibleScalarConverter.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/converter/CompatibleScalarConverter.java @@ -21,9 +21,14 @@ import java.math.BigDecimal; import java.math.BigInteger; +import org.apache.fory.Fory; import org.apache.fory.annotation.Internal; +import org.apache.fory.context.ReadContext; import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; +import org.apache.fory.resolver.RefMode; +import org.apache.fory.resolver.TypeInfo; +import org.apache.fory.serializer.FieldGroups.SerializationFieldInfo; import org.apache.fory.type.BFloat16; import org.apache.fory.type.DispatchId; import org.apache.fory.type.Float16; @@ -183,7 +188,225 @@ static Object convert( throw dataError(fromDispatchId, fromType, toDispatchId, toType, fieldName); } - static Boolean readBool( + static boolean readBooleanTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return false; + } + return readBooleanTargetPayload(readContext, buffer, from, to); + } + + static Boolean readBoxedBooleanTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + return readBooleanTargetPayload(readContext, buffer, from, to); + } + + static byte readByteTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return 0; + } + return (byte) readIntegerBitsTarget(readContext, buffer, from, to); + } + + static Byte readBoxedByteTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + return (byte) readIntegerBitsTarget(readContext, buffer, from, to); + } + + static short readShortTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return 0; + } + return (short) readIntegerBitsTarget(readContext, buffer, from, to); + } + + static Short readBoxedShortTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + return (short) readIntegerBitsTarget(readContext, buffer, from, to); + } + + static int readIntTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return 0; + } + return (int) readIntegerBitsTarget(readContext, buffer, from, to); + } + + static Integer readBoxedIntTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + return (int) readIntegerBitsTarget(readContext, buffer, from, to); + } + + static long readLongTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return 0L; + } + return readIntegerBitsTarget(readContext, buffer, from, to); + } + + static Long readBoxedLongTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + return readIntegerBitsTarget(readContext, buffer, from, to); + } + + static float readFloatTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return 0.0f; + } + return readFloatTargetPayload(readContext, buffer, from, to); + } + + static Float readBoxedFloatTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + return readFloatTargetPayload(readContext, buffer, from, to); + } + + static double readDoubleTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return 0.0d; + } + return readDoubleTargetPayload(readContext, buffer, from, to); + } + + static Double readBoxedDoubleTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + return readDoubleTargetPayload(readContext, buffer, from, to); + } + + static String readStringTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + String fieldName = to.qualifiedFieldName; + int fromDomain = domain(from.dispatchId, from.type); + switch (fromDomain) { + case BOOL: + return readBool(buffer, from.dispatchId, from.type, fieldName) ? "true" : "false"; + case STRING: + return readContext.readString(); + case SIGNED_INT: + case UNSIGNED_INT: + return readIntegerSource(buffer, from, fieldName).toString(); + case DECIMAL: + return decimalText(readDecimalSource(readContext, from)); + case FLOAT: + return readFloatText(buffer, from, fieldName); + default: + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + } + + static BigDecimal readDecimalTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + return readNumericDecimal(readContext, buffer, from, to.qualifiedFieldName); + } + + static UInt8 readUInt8Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + BigInteger value = readIntegerTargetPayload(readContext, buffer, from, to); + return UInt8.valueOf(value.intValue()); + } + + static UInt16 readUInt16Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + BigInteger value = readIntegerTargetPayload(readContext, buffer, from, to); + return UInt16.valueOf(value.intValue()); + } + + static UInt32 readUInt32Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + BigInteger value = readIntegerTargetPayload(readContext, buffer, from, to); + return UInt32.valueOf(value.intValue()); + } + + static UInt64 readUInt64Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + BigInteger value = readIntegerTargetPayload(readContext, buffer, from, to); + return UInt64.valueOf(value.longValue()); + } + + static Float16 readFloat16Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + return readFloat16TargetPayload(readContext, buffer, from, to); + } + + static BFloat16 readBFloat16Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + MemoryBuffer buffer = readScalarBuffer(readContext, from); + if (buffer == null) { + return null; + } + return readBFloat16TargetPayload(readContext, buffer, from, to); + } + + static boolean readBool( MemoryBuffer buffer, int fromDispatchId, Class fromType, String fieldName) { byte raw = buffer.readByte(); if (raw == 0) { @@ -195,6 +418,864 @@ static Boolean readBool( throw dataError(fromDispatchId, fromType, DispatchId.BOOL, Boolean.class, fieldName); } + private static MemoryBuffer readScalarBuffer( + ReadContext readContext, SerializationFieldInfo from) { + if (from.refMode == RefMode.TRACKING) { + throw new DeserializationException( + "Reference-tracked scalar conversion is schema incompatible for " + + from.qualifiedFieldName); + } + MemoryBuffer buffer = readContext.getBuffer(); + if (from.refMode == RefMode.NULL_ONLY) { + byte flag = buffer.readByte(); + if (flag == Fory.NULL_FLAG) { + return null; + } + if (flag != Fory.NOT_NULL_VALUE_FLAG) { + throw new DeserializationException( + "Invalid nullable compatible scalar field flag " + + flag + + " for " + + from.qualifiedFieldName); + } + } + return buffer; + } + + private static boolean readBooleanTargetPayload( + ReadContext readContext, + MemoryBuffer buffer, + SerializationFieldInfo from, + SerializationFieldInfo to) { + String fieldName = to.qualifiedFieldName; + switch (domain(from.dispatchId, from.type)) { + case BOOL: + return readBool(buffer, from.dispatchId, from.type, fieldName); + case STRING: + return boolFromString(readContext.readString(), from, to); + case SIGNED_INT: + case UNSIGNED_INT: + return boolFromInteger(readIntegerSource(buffer, from, fieldName), from, to); + case DECIMAL: + return boolFromDecimal(readDecimalSource(readContext, from), from, to); + case FLOAT: + return readFloatBool(buffer, from, to); + default: + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + } + + private static boolean boolFromString( + String text, SerializationFieldInfo from, SerializationFieldInfo to) { + if ("0".equals(text) || "false".equals(text)) { + return false; + } + if ("1".equals(text) || "true".equals(text)) { + return true; + } + throw dataError(from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + + private static boolean boolFromInteger( + BigInteger value, SerializationFieldInfo from, SerializationFieldInfo to) { + if (BIG_ZERO.equals(value)) { + return false; + } + if (BIG_ONE.equals(value)) { + return true; + } + throw dataError(from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + + private static boolean boolFromDecimal( + BigDecimal value, SerializationFieldInfo from, SerializationFieldInfo to) { + if (value.compareTo(DECIMAL_ZERO) == 0) { + return false; + } + if (value.compareTo(DECIMAL_ONE) == 0) { + return true; + } + throw dataError(from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + + private static BigInteger readIntegerTargetPayload( + ReadContext readContext, + MemoryBuffer buffer, + SerializationFieldInfo from, + SerializationFieldInfo to) { + String fieldName = to.qualifiedFieldName; + BigInteger value; + switch (domain(from.dispatchId, from.type)) { + case BOOL: + value = readBool(buffer, from.dispatchId, from.type, fieldName) ? BIG_ONE : BIG_ZERO; + break; + case STRING: + value = + integerFromDecimal(parseDecimalString(readContext.readString(), from, to), from, to); + break; + case SIGNED_INT: + case UNSIGNED_INT: + value = readIntegerSource(buffer, from, fieldName); + break; + case DECIMAL: + value = integerFromDecimal(readDecimalSource(readContext, from), from, to); + break; + case FLOAT: + value = integerFromDecimal(readFiniteFloatDecimal(buffer, from, fieldName), from, to); + break; + default: + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + checkIntegerRange(value, to.dispatchId, to.type, fieldName); + return value; + } + + private static long readIntegerBitsTarget( + ReadContext readContext, + MemoryBuffer buffer, + SerializationFieldInfo from, + SerializationFieldInfo to) { + String fieldName = to.qualifiedFieldName; + switch (domain(from.dispatchId, from.type)) { + case BOOL: + return checkedSignedIntegerBits( + readBool(buffer, from.dispatchId, from.type, fieldName) ? 1L : 0L, + to.dispatchId, + to.type, + fieldName); + case STRING: + case DECIMAL: + case FLOAT: + return readIntegerTargetPayload(readContext, buffer, from, to).longValue(); + case SIGNED_INT: + return readSignedIntegerBitsTarget(buffer, from, to.dispatchId, to.type, fieldName); + case UNSIGNED_INT: + return readUnsignedIntegerBitsTarget(buffer, from, to.dispatchId, to.type, fieldName); + default: + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + } + + private static long readSignedIntegerBitsTarget( + MemoryBuffer buffer, + SerializationFieldInfo from, + int toDispatchId, + Class toType, + String fieldName) { + long value; + switch (from.dispatchId) { + case DispatchId.INT8: + value = buffer.readByte(); + break; + case DispatchId.INT16: + value = buffer.readInt16(); + break; + case DispatchId.INT32: + value = buffer.readInt32(); + break; + case DispatchId.VARINT32: + value = buffer.readVarInt32(); + break; + case DispatchId.INT64: + value = buffer.readInt64(); + break; + case DispatchId.VARINT64: + value = buffer.readVarInt64(); + break; + case DispatchId.TAGGED_INT64: + value = buffer.readTaggedInt64(); + break; + default: + throw dataError(from.dispatchId, from.type, toDispatchId, toType, fieldName); + } + return checkedSignedIntegerBits(value, toDispatchId, toType, fieldName); + } + + private static long readUnsignedIntegerBitsTarget( + MemoryBuffer buffer, + SerializationFieldInfo from, + int toDispatchId, + Class toType, + String fieldName) { + long value; + boolean unsigned64 = false; + switch (from.dispatchId) { + case DispatchId.UINT8: + case DispatchId.EXT_UINT8: + value = buffer.readByte() & 0xFFL; + break; + case DispatchId.UINT16: + case DispatchId.EXT_UINT16: + value = buffer.readInt16() & 0xFFFFL; + break; + case DispatchId.UINT32: + case DispatchId.EXT_UINT32: + value = Integer.toUnsignedLong(buffer.readInt32()); + break; + case DispatchId.VAR_UINT32: + case DispatchId.EXT_VAR_UINT32: + value = Integer.toUnsignedLong(buffer.readVarUInt32()); + break; + case DispatchId.UINT64: + case DispatchId.EXT_UINT64: + value = buffer.readInt64(); + unsigned64 = true; + break; + case DispatchId.VAR_UINT64: + case DispatchId.EXT_VAR_UINT64: + value = buffer.readVarUInt64(); + unsigned64 = true; + break; + case DispatchId.TAGGED_UINT64: + value = buffer.readTaggedUInt64(); + unsigned64 = true; + break; + default: + throw dataError(from.dispatchId, from.type, toDispatchId, toType, fieldName); + } + return unsigned64 + ? checkedUnsigned64IntegerBits(value, toDispatchId, toType, fieldName) + : checkedUnsignedIntegerBits(value, toDispatchId, toType, fieldName); + } + + private static long checkedSignedIntegerBits( + long value, int toDispatchId, Class toType, String fieldName) { + switch (normalizedIntegerId(toDispatchId, toType)) { + case DispatchId.INT8: + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.INT16: + if (value < Short.MIN_VALUE || value > Short.MAX_VALUE) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.INT32: + if (value < Integer.MIN_VALUE || value > Integer.MAX_VALUE) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.INT64: + return value; + case DispatchId.UINT8: + if (value < 0 || value > 0xFFL) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.UINT16: + if (value < 0 || value > 0xFFFFL) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.UINT32: + if (value < 0 || value > 0xFFFF_FFFFL) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.UINT64: + if (value < 0) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + default: + throw dataError(DispatchId.UNKNOWN, BigInteger.class, toDispatchId, toType, fieldName); + } + } + + private static long checkedUnsignedIntegerBits( + long value, int toDispatchId, Class toType, String fieldName) { + switch (normalizedIntegerId(toDispatchId, toType)) { + case DispatchId.INT8: + if (value > Byte.MAX_VALUE) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.INT16: + if (value > Short.MAX_VALUE) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.INT32: + if (value > Integer.MAX_VALUE) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.INT64: + case DispatchId.UINT64: + return value; + case DispatchId.UINT8: + if (value > 0xFFL) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.UINT16: + if (value > 0xFFFFL) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.UINT32: + if (value > 0xFFFF_FFFFL) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + default: + throw dataError(DispatchId.UNKNOWN, BigInteger.class, toDispatchId, toType, fieldName); + } + } + + private static long checkedUnsigned64IntegerBits( + long value, int toDispatchId, Class toType, String fieldName) { + switch (normalizedIntegerId(toDispatchId, toType)) { + case DispatchId.INT8: + if (value < 0 || value > Byte.MAX_VALUE) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.INT16: + if (value < 0 || value > Short.MAX_VALUE) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.INT32: + if (value < 0 || value > Integer.MAX_VALUE) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.INT64: + if (value < 0) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.UINT8: + if (Long.compareUnsigned(value, 0xFFL) > 0) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.UINT16: + if (Long.compareUnsigned(value, 0xFFFFL) > 0) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.UINT32: + if (Long.compareUnsigned(value, 0xFFFF_FFFFL) > 0) { + throwIntegerRangeError(toDispatchId, toType, fieldName); + } + return value; + case DispatchId.UINT64: + return value; + default: + throw dataError(DispatchId.UNKNOWN, BigInteger.class, toDispatchId, toType, fieldName); + } + } + + private static void throwIntegerRangeError(int toDispatchId, Class toType, String fieldName) { + throw dataError(DispatchId.UNKNOWN, BigInteger.class, toDispatchId, toType, fieldName); + } + + private static BigInteger integerFromDecimal( + BigDecimal decimal, SerializationFieldInfo from, SerializationFieldInfo to) { + try { + return canonicalDecimal(decimal).toBigIntegerExact(); + } catch (ArithmeticException e) { + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, to.qualifiedFieldName, e); + } + } + + private static BigInteger readIntegerSource( + MemoryBuffer buffer, SerializationFieldInfo from, String fieldName) { + switch (from.dispatchId) { + case DispatchId.INT8: + return BigInteger.valueOf(buffer.readByte()); + case DispatchId.UINT8: + case DispatchId.EXT_UINT8: + return BigInteger.valueOf(buffer.readByte() & 0xFF); + case DispatchId.INT16: + return BigInteger.valueOf(buffer.readInt16()); + case DispatchId.UINT16: + case DispatchId.EXT_UINT16: + return BigInteger.valueOf(buffer.readInt16() & 0xFFFF); + case DispatchId.INT32: + return BigInteger.valueOf(buffer.readInt32()); + case DispatchId.UINT32: + case DispatchId.EXT_UINT32: + return BigInteger.valueOf(Integer.toUnsignedLong(buffer.readInt32())); + case DispatchId.VARINT32: + return BigInteger.valueOf(buffer.readVarInt32()); + case DispatchId.VAR_UINT32: + case DispatchId.EXT_VAR_UINT32: + return BigInteger.valueOf(Integer.toUnsignedLong(buffer.readVarUInt32())); + case DispatchId.INT64: + return BigInteger.valueOf(buffer.readInt64()); + case DispatchId.UINT64: + case DispatchId.EXT_UINT64: + return unsignedLongBigInteger(buffer.readInt64()); + case DispatchId.VARINT64: + return BigInteger.valueOf(buffer.readVarInt64()); + case DispatchId.TAGGED_INT64: + return BigInteger.valueOf(buffer.readTaggedInt64()); + case DispatchId.VAR_UINT64: + case DispatchId.EXT_VAR_UINT64: + return unsignedLongBigInteger(buffer.readVarUInt64()); + case DispatchId.TAGGED_UINT64: + return unsignedLongBigInteger(buffer.readTaggedUInt64()); + default: + throw dataError( + from.dispatchId, from.type, DispatchId.UNKNOWN, BigInteger.class, fieldName); + } + } + + private static BigInteger unsignedLongBigInteger(long bits) { + if (bits >= 0) { + return BigInteger.valueOf(bits); + } + return BigInteger.valueOf(bits & Long.MAX_VALUE).setBit(63); + } + + private static BigDecimal readNumericDecimal( + ReadContext readContext, MemoryBuffer buffer, SerializationFieldInfo from, String fieldName) { + switch (domain(from.dispatchId, from.type)) { + case BOOL: + return readBool(buffer, from.dispatchId, from.type, fieldName) ? DECIMAL_ONE : DECIMAL_ZERO; + case STRING: + return canonicalDecimal(parseDecimalString(readContext.readString(), from, from)); + case SIGNED_INT: + case UNSIGNED_INT: + return new BigDecimal(readIntegerSource(buffer, from, fieldName)); + case DECIMAL: + return canonicalDecimal(readDecimalSource(readContext, from)); + case FLOAT: + return readFiniteFloatDecimal(buffer, from, fieldName); + default: + throw dataError( + from.dispatchId, from.type, DispatchId.UNKNOWN, BigDecimal.class, fieldName); + } + } + + private static BigDecimal parseDecimalString( + String value, SerializationFieldInfo from, SerializationFieldInfo to) { + if (!numericLiteralFits(value)) { + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, to.qualifiedFieldName); + } + try { + return new BigDecimal(value); + } catch (NumberFormatException e) { + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, to.qualifiedFieldName, e); + } + } + + private static BigDecimal readDecimalSource( + ReadContext readContext, SerializationFieldInfo from) { + if (from.useDeclaredTypeInfo) { + readContext.preserveRefId(-1); + return (BigDecimal) readContext.readNonRef(from.typeInfo); + } + TypeInfo typeInfo = readContext.getTypeResolver().readTypeInfo(readContext, from.type); + return (BigDecimal) typeInfo.getSerializer().read(readContext, RefMode.NONE); + } + + private static float readFloatTargetPayload( + ReadContext readContext, + MemoryBuffer buffer, + SerializationFieldInfo from, + SerializationFieldInfo to) { + String fieldName = to.qualifiedFieldName; + switch (domain(from.dispatchId, from.type)) { + case BOOL: + return readBool(buffer, from.dispatchId, from.type, fieldName) ? 1.0f : 0.0f; + case STRING: + { + String value = readContext.readString(); + BigDecimal decimal = parseDecimalString(value, from, to); + return decimalToFloat32( + decimal, value.charAt(0) == '-' && decimal.signum() == 0, to, fieldName); + } + case SIGNED_INT: + case UNSIGNED_INT: + return decimalToFloat32( + new BigDecimal(readIntegerSource(buffer, from, fieldName)), false, to, fieldName); + case DECIMAL: + return decimalToFloat32(readDecimalSource(readContext, from), false, to, fieldName); + case FLOAT: + return readFloatAsFloat32(buffer, from, to, fieldName); + default: + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + } + + private static double readDoubleTargetPayload( + ReadContext readContext, + MemoryBuffer buffer, + SerializationFieldInfo from, + SerializationFieldInfo to) { + String fieldName = to.qualifiedFieldName; + switch (domain(from.dispatchId, from.type)) { + case BOOL: + return readBool(buffer, from.dispatchId, from.type, fieldName) ? 1.0d : 0.0d; + case STRING: + { + String value = readContext.readString(); + BigDecimal decimal = parseDecimalString(value, from, to); + return decimalToFloat64( + decimal, value.charAt(0) == '-' && decimal.signum() == 0, to, fieldName); + } + case SIGNED_INT: + case UNSIGNED_INT: + return decimalToFloat64( + new BigDecimal(readIntegerSource(buffer, from, fieldName)), false, to, fieldName); + case DECIMAL: + return decimalToFloat64(readDecimalSource(readContext, from), false, to, fieldName); + case FLOAT: + return readFloatAsFloat64(buffer, from, to, fieldName); + default: + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + } + + private static Float16 readFloat16TargetPayload( + ReadContext readContext, + MemoryBuffer buffer, + SerializationFieldInfo from, + SerializationFieldInfo to) { + if (sameScalar(from.dispatchId, from.type, to.dispatchId, to.type)) { + return Float16.fromBits(buffer.readInt16()); + } + return (Float16) + decimalToFloat( + readNumericDecimal(readContext, buffer, from, to.qualifiedFieldName), + to.dispatchId, + to.type, + false, + to.qualifiedFieldName); + } + + private static BFloat16 readBFloat16TargetPayload( + ReadContext readContext, + MemoryBuffer buffer, + SerializationFieldInfo from, + SerializationFieldInfo to) { + if (sameScalar(from.dispatchId, from.type, to.dispatchId, to.type)) { + return BFloat16.fromBits(buffer.readInt16()); + } + return (BFloat16) + decimalToFloat( + readNumericDecimal(readContext, buffer, from, to.qualifiedFieldName), + to.dispatchId, + to.type, + false, + to.qualifiedFieldName); + } + + private static boolean readFloatBool( + MemoryBuffer buffer, SerializationFieldInfo from, SerializationFieldInfo to) { + switch (normalizedFloatId(from.dispatchId, from.type)) { + case DispatchId.FLOAT16: + { + Float16 value = Float16.fromBits(buffer.readInt16()); + if (!value.isFinite()) { + throw dataError( + from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + if (value.isZero()) { + return false; + } + if (value.toBits() == Float16.ONE.toBits()) { + return true; + } + throw dataError( + from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + case DispatchId.BFLOAT16: + { + BFloat16 value = BFloat16.fromBits(buffer.readInt16()); + if (!value.isFinite()) { + throw dataError( + from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + if (value.isZero()) { + return false; + } + if (value.toBits() == BFloat16.ONE.toBits()) { + return true; + } + throw dataError( + from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + case DispatchId.FLOAT32: + { + float value = buffer.readFloat32(); + if (!Float.isFinite(value)) { + throw dataError( + from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + if (value == 0.0f) { + return false; + } + if (value == 1.0f) { + return true; + } + throw dataError( + from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + case DispatchId.FLOAT64: + { + double value = buffer.readFloat64(); + if (!Double.isFinite(value)) { + throw dataError( + from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + if (value == 0.0d) { + return false; + } + if (value == 1.0d) { + return true; + } + throw dataError( + from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + default: + throw dataError( + from.dispatchId, from.type, DispatchId.BOOL, to.type, to.qualifiedFieldName); + } + } + + private static BigDecimal readFiniteFloatDecimal( + MemoryBuffer buffer, SerializationFieldInfo from, String fieldName) { + switch (normalizedFloatId(from.dispatchId, from.type)) { + case DispatchId.FLOAT16: + { + Float16 value = Float16.fromBits(buffer.readInt16()); + if (!value.isFinite()) { + throw dataError( + from.dispatchId, from.type, DispatchId.UNKNOWN, BigDecimal.class, fieldName); + } + return value.isZero() ? BigDecimal.ZERO : exactFloatDecimal(value.floatValue()); + } + case DispatchId.BFLOAT16: + { + BFloat16 value = BFloat16.fromBits(buffer.readInt16()); + if (!value.isFinite()) { + throw dataError( + from.dispatchId, from.type, DispatchId.UNKNOWN, BigDecimal.class, fieldName); + } + return value.isZero() ? BigDecimal.ZERO : exactFloatDecimal(value.floatValue()); + } + case DispatchId.FLOAT32: + { + float value = buffer.readFloat32(); + if (!Float.isFinite(value)) { + throw dataError( + from.dispatchId, from.type, DispatchId.UNKNOWN, BigDecimal.class, fieldName); + } + return value == 0.0f ? BigDecimal.ZERO : exactFloatDecimal(value); + } + case DispatchId.FLOAT64: + { + double value = buffer.readFloat64(); + if (!Double.isFinite(value)) { + throw dataError( + from.dispatchId, from.type, DispatchId.UNKNOWN, BigDecimal.class, fieldName); + } + return value == 0.0d ? BigDecimal.ZERO : exactDoubleDecimal(value); + } + default: + throw dataError( + from.dispatchId, from.type, DispatchId.UNKNOWN, BigDecimal.class, fieldName); + } + } + + private static String readFloatText( + MemoryBuffer buffer, SerializationFieldInfo from, String fieldName) { + switch (normalizedFloatId(from.dispatchId, from.type)) { + case DispatchId.FLOAT16: + { + Float16 value = Float16.fromBits(buffer.readInt16()); + if (!value.isFinite()) { + throw dataError(from.dispatchId, from.type, DispatchId.STRING, String.class, fieldName); + } + if (value.isZero()) { + return value.signbit() ? "-0.0" : "0.0"; + } + String text = decimalText(exactFloatDecimal(value.floatValue())); + return text.indexOf('.') >= 0 ? text : text + ".0"; + } + case DispatchId.BFLOAT16: + { + BFloat16 value = BFloat16.fromBits(buffer.readInt16()); + if (!value.isFinite()) { + throw dataError(from.dispatchId, from.type, DispatchId.STRING, String.class, fieldName); + } + if (value.isZero()) { + return value.signbit() ? "-0.0" : "0.0"; + } + String text = decimalText(exactFloatDecimal(value.floatValue())); + return text.indexOf('.') >= 0 ? text : text + ".0"; + } + case DispatchId.FLOAT32: + { + float value = buffer.readFloat32(); + if (!Float.isFinite(value)) { + throw dataError(from.dispatchId, from.type, DispatchId.STRING, String.class, fieldName); + } + if (value == 0.0f) { + return Float.floatToRawIntBits(value) < 0 ? "-0.0" : "0.0"; + } + String text = decimalText(exactFloatDecimal(value)); + return text.indexOf('.') >= 0 ? text : text + ".0"; + } + case DispatchId.FLOAT64: + { + double value = buffer.readFloat64(); + if (!Double.isFinite(value)) { + throw dataError(from.dispatchId, from.type, DispatchId.STRING, String.class, fieldName); + } + if (value == 0.0d) { + return Double.doubleToRawLongBits(value) < 0 ? "-0.0" : "0.0"; + } + String text = decimalText(exactDoubleDecimal(value)); + return text.indexOf('.') >= 0 ? text : text + ".0"; + } + default: + throw dataError(from.dispatchId, from.type, DispatchId.STRING, String.class, fieldName); + } + } + + private static float readFloatAsFloat32( + MemoryBuffer buffer, + SerializationFieldInfo from, + SerializationFieldInfo to, + String fieldName) { + switch (normalizedFloatId(from.dispatchId, from.type)) { + case DispatchId.FLOAT16: + { + Float16 value = Float16.fromBits(buffer.readInt16()); + if (value.isNaN()) { + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + if (value.isInfinite()) { + return value.signbit() ? Float.NEGATIVE_INFINITY : Float.POSITIVE_INFINITY; + } + if (value.isZero()) { + return value.signbit() ? -0.0f : 0.0f; + } + return decimalToFloat32(exactFloatDecimal(value.floatValue()), false, to, fieldName); + } + case DispatchId.BFLOAT16: + { + BFloat16 value = BFloat16.fromBits(buffer.readInt16()); + if (value.isNaN()) { + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + if (value.isInfinite()) { + return value.signbit() ? Float.NEGATIVE_INFINITY : Float.POSITIVE_INFINITY; + } + if (value.isZero()) { + return value.signbit() ? -0.0f : 0.0f; + } + return decimalToFloat32(exactFloatDecimal(value.floatValue()), false, to, fieldName); + } + case DispatchId.FLOAT32: + return buffer.readFloat32(); + case DispatchId.FLOAT64: + { + double value = buffer.readFloat64(); + if (Double.isNaN(value)) { + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + if (Double.isInfinite(value)) { + return value < 0 ? Float.NEGATIVE_INFINITY : Float.POSITIVE_INFINITY; + } + if (value == 0.0d) { + return Double.doubleToRawLongBits(value) < 0 ? -0.0f : 0.0f; + } + return decimalToFloat32(exactDoubleDecimal(value), false, to, fieldName); + } + default: + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + } + + private static double readFloatAsFloat64( + MemoryBuffer buffer, + SerializationFieldInfo from, + SerializationFieldInfo to, + String fieldName) { + switch (normalizedFloatId(from.dispatchId, from.type)) { + case DispatchId.FLOAT16: + { + Float16 value = Float16.fromBits(buffer.readInt16()); + if (value.isNaN()) { + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + if (value.isInfinite()) { + return value.signbit() ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY; + } + if (value.isZero()) { + return value.signbit() ? -0.0d : 0.0d; + } + return decimalToFloat64(exactFloatDecimal(value.floatValue()), false, to, fieldName); + } + case DispatchId.BFLOAT16: + { + BFloat16 value = BFloat16.fromBits(buffer.readInt16()); + if (value.isNaN()) { + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + if (value.isInfinite()) { + return value.signbit() ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY; + } + if (value.isZero()) { + return value.signbit() ? -0.0d : 0.0d; + } + return decimalToFloat64(exactFloatDecimal(value.floatValue()), false, to, fieldName); + } + case DispatchId.FLOAT32: + { + float value = buffer.readFloat32(); + if (Float.isNaN(value)) { + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + if (Float.isInfinite(value)) { + return value < 0 ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY; + } + if (value == 0.0f) { + return Float.floatToRawIntBits(value) < 0 ? -0.0d : 0.0d; + } + return decimalToFloat64(exactFloatDecimal(value), false, to, fieldName); + } + case DispatchId.FLOAT64: + return buffer.readFloat64(); + default: + throw dataError(from.dispatchId, from.type, to.dispatchId, to.type, fieldName); + } + } + + private static float decimalToFloat32( + BigDecimal decimal, boolean negativeZero, SerializationFieldInfo to, String fieldName) { + decimal = canonicalDecimal(decimal); + if (decimal.signum() == 0) { + return negativeZero ? -0.0f : 0.0f; + } + float result = decimal.floatValue(); + if (!Float.isFinite(result) || exactFloatDecimal(result).compareTo(decimal) != 0) { + throw dataError(DispatchId.UNKNOWN, BigDecimal.class, to.dispatchId, to.type, fieldName); + } + return result; + } + + private static double decimalToFloat64( + BigDecimal decimal, boolean negativeZero, SerializationFieldInfo to, String fieldName) { + decimal = canonicalDecimal(decimal); + if (decimal.signum() == 0) { + return negativeZero ? -0.0d : 0.0d; + } + double result = decimal.doubleValue(); + if (!Double.isFinite(result) || exactDoubleDecimal(result).compareTo(decimal) != 0) { + throw dataError(DispatchId.UNKNOWN, BigDecimal.class, to.dispatchId, to.type, fieldName); + } + return result; + } + private static Object fromBool( boolean value, int toDispatchId, Class toType, int toDomain, String fieldName) { if (toDomain == BOOL) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/converter/FieldConverters.java b/java/fory-core/src/main/java/org/apache/fory/serializer/converter/FieldConverters.java index f09d34ec5f..5cb3559e35 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/converter/FieldConverters.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/converter/FieldConverters.java @@ -37,6 +37,7 @@ import org.apache.fory.type.DispatchId; import org.apache.fory.type.Float16; import org.apache.fory.type.TypeUtils; +import org.apache.fory.type.Types; import org.apache.fory.type.unsigned.UInt16; import org.apache.fory.type.unsigned.UInt32; import org.apache.fory.type.unsigned.UInt64; @@ -122,6 +123,9 @@ public static boolean canConvert(Class from, Class to) { */ @Internal public static boolean canConvert(TypeResolver resolver, Descriptor from, Descriptor to) { + if (hasIncompatibleRootArrayOrBinary(resolver, from, to)) { + return false; + } int fromDispatchId = DispatchId.getDispatchId(resolver, from); int toDispatchId = DispatchId.getDispatchId(resolver, to); if (isRefTrackedScalarSchemaMismatch( @@ -178,13 +182,13 @@ public static boolean canConvert(SerializationFieldInfo from, SerializationField return CompatibleScalarConverter.canConvert(from.dispatchId, from.type, to.dispatchId, to.type); } - /** - * Returns whether a generated serializer can read {@code from} and adapt it with generated - * field-access code for {@code to}. - */ + /** Returns whether a remote field can be classified as a compatible local-field read. */ @Internal - public static boolean canReadGeneratedField( - SerializationFieldInfo from, SerializationFieldInfo to) { + public static boolean canReadCompatibleField( + TypeResolver resolver, SerializationFieldInfo from, SerializationFieldInfo to) { + if (hasIncompatibleRootArrayOrBinary(resolver, from.descriptor, to.descriptor)) { + return false; + } if (canConvert(from, to)) { return true; } @@ -246,19 +250,6 @@ public static Object convertValue( from.dispatchId, from.type, to.dispatchId, to.type, value, to.qualifiedFieldName); } - /** Converts a compatible field value using scalar metadata captured at code generation time. */ - @Internal - public static Object convertValue( - int fromDispatchId, - Class fromType, - int toDispatchId, - Class toType, - String fieldName, - Object value) { - return CompatibleScalarConverter.convert( - fromDispatchId, fromType, toDispatchId, toType, value, fieldName); - } - /** Returns whether descriptor-level compatible read must read a source scalar payload. */ @Internal public static boolean requiresSourceScalarRead( @@ -289,7 +280,13 @@ public static boolean requiresSourceScalarRead( public static Object readSourceScalar( ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { return readSourceScalar( - readContext, from, from.refMode, from.dispatchId, from.type, false, to.qualifiedFieldName); + readContext, + from, + from.refMode, + from.dispatchId, + from.type, + from.useDeclaredTypeInfo, + to.qualifiedFieldName); } /** Reads a remote scalar conversion source value for an existing field converter. */ @@ -303,29 +300,10 @@ public static Object readSourceScalar( from.refMode, scalar.fromDispatchId, scalar.fromType, - false, + from.useDeclaredTypeInfo, scalar.fieldName); } - /** Reads a remote scalar conversion source value from generated compatible serializers. */ - @Internal - public static Object readSourceScalar( - ReadContext readContext, - int fromDispatchId, - Class fromType, - boolean nullable, - boolean declaredTypeInfo, - String fieldName) { - return readSourceScalar( - readContext, - null, - nullable ? RefMode.NULL_ONLY : RefMode.NONE, - fromDispatchId, - fromType, - declaredTypeInfo, - fieldName); - } - private static Object readSourceScalar( ReadContext readContext, SerializationFieldInfo from, @@ -411,28 +389,140 @@ private static Object readSourceScalar( } @Internal - public static int fromDispatchId(FieldConverter converter) { - return scalarConverter(converter).fromDispatchId; + public static boolean readBooleanTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readBooleanTarget(readContext, from, to); } @Internal - public static Class fromType(FieldConverter converter) { - return scalarConverter(converter).fromType; + public static Boolean readBoxedBooleanTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readBoxedBooleanTarget(readContext, from, to); } @Internal - public static int toDispatchId(FieldConverter converter) { - return scalarConverter(converter).toDispatchId; + public static byte readByteTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readByteTarget(readContext, from, to); } @Internal - public static Class toType(FieldConverter converter) { - return scalarConverter(converter).toType; + public static Byte readBoxedByteTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readBoxedByteTarget(readContext, from, to); + } + + @Internal + public static short readShortTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readShortTarget(readContext, from, to); + } + + @Internal + public static Short readBoxedShortTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readBoxedShortTarget(readContext, from, to); } @Internal - public static String fieldName(FieldConverter converter) { - return scalarConverter(converter).fieldName; + public static int readIntTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readIntTarget(readContext, from, to); + } + + @Internal + public static Integer readBoxedIntTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readBoxedIntTarget(readContext, from, to); + } + + @Internal + public static long readLongTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readLongTarget(readContext, from, to); + } + + @Internal + public static Long readBoxedLongTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readBoxedLongTarget(readContext, from, to); + } + + @Internal + public static float readFloatTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readFloatTarget(readContext, from, to); + } + + @Internal + public static Float readBoxedFloatTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readBoxedFloatTarget(readContext, from, to); + } + + @Internal + public static double readDoubleTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readDoubleTarget(readContext, from, to); + } + + @Internal + public static Double readBoxedDoubleTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readBoxedDoubleTarget(readContext, from, to); + } + + @Internal + public static String readStringTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readStringTarget(readContext, from, to); + } + + @Internal + public static BigDecimal readDecimalTarget( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readDecimalTarget(readContext, from, to); + } + + @Internal + public static UInt8 readUInt8Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readUInt8Target(readContext, from, to); + } + + @Internal + public static UInt16 readUInt16Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readUInt16Target(readContext, from, to); + } + + @Internal + public static UInt32 readUInt32Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readUInt32Target(readContext, from, to); + } + + @Internal + public static UInt64 readUInt64Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readUInt64Target(readContext, from, to); + } + + @Internal + public static Float16 readFloat16Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readFloat16Target(readContext, from, to); + } + + @Internal + public static BFloat16 readBFloat16Target( + ReadContext readContext, SerializationFieldInfo from, SerializationFieldInfo to) { + return CompatibleScalarConverter.readBFloat16Target(readContext, from, to); + } + + @Internal + public static Class toType(FieldConverter converter) { + return scalarConverter(converter).toType; } private static Object readSourceDecimal( @@ -539,6 +629,26 @@ private static boolean isScalarField(SerializationFieldInfo fieldInfo) { return CompatibleScalarConverter.isScalar(fieldInfo.dispatchId, fieldInfo.type); } + private static boolean hasIncompatibleRootArrayOrBinary( + TypeResolver resolver, Descriptor from, Descriptor to) { + int fromTypeId = rootArrayOrBinaryTypeId(resolver, from); + int toTypeId = rootArrayOrBinaryTypeId(resolver, to); + if (fromTypeId == Types.UNKNOWN && toTypeId == Types.UNKNOWN) { + return false; + } + return fromTypeId != toTypeId && !isBinaryUint8ArrayTypePair(fromTypeId, toTypeId); + } + + private static boolean isBinaryUint8ArrayTypePair(int fromTypeId, int toTypeId) { + return (fromTypeId == Types.BINARY && toTypeId == Types.UINT8_ARRAY) + || (fromTypeId == Types.UINT8_ARRAY && toTypeId == Types.BINARY); + } + + private static int rootArrayOrBinaryTypeId(TypeResolver resolver, Descriptor descriptor) { + int typeId = Types.getDescriptorTypeId(resolver, descriptor); + return typeId == Types.BINARY || Types.isPrimitiveArray(typeId) ? typeId : Types.UNKNOWN; + } + private static boolean isDirectlyAssignable(Class from, Class to) { if (to.isAssignableFrom(from)) { return true; diff --git a/java/fory-core/src/test/java/org/apache/fory/builder/StaticCompatibleCodecBuilderTest.java b/java/fory-core/src/test/java/org/apache/fory/builder/StaticCompatibleCodecBuilderTest.java index 8ff3472fde..8bfc9fcf45 100644 --- a/java/fory-core/src/test/java/org/apache/fory/builder/StaticCompatibleCodecBuilderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/builder/StaticCompatibleCodecBuilderTest.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; import java.net.URL; import java.net.URLClassLoader; import java.nio.charset.StandardCharsets; @@ -225,6 +226,20 @@ public void testStaticCompatibleRecordSerializerConvertsRemoteField() throws Exc Class readerType = readerLoader.loadClass("test.StaticCompatibleRecordPayload"); Fory writer = compatibleFory(writerLoader, writerType, false, "record-writer"); Fory reader = compatibleFory(readerLoader, readerType, false, "record-reader"); + TypeDef remoteTypeDef = TypeDef.buildTypeDef(writer.getTypeResolver(), writerType); + String generatedSource = + new StaticCompatibleCodecBuilder(TypeRef.of(readerType), reader, remoteTypeDef).genCode(); + Assert.assertTrue(generatedSource.contains("case 0:")); + Assert.assertTrue(generatedSource.contains("case 1:")); + Assert.assertTrue(generatedSource.contains("FieldConverters.readIntTarget")); + Assert.assertTrue(generatedSource.contains("int _f_recordValue0 = 0;")); + Assert.assertTrue( + generatedSource.contains( + "return new test.StaticCompatibleRecordPayload(_f_recordValue0);")); + Assert.assertFalse(generatedSource.contains("newInstanceWithArguments")); + Assert.assertFalse(generatedSource.contains("Object[] _f_recordArgs = this._f_recordArgs")); + Assert.assertFalse(generatedSource.contains("_f_recordValues")); + Assert.assertFalse(generatedSource.contains("readMatchedRecordField")); Object writerValue = writerType.getConstructor().newInstance(); setField(writerType, writerValue, "id", "73"); @@ -236,6 +251,39 @@ public void testStaticCompatibleRecordSerializerConvertsRemoteField() throws Exc } } + @Test + public void testInaccessibleRecordInstantiator() throws Exception { + assumeRecordSupport(); + CompilationResult writerResult = + compile( + "test.StaticCompatibleHiddenRecordPayload", + "package test;\n" + + "public class StaticCompatibleHiddenRecordPayload {\n" + + " public String id;\n" + + " public StaticCompatibleHiddenRecordPayload() {}\n" + + "}\n"); + CompilationResult readerResult = + compile( + "test.StaticCompatibleHiddenRecordPayload", + "package test;\n" + "record StaticCompatibleHiddenRecordPayload(int id) {}\n"); + Assert.assertTrue(writerResult.success, writerResult.diagnostics()); + Assert.assertTrue(readerResult.success, readerResult.diagnostics()); + try (URLClassLoader writerLoader = writerResult.classLoader(); + URLClassLoader readerLoader = readerResult.classLoader()) { + Class writerType = writerLoader.loadClass("test.StaticCompatibleHiddenRecordPayload"); + Class readerType = readerLoader.loadClass("test.StaticCompatibleHiddenRecordPayload"); + Fory writer = compatibleFory(writerLoader, writerType, false, "hidden-record-writer"); + Fory reader = compatibleFory(readerLoader, readerType, false, "hidden-record-reader"); + TypeDef remoteTypeDef = TypeDef.buildTypeDef(writer.getTypeResolver(), writerType); + String generatedSource = + new StaticCompatibleCodecBuilder(TypeRef.of(readerType), reader, remoteTypeDef).genCode(); + Assert.assertTrue(generatedSource.contains("newInstanceWithArguments")); + Assert.assertTrue(generatedSource.contains("Object[] _f_recordArgs = this._f_recordArgs")); + Assert.assertFalse( + generatedSource.contains("return new test.StaticCompatibleHiddenRecordPayload")); + } + } + @Test public void testCompatibleRecordSerializerConvertsRemoteField() throws Exception { assumeRecordSupport(); @@ -406,6 +454,131 @@ public void testStaticCompatibleArrayPayloadReadsOrdinaryAnnotatedList() throws } } + @Test + public void testStaticRejectsNestedRefSchema() throws Exception { + assertStaticSchemaFails( + "package test;\n" + + "import java.util.List;\n" + + "import org.apache.fory.annotation.Ref;\n" + + "public class StaticCompatibleNestedPayload {\n" + + " public List<@Ref String> values;\n" + + " public StaticCompatibleNestedPayload() {}\n" + + "}\n", + "package test;\n" + + "import java.util.List;\n" + + "public class StaticCompatibleNestedPayload {\n" + + " public List values;\n" + + " public StaticCompatibleNestedPayload() {}\n" + + "}\n"); + } + + @Test + public void testStaticRejectsNestedArraySchema() throws Exception { + assertStaticSchemaFails( + "package test;\n" + + "import java.util.List;\n" + + "public class StaticCompatibleNestedPayload {\n" + + " public List values;\n" + + " public StaticCompatibleNestedPayload() {}\n" + + "}\n", + "package test;\n" + + "import java.util.List;\n" + + "import org.apache.fory.annotation.UInt8Type;\n" + + "public class StaticCompatibleNestedPayload {\n" + + " public List<@UInt8Type byte[]> values;\n" + + " public StaticCompatibleNestedPayload() {}\n" + + "}\n"); + } + + @Test + public void testStaticBinaryUint8ArrayUsesCompatibleCase() throws Exception { + CompilationResult writerResult = + compile( + "test.StaticCompatibleBytePayload", + "package test;\n" + + "public class StaticCompatibleBytePayload {\n" + + " public byte[] value;\n" + + " public StaticCompatibleBytePayload() {}\n" + + "}\n"); + CompilationResult readerResult = + compile( + "test.StaticCompatibleBytePayload", + "package test;\n" + + "import org.apache.fory.annotation.UInt8Type;\n" + + "public class StaticCompatibleBytePayload {\n" + + " @UInt8Type public byte[] value;\n" + + " public StaticCompatibleBytePayload() {}\n" + + "}\n"); + Assert.assertTrue(writerResult.success, writerResult.diagnostics()); + Assert.assertTrue(readerResult.success, readerResult.diagnostics()); + try (URLClassLoader writerLoader = writerResult.classLoader(); + URLClassLoader readerLoader = readerResult.classLoader()) { + Class writerType = writerLoader.loadClass("test.StaticCompatibleBytePayload"); + Class readerType = readerLoader.loadClass("test.StaticCompatibleBytePayload"); + Fory writer = compatibleFory(writerLoader, writerType, true, "byte-writer"); + Fory reader = compatibleFory(readerLoader, readerType, true, "uint8-reader"); + TypeDef remoteTypeDef = TypeDef.buildTypeDef(writer.getTypeResolver(), writerType); + Class compatibleSerializerClass = + CodecUtils.loadOrGenStaticCompatibleCodecClass( + reader.getTypeResolver(), cast(readerType), remoteTypeDef); + Serializer compatibleSerializer = + compatibleSerializerClass + .getConstructor(TypeResolver.class, Class.class, TypeDef.class) + .newInstance(reader.getTypeResolver(), readerType, remoteTypeDef); + Assert.assertEquals(remoteFields(compatibleSerializer).get(0).matchedId, 1); + + Object writerValue = writerType.getConstructor().newInstance(); + setField(writerType, writerValue, "value", new byte[] {1, 2}); + Object result = + roundTripThroughStaticCompatibleSerializer( + writer, reader, writerType, readerType, writerValue); + Assert.assertTrue( + Arrays.equals((byte[]) getField(readerType, result, "value"), new byte[] {1, 2})); + } + } + + @Test + public void testStaticRejectsRootPrimitiveArraySchema() throws Exception { + assertStaticSchemaFails( + "package test;\n" + + "public class StaticCompatibleNestedPayload {\n" + + " public byte[] values;\n" + + " public StaticCompatibleNestedPayload() {}\n" + + "}\n", + "package test;\n" + + "import org.apache.fory.annotation.Int8Type;\n" + + "public class StaticCompatibleNestedPayload {\n" + + " @Int8Type public byte[] values;\n" + + " public StaticCompatibleNestedPayload() {}\n" + + "}\n"); + assertStaticSchemaFails( + "package test;\n" + + "import org.apache.fory.annotation.Int8Type;\n" + + "public class StaticCompatibleNestedPayload {\n" + + " @Int8Type public byte[] values;\n" + + " public StaticCompatibleNestedPayload() {}\n" + + "}\n", + "package test;\n" + + "import org.apache.fory.annotation.UInt8Type;\n" + + "public class StaticCompatibleNestedPayload {\n" + + " @UInt8Type public byte[] values;\n" + + " public StaticCompatibleNestedPayload() {}\n" + + "}\n"); + assertStaticSchemaFails( + "package test;\n" + + "import org.apache.fory.annotation.Float16Type;\n" + + "public class StaticCompatibleNestedPayload {\n" + + " @Float16Type public short[] values;\n" + + " public StaticCompatibleNestedPayload() {}\n" + + "}\n", + "package test;\n" + + "import org.apache.fory.annotation.BFloat16Type;\n" + + "public class StaticCompatibleNestedPayload {\n" + + " @BFloat16Type public short[] values;\n" + + " public StaticCompatibleNestedPayload() {}\n" + + "}\n"); + } + @Test public void testGraalCompatibleSerializerRegistryUsesLocalReaderClass() throws Exception { CompilationResult writerAResult = @@ -491,21 +664,53 @@ private static Object roundTripThroughStaticCompatibleSerializer( return reader.deserialize(bytes); } + private static void assertStaticSchemaFails(String writerSource, String readerSource) + throws Exception { + CompilationResult writerResult = compile("test.StaticCompatibleNestedPayload", writerSource); + CompilationResult readerResult = compile("test.StaticCompatibleNestedPayload", readerSource); + Assert.assertTrue(writerResult.success, writerResult.diagnostics()); + Assert.assertTrue(readerResult.success, readerResult.diagnostics()); + try (URLClassLoader writerLoader = writerResult.classLoader(); + URLClassLoader readerLoader = readerResult.classLoader()) { + Class writerType = writerLoader.loadClass("test.StaticCompatibleNestedPayload"); + Class readerType = readerLoader.loadClass("test.StaticCompatibleNestedPayload"); + Fory writer = compatibleFory(writerLoader, writerType, true, "schema-writer"); + Fory reader = compatibleFory(readerLoader, readerType, true, "schema-reader"); + TypeDef remoteTypeDef = TypeDef.buildTypeDef(writer.getTypeResolver(), writerType); + Class compatibleSerializerClass = + CodecUtils.loadOrGenStaticCompatibleCodecClass( + reader.getTypeResolver(), cast(readerType), remoteTypeDef); + InvocationTargetException exception = + Assert.expectThrows( + InvocationTargetException.class, + () -> + compatibleSerializerClass + .getConstructor(TypeResolver.class, Class.class, TypeDef.class) + .newInstance(reader.getTypeResolver(), readerType, remoteTypeDef)); + Assert.assertTrue(exception.getCause() instanceof RuntimeException, exception.toString()); + } + } + private static Map remoteCodecCategories( Serializer compatibleSerializer) throws Exception { - Field remoteFieldsField = - StaticGeneratedStructSerializer.class.getDeclaredField("remoteFields"); - remoteFieldsField.setAccessible(true); - List remoteFields = (List) remoteFieldsField.get(compatibleSerializer); + List remoteFields = remoteFields(compatibleSerializer); Map categories = new HashMap<>(); - for (Object field : remoteFields) { - RemoteFieldInfo remoteField = (RemoteFieldInfo) field; + for (RemoteFieldInfo remoteField : remoteFields) { categories.put( remoteField.descriptor.getName(), remoteField.serializationFieldInfo.codecCategory); } return categories; } + @SuppressWarnings("unchecked") + private static List remoteFields(Serializer compatibleSerializer) + throws Exception { + Field remoteFieldsField = + StaticGeneratedStructSerializer.class.getDeclaredField("remoteFields"); + remoteFieldsField.setAccessible(true); + return (List) remoteFieldsField.get(compatibleSerializer); + } + private static Fory compatibleFory( ClassLoader classLoader, Class type, boolean xlang, String role) { return compatibleFory(classLoader, type, xlang, role, true); diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/NestedTypeAnnotationTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/NestedTypeAnnotationTest.java index cb2928256d..04f2dbdfea 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/NestedTypeAnnotationTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/NestedTypeAnnotationTest.java @@ -440,17 +440,16 @@ public void compatibleMissingFieldSkipUsesRemoteNestedMetadata(boolean enableCod } @Test(dataProvider = "enableCodegen") - public void compatibleDifferentFieldMetadataUsesRemoteNestedMetadata(boolean enableCodegen) { + public void compatibleNestedScalarMismatchFails(boolean enableCodegen) { Fory writer = xlangFory(true, enableCodegen); writer.register(RemoteNestedField.class, 716); Fory reader = xlangFory(true, enableCodegen); reader.register(LocalDifferentNestedField.class, 716); RemoteNestedField value = remoteNestedField(); - LocalDifferentNestedField copy = (LocalDifferentNestedField) serDeObject(writer, reader, value); - Assert.assertEquals(copy.id, value.id); - Assert.assertEquals(copy.dropped, value.dropped); - Assert.assertEquals(copy.tail, value.tail); + Assert.expectThrows( + org.apache.fory.exception.DeserializationException.class, + () -> serDeObject(writer, reader, value)); } private static Fory xlangFory(boolean compatible, boolean enableCodegen) { diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java index 56e42dbf84..613d50342f 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java @@ -31,7 +31,10 @@ import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.resolver.TypeResolver; +import org.apache.fory.resolver.XtypeResolver; +import org.apache.fory.type.Descriptor; import org.apache.fory.type.Types; +import org.apache.fory.type.union.Union; import org.apache.fory.util.MurmurHash3; import org.testng.Assert; import org.testng.annotations.Test; @@ -233,6 +236,32 @@ public static class ClassWithLargeTagIds { private int field2; } + public static class NestedUnionMapField { + private Map values; + } + + @Test + public void testNestedUnionSchemaCompare() { + Fory fory = Fory.builder().withXlang(true).build(); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(16); + buffer.writeVarUInt32Small7(Types.MAP << 2); + buffer.writeVarUInt32Small7(Types.STRING << 2); + buffer.writeVarUInt32Small7(Types.TYPED_UNION << 2); + buffer.readerIndex(0); + + FieldTypes.MapFieldType fieldType = + (FieldTypes.MapFieldType) + FieldTypes.FieldType.readCrossLanguage(buffer, (XtypeResolver) fory.getTypeResolver()); + + Assert.assertTrue(fieldType.getValueType() instanceof FieldTypes.ObjectFieldType); + Assert.assertEquals(fieldType.getValueType().getTypeId(), Types.TYPED_UNION); + + Descriptor localDescriptor = + Descriptor.getDescriptorsMap(NestedUnionMapField.class).get("values"); + new FieldInfo(NestedUnionMapField.class.getName(), "values", fieldType) + .toDescriptor(fory.getTypeResolver(), localDescriptor); + } + @Test public void testBuildFieldsInfoWithDuplicateTagIds() { Fory fory = Fory.builder().withXlang(true).withCompatible(false).withMetaShare(true).build(); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleFieldConvertTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleFieldConvertTest.java index 876dea0452..b5a500ac3f 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleFieldConvertTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleFieldConvertTest.java @@ -22,13 +22,19 @@ import com.google.common.collect.ImmutableSet; import java.lang.reflect.Field; import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.annotation.BFloat16Type; +import org.apache.fory.annotation.Float16Type; import org.apache.fory.annotation.ForyField; +import org.apache.fory.annotation.Int8Type; import org.apache.fory.annotation.Nullable; import org.apache.fory.annotation.Ref; import org.apache.fory.annotation.UInt64Type; +import org.apache.fory.annotation.UInt8Type; import org.apache.fory.config.Int64Encoding; import org.apache.fory.exception.DeserializationException; import org.apache.fory.reflect.ReflectionUtils; @@ -206,6 +212,13 @@ public static final class RefStringBoolWriter { public String value = "true"; } + public static final class NullableRefStringWriter { + @Nullable + @Ref + @ForyField(id = 0) + public String value = "true"; + } + public static final class RefBooleanWriter { @Ref @ForyField(id = 0) @@ -287,6 +300,19 @@ public static final class StringReader { public String value; } + public static final class RefStringReader { + @Ref + @ForyField(id = 0) + public String value; + } + + public static final class NullableRefStringReader { + @Nullable + @Ref + @ForyField(id = 0) + public String value; + } + public static final class UnsignedLongWriter { @UInt64Type(encoding = Int64Encoding.FIXED) @ForyField(id = 0) @@ -384,6 +410,82 @@ public static final class NumberReader { public Number value; } + public static final class StringListWriter { + @ForyField(id = 0) + public List value = Collections.singletonList("1"); + } + + public static final class RefStringListWriter { + @ForyField(id = 0) + public List<@Ref String> value = Collections.singletonList("1"); + } + + public static final class PlainStringListReader { + @ForyField(id = 0) + public List value; + } + + public static final class IntListReader { + @ForyField(id = 0) + public List value; + } + + public static final class ByteArrayListWriter { + @ForyField(id = 0) + public List value = Collections.singletonList(new byte[] {1}); + } + + public static final class UInt8ByteArrayListReader { + @ForyField(id = 0) + public List<@UInt8Type byte[]> value; + } + + public static final class BinaryWriter { + @ForyField(id = 0) + public byte[] value = new byte[] {1, 2}; + } + + public static final class BinaryReader { + @ForyField(id = 0) + public byte[] value; + } + + public static final class Int8ArrayWriter { + @Int8Type + @ForyField(id = 0) + public byte[] value = new byte[] {1, 2}; + } + + public static final class Int8ArrayReader { + @Int8Type + @ForyField(id = 0) + public byte[] value; + } + + public static final class UInt8ArrayWriter { + @UInt8Type + @ForyField(id = 0) + public byte[] value = new byte[] {(byte) 200, (byte) 250}; + } + + public static final class UInt8ArrayReader { + @UInt8Type + @ForyField(id = 0) + public byte[] value; + } + + public static final class Float16ArrayWriter { + @Float16Type + @ForyField(id = 0) + public short[] value = new short[] {0, 1}; + } + + public static final class BFloat16ArrayReader { + @BFloat16Type + @ForyField(id = 0) + public short[] value; + } + @DataProvider public static Object[][] xlangAndCodegen() { return new Object[][] {{false, false}, {false, true}, {true, false}, {true, true}}; @@ -499,7 +601,7 @@ public void testScalarAssignableSupertype(boolean codegen) { } @Test(dataProvider = "codegenModes") - public void testScalarTrackingRefConversionRejected(boolean codegen) { + public void testScalarTrackingRefRejected(boolean codegen) { Fory writer = Fory.builder() .withXlang(true) @@ -518,8 +620,7 @@ public void testScalarTrackingRefConversionRejected(boolean codegen) { .withRefTracking(true) .build(); reader.register(BoolReader.class, 28000); - BoolReader value = (BoolReader) reader.deserialize(bytes); - Assert.assertFalse(value.value); + Assert.assertThrows(RuntimeException.class, () -> reader.deserialize(bytes)); } @Test @@ -556,50 +657,53 @@ public void testScalarTrackingRefClassifier() { } @Test(dataProvider = "xlang") - public void testScalarTrackingRefMismatchSkipped(boolean xlang) { - Fory refWriter = compatibleRefFory(xlang, false); - refWriter.register(RefBooleanWriter.class, 28000); - byte[] refBytes = refWriter.serialize(new RefBooleanWriter()); + public void testScalarTrackingRefMismatchRejected(boolean xlang) { + assertRefSchemaFails(new RefStringBoolWriter(), StringReader.class, xlang, false); + assertRefSchemaFails(new StringBoolWriter(), RefStringReader.class, xlang, false); + assertRefSchemaFails(new NullableRefStringWriter(), RefStringReader.class, xlang, false); + assertRefSchemaFails(new RefStringBoolWriter(), NullableRefStringReader.class, xlang, false); - Fory nonRefReader = compatibleRefFory(xlang, false); - nonRefReader.register(BoolReader.class, 28000); - BoolReader boolReader = (BoolReader) nonRefReader.deserialize(refBytes); - Assert.assertFalse(boolReader.value); + Fory nullableRefWriter = compatibleRefFory(xlang, false); + nullableRefWriter.register(NullableRefStringWriter.class, 28000); + byte[] nullableRefBytes = nullableRefWriter.serialize(new NullableRefStringWriter()); - Fory nonRefWriter = compatibleRefFory(xlang, false); - nonRefWriter.register(BoolStringWriter.class, 28000); - byte[] nonRefBytes = nonRefWriter.serialize(new BoolStringWriter()); + Fory exactNullableRefReaderFory = compatibleRefFory(xlang, false); + exactNullableRefReaderFory.register(NullableRefStringReader.class, 28000); + NullableRefStringReader exactNullableRefReader = + (NullableRefStringReader) exactNullableRefReaderFory.deserialize(nullableRefBytes); + Assert.assertEquals("true", exactNullableRefReader.value); + } - Fory refReaderFory = compatibleRefFory(xlang, false); - refReaderFory.register(RefBooleanReader.class, 28000); - RefBooleanReader refReader = (RefBooleanReader) refReaderFory.deserialize(nonRefBytes); - Assert.assertNull(refReader.value); + @Test + public void testNestedScalarMismatchRejected() { + assertSchemaFails(new StringListWriter(), IntListReader.class, true, false); + } - Fory nullableRefWriter = compatibleRefFory(xlang, false); - nullableRefWriter.register(NullableRefBooleanWriter.class, 28000); - byte[] nullableRefBytes = nullableRefWriter.serialize(new NullableRefBooleanWriter()); + @Test(dataProvider = "codegenModes") + public void testNestedRefTrackingRejected(boolean codegen) { + assertRefSchemaFails(new RefStringListWriter(), PlainStringListReader.class, true, codegen); + } - Fory requiredRefReaderFory = compatibleRefFory(xlang, false); - requiredRefReaderFory.register(RefBooleanReader.class, 28000); - RefBooleanReader requiredRefReader = - (RefBooleanReader) requiredRefReaderFory.deserialize(nullableRefBytes); - Assert.assertNull(requiredRefReader.value); + @Test(dataProvider = "codegenModes") + public void testNestedArrayTypeRejected(boolean codegen) { + assertSchemaFails(new ByteArrayListWriter(), UInt8ByteArrayListReader.class, true, codegen); + } - Fory requiredRefWriter = compatibleRefFory(xlang, false); - requiredRefWriter.register(RefBooleanWriter.class, 28000); - byte[] requiredRefBytes = requiredRefWriter.serialize(new RefBooleanWriter()); + @Test(dataProvider = "codegenModes") + public void testBinaryUint8ArrayBridge(boolean codegen) { + UInt8ArrayReader uint8Reader = + readAs(new BinaryWriter(), UInt8ArrayReader.class, true, codegen); + Assert.assertTrue(Arrays.equals(uint8Reader.value, new byte[] {1, 2})); - Fory nullableRefReaderFory = compatibleRefFory(xlang, false); - nullableRefReaderFory.register(NullableRefBooleanReader.class, 28000); - NullableRefBooleanReader nullableRefReader = - (NullableRefBooleanReader) nullableRefReaderFory.deserialize(requiredRefBytes); - Assert.assertNull(nullableRefReader.value); + BinaryReader binaryReader = readAs(new UInt8ArrayWriter(), BinaryReader.class, true, codegen); + Assert.assertTrue(Arrays.equals(binaryReader.value, new byte[] {(byte) 200, (byte) 250})); + } - Fory exactNullableRefReaderFory = compatibleRefFory(xlang, false); - exactNullableRefReaderFory.register(NullableRefBooleanReader.class, 28000); - NullableRefBooleanReader exactNullableRefReader = - (NullableRefBooleanReader) exactNullableRefReaderFory.deserialize(nullableRefBytes); - Assert.assertEquals(Boolean.TRUE, exactNullableRefReader.value); + @Test(dataProvider = "codegenModes") + public void testRootPrimitiveArrayMismatchRejected(boolean codegen) { + assertSchemaFails(new BinaryWriter(), Int8ArrayReader.class, true, codegen); + assertSchemaFails(new Int8ArrayWriter(), UInt8ArrayReader.class, true, codegen); + assertSchemaFails(new Float16ArrayWriter(), BFloat16ArrayReader.class, true, codegen); } @Test(dataProvider = "xlangAndCodegen") @@ -675,12 +779,34 @@ private static T readAs( return readerClass.cast(reader.deserialize(bytes)); } + private static T readAsRef( + Object writerObject, Class readerClass, boolean xlang, boolean codegen) { + Fory writer = compatibleRefFory(xlang, codegen); + writer.register(writerObject.getClass(), 28000); + byte[] bytes = writer.serialize(writerObject); + Fory reader = compatibleRefFory(xlang, codegen); + reader.register(readerClass, 28000); + return readerClass.cast(reader.deserialize(bytes)); + } + private static void assertConversionFails( Object writerObject, Class readerClass, boolean xlang, boolean codegen) { Assert.assertThrows( DeserializationException.class, () -> readAs(writerObject, readerClass, xlang, codegen)); } + private static void assertSchemaFails( + Object writerObject, Class readerClass, boolean xlang, boolean codegen) { + Assert.assertThrows( + RuntimeException.class, () -> readAs(writerObject, readerClass, xlang, codegen)); + } + + private static void assertRefSchemaFails( + Object writerObject, Class readerClass, boolean xlang, boolean codegen) { + Assert.assertThrows( + RuntimeException.class, () -> readAsRef(writerObject, readerClass, xlang, codegen)); + } + private static Fory compatibleFory(boolean xlang, boolean codegen) { return Fory.builder().withXlang(xlang).withCompatible(true).withCodegen(codegen).build(); } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/DuplicateFieldsTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/DuplicateFieldsTest.java index 096d3315e5..da9a25882a 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/DuplicateFieldsTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/DuplicateFieldsTest.java @@ -206,22 +206,7 @@ public void testDuplicateFieldsCompatible() { assertEquals(newC, c); } { - // Use CompatibleSerializer JIT version - Serializer serializer = - Serializers.newSerializer( - fory, - C.class, - CodecUtils.loadOrGenCompatibleCodecClass( - fory, C.class, fory.getTypeResolver().getTypeDef(C.class, true))); - MemoryBuffer buffer = MemoryUtils.buffer(32); - writeSerializer(fory, serializer, buffer, c); - C newC = readSerializer(fory, serializer, buffer); - assertEquals(newC.f1, c.f1); - assertEquals(((B) newC).f1, ((B) c).f1); - assertEquals(newC, c); - } - { - // FallbackSerializer/CodegenSerializer will set itself to ClassResolver. + // The compatible generated serializer is schema-pair owned and installed by TypeResolver. Fory fory1 = builder.build(); C newC = serDeCheckSerializer(fory1, c, ".*Codec|.*Serializer"); assertEquals(newC.f1, c.f1); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/MetaShareXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/MetaShareXlangTest.java index e6082e8869..70c4e664d7 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/MetaShareXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/MetaShareXlangTest.java @@ -20,7 +20,6 @@ package org.apache.fory.xlang; import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertNull; import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; @@ -189,7 +188,7 @@ public void testTopLevelListArrayCompatibleReadWithoutCodegen() { } @Test - public void testNullableListCompatibleReadToArrayRejectsNullElements() { + public void testNullableListCompatibleReadToArray() { for (boolean codegen : new boolean[] {false, true}) { Fory listFory = compatibleFory(DirectNullableListField.class, codegen); DirectNullableListField listStruct = new DirectNullableListField(); @@ -200,30 +199,29 @@ public void testNullableListCompatibleReadToArrayRejectsNullElements() { DirectArrayField arrayStruct = (DirectArrayField) arrayFory.deserialize(listBytes); assertTrue(Arrays.equals(arrayStruct.values, new int[] {1, 2, 3})); - listStruct.values = Arrays.asList(1, null, 3); - byte[] nullablePayload = listFory.serialize(listStruct); - assertThrows(DeserializationException.class, () -> arrayFory.deserialize(nullablePayload)); + DirectNullableListField nullElementStruct = new DirectNullableListField(); + nullElementStruct.values = Arrays.asList(1, null, 3); + byte[] nullElementBytes = listFory.serialize(nullElementStruct); + assertThrows(DeserializationException.class, () -> arrayFory.deserialize(nullElementBytes)); } } @Test - public void testNestedListArrayCompatibleReadSkipped() { + public void testNestedListArrayRejected() { Fory nestedListFory = compatibleFory(NestedListField.class); NestedListField nestedListStruct = new NestedListField(); nestedListStruct.values = Arrays.asList(Arrays.asList(1, 2)); byte[] nestedListBytes = nestedListFory.serialize(nestedListStruct); Fory nestedArrayFory = compatibleFory(NestedArrayElementField.class); - NestedArrayElementField skippedNestedArrayStruct = - (NestedArrayElementField) nestedArrayFory.deserialize(nestedListBytes); - assertNull(skippedNestedArrayStruct.values); + assertThrows( + DeserializationException.class, () -> nestedArrayFory.deserialize(nestedListBytes)); NestedArrayElementField nestedArrayStruct = new NestedArrayElementField(); nestedArrayStruct.values = Arrays.asList(new int[] {1, 2}); byte[] nestedArrayBytes = nestedArrayFory.serialize(nestedArrayStruct); - NestedListField skippedNestedListStruct = - (NestedListField) nestedListFory.deserialize(nestedArrayBytes); - assertNull(skippedNestedListStruct.values); + assertThrows( + DeserializationException.class, () -> nestedListFory.deserialize(nestedArrayBytes)); Fory nestedSetListFory = compatibleFory(NestedSetListField.class, false); NestedSetListField nestedSetListStruct = new NestedSetListField(); @@ -231,16 +229,14 @@ public void testNestedListArrayCompatibleReadSkipped() { byte[] nestedSetListBytes = nestedSetListFory.serialize(nestedSetListStruct); Fory nestedSetArrayFory = compatibleFory(NestedSetArrayElementField.class, false); - NestedSetArrayElementField skippedNestedSetArrayStruct = - (NestedSetArrayElementField) nestedSetArrayFory.deserialize(nestedSetListBytes); - assertNull(skippedNestedSetArrayStruct.values); + assertThrows( + DeserializationException.class, () -> nestedSetArrayFory.deserialize(nestedSetListBytes)); NestedSetArrayElementField nestedSetArrayStruct = new NestedSetArrayElementField(); nestedSetArrayStruct.values = new LinkedHashSet<>(Arrays.asList(new int[] {1, 2})); byte[] nestedSetArrayBytes = nestedSetArrayFory.serialize(nestedSetArrayStruct); - NestedSetListField skippedNestedSetListStruct = - (NestedSetListField) nestedSetListFory.deserialize(nestedSetArrayBytes); - assertNull(skippedNestedSetListStruct.values); + assertThrows( + DeserializationException.class, () -> nestedSetListFory.deserialize(nestedSetArrayBytes)); } private static Fory compatibleFory(Class type) { diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java index fd6ceb4dcb..16d7634925 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java @@ -80,8 +80,10 @@ protected CommandContext buildCommandContext(String caseName, Path dataFile) { Map env = envBuilder(dataFile); env.put("RUSTFLAGS", "-Awarnings"); env.put("RUST_BACKTRACE", "1"); + // Rust test threads default to a small stack, and generated xlang compatible reads can exceed + // it on large schema-change structs before the test reaches Fory assertions. + env.put("RUST_MIN_STACK", "4194304"); env.put("ENABLE_FORY_DEBUG_OUTPUT", "1"); - env.put("FORY_PANIC_ON_ERROR", caseName.endsWith("_error") ? "0" : "1"); return new CommandContext(command, env, new File("../../rust")); } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java b/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java index 649492ecf3..5f4bfec2ff 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java @@ -1879,22 +1879,29 @@ protected void testListArrayCompatibleRead(boolean enableCodegen) throws java.io buffer = MemoryBuffer.newHeapBuffer(256); nullableListFory.serialize(buffer, nullableListWithoutNulls); byte[] nullableListWithoutNullsPayload = buffer.getBytes(0, buffer.writerIndex()); - XlangCompatibleInt32ArrayField nullableListWithoutNullsArray = + XlangCompatibleInt32ArrayField nullableArrayResult = (XlangCompatibleInt32ArrayField) arrayFory.deserialize(MemoryUtils.wrap(nullableListWithoutNullsPayload)); - Assert.assertEquals(nullableListWithoutNullsArray.values, new int[] {1, 2, 3}); + assertIntArrayEquals(nullableArrayResult.values, 1, 2, 3); + ctx = + prepareExecution( + "test_list_array_compatible_list_to_array", nullableListWithoutNullsPayload); + runPeer(ctx); + XlangCompatibleInt32ArrayField peerNullableArrayResult = + (XlangCompatibleInt32ArrayField) arrayFory.deserialize(readBuffer(ctx.dataFile())); + assertIntArrayEquals(peerNullableArrayResult.values, 1, 2, 3); - XlangCompatibleNullableInt32ListField nullableListValue = + XlangCompatibleNullableInt32ListField nullableListWithNull = newCompatibleNullableInt32ListField(1, null, 3); buffer = MemoryBuffer.newHeapBuffer(256); - nullableListFory.serialize(buffer, nullableListValue); - byte[] nullablePayload = buffer.getBytes(0, buffer.writerIndex()); + nullableListFory.serialize(buffer, nullableListWithNull); + byte[] nullableListWithNullPayload = buffer.getBytes(0, buffer.writerIndex()); Assert.expectThrows( DeserializationException.class, - () -> arrayFory.deserialize(MemoryUtils.wrap(nullablePayload))); + () -> arrayFory.deserialize(MemoryUtils.wrap(nullableListWithNullPayload))); ctx = prepareExecution( - "test_list_array_compatible_nullable_list_to_array_error", nullablePayload); + "test_list_array_compatible_nullable_list_to_array_error", nullableListWithNullPayload); runPeer(ctx); Path nullableDataFile = ctx.dataFile(); Assert.expectThrows( diff --git a/javascript/packages/core/lib/compatible/field.ts b/javascript/packages/core/lib/compatible/field.ts new file mode 100644 index 0000000000..40510a42c0 --- /dev/null +++ b/javascript/packages/core/lib/compatible/field.ts @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import type { TypeInfo } from "../typeInfo"; + +const skipReadActions = new WeakSet(); + +export function markCompatibleSkipRead(typeInfo: TypeInfo): TypeInfo { + skipReadActions.add(typeInfo); + return typeInfo; +} + +export function shouldSkipCompatibleRead(typeInfo: TypeInfo): boolean { + return skipReadActions.has(typeInfo); +} diff --git a/javascript/packages/core/lib/compatible/scalar.ts b/javascript/packages/core/lib/compatible/scalar.ts index d5a895666d..56759ec06d 100644 --- a/javascript/packages/core/lib/compatible/scalar.ts +++ b/javascript/packages/core/lib/compatible/scalar.ts @@ -39,7 +39,6 @@ type DecimalParts = { }; const scalarReadActions = new WeakMap(); -const scalarSkipActions = new WeakSet(); const float32Array = new Float32Array(1); const float64Buffer = new ArrayBuffer(8); @@ -76,15 +75,6 @@ export function getCompatibleScalarReadAction( return scalarReadActions.get(typeInfo); } -export function markCompatibleScalarSkipRead(typeInfo: TypeInfo): TypeInfo { - scalarSkipActions.add(typeInfo); - return typeInfo; -} - -export function shouldSkipCompatibleScalarRead(typeInfo: TypeInfo): boolean { - return scalarSkipActions.has(typeInfo); -} - export function isCompatibleScalarType(typeId: number): boolean { return scalarKind(typeId) !== undefined; } @@ -155,18 +145,6 @@ function scalarKind(typeId: number): ScalarKind | undefined { } } -function isFloatType(typeId: number): boolean { - switch (canonicalScalarTypeId(typeId)) { - case TypeId.FLOAT16: - case TypeId.BFLOAT16: - case TypeId.FLOAT32: - case TypeId.FLOAT64: - return true; - default: - return false; - } -} - function readDecimal(reader: BinaryReader): Decimal { const scale = reader.readVarInt32(); const header = reader.readVarUInt64(); @@ -189,63 +167,6 @@ function readDecimal(reader: BinaryReader): Decimal { return new Decimal((meta & 1n) === 0n ? magnitude : -magnitude, scale); } -function readScalarPayload( - reader: BinaryReader, - remoteTypeId: number, -): unknown { - switch (remoteTypeId) { - case TypeId.BOOL: { - const value = reader.readUint8(); - if (value !== 0 && value !== 1) { - throw new Error(`Invalid boolean scalar value ${value}.`); - } - return value === 1; - } - case TypeId.STRING: - return reader.stringWithHeader(); - case TypeId.INT8: - return reader.readInt8(); - case TypeId.INT16: - return reader.readInt16(); - case TypeId.INT32: - return reader.readInt32(); - case TypeId.VARINT32: - return reader.readVarInt32(); - case TypeId.INT64: - return reader.readInt64(); - case TypeId.VARINT64: - return reader.readVarInt64(); - case TypeId.TAGGED_INT64: - return reader.readTaggedInt64(); - case TypeId.UINT8: - return reader.readUint8(); - case TypeId.UINT16: - return reader.readUint16(); - case TypeId.UINT32: - return reader.readUint32(); - case TypeId.VAR_UINT32: - return reader.readVarUInt32(); - case TypeId.UINT64: - return reader.readUint64(); - case TypeId.VAR_UINT64: - return reader.readVarUInt64(); - case TypeId.TAGGED_UINT64: - return reader.readTaggedUInt64(); - case TypeId.FLOAT16: - return reader.readFloat16(); - case TypeId.BFLOAT16: - return reader.readBfloat16(); - case TypeId.FLOAT32: - return reader.readFloat32(); - case TypeId.FLOAT64: - return reader.readFloat64(); - case TypeId.DECIMAL: - return readDecimal(reader); - default: - throw new Error(`Unsupported compatible scalar type ${remoteTypeId}.`); - } -} - function pow10(exp: number): bigint { let result = 1n; for (let i = 0; i < exp; i++) { @@ -541,16 +462,6 @@ function exactInteger(value: DecimalParts): bigint { return normalized.unscaled; } -function exactSafeNumber(value: bigint): number { - const result = Number(value); - if (!Number.isSafeInteger(result) || BigInt(result) !== value) { - throw new Error( - `Scalar integer ${value.toString()} is not exactly representable as a number.`, - ); - } - return result; -} - function exactFloat(value: DecimalParts, localTypeId: number): number { let candidate = partsToNumber(value); if (!Number.isFinite(candidate)) { @@ -581,85 +492,50 @@ function exactFloat(value: DecimalParts, localTypeId: number): number { } function rangeCheckedInteger( - value: bigint, + value: number | bigint, localTypeId: number, ): number | bigint { + const integer = typeof value === "bigint" ? value : BigInt(value); switch (canonicalScalarTypeId(localTypeId)) { case TypeId.INT8: - if (value < INT8_MIN || value > INT8_MAX) + if (integer < INT8_MIN || integer > INT8_MAX) throw new Error("Scalar integer is outside int8 range."); - return Number(value); + return Number(integer); case TypeId.INT16: - if (value < INT16_MIN || value > INT16_MAX) + if (integer < INT16_MIN || integer > INT16_MAX) throw new Error("Scalar integer is outside int16 range."); - return Number(value); + return Number(integer); case TypeId.INT32: - if (value < INT32_MIN || value > INT32_MAX) + if (integer < INT32_MIN || integer > INT32_MAX) throw new Error("Scalar integer is outside int32 range."); - return Number(value); + return Number(integer); case TypeId.INT64: - if (value < INT64_MIN || value > INT64_MAX) + if (integer < INT64_MIN || integer > INT64_MAX) throw new Error("Scalar integer is outside int64 range."); - return value; + return integer; case TypeId.UINT8: - if (value < 0n || value > UINT8_MAX) + if (integer < 0n || integer > UINT8_MAX) throw new Error("Scalar integer is outside uint8 range."); - return Number(value); + return Number(integer); case TypeId.UINT16: - if (value < 0n || value > UINT16_MAX) + if (integer < 0n || integer > UINT16_MAX) throw new Error("Scalar integer is outside uint16 range."); - return Number(value); + return Number(integer); case TypeId.UINT32: - if (value < 0n || value > UINT32_MAX) + if (integer < 0n || integer > UINT32_MAX) throw new Error("Scalar integer is outside uint32 range."); - return Number(value); + return Number(integer); case TypeId.UINT64: - if (value < 0n || value > UINT64_MAX) + if (integer < 0n || integer > UINT64_MAX) throw new Error("Scalar integer is outside uint64 range."); - return value; + return integer; default: throw new Error("Target scalar type is not an integer."); } } -function valueToParts(value: unknown, remoteTypeId: number): DecimalParts { - if (value instanceof Decimal) { - return decimalToParts(value); - } - if (typeof value === "bigint") { - return integerToParts(value); - } - if (typeof value === "number") { - return isFloatType(remoteTypeId) - ? floatToParts(value) - : integerToParts(value); - } - if (typeof value === "string") { - return parseDecimalString(value); - } - if (typeof value === "boolean") { - return integerToParts(value ? 1 : 0); - } - throw new Error("Unsupported scalar value."); -} - -function convertToBool(value: unknown, remoteTypeId: number): boolean { - if (typeof value === "boolean") { - return value; - } - if (typeof value === "string") { - switch (value) { - case "0": - case "false": - return false; - case "1": - case "true": - return true; - default: - throw new Error(`Scalar string "${value}" is not a boolean value.`); - } - } - const integer = exactInteger(valueToParts(value, remoteTypeId)); +function integerToBool(value: number | bigint): boolean { + const integer = exactInteger(integerToParts(value)); if (integer === 0n) { return false; } @@ -669,72 +545,191 @@ function convertToBool(value: unknown, remoteTypeId: number): boolean { throw new Error("Scalar numeric value is not a boolean value."); } -function convertToString(value: unknown, remoteTypeId: number): string { - if (typeof value === "boolean") { - return value ? "true" : "false"; +function stringToBool(value: string): boolean { + switch (value) { + case "0": + case "false": + return false; + case "1": + case "true": + return true; + default: + throw new Error(`Scalar string "${value}" is not a boolean value.`); } - const parts = valueToParts(value, remoteTypeId); - return formatParts(parts, isFloatType(remoteTypeId)); } -function convertToDecimal(value: unknown, remoteTypeId: number): Decimal { - const parts = valueToParts(value, remoteTypeId); - const normalized = normalizeParts(parts); +function partsToDecimal(value: DecimalParts): Decimal { + const normalized = normalizeParts(value); return new Decimal(normalized.unscaled, normalized.scale); } -function convertToNumber( - value: unknown, - remoteTypeId: number, - localTypeId: number, -): number | bigint | Decimal { - if (canonicalScalarTypeId(localTypeId) === TypeId.DECIMAL) { - return convertToDecimal(value, remoteTypeId); - } - if (isFloatType(localTypeId)) { - if (typeof value === "number" && !Number.isFinite(value)) { - if (Number.isNaN(value)) { - throw new Error("Scalar NaN cannot be converted losslessly."); - } - return value; +function floatToFloat(value: number, localTypeId: number): number { + if (!Number.isFinite(value)) { + if (Number.isNaN(value)) { + throw new Error("Scalar NaN cannot be converted losslessly."); } - return exactFloat(valueToParts(value, remoteTypeId), localTypeId); - } - const integer = exactInteger(valueToParts(value, remoteTypeId)); - const result = rangeCheckedInteger(integer, localTypeId); - if (typeof result === "number") { - return exactSafeNumber(BigInt(result)); + return value; } - return result; + return exactFloat(floatToParts(value), localTypeId); } export class CompatibleScalarConverter { - static read( - reader: BinaryReader, - remoteTypeId: number, - localTypeId: number, - fieldName: string, - ): unknown { - try { - const value = readScalarPayload(reader, remoteTypeId); - if (remoteTypeId === localTypeId) { - return value; - } - switch (scalarKind(localTypeId)) { - case "bool": - return convertToBool(value, remoteTypeId); - case "string": - return convertToString(value, remoteTypeId); - case "number": - return convertToNumber(value, remoteTypeId, localTypeId); - default: - throw new Error(`Unsupported target scalar type ${localTypeId}.`); - } - } catch (error) { - const message = error instanceof Error ? error.message : String(error); - throw new Error( - `Failed to convert compatible field ${fieldName}: ${message}`, - ); + static readDecimal(reader: BinaryReader): Decimal { + return readDecimal(reader); + } + + static checkedBool(value: number): boolean { + if (value !== 0 && value !== 1) { + throw new Error(`Invalid boolean scalar value ${value}.`); } + return value === 1; + } + + static stringToBool(value: string): boolean { + return stringToBool(value); + } + + static integerToBool(value: number | bigint): boolean { + return integerToBool(value); + } + + static floatToBool(value: number): boolean { + return integerToBool(exactInteger(floatToParts(value))); + } + + static decimalToBool(value: Decimal): boolean { + return integerToBool(exactInteger(decimalToParts(value))); + } + + static floatToString(value: number): string { + return formatParts(floatToParts(value), true); + } + + static decimalToString(value: Decimal): string { + return formatParts(decimalToParts(value), false); + } + + static stringToInteger(value: string): bigint { + return exactInteger(parseDecimalString(value)); + } + + static floatToInteger(value: number): bigint { + return exactInteger(floatToParts(value)); + } + + static decimalToInteger(value: Decimal): bigint { + return exactInteger(decimalToParts(value)); + } + + static checkedInt8(value: number | bigint): number { + return rangeCheckedInteger(value, TypeId.INT8) as number; + } + + static checkedInt16(value: number | bigint): number { + return rangeCheckedInteger(value, TypeId.INT16) as number; + } + + static checkedInt32(value: number | bigint): number { + return rangeCheckedInteger(value, TypeId.INT32) as number; + } + + static checkedInt64(value: number | bigint): bigint { + return rangeCheckedInteger(value, TypeId.INT64) as bigint; + } + + static checkedUint8(value: number | bigint): number { + return rangeCheckedInteger(value, TypeId.UINT8) as number; + } + + static checkedUint16(value: number | bigint): number { + return rangeCheckedInteger(value, TypeId.UINT16) as number; + } + + static checkedUint32(value: number | bigint): number { + return rangeCheckedInteger(value, TypeId.UINT32) as number; + } + + static checkedUint64(value: number | bigint): bigint { + return rangeCheckedInteger(value, TypeId.UINT64) as bigint; + } + + static boolToDecimal(value: boolean): Decimal { + return new Decimal(value ? 1n : 0n, 0); + } + + static integerToDecimal(value: number | bigint): Decimal { + return partsToDecimal(integerToParts(value)); + } + + static floatToDecimal(value: number): Decimal { + return partsToDecimal(floatToParts(value)); + } + + static stringToDecimal(value: string): Decimal { + return partsToDecimal(parseDecimalString(value)); + } + + static integerToFloat16(value: number | bigint): number { + return exactFloat(integerToParts(value), TypeId.FLOAT16); + } + + static integerToBfloat16(value: number | bigint): number { + return exactFloat(integerToParts(value), TypeId.BFLOAT16); + } + + static integerToFloat32(value: number | bigint): number { + return exactFloat(integerToParts(value), TypeId.FLOAT32); + } + + static integerToFloat64(value: number | bigint): number { + return exactFloat(integerToParts(value), TypeId.FLOAT64); + } + + static floatToFloat16(value: number): number { + return floatToFloat(value, TypeId.FLOAT16); + } + + static floatToBfloat16(value: number): number { + return floatToFloat(value, TypeId.BFLOAT16); + } + + static floatToFloat32(value: number): number { + return floatToFloat(value, TypeId.FLOAT32); + } + + static floatToFloat64(value: number): number { + return floatToFloat(value, TypeId.FLOAT64); + } + + static decimalToFloat16(value: Decimal): number { + return exactFloat(decimalToParts(value), TypeId.FLOAT16); + } + + static decimalToBfloat16(value: Decimal): number { + return exactFloat(decimalToParts(value), TypeId.BFLOAT16); + } + + static decimalToFloat32(value: Decimal): number { + return exactFloat(decimalToParts(value), TypeId.FLOAT32); + } + + static decimalToFloat64(value: Decimal): number { + return exactFloat(decimalToParts(value), TypeId.FLOAT64); + } + + static stringToFloat16(value: string): number { + return exactFloat(parseDecimalString(value), TypeId.FLOAT16); + } + + static stringToBfloat16(value: string): number { + return exactFloat(parseDecimalString(value), TypeId.BFLOAT16); + } + + static stringToFloat32(value: string): number { + return exactFloat(parseDecimalString(value), TypeId.FLOAT32); + } + + static stringToFloat64(value: string): number { + return exactFloat(parseDecimalString(value), TypeId.FLOAT64); } } diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 8f0aa56426..d13a08138c 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -32,8 +32,8 @@ import { isCompatibleScalarPair, isCompatibleScalarType, markCompatibleScalarRead, - markCompatibleScalarSkipRead, } from "./compatible/scalar"; +import { markCompatibleSkipRead } from "./compatible/field"; type TypeResolverLike = { config: Config; @@ -47,10 +47,9 @@ type TypeResolverLike = { regenerateReadSerializer(typeInfo: TypeInfo): Serializer; }; -type RegeneratedReadSerializerCacheEntry = { +type CompatibleReadSerializerCacheEntry = { localHash: number; - localTypeInfo: TypeInfo; - serializers: Map; + serializer: Serializer; }; function remoteListElementType( @@ -562,7 +561,7 @@ export class WriteContext { export class ReadContext { private static readonly MAX_CACHED_TYPE_META = 8192; - private static readonly MAX_CACHED_REGENERATED_READ_SERIALIZER = 8192; + private static readonly MAX_CACHED_COMPATIBLE_READ_SERIALIZER = 8192; readonly reader: BinaryReader; readonly refReader: RefReader; @@ -578,10 +577,10 @@ export class ReadContext { private recentTypeMetaHeaderLows = [-1, -1, -1, -1]; private recentTypeMetaHeaderHighs = [-1, -1, -1, -1]; private recentTypeMetas: Array = [null, null, null, null]; - private regeneratedReadSerializers: WeakMap< - Serializer, - RegeneratedReadSerializerCacheEntry - > = new WeakMap(); + private compatibleReadSerializers = new Map< + number, + CompatibleReadSerializerCacheEntry + >(); private _depth = 0; private _maxDepth: number; @@ -798,17 +797,96 @@ export class ReadContext { remoteHash = ReadContext.typeMetaHeaderHash(headerLow, headerHigh); } if (expectedHash !== remoteHash) { - return this.genSerializerByTypeMetaRuntime(typeMeta, original); + const cached = this.compatibleReadSerializers.get(remoteHash); + if (cached !== undefined && cached.localHash === expectedHash) { + return cached.serializer; + } + const serializer = this.genSerializerByTypeMetaRuntime( + typeMeta, + original, + ); + if ( + this.compatibleReadSerializers.size + < ReadContext.MAX_CACHED_COMPATIBLE_READ_SERIALIZER + ) { + this.compatibleReadSerializers.set(remoteHash, { + localHash: expectedHash, + serializer, + }); + } + return serializer; } return undefined; } + private canonicalTypeId(typeId: number): number { + switch (typeId) { + case TypeId.NAMED_ENUM: + return TypeId.ENUM; + case TypeId.NAMED_EXT: + return TypeId.EXT; + case TypeId.NAMED_COMPATIBLE_STRUCT: + case TypeId.NAMED_STRUCT: + case TypeId.COMPATIBLE_STRUCT: + case TypeId.STRUCT: + return TypeId.STRUCT; + case TypeId.NAMED_UNION: + case TypeId.TYPED_UNION: + return TypeId.UNION; + default: + return typeId; + } + } + + private canonicalFieldTypeId(typeInfo: TypeInfo): number { + return this.canonicalTypeId(this.typeResolver.computeTypeId(typeInfo)); + } + + private fieldSchemasEqual( + remote: InnerFieldInfo | undefined, + local: TypeInfo | undefined, + ): boolean { + if (remote === undefined || local === undefined) { + return false; + } + if ( + this.canonicalTypeId(remote.typeId) !== this.canonicalFieldTypeId(local) + ) { + return false; + } + if ( + (remote.trackingRef === true) !== (local.trackingRef === true) + || (remote.nullable === true) !== (local.nullable === true) + ) { + return false; + } + switch (remote.typeId) { + case TypeId.MAP: + return ( + this.fieldSchemasEqual(remote.options?.key, local.options?.key) + && this.fieldSchemasEqual(remote.options?.value, local.options?.value) + ); + case TypeId.LIST: + return this.fieldSchemasEqual( + remote.options?.inner, + local.options?.inner, + ); + case TypeId.SET: + return this.fieldSchemasEqual(remote.options?.key, local.options?.key); + default: + return true; + } + } + private fieldInfoToTypeInfo( fieldInfo: InnerFieldInfo, fallbackTypeInfo?: TypeInfo, topLevel = true, ): TypeInfo { if (topLevel && fallbackTypeInfo) { + if (this.fieldSchemasEqual(fieldInfo, fallbackTypeInfo)) { + return fallbackTypeInfo.clone(); + } const compatible = this.compatibleFieldTypeInfo( fieldInfo, fallbackTypeInfo, @@ -826,8 +904,8 @@ export class ReadContext { && (fieldInfo.typeId !== fallbackTypeInfo.typeId || fieldInfo.nullable !== fallbackTypeInfo.nullable))) ) { - return markCompatibleScalarSkipRead( - this.fieldInfoToTypeInfo(fieldInfo, undefined, false), + throw new Error( + "unsupported compatible scalar tracking-ref schema mismatch", ); } if ( @@ -836,10 +914,27 @@ export class ReadContext { && (fieldInfo.trackingRef === true || fallbackTypeInfo.trackingRef === true) ) { - return markCompatibleScalarSkipRead( - this.fieldInfoToTypeInfo(fieldInfo, undefined, false), + throw new Error( + "unsupported compatible scalar tracking-ref schema mismatch", ); } + if ( + this.hasUnsupportedListArrayMismatch( + fieldInfo, + fallbackTypeInfo, + topLevel, + ) + ) { + throw new Error("unsupported compatible list/array schema mismatch"); + } + if ( + fieldInfo.typeId !== TypeId.UNKNOWN + && this.canonicalFieldTypeId(fallbackTypeInfo) !== TypeId.UNKNOWN + && this.canonicalTypeId(fieldInfo.typeId) + !== this.canonicalFieldTypeId(fallbackTypeInfo) + ) { + throw new Error("unsupported compatible field schema mismatch"); + } } if ( this.hasUnsupportedListArrayMismatch( @@ -850,6 +945,9 @@ export class ReadContext { ) { throw new Error("unsupported compatible list/array schema mismatch"); } + if (this.hasNestedSchemaMismatch(fieldInfo, fallbackTypeInfo, topLevel)) { + throw new Error("unsupported compatible field schema mismatch"); + } switch (fieldInfo.typeId) { case TypeId.MAP: return Type.map( @@ -914,13 +1012,103 @@ export class ReadContext { } } + private hasNestedSchemaMismatch( + remote: InnerFieldInfo, + local: TypeInfo | undefined, + topLevel: boolean, + ): boolean { + if (topLevel || local === undefined) { + return false; + } + if ( + this.schemaMatchTypeId(remote.typeId) + !== this.schemaMatchTypeId(this.typeResolver.computeTypeId(local)) + ) { + return true; + } + const remoteTracksRef = remote.trackingRef === true; + const localTracksRef = local.trackingRef === true; + if ( + remoteTracksRef !== localTracksRef + || ((remoteTracksRef || localTracksRef) + && (remote.nullable === true) !== (local.nullable === true)) + ) { + return true; + } + switch (remote.typeId) { + case TypeId.MAP: + return ( + local.options?.key === undefined + || local.options?.value === undefined + || this.hasNestedSchemaMismatch( + remote.options!.key!, + local.options.key, + false, + ) + || this.hasNestedSchemaMismatch( + remote.options!.value!, + local.options.value, + false, + ) + ); + case TypeId.LIST: + return ( + local.options?.inner === undefined + || this.hasNestedSchemaMismatch( + remote.options!.inner!, + local.options.inner, + false, + ) + ); + case TypeId.SET: + return ( + local.options?.key === undefined + || this.hasNestedSchemaMismatch( + remote.options!.key!, + local.options.key, + false, + ) + ); + default: + return false; + } + } + + private schemaMatchTypeId(typeId: number): number { + return this.canonicalTypeId(typeId); + } + private compatibleFieldTypeInfo( remote: InnerFieldInfo, local: TypeInfo, ): TypeInfo | undefined { + if (this.isByteSequenceRootPair(remote, local)) { + if ( + (remote.nullable === true) !== (local.nullable === true) + || (remote.trackingRef === true) !== (local.trackingRef === true) + ) { + return undefined; + } + return local.clone(); + } + if ( + this.isListArrayRootPair(remote, local) + && (remote.nullable === true + || local.nullable === true + || remote.trackingRef === true + || local.trackingRef === true) + ) { + return undefined; + } const remoteElement = remoteListElementType(remote); const localElement = denseArrayElementTypeId(local.typeId); if (remoteElement !== undefined && localElement !== undefined) { + // Nullable element schema is allowed for list -> array; actual + // null payload elements fail in the dense-array reader. Ref-tracked + // element framing is rejected here because this path stays primitive-only. + if (remoteElement.trackingRef === true) { + return undefined; + } if (compatibleArrayElementTypeId(remoteElement.typeId) !== localElement) { return undefined; } @@ -939,8 +1127,10 @@ export class ReadContext { if ( remote.trackingRef !== true && local.trackingRef !== true - && !(remote.typeId === local.typeId - && (remote.nullable === true) === (local.nullable === true)) + && !( + remote.typeId === local.typeId + && (remote.nullable === true) === (local.nullable === true) + ) && isCompatibleScalarPair(remote.typeId, local.typeId) ) { return markCompatibleScalarRead(local.clone(), { @@ -1009,25 +1199,15 @@ export class ReadContext { ); } - private getRegeneratedReadSerializerCache( - original: Serializer, - ): RegeneratedReadSerializerCacheEntry { - const localTypeInfo = original.getTypeInfo(); - const localHash = original.getHash(); - let entry = this.regeneratedReadSerializers.get(original); - if ( - entry === undefined - || entry.localTypeInfo !== localTypeInfo - || entry.localHash !== localHash - ) { - entry = { - localHash, - localTypeInfo, - serializers: new Map(), - }; - this.regeneratedReadSerializers.set(original, entry); - } - return entry; + private isByteSequenceRootPair( + remote: InnerFieldInfo, + local: TypeInfo, + ): boolean { + return ( + (remote.typeId === TypeId.BINARY + && local.typeId === TypeId.UINT8_ARRAY) + || (remote.typeId === TypeId.UINT8_ARRAY && local.typeId === TypeId.BINARY) + ); } genSerializerByTypeMetaRuntime(typeMeta: TypeMeta, original?: Serializer) { @@ -1046,15 +1226,6 @@ export class ReadContext { ); } } - const cacheEntry - = original === undefined - ? undefined - : this.getRegeneratedReadSerializerCache(original); - const remoteHash = typeMeta.getHash(); - const cached = cacheEntry?.serializers.get(remoteHash); - if (cached !== undefined) { - return cached; - } let typeInfo: TypeInfo; if (original) { typeInfo = original.getTypeInfo().clone(); @@ -1067,33 +1238,34 @@ export class ReadContext { }); } const localProps = original?.getTypeInfo().options?.props; - const props = Object.fromEntries( - typeMeta.remapFieldNames(localProps).map((fieldInfo) => { + const fieldEntries = typeMeta + .remapFieldNames(localProps) + .map((fieldInfo) => { const localFieldTypeInfo = localProps?.[fieldInfo.getFieldName()]; - const fieldTypeInfo = this.fieldInfoToTypeInfo( + let fieldTypeInfo = this.fieldInfoToTypeInfo( fieldInfo, localFieldTypeInfo, ) .setNullable(fieldInfo.nullable) .setTrackingRef(fieldInfo.trackingRef) .setId(fieldInfo.fieldId); - return [fieldInfo.getFieldName(), fieldTypeInfo]; - }), + if (localFieldTypeInfo === undefined) { + fieldTypeInfo = markCompatibleSkipRead(fieldTypeInfo); + } + return { key: fieldInfo.getFieldName(), typeInfo: fieldTypeInfo }; + }); + const props = Object.fromEntries( + fieldEntries.map(({ key, typeInfo }) => [key, typeInfo]), ); typeInfo.options = { ...typeInfo.options, + preserveFieldOrder: true, + fieldEntries, props, }; const serializer = original ? this.typeResolver.generateReadSerializer(typeInfo) : this.typeResolver.regenerateReadSerializer(typeInfo); - if ( - cacheEntry !== undefined - && cacheEntry.serializers.size - < ReadContext.MAX_CACHED_REGENERATED_READ_SERIALIZER - ) { - cacheEntry.serializers.set(remoteHash, serializer); - } return serializer; } diff --git a/javascript/packages/core/lib/gen/struct.ts b/javascript/packages/core/lib/gen/struct.ts index 7a24037bae..3c791e705f 100644 --- a/javascript/packages/core/lib/gen/struct.ts +++ b/javascript/packages/core/lib/gen/struct.ts @@ -28,8 +28,8 @@ import { getCompatibleCollectionArrayReadAction } from "./collection"; import { CompatibleScalarConverter, getCompatibleScalarReadAction, - shouldSkipCompatibleScalarRead, } from "../compatible/scalar"; +import { shouldSkipCompatibleRead } from "../compatible/field"; /** * Returns true when a field's read cannot recurse and needs no depth tracking. @@ -65,8 +65,16 @@ function compatibleReadTargetExpr(typeInfo: TypeInfo, expr: string): string { } const sortProps = (typeInfo: TypeInfo, typeResolver: CodecBuilder["resolver"]) => { - const names = TypeMeta.fromTypeInfo(typeInfo, typeResolver).getFieldInfo(); const props = typeInfo.options!.props; + if (typeInfo.options!.preserveFieldOrder) { + return typeInfo.options!.fieldEntries ?? Object.entries(props!).map( + ([key, fieldTypeInfo]) => ({ + key, + typeInfo: fieldTypeInfo, + }), + ); + } + const names = TypeMeta.fromTypeInfo(typeInfo, typeResolver).getFieldInfo(); return names.map((x) => { return { key: x.fieldName, @@ -101,11 +109,30 @@ function isDirectVarInt32Field( typeInfo: TypeInfo, typeResolver: CodecBuilder["resolver"], ) { - return typeInfo.typeId === TypeId.VARINT32 - && toRefMode(typeInfo.trackingRef, typeInfo.nullable) === RefMode.NONE - && getCompatibleScalarReadAction(typeInfo) === undefined - && !shouldSkipCompatibleScalarRead(typeInfo) - && typeResolver.isMonomorphic(typeInfo, typeInfo.dynamic); + return varInt32ObjectReadKind(typeInfo, typeResolver) === "number"; +} + +function varInt32ObjectReadKind( + typeInfo: TypeInfo, + typeResolver: CodecBuilder["resolver"], +): "number" | "bigint" | null { + if ( + toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE + || !typeResolver.isMonomorphic(typeInfo, typeInfo.dynamic) + ) { + return null; + } + const scalarAction = getCompatibleScalarReadAction(typeInfo); + if (scalarAction !== undefined) { + return scalarAction.remoteNullable !== true + && scalarAction.remoteTypeId === TypeId.VARINT32 + && (scalarAction.localTypeId === TypeId.INT64 + || scalarAction.localTypeId === TypeId.VARINT64 + || scalarAction.localTypeId === TypeId.TAGGED_INT64) + ? "bigint" + : null; + } + return typeInfo.typeId === TypeId.VARINT32 ? "number" : null; } function directNumericFieldReadExpr( @@ -116,7 +143,6 @@ function directNumericFieldReadExpr( toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE || !builder.resolver.isMonomorphic(typeInfo, typeInfo.dynamic) || getCompatibleScalarReadAction(typeInfo) !== undefined - || shouldSkipCompatibleScalarRead(typeInfo) ) { return null; } @@ -162,6 +188,372 @@ function directNumericFieldReadExpr( } } +function compatibleScalarFieldReadExpr( + remoteTypeId: number, + localTypeId: number, + builder: CodecBuilder, +): string | null { + const converter = builder.getExternal(CompatibleScalarConverter.name); + const remoteRead = compatibleScalarRemoteReadExpr( + remoteTypeId, + builder, + converter, + ); + if (remoteRead === null) { + return null; + } + if (remoteTypeId === localTypeId) { + return remoteRead; + } + const remoteCanonical = canonicalScalarTypeId(remoteTypeId); + const localCanonical = canonicalScalarTypeId(localTypeId); + switch (localCanonical) { + case TypeId.BOOL: + return scalarToBoolExpr(remoteCanonical, remoteRead, converter); + case TypeId.STRING: + return scalarToStringExpr(remoteCanonical, remoteRead, converter); + case TypeId.DECIMAL: + return scalarToDecimalExpr(remoteCanonical, remoteRead, converter); + case TypeId.INT8: + case TypeId.INT16: + case TypeId.INT32: + case TypeId.INT64: + case TypeId.UINT8: + case TypeId.UINT16: + case TypeId.UINT32: + case TypeId.UINT64: + return scalarToIntegerExpr( + remoteCanonical, + localCanonical, + remoteRead, + converter, + ); + case TypeId.FLOAT16: + case TypeId.BFLOAT16: + case TypeId.FLOAT32: + case TypeId.FLOAT64: + return scalarToFloatExpr( + remoteCanonical, + localCanonical, + remoteRead, + converter, + ); + default: + return null; + } +} + +function canonicalScalarTypeId(typeId: number): number { + switch (typeId) { + case TypeId.VARINT32: + return TypeId.INT32; + case TypeId.VARINT64: + case TypeId.TAGGED_INT64: + return TypeId.INT64; + case TypeId.VAR_UINT32: + return TypeId.UINT32; + case TypeId.VAR_UINT64: + case TypeId.TAGGED_UINT64: + return TypeId.UINT64; + default: + return typeId; + } +} + +function compatibleScalarRemoteReadExpr( + remoteTypeId: number, + builder: CodecBuilder, + converter: string, +): string | null { + const reader = builder.reader; + switch (remoteTypeId) { + case TypeId.BOOL: + return `${converter}.checkedBool(${reader.readUint8()})`; + case TypeId.INT8: + return reader.readInt8(); + case TypeId.INT16: + return reader.readInt16(); + case TypeId.INT32: + return reader.readInt32(); + case TypeId.VARINT32: + return reader.readVarInt32(); + case TypeId.INT64: + return reader.readInt64(); + case TypeId.VARINT64: + return reader.readVarInt64(); + case TypeId.TAGGED_INT64: + return reader.readTaggedInt64(); + case TypeId.UINT8: + return reader.readUint8(); + case TypeId.UINT16: + return reader.readUint16(); + case TypeId.UINT32: + return reader.readUint32(); + case TypeId.VAR_UINT32: + return reader.readVarUInt32(); + case TypeId.UINT64: + return reader.readUint64(); + case TypeId.VAR_UINT64: + return reader.readVarUInt64(); + case TypeId.TAGGED_UINT64: + return reader.readTaggedUInt64(); + case TypeId.FLOAT16: + return reader.readFloat16(); + case TypeId.BFLOAT16: + return reader.readBfloat16(); + case TypeId.FLOAT32: + return reader.readFloat32(); + case TypeId.FLOAT64: + return reader.readFloat64(); + case TypeId.STRING: + return reader.stringWithHeader(); + case TypeId.DECIMAL: + return `${converter}.readDecimal(${reader.ownName()})`; + default: + return null; + } +} + +function scalarToBoolExpr( + remoteTypeId: number, + value: string, + converter: string, +): string | null { + switch (remoteTypeId) { + case TypeId.BOOL: + return value; + case TypeId.STRING: + return `${converter}.stringToBool(${value})`; + case TypeId.DECIMAL: + return `${converter}.decimalToBool(${value})`; + case TypeId.FLOAT16: + case TypeId.BFLOAT16: + case TypeId.FLOAT32: + case TypeId.FLOAT64: + return `${converter}.floatToBool(${value})`; + default: + return `${converter}.integerToBool(${value})`; + } +} + +function scalarToStringExpr( + remoteTypeId: number, + value: string, + converter: string, +): string | null { + switch (remoteTypeId) { + case TypeId.BOOL: + return `(${value} ? "true" : "false")`; + case TypeId.STRING: + return value; + case TypeId.DECIMAL: + return `${converter}.decimalToString(${value})`; + case TypeId.FLOAT16: + case TypeId.BFLOAT16: + case TypeId.FLOAT32: + case TypeId.FLOAT64: + return `${converter}.floatToString(${value})`; + default: + return `${value}.toString()`; + } +} + +function scalarToDecimalExpr( + remoteTypeId: number, + value: string, + converter: string, +): string | null { + switch (remoteTypeId) { + case TypeId.BOOL: + return `${converter}.boolToDecimal(${value})`; + case TypeId.STRING: + return `${converter}.stringToDecimal(${value})`; + case TypeId.DECIMAL: + return value; + case TypeId.FLOAT16: + case TypeId.BFLOAT16: + case TypeId.FLOAT32: + case TypeId.FLOAT64: + return `${converter}.floatToDecimal(${value})`; + default: + return `${converter}.integerToDecimal(${value})`; + } +} + +function scalarToIntegerExpr( + remoteTypeId: number, + localTypeId: number, + value: string, + converter: string, +): string | null { + const checkedMethod = checkedIntegerMethod(localTypeId); + if (checkedMethod === null) { + return null; + } + switch (remoteTypeId) { + case TypeId.BOOL: + if (localTypeId === TypeId.INT64 || localTypeId === TypeId.UINT64) { + return `(${value} ? 1n : 0n)`; + } + return `(${value} ? 1 : 0)`; + case TypeId.STRING: + return `${converter}.${checkedMethod}(${converter}.stringToInteger(${value}))`; + case TypeId.DECIMAL: + return `${converter}.${checkedMethod}(${converter}.decimalToInteger(${value}))`; + case TypeId.FLOAT16: + case TypeId.BFLOAT16: + case TypeId.FLOAT32: + case TypeId.FLOAT64: + return `${converter}.${checkedMethod}(${converter}.floatToInteger(${value}))`; + default: + return integerToIntegerExpr(remoteTypeId, localTypeId, value, converter); + } +} + +function integerToIntegerExpr( + remoteTypeId: number, + localTypeId: number, + value: string, + converter: string, +): string { + if (remoteTypeId === localTypeId) { + return value; + } + if (integerRangeFits(remoteTypeId, localTypeId)) { + if (localTypeId === TypeId.INT64 || localTypeId === TypeId.UINT64) { + return `BigInt(${value})`; + } + return value; + } + return `${converter}.${checkedIntegerMethod(localTypeId)}(${value})`; +} + +function checkedIntegerMethod(localTypeId: number): string | null { + switch (localTypeId) { + case TypeId.INT8: + return "checkedInt8"; + case TypeId.INT16: + return "checkedInt16"; + case TypeId.INT32: + return "checkedInt32"; + case TypeId.INT64: + return "checkedInt64"; + case TypeId.UINT8: + return "checkedUint8"; + case TypeId.UINT16: + return "checkedUint16"; + case TypeId.UINT32: + return "checkedUint32"; + case TypeId.UINT64: + return "checkedUint64"; + default: + return null; + } +} + +function integerRangeFits(remoteTypeId: number, localTypeId: number): boolean { + switch (localTypeId) { + case TypeId.INT16: + return remoteTypeId === TypeId.INT8 || remoteTypeId === TypeId.UINT8; + case TypeId.INT32: + return remoteTypeId === TypeId.INT8 + || remoteTypeId === TypeId.INT16 + || remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16; + case TypeId.INT64: + return remoteTypeId === TypeId.INT8 + || remoteTypeId === TypeId.INT16 + || remoteTypeId === TypeId.INT32 + || remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16 + || remoteTypeId === TypeId.UINT32; + case TypeId.UINT16: + return remoteTypeId === TypeId.UINT8; + case TypeId.UINT32: + return remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16; + case TypeId.UINT64: + return remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16 + || remoteTypeId === TypeId.UINT32; + default: + return false; + } +} + +function scalarToFloatExpr( + remoteTypeId: number, + localTypeId: number, + value: string, + converter: string, +): string | null { + const method = floatMethod("Float", localTypeId); + if (method === null) { + return null; + } + switch (remoteTypeId) { + case TypeId.BOOL: + return `(${value} ? 1 : 0)`; + case TypeId.STRING: + return `${converter}.stringTo${method}(${value})`; + case TypeId.DECIMAL: + return `${converter}.decimalTo${method}(${value})`; + case TypeId.FLOAT16: + case TypeId.BFLOAT16: + case TypeId.FLOAT32: + case TypeId.FLOAT64: + if (remoteTypeId === localTypeId) { + return value; + } + return `${converter}.floatTo${method}(${value})`; + default: + if (integerRangeFitsFloat(remoteTypeId, localTypeId)) { + return `Number(${value})`; + } + return `${converter}.integerTo${method}(${value})`; + } +} + +function floatMethod(prefix: string, localTypeId: number): string | null { + switch (localTypeId) { + case TypeId.FLOAT16: + return `${prefix}16`; + case TypeId.BFLOAT16: + return `${prefix === "Float" ? "Bfloat" : "bfloat"}16`; + case TypeId.FLOAT32: + return `${prefix}32`; + case TypeId.FLOAT64: + return `${prefix}64`; + default: + return null; + } +} + +function integerRangeFitsFloat( + remoteTypeId: number, + localTypeId: number, +): boolean { + switch (localTypeId) { + case TypeId.FLOAT16: + case TypeId.BFLOAT16: + return remoteTypeId === TypeId.INT8 || remoteTypeId === TypeId.UINT8; + case TypeId.FLOAT32: + return remoteTypeId === TypeId.INT8 + || remoteTypeId === TypeId.INT16 + || remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16; + case TypeId.FLOAT64: + return remoteTypeId === TypeId.INT8 + || remoteTypeId === TypeId.INT16 + || remoteTypeId === TypeId.INT32 + || remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16 + || remoteTypeId === TypeId.UINT32; + default: + return false; + } +} + class StructSerializerGenerator extends BaseSerializerGenerator { typeInfo: TypeInfo; sortedProps: { key: string; typeInfo: TypeInfo }[]; @@ -196,8 +588,8 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const { nullable = false, dynamic, trackingRef } = fieldTypeInfo; const refMode = toRefMode(trackingRef, nullable); const assignCompatible = (expr: string) => assignStmt(compatibleReadTargetExpr(fieldTypeInfo, expr)); - const discard = (expr: string) => `${expr};`; - if (shouldSkipCompatibleScalarRead(fieldTypeInfo)) { + if (shouldSkipCompatibleRead(fieldTypeInfo)) { + const discard = (expr: string) => `${expr};`; if (this.builder.resolver.isMonomorphic(fieldTypeInfo, dynamic)) { if (refMode == RefMode.TRACKING || refMode === RefMode.NULL_ONLY) { return embedGenerator.readRefWithoutTypeInfo(discard); @@ -214,8 +606,16 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } const scalarAction = getCompatibleScalarReadAction(fieldTypeInfo); if (scalarAction) { - const converter = this.builder.getExternal(CompatibleScalarConverter.name); - const readValue = `${converter}.read(${this.builder.reader.ownName()}, ${scalarAction.remoteTypeId}, ${scalarAction.localTypeId}, ${CodecBuilder.safeString(fieldName)})`; + const readValue = compatibleScalarFieldReadExpr( + scalarAction.remoteTypeId, + scalarAction.localTypeId, + this.builder, + ); + if (readValue === null) { + throw new Error( + `Unsupported compatible scalar conversion from ${scalarAction.remoteTypeId} to ${scalarAction.localTypeId}`, + ); + } if (scalarAction.remoteNullable !== true) { return assignStmt(readValue); } @@ -449,8 +849,11 @@ class StructSerializerGenerator extends BaseSerializerGenerator { : ` const ${result} = { ${this.sortedProps.map(({ key }) => { + if (shouldSkipCompatibleRead(this.typeInfo.options!.props![key])) { + return ""; + } return `${CodecBuilder.safePropName(key)}: null`; - }).join(",\n")} + }).filter(Boolean).join(",\n")} }; ` } @@ -482,7 +885,19 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } const fields: Array<{ key: string; expr: string }> = []; for (const { key, typeInfo } of this.sortedProps) { - const expr = directNumericFieldReadExpr(typeInfo, this.builder); + if (shouldSkipCompatibleRead(typeInfo)) { + return null; + } + const scalarAction = getCompatibleScalarReadAction(typeInfo); + const expr = scalarAction?.remoteNullable === true + ? null + : scalarAction + ? compatibleScalarFieldReadExpr( + scalarAction.remoteTypeId, + scalarAction.localTypeId, + this.builder, + ) + : directNumericFieldReadExpr(typeInfo, this.builder); if (expr === null) { return null; } @@ -505,12 +920,24 @@ class StructSerializerGenerator extends BaseSerializerGenerator { if ( this.typeInfo.options!.withConstructor || this.sortedProps.length === 0 - || !this.sortedProps.every(({ typeInfo }) => - isDirectVarInt32Field(typeInfo, this.builder.resolver) - ) ) { return null; } + const fields = []; + for (const { key, typeInfo } of this.sortedProps) { + if (shouldSkipCompatibleRead(typeInfo)) { + return null; + } + const kind = varInt32ObjectReadKind(typeInfo, this.builder.resolver); + if (kind === null) { + return null; + } + fields.push({ + key, + kind, + local: this.scope.uniqueName(key), + }); + } const cursor = this.scope.uniqueName("cursor"); const dataView = this.scope.uniqueName("dataView"); const byteLength = this.scope.uniqueName("byteLength"); @@ -518,11 +945,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const byte = this.scope.uniqueName("byte"); const value = this.scope.uniqueName("value"); const result = this.scope.uniqueName("result"); - const fields = this.sortedProps.map(({ key }) => ({ - key, - local: this.scope.uniqueName(key), - })); - const reads = fields.map(({ local }) => ` + const reads = fields.map(({ local, kind }) => ` if (${byteLength} - ${cursor} >= 5) { ${fourByteValue} = ${dataView}.getUint32(${cursor}, true); ${cursor}++; @@ -562,7 +985,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } } } - const ${local} = (${value} >>> 1) ^ -(${value} & 1); + const ${local} = ${kind === "bigint" ? "BigInt(" : ""}(${value} >>> 1) ^ -(${value} & 1)${kind === "bigint" ? ")" : ""}; `).join("\n"); return ` let ${cursor} = ${this.builder.reader.readGetCursor()}; diff --git a/javascript/packages/core/lib/meta/TypeMeta.ts b/javascript/packages/core/lib/meta/TypeMeta.ts index 351b21d49a..c4cfed4af7 100644 --- a/javascript/packages/core/lib/meta/TypeMeta.ts +++ b/javascript/packages/core/lib/meta/TypeMeta.ts @@ -287,6 +287,7 @@ function nonStructTypeId(kindCode: number): number { export class TypeMeta { private headerHash: number | null; private readonly compressed: boolean; + private bytes: Uint8Array | null = null; private constructor( private fields: FieldInfo[], @@ -774,6 +775,9 @@ export class TypeMeta { if (this.compressed) { throw new Error("compressed TypeMeta is not supported yet"); } + if (this.bytes !== null) { + return this.bytes; + } const writer = new BinaryWriter({}); writer.writeUint8(-1); // placeholder for header, update later @@ -827,7 +831,8 @@ export class TypeMeta { const buffer = writer.dump(); - return this.prependHeader(buffer, false); + this.bytes = this.prependHeader(buffer, false); + return this.bytes; } writePkgName(writer: BinaryWriter, pkg: string) { diff --git a/javascript/packages/core/lib/typeInfo.ts b/javascript/packages/core/lib/typeInfo.ts index acfd33e42e..2535004807 100644 --- a/javascript/packages/core/lib/typeInfo.ts +++ b/javascript/packages/core/lib/typeInfo.ts @@ -62,6 +62,8 @@ class ExtensibleFunction extends Function { interface TypeInfoOptions { props?: { [key: string]: TypeInfo }; + fieldEntries?: Array<{ key: string; typeInfo: TypeInfo }>; + preserveFieldOrder?: boolean; withConstructor?: boolean; creator?: Function; key?: TypeInfo; diff --git a/javascript/test/typemeta.test.ts b/javascript/test/typemeta.test.ts index f2e7013824..83318c01e5 100644 --- a/javascript/test/typemeta.test.ts +++ b/javascript/test/typemeta.test.ts @@ -70,6 +70,13 @@ function readCompatibleScalar( return reader.deserialize(writer.serialize({ value })); } +function typeMetaRecord(typeMeta: TypeMeta): Uint8Array { + const writer = new BinaryWriter({}); + writer.writeVarUInt32(0); + writer.buffer(typeMeta.toBytes()); + return writer.dump(); +} + describe("typemeta", () => { test("splits dotted names", () => { const structInfo = Type.struct({ typeName: "com.example.User" }, {}); @@ -299,7 +306,7 @@ describe("typemeta", () => { expect(result).toEqual({ value: 123 }); }); - test("does not retain regenerated compatible serializer after a schema mismatch read", () => { + test("changed-schema reader does not replace the original serializer", () => { const changedWriterFory = new Fory({ compatible: true }); const localWriterFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); @@ -343,7 +350,7 @@ describe("typemeta", () => { expect(serializer.getTypeInfo().named).toBe("example$repro_struct"); }); - test("caches regenerated compatible readers for alternating nested schemas", () => { + test("caches compatible readers for alternating nested schemas", () => { const stringWriterFory = new Fory({ compatible: true }); const boolWriterFory = new Fory({ compatible: true }); const localWriterFory = new Fory({ compatible: true }); @@ -407,6 +414,52 @@ describe("typemeta", () => { expect(generatedReaders).toBe(2); }); + test("compatible reader cache uses remote hash and local stale guard", () => { + const typeMeta = TypeMeta.fromTypeInfo( + Type.struct(7313, { + value: Type.string().setId(1), + }), + ); + const bytes = typeMetaRecord(typeMeta); + const config = { ref: false, useSliceString: false, hooks: {} } as any; + const context = new ReadContext( + { + config, + trackingRef: false, + computeTypeId: (typeInfo: any) => typeInfo.typeId, + getSerializerById: () => undefined, + getSerializerByName: () => undefined, + getSerializerByData: () => undefined, + isCompatible: () => true, + generateReadSerializer: () => { + throw new Error("unused"); + }, + regenerateReadSerializer: () => { + throw new Error("unused"); + }, + } as any, + config, + ); + const serializers = [{ name: "localA" }, { name: "localB" }] as any[]; + let generatedReaders = 0; + (context as any).genSerializerByTypeMetaRuntime = () => + serializers[generatedReaders++]; + const localHashA = typeMeta.getHash() + 1; + const localHashB = typeMeta.getHash() + 2; + const originalA = { name: "originalA" } as any; + const originalB = { name: "originalB" } as any; + const readChanged = (localHash: number, original: any) => { + context.reset(bytes); + return context.readTypeMetaIfSchemaChanged(localHash, original); + }; + + expect(readChanged(localHashA, originalA)).toBe(serializers[0]); + expect(readChanged(localHashA, originalB)).toBe(serializers[0]); + expect(readChanged(localHashB, originalA)).toBe(serializers[1]); + expect(readChanged(localHashB, originalB)).toBe(serializers[1]); + expect(generatedReaders).toBe(2); + }); + test("remaps compatible tag-id fields onto local property names during regeneration", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); @@ -466,10 +519,113 @@ describe("typemeta", () => { expect(decimalResult.value.equals(decimal(0n, 0))).toBe(true); }); + test("generates direct compatible scalar reads", () => { + const generated: string[] = []; + const writerFory = new Fory({ compatible: true }); + const readerFory = new Fory({ + compatible: true, + hooks: { + afterCodeGenerated: (code) => { + generated.push(code); + return code; + }, + }, + }); + const writer = writerFory.register( + Type.struct(7261, { + wide: Type.int32({ encoding: "fixed" }).setId(1), + flag: Type.string().setId(2), + label: Type.bool().setId(3), + narrow: Type.int64({ encoding: "fixed" }).setId(4), + }), + ); + const reader = readerFory.register( + Type.struct(7261, { + wide: Type.int64({ encoding: "fixed" }).setId(1), + flag: Type.bool().setId(2), + label: Type.string().setId(3), + narrow: Type.int32({ encoding: "fixed" }).setId(4), + }), + ); + + const result = reader.deserialize( + writer.serialize({ + wide: 7, + flag: "true", + label: false, + narrow: 42n, + }), + ); + const source = generated.join("\n"); + + expect(result).toEqual({ + wide: 7n, + flag: true, + label: "false", + narrow: 42, + }); + expect(source).toContain("BigInt(br.readInt32())"); + expect(source).toContain( + "external.CompatibleScalarConverter.stringToBool(br.stringWithHeader())", + ); + expect(source).toContain( + '(external.CompatibleScalarConverter.checkedBool(br.readUint8()) ? "true" : "false")', + ); + expect(source).toContain( + "external.CompatibleScalarConverter.checkedInt32(br.readInt64())", + ); + expect(source).not.toContain("CompatibleScalarConverter.read("); + expect(source).not.toContain("remoteTypeId"); + expect(source).not.toContain("scalarKind("); + }); + + test("preserves regenerated compatible remote field order", () => { + const generated: string[] = []; + const writerFory = new Fory({ compatible: true }); + const readerFory = new Fory({ + compatible: true, + hooks: { + afterCodeGenerated: (code) => { + generated.push(code); + return code; + }, + }, + }); + const writer = writerFory.register( + Type.struct(7264, { + remoteFirst: Type.string().setId(1), + remoteSecond: Type.string().setId(2), + }), + ); + const reader = readerFory.register( + Type.struct(7264, { + "10": Type.string().setId(1), + "1": Type.string().setId(2), + }), + ); + + const result = reader.deserialize( + writer.serialize({ + remoteFirst: "first", + remoteSecond: "second", + }), + ); + const source = generated.join("\n"); + + expect(result).toEqual({ + "10": "first", + "1": "second", + }); + const firstRead = source.indexOf('["10"] = result_'); + const secondRead = source.indexOf('["1"] = result_'); + expect(firstRead).toBeGreaterThanOrEqual(0); + expect(secondRead).toBeGreaterThan(firstRead); + }); + test("rejects invalid bool scalars", () => { expect(() => readCompatibleScalar(7225, Type.string(), Type.bool(), "yes"), - ).toThrow(/compatible field value/); + ).toThrow(/not a boolean value/); expect(() => readCompatibleScalar( 7226, @@ -477,7 +633,7 @@ describe("typemeta", () => { Type.bool(), 2, ), - ).toThrow(/compatible field value/); + ).toThrow(/not a boolean value/); }); test("converts exact number scalars", () => { @@ -542,16 +698,16 @@ describe("typemeta", () => { test("rejects inexact number scalars", () => { expect(() => readCompatibleScalar(7232, Type.string(), Type.float64(), "0.1"), - ).toThrow(/compatible field value/); + ).toThrow(/not exactly representable/); expect(() => readCompatibleScalar(7248, Type.string(), Type.int32(), "+1"), - ).toThrow(/compatible field value/); + ).toThrow(/Invalid scalar string/); expect(() => readCompatibleScalar(7249, Type.string(), Type.float64(), ".5"), - ).toThrow(/compatible field value/); + ).toThrow(/Invalid scalar string/); expect(() => readCompatibleScalar(7250, Type.string(), Type.float64(), "1."), - ).toThrow(/compatible field value/); + ).toThrow(/Invalid scalar string/); expect(() => readCompatibleScalar( 7251, @@ -559,7 +715,7 @@ describe("typemeta", () => { Type.decimal(), "1".repeat(257), ), - ).toThrow(/compatible field value/); + ).toThrow(/Invalid scalar string/); expect(() => readCompatibleScalar( 7253, @@ -567,16 +723,16 @@ describe("typemeta", () => { Type.decimal(), `0.${"0".repeat(319)}`, ), - ).toThrow(/compatible field value/); + ).toThrow(/Invalid scalar string/); expect(() => readCompatibleScalar(7257, Type.string(), Type.decimal(), "1e1000000"), - ).toThrow(/compatible field value/); + ).toThrow(/Invalid scalar string/); expect(() => readCompatibleScalar(7258, Type.string(), Type.decimal(), "1e256"), - ).toThrow(/compatible field value/); + ).toThrow(/Invalid scalar string/); expect(() => readCompatibleScalar(7233, Type.decimal(), Type.int32(), decimal(5n, 1)), - ).toThrow(/compatible field value/); + ).toThrow(/not an integer/); expect(() => readCompatibleScalar( 7259, @@ -584,7 +740,7 @@ describe("typemeta", () => { Type.string(), decimal(1n, -256), ), - ).toThrow(/compatible field value/); + ).toThrow(/magnitude exceeds compatible conversion limit/); expect(() => readCompatibleScalar( 7234, @@ -592,10 +748,10 @@ describe("typemeta", () => { Type.int8(), 128, ), - ).toThrow(/compatible field value/); + ).toThrow(/outside int8 range/); expect(() => readCompatibleScalar(7235, Type.float64(), Type.string(), Number.NaN), - ).toThrow(/compatible field value/); + ).toThrow(/Non-finite scalar value NaN/); }); test("composes scalar conversion with nulls", () => { @@ -625,69 +781,181 @@ describe("typemeta", () => { ).toEqual({ value: true }); }); - test("applies tracking-ref scalar rules", () => { + test("rejects tracking-ref scalar mismatches", () => { const writerFory = new Fory({ compatible: true, ref: true }); const readerFory = new Fory({ compatible: true, ref: true }); class RemoteScalars { flag = "true"; - count = "1"; - same = true; - remoteNullable = true; - localNullable = true; - bothNullable = true; } class LocalScalars { flag = false; - count = 0; - same = false; - remoteNullable = false; - localNullable = false; - bothNullable = false; } Type.struct(7254, { flag: Type.string().setId(1).setTrackingRef(true), - count: Type.string().setId(2), - same: Type.bool().setId(3).setTrackingRef(true), - remoteNullable: Type.bool() - .setId(4) - .setTrackingRef(true) - .setNullable(true), - localNullable: Type.bool().setId(5).setTrackingRef(true), - bothNullable: Type.bool().setId(6).setTrackingRef(true).setNullable(true), })(RemoteScalars); Type.struct(7254, { flag: Type.bool().setId(1), - count: Type.int32().setId(2).setTrackingRef(true), - same: Type.bool().setId(3).setTrackingRef(true), - remoteNullable: Type.bool().setId(4).setTrackingRef(true), - localNullable: Type.bool() - .setId(5) - .setTrackingRef(true) - .setNullable(true), - bothNullable: Type.bool().setId(6).setTrackingRef(true).setNullable(true), })(LocalScalars); const writer = writerFory.register(RemoteScalars); const reader = readerFory.register(LocalScalars); - const result = reader.deserialize(writer.serialize(new RemoteScalars())); - expect(result).toBeInstanceOf(LocalScalars); - expect(result.flag).toBe(false); - expect(result.count).toBe(0); - expect(result.same).toBe(true); - expect(result.remoteNullable).toBe(false); - expect(result.localNullable).toBe(false); - expect(result.bothNullable).toBe(true); + expect(() => + reader.deserialize(writer.serialize(new RemoteScalars())), + ).toThrow(/unsupported compatible scalar tracking-ref schema mismatch/); }); - test("keeps nested scalars unconverted", () => { - expect( + test("rejects incompatible matched fields", () => { + const writerFory = new Fory({ compatible: true }); + const readerFory = new Fory({ compatible: true }); + const writer = writerFory.register( + Type.struct(7255, { + value: Type.string().setId(1), + }), + ); + const reader = readerFory.register( + Type.struct(7255, { + value: Type.map(Type.string(), Type.int32()).setId(1), + }), + ); + + expect(() => + reader.deserialize(writer.serialize({ value: "abc" })), + ).toThrow(/unsupported compatible field schema mismatch/); + }); + + test("rejects nested scalar mismatches", () => { + expect(() => readCompatibleScalar( 7238, Type.list(Type.string()), Type.list(Type.int32()), ["1", "2"], ), - ).toEqual({ value: ["1", "2"] }); + ).toThrow(/unsupported compatible field schema mismatch/); + + expect(() => + readCompatibleScalar( + 7240, + Type.map(Type.string(), Type.string()), + Type.map(Type.string(), Type.int32()), + new Map([["one", "1"]]), + ), + ).toThrow(/unsupported compatible field schema mismatch/); + }); + + test("rejects nested collection shape drift", () => { + expect(() => + readCompatibleScalar( + 7263, + Type.list(Type.string()), + Type.list(Type.map(Type.string(), Type.int32())), + ["one"], + ), + ).toThrow(/unsupported compatible field schema mismatch/); + + expect(() => + readCompatibleScalar( + 7264, + Type.map(Type.string(), Type.list(Type.string())), + Type.map( + Type.string(), + Type.list(Type.map(Type.string(), Type.int32())), + ), + new Map([["values", ["one"]]]), + ), + ).toThrow(/unsupported compatible field schema mismatch/); + + expect(() => + readCompatibleScalar( + 7268, + Type.list(Type.any()), + Type.list(Type.struct(7269, { name: Type.string() })), + ["one"], + ), + ).toThrow(/unsupported compatible field schema mismatch/); + }); + + test("reads nested scalar nullable drift", () => { + expect( + readCompatibleScalar( + 7241, + Type.list(Type.string().setNullable(true)), + Type.list(Type.string()), + ["a", null], + ), + ).toEqual({ value: ["a", null] }); + expect( + readCompatibleScalar( + 7246, + Type.map(Type.string(), Type.string().setNullable(true)), + Type.map(Type.string(), Type.string()), + new Map([["a", null]]), + ), + ).toEqual({ value: new Map([["a", null]]) }); + }); + + test("rejects nested scalar tracking-ref drift", () => { + expect(() => + readCompatibleScalar( + 7242, + Type.list(Type.string().setTrackingRef(true)), + Type.list(Type.string()), + ["a"], + ), + ).toThrow(/unsupported compatible field schema mismatch/); + }); + + test("reuses local struct metadata across struct wire families", () => { + const fory = new Fory({ compatible: true }); + const readContext = (fory as any).readContext; + const local = Type.struct(7243, { + name: Type.string(), + }); + const remote = { + typeId: TypeId.STRUCT, + nullable: false, + trackingRef: false, + options: {}, + }; + + const regenerated = readContext.fieldInfoToTypeInfo(remote, local); + + expect(regenerated.typeId).toBe(local.typeId); + expect(regenerated.options.props.name.typeId).toBe(TypeId.STRING); + }); + + test("reuses local ext metadata across ext wire families", () => { + const fory = new Fory({ compatible: true }); + const readContext = (fory as any).readContext; + const numericLocal = Type.ext(7244); + const namedRemote = { + typeId: TypeId.NAMED_EXT, + nullable: false, + trackingRef: false, + options: {}, + }; + + const numericRegenerated = readContext.fieldInfoToTypeInfo( + namedRemote, + numericLocal, + ); + expect(numericRegenerated.typeId).toBe(numericLocal.typeId); + expect(numericRegenerated.userTypeId).toBe(7244); + + const namedLocal = Type.ext("example.External"); + const numericRemote = { + typeId: TypeId.EXT, + nullable: false, + trackingRef: false, + options: {}, + }; + + const namedRegenerated = readContext.fieldInfoToTypeInfo( + numericRemote, + namedLocal, + ); + expect(namedRegenerated.typeId).toBe(namedLocal.typeId); + expect(namedRegenerated.typeName).toBe("External"); }); test("keeps same-schema scalar reads direct", () => { @@ -820,7 +1088,46 @@ describe("typemeta", () => { expect(result).toEqual({ values: [1, 2, 3] }); }); - test("rejects compatible list to dense array when payload has nullable elements", () => { + test("adapts immediate binary and uint8 array field pairs", () => { + const bytes = new Uint8Array([0, 1, 2, 250, 255]); + expect( + Array.from( + readCompatibleScalar(7265, Type.binary(), Type.uint8Array(), bytes) + .value as Uint8Array, + ), + ).toEqual(Array.from(bytes)); + + expect( + Array.from( + readCompatibleScalar(7266, Type.uint8Array(), Type.binary(), bytes) + .value as Uint8Array, + ), + ).toEqual(Array.from(bytes)); + + expect( + Array.from( + readCompatibleScalar( + 7270, + Type.binary().setTrackingRef(true), + Type.uint8Array().setTrackingRef(true), + bytes, + ).value as Uint8Array, + ), + ).toEqual(Array.from(bytes)); + }); + + test("rejects nested binary and uint8 array positions", () => { + expect(() => + readCompatibleScalar( + 7267, + Type.list(Type.binary()), + Type.list(Type.uint8Array()), + [new Uint8Array([1, 2])], + ), + ).toThrow(/unsupported compatible field schema mismatch/); + }); + + test("adapts compatible nullable list schema to dense array", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); @@ -834,18 +1141,37 @@ describe("typemeta", () => { }); const serializer = writerFory.register(writerType); - const nonNullBytes = serializer.serialize({ + const bytes = serializer.serialize({ values: [1, 2, 3], }); - const result = readerFory.register(readerType).deserialize(nonNullBytes); - expect(Array.from(result.values as Int32Array)).toEqual([1, 2, 3]); + const reader = readerFory.register(readerType); + const decoded = reader.deserialize(bytes); + expect(Array.from(decoded.values as Int32Array)).toEqual([1, 2, 3]); - const nullableBytes = serializer.serialize({ + const nullBytes = serializer.serialize({ values: [1, null, 3], }); + expect(() => reader.deserialize(nullBytes)).toThrow(/nullable/); + }); + + test("rejects compatible list and dense array root framing drift", () => { + expect(() => + readCompatibleScalar( + 7261, + Type.list(Type.int32({ encoding: "fixed" })).setNullable(true), + Type.int32Array(), + [1, 2, 3], + ), + ).toThrow(/list\/array/); + expect(() => - readerFory.register(readerType).deserialize(nullableBytes), - ).toThrow(); + readCompatibleScalar( + 7262, + Type.int32Array(), + Type.list(Type.int32({ encoding: "fixed" })).setNullable(true), + new Int32Array([1, 2, 3]), + ), + ).toThrow(/list\/array/); }); test("rejects incompatible immediate list and dense array element fields", () => { @@ -888,7 +1214,7 @@ describe("typemeta", () => { ); }); - test("keeps compatible named schema evolution working when field count differs", () => { + test("skips remote-only named compatible fields", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); @@ -908,8 +1234,8 @@ describe("typemeta", () => { expect(result).toEqual({ bar: "hello", - bar2: 123, }); + expect((result as { bar2?: number }).bar2).toBeUndefined(); }); test("remaps regenerated compatible field names onto local snake_case properties", () => { @@ -946,7 +1272,9 @@ describe("typemeta", () => { expect( (result as ReaderHolder & { animalMap?: Map }).animalMap, ).toBeUndefined(); - expect((result as ReaderHolder & { marker?: number }).marker).toBe(99); + expect( + (result as ReaderHolder & { marker?: number }).marker, + ).toBeUndefined(); }); test("skips unknown named custom fields by falling back to any when no local field exists", () => { diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt index 099fc1bdad..0c288ef604 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt @@ -75,6 +75,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru builder.append("import org.apache.fory.serializer.FieldGroups\n") builder.append("import org.apache.fory.serializer.FieldGroups.SerializationFieldInfo\n") builder.append("import org.apache.fory.serializer.StaticGeneratedStructSerializer\n") + builder.append("import org.apache.fory.serializer.converter.FieldConverters\n") builder.append("import org.apache.fory.type.Descriptor\n") builder.append("import org.apache.fory.type.BFloat16Array\n") builder.append("import org.apache.fory.type.Float16Array\n") @@ -280,113 +281,6 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru return } - builder - .append(" private fun readCompatibleConstructor(readContext: ReadContext): ") - .append(struct.typeName) - .append(" {\n") - builder.append(" val fieldValues = arrayOfNulls(DESCRIPTORS.size)\n") - builder.append(" val bufferedFields = newFieldBits(DESCRIPTORS.size)\n") - builder.append(" val presentFields = newFieldBits(DESCRIPTORS.size)\n") - builder.append(" beginConstructorRef(readContext)\n") - builder.append(" try {\n") - builder.append(" var remaining = countConstructorFields(constructorFieldBits!!)\n") - builder.append(" var value: ").append(struct.typeName).append("? = null\n") - builder.append(" if (remaining == 0) {\n") - builder.append(" val constructed = newConstructorObject(fieldValues)\n") - builder.append(" value = constructed\n") - builder.append(" referenceConstructorRef(readContext, constructed)\n") - builder.append(" }\n") - builder.append(" for (i in remoteFields.indices) {\n") - builder.append(" val remoteField = remoteFields[i]\n") - builder.append(" val fieldId = remoteField.matchedId\n") - builder.append(" if (fieldId < 0) {\n") - builder.append(" skipField(readContext, remoteField)\n") - builder.append(" continue\n") - builder.append(" }\n") - builder.append(" val localField = fieldsById[fieldId]!!\n") - builder.append(" if (!canReadGeneratedField(remoteField, localField)) {\n") - builder.append(" skipField(readContext, remoteField)\n") - builder.append(" continue\n") - builder.append(" }\n") - builder.append( - " val fieldValue = readCompatibleConstructorField(readContext, remoteField, localField, fieldId)\n" - ) - builder.append(" markField(presentFields, fieldId)\n") - builder.append(" if (hasField(constructorFieldBits!!, fieldId)) {\n") - builder.append( - " fieldValues[fieldId] = ctorFieldValue(readContext, fieldValue, type)\n" - ) - builder.append(" remaining--\n") - builder.append(" if (remaining == 0) {\n") - builder.append(" checkNoUnresolvedReadRef(readContext)\n") - builder.append(" val constructed = newConstructorObject(fieldValues)\n") - builder.append(" value = constructed\n") - builder.append(" referenceConstructorRef(readContext, constructed)\n") - builder.append(" setBufferedFields(constructed, fieldValues, bufferedFields)\n") - builder.append(" }\n") - builder.append(" } else if (value == null) {\n") - builder.append( - " fieldValues[fieldId] = bufferFieldValue(readContext, fieldValue, type)\n" - ) - builder.append(" markField(bufferedFields, fieldId)\n") - builder.append(" } else {\n") - builder.append(" setFieldById(value!!, localField, fieldId, fieldValue)\n") - builder.append(" }\n") - builder.append(" }\n") - for (field in struct.fields) { - if (field.hasDefault || field.nullable) { - continue - } - builder.append(" if (!hasField(presentFields, ").append(field.id).append(")) {\n") - builder - .append(" throw DeserializationException(\"Required Kotlin field ") - .append(struct.qualifiedTypeName) - .append('.') - .append(field.name) - .append(" is missing in compatible xlang payload\")\n") - builder.append(" }\n") - } - builder.append(" if (value == null) {\n") - builder.append(" checkNoUnresolvedReadRef(readContext)\n") - builder.append(" val constructed = newConstructorObject(fieldValues)\n") - builder.append(" value = constructed\n") - builder.append(" referenceConstructorRef(readContext, constructed)\n") - builder.append(" setBufferedFields(constructed, fieldValues, bufferedFields)\n") - builder.append(" }\n") - builder.append(" return value!!\n") - builder.append(" } finally {\n") - builder.append(" endConstructorRef(readContext)\n") - builder.append(" }\n") - builder.append(" }\n\n") - - builder.append( - " private fun readCompatibleConstructorField(readContext: ReadContext, remoteField: StaticGeneratedStructSerializer.RemoteFieldInfo, localField: SerializationFieldInfo, fieldId: Int): Any? {\n" - ) - builder.append(" val buffer = readContext.buffer\n") - builder.append(" return when (fieldId) {\n") - for (field in struct.fields) { - val readExpression = - castReadExpression( - field, - "readCompatibleFieldValue(readContext, remoteField, localField)", - compatible = true, - ) - val expression = constructorReadExpression(field, readExpression) - if (field.trackingRef) { - builder.append(" ").append(field.id).append(" -> {\n") - builder.append(" trackConstructorRefRead(readContext, buffer)\n") - builder.append(" ").append(expression).append("\n") - builder.append(" }\n") - } else { - builder.append(" ").append(field.id).append(" -> ").append(expression).append("\n") - } - } - builder.append( - " else -> throw IllegalStateException(\"Unknown generated field id \${fieldId}\")\n" - ) - builder.append(" }\n") - builder.append(" }\n\n") - builder .append(" private fun newConstructorObject(fieldValues: Array): ") .append(struct.typeName) @@ -795,18 +689,10 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru builder.append(" if (sameSchemaCompatible) {\n") builder.append(" return readSchemaConsistent(readContext)\n") builder.append(" }\n") - if ( - struct.construction == KotlinStructConstruction.CONSTRUCTOR && - struct.fields.any { it.hasDefault } - ) { - builder.append(" return readCompatibleDefaultConstructor(readContext)\n") - builder.append(" }\n\n") - writeCompatibleDefaultConstructorRead() - return - } if (struct.construction == KotlinStructConstruction.CONSTRUCTOR) { builder.append(" return readCompatibleConstructor(readContext)\n") builder.append(" }\n\n") + writeCompatibleConstructorRead() return } if (struct.construction == KotlinStructConstruction.MUTABLE) { @@ -818,9 +704,9 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru builder.append(" }\n\n") } - private fun writeCompatibleDefaultConstructorRead() { + private fun writeCompatibleConstructorRead() { builder - .append(" private fun readCompatibleDefaultConstructor(readContext: ReadContext): ") + .append(" private fun readCompatibleConstructor(readContext: ReadContext): ") .append(struct.typeName) .append(" {\n") builder.append(" beginConstructorRef(readContext)\n") @@ -839,14 +725,49 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru builder.append(indent).append(" val remoteField = remoteFields[i]\n") builder.append(indent).append(" when (remoteField.matchedId) {\n") for (field in struct.fields) { - builder.append(indent).append(" ").append(field.id).append(" -> {\n") + builder.append(indent).append(" ").append(field.id * 2).append(" -> {\n") + builder.append(indent).append(" val buffer = readContext.buffer\n") + val directRead = + directReadExpression(field) + ?: castReadExpression( + field, + "readFieldValue(readContext, fieldsById[${field.id}]!!)", + compatible = false, + ) + val constructorDirectRead = + if (constructorRefs && field.trackingRef) { + "run { trackConstructorRefRead(readContext, buffer); ctorFieldValue(readContext, $directRead, type) }" + } else if (constructorRefs) { + "ctorFieldValue(readContext, $directRead, type)" + } else { + directRead + } + val directAssignment = + if (constructorRefs) { + localValueExpression(field, constructorDirectRead) + } else { + constructorDirectRead + } + builder + .append(indent) + .append(" ") + .append(field.localName) + .append(" = ") + .append(directAssignment) + .append("\n") + builder.append(indent).append(" ") + appendPresenceSet(field) + builder.append("\n") + builder.append(indent).append(" }\n") + builder.append(indent).append(" ").append(field.id * 2 + 1).append(" -> {\n") builder .append(indent) .append(" val localField = fieldsById[") .append(field.id) .append("]!!\n") - builder.append(indent).append(" if (canReadGeneratedField(remoteField, localField)) {\n") - val readExpression = "readCompatibleFieldValue(readContext, remoteField, localField)" + val readExpression = + compatibleScalarReadExpression(field) + ?: "readCompatibleFieldValue(readContext, remoteField, localField)" val constructorReadExpression = if (constructorRefs && field.trackingRef) { "run { trackConstructorRefRead(readContext, readContext.buffer); ctorFieldValue(readContext, $readExpression, type) }" @@ -864,22 +785,25 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru castReadExpression( field, constructorReadExpression, - compatible = true, + compatible = readExpression.startsWith("readCompatibleFieldValue"), ) ) .append("\n") - builder.append(indent).append(" ") + builder.append(indent).append(" ") appendPresenceSet(field) builder.append("\n") - builder.append(indent).append(" } else {\n") - builder.append(indent).append(" skipField(readContext, remoteField)\n") - builder.append(indent).append(" }\n") builder.append(indent).append(" }\n") } - builder.append(indent).append(" else -> skipField(readContext, remoteField)\n") + builder.append(indent).append(" -1 -> skipField(readContext, remoteField)\n") + builder + .append(indent) + .append( + " else -> throw IllegalStateException(\"Invalid compatible matched id \${remoteField.matchedId}\")\n" + ) builder.append(indent).append(" }\n") builder.append(indent).append("}\n") - builder.append(indent).append("var missingDefaultMask = 0L\n") + writeMissingRequiredChecks(indent) + builder.append(indent).append("var missingDefaultMask = 0\n") val defaultFields = struct.fields.filter { it.hasDefault } for (field in struct.fields) { if (!field.hasDefault && field.nullable) { @@ -888,21 +812,12 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru builder.append(indent).append("if (") appendPresenceMissing(field) builder.append(") {\n") - when { - field.hasDefault -> - builder - .append(indent) - .append(" missingDefaultMask = missingDefaultMask or ") - .append(1L shl defaultFields.indexOf(field)) - .append("L\n") - else -> - builder - .append(indent) - .append(" throw DeserializationException(\"Required Kotlin field ") - .append(struct.qualifiedTypeName) - .append('.') - .append(field.name) - .append(" is missing in compatible xlang payload\")\n") + if (field.hasDefault) { + builder + .append(indent) + .append(" missingDefaultMask = missingDefaultMask or ") + .append(1 shl defaultFields.indexOf(field)) + .append("\n") } builder.append(indent).append("}\n") } @@ -922,48 +837,72 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru builder.append(" val remoteField = remoteFields[i]\n") builder.append(" when (remoteField.matchedId) {\n") for (field in struct.fields) { - builder.append(" ").append(field.id).append(" -> {\n") - builder.append(" val localField = fieldsById[").append(field.id).append("]!!\n") - builder.append(" if (canReadGeneratedField(remoteField, localField)) {\n") + builder.append(" ").append(field.id * 2).append(" -> {\n") + builder.append(" val buffer = readContext.buffer\n") + val directRead = + directReadExpression(field) + ?: castReadExpression( + field, + "readFieldValue(readContext, fieldsById[${field.id}]!!)", + compatible = false, + ) builder .append(" value.") .append(field.name) .append(" = ") + .append(directRead) + .append("\n") + builder.append(" ") + appendPresenceSet(field) + builder.append("\n") + builder.append(" }\n") + builder.append(" ").append(field.id * 2 + 1).append(" -> {\n") + builder.append(" val localField = fieldsById[").append(field.id).append("]!!\n") + builder + .append(" value.") + .append(field.name) + .append(" = ") .append( castReadExpression( field, - "readCompatibleFieldValue(readContext, remoteField, localField)", - compatible = true, + compatibleScalarReadExpression(field) + ?: "readCompatibleFieldValue(readContext, remoteField, localField)", + compatible = compatibleScalarReadExpression(field) == null, ) ) .append("\n") - builder.append(" ") + builder.append(" ") appendPresenceSet(field) builder.append("\n") - builder.append(" } else {\n") - builder.append(" skipField(readContext, remoteField)\n") - builder.append(" }\n") builder.append(" }\n") } - builder.append(" else -> skipField(readContext, remoteField)\n") + builder.append(" -1 -> skipField(readContext, remoteField)\n") + builder.append( + " else -> throw IllegalStateException(\"Invalid compatible matched id \${remoteField.matchedId}\")\n" + ) builder.append(" }\n") builder.append(" }\n") + writeMissingRequiredChecks(" ") + builder.append(" return value\n") + } + + private fun writeMissingRequiredChecks(indent: String) { for (field in struct.fields) { - if (field.nullable) { + if (field.hasDefault || field.nullable) { continue } - builder.append(" if (") + builder.append(indent).append("if (") appendPresenceMissing(field) builder.append(") {\n") builder - .append(" throw DeserializationException(\"Required Kotlin field ") + .append(indent) + .append(" throw DeserializationException(\"Required Kotlin field ") .append(struct.qualifiedTypeName) - .append('.') + .append(".") .append(field.name) - .append(" is missing in compatible xlang payload\")\n") - builder.append(" }\n") + .append(" is missing\")\n") + builder.append(indent).append("}\n") } - builder.append(" return value\n") } private fun writePresenceVars(indent: String = " ") { @@ -1093,13 +1032,13 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru if (defaultFields.isEmpty()) { if (referenceConstructor) { builder.append(indent).append("val constructed = ") - appendConstructorCall(defaultMask = 0L) + appendConstructorCall(defaultMask = 0) builder.append("\n") builder.append(indent).append("referenceConstructorRef(readContext, constructed)\n") builder.append(indent).append("return constructed\n") } else { builder.append(indent).append("return ") - appendConstructorCall(defaultMask = 0L) + appendConstructorCall(defaultMask = 0) builder.append("\n") } return @@ -1109,15 +1048,15 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru } else { builder.append(indent).append("return when (missingDefaultMask) {\n") } - val combinations = 1L shl defaultFields.size + val combinations = 1 shl defaultFields.size for (combination in 0 until combinations) { - var mask = 0L + var mask = 0 for (i in defaultFields.indices) { - if ((combination and (1L shl i)) != 0L) { - mask = mask or (1L shl i) + if ((combination and (1 shl i)) != 0) { + mask = mask or (1 shl i) } } - builder.append(indent).append(" ").append(mask).append("L -> ") + builder.append(indent).append(" ").append(mask).append(" -> ") appendConstructorCall(defaultMask = mask) builder.append("\n") } @@ -1133,14 +1072,14 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru } } - private fun appendConstructorCall(defaultMask: Long) { + private fun appendConstructorCall(defaultMask: Int) { val defaultFields = struct.fields.filter { it.hasDefault } builder.append(struct.typeName).append("(") var first = true for (field in struct.fields) { if (field.hasDefault) { val defaultIndex = defaultFields.indexOf(field) - if (defaultIndex >= 0 && (defaultMask and (1L shl defaultIndex)) != 0L) { + if (defaultIndex >= 0 && (defaultMask and (1 shl defaultIndex)) != 0) { continue } } @@ -1180,6 +1119,13 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru return "($source as ${field.propertyTypeName})" } + private fun localValueExpression(field: KotlinSourceField, expression: String): String { + if (field.type.valueTypeName == "Any?") { + return expression + } + return "($expression as ${field.type.valueTypeName})" + } + private fun constructorValueExpression(field: KotlinSourceField): String { val localValue = if (field.nullable || field.type.primitive || isScalarUnsigned(field)) { @@ -1402,6 +1348,88 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru return "($expression as ${field.type.valueTypeName})" } + private fun compatibleScalarReadExpression(field: KotlinSourceField): String? { + if (field.type.componentType != null || field.type.typeArguments.isNotEmpty()) { + return null + } + val nullable = field.type.nullable + val helperCall = + when (field.type.valueTypeName.removeSuffix("?")) { + "Boolean" -> + if (nullable) { + "FieldConverters.readBoxedBooleanTarget(readContext, remoteField.serializationFieldInfo, localField)" + } else { + "FieldConverters.readBooleanTarget(readContext, remoteField.serializationFieldInfo, localField)" + } + "Byte" -> + if (nullable) { + "FieldConverters.readBoxedByteTarget(readContext, remoteField.serializationFieldInfo, localField)" + } else { + "FieldConverters.readByteTarget(readContext, remoteField.serializationFieldInfo, localField)" + } + "Short" -> + if (nullable) { + "FieldConverters.readBoxedShortTarget(readContext, remoteField.serializationFieldInfo, localField)" + } else { + "FieldConverters.readShortTarget(readContext, remoteField.serializationFieldInfo, localField)" + } + "Int" -> + if (nullable) { + "FieldConverters.readBoxedIntTarget(readContext, remoteField.serializationFieldInfo, localField)" + } else { + "FieldConverters.readIntTarget(readContext, remoteField.serializationFieldInfo, localField)" + } + "Long" -> + if (nullable) { + "FieldConverters.readBoxedLongTarget(readContext, remoteField.serializationFieldInfo, localField)" + } else { + "FieldConverters.readLongTarget(readContext, remoteField.serializationFieldInfo, localField)" + } + "Float" -> + if (nullable) { + "FieldConverters.readBoxedFloatTarget(readContext, remoteField.serializationFieldInfo, localField)" + } else { + "FieldConverters.readFloatTarget(readContext, remoteField.serializationFieldInfo, localField)" + } + "Double" -> + if (nullable) { + "FieldConverters.readBoxedDoubleTarget(readContext, remoteField.serializationFieldInfo, localField)" + } else { + "FieldConverters.readDoubleTarget(readContext, remoteField.serializationFieldInfo, localField)" + } + "String" -> + "FieldConverters.readStringTarget(readContext, remoteField.serializationFieldInfo, localField)" + "java.math.BigDecimal" -> + "FieldConverters.readDecimalTarget(readContext, remoteField.serializationFieldInfo, localField)" + "UByte" -> + if (nullable) { + "FieldConverters.readBoxedIntTarget(readContext, remoteField.serializationFieldInfo, localField)?.toUByte()" + } else { + "FieldConverters.readIntTarget(readContext, remoteField.serializationFieldInfo, localField).toUByte()" + } + "UShort" -> + if (nullable) { + "FieldConverters.readBoxedIntTarget(readContext, remoteField.serializationFieldInfo, localField)?.toUShort()" + } else { + "FieldConverters.readIntTarget(readContext, remoteField.serializationFieldInfo, localField).toUShort()" + } + "UInt" -> + if (nullable) { + "FieldConverters.readBoxedLongTarget(readContext, remoteField.serializationFieldInfo, localField)?.toUInt()" + } else { + "FieldConverters.readLongTarget(readContext, remoteField.serializationFieldInfo, localField).toUInt()" + } + "ULong" -> + if (nullable) { + "FieldConverters.readBoxedLongTarget(readContext, remoteField.serializationFieldInfo, localField)?.toULong()" + } else { + "FieldConverters.readLongTarget(readContext, remoteField.serializationFieldInfo, localField).toULong()" + } + else -> return null + } + return helperCall + } + private fun hasKotlinScalar(type: KotlinSourceTypeNode): Boolean = (type.unsigned && type.componentType == null) || type.typeId == "Types.DURATION" || diff --git a/kotlin/fory-kotlin-ksp/src/test/kotlin/org/apache/fory/kotlin/ksp/ProcessorValidationTest.kt b/kotlin/fory-kotlin-ksp/src/test/kotlin/org/apache/fory/kotlin/ksp/ProcessorValidationTest.kt index c279e8667a..2ce36fa3dc 100644 --- a/kotlin/fory-kotlin-ksp/src/test/kotlin/org/apache/fory/kotlin/ksp/ProcessorValidationTest.kt +++ b/kotlin/fory-kotlin-ksp/src/test/kotlin/org/apache/fory/kotlin/ksp/ProcessorValidationTest.kt @@ -175,7 +175,9 @@ class ProcessorValidationTest { .write() assertTrue(source.contains("private fun readSchemaConstructorField")) - assertTrue(source.contains("private fun readCompatibleConstructorField")) + assertTrue( + source.contains("private fun readCompatibleConstructor(readContext: ReadContext): Node") + ) assertTrue(source.contains("trackConstructorRefRead(readContext, buffer)")) } @@ -524,19 +526,19 @@ class ProcessorValidationTest { ) ) assertTrue( - source.contains( - "KotlinCollectionAdapters.toTreeMap((readCompatibleFieldValue(readContext, remoteField, localField) as kotlin.collections.Map))" - ) + source.contains("KotlinCollectionAdapters.toTreeMap") && + source.contains("readCompatibleFieldValue(readContext, remoteField, localField)") && + source.contains("ctorFieldValue(readContext") ) + assertTrue(source.contains("3 -> {") && source.contains("fieldsById[1]!!")) assertTrue( - source.contains( - "1 -> run { val readSource0 = ((readCompatibleFieldValue(readContext, remoteField, localField) as kotlin.collections.List>) as Collection<*>);" - ) + source.contains("readCompatibleFieldValue(readContext, remoteField, localField)") && + source.contains("as Collection<*>") ) + assertTrue(source.contains("7 -> {") && source.contains("fieldsById[3]!!")) assertTrue( - source.contains( - "3 -> run { val readSource0 = ((readCompatibleFieldValue(readContext, remoteField, localField) as java.util.TreeMap>) as Map<*, *>); val readTarget0 = java.util.TreeMap();" - ) + source.contains("readCompatibleFieldValue(readContext, remoteField, localField)") && + source.contains("as Map<*, *>") ) assertTrue( source.contains("KotlinCollectionAdapters.toTreeSet((readEntry0.value as Collection<*>))") @@ -617,8 +619,8 @@ class ProcessorValidationTest { val copyStart = source.indexOf("override fun copy") val compatibleSource = source.substring(compatibleStart, copyStart) - assertFalse(compatibleSource.contains("return readCompatibleConstructor(readContext)")) - assertTrue(compatibleSource.contains("return readCompatibleDefaultConstructor(readContext)")) + assertTrue(compatibleSource.contains("return readCompatibleConstructor(readContext)")) + assertFalse(compatibleSource.contains("readCompatibleDefaultConstructor")) assertTrue(compatibleSource.contains("beginConstructorRef(readContext)")) assertTrue(compatibleSource.contains("checkNoUnresolvedReadRef(readContext)")) assertTrue( @@ -628,6 +630,7 @@ class ProcessorValidationTest { assertTrue(compatibleSource.contains("endConstructorRef(readContext)")) assertTrue(compatibleSource.contains("missingDefaultMask")) assertTrue(compatibleSource.contains("val constructed = when (missingDefaultMask)")) + assertFalse(compatibleSource.contains("fieldValues = arrayOfNulls")) } @Test @@ -682,6 +685,37 @@ class ProcessorValidationTest { unsigned = true, ), ), + field( + 7, + "nullableFlag", + scalar( + "Boolean::class.javaObjectType", + "Boolean?", + "java.lang.Boolean", + "Types.BOOL", + false + ) + .copy(nullable = true), + ), + field( + 8, + "nullableNumber", + scalar("Int::class.javaObjectType", "Int?", "java.lang.Integer", "Types.INT32", false) + .copy(nullable = true), + ), + field( + 9, + "nullableUnsigned", + scalar( + "Int::class.javaObjectType", + "UInt?", + "java.lang.Integer", + "Types.UINT32", + false, + unsigned = true + ) + .copy(nullable = true), + ), ) val source = KotlinSerializerSourceWriter( @@ -701,10 +735,24 @@ class ProcessorValidationTest { val copyStart = source.indexOf("override fun copy") val compatibleSource = source.substring(compatibleStart, copyStart) - assertTrue(compatibleSource.contains("if (canReadGeneratedField(remoteField, localField))")) - assertTrue( + assertFalse(compatibleSource.contains("canReadGeneratedField")) + assertFalse(compatibleSource.contains("arrayOfNulls")) + assertFalse(compatibleSource.contains("Array")) + assertTrue(compatibleSource.contains("when (remoteField.matchedId)")) + assertTrue(compatibleSource.contains("0 -> {")) + assertTrue(compatibleSource.contains("1 -> {")) + assertTrue(compatibleSource.contains("-1 -> skipField(readContext, remoteField)")) + assertFalse( compatibleSource.contains("readCompatibleFieldValue(readContext, remoteField, localField)") ) + assertTrue(compatibleSource.contains("FieldConverters.readBooleanTarget")) + assertTrue(compatibleSource.contains("FieldConverters.readIntTarget")) + assertTrue(compatibleSource.contains("FieldConverters.readStringTarget")) + assertTrue(compatibleSource.contains("FieldConverters.readDecimalTarget")) + assertTrue(compatibleSource.contains("FieldConverters.readLongTarget")) + assertTrue(compatibleSource.contains("FieldConverters.readBoxedBooleanTarget")) + assertTrue(compatibleSource.contains("FieldConverters.readBoxedIntTarget")) + assertTrue(compatibleSource.contains("FieldConverters.readBoxedLongTarget")) assertTrue( compatibleSource.contains("value.flag =") && compatibleSource.contains(" as Boolean") ) @@ -720,11 +768,10 @@ class ProcessorValidationTest { compatibleSource.contains("value.decimalValue =") && compatibleSource.contains(" as String") ) assertTrue(compatibleSource.contains("value.narrow =") && compatibleSource.contains(" as Int")) - assertTrue( - compatibleSource.contains("value.unsigned = run") && - compatibleSource.contains("is org.apache.fory.type.unsigned.UInt32") && - compatibleSource.contains("compatibleValue0.toLong().toUInt()") - ) + assertTrue(compatibleSource.contains("value.unsigned =")) + assertTrue(compatibleSource.contains("FieldConverters.readLongTarget")) + assertTrue(compatibleSource.contains(".toUInt()")) + assertTrue(compatibleSource.contains("?.toUInt()")) } @Test @@ -775,7 +822,7 @@ class ProcessorValidationTest { assertTrue(source.contains("var presentMask0 = 0L")) assertTrue(source.contains("var presentMask1 = 0L")) assertTrue(source.contains("presentMask1 = presentMask1 or (1L shl 5)")) - assertTrue(source.contains("if ((presentMask1 and (1L shl 5)) == 0L)")) + assertTrue(source.contains("Required Kotlin field example.WideStruct.field69 is missing")) } @Test @@ -1395,14 +1442,12 @@ class ProcessorValidationTest { ) .write() - assertTrue(source.contains("is org.apache.fory.type.unsigned.UInt8")) - assertTrue(source.contains("compatibleValue0.toInt().toUByte()")) - assertTrue(source.contains("is org.apache.fory.type.unsigned.UInt16")) - assertTrue(source.contains("compatibleValue0.toInt().toUShort()")) - assertTrue(source.contains("is org.apache.fory.type.unsigned.UInt32")) - assertTrue(source.contains("compatibleValue0.toLong().toUInt()")) - assertTrue(source.contains("is org.apache.fory.type.unsigned.UInt64")) - assertTrue(source.contains("compatibleValue0.toLong().toULong()")) + assertTrue(source.contains("FieldConverters.readIntTarget")) + assertTrue(source.contains(".toUByte()")) + assertTrue(source.contains(".toUShort()")) + assertTrue(source.contains("FieldConverters.readLongTarget")) + assertTrue(source.contains(".toUInt()")) + assertTrue(source.contains(".toULong()")) assertFalse(source.contains("as Number).to")) } diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index 6a663cac8f..94c9d3a50c 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -46,6 +46,17 @@ ctypedef PyObject *PyObjectPtr cdef class ListSerializer +cdef inline bint needs_element_type_info(uint8_t type_id): + return ( + type_id == TypeId.STRUCT + or type_id == TypeId.COMPATIBLE_STRUCT + or type_id == TypeId.NAMED_STRUCT + or type_id == TypeId.NAMED_COMPATIBLE_STRUCT + or type_id == TypeId.EXT + or type_id == TypeId.NAMED_EXT + ) + + cdef class CollectionSerializer(Serializer): # Element serializers may be Cython or pure-Python implementations. cdef Serializer elem_serializer @@ -108,7 +119,9 @@ cdef class CollectionSerializer(Serializer): if elem_type is not None: elem_type_info = self.type_resolver.get_type_info(elem_type) else: - collect_flag |= COLL_IS_DECL_ELEMENT_TYPE | COLL_IS_SAME_TYPE + collect_flag |= COLL_IS_SAME_TYPE + if not needs_element_type_info(elem_type_info.type_id): + collect_flag |= COLL_IS_DECL_ELEMENT_TYPE if items != NULL: for i in range(size): if items[i] == None: diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index ea17bcd9a3..f9f0610c0d 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -26,6 +26,7 @@ from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION from pyfory._serializer import Serializer, StringSerializer from pyfory.resolver import NOT_NULL_VALUE_FLAG, NULL_FLAG +from pyfory.types import TypeId COLL_DEFAULT_FLAG = 0b0 COLL_TRACKING_REF = 0b1 @@ -34,6 +35,17 @@ COLL_IS_SAME_TYPE = 0b1000 +def _needs_element_type_info(type_id): + return type_id in { + TypeId.STRUCT, + TypeId.COMPATIBLE_STRUCT, + TypeId.NAMED_STRUCT, + TypeId.NAMED_COMPATIBLE_STRUCT, + TypeId.EXT, + TypeId.NAMED_EXT, + } + + class CollectionSerializer(Serializer): __slots__ = ( "elem_serializer", @@ -78,7 +90,9 @@ def write_header(self, write_context, value): if elem_type is not None: elem_type_info = self.type_resolver.get_type_info(elem_type) else: - collect_flag |= COLL_IS_DECL_ELEMENT_TYPE | COLL_IS_SAME_TYPE + collect_flag |= COLL_IS_SAME_TYPE + if not _needs_element_type_info(elem_type_info.type_id): + collect_flag |= COLL_IS_DECL_ELEMENT_TYPE for item in value: if item is None: has_null = True diff --git a/python/pyfory/meta/typedef.py b/python/pyfory/meta/typedef.py index 0ac31eaced..32a70773bd 100644 --- a/python/pyfory/meta/typedef.py +++ b/python/pyfory/meta/typedef.py @@ -121,6 +121,24 @@ def xlang_non_struct_type_id(kind_code: int) -> int: raise ValueError(f"Unsupported TypeDef kind code {kind_code}") from exc +def _normalize_user_type_id(type_id: int) -> int: + if type_id in { + TypeId.STRUCT, + TypeId.COMPATIBLE_STRUCT, + TypeId.NAMED_STRUCT, + TypeId.NAMED_COMPATIBLE_STRUCT, + TypeId.UNKNOWN, + }: + return TypeId.STRUCT + if type_id in {TypeId.ENUM, TypeId.NAMED_ENUM}: + return TypeId.ENUM + if type_id in {TypeId.EXT, TypeId.NAMED_EXT}: + return TypeId.EXT + if type_id in {TypeId.UNION, TypeId.TYPED_UNION, TypeId.NAMED_UNION}: + return TypeId.UNION + return type_id + + def _typedef_header_hash(encoded: bytes, header_low_bits: int) -> int: hash_input = encoded + bytes((header_low_bits & 0xFF, (header_low_bits >> 8) & 0xFF)) hash_value = hash_buffer(hash_input, 47)[0] @@ -261,6 +279,10 @@ def create_serializer(self, resolver): field_info.field_type, local_info.field_type if local_info is not None else None, ) + if local_info is not None and not can_assign: + from pyfory.error import TypeNotCompatibleError + + raise TypeNotCompatibleError(f"Compatible field {resolved_name!r} cannot be read as local field {local_info.name!r}") type_hint = type_hints.get(resolved_name, typing.Any) unwrapped_type, _ = unwrap_optional(type_hint, field_nullable=resolver.field_nullable) serializer = _create_compatible_field_serializer( @@ -631,13 +653,21 @@ def __repr__(self): } -def _list_array_element_type_matches(list_field_type: FieldType, array_field_type: FieldType) -> bool: +def _list_array_element_type_matches(list_field_type: FieldType, array_field_type: FieldType, *, require_unframed_element: bool) -> bool: array_element_type_id = _ARRAY_ELEMENT_TYPE_IDS.get(array_field_type.type_id) if array_element_type_id is None: return False - return list_field_type.type_id == TypeId.LIST and _list_element_type_matches_array_element( - list_field_type.element_type.type_id, array_element_type_id - ) + if list_field_type.is_nullable or list_field_type.is_tracking_ref or array_field_type.is_nullable or array_field_type.is_tracking_ref: + return False + element_type = list_field_type.element_type + if element_type is None: + return False + # Nullable element schema is allowed for list -> array; actual null + # payload elements fail in the dense-array reader. Ref-tracked element + # framing is rejected here because this path stays primitive-only. + if require_unframed_element and element_type.is_tracking_ref: + return False + return list_field_type.type_id == TypeId.LIST and _list_element_type_matches_array_element(element_type.type_id, array_element_type_id) def _list_element_type_matches_array_element(list_element_type_id: TypeId, array_element_type_id: TypeId) -> bool: @@ -651,9 +681,9 @@ def _is_root_list_array_pair(remote_field_type: FieldType, local_field_type: Fie if local_field_type is None: return False if remote_field_type.type_id == TypeId.LIST and local_field_type.type_id in _ARRAY_TYPE_IDS: - return _list_array_element_type_matches(remote_field_type, local_field_type) + return _list_array_element_type_matches(remote_field_type, local_field_type, require_unframed_element=True) if local_field_type.type_id == TypeId.LIST and remote_field_type.type_id in _ARRAY_TYPE_IDS: - return _list_array_element_type_matches(local_field_type, remote_field_type) + return _list_array_element_type_matches(local_field_type, remote_field_type, require_unframed_element=False) return False @@ -672,25 +702,58 @@ def _remote_list_to_local_array_allowed(remote_field_type: FieldType, local_fiel return ( remote_field_type.type_id == TypeId.LIST and local_field_type.type_id in _ARRAY_TYPE_IDS - and _list_array_element_type_matches(remote_field_type, local_field_type) + and _list_array_element_type_matches(remote_field_type, local_field_type, require_unframed_element=True) ) -def _payload_shape_matches(remote_field_type: FieldType, local_field_type: FieldType, top_level: bool = True) -> bool: +def _exact_field_type_match(remote_field_type: FieldType, local_field_type: FieldType) -> bool: + if ( + remote_field_type.is_nullable != local_field_type.is_nullable + or remote_field_type.is_tracking_ref != local_field_type.is_tracking_ref + or _normalize_user_type_id(remote_field_type.type_id) != _normalize_user_type_id(local_field_type.type_id) + ): + return False + remote_type_id = remote_field_type.type_id + local_type_id = local_field_type.type_id + if remote_type_id in (TypeId.LIST, TypeId.SET): + return local_type_id == remote_type_id and _exact_field_type_match( + remote_field_type.element_type, + local_field_type.element_type, + ) + if remote_type_id == TypeId.MAP: + return ( + local_type_id == TypeId.MAP + and _exact_field_type_match(remote_field_type.key_type, local_field_type.key_type) + and _exact_field_type_match(remote_field_type.value_type, local_field_type.value_type) + ) + return True + + +def _compatible_field_type_match(remote_field_type: FieldType, local_field_type: FieldType, top_level: bool = True) -> bool: if local_field_type is None: return False remote_type_id = remote_field_type.type_id local_type_id = local_field_type.type_id - if _is_bytes_uint8_array_pair(remote_type_id, local_type_id): + if top_level and _is_bytes_uint8_array_pair(remote_type_id, local_type_id): return True if top_level and _is_root_list_array_pair(remote_field_type, local_field_type): return True - if remote_type_id != local_type_id: + if remote_type_id in _COMPATIBLE_SCALAR_TYPE_IDS or local_type_id in _COMPATIBLE_SCALAR_TYPE_IDS: + return remote_type_id == local_type_id + if _normalize_user_type_id(remote_type_id) != _normalize_user_type_id(local_type_id): return False if remote_type_id in (TypeId.LIST, TypeId.SET): - return _payload_shape_matches(remote_field_type.element_type, local_field_type.element_type, False) + return _compatible_field_type_match( + remote_field_type.element_type, + local_field_type.element_type, + False, + ) if remote_type_id == TypeId.MAP: - return _payload_shape_matches(remote_field_type.key_type, local_field_type.key_type, False) and _payload_shape_matches( + return _compatible_field_type_match( + remote_field_type.key_type, + local_field_type.key_type, + False, + ) and _compatible_field_type_match( remote_field_type.value_type, local_field_type.value_type, False, @@ -698,10 +761,14 @@ def _payload_shape_matches(remote_field_type: FieldType, local_field_type: Field return True +def _payload_shape_matches(remote_field_type: FieldType, local_field_type: FieldType, top_level: bool = True) -> bool: + return _compatible_field_type_match(remote_field_type, local_field_type, top_level) + + def _payload_shape_needs_local_carrier(remote_field_type: FieldType, local_field_type: FieldType, top_level: bool = True) -> bool: remote_type_id = remote_field_type.type_id local_type_id = local_field_type.type_id - if _is_bytes_uint8_array_pair(remote_type_id, local_type_id): + if top_level and _is_bytes_uint8_array_pair(remote_type_id, local_type_id): return True if top_level and _is_root_list_array_pair(remote_field_type, local_field_type): return True @@ -728,6 +795,20 @@ def _create_local_typehint_serializer(resolver, field_name, type_hint): return infer_field(field_name, unwrapped_type, StructFieldSerializerVisitor(resolver)) +def _can_direct_read_integer_scalar(remote_field_type: FieldType, local_field_type: FieldType) -> bool: + if remote_field_type.is_nullable or local_field_type.is_nullable: + return False + if remote_field_type.is_tracking_ref or local_field_type.is_tracking_ref: + return False + remote_domain = _INT_TYPE_DOMAINS.get(remote_field_type.type_id) + local_domain = _INT_TYPE_DOMAINS.get(local_field_type.type_id) + if remote_domain is None or local_domain is None: + return False + remote_signed, remote_bits = remote_domain + local_signed, local_bits = local_domain + return remote_signed == local_signed and remote_bits <= local_bits + + def _create_compatible_field_serializer( resolver, field_name, @@ -802,6 +883,8 @@ def _create_compatible_field_serializer( and not exact_scalar_field_type and supports_compatible_scalar_conversion(remote_field_type.type_id, local_field_type.type_id) ): + if _can_direct_read_integer_scalar(remote_field_type, local_field_type): + return remote_field_type.create_serializer(resolver, local_declared_type) remote_serializer = remote_field_type.create_serializer(resolver, local_declared_type) return CompatibleScalarFieldSerializer( resolver, @@ -860,12 +943,14 @@ def _field_type_assignment(remote_field_type: FieldType, local_field_type: Field needs_validation = _requires_nullable_validation(remote_field_type, local_field_type) remote_type_id = remote_field_type.type_id local_type_id = local_field_type.type_id - if local_type_id == TypeId.UNKNOWN: + if top_level and local_type_id == TypeId.UNKNOWN: return True, needs_validation - if remote_type_id == TypeId.UNKNOWN: + if top_level and remote_type_id == TypeId.UNKNOWN: return True, True if top_level and _is_root_list_array_pair(remote_field_type, local_field_type): return True, True + if top_level and _is_bytes_uint8_array_pair(remote_type_id, local_type_id): + return True, True if top_level: from pyfory.converter import supports_compatible_scalar_conversion @@ -889,47 +974,7 @@ def _field_type_assignment(remote_field_type: FieldType, local_field_type: Field and supports_compatible_scalar_conversion(remote_type_id, local_type_id) ): return True, needs_validation - if remote_type_id in (TypeId.LIST, TypeId.SET): - if local_type_id != remote_type_id: - return False, False - child_assignable, child_needs_validation = _field_type_assignment( - remote_field_type.element_type, - local_field_type.element_type, - False, - ) - return child_assignable, needs_validation or child_needs_validation - if remote_type_id == TypeId.MAP: - if local_type_id != TypeId.MAP: - return False, False - key_assignable, key_needs_validation = _field_type_assignment( - remote_field_type.key_type, - local_field_type.key_type, - False, - ) - value_assignable, value_needs_validation = _field_type_assignment( - remote_field_type.value_type, - local_field_type.value_type, - False, - ) - return ( - key_assignable and value_assignable, - needs_validation or key_needs_validation or value_needs_validation, - ) - if _is_bytes_uint8_array_pair(remote_type_id, local_type_id): - return True, True - remote_int_domain = _INT_TYPE_DOMAINS.get(remote_type_id) - local_int_domain = _INT_TYPE_DOMAINS.get(local_type_id) - if remote_int_domain is not None or local_int_domain is not None: - if remote_int_domain is None or local_int_domain is None: - return False, False - remote_signed, remote_width = remote_int_domain - local_signed, local_width = local_int_domain - if remote_signed != local_signed: - return False, False - return True, needs_validation or remote_width > local_width - if remote_type_id == local_type_id: - return True, needs_validation - return False, False + return _compatible_field_type_match(remote_field_type, local_field_type), needs_validation def _is_compatible_scalar_type_id(type_id: TypeId) -> bool: diff --git a/python/pyfory/tests/test_meta_share.py b/python/pyfory/tests/test_meta_share.py index 0853d2718b..35bb4f6e9c 100644 --- a/python/pyfory/tests/test_meta_share.py +++ b/python/pyfory/tests/test_meta_share.py @@ -22,6 +22,7 @@ import pyfory from pyfory import Fory +from pyfory.error import TypeNotCompatibleError @dataclasses.dataclass @@ -192,10 +193,9 @@ def test_schema_inconsistent_list_fields(self): fory2 = Fory(xlang=True, compatible=True) fory2.register_type(ListFieldsClassInconsistent) - deserialized = fory2.deserialize(buffer) - assert isinstance(deserialized, ListFieldsClassInconsistent) - assert deserialized.name == "test" + with pytest.raises(TypeNotCompatibleError): + fory2.deserialize(buffer) def test_schema_inconsistent_dict_fields(self): fory1 = Fory(xlang=True, compatible=True) @@ -204,7 +204,6 @@ def test_schema_inconsistent_dict_fields(self): fory2 = Fory(xlang=True, compatible=True) fory2.register_type(DictFieldsClassInconsistent) - deserialized = fory2.deserialize(buffer) - assert isinstance(deserialized, DictFieldsClassInconsistent) - assert deserialized.name == "test" + with pytest.raises(TypeNotCompatibleError): + fory2.deserialize(buffer) diff --git a/python/pyfory/tests/test_struct.py b/python/pyfory/tests/test_struct.py index 4097e87598..55d2b2c805 100644 --- a/python/pyfory/tests/test_struct.py +++ b/python/pyfory/tests/test_struct.py @@ -28,7 +28,7 @@ import pyfory from pyfory import Fory -from pyfory.error import ForyInvalidDataError, TypeUnregisteredError +from pyfory.error import ForyInvalidDataError, TypeNotCompatibleError, TypeUnregisteredError from pyfory.resolver import NOT_NULL_VALUE_FLAG, REF_VALUE_FLAG from pyfory.struct import DataClassSerializer, build_default_values_factory from pyfory.types import TypeId @@ -159,6 +159,26 @@ class LocalNestedSignedDefault: values: Dict[pyfory.FixedInt32, List[pyfory.TaggedInt64]] = dataclasses.field(default_factory=lambda: {-1: [-1]}) +@dataclass +class RemoteNestedNullable: + values: Dict[pyfory.Int32, List[Optional[pyfory.Int32]]] = dataclasses.field(default_factory=dict) + + +@dataclass +class LocalNestedRequired: + values: Dict[pyfory.Int32, List[pyfory.Int32]] = dataclasses.field(default_factory=dict) + + +@dataclass +class RemoteNestedRef: + values: Dict[pyfory.Int32, List[pyfory.Ref[str]]] = dataclasses.field(default_factory=dict) + + +@dataclass +class LocalNestedPlain: + values: Dict[pyfory.Int32, List[str]] = dataclasses.field(default_factory=dict) + + @dataclass class RemoteStringScalar: value: str = "" @@ -194,6 +214,11 @@ class RemoteInt64Scalar: value: pyfory.Int64 = 0 +@dataclass +class RemoteInt32Scalar: + value: pyfory.Int32 = 0 + + @dataclass class LocalInt8Scalar: value: pyfory.Int8 = 0 @@ -367,29 +392,34 @@ def test_scalar_tracking_ref_is_not_converted(): assert [field_info.field_type.is_tracking_ref for field_info in field_infos] == [True, True] shared = "".join(["", "1"]) - result = reader.deserialize(writer.serialize(RemoteTwoStringScalars(shared, shared))) - assert result == LocalBoolIntScalars(False, 0) + with pytest.raises(TypeNotCompatibleError): + reader.deserialize(writer.serialize(RemoteTwoStringScalars(shared, shared))) def test_scalar_tracking_ref_rules(): - assert compat_ser_de(RemoteRefBoolScalar, LocalBoolScalar, RemoteRefBoolScalar(True), 739, ref=True) == LocalBoolScalar(False) - assert compat_ser_de(RemoteBoolScalar, LocalRefBoolScalar, RemoteBoolScalar(True), 740, ref=True) == LocalRefBoolScalar(False) + with pytest.raises(TypeNotCompatibleError): + compat_ser_de(RemoteRefBoolScalar, LocalBoolScalar, RemoteRefBoolScalar(True), 739, ref=True) + with pytest.raises(TypeNotCompatibleError): + compat_ser_de(RemoteBoolScalar, LocalRefBoolScalar, RemoteBoolScalar(True), 740, ref=True) assert compat_ser_de(RemoteRefBoolScalar, LocalRefBoolScalar, RemoteRefBoolScalar(True), 741, ref=True) == LocalRefBoolScalar(True) - assert compat_ser_de(RemoteRefFixedInt32Scalar, LocalRefInt32Scalar, RemoteRefFixedInt32Scalar(7), 742, ref=True) == LocalRefInt32Scalar(0) - assert compat_ser_de( - RemoteOptionalRefBoolScalar, - LocalRefBoolScalar, - RemoteOptionalRefBoolScalar(True), - 748, - ref=True, - ) == LocalRefBoolScalar(False) - assert compat_ser_de( - RemoteRefBoolScalar, - LocalOptionalRefBoolScalar, - RemoteRefBoolScalar(True), - 749, - ref=True, - ) == LocalOptionalRefBoolScalar(None) + with pytest.raises(TypeNotCompatibleError): + compat_ser_de(RemoteRefFixedInt32Scalar, LocalRefInt32Scalar, RemoteRefFixedInt32Scalar(7), 742, ref=True) + with pytest.raises(TypeNotCompatibleError): + compat_ser_de( + RemoteOptionalRefBoolScalar, + LocalRefBoolScalar, + RemoteOptionalRefBoolScalar(True), + 748, + ref=True, + ) + with pytest.raises(TypeNotCompatibleError): + compat_ser_de( + RemoteRefBoolScalar, + LocalOptionalRefBoolScalar, + RemoteRefBoolScalar(True), + 749, + ref=True, + ) assert compat_ser_de( RemoteOptionalRefBoolScalar, LocalOptionalRefBoolScalar, @@ -410,42 +440,69 @@ def test_same_schema_scalar_read_is_direct(): assert fory.deserialize(fory.serialize(value)) == value -def test_compatible_read_accepts_nested_same_domain_integer_encoding(): - result = compat_ser_de( - RemoteNestedFixedTagged, - LocalNestedVarint, - RemoteNestedFixedTagged(values={1: [2, -3], -4: [5]}), - 702, - ) - assert result == LocalNestedVarint(values={1: [2, -3], -4: [5]}) +def test_integer_widening_direct(): + from pyfory.converter import CompatibleScalarFieldSerializer + writer, reader, payload = compat_ser(RemoteInt32Scalar, LocalInt64Scalar, RemoteInt32Scalar(42), 751) + assert reader.deserialize(payload) == LocalInt64Scalar(42) + type_info = next(iter(reader.type_resolver._meta_shared_type_info.values())) + field_serializer = type_info.serializer._serializers[0] + assert type(field_serializer).__name__ == "Int32Serializer" + assert not isinstance(field_serializer, CompatibleScalarFieldSerializer) + + writer, reader, payload = compat_ser(RemoteInt64Scalar, LocalInt8Scalar, RemoteInt64Scalar(42), 752) + assert reader.deserialize(payload) == LocalInt8Scalar(42) + type_info = next(iter(reader.type_resolver._meta_shared_type_info.values())) + assert isinstance(type_info.serializer._serializers[0], CompatibleScalarFieldSerializer) + + +def test_nested_integer_encoding_rejected(): + with pytest.raises(TypeNotCompatibleError): + compat_ser_de( + RemoteNestedFixedTagged, + LocalNestedVarint, + RemoteNestedFixedTagged(values={1: [2, -3], -4: [5]}), + 702, + ) -def test_compatible_read_validates_nested_integer_narrowing(): - result = compat_ser_de( - RemoteNestedWide, - LocalNestedNarrow, - RemoteNestedWide(values={1: [2, -3]}), - 703, - ) - assert result == LocalNestedNarrow(values={1: [2, -3]}) - result = compat_ser_de( - RemoteNestedWide, - LocalNestedNarrow, - RemoteNestedWide(values={1: [1 << 40]}), - 704, - ) - assert result == LocalNestedNarrow() +def test_nested_integer_narrowing_rejected(): + with pytest.raises(TypeNotCompatibleError): + compat_ser_de( + RemoteNestedWide, + LocalNestedNarrow, + RemoteNestedWide(values={1: [2, -3]}), + 703, + ) -def test_compatible_read_skips_nested_signed_unsigned_mismatch(): - result = compat_ser_de( - RemoteNestedUnsigned, - LocalNestedSignedDefault, - RemoteNestedUnsigned(values={1: [2]}), - 705, - ) - assert result == LocalNestedSignedDefault() +def test_nested_signed_unsigned_rejected(): + with pytest.raises(TypeNotCompatibleError): + compat_ser_de( + RemoteNestedUnsigned, + LocalNestedSignedDefault, + RemoteNestedUnsigned(values={1: [2]}), + 705, + ) + + +def test_nested_nullable_scalar_compatible(): + assert compat_ser_de( + RemoteNestedNullable, + LocalNestedRequired, + RemoteNestedNullable(values={1: [2]}), + 706, + ) == LocalNestedRequired(values={1: [2]}) + + +def test_nested_ref_tracking_compatible(): + assert compat_ser_de( + RemoteNestedRef, + LocalNestedPlain, + RemoteNestedRef(values={1: ["one"]}), + 707, + ref=True, + ) == LocalNestedPlain(values={1: ["one"]}) @dataclass @@ -1065,6 +1122,27 @@ class CompatibleRequiredDefaultsV2: f_dict: Dict[str, int] +@dataclass +class CompatibleListItemV1: + value: pyfory.Int32 + + +@dataclass +class CompatibleListItemV2: + value: pyfory.Int64 + added: str + + +@dataclass +class CompatibleListOwnerV1: + items: List[CompatibleListItemV1] + + +@dataclass +class CompatibleListOwnerV2: + items: List[CompatibleListItemV2] + + @pytest.mark.parametrize("xlang", [False, True]) def test_compatible_mode_add_field(xlang): """Test that adding a field with default value works in compatible mode.""" @@ -1177,6 +1255,44 @@ def test_compatible_mode_add_required_fields_use_type_defaults(xlang): assert ser_de(fory_v2, v2_result) == v2_result +def test_compatible_nested_list_struct(): + writer = Fory(xlang=True, compatible=True, ref=False) + reader = Fory(xlang=True, compatible=True, ref=False) + + writer.register_type(CompatibleListItemV1, type_id=501) + writer.register_type(CompatibleListOwnerV1, type_id=502) + reader.register_type(CompatibleListItemV2, type_id=501) + reader.register_type(CompatibleListOwnerV2, type_id=502) + + decoded = reader.deserialize(writer.serialize(CompatibleListOwnerV1(items=[CompatibleListItemV1(123), CompatibleListItemV1(456)]))) + + assert isinstance(decoded, CompatibleListOwnerV2) + assert [item.value for item in decoded.items] == [123, 456] + assert [item.added for item in decoded.items] == ["", ""] + + +@dataclass +class CompatibleListStringField: + items: List[str] + + +@dataclass +class CompatibleListIntField: + items: List[pyfory.Int32] + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_incompatible_matched_field(xlang): + writer = Fory(xlang=xlang, compatible=True, ref=False) + reader = Fory(xlang=xlang, compatible=True, ref=False) + writer.register_type(CompatibleListStringField, name="example.CompatibleListField") + reader.register_type(CompatibleListIntField, name="example.CompatibleListField") + + data = writer.serialize(CompatibleListStringField(items=["one", "two"])) + with pytest.raises(TypeNotCompatibleError): + reader.deserialize(data) + + @dataclass class CompatibleWithOptional: f1: Optional[int] = None diff --git a/python/pyfory/tests/test_typedef_encoding.py b/python/pyfory/tests/test_typedef_encoding.py index daf305a5b6..1890e06ad6 100644 --- a/python/pyfory/tests/test_typedef_encoding.py +++ b/python/pyfory/tests/test_typedef_encoding.py @@ -41,6 +41,7 @@ TYPEDEF_HASH_SHIFT, _INT64_MIN, _UINT64_MASK, + plan_field_assignment, ) from pyfory.meta.typedef_encoder import ( FIELD_NAME_ENCODER, @@ -213,6 +214,26 @@ def test_dynamic_field_type(): assert dynamic_field.is_tracking_ref is False +def test_nested_user_type_shape_matching(): + remote = CollectionFieldType(TypeId.LIST, True, False, False, DynamicFieldType(TypeId.UNKNOWN, False, False, False)) + local = CollectionFieldType(TypeId.LIST, True, False, False, DynamicFieldType(TypeId.STRUCT, False, False, False)) + + can_assign, validation = plan_field_assignment(remote, local) + + assert can_assign + assert validation is None + + +def test_nested_unknown_does_not_match_scalar(): + remote = CollectionFieldType(TypeId.LIST, True, False, False, DynamicFieldType(TypeId.UNKNOWN, False, False, False)) + local = CollectionFieldType(TypeId.LIST, True, False, False, FieldType(TypeId.INT32, True, False, False)) + + can_assign, validation = plan_field_assignment(remote, local) + + assert not can_assign + assert validation is None + + def test_encode_decode_typedef(): """Test encoding and decoding a TypeDef.""" fory = Fory(xlang=True, compatible=False) @@ -583,7 +604,7 @@ def test_compatible_int32_pyarray_assigns_to_list(): assert decoded.payload == [1, 2, 3] -def test_compatible_nullable_int32_list_payload_rejects_array_read(): +def test_compatible_nullable_int32_list_schema_assigns_to_array(): writer = Fory(xlang=True, compatible=True) reader = Fory(xlang=True, compatible=True) _register_int32_payload(writer, NullableInt32ListPayload) @@ -591,6 +612,7 @@ def test_compatible_nullable_int32_list_payload_rejects_array_read(): decoded = reader.deserialize(writer.serialize(NullableInt32ListPayload(payload=[1, 2, 3]))) assert isinstance(decoded, Int32ArrayPayload) + assert isinstance(decoded.payload, pyfory.Int32Array) assert list(decoded.payload) == [1, 2, 3] with pytest.raises(TypeNotCompatibleError): @@ -607,16 +629,14 @@ def test_compatible_incompatible_list_array_elements_reject(): reader.deserialize(writer.serialize(StringListPayload(payload=["1", "2"]))) -def test_compatible_nested_list_array_mismatch_not_assigned(): +def test_nested_list_array_mismatch_rejects(): writer = Fory(xlang=True, compatible=True) reader = Fory(xlang=True, compatible=True) _register_int32_payload(writer, NestedInt32ListPayload) _register_int32_payload(reader, NestedInt32ArrayPayload) - decoded = reader.deserialize(writer.serialize(NestedInt32ListPayload(payload=[[1, 2], [3]]))) - - assert isinstance(decoded, NestedInt32ArrayPayload) - assert decoded.payload == [] + with pytest.raises(TypeNotCompatibleError): + reader.deserialize(writer.serialize(NestedInt32ListPayload(payload=[[1, 2], [3]]))) if __name__ == "__main__": diff --git a/rust/fory-core/src/meta/mod.rs b/rust/fory-core/src/meta/mod.rs index 90f07c556b..fd2aa28351 100644 --- a/rust/fory-core/src/meta/mod.rs +++ b/rust/fory-core/src/meta/mod.rs @@ -22,6 +22,9 @@ pub use meta_string::{ Encoding, MetaString, MetaStringDecoder, MetaStringEncoder, FIELD_NAME_DECODER, FIELD_NAME_ENCODER, NAMESPACE_DECODER, NAMESPACE_ENCODER, TYPE_NAME_DECODER, TYPE_NAME_ENCODER, }; +#[doc(hidden)] +pub use type_meta::assign_remote_field_ids; +pub(crate) use type_meta::compatible_scalar_field_pair; pub use type_meta::{ compute_field_hash, compute_struct_hash, sort_fields, FieldInfo, FieldType, TypeMeta, NAMESPACE_ENCODINGS, TYPE_NAME_ENCODINGS, diff --git a/rust/fory-core/src/meta/type_meta.rs b/rust/fory-core/src/meta/type_meta.rs index 75c5329936..a6c1013e63 100644 --- a/rust/fory-core/src/meta/type_meta.rs +++ b/rust/fory-core/src/meta/type_meta.rs @@ -24,7 +24,7 @@ use crate::meta::{ use crate::resolver::{TypeInfo, TypeResolver}; use crate::type_id::{ TypeId, BINARY, COMPATIBLE_STRUCT, ENUM, EXT, NAMED_COMPATIBLE_STRUCT, NAMED_ENUM, NAMED_EXT, - NAMED_STRUCT, NAMED_UNION, STRUCT, TYPED_UNION, UINT8_ARRAY, UNKNOWN, + NAMED_STRUCT, NAMED_UNION, STRUCT, TYPED_UNION, UINT8_ARRAY, UNION, UNKNOWN, }; use crate::util::{murmurhash3_x64_128, to_snake_case}; @@ -32,7 +32,7 @@ use crate::util::{murmurhash3_x64_128, to_snake_case}; /// This treats all struct variants (STRUCT, COMPATIBLE_STRUCT, NAMED_STRUCT, /// NAMED_COMPATIBLE_STRUCT) and UNKNOWN as equivalent to STRUCT. /// UNKNOWN (0) is used for polymorphic types (interfaces) in cross-language serialization. -/// Similarly for ENUM and EXT variants. Dense byte arrays stay distinct here because schema +/// Similarly for ENUM, EXT, and UNION variants. Dense byte arrays stay distinct here because schema /// equality and schema hashes must not turn compatibility-only byte-sequence assignment into /// same-schema equality. fn normalize_type_id_for_eq(type_id: u32) -> u32 { @@ -50,10 +50,39 @@ fn normalize_type_id_for_eq(type_id: u32) -> u32 { _ if type_id == ENUM || type_id == NAMED_ENUM => ENUM, // All ext variants normalize to EXT _ if type_id == EXT || type_id == NAMED_EXT => EXT, + // All union variants normalize to UNION + _ if type_id == UNION || type_id == TYPED_UNION || type_id == NAMED_UNION => UNION, // Everything else stays the same _ => type_id, } } + +fn compatible_scalar_type_id(type_id: u32) -> bool { + matches!( + type_id, + crate::type_id::BOOL + | crate::type_id::INT8 + | crate::type_id::INT16 + | crate::type_id::INT32 + | crate::type_id::VARINT32 + | crate::type_id::INT64 + | crate::type_id::VARINT64 + | crate::type_id::TAGGED_INT64 + | crate::type_id::UINT8 + | crate::type_id::UINT16 + | crate::type_id::UINT32 + | crate::type_id::VAR_UINT32 + | crate::type_id::UINT64 + | crate::type_id::VAR_UINT64 + | crate::type_id::TAGGED_UINT64 + | crate::type_id::FLOAT16 + | crate::type_id::BFLOAT16 + | crate::type_id::FLOAT32 + | crate::type_id::FLOAT64 + | crate::type_id::STRING + | crate::type_id::DECIMAL + ) +} use std::clone::Clone; use std::cmp::min; use std::collections::HashMap; @@ -61,6 +90,7 @@ use std::rc::Rc; const SMALL_NUM_FIELDS_THRESHOLD: usize = 0b11111; const MAX_TYPE_META_FIELDS: usize = i16::MAX as usize; +const MAX_COMPATIBLE_MATCHED_FIELD_INDEX: usize = (i16::MAX as usize - 1) / 2; const REGISTER_BY_NAME_FLAG: u8 = 0b0010_0000; const COMPATIBLE_TYPEDEF_FLAG: u8 = 0b0100_0000; const STRUCT_TYPEDEF_FLAG: u8 = 0b1000_0000; @@ -233,6 +263,45 @@ impl FieldType { self.compatible_fingerprint } + #[inline(always)] + pub(crate) fn exact_shape_match(&self, other: &Self) -> bool { + if normalize_type_id_for_eq(self.type_id) != normalize_type_id_for_eq(other.type_id) + || self.nullable != other.nullable + || self.track_ref != other.track_ref + || self.generics.len() != other.generics.len() + { + return false; + } + self.generics + .iter() + .zip(other.generics.iter()) + .all(|(left, right)| left.exact_shape_match(right)) + } + + #[inline(always)] + pub(crate) fn compatible_shape_match(&self, other: &Self) -> bool { + if compatible_scalar_type_id(self.type_id) || compatible_scalar_type_id(other.type_id) { + return !self.track_ref && !other.track_ref && self.type_id == other.type_id; + } + if normalize_type_id_for_eq(self.type_id) != normalize_type_id_for_eq(other.type_id) { + return false; + } + if self.generics.len() != other.generics.len() { + return collection_missing_generics_match(self, other); + } + if self.type_id == TypeId::MAP as u32 { + return self + .generics + .iter() + .zip(other.generics.iter()) + .all(|(left, right)| map_entry_shape_match(left, right)); + } + self.generics + .iter() + .zip(other.generics.iter()) + .all(|(left, right)| left.compatible_shape_match(right)) + } + fn to_bytes(&self, writer: &mut Writer, write_flag: bool, nullable: bool) -> Result<(), Error> { let mut header = self.type_id; if header == NAMED_ENUM { @@ -541,6 +610,40 @@ fn compatible_fingerprint_type_id(type_id: u32) -> u32 { } } +#[inline(always)] +fn same_numeric_wire_shape(left: u32, right: u32) -> bool { + let left = compatible_fingerprint_type_id(left); + left == compatible_fingerprint_type_id(right) + && (left == TypeId::VARINT32 as u32 + || left == TypeId::VARINT64 as u32 + || left == TypeId::VAR_UINT32 as u32 + || left == TypeId::VAR_UINT64 as u32) +} + +#[inline(always)] +fn collection_missing_generics_match(local: &FieldType, remote: &FieldType) -> bool { + matches!( + normalize_type_id_for_eq(local.type_id), + x if x == TypeId::LIST as u32 || x == TypeId::SET as u32 || x == TypeId::MAP as u32 + ) && (local.generics.is_empty() || remote.generics.is_empty()) +} + +#[inline(always)] +fn map_entry_shape_match(local: &FieldType, remote: &FieldType) -> bool { + if local.exact_shape_match(remote) || local.compatible_shape_match(remote) { + return true; + } + // Map readers consume key/value payloads through the remote entry FieldType. + // Fixed and variable integer encodings are therefore payload-shape variants + // for the same Rust key/value carrier, not top-level scalar conversions. + !local.track_ref + && !remote.track_ref + && local.nullable == remote.nullable + && local.generics.is_empty() + && remote.generics.is_empty() + && same_numeric_wire_shape(local.type_id, remote.type_id) +} + #[inline(always)] fn compute_field_type_fingerprint(type_id: u32, generics: &[FieldType]) -> u64 { let mut hash = fnv1a_hash_u32(FNV_OFFSET_BASIS, compatible_fingerprint_type_id(type_id)); @@ -601,6 +704,61 @@ pub fn compute_struct_hash(field_ids: impl IntoIterator) -> u32 { field_ids.into_iter().fold(17u32, compute_field_hash) } +#[inline(always)] +fn scalar_numeric_type(type_id: u32) -> bool { + matches!( + type_id, + x if x == TypeId::INT8 as u32 + || x == TypeId::INT16 as u32 + || x == TypeId::INT32 as u32 + || x == TypeId::VARINT32 as u32 + || x == TypeId::INT64 as u32 + || x == TypeId::VARINT64 as u32 + || x == TypeId::TAGGED_INT64 as u32 + || x == TypeId::UINT8 as u32 + || x == TypeId::UINT16 as u32 + || x == TypeId::UINT32 as u32 + || x == TypeId::VAR_UINT32 as u32 + || x == TypeId::UINT64 as u32 + || x == TypeId::VAR_UINT64 as u32 + || x == TypeId::TAGGED_UINT64 as u32 + || x == TypeId::FLOAT16 as u32 + || x == TypeId::BFLOAT16 as u32 + || x == TypeId::FLOAT32 as u32 + || x == TypeId::FLOAT64 as u32 + || x == TypeId::DECIMAL as u32 + ) +} + +#[inline(always)] +fn scalar_conversion_type(type_id: u32) -> bool { + type_id == TypeId::BOOL as u32 + || type_id == TypeId::STRING as u32 + || scalar_numeric_type(type_id) +} + +#[inline(always)] +fn scalar_types_compatible(local: u32, remote: u32) -> bool { + if local == remote { + return scalar_conversion_type(local); + } + let local_numeric = scalar_numeric_type(local); + let remote_numeric = scalar_numeric_type(remote); + (local == TypeId::BOOL as u32 && (remote == TypeId::STRING as u32 || remote_numeric)) + || (remote == TypeId::BOOL as u32 && (local == TypeId::STRING as u32 || local_numeric)) + || (local == TypeId::STRING as u32 && remote_numeric) + || (remote == TypeId::STRING as u32 && local_numeric) + || (local_numeric && remote_numeric) +} + +#[inline(always)] +pub(crate) fn compatible_scalar_field_pair(local: &FieldType, remote: &FieldType) -> bool { + !local.track_ref + && !remote.track_ref + && (local.type_id != remote.type_id || local.nullable != remote.nullable) + && scalar_types_compatible(local.type_id, remote.type_id) +} + /// Sorts field infos according to the provided sorted field names and assigns field IDs. /// /// This function takes a vector of field infos and a slice of sorted field names, @@ -648,16 +806,102 @@ pub fn sort_fields( impl PartialEq for FieldType { fn eq(&self, other: &Self) -> bool { - // Normalize type IDs for comparison to handle cross-language schema evolution. - // This allows UNKNOWN (0) polymorphic types to match STRUCT (15) in Rust. - if normalize_type_id_for_eq(self.type_id) != normalize_type_id_for_eq(other.type_id) { - return false; - } - if self.generics != other.generics { - return false; + self.exact_shape_match(other) + } +} + +#[doc(hidden)] +pub fn assign_remote_field_ids( + local_field_infos: &[FieldInfo], + field_infos: &mut [FieldInfo], +) -> Result<(), Error> { + // Build maps for both name-based and ID-based lookup. + // The value is the sorted local index, not the field's ID attribute. + let field_index_by_name: HashMap = local_field_infos + .iter() + .enumerate() + .filter(|(_, f)| f.field_id < 0 && !f.field_name.is_empty()) + .map(|(i, f)| (f.field_name.clone(), (i, f))) + .collect(); + + let field_index_by_id: HashMap = local_field_infos + .iter() + .enumerate() + .filter(|(_, f)| f.field_id >= 0) + .map(|(i, f)| (f.field_id, (i, f))) + .collect(); + + let mut used_local_fields = vec![false; local_field_infos.len()]; + for field in field_infos.iter_mut() { + let local_match = if field.field_id >= 0 { + field_index_by_id.get(&field.field_id).copied() + } else { + let snake_case_name = to_snake_case(&field.field_name); + field_index_by_name.get(&snake_case_name).copied() + }; + + match local_match { + Some((sorted_index, local_info)) => { + if used_local_fields[sorted_index] { + return Err(Error::type_error(format!( + "Remote field {} duplicates local field {} in compatible metadata", + field.field_name, local_info.field_name + ))); + } + let exact_field = local_info.field_type.exact_shape_match(&field.field_type); + if !exact_field + && !crate::serializer::codec::compatible_field_pair( + &local_info.field_type, + &field.field_type, + ) + { + if crate::util::ENABLE_FORY_DEBUG_OUTPUT { + eprintln!( + "[fory-debug] schema-incompatible field: name={}, remote_type={:?}, local_type={:?}", + field.field_name, field.field_type, local_info.field_type + ); + } + return Err(Error::type_error(format!( + "Cannot read remote field {} as local field {}: remote and local field schemas are not compatible", + field.field_name, local_info.field_name + ))); + } + // ID-encoded remote fields have no name. Copying the local name keeps later + // metadata reserialization and diagnostics tied to the matched local field. + if field.field_name.is_empty() { + field.field_name = local_info.field_name.clone(); + } + if sorted_index > MAX_COMPATIBLE_MATCHED_FIELD_INDEX { + return Err(Error::type_error(format!( + "Cannot assign compatible matched id for local field {}: local field index {} exceeds max {}", + local_info.field_name, sorted_index, MAX_COMPATIBLE_MATCHED_FIELD_INDEX + ))); + } + field.field_id = if exact_field { + (sorted_index * 2) as i16 + } else { + (sorted_index * 2 + 1) as i16 + }; + used_local_fields[sorted_index] = true; + if crate::util::ENABLE_FORY_DEBUG_OUTPUT { + eprintln!( + "[fory-debug] matched field: name={}, assigned_field_id={}, remote_type={:?}, local_type={:?}", + field.field_name, field.field_id, field.field_type, local_info.field_type + ); + } + } + None => { + if crate::util::ENABLE_FORY_DEBUG_OUTPUT { + eprintln!( + "[fory-debug] no local match for field: name={}", + field.field_name + ); + } + field.field_id = -1; + } } - true } + Ok(()) } #[derive(Debug)] @@ -958,7 +1202,7 @@ impl TypeMeta { "TypeMeta kind does not match registered type metadata", )); } - Self::assign_field_ids(&type_info_current, &mut sorted_field_infos); + Self::assign_field_ids(&type_info_current, &mut sorted_field_infos)?; } } else if user_type_id != NO_USER_TYPE_ID { if let Some(type_info_current) = type_resolver.get_user_type_info_by_id(user_type_id) { @@ -967,10 +1211,10 @@ impl TypeMeta { "TypeMeta kind does not match registered type metadata", )); } - Self::assign_field_ids(&type_info_current, &mut sorted_field_infos); + Self::assign_field_ids(&type_info_current, &mut sorted_field_infos)?; } } else if let Some(type_info_current) = type_resolver.get_type_info_by_id(type_id) { - Self::assign_field_ids(&type_info_current, &mut sorted_field_infos); + Self::assign_field_ids(&type_info_current, &mut sorted_field_infos)?; } // if no type found, keep all fields id as -1 to be skipped. TypeMeta::new( @@ -983,7 +1227,10 @@ impl TypeMeta { ) } - fn assign_field_ids(type_info_current: &TypeInfo, field_infos: &mut [FieldInfo]) { + fn assign_field_ids( + type_info_current: &TypeInfo, + field_infos: &mut [FieldInfo], + ) -> Result<(), Error> { if crate::util::ENABLE_FORY_DEBUG_OUTPUT { eprintln!( "[fory-debug] assign_field_ids called for type: {:?}", @@ -1007,77 +1254,7 @@ impl TypeMeta { } } - // Build maps for both name-based and ID-based lookup. - // The value is the SORTED INDEX (position in local_field_infos), not the field's ID attribute. - // This index is used for matching in generated code. - let field_index_by_name: HashMap = local_field_infos - .iter() - .enumerate() - .filter(|(_, f)| !f.field_name.is_empty()) - .map(|(i, f)| (f.field_name.clone(), (i, f))) - .collect(); - - let field_index_by_id: HashMap = local_field_infos - .iter() - .enumerate() - .filter(|(_, f)| f.field_id >= 0) - .map(|(i, f)| (f.field_id, (i, f))) - .collect(); - - for field in field_infos.iter_mut() { - // Try to match by field ID first (if the incoming field was encoded with ID) - let local_match = if field.field_id >= 0 && field.field_name.is_empty() { - // Field was encoded with ID, match by ID - field_index_by_id.get(&field.field_id).copied() - } else { - // Field was encoded with name, match by name - // Convert incoming field name to snake_case for cross-language compatibility - // (Java uses camelCase, Rust uses snake_case) - let snake_case_name = to_snake_case(&field.field_name); - field_index_by_name.get(&snake_case_name).copied() - }; - - match local_match { - Some((sorted_index, local_info)) => { - if !crate::serializer::codec::compatible_field_pair( - &local_info.field_type, - &field.field_type, - ) { - if crate::util::ENABLE_FORY_DEBUG_OUTPUT { - eprintln!( - "[fory-debug] schema-incompatible field: name={}, remote_type={:?}, local_type={:?}", - field.field_name, field.field_type, local_info.field_type - ); - } - field.field_id = -1; - continue; - } - // Always copy field name if it was ID-encoded - // This is needed because TypeMeta may need to re-serialize the field info - if field.field_name.is_empty() { - field.field_name = local_info.field_name.clone(); - } - // Assign SORTED INDEX for generated code only after the pair is - // accepted as an exact read or a supported compatible read action. - field.field_id = sorted_index as i16; - if crate::util::ENABLE_FORY_DEBUG_OUTPUT { - eprintln!( - "[fory-debug] matched field: name={}, assigned_field_id={}, remote_type={:?}, local_type={:?}", - field.field_name, field.field_id, field.field_type, local_info.field_type - ); - } - } - None => { - if crate::util::ENABLE_FORY_DEBUG_OUTPUT { - eprintln!( - "[fory-debug] no local match for field: name={}", - field.field_name - ); - } - field.field_id = -1; // No match, skip - } - } - } + assign_remote_field_ids(local_field_infos, field_infos) } #[allow(dead_code)] @@ -1170,6 +1347,183 @@ impl TypeMeta { mod tests { use super::*; + #[test] + fn exact_field_type_requires_wire_metadata() { + let base = FieldType::new(crate::type_id::INT32, false, vec![]); + let nullable = FieldType::new(crate::type_id::INT32, true, vec![]); + let tracked = FieldType::new_with_ref(crate::type_id::INT32, false, true, vec![]); + + assert_eq!(base, base); + assert_ne!(base, nullable); + assert_ne!(base, tracked); + assert!(base.exact_shape_match(&base)); + assert!(!base.exact_shape_match(&nullable)); + assert!(!base.exact_shape_match(&tracked)); + + let list = FieldType::new(crate::type_id::LIST, false, vec![base.clone()]); + let nullable_list = FieldType::new(crate::type_id::LIST, false, vec![nullable]); + assert!(!list.exact_shape_match(&nullable_list)); + + let struct_field = + FieldType::new_with_user_type_id(crate::type_id::STRUCT, 1, false, false, vec![]); + let unknown_field = + FieldType::new_with_user_type_id(crate::type_id::UNKNOWN, 2, false, false, vec![]); + assert!(struct_field.exact_shape_match(&unknown_field)); + + let union_field = FieldType::new(crate::type_id::UNION, false, vec![]); + let typed_union_field = FieldType::new(crate::type_id::TYPED_UNION, false, vec![]); + assert!(union_field.exact_shape_match(&typed_union_field)); + } + + #[test] + fn rejects_unrepresentable_matched_field_id() { + let field_type = FieldType::new(crate::type_id::INT32, false, vec![]); + let mut local_fields = Vec::with_capacity(MAX_COMPATIBLE_MATCHED_FIELD_INDEX + 2); + for index in 0..=MAX_COMPATIBLE_MATCHED_FIELD_INDEX + 1 { + local_fields.push(FieldInfo::new( + &format!("field_{}", index), + field_type.clone(), + )); + } + let mut remote_fields = [FieldInfo::new( + &format!("field_{}", MAX_COMPATIBLE_MATCHED_FIELD_INDEX + 1), + field_type, + )]; + + let message = assign_remote_field_ids(&local_fields, &mut remote_fields) + .err() + .map(|error| error.to_string()) + .unwrap_or_default(); + + assert!(message.contains("exceeds max")); + } + + #[test] + fn classifies_compatible_scalar_field() { + let local_type = FieldType::new(crate::type_id::INT32, false, vec![]); + let remote_type = FieldType::new(crate::type_id::INT16, false, vec![]); + let local_fields = [FieldInfo::new("value", local_type)]; + let mut remote_fields = [FieldInfo::new("value", remote_type)]; + + assign_remote_field_ids(&local_fields, &mut remote_fields).unwrap(); + + assert_eq!(remote_fields[0].field_id, 1); + } + + #[test] + fn classifies_list_array_bridge() { + let array_type = FieldType::new(crate::type_id::INT32_ARRAY, false, vec![]); + let int_type = FieldType::new(crate::type_id::INT32, false, vec![]); + let local_fields = [FieldInfo::new("values", array_type.clone())]; + let mut remote_fields = [FieldInfo::new( + "values", + FieldType::new(crate::type_id::LIST, false, vec![int_type]), + )]; + + assign_remote_field_ids(&local_fields, &mut remote_fields).unwrap(); + assert_eq!(remote_fields[0].field_id, 1); + + let nullable_int = FieldType::new(crate::type_id::INT32, true, vec![]); + let mut nullable_remote = [FieldInfo::new( + "values", + FieldType::new(crate::type_id::LIST, false, vec![nullable_int]), + )]; + assign_remote_field_ids(&local_fields, &mut nullable_remote).unwrap(); + assert_eq!(nullable_remote[0].field_id, 1); + + let tracked_int = FieldType::new_with_ref(crate::type_id::INT32, false, true, vec![]); + let mut tracked_remote = [FieldInfo::new( + "values", + FieldType::new(crate::type_id::LIST, false, vec![tracked_int]), + )]; + assert!(assign_remote_field_ids(&local_fields, &mut tracked_remote).is_err()); + + let nullable_array_local = [FieldInfo::new( + "values", + FieldType::new(crate::type_id::INT32_ARRAY, true, vec![]), + )]; + let mut remote_list = [FieldInfo::new( + "values", + FieldType::new( + crate::type_id::LIST, + false, + vec![FieldType::new(crate::type_id::INT32, false, vec![])], + ), + )]; + assert!(assign_remote_field_ids(&nullable_array_local, &mut remote_list).is_err()); + } + + #[test] + fn classifies_map_entry_encoding_shape() { + let any_type = FieldType::new_with_ref(crate::type_id::UNKNOWN, false, true, vec![]); + let local_map = FieldType::new( + crate::type_id::MAP, + false, + vec![ + FieldType::new(crate::type_id::VAR_UINT32, false, vec![]), + any_type.clone(), + ], + ); + let remote_map = FieldType::new( + crate::type_id::MAP, + false, + vec![ + FieldType::new(crate::type_id::UINT32, false, vec![]), + any_type, + ], + ); + let local_fields = [FieldInfo::new_with_id(0, "values", local_map)]; + let mut remote_fields = [FieldInfo::new_with_id(0, "", remote_map)]; + + assign_remote_field_ids(&local_fields, &mut remote_fields).unwrap(); + + assert_eq!(remote_fields[0].field_id, 1); + } + + #[test] + fn classifies_missing_collection_generics() { + let local_list = FieldType::new(crate::type_id::LIST, false, vec![]); + let remote_list = FieldType::new( + crate::type_id::LIST, + false, + vec![FieldType::new(crate::type_id::UNKNOWN, true, vec![])], + ); + let local_fields = [FieldInfo::new("values", local_list)]; + let mut remote_fields = [FieldInfo::new("values", remote_list)]; + + assign_remote_field_ids(&local_fields, &mut remote_fields).unwrap(); + + assert_eq!(remote_fields[0].field_id, 1); + } + + #[test] + fn name_remote_does_not_match_tagged_local() { + let field_type = FieldType::new(crate::type_id::STRING, false, vec![]); + let local_fields = [FieldInfo::new_with_id(1, "value", field_type.clone())]; + let mut remote_fields = [FieldInfo::new("value", field_type)]; + + assign_remote_field_ids(&local_fields, &mut remote_fields).unwrap(); + + assert_eq!(remote_fields[0].field_id, -1); + } + + #[test] + fn duplicate_remote_name_binding_fails() { + let field_type = FieldType::new(crate::type_id::STRING, false, vec![]); + let local_fields = [FieldInfo::new("value", field_type.clone())]; + let mut remote_fields = [ + FieldInfo::new("value", field_type.clone()), + FieldInfo::new("value", field_type), + ]; + + let message = assign_remote_field_ids(&local_fields, &mut remote_fields) + .err() + .map(|error| error.to_string()) + .unwrap_or_default(); + + assert!(message.contains("duplicates local field")); + } + #[test] fn rejects_body_hash_mismatch_after_successful_parse() { let meta = TypeMeta::new( diff --git a/rust/fory-core/src/resolver/meta_resolver.rs b/rust/fory-core/src/resolver/meta_resolver.rs index 103c2fdb99..48b7770042 100644 --- a/rust/fory-core/src/resolver/meta_resolver.rs +++ b/rust/fory-core/src/resolver/meta_resolver.rs @@ -178,13 +178,18 @@ impl MetaReaderResolver { if let Some(local_type_info) = type_resolver.get_type_info_by_name(namespace, type_name) { - // Use local harness with remote metadata - Rc::new(TypeInfo::from_remote_meta( - type_meta.clone(), - Some(local_type_info.get_harness()), - Some(local_type_info.get_type_id() as u32), - Some(local_type_info.get_user_type_id()), - )) + // Exact schemas can reuse the local TypeInfo; changed + // schemas keep the remote metadata with the local harness. + if type_meta.get_hash() == local_type_info.get_type_meta_ref().get_hash() { + local_type_info + } else { + Rc::new(TypeInfo::from_remote_meta( + type_meta.clone(), + Some(local_type_info.get_harness()), + Some(local_type_info.get_type_id() as u32), + Some(local_type_info.get_user_type_id()), + )) + } } else { // No local type found, use stub harness Rc::new(TypeInfo::from_remote_meta( @@ -202,13 +207,20 @@ impl MetaReaderResolver { if let Some(local_type_info) = type_resolver.get_user_type_info_by_id(user_type_id) { - // Use local harness with remote metadata - Rc::new(TypeInfo::from_remote_meta( - type_meta.clone(), - Some(local_type_info.get_harness()), - Some(local_type_info.get_type_id() as u32), - Some(local_type_info.get_user_type_id()), - )) + // Exact schemas can reuse the local TypeInfo; changed + // schemas keep the remote metadata with the local harness. + if type_meta.get_hash() + == local_type_info.get_type_meta_ref().get_hash() + { + local_type_info + } else { + Rc::new(TypeInfo::from_remote_meta( + type_meta.clone(), + Some(local_type_info.get_harness()), + Some(local_type_info.get_type_id() as u32), + Some(local_type_info.get_user_type_id()), + )) + } } else { // No local type found, use stub harness Rc::new(TypeInfo::from_remote_meta( @@ -220,13 +232,18 @@ impl MetaReaderResolver { } } else if let Some(local_type_info) = type_resolver.get_type_info_by_id(type_id) { - // Use local harness with remote metadata - Rc::new(TypeInfo::from_remote_meta( - type_meta.clone(), - Some(local_type_info.get_harness()), - Some(local_type_info.get_type_id() as u32), - Some(local_type_info.get_user_type_id()), - )) + // Exact schemas can reuse the local TypeInfo; changed + // schemas keep the remote metadata with the local harness. + if type_meta.get_hash() == local_type_info.get_type_meta_ref().get_hash() { + local_type_info + } else { + Rc::new(TypeInfo::from_remote_meta( + type_meta.clone(), + Some(local_type_info.get_harness()), + Some(local_type_info.get_type_id() as u32), + Some(local_type_info.get_user_type_id()), + )) + } } else { // No local type found, use stub harness Rc::new(TypeInfo::from_remote_meta( diff --git a/rust/fory-core/src/resolver/type_resolver.rs b/rust/fory-core/src/resolver/type_resolver.rs index 844379ce6b..41e6c53262 100644 --- a/rust/fory-core/src/resolver/type_resolver.rs +++ b/rust/fory-core/src/resolver/type_resolver.rs @@ -195,6 +195,9 @@ pub struct TypeInfo { namespace: Rc, type_name: Rc, register_by_name: bool, + // False only for a remote TypeMeta paired with the local harness. Those + // values must read through compatible schema mapping, not direct local data. + exact_local_schema: bool, harness: Harness, } @@ -221,6 +224,7 @@ impl TypeInfo { namespace: Rc::from(namespace_meta_string), type_name: Rc::from(type_name_meta_string), register_by_name, + exact_local_schema: true, harness, }) } @@ -242,6 +246,7 @@ impl TypeInfo { namespace, type_name, register_by_name, + exact_local_schema: true, harness, }) } @@ -286,6 +291,11 @@ impl TypeInfo { self.register_by_name } + #[inline(always)] + pub(crate) fn has_exact_local_schema(&self) -> bool { + self.exact_local_schema + } + #[inline(always)] pub fn get_harness(&self) -> &Harness { &self.harness @@ -302,6 +312,7 @@ impl TypeInfo { namespace: Rc::new((*self.namespace).clone()), type_name: Rc::new((*self.type_name).clone()), register_by_name: self.register_by_name, + exact_local_schema: self.exact_local_schema, harness: self.harness.clone(), } } @@ -347,6 +358,7 @@ impl TypeInfo { namespace, type_name, register_by_name, + exact_local_schema: false, harness, } } @@ -438,6 +450,7 @@ fn build_struct_type_infos( namespace: partial_info.namespace.clone(), type_name: partial_info.type_name.clone(), register_by_name: partial_info.register_by_name, + exact_local_schema: true, harness: partial_info.harness.clone(), }; @@ -539,6 +552,7 @@ fn build_serializer_type_infos( namespace: partial_info.namespace.clone(), type_name: partial_info.type_name.clone(), register_by_name: partial_info.register_by_name, + exact_local_schema: true, harness: partial_info.harness.clone(), }; diff --git a/rust/fory-core/src/serializer/codec.rs b/rust/fory-core/src/serializer/codec.rs index 395854788e..c0198aa66f 100644 --- a/rust/fory-core/src/serializer/codec.rs +++ b/rust/fory-core/src/serializer/codec.rs @@ -27,7 +27,7 @@ use super::collection::{ }; use crate::context::{ReadContext, WriteContext}; use crate::error::Error; -use crate::meta::FieldType; +use crate::meta::{FieldInfo, FieldType}; use crate::resolver::{RefFlag, RefMode, TypeResolver}; use crate::serializer::{primitive_list, ForyDefault, Serializer}; use crate::type_id::{self, need_to_write_type_for_field, TypeId, SIZE_OF_REF_AND_TYPE, UNKNOWN}; @@ -219,43 +219,156 @@ pub(super) fn collection_type_with_fallback_generics(type_id: u32) -> bool { #[inline(always)] pub fn field_types_compatible(local: &FieldType, remote: &FieldType) -> bool { - let local_scalar = super::scalar_conversion::is_compatible_scalar_type(local.type_id); - let remote_scalar = super::scalar_conversion::is_compatible_scalar_type(remote.type_id); - if local_scalar && remote_scalar { - if local.track_ref != remote.track_ref { - return false; - } - if (local.track_ref || remote.track_ref) - && (local.type_id != remote.type_id || local.nullable != remote.nullable) - { - return false; - } - if !local.track_ref - && (local.type_id != remote.type_id || local.nullable != remote.nullable) - { - return false; - } - } - if local.compatible_fingerprint() == remote.compatible_fingerprint() { - return true; - } - if local.type_id == remote.type_id - && collection_type_with_fallback_generics(local.type_id) - && (local.generics.is_empty() || remote.generics.is_empty()) - { - return true; - } - false + local.exact_shape_match(remote) +} + +#[inline(always)] +fn compatible_byte_sequence_field(local: &FieldType, remote: &FieldType) -> bool { + !local.track_ref + && !remote.track_ref + && local.nullable == remote.nullable + && ((local.type_id == type_id::BINARY && remote.type_id == type_id::UINT8_ARRAY) + || (local.type_id == type_id::UINT8_ARRAY && remote.type_id == type_id::BINARY)) } #[cold] #[inline(never)] pub fn compatible_field_pair(local: &FieldType, remote: &FieldType) -> bool { field_types_compatible(local, remote) - || super::scalar_conversion::scalar_field_types_compatible(local, remote) + || compatible_byte_sequence_field(local, remote) + || crate::meta::compatible_scalar_field_pair(local, remote) || compatible_list_array_field(local, remote) + || local.compatible_shape_match(remote) +} + +macro_rules! compatible_scalar_reader { + ($read:ident, $read_option:ident, $target:ident, $target_option:ident, $ty:ty) => { + #[inline(always)] + pub fn $read( + context: &mut ReadContext, + local_type: u32, + remote_field: &FieldInfo, + ) -> Result<$ty, Error> { + super::scalar_conversion::$target(context, local_type, remote_field) + } + + #[inline(always)] + pub fn $read_option( + context: &mut ReadContext, + local_type: u32, + remote_field: &FieldInfo, + ) -> Result, Error> { + super::scalar_conversion::$target_option(context, local_type, remote_field) + } + }; } +compatible_scalar_reader!( + read_bool_compatible_scalar, + read_bool_option_compatible_scalar, + read_bool_target, + read_bool_option_target, + bool +); +compatible_scalar_reader!( + read_string_compatible_scalar, + read_string_option_compatible_scalar, + read_string_target, + read_string_option_target, + String +); +compatible_scalar_reader!( + read_i8_compatible_scalar, + read_i8_option_compatible_scalar, + read_i8_target, + read_i8_option_target, + i8 +); +compatible_scalar_reader!( + read_i16_compatible_scalar, + read_i16_option_compatible_scalar, + read_i16_target, + read_i16_option_target, + i16 +); +compatible_scalar_reader!( + read_i32_compatible_scalar, + read_i32_option_compatible_scalar, + read_i32_target, + read_i32_option_target, + i32 +); +compatible_scalar_reader!( + read_i64_compatible_scalar, + read_i64_option_compatible_scalar, + read_i64_target, + read_i64_option_target, + i64 +); +compatible_scalar_reader!( + read_u8_compatible_scalar, + read_u8_option_compatible_scalar, + read_u8_target, + read_u8_option_target, + u8 +); +compatible_scalar_reader!( + read_u16_compatible_scalar, + read_u16_option_compatible_scalar, + read_u16_target, + read_u16_option_target, + u16 +); +compatible_scalar_reader!( + read_u32_compatible_scalar, + read_u32_option_compatible_scalar, + read_u32_target, + read_u32_option_target, + u32 +); +compatible_scalar_reader!( + read_u64_compatible_scalar, + read_u64_option_compatible_scalar, + read_u64_target, + read_u64_option_target, + u64 +); +compatible_scalar_reader!( + read_f32_compatible_scalar, + read_f32_option_compatible_scalar, + read_f32_target, + read_f32_option_target, + f32 +); +compatible_scalar_reader!( + read_f64_compatible_scalar, + read_f64_option_compatible_scalar, + read_f64_target, + read_f64_option_target, + f64 +); +compatible_scalar_reader!( + read_float16_compatible_scalar, + read_float16_option_compatible_scalar, + read_float16_target, + read_float16_option_target, + crate::types::float16::float16 +); +compatible_scalar_reader!( + read_bfloat16_compatible_scalar, + read_bfloat16_option_compatible_scalar, + read_bfloat16_target, + read_bfloat16_option_target, + crate::types::bfloat16::bfloat16 +); +compatible_scalar_reader!( + read_decimal_compatible_scalar, + read_decimal_option_compatible_scalar, + read_decimal_target, + read_decimal_option_target, + crate::types::Decimal +); + #[inline(always)] pub(super) fn generic_field_type<'a>( field_type: &'a FieldType, @@ -269,6 +382,38 @@ pub(super) fn generic_field_type<'a>( }) } +pub enum CodecReadType { + Field(FieldType), + TypeInfo(Rc), +} + +enum ElementReadType { + Direct, + Field(FieldType), + TypeInfo(Rc), +} + +#[inline(always)] +fn element_read_type( + context: &mut ReadContext, + read_type: CodecReadType, +) -> Result +where + T: 'static, + C: Codec, +{ + match read_type { + CodecReadType::Field(field_type) => Ok(ElementReadType::Field(field_type)), + CodecReadType::TypeInfo(type_info) => { + if C::type_info_exact(context, &type_info)? { + Ok(ElementReadType::Direct) + } else { + Ok(ElementReadType::TypeInfo(type_info)) + } + } + } +} + #[inline(always)] fn field_type_for_serializer( type_resolver: &TypeResolver, @@ -339,7 +484,9 @@ pub trait Codec: 'static { local_field_type: &FieldType, remote_field_type: &FieldType, ) -> Result, Error> { - if field_types_compatible(local_field_type, remote_field_type) { + if field_types_compatible(local_field_type, remote_field_type) + || local_field_type.compatible_shape_match(remote_field_type) + { return Self::read_field_with_type(context, remote_field_type).map(Some); } super::scalar_conversion::read_scalar_field::( @@ -361,6 +508,22 @@ pub trait Codec: 'static { Self::read_data(context) } + #[inline(always)] + fn read_data_with_type_info( + context: &mut ReadContext, + type_info: &Rc, + ) -> Result { + Self::read_with_type_info(context, RefMode::None, type_info.clone()) + } + + #[inline(always)] + fn type_info_exact( + _context: &ReadContext, + _type_info: &Rc, + ) -> Result { + Ok(false) + } + fn read_field_with_type( context: &mut ReadContext, remote_field_type: &FieldType, @@ -392,6 +555,11 @@ pub trait Codec: 'static { fn read_type_info(context: &mut ReadContext) -> Result<(), Error>; + #[inline(always)] + fn read_type_info_value(context: &mut ReadContext) -> Result { + Self::read_type_info_as_field_type(context).map(CodecReadType::Field) + } + #[inline(always)] fn read_type_info_as_field_type(context: &mut ReadContext) -> Result { Self::read_type_info(context)?; @@ -477,6 +645,28 @@ where T::fory_read_data(context) } + #[inline(always)] + fn read_data_with_type_info( + context: &mut ReadContext, + type_info: &Rc, + ) -> Result { + if Self::type_info_exact(context, type_info)? { + return T::fory_read_data(context); + } + T::fory_read_with_type_info(context, RefMode::None, type_info.clone()) + } + + #[inline(always)] + fn type_info_exact( + context: &ReadContext, + type_info: &Rc, + ) -> Result { + if !context.is_compatible() { + return Ok(false); + } + Ok(type_info.has_exact_local_schema()) + } + #[inline(always)] fn read_field_with_type( context: &mut ReadContext, @@ -533,6 +723,15 @@ where T::fory_read_type_info(context) } + #[inline(always)] + fn read_type_info_value(context: &mut ReadContext) -> Result { + if context.is_compatible() { + return context.read_any_type_info().map(CodecReadType::TypeInfo); + } + T::fory_read_type_info(context)?; + Self::field_type(context.get_type_resolver()).map(CodecReadType::Field) + } + #[inline(always)] fn static_type_id() -> TypeId { T::fory_static_type_id() @@ -927,7 +1126,9 @@ where local_field_type: &FieldType, remote_field_type: &FieldType, ) -> Result>, Error> { - if field_types_compatible(local_field_type, remote_field_type) { + if field_types_compatible(local_field_type, remote_field_type) + || local_field_type.compatible_shape_match(remote_field_type) + { return Self::read_field_with_type(context, remote_field_type).map(Some); } super::scalar_conversion::read_scalar_option_field::( @@ -1317,6 +1518,71 @@ signed_int_codec!( pub struct VecCodec(PhantomData<(T, C)>); +#[inline(always)] +fn read_vec_items( + context: &mut ReadContext, + len: u32, + has_null: bool, + read_type: Option, +) -> Result, Error> +where + T: 'static, + C: Codec, +{ + let mut vec = Vec::with_capacity(len as usize); + match read_type { + None | Some(ElementReadType::Direct) => { + if has_null { + for _ in 0..len { + let flag = context.reader.read_i8()?; + if flag == RefFlag::Null as i8 { + vec.push(C::default_value()); + } else { + vec.push(C::read_data(context)?); + } + } + } else { + for _ in 0..len { + vec.push(C::read_data(context)?); + } + } + } + Some(ElementReadType::Field(field_type)) => { + if has_null { + for _ in 0..len { + let flag = context.reader.read_i8()?; + if flag == RefFlag::Null as i8 { + vec.push(C::default_value()); + } else { + vec.push(C::read_data_with_type(context, &field_type)?); + } + } + } else { + for _ in 0..len { + vec.push(C::read_data_with_type(context, &field_type)?); + } + } + } + Some(ElementReadType::TypeInfo(type_info)) => { + if has_null { + for _ in 0..len { + let flag = context.reader.read_i8()?; + if flag == RefFlag::Null as i8 { + vec.push(C::default_value()); + } else { + vec.push(C::read_data_with_type_info(context, &type_info)?); + } + } + } else { + for _ in 0..len { + vec.push(C::read_data_with_type_info(context, &type_info)?); + } + } + } + } + Ok(vec) +} + impl Codec> for VecCodec where @@ -1364,7 +1630,9 @@ where local_field_type: &FieldType, remote_field_type: &FieldType, ) -> Result>, Error> { - if local_field_type.compatible_fingerprint() == remote_field_type.compatible_fingerprint() { + if field_types_compatible(local_field_type, remote_field_type) + || local_field_type.compatible_shape_match(remote_field_type) + { return Self::read_field_with_type(context, remote_field_type).map(Some); } if local_field_type.type_id == remote_field_type.type_id @@ -1445,26 +1713,14 @@ where "Type inconsistent, target collection element type is not polymorphic", )); } - if (header & DECL_ELEMENT_TYPE) == 0 { - C::read_type_info(context)?; - } - let has_null = (header & HAS_NULL) != 0; - let mut vec = Vec::with_capacity(len as usize); - if has_null { - for _ in 0..len { - let flag = context.reader.read_i8()?; - if flag == RefFlag::Null as i8 { - vec.push(C::default_value()); - } else { - vec.push(C::read_data(context)?); - } - } + let read_type = if (header & DECL_ELEMENT_TYPE) == 0 { + let codec_read_type = C::read_type_info_value(context)?; + Some(element_read_type::(context, codec_read_type)?) } else { - for _ in 0..len { - vec.push(C::read_data(context)?); - } - } - Ok(vec) + None + }; + let has_null = (header & HAS_NULL) != 0; + read_vec_items::(context, len, has_null, read_type) } fn read_data_with_type( @@ -1494,29 +1750,13 @@ where "Type inconsistent, target collection element type is not polymorphic", )); } - let owned_element_type; - let element_type = if is_declared { - generic_field_type(remote_field_type, 0, "list")? + let read_type = if is_declared { + ElementReadType::Field(generic_field_type(remote_field_type, 0, "list")?.clone()) } else { - owned_element_type = C::read_type_info_as_field_type(context)?; - &owned_element_type + let codec_read_type = C::read_type_info_value(context)?; + element_read_type::(context, codec_read_type)? }; - let mut vec = Vec::with_capacity(len as usize); - if has_null { - for _ in 0..len { - let flag = context.reader.read_i8()?; - if flag == RefFlag::Null as i8 { - vec.push(C::default_value()); - } else { - vec.push(C::read_data_with_type(context, element_type)?); - } - } - } else { - for _ in 0..len { - vec.push(C::read_data_with_type(context, element_type)?); - } - } - Ok(vec) + read_vec_items::(context, len, has_null, Some(read_type)) } #[inline(always)] @@ -1651,7 +1891,7 @@ where local_field_type: &FieldType, remote_field_type: &FieldType, ) -> Result>, Error> { - if local_field_type.compatible_fingerprint() == remote_field_type.compatible_fingerprint() { + if field_types_compatible(local_field_type, remote_field_type) { return Self::read_field_with_type(context, remote_field_type).map(Some); } read_primitive_array_vec_compatible_mismatch::( @@ -2807,10 +3047,14 @@ mod tests { let uint8_array = FieldType::new(type_id::UINT8_ARRAY, false, vec![]); let int8_array = FieldType::new(type_id::INT8_ARRAY, false, vec![]); - assert!(field_types_compatible(&bytes, &uint8_array)); - assert!(field_types_compatible(&uint8_array, &bytes)); + assert!(!field_types_compatible(&bytes, &uint8_array)); + assert!(!field_types_compatible(&uint8_array, &bytes)); + assert!(compatible_field_pair(&bytes, &uint8_array)); + assert!(compatible_field_pair(&uint8_array, &bytes)); assert!(!field_types_compatible(&bytes, &int8_array)); assert!(!field_types_compatible(&int8_array, &bytes)); + assert!(!compatible_field_pair(&bytes, &int8_array)); + assert!(!compatible_field_pair(&int8_array, &bytes)); } #[test] @@ -2832,6 +3076,7 @@ mod tests { let fixed_i32 = FieldType::new(type_id::INT32, false, vec![]); let var_i32 = FieldType::new(type_id::VARINT32, false, vec![]); assert!(!field_types_compatible(&fixed_i32, &var_i32)); + assert!(compatible_field_pair(&fixed_i32, &var_i32)); let ref_fixed_i32 = FieldType::new_with_ref(type_id::INT32, false, true, vec![]); let ref_var_i32 = FieldType::new_with_ref(type_id::VARINT32, false, true, vec![]); @@ -2851,6 +3096,27 @@ mod tests { let list_i16 = FieldType::new(type_id::LIST, false, vec![int16]); assert!(!compatible_field_pair(&list_i16, &list_i8)); + let list_fixed_i32 = FieldType::new( + type_id::LIST, + false, + vec![FieldType::new(type_id::INT32, false, vec![])], + ); + let list_var_i32 = FieldType::new( + type_id::LIST, + false, + vec![FieldType::new(type_id::VARINT32, false, vec![])], + ); + assert!(!field_types_compatible(&list_fixed_i32, &list_var_i32)); + assert!(!compatible_field_pair(&list_fixed_i32, &list_var_i32)); + + let list_nullable_i32 = FieldType::new( + type_id::LIST, + false, + vec![FieldType::new(type_id::INT32, true, vec![])], + ); + assert!(!field_types_compatible(&list_fixed_i32, &list_nullable_i32)); + assert!(compatible_field_pair(&list_fixed_i32, &list_nullable_i32)); + let int32_array = FieldType::new(type_id::INT32_ARRAY, false, vec![]); let list_i32 = FieldType::new( type_id::LIST, @@ -2858,6 +3124,7 @@ mod tests { vec![FieldType::new(type_id::INT32, false, vec![])], ); assert!(compatible_field_pair(&list_i32, &int32_array)); + assert!(compatible_field_pair(&list_nullable_i32, &int32_array)); assert!(compatible_field_pair(&int32_array, &list_i32)); } } diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index 1456ba9e34..0e84016089 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -528,19 +528,35 @@ fn read_primitive_array_data_bulk( } } -fn list_element_type_matches_array(list: &FieldType, array: &FieldType) -> bool { +fn list_element_type_matches_array( + list: &FieldType, + array: &FieldType, + require_unframed_element: bool, +) -> bool { primitive_array_element_type_id(array.type_id).is_some_and(|element_type_id| { - list.type_id == type_id::LIST - && list.generics.len() == 1 - && primitive_array_element_type_matches(element_type_id, list.generics[0].type_id) + if list.type_id != type_id::LIST + || list.generics.len() != 1 + || list.nullable + || list.track_ref + || array.nullable + || array.track_ref + { + return false; + } + let element = &list.generics[0]; + // Nullable element schema is allowed for list -> array; actual + // null payload elements fail in the dense-array reader. Ref-tracked + // element framing is rejected here because this path stays primitive-only. + if require_unframed_element && element.track_ref { + return false; + } + primitive_array_element_type_matches(element_type_id, element.type_id) }) } pub(super) fn compatible_list_array_field(local: &FieldType, remote: &FieldType) -> bool { - (local.type_id == type_id::LIST && list_element_type_matches_array(local, remote)) - || (remote.type_id == type_id::LIST - && !remote.generics.is_empty() - && list_element_type_matches_array(remote, local)) + (local.type_id == type_id::LIST && list_element_type_matches_array(local, remote, false)) + || (remote.type_id == type_id::LIST && list_element_type_matches_array(remote, local, true)) } fn primitive_array_element_type_matches( @@ -794,7 +810,7 @@ where C: Codec, { if local_field_type.type_id == type_id::LIST - && list_element_type_matches_array(local_field_type, remote_field_type) + && list_element_type_matches_array(local_field_type, remote_field_type, false) { return read_array_data_as_vec_bridge::(context, remote_field_type).map(Some); } @@ -836,7 +852,7 @@ where { if remote_field_type.type_id == type_id::LIST && !remote_field_type.generics.is_empty() - && list_element_type_matches_array(remote_field_type, local_field_type) + && list_element_type_matches_array(remote_field_type, local_field_type, true) { if field_ref_mode(remote_field_type) != RefMode::None { let ref_flag = context.reader.read_i8()?; diff --git a/rust/fory-core/src/serializer/scalar_conversion.rs b/rust/fory-core/src/serializer/scalar_conversion.rs index c61cd0d885..ed757ea8a7 100644 --- a/rust/fory-core/src/serializer/scalar_conversion.rs +++ b/rust/fory-core/src/serializer/scalar_conversion.rs @@ -18,9 +18,9 @@ use super::codec::{field_ref_mode, Codec}; use crate::context::ReadContext; use crate::error::Error; -use crate::meta::FieldType; +use crate::meta::{FieldInfo, FieldType}; use crate::resolver::{RefFlag, RefMode}; -use crate::serializer::Serializer; +use crate::serializer::{ForyDefault, Serializer}; use crate::type_id; use crate::types::{bfloat16::bfloat16, float16::float16, Decimal}; use num_bigint::BigInt; @@ -90,31 +90,931 @@ where boxed_to_value(converted).map(|value| Some(Some(value))) } +macro_rules! scalar_target_reader { + ($read:ident, $read_option:ident, $ty:ty, $payload:ident) => { + #[inline(never)] + pub(super) fn $read( + context: &mut ReadContext, + local_type: u32, + remote_field: &FieldInfo, + ) -> Result<$ty, Error> { + let remote_field_type = &remote_field.field_type; + if !read_present_ref(context, remote_field_type)? { + return Ok(<$ty as ForyDefault>::fory_default()); + } + // The doubled compatible arm is reached only after schema-pair + // classification accepts a scalar pair. This dispatch only chooses + // the remote wire payload reader. + $payload(context, local_type, remote_field_type.type_id) + } + + #[inline(never)] + pub(super) fn $read_option( + context: &mut ReadContext, + local_type: u32, + remote_field: &FieldInfo, + ) -> Result, Error> { + let remote_field_type = &remote_field.field_type; + if !read_present_ref(context, remote_field_type)? { + return Ok(None); + } + $payload(context, local_type, remote_field_type.type_id).map(Some) + } + }; +} + +scalar_target_reader!( + read_bool_target, + read_bool_option_target, + bool, + read_bool_payload +); +scalar_target_reader!( + read_string_target, + read_string_option_target, + String, + read_string_payload +); +scalar_target_reader!(read_i8_target, read_i8_option_target, i8, read_i8_payload); +scalar_target_reader!( + read_i16_target, + read_i16_option_target, + i16, + read_i16_payload +); +scalar_target_reader!( + read_i32_target, + read_i32_option_target, + i32, + read_i32_payload +); +scalar_target_reader!(read_u8_target, read_u8_option_target, u8, read_u8_payload); +scalar_target_reader!( + read_u16_target, + read_u16_option_target, + u16, + read_u16_payload +); +scalar_target_reader!( + read_u32_target, + read_u32_option_target, + u32, + read_u32_payload +); +scalar_target_reader!( + read_u64_target, + read_u64_option_target, + u64, + read_u64_payload +); +scalar_target_reader!( + read_f32_target, + read_f32_option_target, + f32, + read_f32_payload +); +scalar_target_reader!( + read_f64_target, + read_f64_option_target, + f64, + read_f64_payload +); +scalar_target_reader!( + read_float16_target, + read_float16_option_target, + float16, + read_float16_payload +); +scalar_target_reader!( + read_bfloat16_target, + read_bfloat16_option_target, + bfloat16, + read_bfloat16_payload +); +scalar_target_reader!( + read_decimal_target, + read_decimal_option_target, + Decimal, + read_decimal_payload +); + +#[inline(never)] +pub(super) fn read_i64_target( + context: &mut ReadContext, + local_type: u32, + remote_field: &FieldInfo, +) -> Result { + let remote_field_type = &remote_field.field_type; + if !read_present_ref(context, remote_field_type)? { + return Ok(::fory_default()); + } + read_i64_payload(context, local_type, remote_field_type.type_id) +} + +#[inline(never)] +pub(super) fn read_i64_option_target( + context: &mut ReadContext, + local_type: u32, + remote_field: &FieldInfo, +) -> Result, Error> { + let remote_field_type = &remote_field.field_type; + if !read_present_ref(context, remote_field_type)? { + return Ok(None); + } + read_i64_payload(context, local_type, remote_field_type.type_id).map(Some) +} + +#[inline(always)] +fn read_i64_payload( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::BOOL => match context.reader.read_u8()? { + 0 => Ok(0), + 1 => Ok(1), + _ => Err(conversion_error( + remote_type, + local_type, + "invalid bool payload", + )), + }, + type_id::INT8 => Ok(i64::from(context.reader.read_i8()?)), + type_id::INT16 => Ok(i64::from(context.reader.read_i16()?)), + type_id::INT32 => Ok(i64::from(context.reader.read_i32()?)), + type_id::VARINT32 => Ok(i64::from(context.reader.read_var_i32()?)), + type_id::INT64 => context.reader.read_i64(), + type_id::VARINT64 => context.reader.read_var_i64(), + type_id::TAGGED_INT64 => context.reader.read_tagged_i64(), + type_id::UINT8 => Ok(i64::from(context.reader.read_u8()?)), + type_id::UINT16 => Ok(i64::from(context.reader.read_u16()?)), + type_id::UINT32 => Ok(i64::from(context.reader.read_u32()?)), + type_id::VAR_UINT32 => Ok(i64::from(context.reader.read_var_u32()?)), + type_id::UINT64 => u64_to_i64(context.reader.read_u64()?, remote_type, local_type), + type_id::VAR_UINT64 => u64_to_i64(context.reader.read_var_u64()?, remote_type, local_type), + type_id::TAGGED_UINT64 => { + u64_to_i64(context.reader.read_tagged_u64()?, remote_type, local_type) + } + _ => read_i64_cold(context, local_type, remote_type), + } +} + +#[inline(always)] +fn u64_to_i64(value: u64, remote_type: u32, local_type: u32) -> Result { + i64::try_from(value) + .map_err(|_| conversion_error(remote_type, local_type, "integer value is out of range")) +} + +#[cold] +#[inline(never)] +fn read_i64_cold( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::FLOAT16 | type_id::BFLOAT16 | type_id::FLOAT32 | type_id::FLOAT64 => { + let value = read_float_value(context, remote_type)?; + float_to_integral_num(value, remote_type, local_type, false) + } + type_id::STRING => { + let value = String::fory_read_data(context)?; + string_to_integral_num(&value, remote_type, local_type, false) + } + type_id::DECIMAL => { + let value = Decimal::fory_read_data(context)?; + decimal_to_integral_num(&value, remote_type, local_type, false) + } + _ => Err(Error::invalid_data("invalid compatible scalar remote type")), + } +} + +#[inline(always)] +fn read_bool_payload( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + let value = match remote_type { + type_id::BOOL => { + return match context.reader.read_u8()? { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(conversion_error( + remote_type, + local_type, + "invalid bool payload", + )), + }; + } + type_id::INT8 + | type_id::INT16 + | type_id::INT32 + | type_id::VARINT32 + | type_id::INT64 + | type_id::VARINT64 + | type_id::TAGGED_INT64 => read_i64_payload(context, local_type, remote_type)?, + type_id::UINT8 + | type_id::UINT16 + | type_id::UINT32 + | type_id::VAR_UINT32 + | type_id::UINT64 + | type_id::VAR_UINT64 + | type_id::TAGGED_UINT64 => { + let unsigned = read_u64_payload(context, local_type, remote_type)?; + if unsigned == 0 { + return Ok(false); + } + if unsigned == 1 { + return Ok(true); + } + return Err(conversion_error( + remote_type, + local_type, + "numeric value is not 0 or 1", + )); + } + _ => return read_bool_cold(context, local_type, remote_type), + }; + match value { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(conversion_error( + remote_type, + local_type, + "numeric value is not 0 or 1", + )), + } +} + +#[inline(always)] +fn read_string_payload( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::BOOL => match context.reader.read_u8()? { + 0 => Ok("false".to_string()), + 1 => Ok("true".to_string()), + _ => Err(conversion_error( + remote_type, + local_type, + "invalid bool payload", + )), + }, + type_id::STRING => String::fory_read_data(context), + type_id::INT8 + | type_id::INT16 + | type_id::INT32 + | type_id::VARINT32 + | type_id::INT64 + | type_id::VARINT64 + | type_id::TAGGED_INT64 => { + read_i64_payload(context, local_type, remote_type).map(|value| value.to_string()) + } + type_id::UINT8 + | type_id::UINT16 + | type_id::UINT32 + | type_id::VAR_UINT32 + | type_id::UINT64 + | type_id::VAR_UINT64 + | type_id::TAGGED_UINT64 => { + read_u64_payload(context, local_type, remote_type).map(|value| value.to_string()) + } + _ => read_string_cold(context, local_type, remote_type), + } +} + +macro_rules! signed_payload { + ($name:ident, $ty:ty) => { + #[inline(always)] + fn $name( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, + ) -> Result<$ty, Error> { + let value = read_i64_payload(context, local_type, remote_type)?; + <$ty>::try_from(value).map_err(|_| { + conversion_error(remote_type, local_type, "integer value is out of range") + }) + } + }; +} + +signed_payload!(read_i8_payload, i8); +signed_payload!(read_i16_payload, i16); +signed_payload!(read_i32_payload, i32); + +#[inline(always)] +fn read_u64_payload( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::BOOL => match context.reader.read_u8()? { + 0 => Ok(0), + 1 => Ok(1), + _ => Err(conversion_error( + remote_type, + local_type, + "invalid bool payload", + )), + }, + type_id::INT8 => signed_to_u64( + i64::from(context.reader.read_i8()?), + remote_type, + local_type, + ), + type_id::INT16 => signed_to_u64( + i64::from(context.reader.read_i16()?), + remote_type, + local_type, + ), + type_id::INT32 => signed_to_u64( + i64::from(context.reader.read_i32()?), + remote_type, + local_type, + ), + type_id::VARINT32 => signed_to_u64( + i64::from(context.reader.read_var_i32()?), + remote_type, + local_type, + ), + type_id::INT64 => signed_to_u64(context.reader.read_i64()?, remote_type, local_type), + type_id::VARINT64 => signed_to_u64(context.reader.read_var_i64()?, remote_type, local_type), + type_id::TAGGED_INT64 => { + signed_to_u64(context.reader.read_tagged_i64()?, remote_type, local_type) + } + type_id::UINT8 => Ok(u64::from(context.reader.read_u8()?)), + type_id::UINT16 => Ok(u64::from(context.reader.read_u16()?)), + type_id::UINT32 => Ok(u64::from(context.reader.read_u32()?)), + type_id::VAR_UINT32 => Ok(u64::from(context.reader.read_var_u32()?)), + type_id::UINT64 => context.reader.read_u64(), + type_id::VAR_UINT64 => context.reader.read_var_u64(), + type_id::TAGGED_UINT64 => context.reader.read_tagged_u64(), + _ => read_u64_cold(context, local_type, remote_type), + } +} + +#[inline(always)] +fn signed_to_u64(value: i64, remote_type: u32, local_type: u32) -> Result { + u64::try_from(value) + .map_err(|_| conversion_error(remote_type, local_type, "integer value is out of range")) +} + +macro_rules! unsigned_payload { + ($name:ident, $ty:ty) => { + #[inline(always)] + fn $name( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, + ) -> Result<$ty, Error> { + let value = read_u64_payload(context, local_type, remote_type)?; + <$ty>::try_from(value).map_err(|_| { + conversion_error(remote_type, local_type, "integer value is out of range") + }) + } + }; +} + +unsigned_payload!(read_u8_payload, u8); +unsigned_payload!(read_u16_payload, u16); +unsigned_payload!(read_u32_payload, u32); + +#[inline(always)] +fn read_f32_payload( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::BOOL => match context.reader.read_u8()? { + 0 => Ok(0.0), + 1 => Ok(1.0), + _ => Err(conversion_error( + remote_type, + local_type, + "invalid bool payload", + )), + }, + type_id::INT8 + | type_id::INT16 + | type_id::INT32 + | type_id::VARINT32 + | type_id::INT64 + | type_id::VARINT64 + | type_id::TAGGED_INT64 => { + let value = read_i64_payload(context, local_type, remote_type)?; + signed_integer_to_f32(value, remote_type, local_type) + } + type_id::UINT8 + | type_id::UINT16 + | type_id::UINT32 + | type_id::VAR_UINT32 + | type_id::UINT64 + | type_id::VAR_UINT64 + | type_id::TAGGED_UINT64 => { + let value = read_u64_payload(context, local_type, remote_type)?; + unsigned_integer_to_f32(value, remote_type, local_type) + } + type_id::FLOAT16 => { + let value = context.reader.read_f16()?; + checked_float16(value, remote_type, local_type).map(float16::to_f32) + } + type_id::BFLOAT16 => { + let value = context.reader.read_bf16()?; + checked_bfloat16(value, remote_type, local_type).map(bfloat16::to_f32) + } + type_id::FLOAT32 => checked_f32(context.reader.read_f32()?, remote_type, local_type), + type_id::FLOAT64 => f64_to_f32_exact(context.reader.read_f64()?, remote_type, local_type), + _ => read_f32_cold(context, local_type, remote_type), + } +} + +#[inline(always)] +fn read_f64_payload( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::BOOL => match context.reader.read_u8()? { + 0 => Ok(0.0), + 1 => Ok(1.0), + _ => Err(conversion_error( + remote_type, + local_type, + "invalid bool payload", + )), + }, + type_id::INT8 + | type_id::INT16 + | type_id::INT32 + | type_id::VARINT32 + | type_id::INT64 + | type_id::VARINT64 + | type_id::TAGGED_INT64 => { + let value = read_i64_payload(context, local_type, remote_type)?; + signed_integer_to_f64(value, remote_type, local_type) + } + type_id::UINT8 + | type_id::UINT16 + | type_id::UINT32 + | type_id::VAR_UINT32 + | type_id::UINT64 + | type_id::VAR_UINT64 + | type_id::TAGGED_UINT64 => { + let value = read_u64_payload(context, local_type, remote_type)?; + unsigned_integer_to_f64(value, remote_type, local_type) + } + type_id::FLOAT16 => { + let value = context.reader.read_f16()?; + checked_float16(value, remote_type, local_type).map(|value| f64::from(value.to_f32())) + } + type_id::BFLOAT16 => { + let value = context.reader.read_bf16()?; + checked_bfloat16(value, remote_type, local_type).map(|value| f64::from(value.to_f32())) + } + type_id::FLOAT32 => { + checked_f32(context.reader.read_f32()?, remote_type, local_type).map(f64::from) + } + type_id::FLOAT64 => checked_f64(context.reader.read_f64()?, remote_type, local_type), + _ => read_f64_cold(context, local_type, remote_type), + } +} + +#[inline(always)] +fn read_float16_payload( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::FLOAT16 => checked_float16(context.reader.read_f16()?, remote_type, local_type), + _ => match remote_type { + type_id::STRING => { + let value = String::fory_read_data(context)?; + string_to_float16_value(&value, remote_type, local_type) + } + type_id::DECIMAL => { + let value = Decimal::fory_read_data(context)?; + decimal_to_float16(&value, false, remote_type, local_type) + } + _ => { + let value = read_f32_payload(context, local_type, remote_type)?; + f32_to_float16_exact(value, remote_type, local_type) + } + }, + } +} + +#[inline(always)] +fn read_bfloat16_payload( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::BFLOAT16 => checked_bfloat16(context.reader.read_bf16()?, remote_type, local_type), + _ => match remote_type { + type_id::STRING => { + let value = String::fory_read_data(context)?; + string_to_bfloat16_value(&value, remote_type, local_type) + } + type_id::DECIMAL => { + let value = Decimal::fory_read_data(context)?; + decimal_to_bfloat16(&value, false, remote_type, local_type) + } + _ => { + let value = read_f32_payload(context, local_type, remote_type)?; + f32_to_bfloat16_exact(value, remote_type, local_type) + } + }, + } +} + +#[inline(always)] +fn read_decimal_payload( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::BOOL => match context.reader.read_u8()? { + 0 => Ok(Decimal::new(BigInt::zero(), 0)), + 1 => Ok(Decimal::new(BigInt::one(), 0)), + _ => Err(conversion_error( + remote_type, + local_type, + "invalid bool payload", + )), + }, + type_id::INT8 + | type_id::INT16 + | type_id::INT32 + | type_id::VARINT32 + | type_id::INT64 + | type_id::VARINT64 + | type_id::TAGGED_INT64 => { + let value = read_i64_payload(context, local_type, remote_type)?; + Ok(Decimal::new(BigInt::from(value), 0)) + } + type_id::UINT8 + | type_id::UINT16 + | type_id::UINT32 + | type_id::VAR_UINT32 + | type_id::UINT64 + | type_id::VAR_UINT64 + | type_id::TAGGED_UINT64 => { + let value = read_u64_payload(context, local_type, remote_type)?; + Ok(Decimal::new(BigInt::from(value), 0)) + } + type_id::FLOAT16 | type_id::BFLOAT16 | type_id::FLOAT32 | type_id::FLOAT64 => { + let value = read_float_value(context, remote_type)?; + float_to_decimal_value(value, remote_type, local_type) + } + type_id::STRING => { + let value = String::fory_read_data(context)?; + string_to_decimal_value(&value, remote_type, local_type).map(|(decimal, _)| decimal) + } + type_id::DECIMAL => canonical_decimal(Decimal::fory_read_data(context)?), + _ => Err(Error::invalid_data("invalid compatible scalar remote type")), + } +} + +#[cold] +#[inline(never)] +fn read_bool_cold( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::FLOAT16 | type_id::BFLOAT16 | type_id::FLOAT32 | type_id::FLOAT64 => { + let value = read_float_value(context, remote_type)?; + float_to_bool_value(value, remote_type, local_type) + } + type_id::STRING => { + let value = String::fory_read_data(context)?; + string_to_bool_value(&value, remote_type, local_type) + } + type_id::DECIMAL => { + let value = Decimal::fory_read_data(context)?; + decimal_to_bool_value(&value, remote_type, local_type) + } + _ => Err(Error::invalid_data("invalid compatible scalar remote type")), + } +} + +#[cold] +#[inline(never)] +fn read_string_cold( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::FLOAT16 | type_id::BFLOAT16 | type_id::FLOAT32 | type_id::FLOAT64 => { + let value = read_float_value(context, remote_type)?; + float_to_string(value, remote_type, local_type) + } + type_id::DECIMAL => { + let value = canonical_decimal(Decimal::fory_read_data(context)?)?; + Ok(decimal_to_string(&value)) + } + _ => Err(Error::invalid_data("invalid compatible scalar remote type")), + } +} + +#[cold] +#[inline(never)] +fn read_u64_cold( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::FLOAT16 | type_id::BFLOAT16 | type_id::FLOAT32 | type_id::FLOAT64 => { + let value = read_float_value(context, remote_type)?; + float_to_integral_num(value, remote_type, local_type, true) + } + type_id::STRING => { + let value = String::fory_read_data(context)?; + string_to_integral_num(&value, remote_type, local_type, true) + } + type_id::DECIMAL => { + let value = Decimal::fory_read_data(context)?; + decimal_to_integral_num(&value, remote_type, local_type, true) + } + _ => Err(Error::invalid_data("invalid compatible scalar remote type")), + } +} + +#[cold] +#[inline(never)] +fn read_f32_cold( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::STRING => { + let value = String::fory_read_data(context)?; + string_to_f32_value(&value, remote_type, local_type) + } + type_id::DECIMAL => { + let value = Decimal::fory_read_data(context)?; + decimal_to_f32(&value, false, remote_type, local_type) + } + _ => Err(Error::invalid_data("invalid compatible scalar remote type")), + } +} + +#[cold] +#[inline(never)] +fn read_f64_cold( + context: &mut ReadContext, + local_type: u32, + remote_type: u32, +) -> Result { + match remote_type { + type_id::STRING => { + let value = String::fory_read_data(context)?; + string_to_f64_value(&value, remote_type, local_type) + } + type_id::DECIMAL => { + let value = Decimal::fory_read_data(context)?; + decimal_to_f64(&value, false, remote_type, local_type) + } + _ => Err(Error::invalid_data("invalid compatible scalar remote type")), + } +} + +#[inline(always)] +fn read_float_value(context: &mut ReadContext, remote_type: u32) -> Result { + match remote_type { + type_id::FLOAT16 => Ok(FloatValue::F16(context.reader.read_f16()?)), + type_id::BFLOAT16 => Ok(FloatValue::BF16(context.reader.read_bf16()?)), + type_id::FLOAT32 => Ok(FloatValue::F32(context.reader.read_f32()?)), + type_id::FLOAT64 => Ok(FloatValue::F64(context.reader.read_f64()?)), + _ => Err(Error::invalid_data("invalid compatible scalar remote type")), + } +} + +#[inline(always)] +fn integer_exact_in_binary(value: u64, precision: u32) -> bool { + if value == 0 { + return true; + } + let bits = u64::BITS - value.leading_zeros(); + bits <= precision || value.trailing_zeros() >= bits - precision +} + +#[inline(always)] +fn signed_integer_to_f32(value: i64, remote_type: u32, local_type: u32) -> Result { + let magnitude = value.unsigned_abs(); + if integer_exact_in_binary(magnitude, f32::MANTISSA_DIGITS) { + Ok(value as f32) + } else { + Err(conversion_error( + remote_type, + local_type, + "integer value is not exactly representable by target float", + )) + } +} + +#[inline(always)] +fn unsigned_integer_to_f32(value: u64, remote_type: u32, local_type: u32) -> Result { + if integer_exact_in_binary(value, f32::MANTISSA_DIGITS) { + Ok(value as f32) + } else { + Err(conversion_error( + remote_type, + local_type, + "integer value is not exactly representable by target float", + )) + } +} + +#[inline(always)] +fn signed_integer_to_f64(value: i64, remote_type: u32, local_type: u32) -> Result { + let magnitude = value.unsigned_abs(); + if integer_exact_in_binary(magnitude, f64::MANTISSA_DIGITS) { + Ok(value as f64) + } else { + Err(conversion_error( + remote_type, + local_type, + "integer value is not exactly representable by target float", + )) + } +} + +#[inline(always)] +fn unsigned_integer_to_f64(value: u64, remote_type: u32, local_type: u32) -> Result { + if integer_exact_in_binary(value, f64::MANTISSA_DIGITS) { + Ok(value as f64) + } else { + Err(conversion_error( + remote_type, + local_type, + "integer value is not exactly representable by target float", + )) + } +} + +#[inline(always)] +fn f64_to_f32_exact(value: f64, remote_type: u32, local_type: u32) -> Result { + if value.is_nan() { + return Err(conversion_error( + remote_type, + local_type, + "NaN is not convertible", + )); + } + if value.is_infinite() { + return Ok(if value.is_sign_negative() { + f32::NEG_INFINITY + } else { + f32::INFINITY + }); + } + let converted = value as f32; + if f64::from(converted) == value + && (converted != 0.0 || converted.is_sign_negative() == value.is_sign_negative()) + { + Ok(converted) + } else { + Err(conversion_error( + remote_type, + local_type, + "float value is not exactly representable by target float", + )) + } +} + +#[inline(always)] +fn f32_to_float16_exact(value: f32, remote_type: u32, local_type: u32) -> Result { + if value.is_nan() { + return Err(conversion_error( + remote_type, + local_type, + "NaN is not convertible", + )); + } + if value.is_infinite() { + return Ok(if value.is_sign_negative() { + float16::NEG_INFINITY + } else { + float16::INFINITY + }); + } + let converted = float16::from_f32(value); + let roundtrip = converted.to_f32(); + if roundtrip == value + && (roundtrip != 0.0 || roundtrip.is_sign_negative() == value.is_sign_negative()) + { + Ok(converted) + } else { + Err(conversion_error( + remote_type, + local_type, + "float value is not exactly representable by target float", + )) + } +} + +#[inline(always)] +fn f32_to_bfloat16_exact(value: f32, remote_type: u32, local_type: u32) -> Result { + if value.is_nan() { + return Err(conversion_error( + remote_type, + local_type, + "NaN is not convertible", + )); + } + if value.is_infinite() { + return Ok(if value.is_sign_negative() { + bfloat16::NEG_INFINITY + } else { + bfloat16::INFINITY + }); + } + let converted = bfloat16::from_f32(value); + let roundtrip = converted.to_f32(); + if roundtrip == value + && (roundtrip != 0.0 || roundtrip.is_sign_negative() == value.is_sign_negative()) + { + Ok(converted) + } else { + Err(conversion_error( + remote_type, + local_type, + "float value is not exactly representable by target float", + )) + } +} + #[inline(always)] -pub(super) fn scalar_types_compatible(local: u32, remote: u32) -> bool { - if local == remote { - return is_compatible_scalar_type(local); +fn checked_f32(value: f32, remote_type: u32, local_type: u32) -> Result { + if value.is_nan() { + Err(conversion_error( + remote_type, + local_type, + "NaN is not convertible", + )) + } else { + Ok(value) + } +} + +#[inline(always)] +fn checked_f64(value: f64, remote_type: u32, local_type: u32) -> Result { + if value.is_nan() { + Err(conversion_error( + remote_type, + local_type, + "NaN is not convertible", + )) + } else { + Ok(value) } - let local_numeric = numeric_type(local); - let remote_numeric = numeric_type(remote); - (local == type_id::BOOL && (remote == type_id::STRING || remote_numeric)) - || (remote == type_id::BOOL && (local == type_id::STRING || local_numeric)) - || (local == type_id::STRING && remote_numeric) - || (remote == type_id::STRING && local_numeric) - || (local_numeric && remote_numeric) } #[inline(always)] -pub(super) fn is_compatible_scalar_type(type_id: u32) -> bool { - type_id == type_id::BOOL || type_id == type_id::STRING || numeric_type(type_id) +fn checked_float16(value: float16, remote_type: u32, local_type: u32) -> Result { + if value.is_nan() { + Err(conversion_error( + remote_type, + local_type, + "NaN is not convertible", + )) + } else { + Ok(value) + } +} + +#[inline(always)] +fn checked_bfloat16(value: bfloat16, remote_type: u32, local_type: u32) -> Result { + if value.is_nan() { + Err(conversion_error( + remote_type, + local_type, + "NaN is not convertible", + )) + } else { + Ok(value) + } } #[inline(always)] pub(super) fn scalar_field_types_compatible(local: &FieldType, remote: &FieldType) -> bool { - !local.track_ref - && !remote.track_ref - && (local.type_id != remote.type_id || local.nullable != remote.nullable) - && scalar_types_compatible(local.type_id, remote.type_id) + crate::meta::compatible_scalar_field_pair(local, remote) } #[inline(always)] @@ -235,126 +1135,304 @@ fn convert_scalar( _ => Err(conversion_error( remote_type, local_type, - "unsupported scalar target", + "unsupported scalar target", + )), + } +} + +trait IntoOk { + fn into_ok(self) -> Result, Error>; +} + +impl IntoOk for Box { + #[inline(always)] + fn into_ok(self) -> Result, Error> { + Ok(self) + } +} + +fn value_to_bool(value: ScalarValue, remote_type: u32, local_type: u32) -> Result { + match value { + ScalarValue::Bool(value) => Ok(value), + ScalarValue::String(value) => match value.as_str() { + "0" | "false" => Ok(false), + "1" | "true" => Ok(true), + _ => Err(conversion_error( + remote_type, + local_type, + "string is not an exact bool literal", + )), + }, + value => { + if numeric_zero(&value, remote_type, local_type)? { + Ok(false) + } else if numeric_one(&value, remote_type, local_type)? { + Ok(true) + } else { + Err(conversion_error( + remote_type, + local_type, + "numeric value is not 0 or 1", + )) + } + } + } +} + +fn value_to_string(value: ScalarValue, remote_type: u32, local_type: u32) -> Result { + match value { + ScalarValue::Bool(value) => Ok(if value { "true" } else { "false" }.to_string()), + ScalarValue::String(value) => Ok(value), + ScalarValue::Int(value) => Ok(value.to_string()), + ScalarValue::Float(value) => float_to_string(value, remote_type, local_type), + ScalarValue::Decimal(value) => Ok(decimal_to_string(&canonical_decimal(value)?)), + } +} + +fn value_to_number( + value: ScalarValue, + local_type: u32, + remote_type: u32, +) -> Result, Error> { + match local_type { + type_id::INT8 => Box::new(value_to_i8(value, remote_type, local_type)?).into_ok(), + type_id::INT16 => Box::new(value_to_i16(value, remote_type, local_type)?).into_ok(), + type_id::INT32 | type_id::VARINT32 => { + Box::new(value_to_i32(value, remote_type, local_type)?).into_ok() + } + type_id::INT64 | type_id::VARINT64 | type_id::TAGGED_INT64 => { + Box::new(value_to_i64(value, remote_type, local_type)?).into_ok() + } + type_id::UINT8 => Box::new(value_to_u8(value, remote_type, local_type)?).into_ok(), + type_id::UINT16 => Box::new(value_to_u16(value, remote_type, local_type)?).into_ok(), + type_id::UINT32 | type_id::VAR_UINT32 => { + Box::new(value_to_u32(value, remote_type, local_type)?).into_ok() + } + type_id::UINT64 | type_id::VAR_UINT64 | type_id::TAGGED_UINT64 => { + Box::new(value_to_u64(value, remote_type, local_type)?).into_ok() + } + type_id::FLOAT16 => Box::new(value_to_float16(value, remote_type, local_type)?).into_ok(), + type_id::BFLOAT16 => Box::new(value_to_bfloat16(value, remote_type, local_type)?).into_ok(), + type_id::FLOAT32 => Box::new(value_to_f32(value, remote_type, local_type)?).into_ok(), + type_id::FLOAT64 => Box::new(value_to_f64(value, remote_type, local_type)?).into_ok(), + type_id::DECIMAL => Box::new(value_to_decimal(value, remote_type, local_type)?).into_ok(), + _ => Err(conversion_error( + remote_type, + local_type, + "unsupported numeric target", + )), + } +} + +macro_rules! signed_target { + ($name:ident, $ty:ty) => { + fn $name(value: ScalarValue, remote_type: u32, local_type: u32) -> Result<$ty, Error> { + value_to_integral::<$ty>(value, remote_type, local_type, false) + } + }; +} + +macro_rules! unsigned_target { + ($name:ident, $ty:ty) => { + fn $name(value: ScalarValue, remote_type: u32, local_type: u32) -> Result<$ty, Error> { + value_to_integral::<$ty>(value, remote_type, local_type, true) + } + }; +} + +signed_target!(value_to_i8, i8); +signed_target!(value_to_i16, i16); +signed_target!(value_to_i32, i32); +signed_target!(value_to_i64, i64); +unsigned_target!(value_to_u8, u8); +unsigned_target!(value_to_u16, u16); +unsigned_target!(value_to_u32, u32); +unsigned_target!(value_to_u64, u64); + +fn value_to_integral( + value: ScalarValue, + remote_type: u32, + local_type: u32, + unsigned: bool, +) -> Result +where + T: FromBigInt, +{ + let (decimal, _) = value_to_decimal_parts(value, remote_type, local_type)?; + decimal_to_integral_num(&decimal, remote_type, local_type, unsigned) +} + +fn string_to_integral_num( + value: &str, + remote_type: u32, + local_type: u32, + unsigned: bool, +) -> Result +where + T: FromBigInt, +{ + let (decimal, _) = string_to_decimal_value(value, remote_type, local_type)?; + decimal_to_integral_num(&decimal, remote_type, local_type, unsigned) +} + +fn float_to_integral_num( + value: FloatValue, + remote_type: u32, + local_type: u32, + unsigned: bool, +) -> Result +where + T: FromBigInt, +{ + let (decimal, _) = finite_float_decimal(value, remote_type, local_type)?; + decimal_to_integral_num(&decimal, remote_type, local_type, unsigned) +} + +fn decimal_to_integral_num( + decimal: &Decimal, + remote_type: u32, + local_type: u32, + unsigned: bool, +) -> Result +where + T: FromBigInt, +{ + let value = decimal_to_integral(decimal, remote_type, local_type)?; + if unsigned && value.is_negative() { + return Err(conversion_error( + remote_type, + local_type, + "negative value cannot convert to unsigned integer", + )); + } + T::from_bigint(&value) + .ok_or_else(|| conversion_error(remote_type, local_type, "integer value is out of range")) +} + +fn string_to_bool_value(value: &str, remote_type: u32, local_type: u32) -> Result { + match value { + "0" | "false" => Ok(false), + "1" | "true" => Ok(true), + _ => Err(conversion_error( + remote_type, + local_type, + "string is not an exact bool literal", )), } } -trait IntoOk { - fn into_ok(self) -> Result, Error>; -} - -impl IntoOk for Box { - #[inline(always)] - fn into_ok(self) -> Result, Error> { - Ok(self) +fn decimal_to_bool_value( + value: &Decimal, + remote_type: u32, + local_type: u32, +) -> Result { + if value.unscaled.is_zero() { + Ok(false) + } else if decimal_eq(value, &Decimal::new(BigInt::one(), 0)) { + Ok(true) + } else { + Err(conversion_error( + remote_type, + local_type, + "numeric value is not 0 or 1", + )) } } -fn value_to_bool(value: ScalarValue, remote_type: u32, local_type: u32) -> Result { - match value { - ScalarValue::Bool(value) => Ok(value), - ScalarValue::String(value) => match value.as_str() { - "0" | "false" => Ok(false), - "1" | "true" => Ok(true), - _ => Err(conversion_error( - remote_type, - local_type, - "string is not an exact bool literal", - )), - }, - value => { - if numeric_zero(&value, remote_type, local_type)? { - Ok(false) - } else if numeric_one(&value, remote_type, local_type)? { - Ok(true) - } else { - Err(conversion_error( - remote_type, - local_type, - "numeric value is not 0 or 1", - )) - } - } +fn float_to_bool_value( + value: FloatValue, + remote_type: u32, + local_type: u32, +) -> Result { + if float_is_nan(value) || float_is_infinite(value) { + return Err(conversion_error( + remote_type, + local_type, + "non-finite float is not convertible to bool", + )); } -} - -fn value_to_string(value: ScalarValue, remote_type: u32, local_type: u32) -> Result { - match value { - ScalarValue::Bool(value) => Ok(if value { "true" } else { "false" }.to_string()), - ScalarValue::String(value) => Ok(value), - ScalarValue::Int(value) => Ok(value.to_string()), - ScalarValue::Float(value) => float_to_string(value, remote_type, local_type), - ScalarValue::Decimal(value) => Ok(decimal_to_string(&canonical_decimal(value)?)), + if float_is_zero(value) { + return Ok(false); } + let decimal = canonical_float_decimal(value, remote_type, local_type)?; + decimal_to_bool_value(&decimal, remote_type, local_type) } -fn value_to_number( +fn value_to_decimal( value: ScalarValue, + remote_type: u32, local_type: u32, +) -> Result { + let (decimal, _) = value_to_decimal_parts(value, remote_type, local_type)?; + canonical_decimal(decimal) +} + +fn value_to_decimal_parts( + value: ScalarValue, remote_type: u32, -) -> Result, Error> { + local_type: u32, +) -> Result<(Decimal, bool), Error> { match value { - ScalarValue::Bool(value) => number_from_decimal( - &Decimal::new(BigInt::from(if value { 1 } else { 0 }), 0), + ScalarValue::Bool(value) => Ok(( + Decimal::new(BigInt::from(if value { 1 } else { 0 }), 0), false, - local_type, - remote_type, - ), + )), ScalarValue::String(value) => { let parsed = parse_number(&value).ok_or_else(|| { conversion_error(remote_type, local_type, "invalid numeric literal") })?; - number_from_decimal( - &parsed.decimal, - parsed.negative_zero, - local_type, - remote_type, - ) - } - ScalarValue::Int(value) => { - number_from_decimal(&Decimal::new(value, 0), false, local_type, remote_type) + Ok((parsed.decimal, parsed.negative_zero)) } - ScalarValue::Decimal(value) => { - let decimal = canonical_decimal(value)?; - number_from_decimal(&decimal, false, local_type, remote_type) - } - ScalarValue::Float(value) => float_to_number(value, local_type, remote_type), + ScalarValue::Int(value) => Ok((Decimal::new(value, 0), false)), + ScalarValue::Decimal(value) => canonical_decimal(value).map(|value| (value, false)), + ScalarValue::Float(value) => finite_float_decimal(value, remote_type, local_type), } } -fn number_from_decimal( - decimal: &Decimal, - negative_zero: bool, +fn string_to_decimal_value( + value: &str, + remote_type: u32, local_type: u32, +) -> Result<(Decimal, bool), Error> { + let parsed = parse_number(value) + .ok_or_else(|| conversion_error(remote_type, local_type, "invalid numeric literal"))?; + Ok((parsed.decimal, parsed.negative_zero)) +} + +fn float_to_decimal_value( + value: FloatValue, remote_type: u32, -) -> Result, Error> { - match local_type { - type_id::INT8 => boxed_int::(decimal, remote_type, local_type), - type_id::INT16 => boxed_int::(decimal, remote_type, local_type), - type_id::INT32 | type_id::VARINT32 => boxed_int::(decimal, remote_type, local_type), - type_id::INT64 | type_id::VARINT64 | type_id::TAGGED_INT64 => { - boxed_int::(decimal, remote_type, local_type) - } - type_id::UINT8 => boxed_uint::(decimal, remote_type, local_type), - type_id::UINT16 => boxed_uint::(decimal, remote_type, local_type), - type_id::UINT32 | type_id::VAR_UINT32 => { - boxed_uint::(decimal, remote_type, local_type) - } - type_id::UINT64 | type_id::VAR_UINT64 | type_id::TAGGED_UINT64 => { - boxed_uint::(decimal, remote_type, local_type) - } - type_id::FLOAT16 => boxed_float16(decimal, negative_zero, remote_type, local_type), - type_id::BFLOAT16 => boxed_bfloat16(decimal, negative_zero, remote_type, local_type), - type_id::FLOAT32 => boxed_f32(decimal, negative_zero, remote_type, local_type), - type_id::FLOAT64 => boxed_f64(decimal, negative_zero, remote_type, local_type), - type_id::DECIMAL => Ok(Box::new(canonical_decimal(decimal.clone())?)), - _ => Err(conversion_error( - remote_type, - local_type, - "unsupported numeric target", - )), - } + local_type: u32, +) -> Result { + finite_float_decimal(value, remote_type, local_type).map(|(decimal, _)| decimal) +} + +fn string_to_f32_value(value: &str, remote_type: u32, local_type: u32) -> Result { + let (decimal, negative_zero) = string_to_decimal_value(value, remote_type, local_type)?; + decimal_to_f32(&decimal, negative_zero, remote_type, local_type) +} + +fn string_to_f64_value(value: &str, remote_type: u32, local_type: u32) -> Result { + let (decimal, negative_zero) = string_to_decimal_value(value, remote_type, local_type)?; + decimal_to_f64(&decimal, negative_zero, remote_type, local_type) +} + +fn string_to_float16_value( + value: &str, + remote_type: u32, + local_type: u32, +) -> Result { + let (decimal, negative_zero) = string_to_decimal_value(value, remote_type, local_type)?; + decimal_to_float16(&decimal, negative_zero, remote_type, local_type) +} + +fn string_to_bfloat16_value( + value: &str, + remote_type: u32, + local_type: u32, +) -> Result { + let (decimal, negative_zero) = string_to_decimal_value(value, remote_type, local_type)?; + decimal_to_bfloat16(&decimal, negative_zero, remote_type, local_type) } trait FromBigInt: Any + Sized { @@ -385,43 +1463,22 @@ impl_from_bigint!( u64 => to_u64, ); -fn boxed_int(decimal: &Decimal, remote_type: u32, local_type: u32) -> Result, Error> -where - T: FromBigInt, -{ - let value = decimal_to_integral(decimal, remote_type, local_type)?; - T::from_bigint(&value) - .map(|value| Box::new(value) as Box) - .ok_or_else(|| conversion_error(remote_type, local_type, "integer value is out of range")) -} - -fn boxed_uint( - decimal: &Decimal, - remote_type: u32, - local_type: u32, -) -> Result, Error> -where - T: FromBigInt, -{ - let value = decimal_to_integral(decimal, remote_type, local_type)?; - if value.is_negative() { - return Err(conversion_error( - remote_type, - local_type, - "negative value cannot convert to unsigned integer", - )); +fn value_to_f32(value: ScalarValue, remote_type: u32, local_type: u32) -> Result { + match value { + ScalarValue::Float(value) => float_to_f32_value(value, remote_type, local_type), + value => { + let (decimal, negative_zero) = value_to_decimal_parts(value, remote_type, local_type)?; + decimal_to_f32(&decimal, negative_zero, remote_type, local_type) + } } - T::from_bigint(&value) - .map(|value| Box::new(value) as Box) - .ok_or_else(|| conversion_error(remote_type, local_type, "integer value is out of range")) } -fn boxed_f32( +fn decimal_to_f32( decimal: &Decimal, negative_zero: bool, remote_type: u32, local_type: u32, -) -> Result, Error> { +) -> Result { let value = if decimal.unscaled.is_zero() && negative_zero { -0.0 } else { @@ -438,7 +1495,7 @@ fn boxed_f32( } let actual = canonical_float_decimal(FloatValue::F32(value), local_type, local_type)?; if decimal_eq(&actual, decimal) { - Ok(Box::new(value)) + Ok(value) } else { Err(conversion_error( remote_type, @@ -448,12 +1505,22 @@ fn boxed_f32( } } -fn boxed_f64( +fn value_to_f64(value: ScalarValue, remote_type: u32, local_type: u32) -> Result { + match value { + ScalarValue::Float(value) => float_to_f64_value(value, remote_type, local_type), + value => { + let (decimal, negative_zero) = value_to_decimal_parts(value, remote_type, local_type)?; + decimal_to_f64(&decimal, negative_zero, remote_type, local_type) + } + } +} + +fn decimal_to_f64( decimal: &Decimal, negative_zero: bool, remote_type: u32, local_type: u32, -) -> Result, Error> { +) -> Result { let value = if decimal.unscaled.is_zero() && negative_zero { -0.0 } else { @@ -470,7 +1537,7 @@ fn boxed_f64( } let actual = canonical_float_decimal(FloatValue::F64(value), local_type, local_type)?; if decimal_eq(&actual, decimal) { - Ok(Box::new(value)) + Ok(value) } else { Err(conversion_error( remote_type, @@ -480,16 +1547,27 @@ fn boxed_f64( } } -fn boxed_float16( +fn value_to_float16( + value: ScalarValue, + remote_type: u32, + local_type: u32, +) -> Result { + match value { + ScalarValue::Float(value) => float_to_float16_value(value, remote_type, local_type), + value => { + let (decimal, negative_zero) = value_to_decimal_parts(value, remote_type, local_type)?; + decimal_to_float16(&decimal, negative_zero, remote_type, local_type) + } + } +} + +fn decimal_to_float16( decimal: &Decimal, negative_zero: bool, remote_type: u32, local_type: u32, -) -> Result, Error> { - let f32_box = boxed_f32(decimal, negative_zero, remote_type, type_id::FLOAT32)?; - let value = *f32_box.downcast::().map_err(|_| { - conversion_error(remote_type, local_type, "internal float conversion failed") - })?; +) -> Result { + let value = decimal_to_f32(decimal, negative_zero, remote_type, type_id::FLOAT32)?; let value = float16::from_f32(value); if !value.is_finite() { return Err(conversion_error( @@ -500,7 +1578,7 @@ fn boxed_float16( } let actual = canonical_float_decimal(FloatValue::F16(value), local_type, local_type)?; if decimal_eq(&actual, decimal) { - Ok(Box::new(value)) + Ok(value) } else { Err(conversion_error( remote_type, @@ -510,16 +1588,27 @@ fn boxed_float16( } } -fn boxed_bfloat16( +fn value_to_bfloat16( + value: ScalarValue, + remote_type: u32, + local_type: u32, +) -> Result { + match value { + ScalarValue::Float(value) => float_to_bfloat16_value(value, remote_type, local_type), + value => { + let (decimal, negative_zero) = value_to_decimal_parts(value, remote_type, local_type)?; + decimal_to_bfloat16(&decimal, negative_zero, remote_type, local_type) + } + } +} + +fn decimal_to_bfloat16( decimal: &Decimal, negative_zero: bool, remote_type: u32, local_type: u32, -) -> Result, Error> { - let f32_box = boxed_f32(decimal, negative_zero, remote_type, type_id::FLOAT32)?; - let value = *f32_box.downcast::().map_err(|_| { - conversion_error(remote_type, local_type, "internal float conversion failed") - })?; +) -> Result { + let value = decimal_to_f32(decimal, negative_zero, remote_type, type_id::FLOAT32)?; let value = bfloat16::from_f32(value); if !value.is_finite() { return Err(conversion_error( @@ -530,7 +1619,7 @@ fn boxed_bfloat16( } let actual = canonical_float_decimal(FloatValue::BF16(value), local_type, local_type)?; if decimal_eq(&actual, decimal) { - Ok(Box::new(value)) + Ok(value) } else { Err(conversion_error( remote_type, @@ -540,11 +1629,11 @@ fn boxed_bfloat16( } } -fn float_to_number( +fn finite_float_decimal( value: FloatValue, - local_type: u32, remote_type: u32, -) -> Result, Error> { + local_type: u32, +) -> Result<(Decimal, bool), Error> { if float_is_nan(value) { return Err(conversion_error( remote_type, @@ -553,52 +1642,102 @@ fn float_to_number( )); } if float_is_infinite(value) { - return float_infinity_to_number(value, local_type, remote_type); + return Err(conversion_error( + remote_type, + local_type, + "infinity is only convertible to floating targets", + )); } let negative_zero = float_is_negative_zero(value); - let decimal = canonical_float_decimal(value, remote_type, local_type)?; - number_from_decimal(&decimal, negative_zero, local_type, remote_type) + canonical_float_decimal(value, remote_type, local_type).map(|decimal| (decimal, negative_zero)) } -fn float_infinity_to_number( - value: FloatValue, - local_type: u32, - remote_type: u32, -) -> Result, Error> { - let negative = float_sign_negative(value); - match local_type { - type_id::FLOAT16 => { - let value = if negative { - float16::NEG_INFINITY - } else { - float16::INFINITY - }; - Ok(Box::new(value)) - } - type_id::BFLOAT16 => { - let value = if negative { - bfloat16::NEG_INFINITY - } else { - bfloat16::INFINITY - }; - Ok(Box::new(value)) - } - type_id::FLOAT32 => Ok(Box::new(if negative { +fn float_to_f32_value(value: FloatValue, remote_type: u32, local_type: u32) -> Result { + if float_is_nan(value) { + return Err(conversion_error( + remote_type, + local_type, + "NaN is not convertible", + )); + } + if float_is_infinite(value) { + return Ok(if float_sign_negative(value) { f32::NEG_INFINITY } else { f32::INFINITY - })), - type_id::FLOAT64 => Ok(Box::new(if negative { + }); + } + let negative_zero = float_is_negative_zero(value); + let decimal = canonical_float_decimal(value, remote_type, local_type)?; + decimal_to_f32(&decimal, negative_zero, remote_type, local_type) +} + +fn float_to_f64_value(value: FloatValue, remote_type: u32, local_type: u32) -> Result { + if float_is_nan(value) { + return Err(conversion_error( + remote_type, + local_type, + "NaN is not convertible", + )); + } + if float_is_infinite(value) { + return Ok(if float_sign_negative(value) { f64::NEG_INFINITY } else { f64::INFINITY - })), - _ => Err(conversion_error( + }); + } + let negative_zero = float_is_negative_zero(value); + let decimal = canonical_float_decimal(value, remote_type, local_type)?; + decimal_to_f64(&decimal, negative_zero, remote_type, local_type) +} + +fn float_to_float16_value( + value: FloatValue, + remote_type: u32, + local_type: u32, +) -> Result { + if float_is_nan(value) { + return Err(conversion_error( remote_type, local_type, - "infinity is only convertible to floating targets", - )), + "NaN is not convertible", + )); + } + if float_is_infinite(value) { + return Ok(if float_sign_negative(value) { + float16::NEG_INFINITY + } else { + float16::INFINITY + }); + } + let negative_zero = float_is_negative_zero(value); + let decimal = canonical_float_decimal(value, remote_type, local_type)?; + decimal_to_float16(&decimal, negative_zero, remote_type, local_type) +} + +fn float_to_bfloat16_value( + value: FloatValue, + remote_type: u32, + local_type: u32, +) -> Result { + if float_is_nan(value) { + return Err(conversion_error( + remote_type, + local_type, + "NaN is not convertible", + )); + } + if float_is_infinite(value) { + return Ok(if float_sign_negative(value) { + bfloat16::NEG_INFINITY + } else { + bfloat16::INFINITY + }); } + let negative_zero = float_is_negative_zero(value); + let decimal = canonical_float_decimal(value, remote_type, local_type)?; + decimal_to_bfloat16(&decimal, negative_zero, remote_type, local_type) } fn float_to_string(value: FloatValue, remote_type: u32, local_type: u32) -> Result { diff --git a/rust/fory-derive/src/object/field_codec.rs b/rust/fory-derive/src/object/field_codec.rs index bce0aa3f41..071c4c3ec1 100644 --- a/rust/fory-derive/src/object/field_codec.rs +++ b/rust/fory-derive/src/object/field_codec.rs @@ -217,73 +217,95 @@ impl<'a> ResolvedField<'a> { pub fn declare_compatible_var(&self) -> TokenStream { let var = &self.private_ident; let ty = self.value_ty; + let default_expr = default_expr_for_type(self.value_ty); quote! { - let mut #var: Option<#ty> = None; + let mut #var: #ty = #default_expr; } } pub fn assign_value(&self) -> TokenStream { let var = &self.private_ident; - let default_expr = default_expr_for_type(self.value_ty); - quote! { - #var.unwrap_or_else(|| #default_expr) - } + quote! { #var } } - pub fn read_compatible(&self) -> TokenStream { + pub fn read_compatible_direct(&self) -> TokenStream { let var = &self.private_ident; match &self.dispatch { FieldDispatch::Codec { .. } => { let call = self.codec_call(); quote! { - let remote_field_type = &_field.field_type; - if ::fory_core::serializer::codec::field_types_compatible( - local_field_type, - remote_field_type, - ) { - #var = Some(#call::read_field_with_type(context, remote_field_type)?); - } else { - if let Some(value) = #call::read_compatible( - context, - local_field_type, - remote_field_type, - ).map_err(|err| ::fory_core::Error::invalid_data( - format!("compatible field '{}': {}", _field.field_name.as_str(), err) - ))? { - #var = Some(value); - } else { - return Err(::fory_core::Error::invalid_data(format!( - "compatible field '{}' cannot convert remote type {} to local type {}", - _field.field_name.as_str(), - remote_field_type.type_id, - local_field_type.type_id, - ))); - } - } + #var = #call::read_field(context)?; } } FieldDispatch::Serializer { .. } => { let ty = self.value_ty; - quote! { - let remote_field_type = &_field.field_type; - if ::fory_core::serializer::codec::field_types_compatible( - local_field_type, - remote_field_type, - ) { + if serializer_field_can_use_data_path(self.source.field) { + quote! { + #var = <#ty as ::fory_core::Serializer>::fory_read_data(context)?; + } + } else { + quote! { let read_ref_mode = - ::fory_core::serializer::codec::field_ref_mode(remote_field_type); + ::fory_core::serializer::codec::field_ref_mode(local_field_type); let read_type_info = if context.is_compatible() { ::fory_core::serializer::util::field_need_read_type_info( - remote_field_type.type_id + local_field_type.type_id ) } else { <#ty as ::fory_core::Serializer>::fory_is_polymorphic() }; - #var = Some(<#ty as ::fory_core::Serializer>::fory_read( + #var = <#ty as ::fory_core::Serializer>::fory_read( context, read_ref_mode, read_type_info, - )?); + )?; + } + } + } + } + } + + pub fn direct_needs_local_field_type(&self) -> bool { + match &self.dispatch { + FieldDispatch::Codec { .. } => false, + FieldDispatch::Serializer { .. } => { + !serializer_field_can_use_data_path(self.source.field) + } + } + } + + pub fn read_compatible_conversion(&self) -> TokenStream { + let var = &self.private_ident; + match &self.dispatch { + FieldDispatch::Codec { .. } => { + if let Some(read_scalar) = compatible_scalar_reader_for(self.value_ty) { + let call = self.codec_call(); + let local_type = if extract_option_inner_type(self.value_ty).is_some() { + quote! { local_field_type.type_id } + } else { + quote! { #call::static_type_id() as u32 } + }; + return quote! { + #var = #read_scalar( + context, + #local_type, + _field, + ).map_err(|err| ::fory_core::Error::invalid_data( + format!("compatible field '{}': {}", _field.field_name.as_str(), err) + ))?; + }; + } + let call = self.codec_call(); + quote! { + let remote_field_type = &_field.field_type; + if let Some(value) = #call::read_compatible( + context, + local_field_type, + remote_field_type, + ).map_err(|err| ::fory_core::Error::invalid_data( + format!("compatible field '{}': {}", _field.field_name.as_str(), err) + ))? { + #var = value; } else { return Err(::fory_core::Error::invalid_data(format!( "compatible field '{}' cannot convert remote type {} to local type {}", @@ -294,6 +316,36 @@ impl<'a> ResolvedField<'a> { } } } + FieldDispatch::Serializer { .. } => { + let ty = self.value_ty; + quote! { + let remote_field_type = &_field.field_type; + let read_ref_mode = + ::fory_core::serializer::codec::field_ref_mode(remote_field_type); + let read_type_info = if context.is_compatible() { + ::fory_core::serializer::util::field_need_read_type_info( + remote_field_type.type_id + ) + } else { + <#ty as ::fory_core::Serializer>::fory_is_polymorphic() + }; + #var = <#ty as ::fory_core::Serializer>::fory_read( + context, + read_ref_mode, + read_type_info, + )?; + } + } + } + } + + pub fn compatible_needs_local_field_type(&self) -> bool { + match &self.dispatch { + FieldDispatch::Codec { .. } => { + compatible_scalar_reader_for(self.value_ty).is_none() + || extract_option_inner_type(self.value_ty).is_some() + } + FieldDispatch::Serializer { .. } => false, } } @@ -903,6 +955,87 @@ fn integer_codec_type( } } +fn compatible_scalar_reader_for(ty: &Type) -> Option { + if let Some(inner) = extract_option_inner_type(ty) { + return compatible_scalar_reader_name(&inner, true); + } + compatible_scalar_reader_name(ty, false) +} + +fn compatible_scalar_reader_name(ty: &Type, option: bool) -> Option { + let name = type_name_and_args(ty) + .map(|(name, _)| name) + .unwrap_or_else(|| ty.to_token_stream().to_string().replace(' ', "")); + let reader = match (name.as_str(), option) { + ("bool", false) => quote! { ::fory_core::serializer::codec::read_bool_compatible_scalar }, + ("bool", true) => { + quote! { ::fory_core::serializer::codec::read_bool_option_compatible_scalar } + } + ("String", false) => { + quote! { ::fory_core::serializer::codec::read_string_compatible_scalar } + } + ("String", true) => { + quote! { ::fory_core::serializer::codec::read_string_option_compatible_scalar } + } + ("i8", false) => quote! { ::fory_core::serializer::codec::read_i8_compatible_scalar }, + ("i8", true) => quote! { ::fory_core::serializer::codec::read_i8_option_compatible_scalar }, + ("i16", false) => quote! { ::fory_core::serializer::codec::read_i16_compatible_scalar }, + ("i16", true) => { + quote! { ::fory_core::serializer::codec::read_i16_option_compatible_scalar } + } + ("i32", false) => quote! { ::fory_core::serializer::codec::read_i32_compatible_scalar }, + ("i32", true) => { + quote! { ::fory_core::serializer::codec::read_i32_option_compatible_scalar } + } + ("i64", false) => quote! { ::fory_core::serializer::codec::read_i64_compatible_scalar }, + ("i64", true) => { + quote! { ::fory_core::serializer::codec::read_i64_option_compatible_scalar } + } + ("u8", false) => quote! { ::fory_core::serializer::codec::read_u8_compatible_scalar }, + ("u8", true) => quote! { ::fory_core::serializer::codec::read_u8_option_compatible_scalar }, + ("u16", false) => quote! { ::fory_core::serializer::codec::read_u16_compatible_scalar }, + ("u16", true) => { + quote! { ::fory_core::serializer::codec::read_u16_option_compatible_scalar } + } + ("u32", false) => quote! { ::fory_core::serializer::codec::read_u32_compatible_scalar }, + ("u32", true) => { + quote! { ::fory_core::serializer::codec::read_u32_option_compatible_scalar } + } + ("u64", false) => quote! { ::fory_core::serializer::codec::read_u64_compatible_scalar }, + ("u64", true) => { + quote! { ::fory_core::serializer::codec::read_u64_option_compatible_scalar } + } + ("f32", false) => quote! { ::fory_core::serializer::codec::read_f32_compatible_scalar }, + ("f32", true) => { + quote! { ::fory_core::serializer::codec::read_f32_option_compatible_scalar } + } + ("f64", false) => quote! { ::fory_core::serializer::codec::read_f64_compatible_scalar }, + ("f64", true) => { + quote! { ::fory_core::serializer::codec::read_f64_option_compatible_scalar } + } + ("float16" | "Float16", false) => { + quote! { ::fory_core::serializer::codec::read_float16_compatible_scalar } + } + ("float16" | "Float16", true) => { + quote! { ::fory_core::serializer::codec::read_float16_option_compatible_scalar } + } + ("bfloat16" | "BFloat16", false) => { + quote! { ::fory_core::serializer::codec::read_bfloat16_compatible_scalar } + } + ("bfloat16" | "BFloat16", true) => { + quote! { ::fory_core::serializer::codec::read_bfloat16_option_compatible_scalar } + } + ("Decimal", false) => { + quote! { ::fory_core::serializer::codec::read_decimal_compatible_scalar } + } + ("Decimal", true) => { + quote! { ::fory_core::serializer::codec::read_decimal_option_compatible_scalar } + } + _ => return None, + }; + Some(reader) +} + fn type_name_and_args( ty: &Type, ) -> Option<( @@ -1430,4 +1563,20 @@ mod tests { let err = codec_type_for(&ty, &meta, false, false).unwrap_err(); assert!(err.to_string().contains("bytes schema requires Vec")); } + + #[test] + fn compatible_scalar_reader_is_typed() { + let ty: Type = parse_quote! { i32 }; + let reader = compatible_scalar_reader_for(&ty).unwrap().to_string(); + assert!(reader.contains("read_i32_compatible_scalar")); + assert!(!reader.contains("read_compatible")); + + let ty: Type = parse_quote! { Option }; + let reader = compatible_scalar_reader_for(&ty).unwrap().to_string(); + assert!(reader.contains("read_u64_option_compatible_scalar")); + assert!(!reader.contains("read_compatible")); + + let ty: Type = parse_quote! { Vec }; + assert!(compatible_scalar_reader_for(&ty).is_none()); + } } diff --git a/rust/fory-derive/src/object/read.rs b/rust/fory-derive/src/object/read.rs index 5e790d49b1..66242a7935 100644 --- a/rust/fory-derive/src/object/read.rs +++ b/rust/fory-derive/src/object/read.rs @@ -258,6 +258,28 @@ pub(crate) fn gen_read_compatible_with_construction( }; let declare_ts: Vec = declare_var(source_fields); let assign_ts: Vec = assign_value(source_fields); + let is_tuple = source_fields + .first() + .map(|sf| sf.is_tuple_struct) + .unwrap_or(false); + + let construction = if let Some(variant) = variant_ident { + quote! { + Ok(Self::#variant { + #(#assign_ts),* + }) + } + } else { + crate::util::ok_self_construction(is_tuple, &assign_ts) + }; + let same_schema_construction = construction.clone(); + let same_schema_read_ts: Vec = bindings + .iter() + .map(|binding| match binding { + FieldBinding::Codec(binding) => binding.read_field(), + FieldBinding::Skipped(binding) => binding.read_default(), + }) + .collect(); let match_arms: Vec = bindings .iter() @@ -266,29 +288,52 @@ pub(crate) fn gen_read_compatible_with_construction( FieldBinding::Skipped(_) => None, }) .enumerate() - .map(|(sorted_idx, binding)| { - // Use sorted index for matching. Field assignment only reaches this arm - // for exact reads or supported compatible read actions; unmatched and - // schema-incompatible fields keep field_id = -1 and are skipped. - let field_id = sorted_idx as i16; + .flat_map(|(sorted_idx, binding)| { + let direct_field_id = (sorted_idx * 2) as i16; + let compatible_field_id = (sorted_idx * 2 + 1) as i16; let field_index = sorted_idx; - let body = binding.read_compatible(); - quote! { - #field_id => { - let local_field_type = unsafe { - &(*local_fields_ptr.add(#field_index)).field_type - }; - #body + let direct_body = binding.read_compatible_direct(); + let compatible_body = binding.read_compatible_conversion(); + let direct_arm = if binding.direct_needs_local_field_type() { + quote! { + #direct_field_id => { + let local_field_type = unsafe { + &(*local_fields_ptr.add(#field_index)).field_type + }; + #direct_body + } } - } + } else { + quote! { + #direct_field_id => { + #direct_body + } + } + }; + let compatible_arm = if binding.compatible_needs_local_field_type() { + quote! { + #compatible_field_id => { + let local_field_type = unsafe { + &(*local_fields_ptr.add(#field_index)).field_type + }; + #compatible_body + } + } + } else { + quote! { + #compatible_field_id => { + #compatible_body + } + } + }; + [direct_arm, compatible_arm] }) .collect(); - let skip_arm = if is_debug_enabled() { let struct_name = get_struct_name().expect("struct context not set"); let struct_name_lit = syn::LitStr::new(&struct_name, proc_macro2::Span::call_site()); quote! { - _ => { + -1 => { let field_type = &_field.field_type; let read_ref_flag = ::fory_core::serializer::util::field_need_write_ref_into( field_type.type_id, @@ -312,7 +357,7 @@ pub(crate) fn gen_read_compatible_with_construction( } } else { quote! { - _ => { + -1 => { let field_type = &_field.field_type; let read_ref_flag = ::fory_core::serializer::util::field_need_write_ref_into( field_type.type_id, @@ -323,21 +368,14 @@ pub(crate) fn gen_read_compatible_with_construction( } }; - // Generate construction based on whether this is a struct or enum variant - let is_tuple = source_fields - .first() - .map(|sf| sf.is_tuple_struct) - .unwrap_or(false); - - let construction = if let Some(variant) = variant_ident { - // Enum variants use named syntax (struct variants) or tuple syntax (tuple variants) - quote! { - Ok(Self::#variant { - #(#assign_ts),* - }) + let invalid_arm = quote! { + field_id => { + return Err(::fory_core::Error::invalid_data(format!( + "invalid compatible matched id {} for field '{}'", + field_id, + _field.field_name.as_str(), + ))); } - } else { - crate::util::ok_self_construction(is_tuple, &assign_ts) }; let variant_field_remap = if let Some(variant) = variant_ident { @@ -357,50 +395,6 @@ pub(crate) fn gen_read_compatible_with_construction( ))?; let local_variant_type_meta = local_variant_type_info.get_type_meta(); let local_fields = local_variant_type_meta.get_field_infos(); - - let field_index_by_name: ::std::collections::HashMap<_, _> = local_fields - .iter() - .enumerate() - .filter(|(_, f)| !f.field_name.is_empty()) - .map(|(i, f)| (f.field_name.clone(), (i, f))) - .collect(); - - let field_index_by_id: ::std::collections::HashMap<_, _> = local_fields - .iter() - .enumerate() - .filter(|(_, f)| f.field_id >= 0) - .map(|(i, f)| (f.field_id, (i, f))) - .collect(); - - for field in fields.iter_mut() { - let local_match = if field.field_id >= 0 && field.field_name.is_empty() { - field_index_by_id.get(&field.field_id).copied() - } else { - let snake_case_name = - ::fory_core::util::to_snake_case(&field.field_name); - field_index_by_name.get(&snake_case_name).copied() - }; - - match local_match { - Some((sorted_index, local_info)) - if ::fory_core::serializer::codec::compatible_field_pair( - &local_info.field_type, - &field.field_type, - ) => - { - if field.field_name.is_empty() { - field.field_name = local_info.field_name.clone(); - } - field.field_id = sorted_index as i16; - } - Some(_) => { - field.field_id = -1; - } - None => { - field.field_id = -1; - } - } - } } } else { quote! {} @@ -408,8 +402,26 @@ pub(crate) fn gen_read_compatible_with_construction( let fields_binding = if variant_ident.is_some() { quote! { - let mut fields = remote_meta.get_field_infos().clone(); #variant_field_remap + let mut remapped_fields: ::std::vec::Vec<::fory_core::meta::FieldInfo>; + let fields = if remote_meta.get_namespace().original.as_str() + == local_variant_type_meta.get_namespace().original.as_str() + && remote_meta.get_type_name().original.as_str() + == local_variant_type_meta.get_type_name().original.as_str() + { + // Same-name variant TypeMeta is resolved by the synthetic variant TypeInfo during + // parsing, so its field ids are already doubled matched dispatch ids. Running + // schema matching here again would treat those internal ids as explicit wire ids + // and skip every field. + remote_meta.get_field_infos() + } else { + // Compatible enums match variants by tag first. If that tag now names a different + // local variant, the remote synthetic variant TypeMeta could not be classified by + // name during parsing, so the selected local variant owns the field remap here. + remapped_fields = remote_meta.get_field_infos().clone(); + ::fory_core::meta::assign_remote_field_ids(local_fields, &mut remapped_fields)?; + &remapped_fields + }; let local_fields_ptr = local_fields.as_ptr(); } } else { @@ -418,24 +430,43 @@ pub(crate) fn gen_read_compatible_with_construction( let fields = remote_meta.get_field_infos(); } }; + let schema_setup = if variant_ident.is_some() { + quote! { + let remote_meta = type_info.get_type_meta_ref(); + let remote_type_hash = remote_meta.get_hash(); + #fields_binding + if remote_type_hash == local_variant_type_meta.get_hash() { + // The payload is still only the variant fields. Reading the whole enum data here + // would consume field bytes as a fresh enum tag, so exact variant schemas use the + // local sorted field reader directly. + #(#same_schema_read_ts)* + return #same_schema_construction; + } + } + } else { + quote! { + let meta = context.get_type_resolver().get_type_meta_by_index_ref( + &::std::any::TypeId::of::(), + ::fory_type_index(), + )?; + let local_type_hash = meta.get_hash(); + let remote_meta = type_info.get_type_meta_ref(); + let remote_type_hash = remote_meta.get_hash(); + if remote_type_hash == local_type_hash { + return ::fory_read_data(context); + } + #fields_binding + } + }; quote! { - let meta = context.get_type_resolver().get_type_meta_by_index_ref( - &::std::any::TypeId::of::(), - ::fory_type_index(), - )?; - let local_type_hash = meta.get_hash(); - let remote_meta = type_info.get_type_meta_ref(); - let remote_type_hash = remote_meta.get_hash(); - if remote_type_hash == local_type_hash { - return ::fory_read_data(context); - } - #fields_binding + #schema_setup #(#declare_ts)* for _field in fields.iter() { match _field.field_id { #(#match_arms)* #skip_arm + #invalid_arm } } #construction diff --git a/rust/fory-derive/src/object/serializer.rs b/rust/fory-derive/src/object/serializer.rs index 9dc9dcd8cd..5ee87b86f3 100644 --- a/rust/fory-derive/src/object/serializer.rs +++ b/rust/fory-derive/src/object/serializer.rs @@ -180,7 +180,7 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea #variants_fields_info_ts } - #[inline] + #[inline(never)] fn fory_read_compatible(context: &mut ::fory_core::ReadContext, type_info: ::std::rc::Rc<::fory_core::TypeInfo>) -> ::std::result::Result { #read_compatible_ts } diff --git a/rust/tests/tests/compatible/test_struct.rs b/rust/tests/tests/compatible/test_struct.rs index 770f044e08..e5a40c87ad 100644 --- a/rust/tests/tests/compatible/test_struct.rs +++ b/rust/tests/tests/compatible/test_struct.rs @@ -158,8 +158,14 @@ fn compatible_list_array_field_pairs() { payload: vec![vec![1, 2], vec![3]], }) .unwrap(); - let decoded: NestedArrayPayload = reader.deserialize(&bytes).unwrap(); - assert_eq!(decoded.payload, Vec::>::default()); + let err = reader + .deserialize::(&bytes) + .expect_err("expected nested list/array mismatch to fail classification"); + assert!( + err.to_string() + .contains("remote and local field schemas are not compatible"), + "{err}" + ); } #[test] @@ -236,7 +242,7 @@ fn nonexistent_struct() { } #[test] -fn serializer_container_mismatch_defaults() { +fn rejects_serializer_container_mismatch() { #[derive(ForyStruct, Debug)] struct SetI8 { values: HashSet, @@ -256,8 +262,14 @@ fn serializer_container_mismatch_defaults() { values: HashSet::from([1]), }) .unwrap(); - let decoded: SetI16 = fory2.deserialize(&bin).unwrap(); - assert_eq!(decoded.values, HashSet::default()); + let err = fory2 + .deserialize::(&bin) + .expect_err("expected incompatible container element schema to fail classification"); + assert!( + err.to_string() + .contains("remote and local field schemas are not compatible"), + "{err}" + ); } #[test] diff --git a/rust/tests/tests/test_field_meta.rs b/rust/tests/tests/test_field_meta.rs index f03052c966..7511fd6c4d 100644 --- a/rust/tests/tests/test_field_meta.rs +++ b/rust/tests/tests/test_field_meta.rs @@ -641,10 +641,15 @@ fn test_compatible_nested_integer_encoding_mismatch() { }; let bytes = writer.serialize(&original).unwrap(); - let deserialized: NestedFixedEncoding = reader.deserialize(&bytes).unwrap(); - assert_eq!(original.values, deserialized.values); - assert_eq!(original.data, deserialized.data); - assert_eq!(original.maybe_values, deserialized.maybe_values); + let error = reader + .deserialize::(&bytes) + .expect_err("expected nested integer encoding drift to fail classification"); + assert!( + error + .to_string() + .contains("remote and local field schemas are not compatible"), + "{error}" + ); } /// Test struct with field IDs for compact encoding diff --git a/rust/tests/tests/test_meta.rs b/rust/tests/tests/test_meta.rs index 86fbd6882f..7bc624d6df 100644 --- a/rust/tests/tests/test_meta.rs +++ b/rust/tests/tests/test_meta.rs @@ -26,11 +26,11 @@ fn test_meta_hash() { MetaString::get_empty().clone(), MetaString::get_empty().clone(), false, - vec![FieldInfo { - field_id: 43, - field_name: "f1".to_string(), - field_type: FieldType::new(TypeId::BOOL as u32, true, vec![]), - }], + vec![FieldInfo::new_with_id( + 43, + "f1", + FieldType::new(TypeId::BOOL as u32, true, vec![]), + )], ) .unwrap(); assert_ne!(meta.get_hash(), 0); diff --git a/rust/tests/tests/test_one_struct.rs b/rust/tests/tests/test_one_struct.rs index 3bf883f41d..139bb4073f 100644 --- a/rust/tests/tests/test_one_struct.rs +++ b/rust/tests/tests/test_one_struct.rs @@ -42,7 +42,7 @@ fn test_simple() { f3: Vec, f4: String, f5: i8, - f6: Vec, + f6: Vec, f7: i16, last: i8, } @@ -66,7 +66,7 @@ fn test_simple() { assert_eq!(animal.f3, obj.f3); assert_eq!(obj.f4, String::default()); assert_eq!(obj.f5, 5); - assert_eq!(obj.f6, Vec::::default()); + assert_eq!(animal.f6, obj.f6); assert_eq!(obj.f7, 43); assert_eq!(animal.last, obj.last); } diff --git a/rust/tests/tests/test_tuple_compatible.rs b/rust/tests/tests/test_tuple_compatible.rs index 74a113bcba..2618ae1c50 100644 --- a/rust/tests/tests/test_tuple_compatible.rs +++ b/rust/tests/tests/test_tuple_compatible.rs @@ -682,6 +682,8 @@ fn run_struct_tuple_element_decrease(xlang: bool) { /// Helper: Test struct with complex nested tuple evolution fn run_struct_nested_tuple_evolution(xlang: bool) { + type NestedTupleV2 = ((i32, String, Vec), (f64, bool, Option)); + // V1: Struct with simple nested tuple #[derive(ForyStruct, Debug, PartialEq)] struct StructV1 { @@ -691,10 +693,9 @@ fn run_struct_nested_tuple_evolution(xlang: bool) { // V2: Struct with evolved nested tuple (more elements) #[derive(ForyStruct, Debug, PartialEq)] - #[allow(clippy::type_complexity)] struct StructV2 { id: i32, - nested: ((i32, String, Vec), (f64, bool, Option)), + nested: NestedTupleV2, } // Use separate Fory instances with the same type ID @@ -889,6 +890,9 @@ fn test_struct_complex_evolution_scenario_xlang() { /// - Multiple tuple fields evolving simultaneously /// - Mix of simple, nested, and collection-based tuples fn run_struct_complex_evolution_scenario(xlang: bool) { + type MetadataTupleV2 = ((String, i32, Vec), (bool, f64, Option)); + type AttributesTupleV2 = ((Vec, HashMap), (Option,)); + // V1: Original schema with multiple tuple fields #[derive(ForyStruct, Debug, PartialEq)] struct DataRecordV1 { @@ -906,7 +910,6 @@ fn run_struct_complex_evolution_scenario(xlang: bool) { // V2: Evolved schema with complex changes #[derive(ForyStruct, Debug, PartialEq)] - #[allow(clippy::type_complexity)] struct DataRecordV2 { id: i32, name: String, @@ -915,13 +918,13 @@ fn run_struct_complex_evolution_scenario(xlang: bool) { // category reduced to single element (2 -> 1 elements) category: (String,), // metadata nested tuple expanded (both inner tuples gain elements) - metadata: ((String, i32, Vec), (bool, f64, Option)), + metadata: MetadataTupleV2, // tags remains same tags: (Vec, Vec), // NEW FIELD: status tuple added status: (bool, String, i32), // NEW FIELD: nested tuple with collections - attributes: ((Vec, HashMap), (Option,)), + attributes: AttributesTupleV2, } // Use separate Fory instances with the same type ID diff --git a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala index 219df49a91..025ff14c82 100644 --- a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala +++ b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala @@ -29,6 +29,7 @@ import org.apache.fory.serializer.{ StaticGeneratedStructSerializer, UnionSerializer } +import org.apache.fory.serializer.converter.FieldConverters import org.apache.fory.`type`.{Descriptor, ScalaTypes, Types} import java.lang.reflect.Modifier @@ -57,7 +58,13 @@ object ForySerializerMacros { nullable: Boolean, trackingRef: Boolean, hasTrackingRefMetadata: Boolean, - constructorOwned: Boolean) + constructorOwned: Boolean, + hasDefault: Boolean, + directReadable: Boolean, + directWritable: Boolean) { + def usesFieldAccessor: Boolean = + !constructorOwned && (!directReadable || !directWritable) + } if !hasAnnotation[ForyStruct](owner) then { report.errorAndAbort( @@ -157,19 +164,13 @@ object ForySerializerMacros { val bodyFields = { val candidates = owner.fieldMembers.filter { field => !constructorFieldSet.contains(field) && - !field.flags.is(Flags.Private) && !field.flags.is(Flags.Synthetic) && - !field.name.contains("$") + !field.name.contains("$") && + (!field.flags.is(Flags.Private) || annotationIntArg[ForyField](field, "id").nonEmpty) } val selected = if constructorFields.isEmpty then candidates else candidates.filter(field => annotationIntArg[ForyField](field, "id").nonEmpty) - selected.foreach { field => - if !field.flags.is(Flags.Mutable) then { - report.errorAndAbort( - s"${owner.fullName}.${field.name} is a post-construction field and must be a mutable var") - } - } selected } val serializableFields = constructorFields ++ bodyFields @@ -188,6 +189,8 @@ object ForySerializerMacros { case None => (sourceType, false, false) } val refTracking = refAnnotation(field).orElse(topLevelTypeRefTracking(sourceType)) + val constructorOwned = constructorFieldSet.contains(field) + val privateField = field.flags.is(Flags.Private) FieldMeta( field, field.name, @@ -199,7 +202,10 @@ object ForySerializerMacros { nullable, refTracking.getOrElse(false), refTracking.nonEmpty, - constructorFieldSet.contains(field)) + constructorOwned, + field.flags.is(Flags.HasDefault), + !privateField, + constructorOwned || (field.flags.is(Flags.Mutable) && !privateField)) } val hasNestedCompatibleStructFields = fields.exists(field => hasNestedCompatibleStruct(field.sourceType)) @@ -329,21 +335,28 @@ object ForySerializerMacros { } } - def selectValue(valueExpr: Expr[T], field: FieldMeta): Expr[Any] = { - Select.unique(valueExpr.asTerm, field.name).asExpr - } + def selectValue( + valueExpr: Expr[T], + field: FieldMeta, + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[Any] = + if field.usesFieldAccessor then { + '{ $fieldAccessorsExpr(${ Expr(field.index) }).getObject($valueExpr) } + } else { + Select.unique(valueExpr.asTerm, field.name).asExpr + } def writeDispatch( valueExpr: Expr[T], fieldIdExpr: Expr[Int], fieldInfoExpr: Expr[FieldGroups.SerializationFieldInfo], writeContextExpr: Expr[org.apache.fory.context.WriteContext], - resolverExpr: Expr[TypeResolver]): Expr[Unit] = { + resolverExpr: Expr[TypeResolver], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[Unit] = { fields.foldRight( '{ throw new IllegalStateException("Unknown generated Scala field id " + $fieldIdExpr) }: Expr[Unit]) { (field, next) => - val fieldValue = selectValue(valueExpr, field) + val fieldValue = selectValue(valueExpr, field, fieldAccessorsExpr) val wireValue = if field.option then '{ $fieldValue.asInstanceOf[Option[Any]].orNull } else fieldValue @@ -376,35 +389,482 @@ object ForySerializerMacros { def valueArg(valuesExpr: Expr[Array[Any]], field: FieldMeta): Expr[Any] = decodeValue('{ $valuesExpr(${ Expr(field.index) }) }, field) - def assignRawValue(objExpr: Expr[T], field: FieldMeta, raw: Expr[Any]): Expr[Unit] = - Assign(Select.unique(objExpr.asTerm, field.name), decodeValue(raw, field).asTerm) - .asExprOf[Unit] + def localDefault(field: FieldMeta): Term = { + if field.option then { + field.sourceType.asType match { + case '[a] => '{ None.asInstanceOf[a] }.asTerm + } + } else { + val normalized = peelAnnotations(field.sourceType.widen)._1.dealias + if normalized =:= TypeRepr.of[Boolean] then '{ false }.asTerm + else if normalized =:= TypeRepr.of[Byte] then '{ 0.toByte }.asTerm + else if normalized =:= TypeRepr.of[Short] then '{ 0.toShort }.asTerm + else if normalized =:= TypeRepr.of[Int] then '{ 0 }.asTerm + else if normalized =:= TypeRepr.of[Long] then '{ 0L }.asTerm + else if normalized =:= TypeRepr.of[Float] then '{ 0.0f }.asTerm + else if normalized =:= TypeRepr.of[Double] then '{ 0.0d }.asTerm + else if normalized =:= TypeRepr.of[Char] then '{ 0.toChar }.asTerm + else { + field.sourceType.asType match { + case '[a] => '{ null.asInstanceOf[a] }.asTerm + } + } + } + } - def assignValueById(objExpr: Expr[T], fieldIdExpr: Expr[Int], raw: Expr[Any]): Expr[Unit] = { - fields.foldRight( + def assignRawValue( + objExpr: Expr[T], + field: FieldMeta, + raw: Expr[Any], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[Unit] = + if field.usesFieldAccessor then { + val decoded = decodeValue(raw, field) + '{ + $fieldAccessorsExpr(${ Expr(field.index) }) + .putObject($objExpr, $decoded.asInstanceOf[AnyRef]) + } + } else { + Assign(Select.unique(objExpr.asTerm, field.name), decodeValue(raw, field).asTerm) + .asExprOf[Unit] + } + + def assignLocalValue(local: Symbol, field: FieldMeta, raw: Expr[Any]): Term = + Assign(Ref(local), decodeValue(raw, field).asTerm) + + def assignLocalSource(local: Symbol, field: FieldMeta, raw: Expr[Any]): Expr[Unit] = + field.sourceType.asType match { + case '[a] => Assign(Ref(local), '{ $raw.asInstanceOf[a] }.asTerm).asExprOf[Unit] + } + + def assignValueById( + objExpr: Expr[T], + fieldIdExpr: Expr[Int], + raw: Expr[Any], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[Unit] = { + val cases = fields.map { field => + CaseDef( + Literal(IntConstant(field.index)), + None, + assignRawValue(objExpr, field, raw, fieldAccessorsExpr).asTerm) + } :+ CaseDef( + Wildcard(), + None, '{ throw new IllegalStateException("Unknown generated Scala field id " + $fieldIdExpr) - }: Expr[Unit]) { (field, next) => + }.asTerm) + Match(fieldIdExpr.asTerm, cases).asExprOf[Unit] + } + + def compatibleScalarValue( + field: FieldMeta, + fieldInfoExpr: Expr[FieldGroups.SerializationFieldInfo], + remoteFieldExpr: Expr[StaticGeneratedStructSerializer.RemoteFieldInfo], + readContextExpr: Expr[org.apache.fory.context.ReadContext]): Option[Expr[Any]] = { + val targetType = optionElement(field.sourceType).getOrElse(field.sourceType) + val normalized = peelAnnotations(targetType.widen)._1.dealias + val boxed = field.option + normalized.asType match { + case '[Boolean] => + Some( + if boxed then + '{ + FieldConverters.readBoxedBooleanTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any] + else + '{ + FieldConverters.readBooleanTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[java.lang.Boolean] => + Some( + '{ + FieldConverters.readBoxedBooleanTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[Byte] => + Some( + if boxed then + '{ + FieldConverters.readBoxedByteTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any] + else + '{ + FieldConverters.readByteTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[java.lang.Byte] => + Some( + '{ + FieldConverters.readBoxedByteTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[Short] => + Some( + if boxed then + '{ + FieldConverters.readBoxedShortTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any] + else + '{ + FieldConverters.readShortTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[java.lang.Short] => + Some( + '{ + FieldConverters.readBoxedShortTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[Int] => + Some( + if boxed then + '{ + FieldConverters.readBoxedIntTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any] + else + '{ + FieldConverters.readIntTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[java.lang.Integer] => + Some( + '{ + FieldConverters.readBoxedIntTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[Long] => + Some( + if boxed then + '{ + FieldConverters.readBoxedLongTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any] + else + '{ + FieldConverters.readLongTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[java.lang.Long] => + Some( + '{ + FieldConverters.readBoxedLongTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[Float] => + Some( + if boxed then + '{ + FieldConverters.readBoxedFloatTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any] + else + '{ + FieldConverters.readFloatTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[java.lang.Float] => + Some( + '{ + FieldConverters.readBoxedFloatTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[Double] => + Some( + if boxed then + '{ + FieldConverters.readBoxedDoubleTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any] + else + '{ + FieldConverters.readDoubleTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[java.lang.Double] => + Some( + '{ + FieldConverters.readBoxedDoubleTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[String] => + Some( + '{ + FieldConverters.readStringTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[java.math.BigDecimal] => + Some( + '{ + FieldConverters.readDecimalTarget( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[org.apache.fory.`type`.unsigned.UInt8] => + Some( + '{ + FieldConverters.readUInt8Target( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[org.apache.fory.`type`.unsigned.UInt16] => + Some( + '{ + FieldConverters.readUInt16Target( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[org.apache.fory.`type`.unsigned.UInt32] => + Some( + '{ + FieldConverters.readUInt32Target( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[org.apache.fory.`type`.unsigned.UInt64] => + Some( + '{ + FieldConverters.readUInt64Target( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[org.apache.fory.`type`.Float16] => + Some( + '{ + FieldConverters.readFloat16Target( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case '[org.apache.fory.`type`.BFloat16] => + Some( + '{ + FieldConverters.readBFloat16Target( + $readContextExpr, + $remoteFieldExpr.serializationFieldInfo, + $fieldInfoExpr) + }.asExprOf[Any]) + case _ => None + } + } + + def compatibleValue( + serializerExpr: Expr[StaticGeneratedStructSerializer[T]], + field: FieldMeta, + fieldInfoExpr: Expr[FieldGroups.SerializationFieldInfo], + remoteFieldExpr: Expr[StaticGeneratedStructSerializer.RemoteFieldInfo], + readContextExpr: Expr[org.apache.fory.context.ReadContext]): Expr[Any] = + compatibleScalarValue(field, fieldInfoExpr, remoteFieldExpr, readContextExpr).getOrElse { '{ - if $fieldIdExpr == ${ Expr(field.index) } then { - ${ assignRawValue(objExpr, field, raw) } - } else { - $next - } + $serializerExpr.readCompatibleFieldValue( + $readContextExpr, + $remoteFieldExpr, + $fieldInfoExpr) } } - } def readAndAssignDispatch( objExpr: Expr[T], fieldIdExpr: Expr[Int], fieldInfoExpr: Expr[FieldGroups.SerializationFieldInfo], readContextExpr: Expr[org.apache.fory.context.ReadContext], - resolverExpr: Expr[TypeResolver]): Expr[Unit] = { + resolverExpr: Expr[TypeResolver], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[Unit] = { '{ val fieldValue = StaticGeneratedStructSerializer.readFieldValue($resolverExpr, $readContextExpr, $fieldInfoExpr) - ${ assignValueById(objExpr, fieldIdExpr, 'fieldValue) } + ${ assignValueById(objExpr, fieldIdExpr, 'fieldValue, fieldAccessorsExpr) } + } + } + + def compatibleAssignByMatchedId( + serializerExpr: Expr[StaticGeneratedStructSerializer[T]], + resolverExpr: Expr[TypeResolver], + fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]], + objExpr: Expr[T], + matchedIdExpr: Expr[Int], + remoteFieldExpr: Expr[StaticGeneratedStructSerializer.RemoteFieldInfo], + readContextExpr: Expr[org.apache.fory.context.ReadContext], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[Unit] = { + val fieldCases = fields.flatMap { field => + val fieldInfoExpr = '{ $fieldsByIdExpr(${ Expr(field.index) }) } + val directValue = + '{ + StaticGeneratedStructSerializer.readFieldValue( + $resolverExpr, + $readContextExpr, + $fieldInfoExpr) + } + val compatibleFieldValue = + compatibleValue(serializerExpr, field, fieldInfoExpr, remoteFieldExpr, readContextExpr) + List( + CaseDef( + Literal(IntConstant(field.index * 2)), + None, + assignRawValue(objExpr, field, directValue, fieldAccessorsExpr).asTerm), + CaseDef( + Literal(IntConstant(field.index * 2 + 1)), + None, + assignRawValue(objExpr, field, compatibleFieldValue, fieldAccessorsExpr).asTerm)) + } + val skipCase = + CaseDef( + Literal(IntConstant(-1)), + None, + '{ $serializerExpr.skipField($readContextExpr, $remoteFieldExpr) }.asTerm) + val invalidCase = + CaseDef( + Wildcard(), + None, + '{ + throw new IllegalStateException("Invalid compatible matched id " + $matchedIdExpr) + }.asTerm) + Match(matchedIdExpr.asTerm, fieldCases :+ skipCase :+ invalidCase).asExprOf[Unit] + } + + def compatibleLocalsByMatchedId( + serializerExpr: Expr[StaticGeneratedStructSerializer[T]], + resolverExpr: Expr[TypeResolver], + fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]], + localFields: Seq[(FieldMeta, Symbol)], + markPresent: FieldMeta => Term, + matchedIdExpr: Expr[Int], + remoteFieldExpr: Expr[StaticGeneratedStructSerializer.RemoteFieldInfo], + readContextExpr: Expr[org.apache.fory.context.ReadContext]): Expr[Unit] = { + val fieldCases = localFields.flatMap { (field, local) => + val fieldInfoExpr = '{ $fieldsByIdExpr(${ Expr(field.index) }) } + val directValue = + '{ + StaticGeneratedStructSerializer.readFieldValue( + $resolverExpr, + $readContextExpr, + $fieldInfoExpr) + } + val compatibleFieldValue = + compatibleValue(serializerExpr, field, fieldInfoExpr, remoteFieldExpr, readContextExpr) + List( + CaseDef( + Literal(IntConstant(field.index * 2)), + None, + Block( + assignLocalValue(local, field, directValue) :: markPresent(field) :: Nil, + '{ () }.asTerm)), + CaseDef( + Literal(IntConstant(field.index * 2 + 1)), + None, + Block( + assignLocalValue(local, field, compatibleFieldValue) :: markPresent(field) :: Nil, + '{ () }.asTerm))) + }.toList + val skipCase = + CaseDef( + Literal(IntConstant(-1)), + None, + '{ $serializerExpr.skipField($readContextExpr, $remoteFieldExpr) }.asTerm) + val invalidCase = + CaseDef( + Wildcard(), + None, + '{ + throw new IllegalStateException("Invalid compatible matched id " + $matchedIdExpr) + }.asTerm) + Match(matchedIdExpr.asTerm, fieldCases :+ skipCase :+ invalidCase).asExprOf[Unit] + } + + def assignPostConstruction( + obj: Term, + field: FieldMeta, + value: Term, + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Term = + if field.usesFieldAccessor then { + '{ + $fieldAccessorsExpr(${ Expr(field.index) }) + .putObject(${ obj.asExprOf[T] }, ${ value.asExpr }.asInstanceOf[AnyRef]) + }.asTerm + } else { + Assign(Select.unique(obj, field.name), value) + } + + def constructFromLocals( + localFields: Seq[(FieldMeta, Symbol)], + instantiatorExpr: Expr[org.apache.fory.reflect.ObjectInstantiator[T]], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[T] = { + val constructorOwned = fields.filter(_.constructorOwned) + val postConstruction = fields.filterNot(_.constructorOwned) + val localById = localFields.map { (field, local) => field.index -> local }.toMap + def localRef(field: FieldMeta): Term = Ref(localById(field.index)) + if postConstruction.isEmpty then { + val args = constructorOwned.map(localRef) + Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args).asExprOf[T] + } else if constructorOwned.isEmpty then { + val obj = + Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) + val construct = '{ $instantiatorExpr.newInstance() }.asTerm + val assignments = postConstruction.map { field => + assignPostConstruction(Ref(obj), field, localRef(field), fieldAccessorsExpr) + } + Block(ValDef(obj, Some(construct)) :: assignments, Ref(obj)).asExprOf[T] + } else { + val obj = + Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) + val args = constructorOwned.map(localRef) + val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args) + val assignments = postConstruction.map { field => + assignPostConstruction(Ref(obj), field, localRef(field), fieldAccessorsExpr) + } + Block(ValDef(obj, Some(construct)) :: assignments, Ref(obj)).asExprOf[T] } } @@ -440,8 +900,9 @@ object ForySerializerMacros { valueExpr: Expr[T], field: FieldMeta, copyContextExpr: Expr[org.apache.fory.context.CopyContext], - fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]]): Expr[Any] = { - val selected = selectValue(valueExpr, field) + fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[Any] = { + val selected = selectValue(valueExpr, field, fieldAccessorsExpr) val wireValue = if field.option then '{ $selected.asInstanceOf[Option[Any]].orNull } else selected @@ -479,7 +940,9 @@ object ForySerializerMacros { def copyValue( valueExpr: Expr[T], copyContextExpr: Expr[org.apache.fory.context.CopyContext], - fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]]): Expr[T] = { + fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]], + instantiatorExpr: Expr[org.apache.fory.reflect.ObjectInstantiator[T]], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[T] = { val constructorOwned = fields.filter(_.constructorOwned) val postConstruction = fields.filterNot(_.constructorOwned) @@ -487,7 +950,7 @@ object ForySerializerMacros { if postConstruction.isEmpty then { val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) val args = constructorOwned.map { field => - copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr).asTerm + copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr, fieldAccessorsExpr).asTerm } val cycleCheck = failIfCopiedDuringConstructorArgCopy( valueExpr, @@ -501,10 +964,11 @@ object ForySerializerMacros { Ref(obj)).asExprOf[T] } else if constructorOwned.isEmpty then { val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) - val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), Nil) + val construct = '{ $instantiatorExpr.newInstance() }.asTerm val assignments = postConstruction.map { field => - val copied = copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr) - Assign(Select.unique(Ref(obj), field.name), copied.asTerm) + val copied = + copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr, fieldAccessorsExpr) + assignPostConstruction(Ref(obj), field, copied.asTerm, fieldAccessorsExpr) } Block( ValDef(obj, Some(construct)) :: @@ -514,15 +978,16 @@ object ForySerializerMacros { } else { val obj = Symbol.newVal(Symbol.spliceOwner, "obj", TypeRepr.of[T], Flags.EmptyFlags, Symbol.noSymbol) val args = constructorOwned.map { field => - copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr).asTerm + copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr, fieldAccessorsExpr).asTerm } val cycleCheck = failIfCopiedDuringConstructorArgCopy( valueExpr, copyContextExpr).asTerm val construct = Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), args) val assignments = postConstruction.map { field => - val copied = copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr) - Assign(Select.unique(Ref(obj), field.name), copied.asTerm) + val copied = + copiedValueArg(valueExpr, field, copyContextExpr, fieldsByIdExpr, fieldAccessorsExpr) + assignPostConstruction(Ref(obj), field, copied.asTerm, fieldAccessorsExpr) } Block( cycleCheck :: @@ -543,6 +1008,132 @@ object ForySerializerMacros { } else constructFromValues(valuesExpr) } + def compatibleConstructRead( + serializerExpr: Expr[StaticGeneratedStructSerializer[T]], + resolverExpr: Expr[TypeResolver], + fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]], + readContextExpr: Expr[org.apache.fory.context.ReadContext], + instantiatorExpr: Expr[org.apache.fory.reflect.ObjectInstantiator[T]], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[T] = { + val localFields = fields.map { field => + val local = Symbol.newVal( + Symbol.spliceOwner, + s"${field.name}Value", + field.sourceType, + Flags.Mutable, + Symbol.noSymbol) + field -> local + } + val localDefs = localFields.map { (field, local) => + ValDef(local, Some(localDefault(field))) + } + val defaultedFields = fields.filter(field => field.constructorOwned && field.hasDefault) + val defaultIndexByField = defaultedFields.zipWithIndex.map { (field, index) => + field.index -> index + }.toMap + val defaultMaskSym = + if defaultedFields.nonEmpty && defaultedFields.size <= 64 then { + Some( + Symbol.newVal( + Symbol.spliceOwner, + "missingDefaultMask", + TypeRepr.of[Long], + Flags.Mutable, + Symbol.noSymbol)) + } else None + val defaultArraySym = + if defaultedFields.size > 64 then { + Some( + Symbol.newVal( + Symbol.spliceOwner, + "missingDefaultMask", + TypeRepr.of[Array[Boolean]], + Flags.EmptyFlags, + Symbol.noSymbol)) + } else None + def clearDefault(field: FieldMeta): Term = + defaultIndexByField.get(field.index) match { + case Some(defaultIndex) => + defaultMaskSym match { + case Some(mask) => + val bit = 1L << defaultIndex + Assign(Ref(mask), '{ ${ Ref(mask).asExprOf[Long] } & ${ Expr(~bit) } }.asTerm) + case None => + val mask = defaultArraySym.get + '{ + ${ Ref(mask).asExprOf[Array[Boolean]] }.update(${ Expr(defaultIndex) }, false) + }.asTerm + } + case None => + '{ () }.asTerm + } + def applyDefault(field: FieldMeta, local: Symbol, defaultIndex: Int): Term = { + val defaultValue = + '{ + org.apache.fory.util.DefaultValueUtils.getScalaDefaultValue( + Class.forName(${ Expr(ownerClassName) }), + ${ Expr(field.name) }) + } + val assignDefault = assignLocalSource(local, field, defaultValue) + defaultMaskSym match { + case Some(mask) => + val bit = 1L << defaultIndex + '{ + if ((${ Ref(mask).asExprOf[Long] } & ${ Expr(bit) }) != 0L) { + $assignDefault + } + }.asTerm + case None => + val mask = defaultArraySym.get + '{ + if (${ Ref(mask).asExprOf[Array[Boolean]] }(${ Expr(defaultIndex) })) { + $assignDefault + } + }.asTerm + } + } + val defaultAssignments = defaultedFields.zipWithIndex.map { (field, defaultIndex) => + applyDefault(field, localFields.find(_._1.index == field.index).get._2, defaultIndex) + } + val readLoop = + '{ + val remoteFields = $serializerExpr.getRemoteFields() + var i = 0 + while i < remoteFields.size() do { + val remoteField = remoteFields.get(i) + val matchedId = remoteField.matchedId + ${ + compatibleLocalsByMatchedId( + serializerExpr, + resolverExpr, + fieldsByIdExpr, + localFields, + clearDefault, + 'matchedId, + 'remoteField, + readContextExpr) + } + i += 1 + } + } + val maskDefs = + defaultMaskSym.map { mask => + val initial = + if defaultedFields.size == 64 then -1L + else (1L << defaultedFields.size) - 1L + ValDef(mask, Some(Literal(LongConstant(initial)))) + }.toList ++ + defaultArraySym.map { mask => + ValDef(mask, Some('{ Array.fill[Boolean](${ Expr(defaultedFields.size) })(true) }.asTerm)) + }.toList + Block( + localDefs ++ maskDefs, + Block( + readLoop.asTerm :: defaultAssignments.toList, + constructFromLocals(localFields, instantiatorExpr, fieldAccessorsExpr).asTerm)) + .asExprOf[T] + } + def readSchemaConsistentBody( serializerExpr: Expr[StaticGeneratedStructSerializer[T]], resolverExpr: Expr[TypeResolver], @@ -550,7 +1141,9 @@ object ForySerializerMacros { classVersionHashExpr: Expr[Int], allFieldsExpr: Expr[Array[FieldGroups.SerializationFieldInfo]], allFieldIdsExpr: Expr[Array[Int]], - readContextExpr: Expr[org.apache.fory.context.ReadContext]): Expr[T] = { + readContextExpr: Expr[org.apache.fory.context.ReadContext], + instantiatorExpr: Expr[org.apache.fory.reflect.ObjectInstantiator[T]], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[T] = { val constructorOwned = fields.filter(_.constructorOwned) if constructorOwned.isEmpty then { '{ @@ -558,13 +1151,21 @@ object ForySerializerMacros { if $resolverExpr.checkClassVersion() then { $serializerExpr.checkClassVersion(buffer.readInt32(), $classVersionHashExpr) } - val obj = ${ Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), Nil).asExprOf[T] } + val obj = $instantiatorExpr.newInstance() $readContextExpr.reference(obj) var i = 0 while i < $allFieldsExpr.length do { val fieldInfo = $allFieldsExpr(i) val fieldId = $allFieldIdsExpr(i) - ${ readAndAssignDispatch('obj, 'fieldId, 'fieldInfo, readContextExpr, resolverExpr) } + ${ + readAndAssignDispatch( + 'obj, + 'fieldId, + 'fieldInfo, + readContextExpr, + resolverExpr, + fieldAccessorsExpr) + } i += 1 } obj @@ -590,35 +1191,35 @@ object ForySerializerMacros { def readCompatibleBody( serializerExpr: Expr[StaticGeneratedStructSerializer[T]], + resolverExpr: Expr[TypeResolver], descriptorsExpr: Expr[java.util.List[Descriptor]], fieldsByIdExpr: Expr[Array[FieldGroups.SerializationFieldInfo]], sameSchemaCompatibleExpr: Expr[Boolean], - readContextExpr: Expr[org.apache.fory.context.ReadContext]): Expr[T] = { + readContextExpr: Expr[org.apache.fory.context.ReadContext], + instantiatorExpr: Expr[org.apache.fory.reflect.ObjectInstantiator[T]], + fieldAccessorsExpr: Expr[Array[org.apache.fory.reflect.FieldAccessor]]): Expr[T] = { val constructorOwned = fields.filter(_.constructorOwned) if constructorOwned.isEmpty then { '{ if $sameSchemaCompatibleExpr then { $serializerExpr.read($readContextExpr) } else { - val obj = ${ Apply(Select(New(TypeTree.of[T]), owner.primaryConstructor), Nil).asExprOf[T] } + val obj = $instantiatorExpr.newInstance() $readContextExpr.reference(obj) val remoteFields = $serializerExpr.getRemoteFields() var i = 0 while i < remoteFields.size() do { val remoteField = remoteFields.get(i) val matchedId = remoteField.matchedId - if matchedId >= 0 then { - val localField = $fieldsByIdExpr(matchedId) - if $serializerExpr.canReadGeneratedField(remoteField, localField) then { - val fieldValue = - $serializerExpr.readCompatibleFieldValue($readContextExpr, remoteField, localField) - ${ assignValueById('obj, 'matchedId, 'fieldValue) } - } else { - $serializerExpr.skipField($readContextExpr, remoteField) - } - } else { - $serializerExpr.skipField($readContextExpr, remoteField) - } + ${ compatibleAssignByMatchedId( + serializerExpr, + resolverExpr, + fieldsByIdExpr, + 'obj, + 'matchedId, + 'remoteField, + readContextExpr, + fieldAccessorsExpr) } i += 1 } obj @@ -629,31 +1230,51 @@ object ForySerializerMacros { if $sameSchemaCompatibleExpr then { $serializerExpr.read($readContextExpr) } else { - val values = new Array[Any]($descriptorsExpr.size()) - val remoteFields = $serializerExpr.getRemoteFields() - var i = 0 - while i < remoteFields.size() do { - val remoteField = remoteFields.get(i) - val matchedId = remoteField.matchedId - if matchedId >= 0 then { - val localField = $fieldsByIdExpr(matchedId) - if $serializerExpr.canReadGeneratedField(remoteField, localField) then { - values(matchedId) = - $serializerExpr.readCompatibleFieldValue($readContextExpr, remoteField, localField) - } else { - $serializerExpr.skipField($readContextExpr, remoteField) - } - } else { - $serializerExpr.skipField($readContextExpr, remoteField) - } - i += 1 + ${ + compatibleConstructRead( + serializerExpr, + resolverExpr, + fieldsByIdExpr, + readContextExpr, + instantiatorExpr, + fieldAccessorsExpr) } - ${ constructRead('values, readContextExpr) } } } } } + def fieldAccessors(clsExpr: Expr[Class[T]]) + : Expr[Array[org.apache.fory.reflect.FieldAccessor]] = { + val accessorsSym = + Symbol.newVal( + Symbol.spliceOwner, + "accessors", + TypeRepr.of[Array[org.apache.fory.reflect.FieldAccessor]], + Flags.EmptyFlags, + Symbol.noSymbol) + val assignments = fields.filter(_.usesFieldAccessor).map { field => + '{ + ${ Ref(accessorsSym).asExprOf[Array[org.apache.fory.reflect.FieldAccessor]] }( + ${ Expr(field.index) }) = + org.apache.fory.reflect.FieldAccessor.createAccessor( + Class + .forName(${ Expr(ownerClassName) }, false, $clsExpr.getClassLoader) + .getDeclaredField(${ Expr(field.name) })) + }.asTerm + } + Block( + ValDef( + accessorsSym, + Some( + '{ + new Array[org.apache.fory.reflect.FieldAccessor](${ Expr(fields.size) }) + }.asTerm)) + :: assignments, + Ref(accessorsSym)) + .asExprOf[Array[org.apache.fory.reflect.FieldAccessor]] + } + val classExpr: Expr[Class[T]] = '{ Class.forName(${ Expr(ownerClassName) }).asInstanceOf[Class[T]] } @@ -681,6 +1302,12 @@ object ForySerializerMacros { } result } + private val generatedObjectInstantiator + : org.apache.fory.reflect.ObjectInstantiator[T] = + resolver.getObjectInstantiator(cls) + private val generatedFieldAccessors + : Array[org.apache.fory.reflect.FieldAccessor] = + ${ fieldAccessors('cls) } private val classVersionHash: Int = if resolver.checkClassVersion() then computeClassVersionHash(descriptors) else 0 private val sameSchemaCompatible: Boolean = @@ -707,7 +1334,15 @@ object ForySerializerMacros { while i < allFields.length do { val fieldInfo = allFields(i) val fieldId = allFieldIds(i) - ${ writeDispatch('value, 'fieldId, 'fieldInfo, 'writeContext, 'resolver) } + ${ + writeDispatch( + 'value, + 'fieldId, + 'fieldInfo, + 'writeContext, + 'resolver, + 'generatedFieldAccessors) + } i += 1 } } @@ -727,23 +1362,36 @@ object ForySerializerMacros { 'classVersionHash, 'allFields, 'allFieldIds, - 'readContext) + 'readContext, + 'generatedObjectInstantiator, + 'generatedFieldAccessors) } override def readCompatible(readContext: org.apache.fory.context.ReadContext): T = { ${ readCompatibleBody( 'generatedSerializer, + 'resolver, 'descriptors, 'fieldsById, 'sameSchemaCompatible, - 'readContext) + 'readContext, + 'generatedObjectInstantiator, + 'generatedFieldAccessors) } } override def copy( copyContext: org.apache.fory.context.CopyContext, - value: T): T = ${ copyValue('value, 'copyContext, 'fieldsById) } + value: T): T = + ${ + copyValue( + 'value, + 'copyContext, + 'fieldsById, + 'generatedObjectInstantiator, + 'generatedFieldAccessors) + } } } } diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala index 71a0868561..7cf598f381 100644 --- a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala @@ -33,6 +33,7 @@ import org.apache.fory.annotation.{ import org.apache.fory.config.Int64Encoding import org.apache.fory.memory.MemoryBuffer import org.apache.fory.meta.TypeDef +import org.apache.fory.reflect.{FieldAccessor, ObjectInstantiators} import org.apache.fory.scala.ForySerializer import org.apache.fory.scala.ForyScala import org.apache.fory.scala.register @@ -102,6 +103,37 @@ object ForySerializerDerivationTest { @ForyField(id = 6) narrow: Int) derives ForySerializer + @ForyStruct + final class AccessorScalarWriter private () derives ForySerializer { + @ForyField(id = 1) + private var id: Long = 0L + + @ForyField(id = 2) + private var name: String = "" + + @ForyField(id = 3) + private var ignored: String = "" + + def idValue: Long = id + + def nameValue: String = name + + def ignoredValue: String = ignored + } + + @ForyStruct + final class AccessorScalarReader private () derives ForySerializer { + @ForyField(id = 1) + private val id: Int = 0 + + @ForyField(id = 2) + private var name: String = "" + + def idValue: Int = id + + def nameValue: String = name + } + @ForyStruct final case class CopyBox( @ForyField(id = 1) user: SearchUser, @@ -205,6 +237,14 @@ object ForySerializerDerivationTest { .build() fory } + + def newAccessorValue[T](cls: Class[T], values: (String, AnyRef)*): T = { + val value = ObjectInstantiators.getObjectInstantiator(cls).newInstance() + values.foreach { (fieldName, fieldValue) => + FieldAccessor.createAccessor(cls.getDeclaredField(fieldName)).putObject(value, fieldValue) + } + value + } } class ForySerializerDerivationTest extends AnyWordSpec with Matchers { @@ -337,6 +377,39 @@ class ForySerializerDerivationTest extends AnyWordSpec with Matchers { readerValue.narrow shouldBe 123 } + "read compatible private ordinary fields through generated accessors" in { + val writerFory = ForySerializerDerivationTest.compatibleXlangFory() + ForySerializer.register( + writerFory, + classOf[AccessorScalarWriter], + "scala_test", + "AccessorScalar") + val readerFory = ForySerializerDerivationTest.compatibleXlangFory() + ForySerializer.register( + readerFory, + classOf[AccessorScalarReader], + "scala_test", + "AccessorScalar") + + val writerValue = newAccessorValue( + classOf[AccessorScalarWriter], + "id" -> java.lang.Long.valueOf(321L), + "name" -> "Ada", + "ignored" -> "remote") + val readerValue = + readerFory + .deserialize(writerFory.serialize(writerValue)) + .asInstanceOf[AccessorScalarReader] + + readerValue.idValue shouldBe 321 + readerValue.nameValue shouldBe "Ada" + + val copied = readerFory.copy(readerValue).asInstanceOf[AccessorScalarReader] + copied should not be theSameInstanceAs(readerValue) + copied.idValue shouldBe 321 + copied.nameValue shouldBe "Ada" + } + "emit inner nullable metadata for Option collection elements" in { val fory = xlangFory() val serializer = diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala index 1ad51ffd68..172e42bec6 100644 --- a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ScalaXlangPeer.scala @@ -867,14 +867,7 @@ object ScalaXlangPeer { } private def listArrayListToArray(dataFile: Path): Unit = { - val value = readOne[XlangCompatibleInt32ListField](dataFile, int32ListFory()) - val values = new Array[Int](value.values.size()) - var i = 0 - while i < value.values.size() do { - values(i) = value.values.get(i) - i += 1 - } - writeOne(dataFile, int32ArrayFory(), XlangCompatibleInt32ArrayField(values)) + roundTripValues(dataFile, int32ArrayFory()) } private def listArrayArrayToList(dataFile: Path): Unit = { diff --git a/swift/Sources/Fory/CompatibleScalarConversion.swift b/swift/Sources/Fory/CompatibleScalarConversion.swift index 75a0ec629c..57f5adf192 100644 --- a/swift/Sources/Fory/CompatibleScalarConversion.swift +++ b/swift/Sources/Fory/CompatibleScalarConversion.swift @@ -32,24 +32,6 @@ private enum CompatibleScalarValue { case decimal(Decimal) } -private enum CompatibleScalarKind { - case bool - case string - case signedInteger - case unsignedInteger - case floatingPoint - case decimal - - var isNumeric: Bool { - switch self { - case .signedInteger, .unsignedInteger, .floatingPoint, .decimal: - return true - case .bool, .string: - return false - } - } -} - private struct DecimalLiteral { let negative: Bool let digits: String @@ -75,58 +57,239 @@ private struct BinaryFloatLayout { } @inline(never) -public func foryReadCompatibleScalarField( +public func foryReadCompatibleBoolField( _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - localTypeID: UInt32, - fieldName: String, - directRead: () throws -> T -) throws -> T { - guard let remoteTypeID = TypeId(rawValue: remoteFieldType.typeID), - let localTypeID = TypeId(rawValue: localTypeID) + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Bool { + try foryReadCompatibleOptionalBoolField( + context, remoteField: remoteField, localField: localField) + ?? false +} + +@inline(never) +public func foryReadCompatibleOptionalBoolField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Bool? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) else { - return try directRead() + return nil } - guard compatibleScalarKind(localTypeID) != nil else { - return try directRead() + guard let converted = try compatibleScalarToBool(remoteValue) else { + throw compatibleScalarError( + fieldName: fieldName, + remoteTypeID: remoteTypeID, + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" + ) } - guard compatibleScalarKind(remoteTypeID) != nil else { + return converted +} + +@inline(never) +public func foryReadCompatibleInt8Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Int8 { + try foryReadCompatibleOptionalInt8Field( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalInt8Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Int8? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) + else { + return nil + } + guard let converted = try compatibleScalarToInt8Target(remoteValue) else { throw compatibleScalarError( fieldName: fieldName, remoteTypeID: remoteTypeID, - localTypeID: localTypeID, - reason: "remote field type is not a compatible scalar" + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" ) } - guard compatibleScalarPair(remoteTypeID: remoteTypeID, localTypeID: localTypeID) else { + return converted +} + +@inline(never) +public func foryReadCompatibleInt16Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Int16 { + try foryReadCompatibleOptionalInt16Field( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalInt16Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Int16? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) + else { + return nil + } + guard let converted = try compatibleScalarToInt16Target(remoteValue) else { throw compatibleScalarError( fieldName: fieldName, remoteTypeID: remoteTypeID, - localTypeID: localTypeID, - reason: "schema pair is outside the scalar conversion matrix" + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" ) } - guard !remoteFieldType.trackRef else { + return converted +} + +@inline(never) +public func foryReadCompatibleInt32Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Int32 { + try foryReadCompatibleOptionalInt32Field( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalInt32Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Int32? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) + else { + return nil + } + guard let converted = try compatibleScalarToInt32Target(remoteValue) else { throw compatibleScalarError( fieldName: fieldName, remoteTypeID: remoteTypeID, - localTypeID: localTypeID, - reason: "trackingRef scalar conversion is not supported" + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" ) } + return converted +} + +@inline(never) +public func foryReadCompatibleInt64Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Int64 { + try foryReadCompatibleOptionalInt64Field( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalInt64Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Int64? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName guard - let remoteValue = try readCompatibleRemoteScalar( - context, remoteTypeID: remoteTypeID, fieldType: remoteFieldType) + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) else { - return compatibleScalarDefault(T.self, localTypeID: localTypeID) + return nil + } + guard let converted = try compatibleScalarToInt64(remoteValue) else { + throw compatibleScalarError( + fieldName: fieldName, + remoteTypeID: remoteTypeID, + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" + ) } + return converted +} + +@inline(never) +public func foryReadCompatibleIntField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Int { + try foryReadCompatibleOptionalIntField( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalIntField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Int? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName guard - let converted = try convertCompatibleScalar(remoteValue, to: T.self, localTypeID: localTypeID) + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) else { + return nil + } + guard let converted = try compatibleScalarToIntTarget(remoteValue) else { throw compatibleScalarError( fieldName: fieldName, remoteTypeID: remoteTypeID, - localTypeID: localTypeID, + localTypeID: resolvedLocalTypeID, reason: "value is not lossless for the local field type" ) } @@ -134,105 +297,467 @@ public func foryReadCompatibleScalarField( } @inline(never) -public func foryReadCompatibleOptionalScalarField( +public func foryReadCompatibleUInt8Field( _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - localTypeID: UInt32, - fieldName: String, - directRead: () throws -> T? -) throws -> T? { - guard let remoteTypeID = TypeId(rawValue: remoteFieldType.typeID), - let localTypeID = TypeId(rawValue: localTypeID) + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> UInt8 { + try foryReadCompatibleOptionalUInt8Field( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalUInt8Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> UInt8? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) else { - return try directRead() + return nil } - guard compatibleScalarKind(localTypeID) != nil else { - return try directRead() + guard let converted = try compatibleScalarToUInt8Target(remoteValue) else { + throw compatibleScalarError( + fieldName: fieldName, + remoteTypeID: remoteTypeID, + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" + ) } - guard compatibleScalarKind(remoteTypeID) != nil else { + return converted +} + +@inline(never) +public func foryReadCompatibleUInt16Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> UInt16 { + try foryReadCompatibleOptionalUInt16Field( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalUInt16Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> UInt16? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) + else { + return nil + } + guard let converted = try compatibleScalarToUInt16Target(remoteValue) else { throw compatibleScalarError( fieldName: fieldName, remoteTypeID: remoteTypeID, - localTypeID: localTypeID, - reason: "remote field type is not a compatible scalar" + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" ) } - guard compatibleScalarPair(remoteTypeID: remoteTypeID, localTypeID: localTypeID) else { + return converted +} + +@inline(never) +public func foryReadCompatibleUInt32Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> UInt32 { + try foryReadCompatibleOptionalUInt32Field( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalUInt32Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> UInt32? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) + else { + return nil + } + guard let converted = try compatibleScalarToUInt32Target(remoteValue) else { throw compatibleScalarError( fieldName: fieldName, remoteTypeID: remoteTypeID, - localTypeID: localTypeID, - reason: "schema pair is outside the scalar conversion matrix" + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" ) } - guard !remoteFieldType.trackRef else { + return converted +} + +@inline(never) +public func foryReadCompatibleUInt64Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> UInt64 { + try foryReadCompatibleOptionalUInt64Field( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalUInt64Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> UInt64? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) + else { + return nil + } + guard let converted = try compatibleScalarToUInt64(remoteValue) else { throw compatibleScalarError( fieldName: fieldName, remoteTypeID: remoteTypeID, - localTypeID: localTypeID, - reason: "trackingRef scalar conversion is not supported" + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" ) } + return converted +} + +@inline(never) +public func foryReadCompatibleUIntField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> UInt { + try foryReadCompatibleOptionalUIntField( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalUIntField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> UInt? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName guard - let remoteValue = try readCompatibleRemoteScalar( - context, remoteTypeID: remoteTypeID, fieldType: remoteFieldType) + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) else { return nil } + guard let converted = try compatibleScalarToUIntTarget(remoteValue) else { + throw compatibleScalarError( + fieldName: fieldName, + remoteTypeID: remoteTypeID, + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" + ) + } + return converted +} + +@inline(never) +public func foryReadCompatibleFloat16Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Float16 { + try foryReadCompatibleOptionalFloat16Field( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalFloat16Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Float16? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName guard - let converted = try convertCompatibleScalar(remoteValue, to: T.self, localTypeID: localTypeID) + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) else { + return nil + } + guard let converted = try compatibleScalarToFloat16(remoteValue) else { throw compatibleScalarError( fieldName: fieldName, remoteTypeID: remoteTypeID, - localTypeID: localTypeID, + localTypeID: resolvedLocalTypeID, reason: "value is not lossless for the local field type" ) } return converted } -private func compatibleScalarKind(_ typeID: TypeId) -> CompatibleScalarKind? { - switch typeID { - case .bool: - return .bool - case .string: - return .string - case .int8, .int16, .int32, .varint32, .int64, .varint64, .taggedInt64: - return .signedInteger - case .uint8, .uint16, .uint32, .varUInt32, .uint64, .varUInt64, .taggedUInt64: - return .unsignedInteger - case .float16, .bfloat16, .float32, .float64: - return .floatingPoint - case .decimal: - return .decimal - default: +@inline(never) +public func foryReadCompatibleBFloat16Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> BFloat16 { + try foryReadCompatibleOptionalBFloat16Field( + context, remoteField: remoteField, localField: localField) + ?? BFloat16() +} + +@inline(never) +public func foryReadCompatibleOptionalBFloat16Field( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> BFloat16? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) + else { return nil } + guard let converted = try compatibleScalarToBFloat16(remoteValue) else { + throw compatibleScalarError( + fieldName: fieldName, + remoteTypeID: remoteTypeID, + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" + ) + } + return converted } -private func compatibleScalarPair(remoteTypeID: TypeId, localTypeID: TypeId) -> Bool { - guard let remote = compatibleScalarKind(remoteTypeID), - let local = compatibleScalarKind(localTypeID) +@inline(never) +public func foryReadCompatibleFloatField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Float { + try foryReadCompatibleOptionalFloatField( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalFloatField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Float? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) else { - return false + return nil } - if remoteTypeID == localTypeID { - return true + guard let converted = try compatibleScalarToFloat(remoteValue) else { + throw compatibleScalarError( + fieldName: fieldName, + remoteTypeID: remoteTypeID, + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" + ) } - if remote == .bool { - return local == .string || local.isNumeric + return converted +} + +@inline(never) +public func foryReadCompatibleDoubleField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Double { + try foryReadCompatibleOptionalDoubleField( + context, remoteField: remoteField, localField: localField) + ?? 0 +} + +@inline(never) +public func foryReadCompatibleOptionalDoubleField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Double? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) + else { + return nil } - if local == .bool { - return remote == .string || remote.isNumeric + guard let converted = try compatibleScalarToDouble(remoteValue) else { + throw compatibleScalarError( + fieldName: fieldName, + remoteTypeID: remoteTypeID, + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" + ) } - if remote == .string { - return local.isNumeric + return converted +} + +@inline(never) +public func foryReadCompatibleStringField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> String { + try foryReadCompatibleOptionalStringField( + context, remoteField: remoteField, localField: localField) + ?? "" +} + +@inline(never) +public func foryReadCompatibleOptionalStringField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> String? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) + else { + return nil + } + guard let converted = try compatibleScalarToString(remoteValue) else { + throw compatibleScalarError( + fieldName: fieldName, + remoteTypeID: remoteTypeID, + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" + ) + } + return converted +} + +@inline(never) +public func foryReadCompatibleDecimalField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Decimal { + try foryReadCompatibleOptionalDecimalField( + context, remoteField: remoteField, localField: localField) + ?? .zero +} + +@inline(never) +public func foryReadCompatibleOptionalDecimalField( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + localField: TypeMeta.FieldInfo +) throws -> Decimal? { + var remoteTypeID: TypeId = .unknown + let resolvedLocalTypeID = try compatibleScalarLocalTypeID(localField) + let fieldName = localField.fieldName + guard + let remoteValue = try readCompatibleScalarValue( + context, + remoteField: remoteField, + fieldName: fieldName, + remoteTypeID: &remoteTypeID) + else { + return nil + } + guard let converted = try compatibleScalarToDecimal(remoteValue) else { + throw compatibleScalarError( + fieldName: fieldName, + remoteTypeID: remoteTypeID, + localTypeID: resolvedLocalTypeID, + reason: "value is not lossless for the local field type" + ) } - if local == .string { - return remote.isNumeric + return converted +} + +private func readCompatibleScalarValue( + _ context: ReadContext, + remoteField: TypeMeta.FieldInfo, + fieldName: String, + remoteTypeID: inout TypeId +) throws -> CompatibleScalarValue? { + let remoteFieldType = remoteField.fieldType + guard let resolvedRemoteTypeID = TypeId(rawValue: remoteFieldType.typeID) else { + throw ForyError.invalidData( + "unknown compatible scalar remote type \(remoteFieldType.typeID) for field \(fieldName)") + } + remoteTypeID = resolvedRemoteTypeID + return try readCompatibleRemoteScalar( + context, remoteTypeID: resolvedRemoteTypeID, fieldType: remoteFieldType) +} + +private func compatibleScalarLocalTypeID(_ localField: TypeMeta.FieldInfo) throws -> TypeId { + guard let localTypeID = TypeId(rawValue: localField.fieldType.typeID) else { + throw ForyError.invalidData( + "unknown compatible scalar local type for field \(localField.fieldName)") } - return remote.isNumeric && local.isNumeric + return localTypeID } private func readCompatibleRemoteScalar( @@ -325,73 +850,81 @@ private func readCompatibleBoolPayload(_ context: ReadContext) throws -> Bool { } } -private func convertCompatibleScalar( - _ value: CompatibleScalarValue, - to _: T.Type, - localTypeID: TypeId -) throws -> T? { - switch localTypeID { - case .bool: - guard T.self == Bool.self, let converted = try compatibleScalarToBool(value) else { - return nil - } - return converted as? T - case .string: - guard T.self == String.self, let converted = try compatibleScalarToString(value) else { - return nil - } - return converted as? T - case .int8: - return compatibleCastSigned( - try compatibleScalarToInt64(value), min: Int64(Int8.min), max: Int64(Int8.max), to: T.self) - case .int16: - return compatibleCastSigned( - try compatibleScalarToInt64(value), min: Int64(Int16.min), max: Int64(Int16.max), to: T.self) - case .int32, .varint32: - return compatibleCastSigned( - try compatibleScalarToInt64(value), min: Int64(Int32.min), max: Int64(Int32.max), to: T.self) - case .int64, .varint64, .taggedInt64: - return compatibleCastSigned( - try compatibleScalarToInt64(value), min: Int64.min, max: Int64.max, to: T.self) - case .uint8: - return compatibleCastUnsigned( - try compatibleScalarToUInt64(value), max: UInt64(UInt8.max), to: T.self) - case .uint16: - return compatibleCastUnsigned( - try compatibleScalarToUInt64(value), max: UInt64(UInt16.max), to: T.self) - case .uint32, .varUInt32: - return compatibleCastUnsigned( - try compatibleScalarToUInt64(value), max: UInt64(UInt32.max), to: T.self) - case .uint64, .varUInt64, .taggedUInt64: - return compatibleCastUnsigned(try compatibleScalarToUInt64(value), max: UInt64.max, to: T.self) - case .float16: - guard T.self == Float16.self, let converted = try compatibleScalarToFloat16(value) else { - return nil - } - return converted as? T - case .bfloat16: - guard T.self == BFloat16.self, let converted = try compatibleScalarToBFloat16(value) else { - return nil - } - return converted as? T - case .float32: - guard T.self == Float.self, let converted = try compatibleScalarToFloat(value) else { - return nil - } - return converted as? T - case .float64: - guard T.self == Double.self, let converted = try compatibleScalarToDouble(value) else { - return nil - } - return converted as? T - case .decimal: - guard T.self == Decimal.self, let converted = try compatibleScalarToDecimal(value) else { - return nil - } - return converted as? T - default: +private func compatibleScalarToInt8Target(_ value: CompatibleScalarValue) throws -> Int8? { + guard + let converted = try compatibleScalarToInt64(value), + converted >= Int64(Int8.min), + converted <= Int64(Int8.max) + else { + return nil + } + return Int8(converted) +} + +private func compatibleScalarToInt16Target(_ value: CompatibleScalarValue) throws -> Int16? { + guard + let converted = try compatibleScalarToInt64(value), + converted >= Int64(Int16.min), + converted <= Int64(Int16.max) + else { + return nil + } + return Int16(converted) +} + +private func compatibleScalarToInt32Target(_ value: CompatibleScalarValue) throws -> Int32? { + guard + let converted = try compatibleScalarToInt64(value), + converted >= Int64(Int32.min), + converted <= Int64(Int32.max) + else { + return nil + } + return Int32(converted) +} + +private func compatibleScalarToIntTarget(_ value: CompatibleScalarValue) throws -> Int? { + guard let converted = try compatibleScalarToInt64(value) else { + return nil + } + return Int(exactly: converted) +} + +private func compatibleScalarToUInt8Target(_ value: CompatibleScalarValue) throws -> UInt8? { + guard + let converted = try compatibleScalarToUInt64(value), + converted <= UInt64(UInt8.max) + else { + return nil + } + return UInt8(converted) +} + +private func compatibleScalarToUInt16Target(_ value: CompatibleScalarValue) throws -> UInt16? { + guard + let converted = try compatibleScalarToUInt64(value), + converted <= UInt64(UInt16.max) + else { + return nil + } + return UInt16(converted) +} + +private func compatibleScalarToUInt32Target(_ value: CompatibleScalarValue) throws -> UInt32? { + guard + let converted = try compatibleScalarToUInt64(value), + converted <= UInt64(UInt32.max) + else { + return nil + } + return UInt32(converted) +} + +private func compatibleScalarToUIntTarget(_ value: CompatibleScalarValue) throws -> UInt? { + guard let converted = try compatibleScalarToUInt64(value) else { return nil } + return UInt(exactly: converted) } private func compatibleScalarToBool(_ value: CompatibleScalarValue) throws -> Bool? { @@ -612,8 +1145,7 @@ private func compatibleScalarToDecimal(_ value: CompatibleScalarValue) throws -> } private func compatibleScalarToDecimalLiteral(_ value: CompatibleScalarValue) throws - -> DecimalLiteral? -{ + -> DecimalLiteral? { switch value { case .bool(let value): return DecimalLiteral(negative: false, digits: value ? "1" : "0", scale: 0, negativeZero: false) @@ -641,102 +1173,6 @@ private func compatibleScalarToDecimalLiteral(_ value: CompatibleScalarValue) th } } -private func compatibleCastSigned(_ value: Int64?, min: Int64, max: Int64, to _: T.Type) -> T? { - guard let value, value >= min, value <= max else { - return nil - } - switch T.self { - case is Int8.Type: - return Int8(value) as? T - case is Int16.Type: - return Int16(value) as? T - case is Int32.Type: - return Int32(value) as? T - case is Int64.Type: - return value as? T - case is Int.Type: - guard let converted = Int(exactly: value) else { - return nil - } - return converted as? T - default: - return nil - } -} - -private func compatibleCastUnsigned(_ value: UInt64?, max: UInt64, to _: T.Type) -> T? { - guard let value, value <= max else { - return nil - } - switch T.self { - case is UInt8.Type: - return UInt8(value) as? T - case is UInt16.Type: - return UInt16(value) as? T - case is UInt32.Type: - return UInt32(value) as? T - case is UInt64.Type: - return value as? T - case is UInt.Type: - guard let converted = UInt(exactly: value) else { - return nil - } - return converted as? T - default: - return nil - } -} - -private func compatibleScalarDefault(_: T.Type, localTypeID: TypeId) -> T { - guard let value = compatibleScalarDefaultValue(T.self, localTypeID: localTypeID) else { - preconditionFailure("unsupported compatible scalar default") - } - return value -} - -private func compatibleScalarDefaultValue(_: T.Type, localTypeID: TypeId) -> T? { - switch localTypeID { - case .bool: - return false as? T - case .string: - return "" as? T - case .int8: - return Int8(0) as? T - case .int16: - return Int16(0) as? T - case .int32, .varint32: - return Int32(0) as? T - case .int64, .varint64, .taggedInt64: - if T.self == Int.self { - return Int(0) as? T - } - return Int64(0) as? T - case .uint8: - return UInt8(0) as? T - case .uint16: - return UInt16(0) as? T - case .uint32, .varUInt32: - return UInt32(0) as? T - case .uint64, .varUInt64, .taggedUInt64: - if T.self == UInt.self { - return UInt(0) as? T - } - return UInt64(0) as? T - case .float16: - return Float16(0) as? T - case .bfloat16: - return BFloat16() as? T - case .float32: - return Float(0) as? T - case .float64: - return Double(0) as? T - case .decimal: - return Decimal.zero as? T - default: - return nil - } -} - private func compatibleParseNumericLiteral(_ text: String) -> DecimalLiteral? { guard !text.isEmpty, text.utf8.count <= maxCompatibleNumericTextLength else { return nil @@ -903,7 +1339,9 @@ private func compatibleParseNumericLiteral(_ text: String) -> DecimalLiteral? { else { return nil } - let normalized = String(decoding: digitBytes, as: UTF8.self) + guard let normalized = String(bytes: digitBytes, encoding: .utf8) else { + return nil + } let negativeZero = negative && normalized == "0" return DecimalLiteral( negative: negative, digits: normalized, scale: scale, negativeZero: negativeZero) @@ -1039,8 +1477,7 @@ private func compatibleFormatDecimalText(negative: Bool, digits: String, scale: return fraction.isEmpty ? sign + integer : "\(sign)\(integer).\(fraction)" } -private func compatibleFloatCanonicalText(_ literal: DecimalLiteral, forceFraction: Bool) -> String? -{ +private func compatibleFloatCanonicalText(_ literal: DecimalLiteral, forceFraction: Bool) -> String? { if literal.isZero { return literal.negativeZero ? "-0.0" : "0.0" } diff --git a/swift/Sources/Fory/FieldCodecs.swift b/swift/Sources/Fory/FieldCodecs.swift index 9d48cd6e74..96e0babd1b 100644 --- a/swift/Sources/Fory/FieldCodecs.swift +++ b/swift/Sources/Fory/FieldCodecs.swift @@ -643,7 +643,11 @@ public enum ListFieldCodec: FieldCodec { typeID: TypeId.list.rawValue, nullable: nullable, trackRef: trackRef, - generics: [ElementCodec.fieldType(nullable: ElementCodec.isNullableType, trackRef: false)] + generics: [ + ElementCodec.fieldType( + nullable: ElementCodec.isNullableType, + trackRef: trackRef && ElementCodec.isRefType) + ] ) } @@ -823,7 +827,11 @@ public enum SetFieldCodec: FieldCodec where ElementCod typeID: TypeId.set.rawValue, nullable: nullable, trackRef: trackRef, - generics: [ElementCodec.fieldType(nullable: ElementCodec.isNullableType, trackRef: false)] + generics: [ + ElementCodec.fieldType( + nullable: ElementCodec.isNullableType, + trackRef: trackRef && ElementCodec.isRefType) + ] ) } @@ -860,8 +868,12 @@ where KeyCodec.Value: Hashable { nullable: nullable, trackRef: trackRef, generics: [ - KeyCodec.fieldType(nullable: KeyCodec.isNullableType, trackRef: false), - ValueCodec.fieldType(nullable: ValueCodec.isNullableType, trackRef: false) + KeyCodec.fieldType( + nullable: KeyCodec.isNullableType, + trackRef: trackRef && KeyCodec.isRefType), + ValueCodec.fieldType( + nullable: ValueCodec.isNullableType, + trackRef: trackRef && ValueCodec.isRefType) ] ) } diff --git a/swift/Sources/Fory/MacroDeclarations.swift b/swift/Sources/Fory/MacroDeclarations.swift index 60b700e462..c40096055c 100644 --- a/swift/Sources/Fory/MacroDeclarations.swift +++ b/swift/Sources/Fory/MacroDeclarations.swift @@ -16,211 +16,227 @@ // under the License. @attached( - member, - names: named(staticTypeId), - named(foryEvolving), - named(isRefType), - named(__foryNormalizeSchemaFingerprintTypeID), - named(__forySchemaHash), - named(__forySchemaHashTrackRefDisabled), - named(__forySchemaHashTrackRefEnabled), - named(foryFieldsInfo), - named(__foryFieldsInfoTrackRefDisabled), - named(__foryFieldsInfoTrackRefEnabled), - named(foryDefault), - named(foryWrite), - named(foryRead), - named(foryWriteData), - named(foryReadData), - named(foryReadCompatibleData), - named(__foryReadDataImpl), - named(__foryReadCompatibleDataImpl) + member, + names: named(staticTypeId), + named(foryEvolving), + named(isRefType), + named(__foryNormalizeSchemaFingerprintTypeID), + named(__forySchemaHash), + named(__forySchemaHashTrackRefDisabled), + named(__forySchemaHashTrackRefEnabled), + named(foryFieldsInfo), + named(__foryFieldsInfoTrackRefDisabled), + named(__foryFieldsInfoTrackRefEnabled), + named(foryDefault), + named(foryWrite), + named(foryRead), + named(foryWriteData), + named(foryReadData), + named(foryReadCompatibleData), + named(__foryReadChangedData), + named(__foryReadDataImpl), + named(__foryReadCompatibleDataImpl) ) @attached(extension, conformances: Serializer, StructSerializer) -public macro ForyStruct(evolving: Bool = true) = #externalMacro(module: "ForyMacro", type: "ForyStructMacro") +public macro ForyStruct(evolving: Bool = true) = + #externalMacro(module: "ForyMacro", type: "ForyStructMacro") @attached( - member, - names: named(staticTypeId), - named(foryEvolving), - named(isRefType), - named(__foryNormalizeSchemaFingerprintTypeID), - named(__forySchemaHash), - named(__forySchemaHashTrackRefDisabled), - named(__forySchemaHashTrackRefEnabled), - named(foryFieldsInfo), - named(__foryFieldsInfoTrackRefDisabled), - named(__foryFieldsInfoTrackRefEnabled), - named(foryDefault), - named(foryWrite), - named(foryRead), - named(foryWriteData), - named(foryReadData), - named(foryReadCompatibleData), - named(__foryReadDataImpl), - named(__foryReadCompatibleDataImpl) + member, + names: named(staticTypeId), + named(foryEvolving), + named(isRefType), + named(__foryNormalizeSchemaFingerprintTypeID), + named(__forySchemaHash), + named(__forySchemaHashTrackRefDisabled), + named(__forySchemaHashTrackRefEnabled), + named(foryFieldsInfo), + named(__foryFieldsInfoTrackRefDisabled), + named(__foryFieldsInfoTrackRefEnabled), + named(foryDefault), + named(foryWrite), + named(foryRead), + named(foryWriteData), + named(foryReadData), + named(foryReadCompatibleData), + named(__foryReadChangedData), + named(__foryReadDataImpl), + named(__foryReadCompatibleDataImpl) ) @attached(extension, conformances: Serializer, StructSerializer) public macro ForyEnum() = #externalMacro(module: "ForyMacro", type: "ForyEnumMacro") @attached( - member, - names: named(staticTypeId), - named(foryEvolving), - named(isRefType), - named(__foryNormalizeSchemaFingerprintTypeID), - named(__forySchemaHash), - named(__forySchemaHashTrackRefDisabled), - named(__forySchemaHashTrackRefEnabled), - named(foryFieldsInfo), - named(__foryFieldsInfoTrackRefDisabled), - named(__foryFieldsInfoTrackRefEnabled), - named(foryDefault), - named(foryWrite), - named(foryRead), - named(foryWriteData), - named(foryReadData), - named(foryReadCompatibleData), - named(__foryReadDataImpl), - named(__foryReadCompatibleDataImpl) + member, + names: named(staticTypeId), + named(foryEvolving), + named(isRefType), + named(__foryNormalizeSchemaFingerprintTypeID), + named(__forySchemaHash), + named(__forySchemaHashTrackRefDisabled), + named(__forySchemaHashTrackRefEnabled), + named(foryFieldsInfo), + named(__foryFieldsInfoTrackRefDisabled), + named(__foryFieldsInfoTrackRefEnabled), + named(foryDefault), + named(foryWrite), + named(foryRead), + named(foryWriteData), + named(foryReadData), + named(foryReadCompatibleData), + named(__foryReadChangedData), + named(__foryReadDataImpl), + named(__foryReadCompatibleDataImpl) ) @attached(extension, conformances: Serializer, StructSerializer) public macro ForyUnion() = #externalMacro(module: "ForyMacro", type: "ForyUnionMacro") public enum ForyFieldEncoding: String { - case varint - case fixed - case tagged + case varint + case fixed + case tagged } public struct ForyFieldType: Sendable { - private init() {} - - public static func encoding(_ encoding: ForyFieldEncoding) -> ForyFieldType { - _ = encoding - return .init() - } - - public static var bool: ForyFieldType { .init() } - public static var int8: ForyFieldType { .init() } - public static var int16: ForyFieldType { .init() } - public static var uint8: ForyFieldType { .init() } - public static var uint16: ForyFieldType { .init() } - public static var float16: ForyFieldType { .init() } - public static var bfloat16: ForyFieldType { .init() } - public static var float32: ForyFieldType { .init() } - public static var float64: ForyFieldType { .init() } - public static var string: ForyFieldType { .init() } - public static var date: ForyFieldType { .init() } - public static var timestamp: ForyFieldType { .init() } - public static var duration: ForyFieldType { .init() } - public static var decimal: ForyFieldType { .init() } - public static var binary: ForyFieldType { .init() } - - public static func int32(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) -> ForyFieldType { - _ = nullable - _ = encoding - return .init() - } - - public static func int64(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) -> ForyFieldType { - _ = nullable - _ = encoding - return .init() - } - - public static func int(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) -> ForyFieldType { - _ = nullable - _ = encoding - return .init() - } - - public static func uint32(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) -> ForyFieldType { - _ = nullable - _ = encoding - return .init() - } - - public static func uint64(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) -> ForyFieldType { - _ = nullable - _ = encoding - return .init() - } - - public static func uint(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) -> ForyFieldType { - _ = nullable - _ = encoding - return .init() - } - - public static func list(_ element: ForyFieldType) -> ForyFieldType { - list(element: element) - } - - public static func list(element: ForyFieldType) -> ForyFieldType { - _ = element - return .init() - } - - public static func array(element: ForyFieldType) -> ForyFieldType { - _ = element - return .init() - } - - public static func set(element: ForyFieldType) -> ForyFieldType { - _ = element - return .init() - } - - public static func map(key: ForyFieldType) -> ForyFieldType { - _ = key - return .init() - } - - public static func map(value: ForyFieldType) -> ForyFieldType { - _ = value - return .init() - } - - public static func map(key: ForyFieldType, value: ForyFieldType) -> ForyFieldType { - _ = key - _ = value - return .init() - } + private init() {} + + public static func encoding(_ encoding: ForyFieldEncoding) -> ForyFieldType { + _ = encoding + return .init() + } + + public static var bool: ForyFieldType { .init() } + public static var int8: ForyFieldType { .init() } + public static var int16: ForyFieldType { .init() } + public static var uint8: ForyFieldType { .init() } + public static var uint16: ForyFieldType { .init() } + public static var float16: ForyFieldType { .init() } + public static var bfloat16: ForyFieldType { .init() } + public static var float32: ForyFieldType { .init() } + public static var float64: ForyFieldType { .init() } + public static var string: ForyFieldType { .init() } + public static var date: ForyFieldType { .init() } + public static var timestamp: ForyFieldType { .init() } + public static var duration: ForyFieldType { .init() } + public static var decimal: ForyFieldType { .init() } + public static var binary: ForyFieldType { .init() } + + public static func int32(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) + -> ForyFieldType + { + _ = nullable + _ = encoding + return .init() + } + + public static func int64(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) + -> ForyFieldType + { + _ = nullable + _ = encoding + return .init() + } + + public static func int(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) + -> ForyFieldType + { + _ = nullable + _ = encoding + return .init() + } + + public static func uint32(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) + -> ForyFieldType + { + _ = nullable + _ = encoding + return .init() + } + + public static func uint64(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) + -> ForyFieldType + { + _ = nullable + _ = encoding + return .init() + } + + public static func uint(nullable: Bool = false, encoding: ForyFieldEncoding = .varint) + -> ForyFieldType + { + _ = nullable + _ = encoding + return .init() + } + + public static func list(_ element: ForyFieldType) -> ForyFieldType { + list(element: element) + } + + public static func list(element: ForyFieldType) -> ForyFieldType { + _ = element + return .init() + } + + public static func array(element: ForyFieldType) -> ForyFieldType { + _ = element + return .init() + } + + public static func set(element: ForyFieldType) -> ForyFieldType { + _ = element + return .init() + } + + public static func map(key: ForyFieldType) -> ForyFieldType { + _ = key + return .init() + } + + public static func map(value: ForyFieldType) -> ForyFieldType { + _ = value + return .init() + } + + public static func map(key: ForyFieldType, value: ForyFieldType) -> ForyFieldType { + _ = key + _ = value + return .init() + } } @attached(peer) public macro ForyField( - id: Int? = nil, - encoding: ForyFieldEncoding? = nil, - type: ForyFieldType? = nil + id: Int? = nil, + encoding: ForyFieldEncoding? = nil, + type: ForyFieldType? = nil ) = #externalMacro(module: "ForyMacro", type: "ForyFieldMacro") @attached(peer) public macro ListField( - element: ForyFieldType + element: ForyFieldType ) = #externalMacro(module: "ForyMacro", type: "ListFieldMacro") @attached(peer) public macro ArrayField( - element: ForyFieldType + element: ForyFieldType ) = #externalMacro(module: "ForyMacro", type: "ArrayFieldMacro") @attached(peer) public macro SetField( - element: ForyFieldType + element: ForyFieldType ) = #externalMacro(module: "ForyMacro", type: "SetFieldMacro") @attached(peer) public macro MapField( - key: ForyFieldType? = nil, - value: ForyFieldType? = nil + key: ForyFieldType? = nil, + value: ForyFieldType? = nil ) = #externalMacro(module: "ForyMacro", type: "MapFieldMacro") @attached(peer) public macro ForyCase( - id: Int? = nil, - payload: ForyFieldType? = nil + id: Int? = nil, + payload: ForyFieldType? = nil ) = #externalMacro(module: "ForyMacro", type: "ForyCaseMacro") @attached(peer) diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 04a7e73bd6..664916a0fc 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -20,623 +20,659 @@ import Foundation private let typeMetaSizeMask = 0xFF public final class ReadContext { - public let buffer: ByteBuffer - let typeResolver: TypeResolver - public let trackRef: Bool - public let compatible: Bool - public let checkClassVersion: Bool - public let maxCollectionSize: Int - public let maxBinarySize: Int - public let maxDepth: Int - public let refReader: RefReader - private let compatibleTypeDefTypeInfos = ReusableArray(defaultValue: nil, reserve: 2) - private let metaStrings = ReusableArray(defaultValue: nil, reserve: 16) - private var dynamicAnyDepth = 0 - - private var typeInfoStack = UInt64Map(initialCapacity: 8) - private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] - private var lastTypeInfo = TypeInfo.uncached - - init( - buffer: ByteBuffer, - typeResolver: TypeResolver, - trackRef: Bool, - compatible: Bool = false, - checkClassVersion: Bool = true, - maxCollectionSize: Int = 1_000_000, - maxBinarySize: Int = 64 * 1024 * 1024, - maxDepth: Int = 5 - ) { - self.buffer = buffer - self.typeResolver = typeResolver - self.trackRef = trackRef - self.compatible = compatible - self.checkClassVersion = checkClassVersion - self.maxCollectionSize = maxCollectionSize - self.maxBinarySize = maxBinarySize - self.maxDepth = maxDepth - self.refReader = RefReader() - } - - @inline(__always) - func enterDynamicAnyDepth() throws { - if maxDepth < 0 { - throw ForyError.invalidData("configured maxDepth \(maxDepth) is negative") - } - let nextDepth = dynamicAnyDepth + 1 - if nextDepth > maxDepth { - throw ForyError.invalidData( - "dynamic Any nesting depth \(nextDepth) exceeds configured maxDepth \(maxDepth)" - ) - } - dynamicAnyDepth = nextDepth + public let buffer: ByteBuffer + let typeResolver: TypeResolver + public let trackRef: Bool + public let compatible: Bool + public let checkClassVersion: Bool + public let maxCollectionSize: Int + public let maxBinarySize: Int + public let maxDepth: Int + public let refReader: RefReader + private let compatibleTypeDefTypeInfos = ReusableArray(defaultValue: nil, reserve: 2) + private let metaStrings = ReusableArray(defaultValue: nil, reserve: 16) + private var dynamicAnyDepth = 0 + + private var typeInfoStack = UInt64Map(initialCapacity: 8) + private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] + private var lastTypeInfo = TypeInfo.uncached + + init( + buffer: ByteBuffer, + typeResolver: TypeResolver, + trackRef: Bool, + compatible: Bool = false, + checkClassVersion: Bool = true, + maxCollectionSize: Int = 1_000_000, + maxBinarySize: Int = 64 * 1024 * 1024, + maxDepth: Int = 5 + ) { + self.buffer = buffer + self.typeResolver = typeResolver + self.trackRef = trackRef + self.compatible = compatible + self.checkClassVersion = checkClassVersion + self.maxCollectionSize = maxCollectionSize + self.maxBinarySize = maxBinarySize + self.maxDepth = maxDepth + self.refReader = RefReader() + } + + @inline(__always) + func enterDynamicAnyDepth() throws { + if maxDepth < 0 { + throw ForyError.invalidData("configured maxDepth \(maxDepth) is negative") } - - @inline(__always) - func leaveDynamicAnyDepth() { - if dynamicAnyDepth > 0 { - dynamicAnyDepth -= 1 - } + let nextDepth = dynamicAnyDepth + 1 + if nextDepth > maxDepth { + throw ForyError.invalidData( + "dynamic Any nesting depth \(nextDepth) exceeds configured maxDepth \(maxDepth)" + ) } + dynamicAnyDepth = nextDepth + } - @inline(__always) - func ensureCollectionLength(_ length: Int, label: String) throws { - if length < 0 { - throw ForyError.invalidData("\(label) length is negative") - } - if length > maxCollectionSize { - throw ForyError.invalidData( - "\(label) length \(length) exceeds configured maxCollectionSize \(maxCollectionSize)" - ) - } + @inline(__always) + func leaveDynamicAnyDepth() { + if dynamicAnyDepth > 0 { + dynamicAnyDepth -= 1 } + } - @inline(__always) - func ensureBinaryLength(_ length: Int, label: String) throws { - if length < 0 { - throw ForyError.invalidData("\(label) size is negative") - } - if length > maxBinarySize { - throw ForyError.invalidData( - "\(label) size \(length) exceeds configured maxBinarySize \(maxBinarySize)" - ) - } + @inline(__always) + func ensureCollectionLength(_ length: Int, label: String) throws { + if length < 0 { + throw ForyError.invalidData("\(label) length is negative") } + if length > maxCollectionSize { + throw ForyError.invalidData( + "\(label) length \(length) exceeds configured maxCollectionSize \(maxCollectionSize)" + ) + } + } - @inline(__always) - func ensureRemainingBytes(_ byteCount: Int, label: String) throws { - if byteCount < 0 { - throw ForyError.invalidData("\(label) size is negative") - } - let remainingBytes = buffer.remaining - if byteCount > remainingBytes { - throw ForyError.invalidData( - "\(label) requires \(byteCount) bytes but only \(remainingBytes) remain in buffer" - ) - } + @inline(__always) + func ensureBinaryLength(_ length: Int, label: String) throws { + if length < 0 { + throw ForyError.invalidData("\(label) size is negative") + } + if length > maxBinarySize { + throw ForyError.invalidData( + "\(label) size \(length) exceeds configured maxBinarySize \(maxBinarySize)" + ) } + } - @inline(__always) - func typeInfo(for type: T.Type) throws -> TypeInfo { - let typeID = ObjectIdentifier(type) - if lastTypeInfo.swiftTypeID == typeID { - return lastTypeInfo - } - let info = try typeResolver.requireTypeInfo(for: type) - lastTypeInfo = info - return info + @inline(__always) + func ensureRemainingBytes(_ byteCount: Int, label: String) throws { + if byteCount < 0 { + throw ForyError.invalidData("\(label) size is negative") } + let remainingBytes = buffer.remaining + if byteCount > remainingBytes { + throw ForyError.invalidData( + "\(label) requires \(byteCount) bytes but only \(remainingBytes) remain in buffer" + ) + } + } - @inline(__always) - func readStaticTypeInfo(_ typeID: TypeId) throws -> TypeInfo? { - let rawTypeID = UInt32(try buffer.readUInt8()) - guard let actualTypeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown type id \(rawTypeID)") - } - if actualTypeID != typeID { - throw ForyError.typeMismatch(expected: typeID.rawValue, actual: rawTypeID) - } - return nil + @inline(__always) + func typeInfo(for type: T.Type) throws -> TypeInfo { + let typeID = ObjectIdentifier(type) + if lastTypeInfo.swiftTypeID == typeID { + return lastTypeInfo } + let info = try typeResolver.requireTypeInfo(for: type) + lastTypeInfo = info + return info + } + + @inline(__always) + func readStaticTypeInfo(_ typeID: TypeId) throws -> TypeInfo? { + let rawTypeID = UInt32(try buffer.readUInt8()) + guard let actualTypeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown type id \(rawTypeID)") + } + if actualTypeID != typeID { + throw ForyError.typeMismatch(expected: typeID.rawValue, actual: rawTypeID) + } + return nil + } - func readTypeInfo() throws -> TypeInfo { - let rawTypeID = UInt32(try buffer.readUInt8()) - guard let wireTypeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown dynamic type id \(rawTypeID)") - } + func readTypeInfo() throws -> TypeInfo { + let rawTypeID = UInt32(try buffer.readUInt8()) + guard let wireTypeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown dynamic type id \(rawTypeID)") + } - switch wireTypeID { - case .compatibleStruct, .namedCompatibleStruct: - return try readCompatibleTypeInfo() - case .namedEnum, .namedStruct, .namedExt, .namedUnion: - if compatible { - return try readCompatibleTypeInfo() - } - let namespace = try readMetaString( - context: self, - decoder: .namespace, - encodings: namespaceMetaStringEncodings - ) - let typeName = try readMetaString( - context: self, - decoder: .typeName, - encodings: typeNameMetaStringEncodings - ) - return try typeResolver.requireTypeInfo(namespace: namespace.value, typeName: typeName.value) - case .structType, .enumType, .ext, .typedUnion, .union: - let userTypeID = try buffer.readVarUInt32() - return try typeResolver.requireTypeInfo(userTypeID: userTypeID) - default: - return typeResolver.builtinTypeInfo(for: wireTypeID) - } + switch wireTypeID { + case .compatibleStruct, .namedCompatibleStruct: + return try readCompatibleTypeInfo() + case .namedEnum, .namedStruct, .namedExt, .namedUnion: + if compatible { + return try readCompatibleTypeInfo() + } + let namespace = try readMetaString( + context: self, + decoder: .namespace, + encodings: namespaceMetaStringEncodings + ) + let typeName = try readMetaString( + context: self, + decoder: .typeName, + encodings: typeNameMetaStringEncodings + ) + return try typeResolver.requireTypeInfo(namespace: namespace.value, typeName: typeName.value) + case .structType, .enumType, .ext, .typedUnion, .union: + let userTypeID = try buffer.readVarUInt32() + return try typeResolver.requireTypeInfo(userTypeID: userTypeID) + default: + return typeResolver.builtinTypeInfo(for: wireTypeID) } + } - func readTypeInfo(for type: T.Type) throws -> TypeInfo? { - let rawTypeID = UInt32(try buffer.readUInt8()) - guard let typeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown type id \(rawTypeID)") - } + func readTypeInfo(for type: T.Type) throws -> TypeInfo? { + let rawTypeID = UInt32(try buffer.readUInt8()) + guard let typeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown type id \(rawTypeID)") + } - guard T.staticTypeId.isUserTypeKind else { - if typeID != T.staticTypeId { - throw ForyError.typeMismatch(expected: T.staticTypeId.rawValue, actual: rawTypeID) - } - return nil - } + guard T.staticTypeId.isUserTypeKind else { + if typeID != T.staticTypeId { + throw ForyError.typeMismatch(expected: T.staticTypeId.rawValue, actual: rawTypeID) + } + return nil + } - let localTypeInfo = try typeInfo(for: type) - let expectedWireTypeID = localTypeInfo.wireTypeID(compatible: compatible) - if !isAllowedRegisteredWireTypeID( - typeID, - declaredTypeID: localTypeInfo.typeID, - registerByName: localTypeInfo.registerByName, - compatible: compatible, - evolving: localTypeInfo.evolving - ) { - throw ForyError.typeMismatch(expected: expectedWireTypeID.rawValue, actual: rawTypeID) - } + let localTypeInfo = try typeInfo(for: type) + let expectedWireTypeID = localTypeInfo.wireTypeID(compatible: compatible) + if !isAllowedRegisteredWireTypeID( + typeID, + declaredTypeID: localTypeInfo.typeID, + registerByName: localTypeInfo.registerByName, + compatible: compatible, + evolving: localTypeInfo.evolving + ) { + throw ForyError.typeMismatch(expected: expectedWireTypeID.rawValue, actual: rawTypeID) + } - switch typeID { - case .compatibleStruct, .namedCompatibleStruct: - return try readCompatibleTypeInfoIfNeeded( - for: localTypeInfo, - wireTypeID: typeID - ) - case .namedEnum, .namedStruct, .namedExt, .namedUnion: - if compatible { - _ = try readCompatibleTypeInfoIfNeeded( - for: localTypeInfo, - wireTypeID: typeID - ) - } else { - let namespace = try readMetaString( - context: self, - decoder: .namespace, - encodings: namespaceMetaStringEncodings - ) - let typeName = try readMetaString( - context: self, - decoder: .typeName, - encodings: typeNameMetaStringEncodings - ) - guard localTypeInfo.registerByName else { - throw ForyError.invalidData("received name-registered type info for id-registered local type") - } - if namespace.value != localTypeInfo.namespace.value || - typeName.value != localTypeInfo.typeName.value { - let expectedTypeName = "\(localTypeInfo.namespace.value)::\(localTypeInfo.typeName.value)" - let actualTypeName = "\(namespace.value)::\(typeName.value)" - throw ForyError.invalidData( - "type name mismatch: expected \(expectedTypeName), got \(actualTypeName)" - ) - } - } - default: - if !localTypeInfo.registerByName && registeredWireTypeNeedsUserTypeID(typeID) { - guard let localUserTypeID = localTypeInfo.userTypeID else { - throw ForyError.invalidData("missing user type id for id-registered type") - } - let remoteUserTypeID = try buffer.readVarUInt32() - if remoteUserTypeID != localUserTypeID { - throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) - } - } - } - return nil - } - - @inline(__always) - private func readCompatibleTypeInfoIfNeeded( - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo? { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - if !checkClassVersion, - compatibleTypeDefTypeInfos.isEmpty, - !localTypeInfo.typeDefHasUserTypeFields, - let localTypeDefHeader = localTypeInfo.typeDefHeader { - let typeMetaStart = buffer.getCursor() - let indexMarker = try buffer.readVarUInt32() - if indexMarker == 0 { - let header = try buffer.readUInt64() - var bodySize = Int(header & UInt64(typeMetaSizeMask)) - if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) - } - if header == localTypeDefHeader { - // Header-cache hits intentionally skip without rehashing. Entries reach this - // cache only after a successful TypeDef parse and 52-bit metadata-hash validation. - compatibleTypeDefTypeInfos.push(localTypeInfo) - try buffer.skip(bodySize) - return nil - } - } - buffer.setCursor(typeMetaStart) - } - return try readCompatibleTypeInfo( - for: localTypeInfo, - wireTypeID: wireTypeID + switch typeID { + case .compatibleStruct, .namedCompatibleStruct: + return try readCompatibleTypeInfoIfNeeded( + for: localTypeInfo, + wireTypeID: typeID + ) + case .namedEnum, .namedStruct, .namedExt, .namedUnion: + if compatible { + _ = try readCompatibleTypeInfoIfNeeded( + for: localTypeInfo, + wireTypeID: typeID + ) + } else { + let namespace = try readMetaString( + context: self, + decoder: .namespace, + encodings: namespaceMetaStringEncodings + ) + let typeName = try readMetaString( + context: self, + decoder: .typeName, + encodings: typeNameMetaStringEncodings ) + guard localTypeInfo.registerByName else { + throw ForyError.invalidData( + "received name-registered type info for id-registered local type") + } + if namespace.value != localTypeInfo.namespace.value + || typeName.value != localTypeInfo.typeName.value { + let expectedTypeName = "\(localTypeInfo.namespace.value)::\(localTypeInfo.typeName.value)" + let actualTypeName = "\(namespace.value)::\(typeName.value)" + throw ForyError.invalidData( + "type name mismatch: expected \(expectedTypeName), got \(actualTypeName)" + ) + } + } + default: + if !localTypeInfo.registerByName && registeredWireTypeNeedsUserTypeID(typeID) { + guard let localUserTypeID = localTypeInfo.userTypeID else { + throw ForyError.invalidData("missing user type id for id-registered type") + } + let remoteUserTypeID = try buffer.readVarUInt32() + if remoteUserTypeID != localUserTypeID { + throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) + } + } } - - private func readCompatibleTypeInfo() throws -> TypeInfo { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - let indexMarker = try buffer.readVarUInt32() - let isRef = (indexMarker & 1) == 1 - let index = Int(indexMarker >> 1) - if isRef { - guard let typeInfo = compatibleTypeDefTypeInfos.get(index) else { - throw ForyError.invalidData("unknown compatible type definition ref index \(index)") - } - return typeInfo - } - - let typeMetaStart = buffer.getCursor() + return nil + } + + @inline(__always) + private func readCompatibleTypeInfoIfNeeded( + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo? { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + if !checkClassVersion, + compatibleTypeDefTypeInfos.isEmpty, + !localTypeInfo.typeDefHasUserTypeFields, + let localTypeDefHeader = localTypeInfo.typeDefHeader { + let indexMarker = try buffer.readVarUInt32() + if indexMarker == 0 { + let headerStart = buffer.getCursor() let header = try buffer.readUInt64() var bodySize = Int(header & UInt64(typeMetaSizeMask)) if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) + bodySize += Int(try buffer.readVarUInt32()) + } + if header == localTypeDefHeader { + // Header-cache hits intentionally skip without rehashing. Entries reach this + // cache only after a successful TypeDef parse and 52-bit metadata-hash validation. + compatibleTypeDefTypeInfos.push(localTypeInfo) + try buffer.skip(bodySize) + return nil } if let cached = typeResolver.getTypeInfo(forHeader: header) { - // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeDef parse and 52-bit metadata-hash validation. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(cached) - return cached + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) } - - buffer.setCursor(typeMetaStart) + buffer.setCursor(headerStart) let decoded = try TypeMeta.decode(buffer) let cachedTypeInfo = try typeResolver.cacheTypeInfo(decoded, forHeader: header) compatibleTypeDefTypeInfos.push(cachedTypeInfo) - return cachedTypeInfo - } - - @inline(__always) - private func readCompatibleTypeInfo( - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - let remoteTypeInfo: TypeInfo - if compatibleTypeDefTypeInfos.isEmpty, - let localTypeDefHeader = localTypeInfo.typeDefHeader { - let typeMetaStart = buffer.getCursor() - let indexMarker = try buffer.readVarUInt32() - if indexMarker != 0 { - buffer.setCursor(typeMetaStart) - remoteTypeInfo = try readCompatibleTypeInfo() - } else { - let header = try buffer.readUInt64() - var bodySize = Int(header & UInt64(typeMetaSizeMask)) - if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) - } - - if header == localTypeDefHeader { - // Header-cache hits intentionally skip without rehashing. Entries reach this - // cache only after a successful TypeDef parse and 52-bit metadata-hash validation. - compatibleTypeDefTypeInfos.push(localTypeInfo) - try buffer.skip(bodySize) - return localTypeInfo - } - - buffer.setCursor(typeMetaStart) - remoteTypeInfo = try readCompatibleTypeInfo() - } - } else { - remoteTypeInfo = try readCompatibleTypeInfo() - } - guard let remoteTypeMeta = remoteTypeInfo.compatibleTypeMeta else { - throw ForyError.invalidData("compatible type metadata is required") - } - if let localTypeMeta = localTypeInfo.typeMeta, - remoteTypeMeta === localTypeMeta { - return localTypeInfo - } - if remoteTypeMeta.registerByName { - guard localTypeInfo.registerByName else { - throw ForyError.invalidData("received name-registered compatible metadata for id-registered local type") - } - if remoteTypeMeta.namespace.value != localTypeInfo.namespace.value { - throw ForyError.invalidData( - "namespace mismatch: expected \(localTypeInfo.namespace.value), got \(remoteTypeMeta.namespace.value)" - ) - } - if remoteTypeMeta.typeName.value != localTypeInfo.typeName.value { - throw ForyError.invalidData( - "type name mismatch: expected \(localTypeInfo.typeName.value), got \(remoteTypeMeta.typeName.value)" - ) - } - } else { - guard !localTypeInfo.registerByName else { - throw ForyError.invalidData("received id-registered compatible metadata for name-registered local type") - } - guard let remoteUserTypeID = remoteTypeMeta.userTypeID else { - throw ForyError.invalidData("missing user type id in compatible type metadata") - } - guard let localUserTypeID = localTypeInfo.userTypeID else { - throw ForyError.invalidData("missing local user type id metadata for id-registered type") - } - if remoteUserTypeID != localUserTypeID { - throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) - } - } - - if let remoteTypeID = remoteTypeMeta.typeID, - let remoteWireTypeID = TypeId(rawValue: remoteTypeID), - !isAllowedRegisteredWireTypeID( - remoteWireTypeID, - declaredTypeID: localTypeInfo.typeID, - registerByName: localTypeInfo.registerByName, - compatible: compatible, - evolving: localTypeInfo.evolving - ) { - throw ForyError.typeMismatch(expected: wireTypeID.rawValue, actual: remoteTypeID) - } - return remoteTypeInfo - } - - func readAnyValue(typeInfo: TypeInfo) throws -> Any { - try enterDynamicAnyDepth() - defer { leaveDynamicAnyDepth() } - - let value: Any - switch typeInfo.typeID { - case .bool: - value = try Bool.foryRead(self, refMode: .none, readTypeInfo: false) - case .int8: - value = try Int8.foryRead(self, refMode: .none, readTypeInfo: false) - case .int16: - value = try Int16.foryRead(self, refMode: .none, readTypeInfo: false) - case .int32: - value = try buffer.readInt32() - case .varint32: - value = try Int32.foryRead(self, refMode: .none, readTypeInfo: false) - case .int64: - value = try buffer.readInt64() - case .varint64: - value = try Int64.foryRead(self, refMode: .none, readTypeInfo: false) - case .taggedInt64: - value = try buffer.readTaggedInt64() - case .uint8: - value = try UInt8.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint16: - value = try UInt16.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint32: - value = try buffer.readUInt32() - case .varUInt32: - value = try UInt32.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint64: - value = try buffer.readUInt64() - case .varUInt64: - value = try UInt64.foryRead(self, refMode: .none, readTypeInfo: false) - case .taggedUInt64: - value = try buffer.readTaggedUInt64() - case .float16: - value = try Float16.foryRead(self, refMode: .none, readTypeInfo: false) - case .bfloat16: - value = try BFloat16.foryRead(self, refMode: .none, readTypeInfo: false) - case .float32: - value = try Float.foryRead(self, refMode: .none, readTypeInfo: false) - case .float64: - value = try Double.foryRead(self, refMode: .none, readTypeInfo: false) - case .string: - value = try String.foryRead(self, refMode: .none, readTypeInfo: false) - case .duration: - value = try Duration.foryRead(self, refMode: .none, readTypeInfo: false) - case .timestamp: - value = try Date.foryRead(self, refMode: .none, readTypeInfo: false) - case .date: - value = try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) - case .decimal: - value = try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) - case .binary: - value = try Data.foryRead(self, refMode: .none, readTypeInfo: false) - case .boolArray: - value = try readPrimitiveArray(self) as [Bool] - case .int8Array: - value = try readPrimitiveArray(self) as [Int8] - case .int16Array: - value = try readPrimitiveArray(self) as [Int16] - case .int32Array: - value = try readPrimitiveArray(self) as [Int32] - case .int64Array: - value = try readPrimitiveArray(self) as [Int64] - case .uint8Array: - value = try readPrimitiveArray(self) as [UInt8] - case .uint16Array: - value = try readPrimitiveArray(self) as [UInt16] - case .uint32Array: - value = try readPrimitiveArray(self) as [UInt32] - case .uint64Array: - value = try readPrimitiveArray(self) as [UInt64] - case .float16Array: - value = try readPrimitiveArray(self) as [Float16] - case .bfloat16Array: - value = try readPrimitiveArray(self) as [BFloat16] - case .float32Array: - value = try readPrimitiveArray(self) as [Float] - case .float64Array: - value = try readPrimitiveArray(self) as [Double] - case .array, .list: - value = try readListOfAny(refMode: .none) ?? [] - case .set: - value = try Set.foryRead(self, refMode: .none, readTypeInfo: false) - case .map: - value = try readDynamicAnyMapValue(context: self) - case .none: - value = ForyAnyNullValue() - default: - if typeInfo.typeID.isUserTypeKind { - value = try typeInfo.read(self) - } else { - throw ForyError.invalidData("unsupported dynamic type id \(typeInfo.typeID)") - } - } - return value + return try validateCompatibleTypeInfo( + cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + } + let remoteTypeInfo = try readCompatibleTypeInfo(afterMarker: indexMarker) + return try validateCompatibleTypeInfo( + remoteTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + } + return try readCompatibleTypeInfo( + for: localTypeInfo, + wireTypeID: wireTypeID + ) + } + + private func readCompatibleTypeInfo() throws -> TypeInfo { + let indexMarker = try buffer.readVarUInt32() + return try readCompatibleTypeInfo(afterMarker: indexMarker) + } + + private func readCompatibleTypeInfo(afterMarker indexMarker: UInt32) throws -> TypeInfo { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + let isRef = (indexMarker & 1) == 1 + let index = Int(indexMarker >> 1) + if isRef { + guard let typeInfo = compatibleTypeDefTypeInfos.get(index) else { + throw ForyError.invalidData("unknown compatible type definition ref index \(index)") + } + return typeInfo } - @inline(__always) - func getTypeInfo(for type: T.Type) -> TypeInfo? { - typeInfoStack.value(for: UInt64(UInt(bitPattern: ObjectIdentifier(type)))) + let typeMetaStart = buffer.getCursor() + let header = try buffer.readUInt64() + var bodySize = Int(header & UInt64(typeMetaSizeMask)) + if bodySize == typeMetaSizeMask { + bodySize += Int(try buffer.readVarUInt32()) + } + if let cached = typeResolver.getTypeInfo(forHeader: header) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit metadata-hash validation. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + return cached } - func withTypeInfo( - _ typeInfo: TypeInfo?, - for type: T.Type, - _ body: () throws -> R - ) rethrows -> R { - guard let typeInfo else { - return try body() + buffer.setCursor(typeMetaStart) + let decoded = try TypeMeta.decode(buffer) + let cachedTypeInfo = try typeResolver.cacheTypeInfo(decoded, forHeader: header) + compatibleTypeDefTypeInfos.push(cachedTypeInfo) + return cachedTypeInfo + } + + @inline(__always) + private func readCompatibleTypeInfo( + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + let remoteTypeInfo: TypeInfo + if compatibleTypeDefTypeInfos.isEmpty, + let localTypeDefHeader = localTypeInfo.typeDefHeader { + let indexMarker = try buffer.readVarUInt32() + if indexMarker != 0 { + remoteTypeInfo = try readCompatibleTypeInfo(afterMarker: indexMarker) + } else { + let headerStart = buffer.getCursor() + let header = try buffer.readUInt64() + var bodySize = Int(header & UInt64(typeMetaSizeMask)) + if bodySize == typeMetaSizeMask { + bodySize += Int(try buffer.readVarUInt32()) } - let typeKey = UInt64(UInt(bitPattern: ObjectIdentifier(type))) - let previousTypeInfo = typeInfoStack.value(for: typeKey) - typeInfoScopeStack.append((typeKey: typeKey, previousTypeInfo: previousTypeInfo)) - typeInfoStack.set(typeInfo, for: typeKey) - defer { - if let scope = typeInfoScopeStack.popLast() { - if let previousTypeInfo = scope.previousTypeInfo { - typeInfoStack.set(previousTypeInfo, for: scope.typeKey) - } else { - _ = typeInfoStack.removeValue(for: scope.typeKey) - } - } else { - assertionFailure("type info scope stack underflow") - } + if header == localTypeDefHeader { + // Header-cache hits intentionally skip without rehashing. Entries reach this + // cache only after a successful TypeDef parse and 52-bit metadata-hash validation. + compatibleTypeDefTypeInfos.push(localTypeInfo) + try buffer.skip(bodySize) + return localTypeInfo } - return try body() - } - @inline(__always) - func getReadMetaString(at index: Int) -> MetaString? { - metaStrings.get(index) + if let cached = typeResolver.getTypeInfo(forHeader: header) { + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + remoteTypeInfo = cached + } else { + buffer.setCursor(headerStart) + let decoded = try TypeMeta.decode(buffer) + remoteTypeInfo = try typeResolver.cacheTypeInfo(decoded, forHeader: header) + compatibleTypeDefTypeInfos.push(remoteTypeInfo) + } + } + } else { + remoteTypeInfo = try readCompatibleTypeInfo() + } + return try validateCompatibleTypeInfo( + remoteTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + } + + private func validateCompatibleTypeInfo( + _ remoteTypeInfo: TypeInfo, + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo { + guard let remoteTypeMeta = remoteTypeInfo.compatibleTypeMeta else { + throw ForyError.invalidData("compatible type metadata is required") + } + if let localTypeMeta = localTypeInfo.typeMeta, + remoteTypeMeta === localTypeMeta { + return localTypeInfo + } + if remoteTypeMeta.registerByName { + guard localTypeInfo.registerByName else { + throw ForyError.invalidData( + "received name-registered compatible metadata for id-registered local type") + } + if remoteTypeMeta.namespace.value != localTypeInfo.namespace.value { + throw ForyError.invalidData( + "namespace mismatch: expected \(localTypeInfo.namespace.value), got \(remoteTypeMeta.namespace.value)" + ) + } + if remoteTypeMeta.typeName.value != localTypeInfo.typeName.value { + throw ForyError.invalidData( + "type name mismatch: expected \(localTypeInfo.typeName.value), got \(remoteTypeMeta.typeName.value)" + ) + } + } else { + guard !localTypeInfo.registerByName else { + throw ForyError.invalidData( + "received id-registered compatible metadata for name-registered local type") + } + guard let remoteUserTypeID = remoteTypeMeta.userTypeID else { + throw ForyError.invalidData("missing user type id in compatible type metadata") + } + guard let localUserTypeID = localTypeInfo.userTypeID else { + throw ForyError.invalidData("missing local user type id metadata for id-registered type") + } + if remoteUserTypeID != localUserTypeID { + throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) + } } - @inline(__always) - func appendReadMetaString(_ value: MetaString) { - metaStrings.push(value) + if let remoteTypeID = remoteTypeMeta.typeID, + let remoteWireTypeID = TypeId(rawValue: remoteTypeID), + !isAllowedRegisteredWireTypeID( + remoteWireTypeID, + declaredTypeID: localTypeInfo.typeID, + registerByName: localTypeInfo.registerByName, + compatible: compatible, + evolving: localTypeInfo.evolving + ) { + throw ForyError.typeMismatch(expected: wireTypeID.rawValue, actual: remoteTypeID) + } + return remoteTypeInfo + } + + func readAnyValue(typeInfo: TypeInfo) throws -> Any { + try enterDynamicAnyDepth() + defer { leaveDynamicAnyDepth() } + + let value: Any + switch typeInfo.typeID { + case .bool: + value = try Bool.foryRead(self, refMode: .none, readTypeInfo: false) + case .int8: + value = try Int8.foryRead(self, refMode: .none, readTypeInfo: false) + case .int16: + value = try Int16.foryRead(self, refMode: .none, readTypeInfo: false) + case .int32: + value = try buffer.readInt32() + case .varint32: + value = try Int32.foryRead(self, refMode: .none, readTypeInfo: false) + case .int64: + value = try buffer.readInt64() + case .varint64: + value = try Int64.foryRead(self, refMode: .none, readTypeInfo: false) + case .taggedInt64: + value = try buffer.readTaggedInt64() + case .uint8: + value = try UInt8.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint16: + value = try UInt16.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint32: + value = try buffer.readUInt32() + case .varUInt32: + value = try UInt32.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint64: + value = try buffer.readUInt64() + case .varUInt64: + value = try UInt64.foryRead(self, refMode: .none, readTypeInfo: false) + case .taggedUInt64: + value = try buffer.readTaggedUInt64() + case .float16: + value = try Float16.foryRead(self, refMode: .none, readTypeInfo: false) + case .bfloat16: + value = try BFloat16.foryRead(self, refMode: .none, readTypeInfo: false) + case .float32: + value = try Float.foryRead(self, refMode: .none, readTypeInfo: false) + case .float64: + value = try Double.foryRead(self, refMode: .none, readTypeInfo: false) + case .string: + value = try String.foryRead(self, refMode: .none, readTypeInfo: false) + case .duration: + value = try Duration.foryRead(self, refMode: .none, readTypeInfo: false) + case .timestamp: + value = try Date.foryRead(self, refMode: .none, readTypeInfo: false) + case .date: + value = try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) + case .decimal: + value = try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) + case .binary: + value = try Data.foryRead(self, refMode: .none, readTypeInfo: false) + case .boolArray: + value = try readPrimitiveArray(self) as [Bool] + case .int8Array: + value = try readPrimitiveArray(self) as [Int8] + case .int16Array: + value = try readPrimitiveArray(self) as [Int16] + case .int32Array: + value = try readPrimitiveArray(self) as [Int32] + case .int64Array: + value = try readPrimitiveArray(self) as [Int64] + case .uint8Array: + value = try readPrimitiveArray(self) as [UInt8] + case .uint16Array: + value = try readPrimitiveArray(self) as [UInt16] + case .uint32Array: + value = try readPrimitiveArray(self) as [UInt32] + case .uint64Array: + value = try readPrimitiveArray(self) as [UInt64] + case .float16Array: + value = try readPrimitiveArray(self) as [Float16] + case .bfloat16Array: + value = try readPrimitiveArray(self) as [BFloat16] + case .float32Array: + value = try readPrimitiveArray(self) as [Float] + case .float64Array: + value = try readPrimitiveArray(self) as [Double] + case .array, .list: + value = try readListOfAny(refMode: .none) ?? [] + case .set: + value = try Set.foryRead(self, refMode: .none, readTypeInfo: false) + case .map: + value = try readDynamicAnyMapValue(context: self) + case .none: + value = ForyAnyNullValue() + default: + if typeInfo.typeID.isUserTypeKind { + value = try typeInfo.read(self) + } else { + throw ForyError.invalidData("unsupported dynamic type id \(typeInfo.typeID)") + } + } + return value + } + + @inline(__always) + func getTypeInfo(for type: T.Type) -> TypeInfo? { + typeInfoStack.value(for: UInt64(UInt(bitPattern: ObjectIdentifier(type)))) + } + + func withTypeInfo( + _ typeInfo: TypeInfo?, + for type: T.Type, + _ body: () throws -> R + ) rethrows -> R { + guard let typeInfo else { + return try body() } - func reset() { - if dynamicAnyDepth != 0 { - dynamicAnyDepth = 0 - } - if trackRef { - refReader.reset() - } - if !typeInfoStack.isEmpty { - typeInfoStack.clear() - } - if !typeInfoScopeStack.isEmpty { - typeInfoScopeStack.removeAll(keepingCapacity: true) + let typeKey = UInt64(UInt(bitPattern: ObjectIdentifier(type))) + let previousTypeInfo = typeInfoStack.value(for: typeKey) + typeInfoScopeStack.append((typeKey: typeKey, previousTypeInfo: previousTypeInfo)) + typeInfoStack.set(typeInfo, for: typeKey) + defer { + if let scope = typeInfoScopeStack.popLast() { + if let previousTypeInfo = scope.previousTypeInfo { + typeInfoStack.set(previousTypeInfo, for: scope.typeKey) + } else { + _ = typeInfoStack.removeValue(for: scope.typeKey) } - compatibleTypeDefTypeInfos.reset() - metaStrings.reset() + } else { + assertionFailure("type info scope stack underflow") + } + } + return try body() + } + + @inline(__always) + func getReadMetaString(at index: Int) -> MetaString? { + metaStrings.get(index) + } + + @inline(__always) + func appendReadMetaString(_ value: MetaString) { + metaStrings.push(value) + } + + func reset() { + if dynamicAnyDepth != 0 { + dynamicAnyDepth = 0 + } + if trackRef { + refReader.reset() + } + if !typeInfoStack.isEmpty { + typeInfoStack.clear() + } + if !typeInfoScopeStack.isEmpty { + typeInfoScopeStack.removeAll(keepingCapacity: true) } + compatibleTypeDefTypeInfos.reset() + metaStrings.reset() + } } public extension ReadContext { - func readAny( - refMode: RefMode, - readTypeInfo: Bool = true - ) throws -> Any? { - try SerializableAny.foryRead(self, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() - } - - func readListOfAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [Any]? { - let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - return wrapped?.map { $0.anyValueForCollection() } + func readAny( + refMode: RefMode, + readTypeInfo: Bool = true + ) throws -> Any? { + try SerializableAny.foryRead(self, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() + } + + func readListOfAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [Any]? { + let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + return wrapped?.map { $0.anyValueForCollection() } + } + + func readMapStringToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [String: Any]? { + let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil } - - func readMapStringToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [String: Any]? { - let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - var map: [String: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + var map: [String: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() } - - func readMapInt32ToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [Int32: Any]? { - let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - var map: [Int32: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + return map + } + + func readMapInt32ToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [Int32: Any]? { + let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil } - - func readMapAnyHashableToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [AnyHashable: Any]? { - let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - var map: [AnyHashable: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + var map: [Int32: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map + } + + func readMapAnyHashableToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [AnyHashable: Any]? { + let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + var map: [AnyHashable: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() } + return map + } } diff --git a/swift/Sources/Fory/TypeMeta.swift b/swift/Sources/Fory/TypeMeta.swift index 6179d930f9..fe4e53a047 100644 --- a/swift/Sources/Fory/TypeMeta.swift +++ b/swift/Sources/Fory/TypeMeta.swift @@ -34,20 +34,20 @@ private let noUserTypeID: UInt32 = UInt32.max public let namespaceMetaStringEncodings: [MetaStringEncoding] = [ .utf8, .allToLowerSpecial, - .lowerUpperDigitSpecial, + .lowerUpperDigitSpecial ] public let typeNameMetaStringEncodings: [MetaStringEncoding] = [ .utf8, .allToLowerSpecial, .lowerUpperDigitSpecial, - .firstToLowerSpecial, + .firstToLowerSpecial ] public let fieldNameMetaStringEncodings: [MetaStringEncoding] = [ .utf8, .allToLowerSpecial, - .lowerUpperDigitSpecial, + .lowerUpperDigitSpecial ] public final class TypeMeta: Equatable, @unchecked Sendable { @@ -408,8 +408,7 @@ public final class TypeMeta: Equatable, @unchecked Sendable { throw ForyError.invalidData("unexpected trailing bytes in TypeMeta body") } if (header & Self.hashMask()) - != Self.typeMetaHeaderHash(encodedBody, headerLowBits: header & ~Self.hashMask()) - { + != Self.typeMetaHeaderHash(encodedBody, headerLowBits: header & ~Self.hashMask()) { throw ForyError.invalidData("invalid TypeMeta metadata hash") } @@ -599,7 +598,25 @@ public final class TypeMeta: Equatable, @unchecked Sendable { let localFields = localTypeMeta.fields guard !localFields.isEmpty else { - return self + var resolvedFields = fields + var changed = false + for index in resolvedFields.indices where resolvedFields[index].fieldID != -1 { + resolvedFields[index].fieldID = -1 + changed = true + } + guard changed else { + return self + } + return try TypeMeta( + typeID: typeID, + userTypeID: userTypeID, + namespace: namespace, + typeName: typeName, + registerByName: registerByName, + fields: resolvedFields, + compressed: compressed, + headerHash: headerHash + ) } var fieldIndexByName: [String: (Int, FieldInfo)] = [:] @@ -608,9 +625,10 @@ public final class TypeMeta: Equatable, @unchecked Sendable { fieldIndexByID.reserveCapacity(localFields.count) for (index, localField) in localFields.enumerated() { - fieldIndexByName[toSnakeCase(localField.fieldName)] = (index, localField) if let fieldID = localField.fieldID, fieldID >= 0 { fieldIndexByID[fieldID] = (index, localField) + } else { + fieldIndexByName[toSnakeCase(localField.fieldName)] = (index, localField) } } @@ -623,26 +641,35 @@ public final class TypeMeta: Equatable, @unchecked Sendable { var localMatch: (Int, FieldInfo)? if let fieldID = field.fieldID, fieldID >= 0 { - if let candidate = fieldIndexByID[fieldID], - Self.isCompatibleFieldType(field.fieldType, candidate.1.fieldType) - { + if let candidate = fieldIndexByID[fieldID] { + guard Self.isCompatibleFieldType(field.fieldType, candidate.1.fieldType) else { + throw ForyError.invalidData( + "compatible field \(field.fieldName) cannot be read as local field \(candidate.1.fieldName)" + ) + } localMatch = candidate } } - if localMatch == nil { - if let candidate = fieldIndexByName[toSnakeCase(field.fieldName)], - Self.isCompatibleFieldType(field.fieldType, candidate.1.fieldType) - { + if localMatch == nil && field.fieldID == nil { + if let candidate = fieldIndexByName[toSnakeCase(field.fieldName)] { + guard Self.isCompatibleFieldType(field.fieldType, candidate.1.fieldType) else { + throw ForyError.invalidData( + "compatible field \(field.fieldName) cannot be read as local field \(candidate.1.fieldName)" + ) + } localMatch = candidate } } - if localMatch == nil { + // Only anonymous legacy fields may fall back by exact type. Named or + // tagged remote-only fields must stay unmatched so the reader skips them. + if localMatch == nil && field.fieldID == nil && field.fieldName.isEmpty { for localIndex in localFields.indices where !usedLocalFields[localIndex] { if Self.isCompatibleFieldType( field.fieldType, localFields[localIndex].fieldType, + topLevel: false, allowScalarConversion: false ) { localMatch = (localIndex, localFields[localIndex]) @@ -651,17 +678,27 @@ public final class TypeMeta: Equatable, @unchecked Sendable { } } - guard let (sortedIndex, _) = localMatch, - sortedIndex <= Int(Int16.max) - else { + guard let (sortedIndex, _) = localMatch else { if field.fieldID != -1 { resolvedFields[index].fieldID = -1 changed = true } continue } + guard sortedIndex <= Int(Int16.max) / 2 else { + throw ForyError.invalidData( + "compatible field \(field.fieldName) matched local field index \(sortedIndex) beyond Int16 dispatch range" + ) + } + if usedLocalFields[sortedIndex] { + throw ForyError.invalidData( + "compatible field \(field.fieldName) duplicates local field \(localFields[sortedIndex].fieldName)" + ) + } - let resolvedFieldID = Int16(sortedIndex) + let localField = localFields[sortedIndex] + let exactField = field.fieldType == localField.fieldType + let resolvedFieldID = Int16(sortedIndex * 2 + (exactField ? 0 : 1)) if field.fieldID != resolvedFieldID { resolvedFields[index].fieldID = resolvedFieldID changed = true @@ -694,29 +731,40 @@ public final class TypeMeta: Equatable, @unchecked Sendable { if topLevel, isCompatibleTopLevelListArrayFieldType(remoteType, localType) { return true } + if topLevel, isCompatibleByteSequenceFieldType(remoteType, localType) { + return true + } if topLevel, remoteType.trackRef != localType.trackRef, compatibleScalarKind(remoteType.typeID) != nil, - compatibleScalarKind(localType.typeID) != nil - { + compatibleScalarKind(localType.typeID) != nil { return false } if topLevel, remoteType.trackRef || localType.trackRef, compatibleScalarKind(remoteType.typeID) != nil, compatibleScalarKind(localType.typeID) != nil, - remoteType.typeID != localType.typeID || remoteType.nullable != localType.nullable - { + remoteType.typeID != localType.typeID || remoteType.nullable != localType.nullable { return false } if topLevel, allowScalarConversion, isCompatibleScalarFieldType(remoteType, localType) { return true } - if normalizeCompatibleTypeIDForComparison(remoteType.typeID) - != normalizeCompatibleTypeIDForComparison(localType.typeID) - { + if normalizeUserTypeIDForComparison(remoteType.typeID) + != normalizeUserTypeIDForComparison(localType.typeID) { return false } + if !topLevel, + compatibleScalarKind(remoteType.typeID) != nil + || compatibleScalarKind(localType.typeID) != nil { + // Untracked nested scalar payloads carry null markers at read time. A nullable schema can + // match a non-null local container element; the container reader owns actual-null semantics. + return remoteType.typeID == localType.typeID + && remoteType.trackRef == localType.trackRef + && (!remoteType.trackRef || remoteType.nullable == localType.nullable) + && remoteType.generics.isEmpty + && localType.generics.isEmpty + } if remoteType.generics.count != localType.generics.count { return false } @@ -736,27 +784,63 @@ public final class TypeMeta: Equatable, @unchecked Sendable { _ remoteType: FieldType, _ localType: FieldType ) -> Bool { + guard !remoteType.nullable, !localType.nullable, !remoteType.trackRef, !localType.trackRef + else { + return false + } if remoteType.typeID == TypeId.list.rawValue { - return listElementMatchesDenseArrayTypeID(remoteType, arrayTypeID: localType.typeID) + return listElementMatchesDenseArrayTypeID( + remoteType, + arrayTypeID: localType.typeID, + requireUnframedElement: true + ) } if localType.typeID == TypeId.list.rawValue { - return listElementMatchesDenseArrayTypeID(localType, arrayTypeID: remoteType.typeID) + return listElementMatchesDenseArrayTypeID( + localType, + arrayTypeID: remoteType.typeID, + requireUnframedElement: false + ) } return false } private static func listElementMatchesDenseArrayTypeID( _ listType: FieldType, - arrayTypeID: UInt32 + arrayTypeID: UInt32, + requireUnframedElement: Bool ) -> Bool { guard listType.typeID == TypeId.list.rawValue, let elementType = listType.generics.first else { return false } + // Nullable element schema is allowed for list -> array; actual + // null payload elements fail in the dense-array reader. Ref-tracked + // element framing is rejected here because this path stays primitive-only. + if requireUnframedElement, elementType.trackRef { + return false + } return TypeId.listElementTypeID(elementType.typeID, matchesDenseArrayTypeID: arrayTypeID) } + private static func isCompatibleByteSequenceFieldType( + _ remoteType: FieldType, + _ localType: FieldType + ) -> Bool { + guard !remoteType.trackRef, + !localType.trackRef, + remoteType.nullable == localType.nullable + else { + return false + } + return + (remoteType.typeID == TypeId.binary.rawValue + && localType.typeID == TypeId.uint8Array.rawValue) + || (remoteType.typeID == TypeId.uint8Array.rawValue + && localType.typeID == TypeId.binary.rawValue) + } + private static func isCompatibleScalarFieldType( _ remoteType: FieldType, _ localType: FieldType @@ -838,7 +922,7 @@ public final class TypeMeta: Equatable, @unchecked Sendable { } } - private static func normalizeCompatibleTypeIDForComparison(_ typeID: UInt32) -> UInt32 { + private static func normalizeUserTypeIDForComparison(_ typeID: UInt32) -> UInt32 { switch typeID { case TypeId.structType.rawValue, TypeId.compatibleStruct.rawValue, @@ -852,10 +936,6 @@ public final class TypeMeta: Equatable, @unchecked Sendable { case TypeId.ext.rawValue, TypeId.namedExt.rawValue: return TypeId.ext.rawValue - case TypeId.binary.rawValue, - TypeId.int8Array.rawValue, - TypeId.uint8Array.rawValue: - return TypeId.binary.rawValue case TypeId.union.rawValue, TypeId.typedUnion.rawValue, TypeId.namedUnion.rawValue: diff --git a/swift/Sources/Fory/TypeResolver.swift b/swift/Sources/Fory/TypeResolver.swift index 38a8951079..06963509a4 100644 --- a/swift/Sources/Fory/TypeResolver.swift +++ b/swift/Sources/Fory/TypeResolver.swift @@ -390,6 +390,7 @@ public final class TypeInfo: @unchecked Sendable { } return try reader(context) } + } private struct TypeNameKey: Hashable { @@ -567,9 +568,8 @@ final class TypeResolver { return localTypeInfo } let canonicalTypeMeta: TypeMeta - if let localTypeMeta = localTypeInfo.typeMeta, - let remapped = try? typeMeta.assigningFieldIDs(from: localTypeMeta) { - canonicalTypeMeta = remapped + if let localTypeMeta = localTypeInfo.typeMeta { + canonicalTypeMeta = try typeMeta.assigningFieldIDs(from: localTypeMeta) } else { canonicalTypeMeta = typeMeta } diff --git a/swift/Sources/ForyMacro/ForyObjectMacro.swift b/swift/Sources/ForyMacro/ForyObjectMacro.swift index 6f52b601e3..a70a47ae12 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacro.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacro.swift @@ -2860,9 +2860,9 @@ TypeMeta.FieldType( private func compatibleFieldTypeIDExpression(_ typeText: String) -> String { let staticTypeIDExpr = "\(typeText).staticTypeId" // A typed union field already has schema from the owning struct field; only - // dynamic/top-level union values need TYPED_UNION metadata. + // dynamic/top-level union values need typed or named union metadata. return "UInt32((\(staticTypeIDExpr) == .structType ? TypeId.compatibleStruct : " - + "(\(staticTypeIDExpr) == .typedUnion ? TypeId.union : \(staticTypeIDExpr))).rawValue)" + + "(\(staticTypeIDExpr) == .typedUnion || \(staticTypeIDExpr) == .namedUnion ? TypeId.union : \(staticTypeIDExpr))).rawValue)" } private func compatibleGenericNullableExpression(_ typeText: String) -> String { diff --git a/swift/Sources/ForyMacro/ForyObjectMacroPrimitiveFastPath.swift b/swift/Sources/ForyMacro/ForyObjectMacroPrimitiveFastPath.swift index 8601b1c717..42ff413da6 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacroPrimitiveFastPath.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacroPrimitiveFastPath.swift @@ -16,110 +16,111 @@ // under the License. func leadingPrimitiveFastPathFields(_ fields: [ParsedField]) -> [ParsedField] { - var result: [ParsedField] = [] - result.reserveCapacity(fields.count) - for field in fields { - if isPrimitiveFastPathField(field) { - result.append(field) - } else { - break - } + var result: [ParsedField] = [] + result.reserveCapacity(fields.count) + for field in fields { + if isPrimitiveFastPathField(field) { + result.append(field) + } else { + break } - return result + } + return result } private func leadingFixedPrimitiveFields(_ fields: [ParsedField]) -> [ParsedField] { - var result: [ParsedField] = [] - result.reserveCapacity(fields.count) - for field in fields { - if primitiveFixedByteWidth(for: field) != nil { - result.append(field) - } else { - break - } + var result: [ParsedField] = [] + result.reserveCapacity(fields.count) + for field in fields { + if primitiveFixedByteWidth(for: field) != nil { + result.append(field) + } else { + break } - return result + } + return result } private func primitiveFixedPrefixBytes(_ fields: [ParsedField]) -> Int { - fields.reduce(0) { partial, field in - partial + (primitiveFixedByteWidth(for: field) ?? 0) - } + fields.reduce(0) { partial, field in + partial + (primitiveFixedByteWidth(for: field) ?? 0) + } } private func primitiveFixedByteWidth(for field: ParsedField) -> Int? { - switch trimType(field.typeText) { - case "Bool", "Int8", "UInt8": - return 1 - case "Int16", "UInt16": - return 2 - case "Float": - return 4 - case "Double": - return 8 - default: - return nil - } + switch trimType(field.typeText) { + case "Bool", "Int8", "UInt8": + return 1 + case "Int16", "UInt16": + return 2 + case "Float": + return 4 + case "Double": + return 8 + default: + return nil + } } private func isPrimitiveFastPathField(_ field: ParsedField) -> Bool { - guard !field.isOptional else { - return false - } - guard field.dynamicAnyCodec == nil, field.customCodecType == nil else { - return false - } - guard field.typeID != 27, !compatibleFieldNeedsTypeInfo(field) else { - return false - } - return primitiveUnsafeWriteMethod(for: field) != nil && primitiveUnsafeReadMethod(for: field) != nil + guard !field.isOptional else { + return false + } + guard field.dynamicAnyCodec == nil, field.customCodecType == nil else { + return false + } + guard field.typeID != 27, !compatibleFieldNeedsTypeInfo(field) else { + return false + } + return primitiveUnsafeWriteMethod(for: field) != nil + && primitiveUnsafeReadMethod(for: field) != nil } func buildPrimitiveFastWriteBlock(_ fields: [ParsedField]) -> String? { - guard !fields.isEmpty else { - return nil - } - let fixedFields = leadingFixedPrimitiveFields(fields) - let remainingFields = Array(fields.dropFirst(fixedFields.count)) - let fixedPrefixBytes = primitiveFixedPrefixBytes(fixedFields) - let maxNumericBytes = fields.reduce(0) { partial, field in - partial + (primitiveMaxEncodedByteWidth(for: field) ?? 0) + guard !fields.isEmpty else { + return nil + } + let fixedFields = leadingFixedPrimitiveFields(fields) + let remainingFields = Array(fields.dropFirst(fixedFields.count)) + let fixedPrefixBytes = primitiveFixedPrefixBytes(fixedFields) + let maxNumericBytes = fields.reduce(0) { partial, field in + partial + (primitiveMaxEncodedByteWidth(for: field) ?? 0) + } + guard maxNumericBytes > 0 else { + return nil + } + let locals = fields.map { field in + "let __\(field.name) = self.\(field.name)" + }.joined(separator: "\n ") + var fixedOffset = 0 + let fixedWrites = fixedFields.compactMap { field -> String? in + guard let line = primitiveUnsafeWriteFixedLine(for: field, offset: fixedOffset) else { + return nil } - guard maxNumericBytes > 0 else { - return nil - } - let locals = fields.map { field in - "let __\(field.name) = self.\(field.name)" - }.joined(separator: "\n ") - var fixedOffset = 0 - let fixedWrites = fixedFields.compactMap { field -> String? in - guard let line = primitiveUnsafeWriteFixedLine(for: field, offset: fixedOffset) else { - return nil - } - fixedOffset += primitiveFixedByteWidth(for: field) ?? 0 - return line - }.joined(separator: "\n ") - let remainingWrites = remainingFields.compactMap { field in - primitiveUnsafeWriteAdvanceLine(for: field, indexExpr: "__writerIndex") - }.joined(separator: "\n ") - var bodySections: [String] = [] - if !fixedWrites.isEmpty { - bodySections.append(fixedWrites) - } - if !remainingWrites.isEmpty { - bodySections.append( - """ - var __writerIndex = \(fixedPrefixBytes) - \(remainingWrites) - assert(__writerIndex <= \(maxNumericBytes)) - return __writerIndex - """ - ) - } else { - bodySections.append("return \(fixedPrefixBytes)") - } - let writeBody = bodySections.joined(separator: "\n ") - return """ + fixedOffset += primitiveFixedByteWidth(for: field) ?? 0 + return line + }.joined(separator: "\n ") + let remainingWrites = remainingFields.compactMap { field in + primitiveUnsafeWriteAdvanceLine(for: field, indexExpr: "__writerIndex") + }.joined(separator: "\n ") + var bodySections: [String] = [] + if !fixedWrites.isEmpty { + bodySections.append(fixedWrites) + } + if !remainingWrites.isEmpty { + bodySections.append( + """ + var __writerIndex = \(fixedPrefixBytes) + \(remainingWrites) + assert(__writerIndex <= \(maxNumericBytes)) + return __writerIndex + """ + ) + } else { + bodySections.append("return \(fixedPrefixBytes)") + } + let writeBody = bodySections.joined(separator: "\n ") + return """ \(locals) UnsafeUtil.writeRegion(buffer: __buffer, maxCount: \(maxNumericBytes)) { __base in \(writeBody) @@ -128,214 +129,224 @@ func buildPrimitiveFastWriteBlock(_ fields: [ParsedField]) -> String? { } private func primitiveMaxEncodedByteWidth(for field: ParsedField) -> Int? { - switch trimType(field.typeText) { - case "Bool", "Int8", "UInt8": - return 1 - case "Int16", "UInt16": - return 2 - case "Float": - return 4 - case "Double": - return 8 - case "Int32": - return 5 - case "UInt32": - return 5 - case "Int64": - return 9 - case "UInt64": - return 9 - case "Int": - return 9 - case "UInt": - return 9 - default: - return nil - } + switch trimType(field.typeText) { + case "Bool", "Int8", "UInt8": + return 1 + case "Int16", "UInt16": + return 2 + case "Float": + return 4 + case "Double": + return 8 + case "Int32": + return 5 + case "UInt32": + return 5 + case "Int64": + return 9 + case "UInt64": + return 9 + case "Int": + return 9 + case "UInt": + return 9 + default: + return nil + } } private func primitiveUnsafeWriteFixedLine(for field: ParsedField, offset: Int) -> String? { - guard let method = primitiveUnsafeWriteMethod(for: field) else { - return nil - } - return "_ = UnsafeUtil.\(method)(__\(field.name), to: __base, index: \(offset))" + guard let method = primitiveUnsafeWriteMethod(for: field) else { + return nil + } + return "_ = UnsafeUtil.\(method)(__\(field.name), to: __base, index: \(offset))" } private func primitiveUnsafeWriteAdvanceLine(for field: ParsedField, indexExpr: String) -> String? { - guard let method = primitiveUnsafeWriteMethod(for: field) else { - return nil - } - return "__writerIndex = UnsafeUtil.\(method)(__\(field.name), to: __base, index: \(indexExpr))" + guard let method = primitiveUnsafeWriteMethod(for: field) else { + return nil + } + return "__writerIndex = UnsafeUtil.\(method)(__\(field.name), to: __base, index: \(indexExpr))" } private func primitiveUnsafeWriteMethod(for field: ParsedField) -> String? { - switch trimType(field.typeText) { - case "Bool": - return "writeBool" - case "Int8": - return "writeInt8" - case "Int16": - return "writeInt16" - case "Int32": - return "writeInt32" - case "Int64": - return "writeInt64" - case "Int": - return "writeInt" - case "UInt8": - return "writeUInt8" - case "UInt16": - return "writeUInt16" - case "UInt32": - return "writeUInt32" - case "UInt64": - return "writeUInt64" - case "UInt": - return "writeUInt" - case "Float": - return "writeFloat32" - case "Double": - return "writeFloat64" - default: - return nil - } + switch trimType(field.typeText) { + case "Bool": + return "writeBool" + case "Int8": + return "writeInt8" + case "Int16": + return "writeInt16" + case "Int32": + return "writeInt32" + case "Int64": + return "writeInt64" + case "Int": + return "writeInt" + case "UInt8": + return "writeUInt8" + case "UInt16": + return "writeUInt16" + case "UInt32": + return "writeUInt32" + case "UInt64": + return "writeUInt64" + case "UInt": + return "writeUInt" + case "Float": + return "writeFloat32" + case "Double": + return "writeFloat64" + default: + return nil + } } private func primitiveUnsafeReadMethod(for field: ParsedField) -> String? { - switch trimType(field.typeText) { - case "Bool": - return "readBool" - case "Int8": - return "readInt8" - case "Int16": - return "readInt16" - case "Int32": - return "readInt32" - case "Int64": - return "readInt64" - case "Int": - return "readInt" - case "UInt8": - return "readUInt8" - case "UInt16": - return "readUInt16" - case "UInt32": - return "readUInt32" - case "UInt64": - return "readUInt64" - case "UInt": - return "readUInt" - case "Float": - return "readFloat32" - case "Double": - return "readFloat64" - default: - return nil - } + switch trimType(field.typeText) { + case "Bool": + return "readBool" + case "Int8": + return "readInt8" + case "Int16": + return "readInt16" + case "Int32": + return "readInt32" + case "Int64": + return "readInt64" + case "Int": + return "readInt" + case "UInt8": + return "readUInt8" + case "UInt16": + return "readUInt16" + case "UInt32": + return "readUInt32" + case "UInt64": + return "readUInt64" + case "UInt": + return "readUInt" + case "Float": + return "readFloat32" + case "Double": + return "readFloat64" + default: + return nil + } } private func primitiveUnsafeFixedReadMethod(for field: ParsedField) -> String? { - switch trimType(field.typeText) { - case "Bool": - return "readBoolUnchecked" - case "Int8": - return "readInt8Unchecked" - case "UInt8": - return "readUInt8Unchecked" - case "Int16": - return "readInt16Unchecked" - case "UInt16": - return "readUInt16Unchecked" - case "Float": - return "readFloat32Unchecked" - case "Double": - return "readFloat64Unchecked" - default: - return nil - } + switch trimType(field.typeText) { + case "Bool": + return "readBoolUnchecked" + case "Int8": + return "readInt8Unchecked" + case "UInt8": + return "readUInt8Unchecked" + case "Int16": + return "readInt16Unchecked" + case "UInt16": + return "readUInt16Unchecked" + case "Float": + return "readFloat32Unchecked" + case "Double": + return "readFloat64Unchecked" + default: + return nil + } } -private func primitiveUnsafeFixedReadExpr(for field: ParsedField, baseExpr: String, offset: Int) -> String? { - guard let method = primitiveUnsafeFixedReadMethod(for: field) else { - return nil - } - return "UnsafeUtil.\(method)(from: \(baseExpr), index: \(offset))" +private func primitiveUnsafeFixedReadExpr(for field: ParsedField, baseExpr: String, offset: Int) + -> String? +{ + guard let method = primitiveUnsafeFixedReadMethod(for: field) else { + return nil + } + return "UnsafeUtil.\(method)(from: \(baseExpr), index: \(offset))" } -private func primitiveUnsafePointerReadAdvanceExpr(for field: ParsedField) -> String? { - guard let method = primitiveUnsafeReadMethod(for: field) else { - return nil - } - return "try UnsafeUtil.\(method)(from: __base, length: __length, index: &__readerIndex)" +func primitiveUnsafePointerReadAdvanceExpr(for field: ParsedField) -> String? { + guard let method = primitiveUnsafeReadMethod(for: field) else { + return nil + } + return "try UnsafeUtil.\(method)(from: __base, length: __length, index: &__readerIndex)" } private struct PrimitiveFastReadLayout { - let statements: [String] - let consumedExpr: String - let fixedPrefixBytes: Int + let statements: [String] + let consumedExpr: String + let fixedPrefixBytes: Int } private func buildPrimitiveFastReadStatements( - _ fields: [ParsedField], - assignLine: (ParsedField, String) -> String, - remainingReadExpr: (ParsedField) -> String? + _ fields: [ParsedField], + assignLine: (ParsedField, String) -> String, + remainingReadExpr: (ParsedField) -> String? ) -> PrimitiveFastReadLayout? { - guard !fields.isEmpty else { - return nil + guard !fields.isEmpty else { + return nil + } + let fixedFields = leadingFixedPrimitiveFields(fields) + let remainingFields = Array(fields.dropFirst(fixedFields.count)) + let fixedPrefixBytes = primitiveFixedPrefixBytes(fixedFields) + var fixedOffset = 0 + let fixedReads = fixedFields.compactMap { field -> String? in + guard + let readExpr = primitiveUnsafeFixedReadExpr( + for: field, baseExpr: "__base", offset: fixedOffset) + else { + return nil } - let fixedFields = leadingFixedPrimitiveFields(fields) - let remainingFields = Array(fields.dropFirst(fixedFields.count)) - let fixedPrefixBytes = primitiveFixedPrefixBytes(fixedFields) - var fixedOffset = 0 - let fixedReads = fixedFields.compactMap { field -> String? in - guard let readExpr = primitiveUnsafeFixedReadExpr(for: field, baseExpr: "__base", offset: fixedOffset) else { - return nil - } - fixedOffset += primitiveFixedByteWidth(for: field) ?? 0 - return assignLine(field, readExpr) - }.joined(separator: "\n ") - let remainingReads = remainingFields.compactMap { field -> String? in - guard let readExpr = remainingReadExpr(field) else { - return nil - } - return assignLine(field, readExpr) - }.joined(separator: "\n ") - var readSections: [String] = [] - if !fixedReads.isEmpty { - readSections.append(fixedReads) + fixedOffset += primitiveFixedByteWidth(for: field) ?? 0 + return assignLine(field, readExpr) + }.joined(separator: "\n ") + let remainingReads = remainingFields.compactMap { field -> String? in + guard let readExpr = remainingReadExpr(field) else { + return nil } - if !remainingReads.isEmpty { - readSections.append("var __readerIndex = \(fixedPrefixBytes)") - readSections.append(remainingReads) - } - let consumedExpr = remainingReads.isEmpty ? "\(fixedPrefixBytes)" : "__readerIndex" - return PrimitiveFastReadLayout( - statements: readSections, - consumedExpr: consumedExpr, - fixedPrefixBytes: fixedPrefixBytes - ) + return assignLine(field, readExpr) + }.joined(separator: "\n ") + var readSections: [String] = [] + if !fixedReads.isEmpty { + readSections.append(fixedReads) + } + if !remainingReads.isEmpty { + readSections.append("var __readerIndex = \(fixedPrefixBytes)") + readSections.append(remainingReads) + } + let consumedExpr = remainingReads.isEmpty ? "\(fixedPrefixBytes)" : "__readerIndex" + return PrimitiveFastReadLayout( + statements: readSections, + consumedExpr: consumedExpr, + fixedPrefixBytes: fixedPrefixBytes + ) } private func buildPrimitiveFastReadBlock( - _ fields: [ParsedField], - assignLine: (ParsedField, String) -> String + _ fields: [ParsedField], + assignLine: (ParsedField, String) -> String ) -> String? { - guard let readLayout = buildPrimitiveFastReadStatements( - fields, - assignLine: assignLine, - remainingReadExpr: primitiveUnsafePointerReadAdvanceExpr - ) else { - return nil - } - var readSections: [String] = [] - if readLayout.fixedPrefixBytes > 0 { - readSections.append("try UnsafeUtil.checkReadable(length: __length, index: 0, need: \(readLayout.fixedPrefixBytes))") - } - readSections.append(contentsOf: readLayout.statements) - readSections.append("return \(readLayout.consumedExpr)") - let readBody = readSections.joined(separator: "\n ") - let lengthArgument = readLayout.consumedExpr == "__readerIndex" || readLayout.fixedPrefixBytes > 0 ? "__length" : "_" - return """ + guard + let readLayout = buildPrimitiveFastReadStatements( + fields, + assignLine: assignLine, + remainingReadExpr: primitiveUnsafePointerReadAdvanceExpr + ) + else { + return nil + } + var readSections: [String] = [] + if readLayout.fixedPrefixBytes > 0 { + readSections.append( + "try UnsafeUtil.checkReadable(length: __length, index: 0, need: \(readLayout.fixedPrefixBytes))" + ) + } + readSections.append(contentsOf: readLayout.statements) + readSections.append("return \(readLayout.consumedExpr)") + let readBody = readSections.joined(separator: "\n ") + let lengthArgument = + readLayout.consumedExpr == "__readerIndex" || readLayout.fixedPrefixBytes > 0 ? "__length" : "_" + return """ try UnsafeUtil.readRegion(buffer: __buffer) { __base, \(lengthArgument) in \(readBody) } @@ -343,22 +354,22 @@ private func buildPrimitiveFastReadBlock( } func buildPrimitiveFastClassReadBlock(_ fields: [ParsedField]) -> String? { - buildPrimitiveFastReadBlock(fields) { field, readExpr in - "value.\(field.name) = \(readExpr)" - } + buildPrimitiveFastReadBlock(fields) { field, readExpr in + "value.\(field.name) = \(readExpr)" + } } func buildPrimitiveFastStructReadDeclarations(_ fields: [ParsedField]) -> String? { - guard !fields.isEmpty else { - return nil - } - return fields.map { field in - "var __\(field.name): \(field.typeText) = \(field.typeText).foryDefault()" - }.joined(separator: "\n ") + guard !fields.isEmpty else { + return nil + } + return fields.map { field in + "var __\(field.name): \(field.typeText) = \(field.typeText).foryDefault()" + }.joined(separator: "\n ") } func buildPrimitiveFastStructReadBlock(_ fields: [ParsedField]) -> String? { - buildPrimitiveFastReadBlock(fields) { field, readExpr in - "__\(field.name) = \(readExpr)" - } + buildPrimitiveFastReadBlock(fields) { field, readExpr in + "__\(field.name) = \(readExpr)" + } } diff --git a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift index 9e58695384..7f82c6cbeb 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift @@ -181,9 +181,12 @@ private func buildClassReadCompatibleDataDecl( let bufferBinding = (schemaAssignBody.contains("__buffer") || compatibleAlignedAssignBody.contains("__buffer") || compatibleCases.contains("__buffer")) ? "let __buffer = context.buffer\n " : "" + let localFieldsBinding = + compatibleCases.contains("__foryLocalFields") + ? "let __foryLocalFields = Self.foryFieldsInfo(trackRef: context.trackRef)\n " : "" return """ - @inline(__always) + @inline(never) private static func __foryReadCompatibleDataImpl( _ context: ReadContext, remoteTypeInfo: TypeInfo, @@ -206,17 +209,19 @@ private func buildClassReadCompatibleDataDecl( \(compatibleAlignedAssignBody) return value } - for remoteField in typeMeta.fields { + \(localFieldsBinding)for remoteField in typeMeta.fields { switch Int(remoteField.fieldID ?? -1) { \(compatibleCases) - default: + case -1: try context.skipFieldValue(remoteField.fieldType) + default: + throw ForyError.invalidData("invalid compatible matched id \\(remoteField.fieldID ?? -2)") } } return value } - @inline(__always) + @inline(never) \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { try Self.__foryReadCompatibleDataImpl(context, remoteTypeInfo: remoteTypeInfo, reservedRefID: nil) } @@ -225,7 +230,7 @@ private func buildClassReadCompatibleDataDecl( private func buildEmptyStructReadCompatibleDataDecl(accessPrefix: String) -> String { """ - @inline(__always) + @inline(never) \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { throw ForyError.invalidData("compatible type metadata is required") @@ -266,12 +271,19 @@ private func buildStructReadCompatibleDataDecl( ) { sortedIndex, field, valueExpr in "case \(sortedIndex): __\(field.name) = \(valueExpr)" } + let changedFallbackDecl = buildStructChangedFallbackDecl( + defaults: compatibleDefaults, + cases: compatibleCases, + ctorArgs: ctorArgs + ) let bufferBinding = - (schemaReadBody.contains("__buffer") || compatibleAlignedReadBody.contains("__buffer") - || compatibleCases.contains("__buffer")) ? "let __buffer = context.buffer\n " : "" + (schemaReadBody.contains("__buffer") || compatibleAlignedReadBody.contains("__buffer")) + ? "let __buffer = context.buffer\n " : "" return """ - @inline(__always) + \(changedFallbackDecl) + + @inline(never) \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { throw ForyError.invalidData("compatible type metadata is required") @@ -290,28 +302,53 @@ private func buildStructReadCompatibleDataDecl( \(ctorArgs) ) } - \(compatibleDefaults) - for remoteField in typeMeta.fields { - switch Int(remoteField.fieldID ?? -1) { - \(compatibleCases) - default: - try context.skipFieldValue(remoteField.fieldType) - } - } - return Self( - \(ctorArgs) + return try Self.__foryReadChangedData( + context, + typeMeta: typeMeta ) } """ } +private func buildStructChangedFallbackDecl( + defaults: String, + cases: String, + ctorArgs: String +) -> String { + let bufferBinding = cases.contains("__buffer") ? "let __buffer = context.buffer\n " : "" + let localFieldsBinding = + cases.contains("__foryLocalFields") + ? "let __foryLocalFields = Self.foryFieldsInfo(trackRef: context.trackRef)\n " : "" + return """ + @inline(never) + private static func __foryReadChangedData( + _ context: ReadContext, + typeMeta: TypeMeta + ) throws -> Self { + \(bufferBinding) + \(defaults) + \(localFieldsBinding)for remoteField in typeMeta.fields { + switch Int(remoteField.fieldID ?? -1) { + \(cases) + case -1: + try context.skipFieldValue(remoteField.fieldType) + default: + throw ForyError.invalidData("invalid compatible matched id \\(remoteField.fieldID ?? -2)") + } + } + return Self( + \(ctorArgs) + ) + } + """ +} + private func buildClassAssignBody( sortedFields: [ParsedField], primitiveFastFields: [ParsedField], compatibleAligned: Bool ) -> String { - let remainingAssignLines = sortedFields.dropFirst(primitiveFastFields.count).map { - field -> String in + let remainingAssignLines = sortedFields.dropFirst(primitiveFastFields.count).map { field -> String in let valueExpr: String if compatibleAligned { valueExpr = compatibleSchemaReadFieldExpr(field) @@ -343,8 +380,7 @@ private func buildStructReadBody( primitiveFastFields: [ParsedField], compatibleAligned: Bool ) -> String { - let remainingReadLines = sortedFields.dropFirst(primitiveFastFields.count).map { - field -> String in + let remainingReadLines = sortedFields.dropFirst(primitiveFastFields.count).map { field -> String in let valueExpr = compatibleAligned ? compatibleSchemaReadFieldExpr(field) : schemaReadFieldExpr(field) return "let __\(field.name) = \(valueExpr)" @@ -395,64 +431,107 @@ private func buildCompatibleReadCases( assignCase: (Int, ParsedField, String) -> String ) -> String { sortedFields.enumerated().map { sortedIndex, field -> String in - let directValueExpr = readFieldExpr( + let directValueExpr = compatibleSchemaReadFieldExpr(field) + let compatibleValueExpr = readFieldExpr( field, refModeExpr: "RefMode.from(nullable: remoteField.fieldType.nullable, trackRef: remoteField.fieldType.trackRef)", readTypeInfoExpr: "TypeId.needsTypeInfoForField(TypeId(rawValue: remoteField.fieldType.typeID) ?? .unknown)" ) - let valueExpr = compatibleScalarReadExpr(field, directValueExpr: directValueExpr) - return assignCase(sortedIndex, field, valueExpr) + let compatibleCaseExpr = compatibleScalarReadExpr( + field, + sortedIndex: sortedIndex, + compatibleValueExpr: compatibleValueExpr + ) + return [ + assignCase(sortedIndex * 2, field, directValueExpr), + assignCase(sortedIndex * 2 + 1, field, compatibleCaseExpr) + ].joined(separator: "\n\(indent)") }.joined(separator: "\n\(indent)") } -private func compatibleScalarReadExpr(_ field: ParsedField, directValueExpr: String) -> String { - guard field.dynamicAnyCodec == nil, compatibleScalarTypeID(field.typeID) else { - return directValueExpr - } - let fieldName = swiftStringLiteral(field.schemaIdentifier) - let localRefModeExpr = fieldRefModeExpression(field) - if field.isOptional { - return """ - try { - let __localRefMode = \(localRefModeExpr) - if remoteField.fieldType.typeID == \(field.typeID) && - RefMode.from(nullable: remoteField.fieldType.nullable, trackRef: remoteField.fieldType.trackRef) == __localRefMode { - return \(directValueExpr) - } - return try foryReadCompatibleOptionalScalarField( - context, - remoteFieldType: remoteField.fieldType, - localTypeID: \(field.typeID), - fieldName: \(fieldName), - directRead: { - \(directValueExpr) - } - ) - }() - """ +private func compatibleScalarReadExpr( + _ field: ParsedField, + sortedIndex: Int, + compatibleValueExpr: String +) -> String { + guard + field.dynamicAnyCodec == nil, + let helperTarget = compatibleScalarReaderTarget(field) + else { + return compatibleValueExpr } + let helperName = + field.isOptional + ? "foryReadCompatibleOptional\(helperTarget)Field" + : "foryReadCompatible\(helperTarget)Field" return """ - try { - let __localRefMode = \(localRefModeExpr) - if remoteField.fieldType.typeID == \(field.typeID) && - RefMode.from(nullable: remoteField.fieldType.nullable, trackRef: remoteField.fieldType.trackRef) == __localRefMode { - return \(directValueExpr) - } - return try foryReadCompatibleScalarField( - context, - remoteFieldType: remoteField.fieldType, - localTypeID: \(field.typeID), - fieldName: \(fieldName), - directRead: { - \(directValueExpr) - } - ) - }() + try \(helperName)( + context, + remoteField: remoteField, + localField: __foryLocalFields[\(sortedIndex)] + ) """ } +private func compatibleScalarReaderTarget(_ field: ParsedField) -> String? { + guard compatibleScalarTypeID(field.typeID) else { + return nil + } + switch compatibleScalarPayloadType(field.typeText) { + case "Bool": + return "Bool" + case "Int8": + return "Int8" + case "Int16": + return "Int16" + case "Int32": + return "Int32" + case "Int64": + return "Int64" + case "Int": + return "Int" + case "UInt8": + return "UInt8" + case "UInt16": + return "UInt16" + case "UInt32": + return "UInt32" + case "UInt64": + return "UInt64" + case "UInt": + return "UInt" + case "Float16": + return "Float16" + case "BFloat16": + return "BFloat16" + case "Float": + return "Float" + case "Double": + return "Double" + case "String": + return "String" + case "Decimal": + return "Decimal" + default: + return nil + } +} + +private func compatibleScalarPayloadType(_ typeText: String) -> String { + var type = trimType(typeText) + if type.hasSuffix("?") { + type.removeLast() + } else if type.hasPrefix("Optional<"), type.hasSuffix(">") { + type = String(type.dropFirst("Optional<".count).dropLast()) + } + for prefix in ["Swift.", "Foundation.", "Fory."] where type.hasPrefix(prefix) { + return String(type.dropFirst(prefix.count)) + } + return type +} + private func compatibleScalarTypeID(_ typeID: UInt32) -> Bool { switch typeID { case 1...15, 17...21, 40: diff --git a/swift/Tests/ForyTests/CompatibilityTests.swift b/swift/Tests/ForyTests/CompatibilityTests.swift index 20f28223e7..6494d3db3d 100644 --- a/swift/Tests/ForyTests/CompatibilityTests.swift +++ b/swift/Tests/ForyTests/CompatibilityTests.swift @@ -88,6 +88,39 @@ private struct LocalVarUInt32V2: Equatable { } +@ForyStruct +private struct ExactEncodedV1: Equatable { + @ForyField(id: 1, encoding: .fixed) + var fixed32: Int32 = 0 + + @ForyField(id: 2, encoding: .fixed) + var fixed64: Int64 = 0 + + @ForyField(id: 3, encoding: .tagged) + var tagged64: Int64 = 0 + + @ForyField(id: 4) + var tail: String = "" +} + +@ForyStruct +private struct ExactEncodedV2: Equatable { + @ForyField(id: 0) + var added: Int32 = 0 + + @ForyField(id: 3, encoding: .tagged) + var tagged64: Int64 = 0 + + @ForyField(id: 1, encoding: .fixed) + var fixed32: Int32 = 0 + + @ForyField(id: 4) + var tail: String = "" + + @ForyField(id: 2, encoding: .fixed) + var fixed64: Int64 = 0 +} + @ForyStruct private struct ScalarBoolBox: Equatable { var value: Bool = false @@ -210,6 +243,18 @@ private struct CompatibleNullableListFieldV1: Equatable { var extra: Int32 = 0 } +@ForyStruct +private struct CompatibleNestedNullableListV1: Equatable { + @ListField(element: .list(element: .int32(nullable: true))) + var values: [[Int32?]] = [] +} + +@ForyStruct +private struct CompatibleNestedRequiredListV2: Equatable { + @ListField(element: .list(element: .int32())) + var values: [[Int32]] = [] +} + @ForyStruct private struct CompatibleNestedListArrayFieldV1: Equatable { @ListField(element: .list(element: .int32(encoding: .fixed))) @@ -522,7 +567,7 @@ func sameTypeNullableScalarUsesStrictSourceRead() throws { } @Test -func scalarTrackRefMismatchIsUnassigned() throws { +func scalarTrackRefMismatchIsRejected() throws { let empty = MetaString.empty(specialChar1: "_", specialChar2: "_") let local = try TypeMeta( typeID: TypeId.compatibleStruct.rawValue, @@ -549,8 +594,9 @@ func scalarTrackRefMismatchIsUnassigned() throws { fieldType: TypeMeta.FieldType(typeID: TypeId.bool.rawValue, nullable: false, trackRef: true) ) ]) - let resolvedRemote = try remoteTracking.assigningFieldIDs(from: local) - #expect(resolvedRemote.fields[0].fieldID == -1) + try expectInvalidData { + _ = try remoteTracking.assigningFieldIDs(from: local) + } let localTracking = try TypeMeta( typeID: TypeId.compatibleStruct.rawValue, @@ -577,8 +623,9 @@ func scalarTrackRefMismatchIsUnassigned() throws { fieldName: "$tag1", fieldType: TypeMeta.FieldType(typeID: TypeId.bool.rawValue, nullable: false)) ]) - let resolvedLocalTracking = try remote.assigningFieldIDs(from: localTracking) - #expect(resolvedLocalTracking.fields[0].fieldID == -1) + try expectInvalidData { + _ = try remote.assigningFieldIDs(from: localTracking) + } let resolvedBothTracking = try remoteTracking.assigningFieldIDs(from: localTracking) #expect(resolvedBothTracking.fields[0].fieldID == 0) @@ -612,12 +659,283 @@ func scalarTrackRefMismatchIsUnassigned() throws { let resolvedBothNullableTracking = try remoteNullableTracking.assigningFieldIDs( from: localNullableTracking) #expect(resolvedBothNullableTracking.fields[0].fieldID == 0) - let resolvedRemoteNullableTracking = try remoteNullableTracking.assigningFieldIDs( - from: localTracking) - #expect(resolvedRemoteNullableTracking.fields[0].fieldID == -1) - let resolvedLocalNullableTracking = try remoteTracking.assigningFieldIDs( - from: localNullableTracking) - #expect(resolvedLocalNullableTracking.fields[0].fieldID == -1) + try expectInvalidData { + _ = try remoteNullableTracking.assigningFieldIDs(from: localTracking) + } + try expectInvalidData { + _ = try remoteTracking.assigningFieldIDs(from: localNullableTracking) + } +} + +@Test +func namedRemoteOnlyFieldIsSkipped() throws { + let empty = MetaString.empty(specialChar1: "_", specialChar2: "_") + let stringType = TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false) + let intType = TypeMeta.FieldType(typeID: TypeId.int64.rawValue, nullable: false) + let local = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "username", fieldType: stringType), + TypeMeta.FieldInfo(fieldID: nil, fieldName: "id", fieldType: intType) + ]) + let remote = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "email", fieldType: stringType), + TypeMeta.FieldInfo(fieldID: nil, fieldName: "id", fieldType: intType) + ]) + + let resolved = try remote.assigningFieldIDs(from: local) + #expect(resolved.fields[0].fieldID == -1) + #expect(resolved.fields[1].fieldID == 2) +} + +@Test +func remoteOnlyFieldsSkipWhenLocalSchemaIsEmpty() throws { + let empty = MetaString.empty(specialChar1: "_", specialChar2: "_") + let stringType = TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false) + let intType = TypeMeta.FieldType(typeID: TypeId.int64.rawValue, nullable: false) + let local = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: []) + let remote = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: 1, fieldName: "$tag1", fieldType: stringType), + TypeMeta.FieldInfo(fieldID: 2, fieldName: "$tag2", fieldType: intType) + ]) + + let resolved = try remote.assigningFieldIDs(from: local) + #expect(resolved.fields[0].fieldID == -1) + #expect(resolved.fields[1].fieldID == -1) +} + +@Test +func nameRemoteFieldDoesNotMatchTaggedLocalField() throws { + let empty = MetaString.empty(specialChar1: "_", specialChar2: "_") + let stringType = TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false) + let local = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: 1, fieldName: "value", fieldType: stringType) + ]) + let remote = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "value", fieldType: stringType) + ]) + + let resolved = try remote.assigningFieldIDs(from: local) + #expect(resolved.fields[0].fieldID == -1) +} + +@Test +func duplicateRemoteNameBindingFails() throws { + let empty = MetaString.empty(specialChar1: "_", specialChar2: "_") + let stringType = TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false) + let local = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "value", fieldType: stringType) + ]) + let remote = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "value", fieldType: stringType), + TypeMeta.FieldInfo(fieldID: nil, fieldName: "value", fieldType: stringType) + ]) + + #expect(throws: ForyError.invalidData("compatible field value duplicates local field value")) { + _ = try remote.assigningFieldIDs(from: local) + } +} + +@Test +func matchedFieldIdOverflowFails() throws { + let empty = MetaString.empty(specialChar1: "_", specialChar2: "_") + let fieldType = TypeMeta.FieldType(typeID: TypeId.bool.rawValue, nullable: false) + let overflowIndex = Int(Int16.max) / 2 + 1 + let localFields = (0...overflowIndex).map { + TypeMeta.FieldInfo(fieldID: nil, fieldName: "f\($0)", fieldType: fieldType) + } + let local = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: localFields) + let remote = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: "f\(overflowIndex)", + fieldType: fieldType) + ]) + + try expectInvalidData { + _ = try remote.assigningFieldIDs(from: local) + } +} + +@Test +func matchedByteFamilyClassification() throws { + let empty = MetaString.empty(specialChar1: "_", specialChar2: "_") + let binaryType = TypeMeta.FieldType(typeID: TypeId.binary.rawValue, nullable: false) + let uint8ArrayType = TypeMeta.FieldType(typeID: TypeId.uint8Array.rawValue, nullable: false) + let int8ArrayType = TypeMeta.FieldType(typeID: TypeId.int8Array.rawValue, nullable: false) + let local = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "payload", fieldType: binaryType) + ]) + let remoteUInt8Array = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "payload", fieldType: uint8ArrayType) + ]) + let remoteInt8Array = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "payload", fieldType: int8ArrayType) + ]) + + let resolved = try remoteUInt8Array.assigningFieldIDs(from: local) + #expect(resolved.fields[0].fieldID == 1) + try expectInvalidData { + _ = try remoteInt8Array.assigningFieldIDs(from: local) + } +} + +@Test +func matchedNestedScalarShapeAcceptsNullableDrift() throws { + let empty = MetaString.empty(specialChar1: "_", specialChar2: "_") + let local = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: "values", + fieldType: TypeMeta.FieldType( + typeID: TypeId.list.rawValue, + nullable: false, + generics: [TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false)])) + ]) + let remote = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: "values", + fieldType: TypeMeta.FieldType( + typeID: TypeId.list.rawValue, + nullable: false, + generics: [TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: true)])) + ]) + + let resolved = try remote.assigningFieldIDs(from: local) + #expect(resolved.fields[0].fieldID == 1) +} + +@Test +func matchedNestedScalarShapeRejectsRefDrift() throws { + let empty = MetaString.empty(specialChar1: "_", specialChar2: "_") + let local = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: "values", + fieldType: TypeMeta.FieldType( + typeID: TypeId.list.rawValue, + nullable: false, + generics: [TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false)])) + ]) + let remote = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 1, + namespace: empty, + typeName: empty, + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: "values", + fieldType: TypeMeta.FieldType( + typeID: TypeId.list.rawValue, + nullable: false, + generics: [ + TypeMeta.FieldType( + typeID: TypeId.string.rawValue, + nullable: false, + trackRef: true) + ])) + ]) + + try expectInvalidData { + _ = try remote.assigningFieldIDs(from: local) + } } @Test @@ -705,6 +1023,28 @@ func sameSchemaScalarPreserved() throws { #expect(decoded == value) } +@Test +func encodedExactReadsUseCodecs() throws { + let value = ExactEncodedV1( + fixed32: 0x0102_0304, + fixed64: 0x0102_0304_0506_0708, + tagged64: Int64(Int32.max) + 0x1020, + tail: "aligned" + ) + + let sameSchema = Fory(config: .init(trackRef: false, compatible: true)) + sameSchema.register(ExactEncodedV1.self, id: 9960) + let sameDecoded: ExactEncodedV1 = try sameSchema.deserialize(try sameSchema.serialize(value)) + #expect(sameDecoded == value) + + let changedDecoded: ExactEncodedV2 = try compatibleDecode(value, as: ExactEncodedV2.self, id: 9961) + #expect(changedDecoded.added == 0) + #expect(changedDecoded.fixed32 == value.fixed32) + #expect(changedDecoded.fixed64 == value.fixed64) + #expect(changedDecoded.tagged64 == value.tagged64) + #expect(changedDecoded.tail == value.tail) +} + @Test func scalarConversionDoesNotDriveFallbackMatch() throws { let decoded: LocalFallbackStringBox = try compatibleDecode( @@ -746,7 +1086,7 @@ func compatibleModePreservesSharedAndCircularReferencesForMacroObjects() throws items: [shared, shared], byName: [ "left": shared, - "right": shared, + "right": shared ] ) @@ -806,7 +1146,7 @@ func compatibleNestedArrayEvolves() throws { let sourceV1 = CompatibleNestedArrayV1( items: [ CompatibleNestedProfileV1(id: 1, name: "alpha"), - CompatibleNestedProfileV1(id: 2, name: "beta"), + CompatibleNestedProfileV1(id: 2, name: "beta") ] ) let decodedAsV2: CompatibleNestedArrayV2 = try readerV2.deserialize( @@ -827,7 +1167,7 @@ func compatibleNestedArrayEvolves() throws { let sourceV2 = CompatibleNestedArrayV2( items: [ CompatibleNestedProfileV2(id: 3, name: "gamma", alias: "g", scores: [3, 4]), - CompatibleNestedProfileV2(id: 4, name: "delta", alias: "d", scores: []), + CompatibleNestedProfileV2(id: 4, name: "delta", alias: "d", scores: []) ] ) let decodedAsV1: CompatibleNestedArrayV1 = try readerV1.deserialize( @@ -835,7 +1175,7 @@ func compatibleNestedArrayEvolves() throws { #expect( decodedAsV1.items == [ CompatibleNestedProfileV1(id: 3, name: "gamma"), - CompatibleNestedProfileV1(id: 4, name: "delta"), + CompatibleNestedProfileV1(id: 4, name: "delta") ]) } @@ -854,7 +1194,7 @@ func compatibleReadConvertsFixedUInt32() throws { } @Test -func compatibleSkipUsesRemoteMetadataForNestedMapListSetFields() throws { +func compatibleRejectsNestedMapListMismatch() throws { let writer = Fory(config: .init(trackRef: false, compatible: true)) writer.register(RemoteNestedFixedMapV1.self, id: 9921) @@ -864,15 +1204,17 @@ func compatibleSkipUsesRemoteMetadataForNestedMapListSetFields() throws { let source = RemoteNestedFixedMapV1( data: [ "a": [1, nil, Int32.max], - "b": [], + "b": [] ], keep: 84, ids: [nil, -1, Int32.max] ) - let decoded: LocalNestedVarintMapV2 = try reader.deserialize(try writer.serialize(source)) - #expect(decoded.data.isEmpty) - #expect(decoded.keep == source.keep) - #expect(decoded.ids.isEmpty) + let bytes = try writer.serialize(source) + #expect( + throws: ForyError.invalidData("compatible field $tag1 cannot be read as local field data") + ) { + let _: LocalNestedVarintMapV2 = try reader.deserialize(bytes) + } } @Test @@ -918,7 +1260,7 @@ func compatibleReadAdaptsArrayFieldToDefaultVarintListField() throws { } @Test -func compatibleReadRejectsNullableListElementsForArrayField() throws { +func compatibleReadAcceptsNullableListSchemaForArrayField() throws { let writer = Fory(config: .init(trackRef: false, compatible: true)) writer.register(CompatibleNullableListFieldV1.self, id: 9923) @@ -929,28 +1271,45 @@ func compatibleReadRejectsNullableListElementsForArrayField() throws { let decoded: CompatibleArrayFieldV2 = try reader.deserialize(bytes) #expect(decoded.values == [1, 2, 3]) - let nullableBytes = try writer.serialize( - CompatibleNullableListFieldV1(values: [1, nil, 3], extra: 9)) + let nullBytes = try writer.serialize(CompatibleNullableListFieldV1(values: [1, nil, 3], extra: 9)) #expect( throws: ForyError.invalidData("compatible list-to-array field cannot read nullable elements") ) { - let _: CompatibleArrayFieldV2 = try reader.deserialize(nullableBytes) + let _: CompatibleArrayFieldV2 = try reader.deserialize(nullBytes) } } @Test -func compatibleReadSkipsNestedListArrayFieldPair() throws { +func compatibleRejectsNestedListArrayPair() throws { let writer = Fory(config: .init(trackRef: false, compatible: true)) writer.register(CompatibleNestedListArrayFieldV1.self, id: 9926) let reader = Fory(config: .init(trackRef: false, compatible: true)) reader.register(CompatibleNestedArrayListFieldV2.self, id: 9926) - let decoded: CompatibleNestedArrayListFieldV2 = try reader.deserialize( - try writer.serialize(CompatibleNestedListArrayFieldV1(values: [[1, 2]], keep: 7)) - ) - #expect(decoded.values.isEmpty) - #expect(decoded.keep == 7) + let bytes = try writer.serialize(CompatibleNestedListArrayFieldV1(values: [[1, 2]], keep: 7)) + #expect( + throws: ForyError.invalidData("compatible field values cannot be read as local field values") + ) { + let _: CompatibleNestedArrayListFieldV2 = try reader.deserialize(bytes) + } +} + +@Test +func compatibleReadsNestedNullableScalarWithoutNulls() throws { + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(CompatibleNestedNullableListV1.self, id: 9927) + + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(CompatibleNestedRequiredListV2.self, id: 9927) + + let bytes = try writer.serialize(CompatibleNestedNullableListV1(values: [[1, 2]])) + let decoded: CompatibleNestedRequiredListV2 = try reader.deserialize(bytes) + #expect(decoded.values == [[1, 2]]) + + let nullBytes = try writer.serialize(CompatibleNestedNullableListV1(values: [[1, nil, 2]])) + let decodedWithNull: CompatibleNestedRequiredListV2 = try reader.deserialize(nullBytes) + #expect(decodedWithNull.values == [[1, 0, 2]]) } @Test @@ -966,7 +1325,7 @@ func compatibleNestedMapEvolves() throws { let sourceV1 = CompatibleNestedMapV1( items: [ 1: CompatibleNestedProfileV1(id: 10, name: "first"), - 2: CompatibleNestedProfileV1(id: 20, name: "second"), + 2: CompatibleNestedProfileV1(id: 20, name: "second") ] ) let decodedAsV2: CompatibleNestedMapV2 = try readerV2.deserialize( @@ -994,14 +1353,14 @@ func compatibleNestedReadsReuseTypeMeta() throws { let first = CompatibleNestedArrayV1( items: [ CompatibleNestedProfileV1(id: 1, name: "alpha"), - CompatibleNestedProfileV1(id: 2, name: "beta"), + CompatibleNestedProfileV1(id: 2, name: "beta") ] ) let second = CompatibleNestedArrayV1( items: [ CompatibleNestedProfileV1(id: 3, name: "gamma"), CompatibleNestedProfileV1(id: 4, name: "delta"), - CompatibleNestedProfileV1(id: 5, name: "epsilon"), + CompatibleNestedProfileV1(id: 5, name: "epsilon") ] ) diff --git a/swift/Tests/ForyXlangTests/main.swift b/swift/Tests/ForyXlangTests/main.swift index 869298c03b..2756c341b1 100644 --- a/swift/Tests/ForyXlangTests/main.swift +++ b/swift/Tests/ForyXlangTests/main.swift @@ -42,12 +42,15 @@ private struct Item { @ForyStruct private struct SimpleStruct { - var f1: [Int32: Double] = [:] + // Java xlang uses boxed collection element types for this case, so the + // schema marks map keys/values and list elements nullable even when the + // sample payload contains only non-null values. + var f1: [Int32?: Double?] = [:] var f2: Int32 = 0 var f3: Item = .foryDefault() var f4: String = "" var f5: PeerColor = .green - var f6: [String] = [] + var f6: [String?] = [] var f7: Int32 = 0 var f8: Int32 = 0 var last: Int32 = 0