Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions diskann-tools/src/utils/build_disk_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ mod tests {

#[test]
fn test_build_disk_index_with_num_of_pq_chunks() {
let storage_provider = VirtualStorageProvider::new(MemoryFS::new());
Comment thread
arrayka marked this conversation as resolved.
Outdated
let storage_provider = VirtualStorageProvider::new_memory();
let parameters = BuildDiskIndexParameters {
metric: Metric::L2,
data_path: "test_data_path",
Expand All @@ -220,7 +220,7 @@ mod tests {

#[test]
fn test_build_disk_index_with_zero_num_of_pq_chunks() {
let storage_provider = VirtualStorageProvider::new(MemoryFS::new());
let storage_provider = VirtualStorageProvider::new_memory();
let parameters = BuildDiskIndexParameters {
metric: Metric::L2,
data_path: "test_data_path",
Expand Down
79 changes: 79 additions & 0 deletions diskann-tools/src/utils/cmd_tool_error.rs
Comment thread
arrayka marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,82 @@ where
ann_error.into()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_cmd_tool_error_display() {
let error = CMDToolError {
details: "test error".to_string(),
};
assert_eq!(format!("{}", error), "test error");
}

#[test]
fn test_cmd_tool_error_debug() {
let error = CMDToolError {
details: "test error".to_string(),
};
assert_eq!(format!("{:?}", error), "test error");
}

#[test]
fn test_cmd_tool_error_description() {
let error = CMDToolError {
details: "test error".to_string(),
};
#[allow(deprecated)]
{
assert_eq!(error.description(), "test error");
}
Comment thread
arrayka marked this conversation as resolved.
}

#[test]
fn test_from_io_error() {
let io_error = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let cmd_error: CMDToolError = io_error.into();
assert!(cmd_error.details.contains("file not found"));
}

#[test]
fn test_from_normal_error() {
let normal_error = rand_distr::NormalError::BadVariance;
let cmd_error: CMDToolError = normal_error.into();
// Just verify the error was converted and has some details
assert!(!cmd_error.details.is_empty());
}

#[test]
fn test_from_ann_error() {
use diskann::ANNErrorKind;
let ann_error = diskann::ANNError::new(
ANNErrorKind::IndexError,
std::io::Error::other("test error"),
);
let cmd_error: CMDToolError = ann_error.into();
assert!(cmd_error.details.contains("test error"));
}

#[test]
fn test_from_config_error() {
// We can't easily construct a ConfigError directly, so we test the conversion
// by testing that a string error message can be converted
let io_error = std::io::Error::other("config error");
let ann_error = diskann::ANNError::new(diskann::ANNErrorKind::IndexConfigError, io_error);
let cmd_error: CMDToolError = ann_error.into();
assert!(cmd_error.details.contains("config error"));
Comment thread
arrayka marked this conversation as resolved.
Outdated
}

#[test]
fn test_from_jsonl_read_error() {
use diskann_label_filter::JsonlReadError;
let jsonl_error = JsonlReadError::IoError(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"invalid jsonl",
));
let cmd_error: CMDToolError = jsonl_error.into();
assert!(cmd_error.details.contains("invalid jsonl"));
}
}
63 changes: 63 additions & 0 deletions diskann-tools/src/utils/filter_search_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,67 @@ mod tests {
assert_eq!(bitmaps.len(), 1);
assert!(bitmaps[0].is_empty());
}

#[test]
fn test_serializable_bitset_conversion() {
let mut bitset = BitSet::new();
bitset.insert(0);
bitset.insert(5);
bitset.insert(10);

let serializable = SerializableBitSet::from(&bitset);
let converted_back: BitSet = serializable.into();

assert!(converted_back.contains(0));
assert!(converted_back.contains(5));
assert!(converted_back.contains(10));
assert!(!converted_back.contains(1));
}

#[test]
fn test_serializable_bitset_empty() {
let bitset = BitSet::new();
let serializable = SerializableBitSet::from(&bitset);
let converted_back: BitSet = serializable.into();
assert!(converted_back.is_empty());
}

#[test]
fn test_process_bitmap_single_query_single_metadata() {
let query_strings = vec![String::from("CAT=Automotive")];
let metadata_strings = vec![String::from("CAT=Automotive,RATING=5")];

let bitmaps = process_bitmap_for_labels(query_strings, metadata_strings, &POOL);
assert_eq!(bitmaps.len(), 1);
assert!(bitmaps[0].contains(0));
}

#[test]
fn test_process_bitmap_no_match() {
let query_strings = vec![String::from("CAT=Electronics")];
let metadata_strings = vec![
String::from("CAT=Automotive,RATING=5"),
String::from("CAT=Fashion,RATING=4"),
];

let bitmaps = process_bitmap_for_labels(query_strings, metadata_strings, &POOL);
assert_eq!(bitmaps.len(), 1);
assert!(bitmaps[0].is_empty());
}

#[test]
fn test_process_bitmap_multiple_matches() {
let query_strings = vec![String::from("RATING=5")];
let metadata_strings = vec![
String::from("CAT=Automotive,RATING=5"),
String::from("CAT=Fashion,RATING=4"),
String::from("CAT=Electronics,RATING=5"),
];

let bitmaps = process_bitmap_for_labels(query_strings, metadata_strings, &POOL);
assert_eq!(bitmaps.len(), 1);
assert!(bitmaps[0].contains(0));
assert!(!bitmaps[0].contains(1));
assert!(bitmaps[0].contains(2));
}
}
88 changes: 88 additions & 0 deletions diskann-tools/src/utils/gen_associated_data_from_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,91 @@ pub fn gen_associated_data_from_range(

Ok(())
}

Comment thread
arrayka marked this conversation as resolved.
#[cfg(test)]
Comment thread
arrayka marked this conversation as resolved.
mod tests {
use super::*;
use byteorder::{LittleEndian, ReadBytesExt};
use diskann_providers::storage::StorageReadProvider;

#[test]
fn test_gen_associated_data_from_range() {
let storage_provider = FileStorageProvider;
let path = "/tmp/test_gen_associated_data_from_range.bin";

// Clean up if file exists
let _ = std::fs::remove_file(path);

// Generate data from range 0 to 9
gen_associated_data_from_range(&storage_provider, path, 0, 9).unwrap();

// Read back and verify
let mut file = storage_provider.open_reader(path).unwrap();

// Read metadata
let num_ints = file.read_u32::<LittleEndian>().unwrap();
let int_length = file.read_u32::<LittleEndian>().unwrap();

assert_eq!(num_ints, 10);
assert_eq!(int_length, 1);

// Read integers
for expected in 0u32..=9 {
let actual = file.read_u32::<LittleEndian>().unwrap();
assert_eq!(actual, expected);
}

// Clean up
std::fs::remove_file(path).unwrap();
}

#[test]
fn test_gen_associated_data_from_range_single_value() {
let storage_provider = FileStorageProvider;
let path = "/tmp/test_gen_associated_data_single.bin";

let _ = std::fs::remove_file(path);

// Generate data for a single value
gen_associated_data_from_range(&storage_provider, path, 42, 42).unwrap();

let mut file = storage_provider.open_reader(path).unwrap();

let num_ints = file.read_u32::<LittleEndian>().unwrap();
let int_length = file.read_u32::<LittleEndian>().unwrap();

assert_eq!(num_ints, 1);
assert_eq!(int_length, 1);

let value = file.read_u32::<LittleEndian>().unwrap();
assert_eq!(value, 42);

std::fs::remove_file(path).unwrap();
}

#[test]
fn test_gen_associated_data_from_range_large() {
let storage_provider = FileStorageProvider;
let path = "/tmp/test_gen_associated_data_large.bin";

let _ = std::fs::remove_file(path);

// Generate data for range 100 to 199
gen_associated_data_from_range(&storage_provider, path, 100, 199).unwrap();

let mut file = storage_provider.open_reader(path).unwrap();

let num_ints = file.read_u32::<LittleEndian>().unwrap();
let int_length = file.read_u32::<LittleEndian>().unwrap();

assert_eq!(num_ints, 100);
assert_eq!(int_length, 1);

for expected in 100u32..=199 {
let actual = file.read_u32::<LittleEndian>().unwrap();
assert_eq!(actual, expected);
}

std::fs::remove_file(path).unwrap();
}
}
57 changes: 57 additions & 0 deletions diskann-tools/src/utils/generate_synthetic_labels_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ pub fn generate_labels(
#[cfg(test)]
mod test {
use std::fs;
use std::io::BufRead;

use super::generate_labels;

Expand Down Expand Up @@ -165,4 +166,60 @@ mod test {
fs::remove_file(label_file2).expect("Failed to delete file");
fs::remove_file(label_file3).expect("Failed to delete file");
}

#[test]
fn test_generate_labels_small_dataset() {
let label_file = "/tmp/test_labels_small.txt";
let result = generate_labels(label_file, "zipf", 10, 5);

assert!(result.is_ok());
assert!(fs::metadata(label_file).is_ok());

// Verify we have 10 lines
let file = fs::File::open(label_file).unwrap();
let reader = std::io::BufReader::new(file);
let lines: Vec<_> = reader.lines().collect();
assert_eq!(lines.len(), 10);

fs::remove_file(label_file).ok();
}

#[test]
fn test_generate_labels_random_distribution() {
let label_file = "/tmp/test_labels_random.txt";
let result = generate_labels(label_file, "random", 100, 10);

assert!(result.is_ok());
assert!(fs::metadata(label_file).is_ok());

fs::remove_file(label_file).ok();
}

#[test]
fn test_generate_labels_one_per_point() {
let label_file = "/tmp/test_labels_one_per_point.txt";
let result = generate_labels(label_file, "one_per_point", 50, 20);

assert!(result.is_ok());
assert!(fs::metadata(label_file).is_ok());

// Verify we have 50 lines
let file = fs::File::open(label_file).unwrap();
let reader = std::io::BufReader::new(file);
let lines: Vec<_> = reader.lines().collect();
assert_eq!(lines.len(), 50);

fs::remove_file(label_file).ok();
}

#[test]
fn test_generate_labels_single_point() {
let label_file = "/tmp/test_labels_single.txt";
let result = generate_labels(label_file, "zipf", 1, 5);

assert!(result.is_ok());
assert!(fs::metadata(label_file).is_ok());

fs::remove_file(label_file).ok();
}
}
21 changes: 21 additions & 0 deletions diskann-tools/src/utils/parameter_helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,24 @@ pub fn get_num_threads(num_threads: Option<usize>) -> usize {
None => num_cpus::get(),
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_get_num_threads_with_some() {
assert_eq!(get_num_threads(Some(4)), 4);
assert_eq!(get_num_threads(Some(1)), 1);
assert_eq!(get_num_threads(Some(16)), 16);
}

#[test]
fn test_get_num_threads_with_none() {
let result = get_num_threads(None);
// Should return the number of CPUs, which is at least 1
assert!(result >= 1);
// Should match num_cpus::get()
assert_eq!(result, num_cpus::get());
}
}
Loading
Loading