Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bindings/cpp/include/svs/runtime/vamana_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ struct SVS_RUNTIME_API VamanaIndex {
IDFilter* filter = nullptr
) const noexcept = 0;

// Compute distance between stored vector `id` and `query` (dim floats).
virtual Status
get_distance(float* distance, size_t id, const float* query) const noexcept = 0;

// Reconstruct `n` vectors by ID into `output` buffer (n * dim floats).
virtual Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept = 0;

// Utility function to check storage kind support
static Status check_storage_kind(StorageKind storage_kind) noexcept;

Expand Down
16 changes: 16 additions & 0 deletions bindings/cpp/src/dynamic_vamana_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ struct DynamicVamanaIndexManagerBase : public DynamicVamanaIndex {
Status save(std::ostream& out) const noexcept override {
return runtime_error_wrapper([&] { impl_->save(out); });
}

Status
get_distance(float* distance, size_t id, const float* query) const noexcept override {
return runtime_error_wrapper([&] {
std::span<const float> q{query, impl_->dimensions()};
*distance = static_cast<float>(impl_->get_distance(id, q));
});
}

Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override {
return runtime_error_wrapper([&] {
svs::data::SimpleDataView<float> dst{output, n, impl_->dimensions()};
std::span<const size_t> id_span{ids, n};
impl_->reconstruct_at(dst, id_span);
});
}
};
} // namespace

Expand Down
15 changes: 15 additions & 0 deletions bindings/cpp/src/dynamic_vamana_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,117 +84,117 @@
impl_->add_points(data, labels);
}

void search(
svs::QueryResultView<size_t> result,
svs::data::ConstSimpleDataView<float> queries,
const VamanaIndex::SearchParams* params = nullptr,
IDFilter* filter = nullptr
) const {
if (!impl_) {
auto& dists = result.distances();
std::fill(dists.begin(), dists.end(), Unspecify<float>());
auto& inds = result.indices();
std::fill(inds.begin(), inds.end(), Unspecify<size_t>());
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}

if (queries.size() == 0) {
return;
}

const size_t k = result.n_neighbors();
if (k == 0) {
throw StatusException{ErrorCode::INVALID_ARGUMENT, "k must be greater than 0"};
}

auto sp = make_search_parameters(params);

// Simple search
if (filter == nullptr) {
impl_->search(result, queries, sp);
return;
}

// Selective search with IDSelector
auto old_sp = impl_->get_search_parameters();
impl_->set_search_parameters(sp);
float filter_stop = 0.0f;
bool filter_estimate_batch = true;
if (params) {
set_if_specified(filter_stop, params->filter_stop);
set_if_specified(filter_estimate_batch, params->filter_estimate_batch);
}
const auto max_batch_size = impl_->size();

// Pre-search filter sampling: estimate hit rate before graph traversal.
size_t sampled = 0;
size_t sample_hits = 0;
const auto sws = sp.buffer_config_.get_search_window_size();
const auto initial_batch_hint = std::max(k, sws);
auto initial_batch_size = initial_batch_hint;
if (filter_estimate_batch) {
std::tie(sampled, sample_hits) = sample_filter_hits(
*filter,
max_batch_size,
[this](size_t id) { return impl_->has_id(id); },
sample_size_for_filter_stop(filter_stop)
);
if (should_stop_filtered_search(sampled, sample_hits, filter_stop)) {
pad_empty_results(result, queries.size(), k);
impl_->set_search_parameters(old_sp);
return;
}
initial_batch_size = predict_further_processing(
sampled, sample_hits, k, initial_batch_hint, max_batch_size
);
}

auto search_closure = [&](const auto& range, uint64_t SVS_UNUSED(tid)) {
for (auto i : range) {
auto query = queries.get_datum(i);
auto iterator = impl_->batch_iterator(query);
size_t found = 0;
size_t total_checked = 0;
auto batch_size = initial_batch_size;
do {
batch_size = predict_further_processing(
total_checked, found, k, batch_size, max_batch_size
);
iterator.next(batch_size);
total_checked += iterator.size();
for (auto& neighbor : iterator.results()) {
if (filter->is_member(neighbor.id())) {
result.set(neighbor, i, found);
found++;
if (found == k) {
break;
}
}
}
if (should_stop_filtered_search(total_checked, found, filter_stop)) {
found = 0;
break;
}
} while (found < k && !iterator.done());

// Pad results if not enough neighbors found
if (found < k) {
for (size_t j = found; j < k; ++j) {
result.set(Neighbor{Unspecify<size_t>(), Unspecify<float>()}, i, j);
}
}
}
};

auto threadpool = default_threadpool();

svs::threads::parallel_for(
threadpool, svs::threads::StaticPartition{queries.size()}, search_closure
);

impl_->set_search_parameters(old_sp);
}

Check notice on line 197 in bindings/cpp/src/dynamic_vamana_index_impl.h

View check run for this annotation

codefactor.io / CodeFactor

bindings/cpp/src/dynamic_vamana_index_impl.h#L87-L197

Complex Method
void range_search(
svs::data::ConstSimpleDataView<float> queries,
float radius,
Expand Down Expand Up @@ -344,6 +344,21 @@
return remove(ids_to_delete);
}

double get_distance(size_t id, std::span<const float> query) const {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
return impl_->get_distance(id, query);
}

void reconstruct_at(svs::data::SimpleDataView<float> dst, std::span<const size_t> ids) {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
std::vector<uint64_t> id_vec(ids.begin(), ids.end());
impl_->reconstruct_at(dst, std::span<const uint64_t>{id_vec});
}

void reset() {
impl_.reset();
ntotal_soft_deleted = 0;
Expand Down
16 changes: 16 additions & 0 deletions bindings/cpp/src/vamana_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ struct VamanaIndexManagerBase : public VamanaIndex {
Status save(std::ostream& out) const noexcept override {
return runtime_error_wrapper([&] { impl_->save(out); });
}

Status
get_distance(float* distance, size_t id, const float* query) const noexcept override {
return runtime_error_wrapper([&] {
std::span<const float> q{query, impl_->dimensions()};
*distance = static_cast<float>(impl_->get_distance(id, q));
});
}

Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override {
return runtime_error_wrapper([&] {
svs::data::SimpleDataView<float> dst{output, n, impl_->dimensions()};
std::span<const size_t> id_span{ids, n};
impl_->reconstruct_at(dst, id_span);
});
}
};
} // namespace

Expand Down
15 changes: 15 additions & 0 deletions bindings/cpp/src/vamana_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,117 +87,117 @@
"Vamana index does not support adding points after initialization"};
}

void search(
svs::QueryResultView<size_t> result,
svs::data::ConstSimpleDataView<float> queries,
const VamanaIndex::SearchParams* params = nullptr,
IDFilter* filter = nullptr
) const {
if (!impl_) {
auto& dists = result.distances();
std::fill(dists.begin(), dists.end(), Unspecify<float>());
auto& inds = result.indices();
std::fill(inds.begin(), inds.end(), Unspecify<size_t>());
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}

if (queries.size() == 0) {
return;
}

const size_t k = result.n_neighbors();
if (k == 0) {
throw StatusException{ErrorCode::INVALID_ARGUMENT, "k must be greater than 0"};
}

auto sp = make_search_parameters(params);

// Simple search
if (filter == nullptr) {
get_impl()->search(result, queries, sp);
return;
}

// Selective search with IDSelector
auto old_sp = get_impl()->get_search_parameters();
auto sp_restore = svs::lib::make_scope_guard([&]() noexcept {
get_impl()->set_search_parameters(old_sp);
});
get_impl()->set_search_parameters(sp);
float filter_stop = 0.0f;
bool filter_estimate_batch = true;
if (params) {
set_if_specified(filter_stop, params->filter_stop);
set_if_specified(filter_estimate_batch, params->filter_estimate_batch);
}
const auto max_batch_size = get_impl()->size();

// Pre-search filter sampling: estimate hit rate before graph traversal.
size_t sampled = 0;
size_t sample_hits = 0;
const auto sws = sp.buffer_config_.get_search_window_size();
const auto initial_batch_hint = std::max(k, sws);
auto initial_batch_size = initial_batch_hint;
if (filter_estimate_batch) {
std::tie(sampled, sample_hits) = sample_filter_hits(
*filter,
max_batch_size,
[](size_t) { return true; },
sample_size_for_filter_stop(filter_stop)
);
if (should_stop_filtered_search(sampled, sample_hits, filter_stop)) {
pad_empty_results(result, queries.size(), k);
return;
}
initial_batch_size = predict_further_processing(
sampled, sample_hits, k, initial_batch_hint, max_batch_size
);
}

auto search_closure = [&](const auto& range, uint64_t SVS_UNUSED(tid)) {
for (auto i : range) {
auto query = queries.get_datum(i);
auto iterator = get_impl()->batch_iterator(query);
size_t found = 0;
size_t total_checked = 0;
auto batch_size = initial_batch_size;
do {
batch_size = predict_further_processing(
total_checked, found, k, batch_size, max_batch_size
);
iterator.next(batch_size);
total_checked += iterator.size();
for (auto& neighbor : iterator.results()) {
if (filter->is_member(neighbor.id())) {
result.set(neighbor, i, found);
found++;
if (found == k) {
break;
}
}
}
if (should_stop_filtered_search(total_checked, found, filter_stop)) {
found = 0;
break;
}
} while (found < k && !iterator.done());

// Pad results if not enough neighbors found
if (found < k) {
for (size_t j = found; j < k; ++j) {
result.set(Neighbor{Unspecify<size_t>(), Unspecify<float>()}, i, j);
}
}
}
};

auto threadpool = default_threadpool();

svs::threads::parallel_for(
threadpool, svs::threads::StaticPartition{queries.size()}, search_closure
);
}

Check notice on line 200 in bindings/cpp/src/vamana_index_impl.h

View check run for this annotation

codefactor.io / CodeFactor

bindings/cpp/src/vamana_index_impl.h#L90-L200

Complex Method
void range_search(
svs::data::ConstSimpleDataView<float> queries,
float radius,
Expand Down Expand Up @@ -307,6 +307,21 @@
}
}

double get_distance(size_t id, std::span<const float> query) const {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
return get_impl()->get_distance(id, query);
}

void reconstruct_at(svs::data::SimpleDataView<float> dst, std::span<const size_t> ids) {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
std::vector<uint64_t> id_vec(ids.begin(), ids.end());
get_impl()->reconstruct_at(dst, std::span<const uint64_t>{id_vec});
}

void reset() { impl_.reset(); }

void save(std::ostream& out) const { get_impl()->save(out); }
Expand Down
138 changes: 138 additions & 0 deletions bindings/cpp/tests/runtime_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,3 +997,141 @@ CATCH_TEST_CASE("RangeSearchFunctionalStatic", "[runtime][static_vamana]") {

svs::runtime::v0::VamanaIndex::destroy(index);
}

CATCH_TEST_CASE("GetDistanceDynamic", "[runtime]") {
const auto& test_data = get_test_data();
svs::runtime::v0::DynamicVamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::DynamicVamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

std::vector<size_t> labels(test_n);
std::iota(labels.begin(), labels.end(), 0);
status = index->add(test_n, labels.data(), test_data.data());
CATCH_REQUIRE(status.ok());

// Self-distance should be approximately 0
float dist = -1.0f;
const float* vec0 = test_data.data();
status = index->get_distance(&dist, 0, vec0);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist < 1e-6);

// Distance to a different vector should be positive
const float* vec1 = test_data.data() + test_d;
status = index->get_distance(&dist, 0, vec1);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist > 0.0);

svs::runtime::v0::DynamicVamanaIndex::destroy(index);
}

CATCH_TEST_CASE("GetDistanceStatic", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
svs::runtime::v0::VamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::VamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

status = index->add(test_n, test_data.data());
CATCH_REQUIRE(status.ok());

// Self-distance should be approximately 0
float dist = -1.0f;
const float* vec0 = test_data.data();
status = index->get_distance(&dist, 0, vec0);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist < 1e-6);

// Distance to a different vector should be positive
const float* vec1 = test_data.data() + test_d;
status = index->get_distance(&dist, 0, vec1);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist > 0.0);

svs::runtime::v0::VamanaIndex::destroy(index);
}

CATCH_TEST_CASE("ReconstructAtDynamic", "[runtime]") {
const auto& test_data = get_test_data();
svs::runtime::v0::DynamicVamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::DynamicVamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

std::vector<size_t> labels(test_n);
std::iota(labels.begin(), labels.end(), 0);
status = index->add(test_n, labels.data(), test_data.data());
CATCH_REQUIRE(status.ok());

// Reconstruct first 5 vectors
constexpr size_t nrecon = 5;
std::vector<size_t> ids(nrecon);
std::iota(ids.begin(), ids.end(), 0);
std::vector<float> output(nrecon * test_d, 0.0f);

status = index->reconstruct_at(nrecon, ids.data(), output.data());
CATCH_REQUIRE(status.ok());

// For FP32 storage, reconstructed vectors should match originals exactly
for (size_t i = 0; i < nrecon; ++i) {
for (size_t j = 0; j < test_d; ++j) {
CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]);
}
}

svs::runtime::v0::DynamicVamanaIndex::destroy(index);
}

CATCH_TEST_CASE("ReconstructAtStatic", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
svs::runtime::v0::VamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::VamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

status = index->add(test_n, test_data.data());
CATCH_REQUIRE(status.ok());

// Reconstruct first 5 vectors
constexpr size_t nrecon = 5;
std::vector<size_t> ids(nrecon);
std::iota(ids.begin(), ids.end(), 0);
std::vector<float> output(nrecon * test_d, 0.0f);

status = index->reconstruct_at(nrecon, ids.data(), output.data());
CATCH_REQUIRE(status.ok());

// For FP32 storage, reconstructed vectors should match originals exactly
for (size_t i = 0; i < nrecon; ++i) {
for (size_t j = 0; j < test_d; ++j) {
CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]);
}
}

svs::runtime::v0::VamanaIndex::destroy(index);
}
Loading