diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index 0fa4d42c1..eb1b65868 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -20,12 +20,9 @@ pub struct RecallMetrics { pub recall_n: usize, /// The number of queries. pub num_queries: usize, - /// The average recall across all queries. + /// The average recall across queries with non-empty groundtruth. + /// Queries with zero groundtruth results are excluded from the average. pub average: f64, - /// The minimum observed recall (max possible value: `recall_n`). - pub minimum: usize, - /// The maximum observed recall (max possible value: `recall_k`). - pub maximum: usize, } #[derive(Debug, Error)] @@ -36,14 +33,12 @@ pub enum ComputeRecallError { DistanceRowsMismatch(usize, usize), #[error("recall k value {0} must be less than or equal to recall n {1}")] RecallKAndNError(usize, usize), - #[error("number of results per query {0} must be at least the specified recall k {1}")] - NotEnoughResults(usize, usize), #[error( "number of groundtruth values per query {0} must be at least the specified recall n {1}" )] NotEnoughGroundTruth(usize, usize), #[error("number of groundtruth distances {0} does not match groundtruth entries {1}")] - NotEnoughGroundTruthDistances(usize, usize), + GroundTruthDistanceMismatch(usize, usize), } /// An abstraction over data-structures such as vector-of-vectors. @@ -121,6 +116,15 @@ pub trait RecallCompatible: Eq + Hash + Clone + std::fmt::Debug {} impl RecallCompatible for T where T: Eq + Hash + Clone + std::fmt::Debug {} +// Enum representing whether the ground truth has fixed size +// (every row should have the same number, >= recall_k entries) +// or flexible size (row may have any number of entries, including none) +#[derive(Copy, Clone, Debug)] +pub enum GroundTruthMode { + Fixed, + Flexible, +} + /// Compute the K-nearest-neighbors recall value "K-recall-at-N". /// /// For each entry in `groundtruth` and `results`, this computes the `recall_k` number of @@ -130,15 +134,19 @@ impl RecallCompatible for T where T: Eq + Hash + Clone + std::fmt::Debug {} /// the last values of each entry of `results`. Values will be counted towards the recall if /// they have the same distance as the last ordered candidate. /// -/// If `allow_insufficient_results`, an error will not be given if an entry in `results` +/// Note that an error will NOT be given if an entry in `results` /// has fewer than `recall_n` candidates. +/// +/// If `ground_truth_mode` is `GroundTruthMode::Fixed`, then an +/// error will be given if any entry in `groundtruth` has fewer +/// than `recall_k` candidates. pub fn knn( groundtruth: &dyn Rows, groundtruth_distances: Option>, results: &dyn Rows, recall_k: usize, recall_n: usize, - allow_insufficient_results: bool, + ground_truth_mode: GroundTruthMode, ) -> Result where T: RecallCompatible, @@ -152,21 +160,22 @@ where return Err(ComputeRecallError::RowsMismatch(nrows, groundtruth.nrows())); } - if let Some(cols) = results.ncols() - && cols < recall_n - && !allow_insufficient_results - { - return Err(ComputeRecallError::NotEnoughResults(cols, recall_n)); - } - - // Validate groundtruth size for fixed-size sources - match groundtruth.ncols() { - Some(ncols) if ncols < recall_k => { - return Err(ComputeRecallError::NotEnoughGroundTruth(ncols, recall_k)); + if let GroundTruthMode::Fixed = ground_truth_mode { + // Validate that all rows in `groundtruth` have at least `recall_k` entries. + for i in 0..nrows { + let gt_row = groundtruth.row(i); + if gt_row.len() < recall_k { + return Err(ComputeRecallError::NotEnoughGroundTruth( + gt_row.len(), + recall_k, + )); + } } - _ => {} } + // If `groundtruth_distances` are present, validate that there are + // enough rows and that each row has the same + // number of entries as the corresponding row in `groundtruth`. if let Some(distances) = groundtruth_distances { if nrows != distances.nrows() { return Err(ComputeRecallError::DistanceRowsMismatch( @@ -175,54 +184,52 @@ where )); } - match groundtruth.ncols() { - Some(ncols) if distances.ncols() != ncols => { - return Err(ComputeRecallError::NotEnoughGroundTruthDistances( - distances.ncols(), - ncols, + for i in 0..nrows { + let gt_row = groundtruth.row(i); + let distances_row = distances.row(i); + if gt_row.len() != distances_row.len() { + return Err(ComputeRecallError::GroundTruthDistanceMismatch( + distances_row.len(), + gt_row.len(), )); } - _ => {} } } - // The actual recall computation for fixed-size groundtruth - let mut recall_values: Vec = Vec::new(); + // The actual recall computation for groundtruth + let mut recall_values: Vec = Vec::new(); let mut this_groundtruth = HashSet::new(); let mut this_results = HashSet::new(); for i in 0..results.nrows() { let result = results.row(i); - if !allow_insufficient_results && result.len() < recall_n { - return Err(ComputeRecallError::NotEnoughResults(result.len(), recall_n)); - } let gt_row = groundtruth.row(i); - if gt_row.len() < recall_k { - return Err(ComputeRecallError::NotEnoughGroundTruth( - gt_row.len(), - recall_k, - )); + // `groundtruth` does not have to be fixed-size, + // so we compute `recall_k` for this row based on its gt length + let this_recall_k = gt_row.len().min(recall_k); + + if this_recall_k == 0 { + continue; } // Populate the groundtruth using the top-k this_groundtruth.clear(); - this_groundtruth.extend(gt_row.iter().take(recall_k).cloned()); + this_groundtruth.extend(gt_row.iter().take(this_recall_k).cloned()); // If we have distances, then continue to append distances as long as the distance // value is constant - if let Some(distances) = groundtruth_distances - && recall_k > 0 - { + if let Some(distances) = groundtruth_distances { let distances_row = distances.row(i); - if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 { - let last_distance = distances_row[recall_k - 1]; - for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) { - if *d == last_distance { - this_groundtruth.insert(g.clone()); - } else { - break; - } + + // we've already checked that `results` and `distances` have at lesat + // `recall_k >= this_recall_k` entries, so it's safe to access `distances_row[this_recall_k - 1]` + let last_distance = distances_row[this_recall_k - 1]; + for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(this_recall_k) { + if *d == last_distance { + this_groundtruth.insert(g.clone()); + } else { + break; } } } @@ -235,27 +242,26 @@ where .iter() .filter(|i| this_results.contains(i)) .count() - .min(recall_k); + .min(this_recall_k); - recall_values.push(r); - } + let recall = (r as f64) / (this_recall_k as f64); - // Perform post-processing - let total: usize = recall_values.iter().sum(); - let minimum = recall_values.iter().min().unwrap_or(&0); - let maximum = recall_values.iter().max().unwrap_or(&0); + recall_values.push(recall); + } - // We explicitly check that each groundtruth row has at least `recall_k` elements. - let div = recall_k * nrows; - let average = (total as f64) / (div as f64); + // Compute the average recall + let total: f64 = recall_values.iter().sum(); + let average = if recall_values.is_empty() { + 0.0 + } else { + total / (recall_values.len() as f64) + }; Ok(RecallMetrics { recall_k, recall_n, num_queries: nrows, average, - minimum: *minimum, - maximum: *maximum, }) } @@ -450,7 +456,7 @@ mod tests { &our_results, expected.recall_k, expected.recall_n, - false, + GroundTruthMode::Fixed, ) .unwrap(); @@ -467,8 +473,6 @@ mod tests { assert_eq!(recall.num_queries, our_results.nrows()); assert_eq!(recall.recall_k, expected.recall_k); assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); } //-----------// @@ -497,7 +501,7 @@ mod tests { &our_results, expected.recall_k, expected.recall_n, - false, + GroundTruthMode::Fixed, ) .unwrap(); @@ -514,116 +518,231 @@ mod tests { assert_eq!(recall.num_queries, our_results.nrows()); assert_eq!(recall.recall_k, expected.recall_k); assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); } } #[test] - fn test_errors() { - // k greater than n - { - let groundtruth = Matrix::::new(0, 10, 10); - let results = Matrix::::new(0, 10, 10); - let err = knn(&groundtruth, None, &results, 11, 10, false).unwrap_err(); - assert!(matches!(err, ComputeRecallError::RecallKAndNError(..))); - } + fn test_error_recall_k_and_n() { + let groundtruth = Matrix::::new(0, 10, 10); + let results = Matrix::::new(0, 10, 10); + let err = knn(&groundtruth, None, &results, 11, 10, GroundTruthMode::Fixed).unwrap_err(); + assert!(matches!(err, ComputeRecallError::RecallKAndNError(..))); + } - // Unequal rows - { - let groundtruth = Matrix::::new(0, 11, 10); - let results = Matrix::::new(0, 10, 10); - let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err(); - assert!(matches!(err, ComputeRecallError::RowsMismatch(..))); - let err_allow_insufficient_results = - knn(&groundtruth, None, &results, 10, 10, true).unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::RowsMismatch(..) - )); - } + #[test] + fn test_error_rows_mismatch() { + let groundtruth = Matrix::::new(0, 11, 10); + let results = Matrix::::new(0, 10, 10); + let err = knn(&groundtruth, None, &results, 10, 10, GroundTruthMode::Fixed).unwrap_err(); + assert!(matches!(err, ComputeRecallError::RowsMismatch(..))); + let err_allow_insufficient_results = + knn(&groundtruth, None, &results, 10, 10, GroundTruthMode::Fixed).unwrap_err(); + assert!(matches!( + err_allow_insufficient_results, + ComputeRecallError::RowsMismatch(..) + )); + } - // Not enough results - { - let groundtruth = Matrix::::new(0, 10, 10); - let results = Matrix::::new(0, 10, 5); - let err = knn(&groundtruth, None, &results, 5, 10, false).unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughResults(..))); - let _ = knn(&groundtruth, None, &results, 5, 10, true); - } + #[test] + fn test_error_not_enough_groundtruth() { + let groundtruth = Matrix::::new(0, 10, 5); + let results = Matrix::::new(0, 10, 10); + let err = knn(&groundtruth, None, &results, 10, 10, GroundTruthMode::Fixed).unwrap_err(); + assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..))); + let err_allow_insufficient_results = + knn(&groundtruth, None, &results, 10, 10, GroundTruthMode::Fixed).unwrap_err(); + assert!(matches!( + err_allow_insufficient_results, + ComputeRecallError::NotEnoughGroundTruth(..) + )); + } - // Not enough results - dynamic - { - let groundtruth = Matrix::::new(0, 10, 10); - let results: Vec<_> = (0..10).map(|_| vec![0; 5]).collect(); - let err = knn(&groundtruth, None, &results, 5, 10, false).unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughResults(..))); - let _ = knn(&groundtruth, None, &results, 5, 10, true); - } + #[test] + fn test_dynamic_groundtruth_valid() { + let groundtruth: Vec<_> = (0..10).map(|_| vec![0u32; 5]).collect(); + let results = Matrix::::new(0, 10, 10); + // Should succeed: each row uses this_recall_k = min(5, 10) = 5 + // Should succeed in Flexible mode, but fail in Fixed mode + let recall_flexible = knn( + &groundtruth, + None, + &results, + 10, + 10, + GroundTruthMode::Flexible, + ) + .unwrap(); + assert_eq!(recall_flexible.num_queries, 10); + // Should fail in Fixed mode + let err = knn(&groundtruth, None, &results, 10, 10, GroundTruthMode::Fixed).unwrap_err(); + assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..))); + assert_eq!(recall_flexible.num_queries, 10); + } - // Not enough groundtruth - { - let groundtruth = Matrix::::new(0, 10, 5); - let results = Matrix::::new(0, 10, 10); - let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..))); - let err_allow_insufficient_results = - knn(&groundtruth, None, &results, 10, 10, true).unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::NotEnoughGroundTruth(..) - )); + #[test] + fn test_dynamic_groundtruth_full_match() { + let gt_row: Vec = (1..=5).collect(); + let groundtruth: Vec<_> = (0..10).map(|_| gt_row.clone()).collect(); + let mut results = Matrix::::new(0, 10, 10); + for i in 0..10 { + for (j, v) in (1u32..=10).enumerate() { + results[(i, j)] = v; + } } + let recall = knn( + &groundtruth, + None, + &results, + 10, + 10, + GroundTruthMode::Flexible, + ) + .unwrap(); + assert!((recall.average - 1.0).abs() < 1e-10); + } - // Not enough groundtruth - dynamic - { - let groundtruth: Vec<_> = (0..10).map(|_| vec![0; 5]).collect(); - let results = Matrix::::new(0, 10, 10); - let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..))); - let err_allow_insufficient_results = - knn(&groundtruth, None, &results, 10, 10, true).unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::NotEnoughGroundTruth(..) - )); + #[test] + fn test_dynamic_groundtruth_partial_match() { + // groundtruth: [1, 2, 3, 4, 5]; results contain [1, 2, 3, 6, 7, 8, 9, 10, 11, 12] + let gt_row: Vec = (1..=5).collect(); + let groundtruth: Vec<_> = (0..10).map(|_| gt_row.clone()).collect(); + let mut results = Matrix::::new(0, 10, 10); + let res_row: Vec = vec![1, 2, 3, 6, 7, 8, 9, 10, 11, 12]; + for i in 0..10 { + for (j, &v) in res_row.iter().enumerate() { + results[(i, j)] = v; + } } + let recall = knn( + &groundtruth, + None, + &results, + 10, + 10, + GroundTruthMode::Flexible, + ) + .unwrap(); + assert!((recall.average - 0.6).abs() < 1e-10); + } - // Distance Row Mismatch - { - let groundtruth = Matrix::::new(0, 10, 10); - let distances = Matrix::::new(0.0, 9, 10); - let results = Matrix::::new(0, 10, 10); - let err = knn( - &groundtruth, - Some(distances.as_view().into()), - &results, - 10, - 10, - false, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::DistanceRowsMismatch(..))); + #[test] + fn test_dynamic_groundtruth_mixed_zero_nonzero() { + let mut groundtruth: Vec> = Vec::new(); + // First 5 rows: non-empty groundtruth + for _ in 0..5 { + groundtruth.push((1..=5).collect()); + } + // Last 5 rows: empty groundtruth + for _ in 0..5 { + groundtruth.push(vec![]); } - // Distance Cols Mismatch - { - let groundtruth = Matrix::::new(0, 10, 10); - let distances = Matrix::::new(0.0, 10, 9); - let results = Matrix::::new(0, 10, 10); - let err = knn( - &groundtruth, - Some(distances.as_view().into()), - &results, - 10, - 10, - false, - ) - .unwrap_err(); - assert!(matches!( - err, - ComputeRecallError::NotEnoughGroundTruthDistances(..) - )); + let mut results = Matrix::::new(0, 10, 10); + for i in 0..10 { + for (j, v) in (1u32..=10).enumerate() { + results[(i, j)] = v; + } } + + let recall = knn( + &groundtruth, + None, + &results, + 10, + 10, + GroundTruthMode::Flexible, + ) + .unwrap(); + assert_eq!(recall.num_queries, 10); + assert!((recall.average - 1.0).abs() < 1e-10); + } + + #[test] + fn test_dynamic_groundtruth_all_zero() { + let groundtruth: Vec> = (0..10).map(|_| vec![]).collect(); + let results = Matrix::::new(0, 10, 10); + + let recall = knn( + &groundtruth, + None, + &results, + 10, + 10, + GroundTruthMode::Flexible, + ) + .unwrap(); + assert_eq!(recall.num_queries, 10); + assert_eq!(recall.average, 0.0); + assert!(!recall.average.is_nan()); + assert!(!recall.average.is_infinite()); + } + + #[test] + fn test_error_distance_rows_mismatch() { + let groundtruth = Matrix::::new(0, 10, 10); + let distances = Matrix::::new(0.0, 9, 10); + let results = Matrix::::new(0, 10, 10); + let err = knn( + &groundtruth, + Some(distances.as_view().into()), + &results, + 10, + 10, + GroundTruthMode::Fixed, + ) + .unwrap_err(); + assert!(matches!(err, ComputeRecallError::DistanceRowsMismatch(..))); + } + + #[test] + fn test_error_distance_cols_mismatch() { + let groundtruth = Matrix::::new(0, 10, 10); + let distances = Matrix::::new(0.0, 10, 9); + let results = Matrix::::new(0, 10, 10); + let err = knn( + &groundtruth, + Some(distances.as_view().into()), + &results, + 10, + 10, + GroundTruthMode::Fixed, + ) + .unwrap_err(); + assert!(matches!( + err, + ComputeRecallError::GroundTruthDistanceMismatch(..) + )); + } + + #[test] + fn test_error_distance_cols_mismatch_variable_size_groundtruth() { + // groundtruth: 2 rows, first row has 3 elements, second row has 2 + let groundtruth: Vec> = vec![vec![1, 2, 3], vec![4, 5]]; + // distances: first row has 2 elements (should be 3), second row has 2 (matches) + let distances: Vec> = vec![vec![0.1, 0.2], vec![0.3, 0.4]]; + let distances = Matrix::try_from( + distances.into_iter().flatten().collect::>().into(), + 2, + 2, + ) + .unwrap(); + // results: 2 rows, each with 3 elements + let results: Vec> = vec![vec![1, 2, 3], vec![4, 5, 6]]; + let err = knn( + &groundtruth, + Some(distances.as_view().into()), + &results, + 3, + 3, + GroundTruthMode::Flexible, + ) + .unwrap_err(); + + println!("{err}"); + + assert!(matches!( + err, + ComputeRecallError::GroundTruthDistanceMismatch(2, 3) + )); } } diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 2149842a8..c46217b41 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -17,6 +17,7 @@ use diskann_utils::{future::AsyncFriendly, views::Matrix}; use crate::{ recall, + recall::GroundTruthMode, search::{self, Search, graph::Strategy}, utils, }; @@ -186,6 +187,7 @@ pub struct Aggregator<'a, I> { groundtruth: &'a dyn crate::recall::Rows, recall_k: usize, recall_n: usize, + groundtruth_mode: GroundTruthMode, } impl<'a, I> Aggregator<'a, I> { @@ -199,11 +201,13 @@ impl<'a, I> Aggregator<'a, I> { groundtruth: &'a dyn crate::recall::Rows, recall_k: usize, recall_n: usize, + groundtruth_mode: GroundTruthMode, ) -> Self { Self { groundtruth, recall_k, recall_n, + groundtruth_mode, } } } @@ -227,7 +231,7 @@ where first.ids().as_rows(), self.recall_k, self.recall_n, - true, + self.groundtruth_mode, )?, None => anyhow::bail!("Results must be non-empty"), }; @@ -349,8 +353,12 @@ mod tests { let recall_k = nearest_neighbors; let recall_n = nearest_neighbors; - let all = - search::search_all(knn, parameters, Aggregator::new(rows, recall_k, recall_n)).unwrap(); + let all = search::search_all( + knn, + parameters, + Aggregator::new(rows, recall_k, recall_n, GroundTruthMode::Fixed), + ) + .unwrap(); assert_eq!(all.len(), 2); for summary in all { diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index 914f01abe..d05a3bd3d 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -140,6 +140,7 @@ mod tests { use super::*; + use crate::recall::GroundTruthMode; use diskann::graph::{index::QueryLabelProvider, test::provider}; // A simple [`QueryLabelProvider`] that rejects odd indices. @@ -224,7 +225,12 @@ mod tests { let all = search::search_all( multihop, parameters, - search::graph::knn::Aggregator::new(rows, recall_k, recall_n), + search::graph::knn::Aggregator::new( + rows, + recall_k, + recall_n, + GroundTruthMode::Flexible, + ), ) .unwrap(); diff --git a/diskann-benchmark/src/backend/exhaustive/minmax.rs b/diskann-benchmark/src/backend/exhaustive/minmax.rs index da38f4e9e..4c07f5406 100644 --- a/diskann-benchmark/src/backend/exhaustive/minmax.rs +++ b/diskann-benchmark/src/backend/exhaustive/minmax.rs @@ -136,7 +136,7 @@ mod imp { f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?; let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?; + datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?; let mut search_results = Vec::::new(); let threadpool = rayon::ThreadPoolBuilder::new() .num_threads(input.search.num_threads.get()) diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 1242c0aef..28ca28113 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -135,7 +135,7 @@ mod imp { f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?; let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?; + datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?; let search_progress = make_progress_bar("running search", queries.nrows(), output.draw_target())?; diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index a4df5b702..be29865cd 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -133,7 +133,7 @@ mod imp { f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?; let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?; + datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?; let search_progress = make_progress_bar( "running search", diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 124d0b8e1..e0684923a 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -464,7 +464,11 @@ where let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); - let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; + // compute the maximum value of k used in any search + let max_k = topk.max_k(); + + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?; let knn = benchmark_core::search::graph::KNN::new( index.clone(), @@ -650,10 +654,8 @@ fn full_precision_streaming( where T: bytemuck::Pod + VectorRepr + WithApproximateNorm + SampleableForStart, { - let topk = match &input.search_phase { - SearchPhase::Topk(topk) => topk, - _ => anyhow::bail!("Only TopK is currently supported by the streaming index"), - }; + let topk = input.search_phase.as_topk()?; + let consolidate_threshold: f32 = input.runbook_params.consolidate_threshold; let data = datafiles::load_dataset::(datafiles::BinFile(&input.build.data))?; @@ -688,10 +690,14 @@ where let managed = Managed::new(max_points, consolidate_threshold, managed_stream); - let layered = bigann::WithData::new(managed, data, queries, |path| { - Ok(Box::new(datafiles::load_groundtruth(datafiles::BinFile( - path, - ))?)) + // compute the maximum value of k used in any search + let max_k = topk.max_k(); + + let layered = bigann::WithData::new(managed, data, queries, move |path| { + Ok(Box::new(datafiles::load_groundtruth( + datafiles::BinFile(path), + Some(max_k), + )?)) }); Ok(layered) diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index b50e69010..7385f4c30 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -5,6 +5,7 @@ use std::{num::NonZeroUsize, sync::Arc}; +use diskann_benchmark_core::recall::GroundTruthMode; use diskann_benchmark_core::{self as benchmark_core, search as core_search}; use crate::{backend::index::result::SearchResults, inputs::graph_index::GraphSearch}; @@ -97,7 +98,12 @@ where let results = core_search::search_all( self.clone(), parameters.into_iter(), - core_search::graph::knn::Aggregator::new(groundtruth, recall_k, recall_n), + core_search::graph::knn::Aggregator::new( + groundtruth, + recall_k, + recall_n, + GroundTruthMode::Fixed, + ), )?; Ok(results.into_iter().map(SearchResults::new).collect()) @@ -123,7 +129,12 @@ where let results = core_search::search_all( self.clone(), parameters.into_iter(), - core_search::graph::knn::Aggregator::new(groundtruth, recall_k, recall_n), + core_search::graph::knn::Aggregator::new( + groundtruth, + recall_k, + recall_n, + GroundTruthMode::Flexible, + ), )?; Ok(results.into_iter().map(SearchResults::new).collect()) diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index cb2558083..6f4469181 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -378,10 +378,14 @@ mod imp { ) -> anyhow::Result { let topk = phase.as_topk()?; + // compute the maximum value of k used in any search + let max_k = topk.max_k(); + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); - let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?; let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs); @@ -518,7 +522,7 @@ mod imp { ))?); let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&multihop.groundtruth))?; + datafiles::load_range_groundtruth(datafiles::BinFile(&multihop.groundtruth))?; let steps = search::knn::SearchSteps::new(multihop.reps, &multihop.num_threads, &multihop.runs); diff --git a/diskann-benchmark/src/inputs/graph_index.rs b/diskann-benchmark/src/inputs/graph_index.rs index 9df194382..5a3231825 100644 --- a/diskann-benchmark/src/inputs/graph_index.rs +++ b/diskann-benchmark/src/inputs/graph_index.rs @@ -115,6 +115,10 @@ pub(crate) struct TopkSearchPhase { } impl TopkSearchPhase { + pub(crate) fn max_k(&self) -> usize { + self.runs.iter().map(|run| run.recall_k).max().unwrap_or(0) + } + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { self.queries.resolve(checker)?; self.groundtruth.resolve(checker)?; diff --git a/diskann-benchmark/src/utils/datafiles.rs b/diskann-benchmark/src/utils/datafiles.rs index 9c5057488..abfe06a7d 100644 --- a/diskann-benchmark/src/utils/datafiles.rs +++ b/diskann-benchmark/src/utils/datafiles.rs @@ -95,7 +95,7 @@ impl ConvertingLoad for f32 { } /// Load a groundtruth set from disk and return the result as a row-major matrix. -pub(crate) fn load_groundtruth(path: BinFile<'_>) -> anyhow::Result> { +pub(crate) fn load_groundtruth(path: BinFile<'_>, k: Option) -> anyhow::Result> { let provider = diskann_providers::storage::FileStorageProvider; let mut file = provider .open_reader(&path.0.to_string_lossy()) @@ -114,6 +114,17 @@ pub(crate) fn load_groundtruth(path: BinFile<'_>) -> anyhow::Result> let mut groundtruth = Matrix::::new(0, num_points, dim); let groundtruth_slice: &mut [u8] = bytemuck::cast_slice_mut(groundtruth.as_mut_slice()); file.read_exact(groundtruth_slice)?; + + if let Some(expected_k) = k { + if groundtruth.ncols() < expected_k { + return Err(anyhow::anyhow!( + "Each row of groundtruth must have at least {} neighbors (got {})", + expected_k, + groundtruth.ncols() + )); + } + } + Ok(groundtruth) } @@ -169,3 +180,35 @@ impl From for BitSet { BitSet::from_bytes(&val.0) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use std::path::PathBuf; + use tempfile::NamedTempFile; + + #[test] + fn test_load_groundtruth_with_expected_k() { + // Prepare a temporary .bin file with a valid groundtruth header and data + let num_points: u32 = 2; + let dim: u32 = 3; + let data: Vec = vec![1, 2, 3, 4, 5, 6]; + let mut file = NamedTempFile::new().expect("Failed to create temp file"); + file.write_all(&num_points.to_le_bytes()).unwrap(); + file.write_all(&dim.to_le_bytes()).unwrap(); + for v in &data { + file.write_all(&v.to_le_bytes()).unwrap(); + } + let path = PathBuf::from(file.path()); + let bin_file = BinFile(&path); + // Should succeed for k <= dim + let mat = load_groundtruth(bin_file, Some(3)).expect("Should succeed for k <= dim"); + assert_eq!(mat.nrows(), 2); + assert_eq!(mat.ncols(), 3); + // Should fail for k > dim + let bin_file = BinFile(&path); + let err = load_groundtruth(bin_file, Some(4)).unwrap_err(); + assert!(err.to_string().contains("at least 4 neighbors")); + } +} diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index dcbe86d94..2d0fba26a 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -18,10 +18,6 @@ pub(crate) struct RecallMetrics { pub(crate) num_queries: usize, /// The average recall across all queries. pub(crate) average: f64, - /// The minimum observed recall (max possible value: `recall_n`). - pub(crate) minimum: usize, - /// The maximum observed recall (max possible value: `recall_k`). - pub(crate) maximum: usize, } impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { @@ -31,8 +27,6 @@ impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { recall_n: m.recall_n, num_queries: m.num_queries, average: m.average, - minimum: m.minimum, - maximum: m.maximum, } } } @@ -62,7 +56,14 @@ where continue; } - let recall = benchmark_core::recall::knn(groundtruth, None, results, *k, *n, false)?; + let recall = benchmark_core::recall::knn( + groundtruth, + None, + results, + *k, + *n, + benchmark_core::recall::GroundTruthMode::Fixed, + )?; result.push((&recall).into()); } }