Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
443 changes: 281 additions & 162 deletions diskann-benchmark-core/src/recall.rs

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions diskann-benchmark-core/src/search/graph/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use diskann_utils::{future::AsyncFriendly, views::Matrix};

use crate::{
recall,
recall::GroundTruthMode,
search::{self, Search, graph::Strategy},
utils,
};
Expand Down Expand Up @@ -186,6 +187,7 @@ pub struct Aggregator<'a, I> {
groundtruth: &'a dyn crate::recall::Rows<I>,
recall_k: usize,
recall_n: usize,
groundtruth_mode: GroundTruthMode,
}

impl<'a, I> Aggregator<'a, I> {
Expand All @@ -199,11 +201,13 @@ impl<'a, I> Aggregator<'a, I> {
groundtruth: &'a dyn crate::recall::Rows<I>,
recall_k: usize,
recall_n: usize,
groundtruth_mode: GroundTruthMode,
) -> Self {
Self {
groundtruth,
recall_k,
recall_n,
groundtruth_mode,
}
}
}
Expand All @@ -227,7 +231,7 @@ where
first.ids().as_rows(),
self.recall_k,
self.recall_n,
true,
self.groundtruth_mode,
)?,
None => anyhow::bail!("Results must be non-empty"),
};
Expand Down Expand Up @@ -349,8 +353,12 @@ mod tests {
let recall_k = nearest_neighbors;
let recall_n = nearest_neighbors;

let all =
search::search_all(knn, parameters, Aggregator::new(rows, recall_k, recall_n)).unwrap();
let all = search::search_all(
knn,
parameters,
Aggregator::new(rows, recall_k, recall_n, GroundTruthMode::Fixed),
)
.unwrap();

assert_eq!(all.len(), 2);
for summary in all {
Expand Down
8 changes: 7 additions & 1 deletion diskann-benchmark-core/src/search/graph/multihop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ mod tests {

use super::*;

use crate::recall::GroundTruthMode;
use diskann::graph::{index::QueryLabelProvider, test::provider};

// A simple [`QueryLabelProvider`] that rejects odd indices.
Expand Down Expand Up @@ -224,7 +225,12 @@ mod tests {
let all = search::search_all(
multihop,
parameters,
search::graph::knn::Aggregator::new(rows, recall_k, recall_n),
search::graph::knn::Aggregator::new(
rows,
recall_k,
recall_n,
GroundTruthMode::Flexible,
),
)
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion diskann-benchmark/src/backend/exhaustive/minmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ mod imp {
f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?;

let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?;
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?;
let mut search_results = Vec::<SearchResults>::new();
let threadpool = rayon::ThreadPoolBuilder::new()
.num_threads(input.search.num_threads.get())
Expand Down
2 changes: 1 addition & 1 deletion diskann-benchmark/src/backend/exhaustive/product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ mod imp {
f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?;

let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?;
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?;

let search_progress =
make_progress_bar("running search", queries.nrows(), output.draw_target())?;
Expand Down
2 changes: 1 addition & 1 deletion diskann-benchmark/src/backend/exhaustive/spherical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ mod imp {
f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?;

let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?;
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?;

let search_progress = make_progress_bar(
"running search",
Expand Down
24 changes: 15 additions & 9 deletions diskann-benchmark/src/backend/index/benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,11 @@ where
let queries: Arc<Matrix<DP::Element>> =
Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?);

let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?;
// compute the maximum value of k used in any search
let max_k = topk.max_k();

let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?;

let knn = benchmark_core::search::graph::KNN::new(
index.clone(),
Expand Down Expand Up @@ -650,10 +654,8 @@ fn full_precision_streaming<T>(
where
T: bytemuck::Pod + VectorRepr + WithApproximateNorm + SampleableForStart,
{
let topk = match &input.search_phase {
SearchPhase::Topk(topk) => topk,
_ => anyhow::bail!("Only TopK is currently supported by the streaming index"),
};
let topk = input.search_phase.as_topk()?;

let consolidate_threshold: f32 = input.runbook_params.consolidate_threshold;

let data = datafiles::load_dataset::<T>(datafiles::BinFile(&input.build.data))?;
Expand Down Expand Up @@ -688,10 +690,14 @@ where

let managed = Managed::new(max_points, consolidate_threshold, managed_stream);

let layered = bigann::WithData::new(managed, data, queries, |path| {
Ok(Box::new(datafiles::load_groundtruth(datafiles::BinFile(
path,
))?))
// compute the maximum value of k used in any search
let max_k = topk.max_k();

let layered = bigann::WithData::new(managed, data, queries, move |path| {
Ok(Box::new(datafiles::load_groundtruth(
datafiles::BinFile(path),
Some(max_k),
)?))
});

Ok(layered)
Expand Down
15 changes: 13 additions & 2 deletions diskann-benchmark/src/backend/index/search/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

use std::{num::NonZeroUsize, sync::Arc};

use diskann_benchmark_core::recall::GroundTruthMode;
use diskann_benchmark_core::{self as benchmark_core, search as core_search};

use crate::{backend::index::result::SearchResults, inputs::graph_index::GraphSearch};
Expand Down Expand Up @@ -97,7 +98,12 @@ where
let results = core_search::search_all(
self.clone(),
parameters.into_iter(),
core_search::graph::knn::Aggregator::new(groundtruth, recall_k, recall_n),
core_search::graph::knn::Aggregator::new(
groundtruth,
recall_k,
recall_n,
GroundTruthMode::Fixed,
),
)?;

Ok(results.into_iter().map(SearchResults::new).collect())
Expand All @@ -123,7 +129,12 @@ where
let results = core_search::search_all(
self.clone(),
parameters.into_iter(),
core_search::graph::knn::Aggregator::new(groundtruth, recall_k, recall_n),
core_search::graph::knn::Aggregator::new(
groundtruth,
recall_k,
recall_n,
GroundTruthMode::Flexible,
),
)?;

Ok(results.into_iter().map(SearchResults::new).collect())
Expand Down
8 changes: 6 additions & 2 deletions diskann-benchmark/src/backend/index/spherical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,14 @@ mod imp {
) -> anyhow::Result<AggregatedSearchResults> {
let topk = phase.as_topk()?;

// compute the maximum value of k used in any search
let max_k = topk.max_k();

let queries: Arc<Matrix<f32>> =
Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?);

let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?;
let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?;

let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs);

Expand Down Expand Up @@ -518,7 +522,7 @@ mod imp {
))?);

let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&multihop.groundtruth))?;
datafiles::load_range_groundtruth(datafiles::BinFile(&multihop.groundtruth))?;

let steps =
search::knn::SearchSteps::new(multihop.reps, &multihop.num_threads, &multihop.runs);
Expand Down
4 changes: 4 additions & 0 deletions diskann-benchmark/src/inputs/graph_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ pub(crate) struct TopkSearchPhase {
}

impl TopkSearchPhase {
pub(crate) fn max_k(&self) -> usize {
self.runs.iter().map(|run| run.recall_k).max().unwrap_or(0)
}

pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> {
self.queries.resolve(checker)?;
self.groundtruth.resolve(checker)?;
Expand Down
45 changes: 44 additions & 1 deletion diskann-benchmark/src/utils/datafiles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl ConvertingLoad for f32 {
}

/// Load a groundtruth set from disk and return the result as a row-major matrix.
pub(crate) fn load_groundtruth(path: BinFile<'_>) -> anyhow::Result<Matrix<u32>> {
pub(crate) fn load_groundtruth(path: BinFile<'_>, k: Option<usize>) -> anyhow::Result<Matrix<u32>> {
Comment thread
magdalendobson marked this conversation as resolved.
let provider = diskann_providers::storage::FileStorageProvider;
let mut file = provider
.open_reader(&path.0.to_string_lossy())
Expand All @@ -114,6 +114,17 @@ pub(crate) fn load_groundtruth(path: BinFile<'_>) -> anyhow::Result<Matrix<u32>>
let mut groundtruth = Matrix::<u32>::new(0, num_points, dim);
let groundtruth_slice: &mut [u8] = bytemuck::cast_slice_mut(groundtruth.as_mut_slice());
file.read_exact(groundtruth_slice)?;

if let Some(expected_k) = k {
if groundtruth.ncols() < expected_k {
Comment thread
magdalendobson marked this conversation as resolved.
return Err(anyhow::anyhow!(
"Each row of groundtruth must have at least {} neighbors (got {})",
expected_k,
groundtruth.ncols()
));
}
}
Comment thread
magdalendobson marked this conversation as resolved.

Ok(groundtruth)
}

Expand Down Expand Up @@ -169,3 +180,35 @@ impl From<SerializableBitSet> for BitSet {
BitSet::from_bytes(&val.0)
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::path::PathBuf;
use tempfile::NamedTempFile;

#[test]
fn test_load_groundtruth_with_expected_k() {
// Prepare a temporary .bin file with a valid groundtruth header and data
let num_points: u32 = 2;
let dim: u32 = 3;
let data: Vec<u32> = vec![1, 2, 3, 4, 5, 6];
let mut file = NamedTempFile::new().expect("Failed to create temp file");
file.write_all(&num_points.to_le_bytes()).unwrap();
file.write_all(&dim.to_le_bytes()).unwrap();
for v in &data {
file.write_all(&v.to_le_bytes()).unwrap();
}
let path = PathBuf::from(file.path());
let bin_file = BinFile(&path);
// Should succeed for k <= dim
let mat = load_groundtruth(bin_file, Some(3)).expect("Should succeed for k <= dim");
assert_eq!(mat.nrows(), 2);
assert_eq!(mat.ncols(), 3);
// Should fail for k > dim
let bin_file = BinFile(&path);
let err = load_groundtruth(bin_file, Some(4)).unwrap_err();
assert!(err.to_string().contains("at least 4 neighbors"));
}
}
15 changes: 8 additions & 7 deletions diskann-benchmark/src/utils/recall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ pub(crate) struct RecallMetrics {
pub(crate) num_queries: usize,
/// The average recall across all queries.
pub(crate) average: f64,
/// The minimum observed recall (max possible value: `recall_n`).
pub(crate) minimum: usize,
/// The maximum observed recall (max possible value: `recall_k`).
pub(crate) maximum: usize,
}

impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics {
Expand All @@ -31,8 +27,6 @@ impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics {
recall_n: m.recall_n,
num_queries: m.num_queries,
average: m.average,
minimum: m.minimum,
maximum: m.maximum,
}
}
}
Expand Down Expand Up @@ -62,7 +56,14 @@ where
continue;
}

let recall = benchmark_core::recall::knn(groundtruth, None, results, *k, *n, false)?;
let recall = benchmark_core::recall::knn(
groundtruth,
None,
results,
*k,
*n,
benchmark_core::recall::GroundTruthMode::Fixed,
)?;
result.push((&recall).into());
}
}
Expand Down
Loading