Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions include/svs/core/data/simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class GenericSerializer {
}

template <data::ImmutableMemoryDataset Data>
static lib::SaveTable save_table(const Data& data) {
static lib::SaveTable metadata(const Data& data) {
using T = typename Data::element_type;
auto table = lib::SaveTable(
serialization_schema,
Expand All @@ -92,8 +92,8 @@ class GenericSerializer {

template <data::ImmutableMemoryDataset Data, class FileName_t>
static lib::SaveTable
save_table(const Data& data, const FileName_t& filename, const lib::UUID& uuid) {
auto table = save_table(data);
metadata(const Data& data, const FileName_t& filename, const lib::UUID& uuid) {
auto table = metadata(data);
table.insert("binary_file", filename);
table.insert("uuid", uuid.str());
return table;
Expand All @@ -105,7 +105,7 @@ class GenericSerializer {
auto uuid = lib::UUID{};
auto filename = ctx.generate_name("data");
io::save(data, io::NativeFile(filename), uuid);
return save_table(data, lib::save(filename.filename()), uuid);
return metadata(data, lib::save(filename.filename()), uuid);
}

template <data::ImmutableMemoryDataset Data>
Expand Down Expand Up @@ -452,7 +452,7 @@ class SimpleData {

void save(std::ostream& os) const { return GenericSerializer::save(*this, os); }

lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); }
lib::SaveTable metadata() const { return GenericSerializer::metadata(*this); }

static bool check_load_compatibility(std::string_view schema, lib::Version version) {
return GenericSerializer::check_compatibility(schema, version);
Expand Down Expand Up @@ -871,7 +871,7 @@ class SimpleData<T, Extent, Blocked<Alloc>> {

void save(std::ostream& os) const { return GenericSerializer::save(*this, os); }

lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); }
lib::SaveTable metadata() const { return GenericSerializer::metadata(*this); }

static bool check_load_compatibility(std::string_view schema, lib::Version version) {
return GenericSerializer::check_compatibility(schema, version);
Expand Down
79 changes: 72 additions & 7 deletions include/svs/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,22 +276,36 @@ template <std::unsigned_integral Idx, data::MemoryDataset Data> class SimpleGrap
///// Saving
static constexpr lib::Version save_version = lib::Version(0, 0, 0);
static constexpr std::string_view serialization_schema = "default_graph";
lib::SaveTable save(const lib::SaveContext& ctx) const {
auto uuid = lib::UUID{};
auto filename = ctx.generate_name("graph");
io::save(data_, io::NativeFile(filename), uuid);
return lib::SaveTable(

lib::SaveTable metadata() const {
auto table = lib::SaveTable(
serialization_schema,
save_version,
{{"name", "graph"},
{"binary_file", lib::save(filename.filename())},
{"max_degree", lib::save(max_degree())},
{"num_vertices", lib::save(n_nodes())},
{"uuid", lib::save(uuid.str())},
{"eltype", lib::save(datatype_v<Idx>)}}
);
return table;
}

template <class FileName>
lib::SaveTable metadata(const FileName& filename, const lib::UUID& uuid) const {
auto table = metadata();
table.insert("binary_file", filename);
table.insert("uuid", uuid.str());
return table;
}

lib::SaveTable save(const lib::SaveContext& ctx) const {
auto uuid = lib::UUID{};
auto filename = ctx.generate_name("graph");
io::save(data_, io::NativeFile(filename), uuid);
return metadata(lib::save(filename.filename()), uuid);
}

void save(std::ostream& os) const { io::save(data_, os); }

protected:
template <lib::LazyInvocable<data_type> F, typename... Args>
static lib::lazy_result_t<F, data_type>
Expand All @@ -317,6 +331,39 @@ template <std::unsigned_integral Idx, data::MemoryDataset Data> class SimpleGrap
return lazy(data_type::load(binaryfile.value(), std::forward<Args>(args)...));
}

template <lib::LazyInvocable<data_type> F, typename... AllocArgs>
static lib::lazy_result_t<F, data_type> load(
const lib::ContextFreeLoadTable& table,
const F& lazy,
const lib::detail::Deserializer& deserializer,
std::istream& is,
AllocArgs&&... alloc_args
) {
// Perform a sanity check on the element type.
// Make sure we're loading the correct kind.
auto eltype = lib::load_at<DataType>(table, "eltype");
if (eltype != datatype_v<Idx>) {
throw ANNEXCEPTION(
"Trying to load a graph with adjacency list types {} to a graph with "
"adjacency list types {}.",
name(eltype),
name<datatype_v<Idx>>()
);
}

size_t num_vertices = lib::load_at<size_t>(table, "num_vertices");
size_t max_degree = lib::load_at<size_t>(table, "max_degree");

// Skip legacy stream metadata (no-op for native serialization).
deserializer.read_name(is);
deserializer.read_size(is);
deserializer.read_binary<io::v1::Header>(is);

auto data = data_type(num_vertices, max_degree + 1, alloc_args...);
io::populate(is, data);
return lazy(std::move(data));
}

protected:
data_type data_;
Idx max_degree_;
Expand Down Expand Up @@ -366,6 +413,16 @@ class SimpleGraph : public SimpleGraphBase<Idx, data::SimpleData<Idx, Dynamic, A
return parent_type::load(table, lazy, allocator);
}

static constexpr SimpleGraph load(
const lib::ContextFreeLoadTable& table,
const lib::detail::Deserializer& deserializer,
std::istream& is,
const Alloc& allocator = {}
) {
auto lazy = lib::Lazy([](data_type data) { return SimpleGraph(std::move(data)); });
return parent_type::load(table, lazy, deserializer, is, allocator);
}

static constexpr SimpleGraph
load(const std::filesystem::path& path, const Alloc& allocator = {}) {
if (data::detail::is_likely_reload(path)) {
Expand All @@ -374,6 +431,14 @@ class SimpleGraph : public SimpleGraphBase<Idx, data::SimpleData<Idx, Dynamic, A
return SimpleGraph(data_type::load(path, allocator));
}
}

static constexpr SimpleGraph load(
const lib::detail::Deserializer& deserializer,
std::istream& is,
const Alloc& allocator = {}
) {
return lib::load_from_stream<SimpleGraph>(deserializer, is, allocator);
}
};

template <typename Idx, typename A1, typename A2>
Expand Down
15 changes: 8 additions & 7 deletions include/svs/core/translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,16 +380,17 @@ class IDTranslator {
return translator;
}

static IDTranslator load(
const lib::ContextFreeLoadTable& table,
const lib::detail::Deserializer& deserializer,
std::istream& is
) {
IDTranslator::validate(table);
static IDTranslator
load(const lib::detail::Deserializer& deserializer, std::istream& is) {
auto table = lib::detail::read_metadata(deserializer, is);
auto translation = table.template cast<toml::table>()
.at("translation")
.template cast<toml::table>();
IDTranslator::validate(translation);
deserializer.read_name(is);
deserializer.read_size(is);

return IDTranslator::load(table, is);
return IDTranslator::load(translation, is);
}

static IDTranslator load(const lib::LoadTable& table) {
Expand Down
60 changes: 28 additions & 32 deletions include/svs/index/flat/dynamic_flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -788,16 +788,6 @@ auto auto_dynamic_assemble(
);
}

auto load_translator(const lib::detail::Deserializer& deserializer, std::istream& is) {
auto table = lib::detail::begin_deserialization(deserializer, is);
auto translator = IDTranslator::load(
table.template cast<toml::table>().at("translation").template cast<toml::table>(),
deserializer,
is
);
return translator;
}

template <typename LazyDataLoader, typename Distance, typename ThreadPoolProto>
auto auto_dynamic_assemble(
const lib::detail::Deserializer& deserializer,
Expand All @@ -813,38 +803,44 @@ auto auto_dynamic_assemble(
bool SVS_UNUSED(debug_load_from_static) = false,
svs::logging::logger_ptr logger = svs::logging::get()
) {
IDTranslator translator;
// In legacy deserialization the order of directories isn't determined.
auto name = deserializer.read_name_in_advance(is);

// We have to hardcode the file_name for legacy mode, since it was hardcoded when legacy
// model was serialized
bool translator_before_data =
(name == "config/svs_config.toml") || deserializer.is_native();
if (translator_before_data) {
translator = load_translator(deserializer, is);
}

// Load the dataset
auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
auto data = svs::detail::dispatch_load(data_loader(), threadpool);
auto datasize = data.size();

if (!translator_before_data) {
translator = load_translator(deserializer, is);
using Data = decltype(data_loader());
auto config_loader = [&] { return IDTranslator::load(deserializer, is); };

std::optional<IDTranslator> config;
std::optional<Data> data;

if (deserializer.is_native()) {
// Order is always config->data.
config.emplace(config_loader());
data.emplace(data_loader());
} else {
// Directory packing order is filesystem-dependent.
// Read 2 data blocks: config and data in a corresponding order.
for (int data_block_idx = 0; data_block_idx < 2; ++data_block_idx) {
auto name = deserializer.read_name_in_advance(is);
if (name.starts_with("config/")) {
config.emplace(config_loader());
} else if (name.starts_with("data/")) {
data.emplace(data_loader());
} else {
throw ANNEXCEPTION("The stream is corrupted!");
}
}
}

// Validate the translator
auto translator_size = translator.size();
auto translator_size = config->size();
auto datasize = data->size();
if (translator_size != datasize) {
throw ANNEXCEPTION(
"Translator has {} IDs but should have {}", translator_size, datasize
);
}

auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
return DynamicFlatIndex(
std::move(data),
std::move(translator),
std::move(*data),
std::move(*config),
std::move(distance),
std::move(threadpool),
std::move(logger)
Expand Down
49 changes: 38 additions & 11 deletions include/svs/index/vamana/dynamic_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,18 @@ class MutableVamanaIndex {

///// Saving

VamanaIndexParameters parameters() const {
return {
entry_point_.front(),
{alpha_,
graph_.max_degree(),
get_construction_window_size(),
get_max_candidates(),
prune_to_,
get_full_search_history()},
get_search_parameters()};
}

static constexpr lib::Version save_version = lib::Version(0, 0, 0);
void save(
const std::filesystem::path& config_directory,
Expand All @@ -1003,22 +1015,12 @@ class MutableVamanaIndex {
lib::save_to_disk(
lib::SaveOverride([&](const lib::SaveContext& ctx) {
// Save the construction parameters.
auto parameters = VamanaIndexParameters{
entry_point_.front(),
{alpha_,
graph_.max_degree(),
get_construction_window_size(),
get_max_candidates(),
prune_to_,
get_full_search_history()},
get_search_parameters()};

return lib::SaveTable(
"vamana_dynamic_auxiliary_parameters",
save_version,
{
{"name", lib::save(name())},
{"parameters", lib::save(parameters, ctx)},
{"parameters", lib::save(parameters(), ctx)},
{"translation", lib::save(translator_, ctx)},
}
);
Expand All @@ -1032,6 +1034,31 @@ class MutableVamanaIndex {
lib::save_to_disk(graph_, graph_directory);
}

void save(std::ostream& os) {
// Post-consolidation, all entries should be "valid".
// Therefore, we don't need to save the slot metadata.
consolidate();
compact();

lib::begin_serialization(os);
auto save_table = lib::SaveTable(
"vamana_dynamic_auxiliary_parameters",
save_version,
{
{"name", lib::save(name())},
{"parameters", lib::save(parameters())},
{"translation", lib::detail::exit_hook(translator_.save_table())},
}
);
lib::save_to_stream(save_table, os);
translator_.save(os);

// Save the dataset.
lib::save_to_stream(data_, os);
// Save the graph.
lib::save_to_stream(graph_, os);
}

/////
///// Calibrate
/////
Expand Down
Loading
Loading