diff --git a/include/svs/core/data/simple.h b/include/svs/core/data/simple.h index 890cda372..8f15ff717 100644 --- a/include/svs/core/data/simple.h +++ b/include/svs/core/data/simple.h @@ -75,7 +75,7 @@ class GenericSerializer { } template - static lib::SaveTable save_table(const Data& data) { + static lib::SaveTable metadata(const Data& data) { using T = typename Data::element_type; auto table = lib::SaveTable( serialization_schema, @@ -92,8 +92,8 @@ class GenericSerializer { template static lib::SaveTable - save_table(const Data& data, const FileName_t& filename, const lib::UUID& uuid) { - auto table = save_table(data); + metadata(const Data& data, const FileName_t& filename, const lib::UUID& uuid) { + auto table = metadata(data); table.insert("binary_file", filename); table.insert("uuid", uuid.str()); return table; @@ -105,7 +105,7 @@ class GenericSerializer { auto uuid = lib::UUID{}; auto filename = ctx.generate_name("data"); io::save(data, io::NativeFile(filename), uuid); - return save_table(data, lib::save(filename.filename()), uuid); + return metadata(data, lib::save(filename.filename()), uuid); } template @@ -452,7 +452,7 @@ class SimpleData { void save(std::ostream& os) const { return GenericSerializer::save(*this, os); } - lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); } + lib::SaveTable metadata() const { return GenericSerializer::metadata(*this); } static bool check_load_compatibility(std::string_view schema, lib::Version version) { return GenericSerializer::check_compatibility(schema, version); @@ -871,7 +871,7 @@ class SimpleData> { void save(std::ostream& os) const { return GenericSerializer::save(*this, os); } - lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); } + lib::SaveTable metadata() const { return GenericSerializer::metadata(*this); } static bool check_load_compatibility(std::string_view schema, lib::Version version) { return GenericSerializer::check_compatibility(schema, version); diff --git a/include/svs/core/graph/graph.h b/include/svs/core/graph/graph.h index 48d3e8048..68357d157 100644 --- a/include/svs/core/graph/graph.h +++ b/include/svs/core/graph/graph.h @@ -276,22 +276,36 @@ template class SimpleGrap ///// Saving static constexpr lib::Version save_version = lib::Version(0, 0, 0); static constexpr std::string_view serialization_schema = "default_graph"; - lib::SaveTable save(const lib::SaveContext& ctx) const { - auto uuid = lib::UUID{}; - auto filename = ctx.generate_name("graph"); - io::save(data_, io::NativeFile(filename), uuid); - return lib::SaveTable( + + lib::SaveTable metadata() const { + auto table = lib::SaveTable( serialization_schema, save_version, {{"name", "graph"}, - {"binary_file", lib::save(filename.filename())}, {"max_degree", lib::save(max_degree())}, {"num_vertices", lib::save(n_nodes())}, - {"uuid", lib::save(uuid.str())}, {"eltype", lib::save(datatype_v)}} ); + return table; + } + + template + lib::SaveTable metadata(const FileName& filename, const lib::UUID& uuid) const { + auto table = metadata(); + table.insert("binary_file", filename); + table.insert("uuid", uuid.str()); + return table; + } + + lib::SaveTable save(const lib::SaveContext& ctx) const { + auto uuid = lib::UUID{}; + auto filename = ctx.generate_name("graph"); + io::save(data_, io::NativeFile(filename), uuid); + return metadata(lib::save(filename.filename()), uuid); } + void save(std::ostream& os) const { io::save(data_, os); } + protected: template F, typename... Args> static lib::lazy_result_t @@ -317,6 +331,39 @@ template class SimpleGrap return lazy(data_type::load(binaryfile.value(), std::forward(args)...)); } + template F, typename... AllocArgs> + static lib::lazy_result_t load( + const lib::ContextFreeLoadTable& table, + const F& lazy, + const lib::detail::Deserializer& deserializer, + std::istream& is, + AllocArgs&&... alloc_args + ) { + // Perform a sanity check on the element type. + // Make sure we're loading the correct kind. + auto eltype = lib::load_at(table, "eltype"); + if (eltype != datatype_v) { + throw ANNEXCEPTION( + "Trying to load a graph with adjacency list types {} to a graph with " + "adjacency list types {}.", + name(eltype), + name>() + ); + } + + size_t num_vertices = lib::load_at(table, "num_vertices"); + size_t max_degree = lib::load_at(table, "max_degree"); + + // Skip legacy stream metadata (no-op for native serialization). + deserializer.read_name(is); + deserializer.read_size(is); + deserializer.read_binary(is); + + auto data = data_type(num_vertices, max_degree + 1, alloc_args...); + io::populate(is, data); + return lazy(std::move(data)); + } + protected: data_type data_; Idx max_degree_; @@ -366,6 +413,16 @@ class SimpleGraph : public SimpleGraphBase(deserializer, is, allocator); + } }; template diff --git a/include/svs/core/translation.h b/include/svs/core/translation.h index 1b1fde9c6..ae2dc4880 100644 --- a/include/svs/core/translation.h +++ b/include/svs/core/translation.h @@ -380,16 +380,17 @@ class IDTranslator { return translator; } - static IDTranslator load( - const lib::ContextFreeLoadTable& table, - const lib::detail::Deserializer& deserializer, - std::istream& is - ) { - IDTranslator::validate(table); + static IDTranslator + load(const lib::detail::Deserializer& deserializer, std::istream& is) { + auto table = lib::detail::read_metadata(deserializer, is); + auto translation = table.template cast() + .at("translation") + .template cast(); + IDTranslator::validate(translation); deserializer.read_name(is); deserializer.read_size(is); - return IDTranslator::load(table, is); + return IDTranslator::load(translation, is); } static IDTranslator load(const lib::LoadTable& table) { diff --git a/include/svs/index/flat/dynamic_flat.h b/include/svs/index/flat/dynamic_flat.h index 26946a220..849b80b7f 100644 --- a/include/svs/index/flat/dynamic_flat.h +++ b/include/svs/index/flat/dynamic_flat.h @@ -788,16 +788,6 @@ auto auto_dynamic_assemble( ); } -auto load_translator(const lib::detail::Deserializer& deserializer, std::istream& is) { - auto table = lib::detail::begin_deserialization(deserializer, is); - auto translator = IDTranslator::load( - table.template cast().at("translation").template cast(), - deserializer, - is - ); - return translator; -} - template auto auto_dynamic_assemble( const lib::detail::Deserializer& deserializer, @@ -813,38 +803,44 @@ auto auto_dynamic_assemble( bool SVS_UNUSED(debug_load_from_static) = false, svs::logging::logger_ptr logger = svs::logging::get() ) { - IDTranslator translator; - // In legacy deserialization the order of directories isn't determined. - auto name = deserializer.read_name_in_advance(is); - - // We have to hardcode the file_name for legacy mode, since it was hardcoded when legacy - // model was serialized - bool translator_before_data = - (name == "config/svs_config.toml") || deserializer.is_native(); - if (translator_before_data) { - translator = load_translator(deserializer, is); - } - - // Load the dataset - auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); - auto data = svs::detail::dispatch_load(data_loader(), threadpool); - auto datasize = data.size(); - - if (!translator_before_data) { - translator = load_translator(deserializer, is); + using Data = decltype(data_loader()); + auto config_loader = [&] { return IDTranslator::load(deserializer, is); }; + + std::optional config; + std::optional data; + + if (deserializer.is_native()) { + // Order is always config->data. + config.emplace(config_loader()); + data.emplace(data_loader()); + } else { + // Directory packing order is filesystem-dependent. + // Read 2 data blocks: config and data in a corresponding order. + for (int data_block_idx = 0; data_block_idx < 2; ++data_block_idx) { + auto name = deserializer.read_name_in_advance(is); + if (name.starts_with("config/")) { + config.emplace(config_loader()); + } else if (name.starts_with("data/")) { + data.emplace(data_loader()); + } else { + throw ANNEXCEPTION("The stream is corrupted!"); + } + } } // Validate the translator - auto translator_size = translator.size(); + auto translator_size = config->size(); + auto datasize = data->size(); if (translator_size != datasize) { throw ANNEXCEPTION( "Translator has {} IDs but should have {}", translator_size, datasize ); } + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); return DynamicFlatIndex( - std::move(data), - std::move(translator), + std::move(*data), + std::move(*config), std::move(distance), std::move(threadpool), std::move(logger) diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 169be1995..3a1778fe1 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -988,6 +988,18 @@ class MutableVamanaIndex { ///// Saving + VamanaIndexParameters parameters() const { + return { + entry_point_.front(), + {alpha_, + graph_.max_degree(), + get_construction_window_size(), + get_max_candidates(), + prune_to_, + get_full_search_history()}, + get_search_parameters()}; + } + static constexpr lib::Version save_version = lib::Version(0, 0, 0); void save( const std::filesystem::path& config_directory, @@ -1003,22 +1015,12 @@ class MutableVamanaIndex { lib::save_to_disk( lib::SaveOverride([&](const lib::SaveContext& ctx) { // Save the construction parameters. - auto parameters = VamanaIndexParameters{ - entry_point_.front(), - {alpha_, - graph_.max_degree(), - get_construction_window_size(), - get_max_candidates(), - prune_to_, - get_full_search_history()}, - get_search_parameters()}; - return lib::SaveTable( "vamana_dynamic_auxiliary_parameters", save_version, { {"name", lib::save(name())}, - {"parameters", lib::save(parameters, ctx)}, + {"parameters", lib::save(parameters(), ctx)}, {"translation", lib::save(translator_, ctx)}, } ); @@ -1032,6 +1034,31 @@ class MutableVamanaIndex { lib::save_to_disk(graph_, graph_directory); } + void save(std::ostream& os) { + // Post-consolidation, all entries should be "valid". + // Therefore, we don't need to save the slot metadata. + consolidate(); + compact(); + + lib::begin_serialization(os); + auto save_table = lib::SaveTable( + "vamana_dynamic_auxiliary_parameters", + save_version, + { + {"name", lib::save(name())}, + {"parameters", lib::save(parameters())}, + {"translation", lib::detail::exit_hook(translator_.save_table())}, + } + ); + lib::save_to_stream(save_table, os); + translator_.save(os); + + // Save the dataset. + lib::save_to_stream(data_, os); + // Save the graph. + lib::save_to_stream(graph_, os); + } + ///// ///// Calibrate ///// diff --git a/include/svs/index/vamana/index.h b/include/svs/index/vamana/index.h index b7c136645..02a93c615 100644 --- a/include/svs/index/vamana/index.h +++ b/include/svs/index/vamana/index.h @@ -85,8 +85,7 @@ struct VamanaIndexParameters { static constexpr lib::Version save_version = lib::Version(0, 0, 3); static constexpr std::string_view serialization_schema = "vamana_index_parameters"; - // Save and Reload. - lib::SaveTable save() const { + lib::SaveTable metadata() const { return lib::SaveTable( serialization_schema, save_version, @@ -97,6 +96,9 @@ struct VamanaIndexParameters { ); } + // Save to Table and Reload. + lib::SaveTable save() const { return metadata(); } + static bool check_load_compatibility(std::string_view schema, lib::Version version) { return schema == serialization_schema && version <= save_version; } @@ -814,6 +816,20 @@ class VamanaIndex { lib::save_to_disk(graph_, graph_directory); } + void save(std::ostream& os) const { + // Construct and save runtime parameters. + auto parameters = VamanaIndexParameters{ + entry_point_.front(), build_parameters_, get_search_parameters()}; + + lib::begin_serialization(os); + // Config + lib::save_to_stream(parameters, os); + // Data + lib::save_to_stream(data_, os); + // // Graph + lib::save_to_stream(graph_, os); + } + ///// Calibration // Return the maximum degree of the graph. @@ -1006,6 +1022,66 @@ auto auto_assemble( return index; } +template < + typename LazyGraphLoader, + typename LazyDataLoader, + typename Distance, + typename ThreadPoolProto> +auto auto_assemble( + const lib::detail::Deserializer& deserializer, + std::istream& is, + LazyGraphLoader graph_loader, + LazyDataLoader data_loader, + Distance distance, + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() +) { + using Data = decltype(data_loader()); + using Graph = decltype(graph_loader()); + auto config_loader = [&] { + return lib::load_from_stream(deserializer, is); + }; + + std::optional config; + std::optional data; + std::optional graph; + + if (deserializer.is_native()) { + // Order is always config->data->graph. + config.emplace(config_loader()); + data.emplace(data_loader()); + graph.emplace(graph_loader()); + } else { + // Directory packing order is filesystem-dependent. + // Read 3 data blocks: config, data and graph in a corresponding order + for (int data_block_idx = 0; data_block_idx < 3; ++data_block_idx) { + auto name = deserializer.read_name_in_advance(is); + if (name.starts_with("config/")) { + config.emplace(config_loader()); + } else if (name.starts_with("data/")) { + data.emplace(data_loader()); + } else if (name.starts_with("graph/")) { + graph.emplace(graph_loader()); + } else { + throw ANNEXCEPTION("The stream is corrupted!"); + } + } + } + + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + // Extract the index type of the provided graph. + using I = typename Graph::index_type; + auto index = VamanaIndex{ + std::move(*graph), + std::move(*data), + I{}, + std::move(distance), + std::move(threadpool), + std::move(logger)}; + index.apply(*config); + return index; +} + /// @brief Verify parameters and set defaults if needed template void verify_and_set_default_index_parameters( diff --git a/include/svs/lib/saveload/load.h b/include/svs/lib/saveload/load.h index 09e072dee..c3e5591b8 100644 --- a/include/svs/lib/saveload/load.h +++ b/include/svs/lib/saveload/load.h @@ -883,6 +883,10 @@ class Deserializer { } void read_size(std::istream& stream) const { + if (skip_next_name_) { + throw ANNEXCEPTION("Error in deserialization: read_size() shouldn't follow " + "read_name_in_advance()!"); + } if (scheme_ == SerializationScheme::legacy) { lib::StreamArchiver::size_type size = 0; lib::StreamArchiver::read_size(stream, size); @@ -897,7 +901,7 @@ class Deserializer { }; inline ContextFreeSerializedObject -begin_deserialization(const Deserializer& deserializer, std::istream& stream) { +read_metadata(const Deserializer& deserializer, std::istream& stream) { deserializer.read_name(stream); if (!stream) { throw ANNEXCEPTION("Error reading from stream!"); @@ -957,6 +961,10 @@ T load_from_disk(const std::filesystem::path& path, Args&&... args) { return lib::load_from_disk(Loader(), path, SVS_FWD(args)...); } +template +concept LoadableFromTable = requires(const T& x +) { T::load(std::declval(), std::declval()...); }; + ///// load_from_stream template T load_from_stream( @@ -964,12 +972,25 @@ T load_from_stream( const detail::Deserializer& deserializer, std::istream& stream, Args&&... args -) { - // At this point, we will try the saving/loading framework to load the object. - // Here we go! +) + requires LoadableFromTable +{ + // Object is loadable from it's toml::table + return lib::load(loader, detail::read_metadata(deserializer, stream), SVS_FWD(args)...); +} + +template +T load_from_stream( + const Loader& loader, + const detail::Deserializer& deserializer, + std::istream& stream, + Args&&... args +) + requires(!LoadableFromTable) +{ return lib::load( loader, - detail::begin_deserialization(deserializer, stream), + detail::read_metadata(deserializer, stream), deserializer, stream, SVS_FWD(args)... diff --git a/include/svs/lib/saveload/save.h b/include/svs/lib/saveload/save.h index c609151e1..edd627b75 100644 --- a/include/svs/lib/saveload/save.h +++ b/include/svs/lib/saveload/save.h @@ -326,7 +326,7 @@ void save_node_to_stream( Nodelike&& node, std::ostream& os, const lib::Version& version = CURRENT_SAVE_VERSION ) { auto top_table = toml::table( - {{config_version_key, version.str()}, {config_object_key, SVS_FWD(node)}} + {{config_version_key, version.str()}, {config_object_key, SVS_FWD(exit_hook(node))}} ); StreamArchiver::write_table(os, top_table); @@ -381,13 +381,16 @@ inline void begin_serialization(std::ostream& os) { lib::StreamArchiver::write_size(os, lib::StreamArchiver::magic_number); } +inline void save_to_stream(const lib::SaveTable& x, std::ostream& os) { + detail::save_node_to_stream(x, os); +} + template void save_to_stream(const T& x, std::ostream& os) { - if constexpr (requires { x.save_table(); }) { - auto save_table = x.save_table(); - detail::save_node_to_stream(detail::exit_hook(save_table), os); - x.save(os); - } else if constexpr (std::is_same_v) { - detail::save_node_to_stream(detail::exit_hook(x), os); + if constexpr (requires { x.metadata(); }) { + save_to_stream(x.metadata(), os); + if constexpr (requires { x.save(os); }) { + x.save(os); + } } else { static_assert(sizeof(T) == 0, "Type not stream-serializable"); } diff --git a/include/svs/orchestrators/vamana.h b/include/svs/orchestrators/vamana.h index 6b698c4f9..65905cb8c 100644 --- a/include/svs/orchestrators/vamana.h +++ b/include/svs/orchestrators/vamana.h @@ -177,15 +177,7 @@ class VamanaImpl : public manager::ManagerImpl { void save(std::ostream& stream) override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_vamana_save"}; - const auto config_dir = tempdir.get() / "config"; - const auto graph_dir = tempdir.get() / "graph"; - const auto data_dir = tempdir.get() / "data"; - std::filesystem::create_directories(config_dir); - std::filesystem::create_directories(graph_dir); - std::filesystem::create_directories(data_dir); - save(config_dir, graph_dir, data_dir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current Vamana backend doesn't support saving!"); } @@ -474,32 +466,45 @@ class Vamana : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); - - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!"); - } - - const auto graph_path = tempdir.get() / "graph"; - if (!fs::is_directory(graph_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!"); - } - - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); + auto deserializer = svs::lib::detail::Deserializer::build(stream); + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + using GraphType = svs::GraphLoader<>::return_type; + if constexpr (std::is_same_v) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return make_vamana>( + AssembleTag(), + deserializer, + stream, + // lazy-loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream( + deserializer, stream, SVS_FWD(data_args)... + ); + }, + distance_function, + std::move(threadpool) + ); + }); + } else { + return make_vamana>( + AssembleTag(), + deserializer, + stream, + // lazy-loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream( + deserializer, stream, SVS_FWD(data_args)... + ); + }, + distance, + std::move(threadpool) + ); } - - return assemble( - config_path, - svs::GraphLoader{graph_path}, - lib::load_from_disk(data_path, SVS_FWD(data_args)...), - distance, - threads::as_threadpool(std::move(threadpool_proto)) - ); } /// diff --git a/include/svs/quantization/scalar/scalar.h b/include/svs/quantization/scalar/scalar.h index a2244d1fd..992d533cf 100644 --- a/include/svs/quantization/scalar/scalar.h +++ b/include/svs/quantization/scalar/scalar.h @@ -497,6 +497,18 @@ class SQDataset { ); } + lib::SaveTable metadata() const { + return lib::SaveTable( + serialization_schema, + save_version, + {{"data", lib::detail::exit_hook(data_.metadata())}, + {"scale", lib::save(scale_)}, + {"bias", lib::save(bias_)}} + ); + } + + void save(std::ostream& os) const { data_.save(os); } + /// @brief Load dataset from a file. static SQDataset load(const lib::LoadTable& table, const allocator_type& allocator = {}) { @@ -506,6 +518,19 @@ class SQDataset { lib::load_at(table, "bias")}; } + /// @brief Load dataset from a stream. + static SQDataset load( + const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, + std::istream& is, + const allocator_type& allocator = {} + ) { + return SQDataset{ + SVS_LOAD_MEMBER_AT_(table, data, deserializer, is, allocator), + lib::load_at(table, "scale"), + lib::load_at(table, "bias")}; + } + /// @brief Prefetch data in the dataset. void prefetch(size_t i) const { data_.prefetch(i); } }; diff --git a/tests/svs/index/vamana/index.cpp b/tests/svs/index/vamana/index.cpp index b94b902b8..67c5aedd4 100644 --- a/tests/svs/index/vamana/index.cpp +++ b/tests/svs/index/vamana/index.cpp @@ -37,6 +37,7 @@ // svsbenchmark #include "svs-benchmark/benchmark.h" // stl +#include #include namespace { @@ -181,6 +182,126 @@ CATCH_TEST_CASE("Static VamanaIndex Per-Index Logging", "[logging]") { CATCH_REQUIRE(captured_logs[2].find("Batch Size:") != std::string::npos); } +CATCH_TEST_CASE("Vamana Index Save and Load", "[vamana][index][saveload]") { + const size_t N = 128; + using Eltype = float; + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto graph = svs::graphs::SimpleGraph(data.size(), 64); + svs::distance::DistanceL2 distance_function; + uint32_t entry_point = 0; + auto threadpool = svs::threads::DefaultThreadPool(1); + + // Build the VamanaIndex with the test logger + svs::index::vamana::VamanaBuildParameters buildParams(1.2, 64, 10, 20, 10, true); + svs::index::vamana::VamanaIndex index( + buildParams, + std::move(graph), + std::move(data), + entry_point, + distance_function, + std::move(threadpool) + ); + + const size_t NUM_NEIGHBORS = 10; + auto queries = test_dataset::queries(); + auto search_params = svs::index::vamana::VamanaSearchParameters{}; + search_params.buffer_config_ = svs::index::vamana::SearchBufferConfig{NUM_NEIGHBORS}; + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search(results.view(), queries.cview(), search_params); + + CATCH_SECTION("Load Vamana Index being serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_vamana_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto graph_dir = tempdir.get() / "graph"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(graph_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, graph_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + auto deserializer = svs::lib::detail::Deserializer::build(stream); + + using Data_t = svs::data::SimpleData; + using GraphType = svs::GraphLoader<>::return_type; + + auto loaded_index = svs::index::vamana::auto_assemble( + deserializer, + stream, + // lazy graph loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy data loader + [&]() -> Data_t { + return svs::lib::load_from_stream(deserializer, stream); + }, + distance_function, + svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + loaded_index.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load Vamana Index being serialized natively to stream") { + std::stringstream stream; + index.save(stream); + + { + auto deserializer = svs::lib::detail::Deserializer::build(stream); + + using Data_t = svs::data::SimpleData; + using GraphType = svs::GraphLoader<>::return_type; + + auto loaded_index = svs::index::vamana::auto_assemble( + deserializer, + stream, + // lazy graph loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy data loader + [&]() -> Data_t { + return svs::lib::load_from_stream(deserializer, stream); + }, + distance_function, + svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + loaded_index.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } +} + CATCH_TEST_CASE("Vamana Index Default Parameters", "[long][parameter][vamana]") { using Catch::Approx; std::filesystem::path data_path = test_dataset::data_svs_file();