diff --git a/BUILD b/BUILD index 809727199e..51bdffc8d6 100644 --- a/BUILD +++ b/BUILD @@ -31,6 +31,7 @@ pyx_library( linkstatic = 1, ), deps = [ + "//cpp/fory/python:_pyfory", "//cpp/fory/util:fory_util", ], ) diff --git a/cpp/fory/python/pyfory.cc b/cpp/fory/python/pyfory.cc index 04117a7b54..c240639d4f 100644 --- a/cpp/fory/python/pyfory.cc +++ b/cpp/fory/python/pyfory.cc @@ -19,6 +19,12 @@ #include "fory/python/pyfory.h" +#include +#include +#include +#include +#include + static PyObject **py_sequence_get_items(PyObject *collection) { if (PyList_CheckExact(collection)) { return ((PyListObject *)collection)->ob_item; @@ -29,6 +35,118 @@ static PyObject **py_sequence_get_items(PyObject *collection) { } namespace fory { + +static std::string fetch_python_error_message() { + PyObject *type = nullptr; + PyObject *value = nullptr; + PyObject *traceback = nullptr; + PyErr_Fetch(&type, &value, &traceback); + PyErr_NormalizeException(&type, &value, &traceback); + std::string message = "python stream read failed"; + if (value != nullptr) { + PyObject *value_str = PyObject_Str(value); + if (value_str != nullptr) { + const char *c_str = PyUnicode_AsUTF8(value_str); + if (c_str != nullptr) { + message = c_str; + } + Py_DECREF(value_str); + } else { + PyErr_Clear(); + } + } + Py_XDECREF(type); + Py_XDECREF(value); + Py_XDECREF(traceback); + return message; +} + +class PythonInputStreamBuf final : public std::streambuf { +public: + explicit PythonInputStreamBuf(PyObject *stream) : stream_(stream) { + Py_INCREF(stream_); + } + + ~PythonInputStreamBuf() override { + if (stream_ != nullptr) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + Py_DECREF(stream_); + PyGILState_Release(gil_state); + stream_ = nullptr; + } + } + + bool has_error() const { return has_error_; } + + const std::string &error_message() const { return error_message_; } + +protected: + std::streamsize xsgetn(char *s, std::streamsize count) override { + if (count <= 0 || has_error_) { + return 0; + } + PyGILState_STATE gil_state = PyGILState_Ensure(); + const Py_ssize_t requested = static_cast( + std::min(count, PY_SSIZE_T_MAX)); + PyObject *chunk = PyObject_CallMethod(stream_, "read", "n", requested); + if (chunk == nullptr) { + has_error_ = true; + error_message_ = fetch_python_error_message(); + PyGILState_Release(gil_state); + return 0; + } + Py_buffer view; + if (PyObject_GetBuffer(chunk, &view, PyBUF_CONTIG_RO) != 0) { + has_error_ = true; + error_message_ = fetch_python_error_message(); + Py_DECREF(chunk); + PyGILState_Release(gil_state); + return 0; + } + const std::streamsize read_bytes = std::min( + count, static_cast(view.len)); + if (read_bytes > 0) { + std::memcpy(s, view.buf, static_cast(read_bytes)); + } + PyBuffer_Release(&view); + Py_DECREF(chunk); + PyGILState_Release(gil_state); + return read_bytes; + } + + int_type underflow() override { + if (gptr() < egptr()) { + return traits_type::to_int_type(*gptr()); + } + if (xsgetn(¤t_, 1) != 1) { + return traits_type::eof(); + } + setg(¤t_, ¤t_, ¤t_ + 1); + return traits_type::to_int_type(current_); + } + +private: + PyObject *stream_ = nullptr; + bool has_error_ = false; + std::string error_message_; + char current_ = 0; +}; + +class PythonInputStream final : public std::istream { +public: + explicit PythonInputStream(PyObject *stream) + : std::istream(nullptr), buf_(stream) { + rdbuf(&buf_); + } + + bool has_error() const { return buf_.has_error(); } + + const std::string &error_message() const { return buf_.error_message(); } + +private: + PythonInputStreamBuf buf_; +}; + int Fory_PyBooleanSequenceWriteToBuffer(PyObject *collection, Buffer *buffer, Py_ssize_t start_index) { PyObject **items = py_sequence_get_items(collection); @@ -58,4 +176,32 @@ int Fory_PyFloatSequenceWriteToBuffer(PyObject *collection, Buffer *buffer, } return 0; } + +int Fory_PyCreateBufferFromStream(PyObject *stream, uint32_t buffer_size, + Buffer **out, std::string *error_message) { + if (stream == nullptr) { + *error_message = "stream must not be null"; + return -1; + } + const int has_read = PyObject_HasAttrString(stream, "read"); + if (has_read < 0) { + *error_message = fetch_python_error_message(); + return -1; + } + if (has_read == 0) { + *error_message = "stream object must provide read(size) method"; + return -1; + } + try { + auto py_stream = std::make_shared(stream); + auto source_stream = std::static_pointer_cast(py_stream); + auto fory_stream = std::make_shared( + std::move(source_stream), buffer_size); + *out = new Buffer(std::move(fory_stream)); + return 0; + } catch (const std::exception &e) { + *error_message = e.what(); + return -1; + } +} } // namespace fory diff --git a/cpp/fory/python/pyfory.h b/cpp/fory/python/pyfory.h index 467a86970f..780ef35343 100644 --- a/cpp/fory/python/pyfory.h +++ b/cpp/fory/python/pyfory.h @@ -18,6 +18,8 @@ */ #pragma once +#include + #include "Python.h" #include "fory/util/buffer.h" @@ -26,4 +28,6 @@ int Fory_PyBooleanSequenceWriteToBuffer(PyObject *collection, Buffer *buffer, Py_ssize_t start_index); int Fory_PyFloatSequenceWriteToBuffer(PyObject *collection, Buffer *buffer, Py_ssize_t start_index); -} // namespace fory \ No newline at end of file +int Fory_PyCreateBufferFromStream(PyObject *stream, uint32_t buffer_size, + Buffer **out, std::string *error_message); +} // namespace fory diff --git a/cpp/fory/serialization/BUILD b/cpp/fory/serialization/BUILD index 5f737ae5f2..c3dfe5bd58 100644 --- a/cpp/fory/serialization/BUILD +++ b/cpp/fory/serialization/BUILD @@ -174,6 +174,16 @@ cc_test( ], ) +cc_test( + name = "stream_test", + srcs = ["stream_test.cc"], + deps = [ + ":fory_serialization", + "@googletest//:gtest", + "@googletest//:gtest_main", + ], +) + cc_binary( name = "xlang_test_main", srcs = ["xlang_test_main.cc"], diff --git a/cpp/fory/serialization/CMakeLists.txt b/cpp/fory/serialization/CMakeLists.txt index d4852742e0..0582554c79 100644 --- a/cpp/fory/serialization/CMakeLists.txt +++ b/cpp/fory/serialization/CMakeLists.txt @@ -113,6 +113,10 @@ if(FORY_BUILD_TESTS) add_executable(fory_serialization_any_test any_serializer_test.cc) target_link_libraries(fory_serialization_any_test fory_serialization GTest::gtest GTest::gtest_main) gtest_discover_tests(fory_serialization_any_test) + + add_executable(fory_serialization_stream_test stream_test.cc) + target_link_libraries(fory_serialization_stream_test fory_serialization GTest::gtest GTest::gtest_main) + gtest_discover_tests(fory_serialization_stream_test) endif() # xlang test binary diff --git a/cpp/fory/serialization/serializer.h b/cpp/fory/serialization/serializer.h index 5eeae493ec..4d7ac37029 100644 --- a/cpp/fory/serialization/serializer.h +++ b/cpp/fory/serialization/serializer.h @@ -29,6 +29,7 @@ #include "fory/util/error.h" #include "fory/util/result.h" #include +#include #include namespace fory { @@ -83,10 +84,15 @@ struct HeaderInfo { /// @param buffer Input buffer /// @return Header information or error inline Result read_header(Buffer &buffer) { - // Check minimum header size (1 byte: flags) - if (buffer.reader_index() + 1 > buffer.size()) { - return Unexpected( - Error::buffer_out_of_bound(buffer.reader_index(), 1, buffer.size())); + // Ensure minimum header size (1 byte: flags), including stream-backed + // buffers. + Error error; + const uint64_t target = static_cast(buffer.reader_index()) + 1; + if (target > std::numeric_limits::max()) { + return Unexpected(Error::out_of_bound("header reader index overflow")); + } + if (!buffer.ensure_size(static_cast(target), error)) { + return Unexpected(std::move(error)); } HeaderInfo info; diff --git a/cpp/fory/serialization/stream_test.cc b/cpp/fory/serialization/stream_test.cc new file mode 100644 index 0000000000..7cbe7fb459 --- /dev/null +++ b/cpp/fory/serialization/stream_test.cc @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "fory/serialization/fory.h" +#include "fory/util/stream.h" + +namespace fory { +namespace serialization { +namespace test { + +struct StreamPoint { + int32_t x; + int32_t y; + + bool operator==(const StreamPoint &other) const { + return x == other.x && y == other.y; + } + + FORY_STRUCT(StreamPoint, x, y); +}; + +struct StreamEnvelope { + std::string name; + std::vector values; + std::map metrics; + StreamPoint point; + bool active; + + bool operator==(const StreamEnvelope &other) const { + return name == other.name && values == other.values && + metrics == other.metrics && point == other.point && + active == other.active; + } + + FORY_STRUCT(StreamEnvelope, name, values, metrics, point, active); +}; + +struct SharedIntPair { + std::shared_ptr first; + std::shared_ptr second; + + FORY_STRUCT(SharedIntPair, first, second); +}; + +class OneByteStreamBuf final : public std::streambuf { +public: + explicit OneByteStreamBuf(std::vector data) + : data_(std::move(data)), pos_(0) {} + +protected: + std::streamsize xsgetn(char *s, std::streamsize count) override { + if (pos_ >= data_.size() || count <= 0) { + return 0; + } + s[0] = static_cast(data_[pos_]); + ++pos_; + return 1; + } + + int_type underflow() override { + if (pos_ >= data_.size()) { + return traits_type::eof(); + } + current_ = static_cast(data_[pos_]); + setg(¤t_, ¤t_, ¤t_ + 1); + return traits_type::to_int_type(current_); + } + +private: + std::vector data_; + size_t pos_; + char current_ = 0; +}; + +class OneByteIStream final : public std::istream { +public: + explicit OneByteIStream(std::vector data) + : std::istream(nullptr), buf_(std::move(data)) { + rdbuf(&buf_); + } + +private: + OneByteStreamBuf buf_; +}; + +static inline void register_stream_types(Fory &fory) { + uint32_t type_id = 1; + fory.register_struct(type_id++); + fory.register_struct(type_id++); + fory.register_struct(type_id++); +} + +TEST(StreamSerializationTest, PrimitiveAndStringRoundTrip) { + auto fory = Fory::builder().xlang(true).track_ref(false).build(); + + auto number_bytes_result = fory.serialize(-9876543212345LL); + ASSERT_TRUE(number_bytes_result.ok()) + << number_bytes_result.error().to_string(); + OneByteIStream number_source(std::move(number_bytes_result).value()); + ForyInputStream number_stream(number_source, 8); + Buffer number_buffer(number_stream); + auto number_result = fory.deserialize(number_buffer); + ASSERT_TRUE(number_result.ok()) << number_result.error().to_string(); + EXPECT_EQ(number_result.value(), -9876543212345LL); + + auto string_bytes_result = fory.serialize("stream-hello-世界"); + ASSERT_TRUE(string_bytes_result.ok()) + << string_bytes_result.error().to_string(); + OneByteIStream string_source(std::move(string_bytes_result).value()); + ForyInputStream string_stream(string_source, 8); + Buffer string_buffer(string_stream); + auto string_result = fory.deserialize(string_buffer); + ASSERT_TRUE(string_result.ok()) << string_result.error().to_string(); + EXPECT_EQ(string_result.value(), "stream-hello-世界"); +} + +TEST(StreamSerializationTest, StructRoundTrip) { + auto fory = Fory::builder().xlang(true).track_ref(true).build(); + register_stream_types(fory); + + StreamEnvelope original{ + "payload-name", + {1, 3, 5, 7, 9}, + {{"count", 5}, {"sum", 25}, {"max", 9}}, + {42, -7}, + true, + }; + + auto bytes_result = fory.serialize(original); + ASSERT_TRUE(bytes_result.ok()) << bytes_result.error().to_string(); + + OneByteIStream source(std::move(bytes_result).value()); + ForyInputStream stream(source, 4); + Buffer buffer(stream); + auto result = fory.deserialize(buffer); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), original); +} + +TEST(StreamSerializationTest, SequentialDeserializeFromSingleStream) { + auto fory = Fory::builder().xlang(true).track_ref(true).build(); + register_stream_types(fory); + + StreamEnvelope envelope{ + "batch", {10, 20, 30}, {{"a", 1}, {"b", 2}}, {9, 8}, false, + }; + + std::vector bytes; + ASSERT_TRUE(fory.serialize_to(bytes, static_cast(12345)).ok()); + ASSERT_TRUE(fory.serialize_to(bytes, std::string("next-value")).ok()); + ASSERT_TRUE(fory.serialize_to(bytes, envelope).ok()); + + OneByteIStream source(bytes); + ForyInputStream stream(source, 3); + Buffer buffer(stream); + + auto first = fory.deserialize(buffer); + ASSERT_TRUE(first.ok()) << first.error().to_string(); + EXPECT_EQ(first.value(), 12345); + + auto second = fory.deserialize(buffer); + ASSERT_TRUE(second.ok()) << second.error().to_string(); + EXPECT_EQ(second.value(), "next-value"); + + auto third = fory.deserialize(buffer); + ASSERT_TRUE(third.ok()) << third.error().to_string(); + EXPECT_EQ(third.value(), envelope); + + EXPECT_EQ(buffer.reader_index(), bytes.size()); +} + +TEST(StreamSerializationTest, SharedPointerIdentityRoundTrip) { + auto fory = Fory::builder().xlang(true).track_ref(true).build(); + register_stream_types(fory); + + auto shared = std::make_shared(2026); + SharedIntPair pair{shared, shared}; + + auto bytes_result = fory.serialize(pair); + ASSERT_TRUE(bytes_result.ok()) << bytes_result.error().to_string(); + + OneByteIStream source(std::move(bytes_result).value()); + ForyInputStream stream(source, 2); + Buffer buffer(stream); + auto result = fory.deserialize(buffer); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + ASSERT_NE(result.value().first, nullptr); + ASSERT_NE(result.value().second, nullptr); + EXPECT_EQ(*result.value().first, 2026); + EXPECT_EQ(result.value().first, result.value().second); +} + +TEST(StreamSerializationTest, TruncatedStreamReturnsError) { + auto fory = Fory::builder().xlang(true).track_ref(true).build(); + register_stream_types(fory); + + StreamEnvelope original{ + "truncated", {1, 2, 3, 4}, {{"k", 99}}, {7, 7}, true, + }; + auto bytes_result = fory.serialize(original); + ASSERT_TRUE(bytes_result.ok()) << bytes_result.error().to_string(); + + std::vector truncated = std::move(bytes_result).value(); + ASSERT_GT(truncated.size(), 1u); + truncated.pop_back(); + + OneByteIStream source(truncated); + ForyInputStream stream(source, 4); + Buffer buffer(stream); + auto result = fory.deserialize(buffer); + EXPECT_FALSE(result.ok()); +} + +} // namespace test +} // namespace serialization +} // namespace fory diff --git a/cpp/fory/serialization/string_serializer.h b/cpp/fory/serialization/string_serializer.h index 139564b93a..432718ccd6 100644 --- a/cpp/fory/serialization/string_serializer.h +++ b/cpp/fory/serialization/string_serializer.h @@ -25,6 +25,7 @@ #include "fory/util/error.h" #include "fory/util/string_util.h" #include +#include #include #include #include @@ -113,9 +114,14 @@ inline std::string read_string_data(ReadContext &ctx) { return std::string(); } - // Validate length against buffer remaining size - if (length > ctx.buffer().remaining_size()) { - ctx.set_error(Error::invalid_data("String length exceeds buffer size")); + const uint64_t target = + static_cast(ctx.buffer().reader_index()) + length; + if (target > std::numeric_limits::max()) { + ctx.set_error(Error::invalid_data("String length exceeds uint32 range")); + return std::string(); + } + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_size( + static_cast(target), ctx.error()))) { return std::string(); } @@ -176,9 +182,14 @@ inline std::u16string read_u16string_data(ReadContext &ctx) { return std::u16string(); } - // Validate length against buffer remaining size - if (length > ctx.buffer().remaining_size()) { - ctx.set_error(Error::invalid_data("String length exceeds buffer size")); + const uint64_t target = + static_cast(ctx.buffer().reader_index()) + length; + if (target > std::numeric_limits::max()) { + ctx.set_error(Error::invalid_data("String length exceeds uint32 range")); + return std::u16string(); + } + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_size( + static_cast(target), ctx.error()))) { return std::u16string(); } diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index bb4a24bb81..9a6d9a0340 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -31,6 +31,7 @@ #include "fory/util/string_util.h" #include #include +#include #include #include #include @@ -831,6 +832,12 @@ template struct CompileTimeFieldHelpers { if constexpr (is_configurable_int_v) { return configurable_int_max_varint_bytes(); + } else if constexpr (std::is_same_v || + std::is_same_v) { + return 5; + } else if constexpr (std::is_same_v || + std::is_same_v) { + return 10; } return 0; } @@ -2753,11 +2760,15 @@ void read_struct_fields_impl(T &obj, ReadContext &ctx, // Phase 1: Read leading fixed-size primitives if any if constexpr (fixed_count > 0 && fixed_bytes > 0) { - // Pre-check bounds for all fixed-size fields at once - if (FORY_PREDICT_FALSE(buffer.reader_index() + fixed_bytes > - buffer.size())) { - ctx.set_error(Error::buffer_out_of_bound(buffer.reader_index(), - fixed_bytes, buffer.size())); + const uint64_t target = + static_cast(buffer.reader_index()) + fixed_bytes; + if (FORY_PREDICT_FALSE(target > std::numeric_limits::max())) { + ctx.set_error( + Error::out_of_bound("fixed-size struct read exceeds uint32 range")); + return; + } + if (FORY_PREDICT_FALSE(!buffer.ensure_size(static_cast(target), + ctx.error()))) { return; } // Fast read fixed-size primitives @@ -2769,6 +2780,12 @@ void read_struct_fields_impl(T &obj, ReadContext &ctx, // Note: varint bounds checking is done per-byte during reading since // varint lengths are variable (actual size << max possible size) if constexpr (varint_count > 0) { + if (FORY_PREDICT_FALSE(buffer.is_stream_backed())) { + // Stream-backed buffers may not have all varint bytes materialized yet. + // Fall back to per-field readers that propagate stream read errors. + read_remaining_fields(obj, ctx); + return; + } // Track offset locally for batch varint reading uint32_t offset = buffer.reader_index(); // Fast read varint primitives (bounds checking happens in @@ -2807,11 +2824,15 @@ read_struct_fields_impl_fast(T &obj, ReadContext &ctx, // Phase 1: Read leading fixed-size primitives if any if constexpr (fixed_count > 0 && fixed_bytes > 0) { - // Pre-check bounds for all fixed-size fields at once - if (FORY_PREDICT_FALSE(buffer.reader_index() + fixed_bytes > - buffer.size())) { - ctx.set_error(Error::buffer_out_of_bound(buffer.reader_index(), - fixed_bytes, buffer.size())); + const uint64_t target = + static_cast(buffer.reader_index()) + fixed_bytes; + if (FORY_PREDICT_FALSE(target > std::numeric_limits::max())) { + ctx.set_error( + Error::out_of_bound("fixed-size struct read exceeds uint32 range")); + return; + } + if (FORY_PREDICT_FALSE( + !buffer.ensure_size(static_cast(target), ctx.error()))) { return; } // Fast read fixed-size primitives @@ -2821,6 +2842,12 @@ read_struct_fields_impl_fast(T &obj, ReadContext &ctx, // Phase 2: Read consecutive varint primitives (int32, int64) if any if constexpr (varint_count > 0) { + if (FORY_PREDICT_FALSE(buffer.is_stream_backed())) { + // Stream-backed buffers may not have all varint bytes materialized yet. + // Fall back to per-field readers that propagate stream read errors. + read_remaining_fields(obj, ctx); + return; + } // Track offset locally for batch varint reading uint32_t offset = buffer.reader_index(); // Fast read varint primitives (bounds checking happens in diff --git a/cpp/fory/util/CMakeLists.txt b/cpp/fory/util/CMakeLists.txt index 2bc4f30c5f..1a5d42dbe1 100644 --- a/cpp/fory/util/CMakeLists.txt +++ b/cpp/fory/util/CMakeLists.txt @@ -32,6 +32,7 @@ set(FORY_UTIL_HEADERS platform.h pool.h result.h + stream.h string_util.h time_util.h flat_int_map.h diff --git a/cpp/fory/util/buffer.cc b/cpp/fory/util/buffer.cc index 4c1e5f5939..b46f62252f 100644 --- a/cpp/fory/util/buffer.cc +++ b/cpp/fory/util/buffer.cc @@ -31,6 +31,7 @@ Buffer::Buffer() { writer_index_ = 0; reader_index_ = 0; wrapped_vector_ = nullptr; + stream_ = nullptr; } Buffer::Buffer(Buffer &&buffer) noexcept { @@ -40,10 +41,13 @@ Buffer::Buffer(Buffer &&buffer) noexcept { writer_index_ = buffer.writer_index_; reader_index_ = buffer.reader_index_; wrapped_vector_ = buffer.wrapped_vector_; + stream_ = buffer.stream_; + stream_owner_ = std::move(buffer.stream_owner_); buffer.data_ = nullptr; buffer.size_ = 0; buffer.own_data_ = false; buffer.wrapped_vector_ = nullptr; + buffer.stream_ = nullptr; } Buffer &Buffer::operator=(Buffer &&buffer) noexcept { @@ -57,10 +61,13 @@ Buffer &Buffer::operator=(Buffer &&buffer) noexcept { writer_index_ = buffer.writer_index_; reader_index_ = buffer.reader_index_; wrapped_vector_ = buffer.wrapped_vector_; + stream_ = buffer.stream_; + stream_owner_ = std::move(buffer.stream_owner_); buffer.data_ = nullptr; buffer.size_ = 0; buffer.own_data_ = false; buffer.wrapped_vector_ = nullptr; + buffer.stream_ = nullptr; return *this; } diff --git a/cpp/fory/util/buffer.h b/cpp/fory/util/buffer.h index 0631d989bf..55a6926744 100644 --- a/cpp/fory/util/buffer.h +++ b/cpp/fory/util/buffer.h @@ -30,6 +30,7 @@ #include "fory/util/error.h" #include "fory/util/logging.h" #include "fory/util/result.h" +#include "fory/util/stream.h" namespace fory { @@ -40,8 +41,8 @@ class Buffer { Buffer(); Buffer(uint8_t *data, uint32_t size, bool own_data = true) - : data_(data), size_(size), own_data_(own_data), - wrapped_vector_(nullptr) { + : data_(data), size_(size), own_data_(own_data), wrapped_vector_(nullptr), + stream_(nullptr) { writer_index_ = 0; reader_index_ = 0; } @@ -54,7 +55,23 @@ class Buffer { explicit Buffer(std::vector &vec) : data_(vec.data()), size_(static_cast(vec.size())), own_data_(false), writer_index_(static_cast(vec.size())), - reader_index_(0), wrapped_vector_(&vec) {} + reader_index_(0), wrapped_vector_(&vec), stream_(nullptr) {} + + explicit Buffer(ForyInputStream &stream) + : data_(stream.data()), size_(stream.size()), own_data_(false), + writer_index_(stream.size()), reader_index_(stream.reader_index()), + wrapped_vector_(nullptr), stream_(&stream) {} + + explicit Buffer(std::shared_ptr stream) + : data_(nullptr), size_(0), own_data_(false), writer_index_(0), + reader_index_(0), wrapped_vector_(nullptr), stream_(stream.get()), + stream_owner_(std::move(stream)) { + FORY_CHECK(stream_ != nullptr) << "stream must not be null"; + data_ = stream_->data(); + size_ = stream_->size(); + writer_index_ = stream_->size(); + reader_index_ = stream_->reader_index(); + } Buffer(Buffer &&buffer) noexcept; @@ -70,6 +87,10 @@ class Buffer { FORY_ALWAYS_INLINE bool own_data() const { return own_data_; } + FORY_ALWAYS_INLINE bool is_stream_backed() const { + return stream_ != nullptr; + } + FORY_ALWAYS_INLINE uint32_t writer_index() { return writer_index_; } FORY_ALWAYS_INLINE uint32_t reader_index() { return reader_index_; } @@ -79,6 +100,24 @@ class Buffer { return size_ - reader_index_; } + FORY_ALWAYS_INLINE bool ensure_size(uint32_t target_size, Error &error) { + if (FORY_PREDICT_TRUE(target_size <= size_)) { + return true; + } + if (FORY_PREDICT_TRUE(stream_ == nullptr)) { + error.set_buffer_out_of_bound(target_size, 0, size_); + return false; + } + if (FORY_PREDICT_FALSE(!fill_to(target_size, error))) { + return false; + } + if (FORY_PREDICT_FALSE(target_size > size_)) { + error.set_buffer_out_of_bound(target_size, 0, size_); + return false; + } + return true; + } + FORY_ALWAYS_INLINE void writer_index(uint32_t writer_index) { FORY_CHECK(writer_index < std::numeric_limits::max()) << "Buffer overflow writer_index" << writer_index_ @@ -94,17 +133,40 @@ class Buffer { } FORY_ALWAYS_INLINE void reader_index(uint32_t reader_index) { - FORY_CHECK(reader_index < std::numeric_limits::max()) - << "Buffer overflow reader_index" << reader_index_ - << " target reader_index " << reader_index; + if (FORY_PREDICT_FALSE(reader_index > size_ && stream_ != nullptr)) { + Error error; + const bool ok = fill_to(reader_index, error); + FORY_CHECK(ok) + << "failed to fill stream buffer while setting reader index: " + << error.to_string(); + } + FORY_CHECK(reader_index <= size_) + << "Buffer overflow reader_index " << reader_index_ + << " target reader_index " << reader_index << " size " << size_; reader_index_ = reader_index; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } } FORY_ALWAYS_INLINE void increase_reader_index(uint32_t diff) { - uint64_t reader_index = reader_index_ + diff; - FORY_CHECK(reader_index < std::numeric_limits::max()) - << "Buffer overflow reader_index" << reader_index_ << " diff " << diff; - reader_index_ = reader_index; + const uint64_t target = static_cast(reader_index_) + diff; + FORY_CHECK(target <= std::numeric_limits::max()) + << "Buffer overflow reader_index " << reader_index_ << " diff " << diff; + if (FORY_PREDICT_FALSE(target > size_ && stream_ != nullptr)) { + Error error; + const bool ok = fill_to(static_cast(target), error); + FORY_CHECK(ok) + << "failed to fill stream buffer while increasing reader index: " + << error.to_string(); + } + FORY_CHECK(target <= size_) + << "Buffer overflow reader_index " << reader_index_ << " diff " << diff + << " size " << size_; + reader_index_ = static_cast(target); + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } } // Unsafe methods don't check bound @@ -143,6 +205,15 @@ class Buffer { } template FORY_ALWAYS_INLINE T get(uint32_t relative_offset) { + const uint64_t target = static_cast(relative_offset) + sizeof(T); + if (FORY_PREDICT_FALSE(stream_ != nullptr && target > size_)) { + FORY_CHECK(target <= std::numeric_limits::max()) + << "target offset out of range " << target; + Error error; + const bool ok = fill_to(static_cast(target), error); + FORY_CHECK(ok) << "failed to fill stream buffer for get: " + << error.to_string(); + } FORY_CHECK(relative_offset < size_) << "Out of range " << relative_offset << " should be less than " << size_; T value = reinterpret_cast(data_ + relative_offset)[0]; @@ -153,6 +224,15 @@ class Buffer { std::is_same, std::is_same, std::is_same>>> FORY_ALWAYS_INLINE T get_byte_as(uint32_t relative_offset) { + const uint64_t target = static_cast(relative_offset) + 1; + if (FORY_PREDICT_FALSE(stream_ != nullptr && target > size_)) { + FORY_CHECK(target <= std::numeric_limits::max()) + << "target offset out of range " << target; + Error error; + const bool ok = fill_to(static_cast(target), error); + FORY_CHECK(ok) << "failed to fill stream buffer for get_byte_as: " + << error.to_string(); + } FORY_CHECK(relative_offset < size_) << "Out of range " << relative_offset << " should be less than " << size_; return data_[relative_offset]; @@ -171,6 +251,15 @@ class Buffer { } FORY_ALWAYS_INLINE int32_t get_int24(uint32_t offset) { + const uint64_t target = static_cast(offset) + 3; + if (FORY_PREDICT_FALSE(stream_ != nullptr && target > size_)) { + FORY_CHECK(target <= std::numeric_limits::max()) + << "target offset out of range " << target; + Error error; + const bool ok = fill_to(static_cast(target), error); + FORY_CHECK(ok) << "failed to fill stream buffer for get_int24: " + << error.to_string(); + } FORY_CHECK(offset + 3 <= size_) << "Out of range " << offset << " should be less than " << size_; int32_t b0 = data_[offset]; @@ -197,6 +286,19 @@ class Buffer { FORY_ALWAYS_INLINE Result get_bytes_as_int64(uint32_t offset, uint32_t length, int64_t *target) { + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + const uint64_t target_size = static_cast(offset) + length; + if (target_size > size_) { + if (target_size > std::numeric_limits::max()) { + return Unexpected( + Error::out_of_bound("buffer offset exceeds uint32 range")); + } + Error error; + if (!fill_to(static_cast(target_size), error)) { + return Unexpected(std::move(error)); + } + } + } if (length == 0) { *target = 0; return Result(); @@ -694,140 +796,164 @@ class Buffer { /// Read uint8_t value from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE uint8_t read_uint8(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 1 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 1, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { return 0; } uint8_t value = data_[reader_index_]; reader_index_ += 1; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } /// Read int8_t value from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE int8_t read_int8(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 1 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 1, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { return 0; } int8_t value = static_cast(data_[reader_index_]); reader_index_ += 1; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } /// Read uint16_t value from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE uint16_t read_uint16(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 2 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 2, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(2, error))) { return 0; } uint16_t value = reinterpret_cast(data_ + reader_index_)[0]; reader_index_ += 2; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } /// Read int16_t value from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE int16_t read_int16(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 2 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 2, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(2, error))) { return 0; } int16_t value = reinterpret_cast(data_ + reader_index_)[0]; reader_index_ += 2; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } /// Read int24 value from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE int32_t read_int24(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 3 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 3, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(3, error))) { return 0; } int32_t b0 = data_[reader_index_]; int32_t b1 = data_[reader_index_ + 1]; int32_t b2 = data_[reader_index_ + 2]; reader_index_ += 3; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16); } /// Read uint32_t value from buffer (fixed 4 bytes). Sets error on bounds /// violation. FORY_ALWAYS_INLINE uint32_t read_uint32(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 4 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 4, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(4, error))) { return 0; } uint32_t value = reinterpret_cast(data_ + reader_index_)[0]; reader_index_ += 4; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } /// Read int32_t value from buffer (fixed 4 bytes). Sets error on bounds /// violation. FORY_ALWAYS_INLINE int32_t read_int32(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 4 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 4, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(4, error))) { return 0; } int32_t value = reinterpret_cast(data_ + reader_index_)[0]; reader_index_ += 4; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } /// Read uint64_t value from buffer (fixed 8 bytes). Sets error on bounds /// violation. FORY_ALWAYS_INLINE uint64_t read_uint64(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 8 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 8, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(8, error))) { return 0; } uint64_t value = reinterpret_cast(data_ + reader_index_)[0]; reader_index_ += 8; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } /// Read int64_t value from buffer (fixed 8 bytes). Sets error on bounds /// violation. FORY_ALWAYS_INLINE int64_t read_int64(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 8 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 8, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(8, error))) { return 0; } int64_t value = reinterpret_cast(data_ + reader_index_)[0]; reader_index_ += 8; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } /// Read float value from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE float read_float(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 4 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 4, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(4, error))) { return 0.0f; } float value = reinterpret_cast(data_ + reader_index_)[0]; reader_index_ += 4; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } /// Read double value from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE double read_double(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 8 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 8, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(8, error))) { return 0.0; } double value = reinterpret_cast(data_ + reader_index_)[0]; reader_index_ += 8; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } /// Read uint32_t value as varint from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE uint32_t read_var_uint32(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 1 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 1, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { return 0; } + if (FORY_PREDICT_FALSE(stream_ != nullptr && size_ - reader_index_ < 5)) { + return read_var_uint32_stream(error); + } uint32_t read_bytes = 0; uint32_t value = get_var_uint32(reader_index_, &read_bytes); increase_reader_index(read_bytes); @@ -837,10 +963,16 @@ class Buffer { /// Read int32_t value as varint (zigzag encoded). Sets error on bounds /// violation. FORY_ALWAYS_INLINE int32_t read_var_int32(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 1 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 1, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { return 0; } + if (FORY_PREDICT_FALSE(stream_ != nullptr && size_ - reader_index_ < 5)) { + uint32_t raw = read_var_uint32_stream(error); + if (FORY_PREDICT_FALSE(!error.ok())) { + return 0; + } + return static_cast((raw >> 1) ^ (~(raw & 1) + 1)); + } uint32_t read_bytes = 0; uint32_t raw = get_var_uint32(reader_index_, &read_bytes); increase_reader_index(read_bytes); @@ -849,10 +981,12 @@ class Buffer { /// Read uint64_t value as varint from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE uint64_t read_var_uint64(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 1 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 1, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { return 0; } + if (FORY_PREDICT_FALSE(stream_ != nullptr && size_ - reader_index_ < 9)) { + return read_var_uint64_stream(error); + } uint32_t read_bytes = 0; uint64_t value = get_var_uint64(reader_index_, &read_bytes); increase_reader_index(read_bytes); @@ -886,22 +1020,26 @@ class Buffer { /// If bit 0 is 0, return value >> 1 (arithmetic shift). /// Otherwise, skip flag byte and read 8 bytes as int64. FORY_ALWAYS_INLINE int64_t read_tagged_int64(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 4 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 4, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(4, error))) { return 0; } int32_t i = reinterpret_cast(data_ + reader_index_)[0]; if ((i & 0b1) != 0b1) { reader_index_ += 4; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return static_cast(i >> 1); // arithmetic right shift } else { - if (FORY_PREDICT_FALSE(reader_index_ + 9 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 9, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(9, error))) { return 0; } int64_t value = reinterpret_cast(data_ + reader_index_ + 1)[0]; reader_index_ += 9; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } } @@ -925,32 +1063,38 @@ class Buffer { /// If bit 0 is 0, return value >> 1. /// Otherwise, skip flag byte and read 8 bytes as uint64. FORY_ALWAYS_INLINE uint64_t read_tagged_uint64(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 4 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 4, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(4, error))) { return 0; } uint32_t i = reinterpret_cast(data_ + reader_index_)[0]; if ((i & 0b1) != 0b1) { reader_index_ += 4; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return static_cast(i >> 1); } else { - if (FORY_PREDICT_FALSE(reader_index_ + 9 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 9, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(9, error))) { return 0; } uint64_t value = reinterpret_cast(data_ + reader_index_ + 1)[0]; reader_index_ += 9; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } return value; } } /// Read uint64_t value as varuint36small. Sets error on bounds violation. FORY_ALWAYS_INLINE uint64_t read_var_uint36_small(Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + 1 > size_)) { - error.set_buffer_out_of_bound(reader_index_, 1, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { return 0; } + if (FORY_PREDICT_FALSE(stream_ != nullptr && size_ - reader_index_ < 8)) { + return read_var_uint36_small_stream(error); + } uint32_t offset = reader_index_; // Fast path: need at least 8 bytes for safe bulk read if (FORY_PREDICT_TRUE(size_ - offset >= 8)) { @@ -1008,8 +1152,7 @@ class Buffer { /// Read raw bytes from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE void read_bytes(void *data, uint32_t length, Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + length > size_)) { - error.set_buffer_out_of_bound(reader_index_, length, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(length, error))) { return; } copy(reader_index_, length, static_cast(data)); @@ -1018,8 +1161,7 @@ class Buffer { /// skip bytes in buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE void skip(uint32_t length, Error &error) { - if (FORY_PREDICT_FALSE(reader_index_ + length > size_)) { - error.set_buffer_out_of_bound(reader_index_, length, size_); + if (FORY_PREDICT_FALSE(!ensure_readable(length, error))) { return; } increase_reader_index(length); @@ -1117,12 +1259,130 @@ class Buffer { std::string hex() const; private: + FORY_ALWAYS_INLINE bool fill_to(uint32_t target_size, Error &error) { + if (FORY_PREDICT_TRUE(target_size <= size_)) { + return true; + } + if (FORY_PREDICT_TRUE(stream_ == nullptr)) { + return false; + } + stream_->reader_index(reader_index_); + auto fill_result = stream_->fill_buffer(target_size - reader_index_); + if (FORY_PREDICT_FALSE(!fill_result.ok())) { + error = std::move(fill_result).error(); + return false; + } + data_ = stream_->data(); + size_ = stream_->size(); + reader_index_ = stream_->reader_index(); + writer_index_ = size_; + return true; + } + + FORY_ALWAYS_INLINE bool ensure_readable(uint32_t length, Error &error) { + const uint64_t target = static_cast(reader_index_) + length; + if (FORY_PREDICT_TRUE(target <= size_)) { + return true; + } + if (FORY_PREDICT_TRUE(stream_ == nullptr)) { + error.set_buffer_out_of_bound(reader_index_, length, size_); + return false; + } + if (FORY_PREDICT_FALSE(target > std::numeric_limits::max())) { + error.set_error(ErrorCode::OutOfBound, + "reader index exceeds uint32 range"); + return false; + } + if (FORY_PREDICT_FALSE(!fill_to(static_cast(target), error))) { + if (error.ok()) { + error.set_buffer_out_of_bound(reader_index_, length, size_); + } + return false; + } + if (FORY_PREDICT_FALSE(target > size_)) { + error.set_buffer_out_of_bound(reader_index_, length, size_); + return false; + } + return true; + } + + FORY_ALWAYS_INLINE uint32_t read_var_uint32_stream(Error &error) { + uint32_t result = 0; + for (int i = 0; i < 5; ++i) { + if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { + return 0; + } + uint8_t b = data_[reader_index_]; + reader_index_ += 1; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } + result |= static_cast(b & 0x7F) << (i * 7); + if ((b & 0x80) == 0) { + return result; + } + } + error.set_error(ErrorCode::InvalidData, "Invalid var_uint32 encoding"); + return 0; + } + + FORY_ALWAYS_INLINE uint64_t read_var_uint64_stream(Error &error) { + uint64_t result = 0; + for (int i = 0; i < 8; ++i) { + if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { + return 0; + } + uint8_t b = data_[reader_index_]; + reader_index_ += 1; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } + result |= static_cast(b & 0x7F) << (i * 7); + if ((b & 0x80) == 0) { + return result; + } + } + if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { + return 0; + } + uint8_t b = data_[reader_index_]; + reader_index_ += 1; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } + result |= static_cast(b) << 56; + return result; + } + + FORY_ALWAYS_INLINE uint64_t read_var_uint36_small_stream(Error &error) { + uint64_t result = 0; + for (int i = 0; i < 5; ++i) { + if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { + return 0; + } + uint8_t b = data_[reader_index_]; + reader_index_ += 1; + if (FORY_PREDICT_FALSE(stream_ != nullptr)) { + stream_->reader_index(reader_index_); + } + result |= static_cast(b & 0x7F) << (i * 7); + if ((b & 0x80) == 0) { + return result; + } + } + error.set_error(ErrorCode::InvalidData, + "Invalid var_uint36_small encoding"); + return 0; + } + uint8_t *data_; uint32_t size_; bool own_data_; uint32_t writer_index_; uint32_t reader_index_; std::vector *wrapped_vector_ = nullptr; + ForyInputStream *stream_ = nullptr; + std::shared_ptr stream_owner_; }; /// \brief Allocate a fixed-size mutable buffer from the default memory pool diff --git a/cpp/fory/util/buffer_test.cc b/cpp/fory/util/buffer_test.cc index 7263e49053..5b07ad5cfb 100644 --- a/cpp/fory/util/buffer_test.cc +++ b/cpp/fory/util/buffer_test.cc @@ -28,6 +28,47 @@ namespace fory { +class OneByteStreamBuf : public std::streambuf { +public: + explicit OneByteStreamBuf(std::vector data) + : data_(std::move(data)), pos_(0) {} + +protected: + std::streamsize xsgetn(char *s, std::streamsize count) override { + if (pos_ >= data_.size() || count <= 0) { + return 0; + } + s[0] = static_cast(data_[pos_]); + ++pos_; + return 1; + } + + int_type underflow() override { + if (pos_ >= data_.size()) { + return traits_type::eof(); + } + current_ = static_cast(data_[pos_]); + setg(¤t_, ¤t_, ¤t_ + 1); + return traits_type::to_int_type(current_); + } + +private: + std::vector data_; + size_t pos_; + char current_ = 0; +}; + +class OneByteIStream : public std::istream { +public: + explicit OneByteIStream(std::vector data) + : std::istream(nullptr), buf_(std::move(data)) { + rdbuf(&buf_); + } + +private: + OneByteStreamBuf buf_; +}; + TEST(Buffer, to_string) { std::shared_ptr buffer; allocate_buffer(16, &buffer); @@ -119,6 +160,77 @@ TEST(Buffer, TestGetBytesAsInt64) { EXPECT_TRUE(buffer->get_bytes_as_int64(0, 1, &result).ok()); EXPECT_EQ(result, 100); } + +TEST(Buffer, StreamReadFromOneByteSource) { + std::vector raw; + raw.reserve(64); + Buffer writer(raw); + writer.write_uint32(0x01020304U); + writer.write_int64(-1234567890LL); + writer.write_var_uint32(300); + writer.write_var_int64(-4567890123LL); + writer.write_tagged_uint64(0x123456789ULL); + writer.write_var_uint36_small(0x1FFFFULL); + + raw.resize(writer.writer_index()); + OneByteIStream one_byte_stream(raw); + ForyInputStream stream(one_byte_stream, 8); + Buffer reader(stream); + Error error; + + EXPECT_EQ(reader.read_uint32(error), 0x01020304U); + ASSERT_TRUE(error.ok()) << error.to_string(); + EXPECT_EQ(reader.read_int64(error), -1234567890LL); + ASSERT_TRUE(error.ok()) << error.to_string(); + EXPECT_EQ(reader.read_var_uint32(error), 300U); + ASSERT_TRUE(error.ok()) << error.to_string(); + EXPECT_EQ(reader.read_var_int64(error), -4567890123LL); + ASSERT_TRUE(error.ok()) << error.to_string(); + EXPECT_EQ(reader.read_tagged_uint64(error), 0x123456789ULL); + ASSERT_TRUE(error.ok()) << error.to_string(); + EXPECT_EQ(reader.read_var_uint36_small(error), 0x1FFFFULL); + ASSERT_TRUE(error.ok()) << error.to_string(); +} + +TEST(Buffer, StreamGetAndReaderIndexFromOneByteSource) { + std::vector raw{0x11, 0x22, 0x33, 0x44, 0x55}; + OneByteIStream one_byte_stream(raw); + ForyInputStream stream(one_byte_stream, 2); + Buffer reader(stream); + + EXPECT_EQ(reader.get(0), 0x44332211U); + reader.reader_index(4); + Error error; + EXPECT_EQ(reader.read_uint8(error), 0x55); + ASSERT_TRUE(error.ok()) << error.to_string(); +} + +TEST(Buffer, StreamRewind) { + std::vector raw{0x01, 0x02, 0x03, 0x04, 0x05}; + OneByteIStream one_byte_stream(raw); + ForyInputStream stream(one_byte_stream, 2); + auto fill_result = stream.fill_buffer(4); + ASSERT_TRUE(fill_result.ok()) << fill_result.error().to_string(); + EXPECT_EQ(stream.size(), 4U); + EXPECT_EQ(stream.reader_index(), 0U); + stream.consume(3); + EXPECT_EQ(stream.reader_index(), 3U); + stream.rewind(2); + EXPECT_EQ(stream.reader_index(), 1U); + stream.consume(1); + EXPECT_EQ(stream.reader_index(), 2U); +} + +TEST(Buffer, StreamReadErrorWhenInsufficientData) { + std::vector raw{0x01, 0x02, 0x03}; + OneByteIStream one_byte_stream(raw); + ForyInputStream stream(one_byte_stream, 2); + Buffer reader(stream); + Error error; + EXPECT_EQ(reader.read_uint32(error), 0U); + EXPECT_FALSE(error.ok()); + EXPECT_EQ(error.code(), ErrorCode::BufferOutOfBound); +} } // namespace fory int main(int argc, char **argv) { diff --git a/cpp/fory/util/stream.h b/cpp/fory/util/stream.h new file mode 100644 index 0000000000..0ef59c2d27 --- /dev/null +++ b/cpp/fory/util/stream.h @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fory/util/error.h" +#include "fory/util/logging.h" +#include "fory/util/result.h" + +namespace fory { + +class ForyInputStreamBuf final : public std::streambuf { +public: + explicit ForyInputStreamBuf(std::istream &stream, uint32_t buffer_size = 4096) + : stream_(&stream), + buffer_(std::max(buffer_size, static_cast(1))) { + char *base = buffer_.data(); + setg(base, base, base); + } + + Result fill_buffer(uint32_t min_fill_size) { + if (min_fill_size == 0) { + return Result(); + } + if (remaining_size() >= min_fill_size) { + return Result(); + } + + const uint32_t read_pos = reader_index(); + const uint32_t cur_size = size(); + const uint64_t required = + static_cast(cur_size) + (min_fill_size - remaining_size()); + if (required > std::numeric_limits::max()) { + return Unexpected( + Error::out_of_bound("stream buffer size exceeds uint32 range")); + } + if (required > buffer_.size()) { + uint64_t new_size = static_cast(buffer_.size()) * 2; + if (new_size < required) { + new_size = required; + } + reserve(static_cast(new_size)); + } + + char *base = buffer_.data(); + uint32_t write_pos = cur_size; + setg(base, base + read_pos, base + write_pos); + while (remaining_size() < min_fill_size) { + uint32_t writable = static_cast(buffer_.size()) - write_pos; + if (writable == 0) { + uint64_t new_size = static_cast(buffer_.size()) * 2 + 1; + if (new_size > std::numeric_limits::max()) { + return Unexpected( + Error::out_of_bound("stream buffer size exceeds uint32 range")); + } + reserve(static_cast(new_size)); + base = buffer_.data(); + writable = static_cast(buffer_.size()) - write_pos; + setg(base, base + read_pos, base + write_pos); + } + std::streambuf *source = stream_->rdbuf(); + if (source == nullptr) { + return Unexpected(Error::io_error("input stream has no stream buffer")); + } + const std::streamsize read_bytes = source->sgetn( + base + write_pos, static_cast(writable)); + if (read_bytes <= 0) { + return Unexpected(Error::buffer_out_of_bound(read_pos, min_fill_size, + remaining_size())); + } + write_pos += static_cast(read_bytes); + setg(base, base + read_pos, base + write_pos); + } + return Result(); + } + + void rewind(uint32_t size) { + const uint32_t consumed = reader_index(); + FORY_CHECK(size <= consumed) + << "rewind size " << size << " exceeds consumed bytes " << consumed; + setg(eback(), gptr() - static_cast(size), egptr()); + } + + void consume(uint32_t size) { + const uint32_t available = remaining_size(); + FORY_CHECK(size <= available) + << "consume size " << size << " exceeds available bytes " << available; + gbump(static_cast(size)); + } + + void reader_index(uint32_t index) { + const uint32_t cur_size = size(); + FORY_CHECK(index <= cur_size) << "reader_index " << index + << " exceeds stream buffer size " << cur_size; + setg(eback(), eback() + static_cast(index), egptr()); + } + + uint8_t *data() { return reinterpret_cast(eback()); } + + uint32_t size() const { return static_cast(egptr() - eback()); } + + uint32_t reader_index() const { + return static_cast(gptr() - eback()); + } + + uint32_t remaining_size() const { + return static_cast(egptr() - gptr()); + } + +protected: + std::streamsize xsgetn(char *s, std::streamsize count) override { + std::streamsize copied = 0; + while (copied < count) { + auto available = static_cast(egptr() - gptr()); + if (available == 0) { + if (!fill_buffer(1).ok()) { + break; + } + available = static_cast(egptr() - gptr()); + } + const std::streamsize n = std::min(available, count - copied); + std::memcpy(s + copied, gptr(), static_cast(n)); + gbump(static_cast(n)); + copied += n; + } + return copied; + } + + int_type underflow() override { + if (gptr() < egptr()) { + return traits_type::to_int_type(*gptr()); + } + if (!fill_buffer(1).ok()) { + return traits_type::eof(); + } + return traits_type::to_int_type(*gptr()); + } + +private: + void reserve(uint32_t new_size) { + const uint32_t old_size = size(); + const uint32_t old_reader_index = reader_index(); + buffer_.resize(new_size); + char *base = buffer_.data(); + setg(base, base + old_reader_index, base + old_size); + } + + std::istream *stream_; + std::vector buffer_; +}; + +class ForyInputStream final : public std::basic_istream { +public: + explicit ForyInputStream(std::istream &stream, uint32_t buffer_size = 4096) + : std::basic_istream(nullptr), streambuf_(stream, buffer_size) { + this->init(&streambuf_); + } + + explicit ForyInputStream(std::shared_ptr stream, + uint32_t buffer_size = 4096) + : std::basic_istream(nullptr), stream_owner_(std::move(stream)), + streambuf_(*stream_owner_, buffer_size) { + FORY_CHECK(stream_owner_ != nullptr) << "stream must not be null"; + this->init(&streambuf_); + } + + Result fill_buffer(uint32_t min_fill_size) { + return streambuf_.fill_buffer(min_fill_size); + } + + void rewind(uint32_t size) { streambuf_.rewind(size); } + + void consume(uint32_t size) { streambuf_.consume(size); } + + void reader_index(uint32_t index) { streambuf_.reader_index(index); } + + uint8_t *data() { return streambuf_.data(); } + + uint32_t size() const { return streambuf_.size(); } + + uint32_t reader_index() const { return streambuf_.reader_index(); } + + uint32_t remaining_size() const { return streambuf_.remaining_size(); } + +private: + std::shared_ptr stream_owner_; + ForyInputStreamBuf streambuf_; +}; + +} // namespace fory diff --git a/python/pyfory/buffer.pyx b/python/pyfory/buffer.pyx index 3f8e0935c6..547f09c1b4 100644 --- a/python/pyfory/buffer.pyx +++ b/python/pyfory/buffer.pyx @@ -34,6 +34,7 @@ from pyfory.includes.libutil cimport( CBuffer, allocate_buffer, get_bit as c_get_bit, set_bit as c_set_bit, clear_bit as c_clear_bit, set_bit_to as c_set_bit_to, CError, CErrorCode, CResultVoidError, utf16_has_surrogate_pairs ) +from pyfory.includes.libpyfory cimport Fory_PyCreateBufferFromStream import os from pyfory.error import raise_fory_error @@ -52,8 +53,25 @@ cdef class _SharedBufferOwner: cdef class Buffer: def __init__(self, data not None, int32_t offset=0, length=None): self.data = data - cdef int32_t buffer_len = len(data) + cdef int32_t buffer_len cdef int length_ + cdef CBuffer* stream_buffer + cdef c_string stream_error + try: + buffer_len = len(data) + except TypeError: + if not hasattr(data, "read"): + raise + if offset != 0 or length is not None: + raise ValueError("offset and length are unsupported for stream input") + if Fory_PyCreateBufferFromStream(data, 4096, &stream_buffer, &stream_error) != 0: + raise ValueError(stream_error.decode("UTF-8")) + if stream_buffer == NULL: + raise ValueError("failed to create stream buffer") + self.c_buffer = move(deref(stream_buffer)) + del stream_buffer + self.data = data + return if length is None: length_ = buffer_len - offset else: @@ -196,7 +214,14 @@ cdef class Buffer: cpdef inline check_bound(self, int32_t offset, int32_t length): cdef int32_t size_ = self.c_buffer.size() + cdef int64_t target = 0 if offset | length | (offset + length) | (size_- (offset + length)) < 0: + if offset >= 0 and length >= 0: + target = offset + length + if target <= 0xFFFFFFFF and self.c_buffer.ensure_size(target, self._error): + return + self._raise_if_error() + size_ = self.c_buffer.size() raise_fory_error( CErrorCode.BufferOutOfBound, f"Address range {offset, offset + length} out of bound {0, size_}", diff --git a/python/pyfory/includes/libpyfory.pxd b/python/pyfory/includes/libpyfory.pxd new file mode 100644 index 0000000000..c0e888c18d --- /dev/null +++ b/python/pyfory/includes/libpyfory.pxd @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from cpython.object cimport PyObject +from libc.stdint cimport uint32_t +from libcpp.string cimport string as c_string + +from pyfory.includes.libutil cimport CBuffer + +cdef extern from "fory/python/pyfory.h" namespace "fory": + int Fory_PyCreateBufferFromStream(PyObject* stream, uint32_t buffer_size, + CBuffer** out, c_string* error_message) diff --git a/python/pyfory/includes/libutil.pxd b/python/pyfory/includes/libutil.pxd index 3328889b09..bc14ae5442 100644 --- a/python/pyfory/includes/libutil.pxd +++ b/python/pyfory/includes/libutil.pxd @@ -71,6 +71,8 @@ cdef extern from "fory/util/buffer.h" namespace "fory" nogil: inline uint32_t reader_index() + inline c_bool ensure_size(uint32_t target_size, CError& error) + inline void writer_index(uint32_t writer_index) inline void increase_writer_index(uint32_t diff) diff --git a/python/pyfory/tests/test_buffer.py b/python/pyfory/tests/test_buffer.py index 9080d966b9..992b6bb88f 100644 --- a/python/pyfory/tests/test_buffer.py +++ b/python/pyfory/tests/test_buffer.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import pytest + from pyfory.buffer import Buffer from pyfory.tests.core import require_pyarrow from pyfory.utils import lazy_import @@ -22,6 +24,24 @@ pa = lazy_import("pyarrow") +class OneByteStream: + def __init__(self, data: bytes): + self._data = data + self._offset = 0 + + def read(self, size=-1): + if self._offset >= len(self._data): + return b"" + if size < 0: + size = len(self._data) - self._offset + if size == 0: + return b"" + read_size = min(1, size, len(self._data) - self._offset) + start = self._offset + self._offset += read_size + return self._data[start : start + read_size] + + def test_buffer(): buffer = Buffer.allocate(8) buffer.write_bool(True) @@ -248,5 +268,42 @@ def test_read_bytes_as_int64(): buf.read_bytes_as_int64(8) +def test_stream_buffer_read(): + writer = Buffer.allocate(32) + writer.write_uint32(0x01020304) + writer.write_int64(-1234567890) + writer.write_var_uint32(300) + writer.write_varint64(-4567890123) + writer.write_tagged_uint64(0x123456789) + writer.write_var_uint64(0x1FFFF) + writer.write_bytes_and_size(b"stream-data") + writer.write_string("hello-stream") + + data = writer.get_bytes(0, writer.get_writer_index()) + stream = OneByteStream(data) + reader = Buffer(stream) + + assert reader.read_uint32() == 0x01020304 + assert reader.read_int64() == -1234567890 + assert reader.read_var_uint32() == 300 + assert reader.read_varint64() == -4567890123 + assert reader.read_tagged_uint64() == 0x123456789 + assert reader.read_var_uint64() == 0x1FFFF + assert reader.read_bytes_and_size() == b"stream-data" + assert reader.read_string() == "hello-stream" + + +def test_stream_buffer_set_reader_index(): + reader = Buffer(OneByteStream(bytes([0x11, 0x22, 0x33, 0x44, 0x55]))) + reader.set_reader_index(4) + assert reader.read_uint8() == 0x55 + + +def test_stream_buffer_short_read_error(): + reader = Buffer(OneByteStream(b"\x01\x02\x03")) + with pytest.raises(Exception, match="Buffer out of bound"): + reader.read_uint32() + + if __name__ == "__main__": test_grow() diff --git a/python/pyfory/tests/test_stream.py b/python/pyfory/tests/test_stream.py new file mode 100644 index 0000000000..c6054c2656 --- /dev/null +++ b/python/pyfory/tests/test_stream.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import pyfory +from pyfory.buffer import Buffer + + +class OneByteStream: + def __init__(self, data: bytes): + self._data = data + self._offset = 0 + + def read(self, size=-1): + if self._offset >= len(self._data): + return b"" + if size < 0: + size = len(self._data) - self._offset + if size == 0: + return b"" + read_size = min(1, size, len(self._data) - self._offset) + start = self._offset + self._offset += read_size + return self._data[start : start + read_size] + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_stream_roundtrip_primitives_and_strings(xlang): + fory = pyfory.Fory(xlang=xlang, ref=True) + values = [ + 0, + -123456789, + 3.1415926, + "stream-hello", + "stream-你好-😀", + b"binary-data" * 8, + [1, 2, 3, 5, 8], + ] + + for value in values: + data = fory.serialize(value) + restored = fory.deserialize(Buffer(OneByteStream(data))) + assert restored == value + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_stream_roundtrip_nested_collections(xlang): + fory = pyfory.Fory(xlang=xlang, ref=True) + value = { + "name": "stream-object", + "items": [1, 2, {"k1": "v1", "k2": [3, 4, 5]}], + "scores": {"a": 10, "b": 20}, + "flags": [True, False, True], + } + + data = fory.serialize(value) + restored = fory.deserialize(Buffer(OneByteStream(data))) + assert restored == value + + +def test_stream_roundtrip_reference_graph_python_mode(): + fory = pyfory.Fory(xlang=False, ref=True) + shared = ["x", 1, 2] + value = {"a": shared, "b": shared} + cycle = [] + cycle.append(cycle) + + data_ref = fory.serialize(value) + restored_ref = fory.deserialize(Buffer(OneByteStream(data_ref))) + assert restored_ref["a"] == shared + assert restored_ref["a"] is restored_ref["b"] + + data_cycle = fory.serialize(cycle) + restored_cycle = fory.deserialize(Buffer(OneByteStream(data_cycle))) + assert restored_cycle[0] is restored_cycle + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_stream_deserialize_multiple_objects_from_single_stream(xlang): + fory = pyfory.Fory(xlang=xlang, ref=True) + expected = [ + 2026, + "multi-object-stream", + {"k": [1, 2, 3], "nested": {"x": True}}, + [10, 20, 30, 40], + ] + + write_buffer = Buffer.allocate(1024) + for obj in expected: + fory.serialize(obj, write_buffer) + + reader = Buffer(OneByteStream(write_buffer.get_bytes(0, write_buffer.get_writer_index()))) + for obj in expected: + assert fory.deserialize(reader) == obj + + assert reader.get_reader_index() == write_buffer.get_writer_index() + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_stream_deserialize_truncated_error(xlang): + fory = pyfory.Fory(xlang=xlang, ref=True) + data = fory.serialize({"k": "value", "numbers": [1, 2, 3, 4]}) + truncated = data[:-1] + + with pytest.raises(Exception): + fory.deserialize(Buffer(OneByteStream(truncated)))