diff --git a/bindings/cpp/include/svs/runtime/flat_index.h b/bindings/cpp/include/svs/runtime/flat_index.h index 87410e61..efb59081 100644 --- a/bindings/cpp/include/svs/runtime/flat_index.h +++ b/bindings/cpp/include/svs/runtime/flat_index.h @@ -52,8 +52,13 @@ struct SVS_RUNTIME_API FlatIndex { // Load from a memory buffer. // The buffer is expected to be in the format produced by save(). - static Status - map_to_memory(FlatIndex** index, void* data, size_t size, MetricType metric) noexcept; + static Status map_to_memory( + FlatIndex** index, + void* data, + size_t size, + MetricType metric, + size_t* read_bytes = nullptr + ) noexcept; }; } // namespace v0 diff --git a/bindings/cpp/include/svs/runtime/vamana_index.h b/bindings/cpp/include/svs/runtime/vamana_index.h index b70d4777..722b413b 100644 --- a/bindings/cpp/include/svs/runtime/vamana_index.h +++ b/bindings/cpp/include/svs/runtime/vamana_index.h @@ -124,7 +124,8 @@ struct SVS_RUNTIME_API VamanaIndex { void* data, size_t size, MetricType metric, - StorageKind storage_kind + StorageKind storage_kind, + size_t* read_bytes = nullptr ) noexcept; }; diff --git a/bindings/cpp/src/flat_index.cpp b/bindings/cpp/src/flat_index.cpp index bce25b5e..2514bdec 100644 --- a/bindings/cpp/src/flat_index.cpp +++ b/bindings/cpp/src/flat_index.cpp @@ -130,7 +130,7 @@ FlatIndex::map_to_file(FlatIndex** index, const char* path, MetricType metric) n } Status FlatIndex::map_to_memory( - FlatIndex** index, void* data, size_t size, MetricType metric + FlatIndex** index, void* data, size_t size, MetricType metric, size_t* read_bytes ) noexcept { *index = nullptr; return runtime_error_wrapper([&] { @@ -138,6 +138,10 @@ Status FlatIndex::map_to_memory( auto is = std::make_unique(sp); std::unique_ptr impl{ FlatIndexImpl::map_to_stream(std::move(is), metric)}; + if (read_bytes) { + auto pos = impl->get_mapped_stream()->tellg(); + *read_bytes = static_cast(pos); + } *index = new FlatIndexManager{std::move(impl)}; return Status_Ok; }); diff --git a/bindings/cpp/src/flat_index_impl.h b/bindings/cpp/src/flat_index_impl.h index cdfa2200..921563fd 100644 --- a/bindings/cpp/src/flat_index_impl.h +++ b/bindings/cpp/src/flat_index_impl.h @@ -134,6 +134,8 @@ class FlatIndexImpl { }); } + std::istream* get_mapped_stream() const { return mapped_stream_.get(); } + protected: // Constructor used during loading FlatIndexImpl( diff --git a/bindings/cpp/src/vamana_index.cpp b/bindings/cpp/src/vamana_index.cpp index 3d4acccd..195164c5 100644 --- a/bindings/cpp/src/vamana_index.cpp +++ b/bindings/cpp/src/vamana_index.cpp @@ -179,7 +179,8 @@ Status VamanaIndex::map_to_memory( void* data, size_t size, MetricType metric, - StorageKind storage_kind + StorageKind storage_kind, + size_t* read_bytes ) noexcept { using Impl = VamanaIndexImpl; *index = nullptr; @@ -188,6 +189,10 @@ Status VamanaIndex::map_to_memory( auto is = std::make_unique(sp); std::unique_ptr impl{ Impl::map_to_stream(std::move(is), metric, storage_kind)}; + if (read_bytes) { + auto pos = impl->get_mapped_stream()->tellg(); + *read_bytes = static_cast(pos); + } *index = new VamanaIndexManagerBase{std::move(impl)}; }); } diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index 3102400f..c3927a16 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -533,6 +533,8 @@ class VamanaIndexImpl { ); } + std::istream* get_mapped_stream() const { return mapped_stream_.get(); } + // Data members protected: size_t dim_; diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index db68a519..f0972b7f 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -230,6 +230,119 @@ void write_and_map_index( Index::destroy(loaded); } +// Template function to write and map an index from an in-memory buffer +template +void write_and_map_index_from_memory( + BuildFunc build_func, + const std::vector& xb, + size_t n, + size_t d, + std::optional storage_kind = std::nullopt, + svs::runtime::v0::MetricType metric = svs::runtime::v0::MetricType::L2 +) { + // Build index + Index* index = nullptr; + svs::runtime::v0::Status status = build_func(&index); + + // Stop here if storage kind is not supported on this platform + if constexpr (std::is_base_of_v) { + if (storage_kind.has_value()) { + if (!Index::check_storage_kind(*storage_kind).ok()) { + CATCH_REQUIRE(!status.ok()); + return; + } + } + } + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data to index + if constexpr (std::is_same_v || std::is_same_v) { + status = index->add(n, xb.data()); + } else { + std::vector labels(n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(n, labels.data(), xb.data()); + } + CATCH_REQUIRE(status.ok()); + + std::stringstream ss{}; + status = index->save(ss); + CATCH_REQUIRE(status.ok()); + + auto payload = ss.str(); + CATCH_REQUIRE(!payload.empty()); + const size_t serialized_size = payload.size(); + + std::vector serialized(payload.size() + 1); // +1 for sentinel + std::copy(payload.begin(), payload.end(), serialized.begin()); + // Keep a trailing sentinel to validate read_bytes for map_to_memory. + serialized[payload.size()] = static_cast(0x5A); + + Index* loaded = nullptr; + size_t read_bytes = 0; + if constexpr (std::is_same_v) { + status = Index::map_to_memory( + &loaded, serialized.data(), serialized.size(), metric, &read_bytes + ); + } else { + CATCH_REQUIRE(storage_kind.has_value()); + status = Index::map_to_memory( + &loaded, + serialized.data(), + serialized.size(), + metric, + *storage_kind, + &read_bytes + ); + } + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(loaded != nullptr); + + CATCH_REQUIRE(read_bytes < serialized.size()); + CATCH_REQUIRE(read_bytes == serialized_size); + CATCH_REQUIRE(serialized[read_bytes] == static_cast(0x5A)); + + // Basic query against mapped index + const int nq = 5; + const float* xq = xb.data(); + const int k = 10; + std::vector distances(nq * k); + std::vector result_labels(nq * k); + status = loaded->search(nq, xq, k, distances.data(), result_labels.data()); + CATCH_REQUIRE(status.ok()); + + // Verify nullptr read_bytes path for map_to_memory. + Index* loaded_without_read_bytes = nullptr; + if constexpr (std::is_same_v) { + status = Index::map_to_memory( + &loaded_without_read_bytes, + serialized.data(), + serialized.size(), + metric, + nullptr + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(loaded_without_read_bytes != nullptr); + } else { + status = Index::map_to_memory( + &loaded_without_read_bytes, + serialized.data(), + serialized.size(), + metric, + *storage_kind, + nullptr + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(loaded_without_read_bytes != nullptr); + } + + // Clean up + Index::destroy(index); + Index::destroy(loaded); + Index::destroy(loaded_without_read_bytes); +} + // Helper that writes and reads and index of requested size // Reports memory usage UsageInfo run_save_and_load_test( @@ -540,6 +653,18 @@ CATCH_TEST_CASE("FlatIndexWriteAndMap", "[runtime]") { write_and_map_index(build_func, test_data, test_n, test_d); } +CATCH_TEST_CASE("FlatIndexWriteAndMapToMemory", "[runtime]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::FlatIndex** index) { + return svs::runtime::v0::FlatIndex::build( + index, test_d, svs::runtime::v0::MetricType::L2 + ); + }; + write_and_map_index_from_memory( + build_func, test_data, test_n, test_d + ); +} + CATCH_TEST_CASE("SearchWithIDFilter", "[runtime]") { const auto& test_data = get_test_data(); // Build index @@ -910,6 +1035,23 @@ CATCH_TEST_CASE("WriteAndMapStaticIndexSVS", "[runtime][static_vamana]") { ); } +CATCH_TEST_CASE("WriteAndMapToMemoryStaticIndexSVS", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndex::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + }; + write_and_map_index_from_memory( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::FP32 + ); +} + CATCH_TEST_CASE("WriteAndReadStaticIndexSVSFP16", "[runtime][static_vamana]") { const auto& test_data = get_test_data(); auto build_func = [](svs::runtime::v0::VamanaIndex** index) { @@ -1012,6 +1154,23 @@ CATCH_TEST_CASE("WriteAndMapStaticIndexSVSLVQ4x4", "[runtime][static_vamana]") { ); } +CATCH_TEST_CASE("WriteAndMapToMemoryStaticIndexSVSLVQ4x4", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndex::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::LVQ4x4, + build_params + ); + }; + write_and_map_index_from_memory( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LVQ4x4 + ); +} + CATCH_TEST_CASE("WriteAndReadStaticIndexSVSVamanaLeanVec4x4", "[runtime][static_vamana]") { const auto& test_data = get_test_data(); auto build_func = [](svs::runtime::v0::VamanaIndex** index) { @@ -1048,6 +1207,24 @@ CATCH_TEST_CASE("WriteAndMapStaticIndexSVSVamanaLeanVec4x4", "[runtime][static_v ); } +CATCH_TEST_CASE("WriteAndMapToMemoryStaticIndexSVSLeanVec4x4", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndexLeanVec::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::LeanVec4x4, + 32, + build_params + ); + }; + write_and_map_index_from_memory( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LeanVec4x4 + ); +} + CATCH_TEST_CASE("StaticIndexLeanVecWithTrainingData", "[runtime][static_vamana]") { const auto& test_data = get_test_data(); const size_t leanvec_dims = 32;