Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions bindings/cpp/include/svs/runtime/flat_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ struct SVS_RUNTIME_API FlatIndex {

// Load from a memory buffer.
// The buffer is expected to be in the format produced by save().
static Status
map_to_memory(FlatIndex** index, void* data, size_t size, MetricType metric) noexcept;
static Status map_to_memory(
FlatIndex** index,
void* data,
size_t size,
MetricType metric,
size_t* read_bytes = nullptr
) noexcept;
};

} // namespace v0
Expand Down
3 changes: 2 additions & 1 deletion bindings/cpp/include/svs/runtime/vamana_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ struct SVS_RUNTIME_API VamanaIndex {
void* data,
size_t size,
MetricType metric,
StorageKind storage_kind
StorageKind storage_kind,
size_t* read_bytes = nullptr
) noexcept;
};

Expand Down
6 changes: 5 additions & 1 deletion bindings/cpp/src/flat_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,18 @@ FlatIndex::map_to_file(FlatIndex** index, const char* path, MetricType metric) n
}

Status FlatIndex::map_to_memory(
FlatIndex** index, void* data, size_t size, MetricType metric
FlatIndex** index, void* data, size_t size, MetricType metric, size_t* read_bytes
) noexcept {
*index = nullptr;
return runtime_error_wrapper([&] {
auto sp = std::span(reinterpret_cast<char*>(data), size);
auto is = std::make_unique<svs::io::ispanstream>(sp);
std::unique_ptr<FlatIndexImpl> impl{
FlatIndexImpl::map_to_stream(std::move(is), metric)};
if (read_bytes) {
auto pos = impl->get_mapped_stream()->tellg();
*read_bytes = static_cast<size_t>(pos);
}
*index = new FlatIndexManager{std::move(impl)};
return Status_Ok;
});
Expand Down
2 changes: 2 additions & 0 deletions bindings/cpp/src/flat_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class FlatIndexImpl {
});
}

std::istream* get_mapped_stream() const { return mapped_stream_.get(); }

protected:
// Constructor used during loading
FlatIndexImpl(
Expand Down
7 changes: 6 additions & 1 deletion bindings/cpp/src/vamana_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ Status VamanaIndex::map_to_memory(
void* data,
size_t size,
MetricType metric,
StorageKind storage_kind
StorageKind storage_kind,
size_t* read_bytes
) noexcept {
using Impl = VamanaIndexImpl;
*index = nullptr;
Expand All @@ -188,6 +189,10 @@ Status VamanaIndex::map_to_memory(
auto is = std::make_unique<svs::io::ispanstream>(sp);
std::unique_ptr<Impl> impl{
Impl::map_to_stream(std::move(is), metric, storage_kind)};
if (read_bytes) {
auto pos = impl->get_mapped_stream()->tellg();
*read_bytes = static_cast<size_t>(pos);
}
*index = new VamanaIndexManagerBase<Impl>{std::move(impl)};
});
}
Expand Down
2 changes: 2 additions & 0 deletions bindings/cpp/src/vamana_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ class VamanaIndexImpl {
);
}

std::istream* get_mapped_stream() const { return mapped_stream_.get(); }

// Data members
protected:
size_t dim_;
Expand Down
177 changes: 177 additions & 0 deletions bindings/cpp/tests/runtime_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,119 @@ void write_and_map_index(
Index::destroy(loaded);
}

// Template function to write and map an index from an in-memory buffer
template <typename Index, typename BuildFunc>
void write_and_map_index_from_memory(
BuildFunc build_func,
const std::vector<float>& xb,
size_t n,
size_t d,
std::optional<svs::runtime::v0::StorageKind> storage_kind = std::nullopt,
svs::runtime::v0::MetricType metric = svs::runtime::v0::MetricType::L2
) {
// Build index
Index* index = nullptr;
svs::runtime::v0::Status status = build_func(&index);

// Stop here if storage kind is not supported on this platform
if constexpr (std::is_base_of_v<svs::runtime::v0::VamanaIndex, Index>) {
if (storage_kind.has_value()) {
if (!Index::check_storage_kind(*storage_kind).ok()) {
CATCH_REQUIRE(!status.ok());
return;
}
}
}
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(index != nullptr);

// Add data to index
if constexpr (std::is_same_v<Index, svs::runtime::v0::FlatIndex> || std::is_same_v<Index, svs::runtime::v0::VamanaIndex>) {
status = index->add(n, xb.data());
} else {
std::vector<size_t> labels(n);
std::iota(labels.begin(), labels.end(), 0);
status = index->add(n, labels.data(), xb.data());
}
CATCH_REQUIRE(status.ok());

std::stringstream ss{};
status = index->save(ss);
CATCH_REQUIRE(status.ok());

auto payload = ss.str();
CATCH_REQUIRE(!payload.empty());
const size_t serialized_size = payload.size();

std::vector<char> serialized(payload.size() + 1); // +1 for sentinel
std::copy(payload.begin(), payload.end(), serialized.begin());
// Keep a trailing sentinel to validate read_bytes for map_to_memory.
serialized[payload.size()] = static_cast<char>(0x5A);

Index* loaded = nullptr;
size_t read_bytes = 0;
if constexpr (std::is_same_v<Index, svs::runtime::v0::FlatIndex>) {
status = Index::map_to_memory(
&loaded, serialized.data(), serialized.size(), metric, &read_bytes
);
} else {
CATCH_REQUIRE(storage_kind.has_value());
status = Index::map_to_memory(
&loaded,
serialized.data(),
serialized.size(),
metric,
*storage_kind,
&read_bytes
);
}
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(loaded != nullptr);

CATCH_REQUIRE(read_bytes < serialized.size());
CATCH_REQUIRE(read_bytes == serialized_size);
CATCH_REQUIRE(serialized[read_bytes] == static_cast<char>(0x5A));

// Basic query against mapped index
const int nq = 5;
const float* xq = xb.data();
const int k = 10;
std::vector<float> distances(nq * k);
std::vector<size_t> result_labels(nq * k);
status = loaded->search(nq, xq, k, distances.data(), result_labels.data());
CATCH_REQUIRE(status.ok());

// Verify nullptr read_bytes path for map_to_memory.
Index* loaded_without_read_bytes = nullptr;
if constexpr (std::is_same_v<Index, svs::runtime::v0::FlatIndex>) {
status = Index::map_to_memory(
&loaded_without_read_bytes,
serialized.data(),
serialized.size(),
metric,
nullptr
);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(loaded_without_read_bytes != nullptr);
} else {
status = Index::map_to_memory(
&loaded_without_read_bytes,
serialized.data(),
serialized.size(),
metric,
*storage_kind,
nullptr
);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(loaded_without_read_bytes != nullptr);
}

// Clean up
Index::destroy(index);
Index::destroy(loaded);
Index::destroy(loaded_without_read_bytes);
}

// Helper that writes and reads and index of requested size
// Reports memory usage
UsageInfo run_save_and_load_test(
Expand Down Expand Up @@ -540,6 +653,18 @@ CATCH_TEST_CASE("FlatIndexWriteAndMap", "[runtime]") {
write_and_map_index<svs::runtime::v0::FlatIndex>(build_func, test_data, test_n, test_d);
}

CATCH_TEST_CASE("FlatIndexWriteAndMapToMemory", "[runtime]") {
const auto& test_data = get_test_data();
auto build_func = [](svs::runtime::v0::FlatIndex** index) {
return svs::runtime::v0::FlatIndex::build(
index, test_d, svs::runtime::v0::MetricType::L2
);
};
write_and_map_index_from_memory<svs::runtime::v0::FlatIndex>(
build_func, test_data, test_n, test_d
);
}

CATCH_TEST_CASE("SearchWithIDFilter", "[runtime]") {
const auto& test_data = get_test_data();
// Build index
Expand Down Expand Up @@ -910,6 +1035,23 @@ CATCH_TEST_CASE("WriteAndMapStaticIndexSVS", "[runtime][static_vamana]") {
);
}

CATCH_TEST_CASE("WriteAndMapToMemoryStaticIndexSVS", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
auto build_func = [](svs::runtime::v0::VamanaIndex** index) {
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
return svs::runtime::v0::VamanaIndex::build(
index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
};
write_and_map_index_from_memory<svs::runtime::v0::VamanaIndex>(
build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::FP32
);
}

CATCH_TEST_CASE("WriteAndReadStaticIndexSVSFP16", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
auto build_func = [](svs::runtime::v0::VamanaIndex** index) {
Expand Down Expand Up @@ -1012,6 +1154,23 @@ CATCH_TEST_CASE("WriteAndMapStaticIndexSVSLVQ4x4", "[runtime][static_vamana]") {
);
}

CATCH_TEST_CASE("WriteAndMapToMemoryStaticIndexSVSLVQ4x4", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
auto build_func = [](svs::runtime::v0::VamanaIndex** index) {
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
return svs::runtime::v0::VamanaIndex::build(
index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::LVQ4x4,
build_params
);
};
write_and_map_index_from_memory<svs::runtime::v0::VamanaIndex>(
build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LVQ4x4
);
}

CATCH_TEST_CASE("WriteAndReadStaticIndexSVSVamanaLeanVec4x4", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
auto build_func = [](svs::runtime::v0::VamanaIndex** index) {
Expand Down Expand Up @@ -1048,6 +1207,24 @@ CATCH_TEST_CASE("WriteAndMapStaticIndexSVSVamanaLeanVec4x4", "[runtime][static_v
);
}

CATCH_TEST_CASE("WriteAndMapToMemoryStaticIndexSVSLeanVec4x4", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
auto build_func = [](svs::runtime::v0::VamanaIndex** index) {
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
return svs::runtime::v0::VamanaIndexLeanVec::build(
index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::LeanVec4x4,
32,
build_params
);
};
write_and_map_index_from_memory<svs::runtime::v0::VamanaIndex>(
build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LeanVec4x4
);
}

CATCH_TEST_CASE("StaticIndexLeanVecWithTrainingData", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
const size_t leanvec_dims = 32;
Expand Down
Loading