Skip to content

Commit 81fab9c

Browse files
arkrishn94hildebrandmwMark HildebrandCopilot
authored
[quantization] 8bit distance kernels and ZipUnzip (#798)
This PR introduces heterogeneous inner-product kernels for 8-bit bitslices; specifically with 4-bit, 2-bit and 1-bit bitslices. The goal is to enable fast kernels for full-precision like queries with quantized vectors (spherical, minmax etc.). In the benchmark, we see the `u8xu4` kernel is ~2x faster than its `f32xu4` counterpart. For AVX2 capable architectures, the 4-bit and 2-bit kernels are implemented using the [`_mm256_maddubs_epi16`](https://doc.rust-lang.org/beta/core/arch/x86_64/fn._mm256_maddubs_epi16.html) intrinsic acting on blocks of 32 byte-sized dimensions for the `u8xu4` kernel and 64 dimensions for the `u8xu2` kernel. Some care needed to be taken to make sure that for these specific kernels, the intrinsic doesn't saturate when doing the madds. For the 1-bit kernel, we implement a simple masked horizontal add strategy on blocks of size 32. `Scalar` fallback is implemented for `Neon` and for now `V4` architecture gets retargeted to `V3` for these kernels. Support to compute `u8xu4`, `u8xu2` and `u8xu1` distances with minmax quantized vectors is available mostly out of the box. ## ZipUnzip A new trait `ZipUnzip` has been added to diskann-wide to implement vectorized zipping and unzipping logic - the zipping merges two halved vectors into a full vector by interleaving elements from each half vector, and, the unzipping performs the inverse transformation on the full vector. - It's currently implemented for `i8x32`, `i16x16`, `i32x8`, `u8x32`, `u32x8` and `f16x16`. - It's implemented for `Scalar`, `V3`, `V4` and `Neon` architectures. # Benchmark We ran the benchmark as a flat scan of vectors, making sure to clear the cache at every run and on a count that exceeds the L3 cache size for the machine. ``` Total latency in ms, COUNT=150K, AMD EPYC 7763 Kernel dim=256 dim=384 dim=896 ───────────────────────────────────────────────────────────── u8×u4 (new) 8.92 9.93 17.71 u8×u2 (new) 9.87 13.10 21.53 u8×u1 (new) 6.09 7.52 14.62 f32×u4 15.96 20.99 40.67 f32×u2 13.88 18.46 35.64 f32×u1 13.53 17.93 34.65 u8×u8 8.40 10.00 19.72 u4×u4 7.33 11.58 15.49 f32×f32 16.56 23.16 47.76 ``` --------- Co-authored-by: Mark Hildebrand <hildebrandmw@gmail.com> Co-authored-by: Mark Hildebrand <mhildebrand@microsoft.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 3cd8ac2 commit 81fab9c

26 files changed

Lines changed: 1514 additions & 107 deletions

diskann-quantization/src/bits/distances.rs

Lines changed: 814 additions & 28 deletions
Large diffs are not rendered by default.

diskann-quantization/src/minmax/vectors.rs

Lines changed: 108 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,16 @@ pub type FullQueryMut<'a> = slice::SliceMut<'a, f32, FullQueryMeta>;
203203
// Compensated Distances //
204204
///////////////////////////
205205
#[inline(always)]
206-
fn kernel<const NBITS: usize, F>(
207-
x: DataRef<'_, NBITS>,
208-
y: DataRef<'_, NBITS>,
206+
fn kernel<const N: usize, const M: usize, F>(
207+
x: DataRef<'_, N>,
208+
y: DataRef<'_, M>,
209209
f: F,
210210
) -> distances::MathematicalResult<f32>
211211
where
212-
Unsigned: Representation<NBITS>,
212+
Unsigned: Representation<N> + Representation<M>,
213213
InnerProduct: for<'a, 'b> PureDistanceFunction<
214-
BitSlice<'a, NBITS, Unsigned>,
215-
BitSlice<'b, NBITS, Unsigned>,
214+
BitSlice<'a, N, Unsigned>,
215+
BitSlice<'b, M, Unsigned>,
216216
distances::MathematicalResult<u32>,
217217
>,
218218
F: Fn(f32, &MinMaxCompensation, &MinMaxCompensation) -> f32,
@@ -477,6 +477,50 @@ mod minmax_vector_tests {
477477
use super::*;
478478
use crate::{alloc::GlobalAllocator, scalar::bit_scale};
479479

480+
/// Builds a random MinMax quantized vector and its full-precision reconstruction.
481+
///
482+
/// Returns `(compressed, original)` where `compressed` has its `MinMaxCompensation`
483+
/// metadata fully populated and `original` is the dequantized f32 vector.
484+
fn random_minmax_vector<const NBITS: usize>(
485+
dim: usize,
486+
rng: &mut impl Rng,
487+
) -> (Data<NBITS>, Vec<f32>)
488+
where
489+
Unsigned: Representation<NBITS>,
490+
{
491+
let mut v = Data::<NBITS>::new_boxed(dim);
492+
493+
let domain = Unsigned::domain_const::<NBITS>();
494+
let code_dist = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
495+
496+
{
497+
let mut bs = v.vector_mut();
498+
for i in 0..dim {
499+
bs.set(i, code_dist.sample(rng)).unwrap();
500+
}
501+
}
502+
503+
let a: f32 = Uniform::new_inclusive(0.0, 2.0).unwrap().sample(rng);
504+
let b: f32 = Uniform::new_inclusive(0.0, 2.0).unwrap().sample(rng);
505+
506+
let original: Vec<f32> = (0..dim)
507+
.map(|i| a * v.vector().get(i).unwrap() as f32 + b)
508+
.collect();
509+
510+
let code_sum: f32 = (0..dim).map(|i| v.vector().get(i).unwrap() as f32).sum();
511+
let norm_squared: f32 = original.iter().map(|x| x * x).sum();
512+
513+
v.set_meta(MinMaxCompensation {
514+
a,
515+
b,
516+
n: a * code_sum,
517+
norm_squared,
518+
dim: dim as u32,
519+
});
520+
521+
(v, original)
522+
}
523+
480524
fn test_minmax_compensated_vectors<const NBITS: usize, R>(dim: usize, rng: &mut R)
481525
where
482526
Unsigned: Representation<NBITS>,
@@ -494,70 +538,11 @@ mod minmax_vector_tests {
494538
{
495539
assert!(dim <= bit_scale::<NBITS>() as usize);
496540

497-
// Create two vectors with known compensation values
498-
let mut v1 = Data::<NBITS>::new_boxed(dim);
499-
let mut v2 = Data::<NBITS>::new_boxed(dim);
500-
501-
let domain = Unsigned::domain_const::<NBITS>();
502-
let code_distribution = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
503-
504-
// Set bit values
505-
{
506-
let mut bitslice1 = v1.vector_mut();
507-
let mut bitslice2 = v2.vector_mut();
508-
509-
for i in 0..dim {
510-
bitslice1.set(i, code_distribution.sample(rng)).unwrap();
511-
bitslice2.set(i, code_distribution.sample(rng)).unwrap();
512-
}
513-
}
514-
let a_rnd = Uniform::new_inclusive(0.0, 2.0).unwrap();
515-
let b_rnd = Uniform::new_inclusive(0.0, 2.0).unwrap();
516-
517-
// Set compensation coefficients
518-
// v1: X = a1 * X' + b1
519-
// v2: Y = a2 * Y' + b2
520-
let a1 = a_rnd.sample(rng);
521-
let b1 = b_rnd.sample(rng);
522-
let a2 = a_rnd.sample(rng);
523-
let b2 = b_rnd.sample(rng);
524-
525-
// Calculate sum of vector elements for n values
526-
let sum1: f32 = (0..dim).map(|i| v1.vector().get(i).unwrap() as f32).sum();
527-
let sum2: f32 = (0..dim).map(|i| v2.vector().get(i).unwrap() as f32).sum();
528-
529-
// Create original full-precision vectors for reference calculations
530-
let mut original1 = Vec::with_capacity(dim);
531-
let mut original2 = Vec::with_capacity(dim);
532-
533-
// Calculate the reconstructed original vectors and their norms
534-
for i in 0..dim {
535-
let val1 = a1 * v1.vector().get(i).unwrap() as f32 + b1;
536-
let val2 = a2 * v2.vector().get(i).unwrap() as f32 + b2;
537-
original1.push(val1);
538-
original2.push(val2);
539-
}
540-
541-
// Calculate squared norms
542-
let norm1_squared: f32 = original1.iter().map(|x| x * x).sum();
543-
let norm2_squared: f32 = original2.iter().map(|x| x * x).sum();
544-
545-
// Set compensation coefficients
546-
v1.set_meta(MinMaxCompensation {
547-
a: a1,
548-
b: b1,
549-
n: a1 * sum1,
550-
norm_squared: norm1_squared,
551-
dim: dim as u32,
552-
});
541+
let (v1, original1) = random_minmax_vector::<NBITS>(dim, rng);
542+
let (v2, original2) = random_minmax_vector::<NBITS>(dim, rng);
553543

554-
v2.set_meta(MinMaxCompensation {
555-
a: a2,
556-
b: b2,
557-
n: a2 * sum2,
558-
norm_squared: norm2_squared,
559-
dim: dim as u32,
560-
});
544+
let norm1_squared = v1.meta().norm_squared;
545+
let norm2_squared = v2.meta().norm_squared;
561546

562547
// Calculate raw integer dot product
563548
let expected_ip = (0..dim).map(|i| original1[i] * original2[i]).sum::<f32>();
@@ -741,4 +726,58 @@ mod minmax_vector_tests {
741726
test_minmax_compensated!(unsigned_minmax_compensated_test_u2, 2, 0xaedf3d2a223b7b77);
742727
test_minmax_compensated!(unsigned_minmax_compensated_test_u4, 4, 0xf60c0c8d1aadc126);
743728
test_minmax_compensated!(unsigned_minmax_compensated_test_u8, 8, 0x09fa14c42a9d7d98);
729+
730+
/// Test the heterogeneous MinMax kernel for N-bit queries × M-bit database vectors.
731+
///
732+
/// Verifies that `kernel::<N, M, _>` produces inner-product and squared-L2
733+
/// results matching the full-precision reference, for random codes and
734+
/// random compensation coefficients.
735+
fn test_minmax_heterogeneous_kernel<const N: usize, const M: usize, R>(dim: usize, rng: &mut R)
736+
where
737+
Unsigned: Representation<N> + Representation<M>,
738+
InnerProduct: for<'a, 'b> PureDistanceFunction<
739+
BitSlice<'a, N, Unsigned>,
740+
BitSlice<'b, M, Unsigned>,
741+
distances::MathematicalResult<u32>,
742+
>,
743+
R: Rng,
744+
{
745+
let (v_query, original1) = random_minmax_vector::<N>(dim, rng);
746+
let (v_data, original2) = random_minmax_vector::<M>(dim, rng);
747+
748+
// ── Inner Product ──
749+
let expected_ip: f32 = original1.iter().zip(&original2).map(|(x, y)| x * y).sum();
750+
let computed_ip = kernel(v_query.reborrow(), v_data.reborrow(), |v, _, _| v)
751+
.unwrap()
752+
.into_inner();
753+
assert!(
754+
(expected_ip - computed_ip).abs() / expected_ip.abs().max(1e-10) < 1e-6,
755+
"Heterogeneous IP ({},{}) failed: expected {}, got {} on dim: {}",
756+
N,
757+
M,
758+
expected_ip,
759+
computed_ip,
760+
dim,
761+
);
762+
}
763+
764+
macro_rules! test_minmax_heterogeneous {
765+
($name:ident, $N:literal, $M:literal, $seed:literal) => {
766+
#[test]
767+
fn $name() {
768+
let mut rng = StdRng::seed_from_u64($seed);
769+
// Use the smaller bit width's scale as max dimension.
770+
const MAX_DIM: usize = bit_scale::<$M>() as usize;
771+
for dim in 1..=MAX_DIM {
772+
for _ in 0..TRIALS {
773+
test_minmax_heterogeneous_kernel::<$N, $M, _>(dim, &mut rng);
774+
}
775+
}
776+
}
777+
};
778+
}
779+
780+
test_minmax_heterogeneous!(minmax_heterogeneous_8x4, 8, 4, 0xb7c3d9e5f1a20864);
781+
test_minmax_heterogeneous!(minmax_heterogeneous_8x2, 8, 2, 0x4e8f2c6a1d3b5079);
782+
test_minmax_heterogeneous!(minmax_heterogeneous_8x1, 8, 1, 0x1b0f2c614d2a7141);
744783
}

diskann-wide/src/arch/aarch64/double.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
* Licensed under the MIT license.
44
*/
55

6+
use std::arch::aarch64::*;
7+
68
use half::f16;
79

810
use crate::{
@@ -75,6 +77,17 @@ doubled::double_scalar_shift!(Doubled<Doubled<i8x16>>);
7577
doubled::double_scalar_shift!(Doubled<Doubled<i16x8>>);
7678
doubled::double_scalar_shift!(Doubled<Doubled<i32x4>>);
7779

80+
//////////////
81+
// ZipUnzip //
82+
//////////////
83+
84+
super::macros::aarch64_zipunzip!(i8x16, vzip1q_s8, vzip2q_s8, vuzp1q_s8, vuzp2q_s8);
85+
super::macros::aarch64_zipunzip!(i16x8, vzip1q_s16, vzip2q_s16, vuzp1q_s16, vuzp2q_s16);
86+
super::macros::aarch64_zipunzip!(i32x4, vzip1q_s32, vzip2q_s32, vuzp1q_s32, vuzp2q_s32);
87+
super::macros::aarch64_zipunzip!(u8x16, vzip1q_u8, vzip2q_u8, vuzp1q_u8, vuzp2q_u8);
88+
super::macros::aarch64_zipunzip!(u32x4, vzip1q_u32, vzip2q_u32, vuzp1q_u32, vuzp2q_u32);
89+
super::macros::aarch64_zipunzip!(f16x8, vzip1q_u16, vzip2q_u16, vuzp1q_u16, vuzp2q_u16);
90+
7891
//-------------//
7992
// Conversions //
8093
//-------------//
@@ -230,6 +243,8 @@ mod tests {
230243

231244
// Bit ops
232245
test_utils::ops::test_bitops!(u8x32, 0xd62d8de09f82ed4e, test_neon());
246+
test_utils::ops::test_splitjoin!(u8x32 => u8x16, 0x2e301b7e12090d5c, test_neon());
247+
test_utils::ops::test_zipunzip!(u8x32 => u8x16, 0xa1b2c3d4e5f67890, test_neon());
233248
}
234249

235250
mod test_u8x64 {
@@ -238,6 +253,7 @@ mod tests {
238253

239254
// Bit ops
240255
test_utils::ops::test_bitops!(u8x64, 0xd62d8de09f82ed4e, test_neon());
256+
test_utils::ops::test_splitjoin!(u8x64 => u8x32, 0x2e301b7e12090d5c, test_neon());
241257
}
242258

243259
// u32s
@@ -250,6 +266,8 @@ mod tests {
250266

251267
// Reductions
252268
test_utils::ops::test_sumtree!(u32x8, 0x90a59e23ad545de1, test_neon());
269+
test_utils::ops::test_splitjoin!(u32x8 => u32x4, 0x2e301b7e12090d5c, test_neon());
270+
test_utils::ops::test_zipunzip!(u32x8 => u32x4, 0x4e7c0a3d5b9f2816, test_neon());
253271
}
254272

255273
mod test_u32x16 {
@@ -261,6 +279,7 @@ mod tests {
261279

262280
// Reductions
263281
test_utils::ops::test_sumtree!(u32x16, 0x90a59e23ad545de1, test_neon());
282+
test_utils::ops::test_splitjoin!(u32x16 => u32x8, 0x2e301b7e12090d5c, test_neon());
264283
}
265284

266285
// u64s
@@ -270,6 +289,7 @@ mod tests {
270289

271290
// Bit ops
272291
test_utils::ops::test_bitops!(u64x4, 0xc4491a44af4aa58e, test_neon());
292+
test_utils::ops::test_splitjoin!(u64x4 => u64x2, 0x2e301b7e12090d5c, test_neon());
273293
}
274294

275295
// i8s
@@ -280,6 +300,8 @@ mod tests {
280300
// Bit ops
281301
test_utils::ops::test_bitops!(i8x32, 0xd62d8de09f82ed4e, test_neon());
282302
test_utils::ops::test_abs!(i8x32, 0xd62d8de09f82ed4e, test_neon());
303+
test_utils::ops::test_splitjoin!(i8x32 => i8x16, 0x2e301b7e12090d5c, test_neon());
304+
test_utils::ops::test_zipunzip!(i8x32 => i8x16, 0xc7e3a92f1d8b5604, test_neon());
283305
}
284306

285307
mod test_i8x64 {
@@ -289,6 +311,7 @@ mod tests {
289311
// Bit ops
290312
test_utils::ops::test_bitops!(i8x64, 0xd62d8de09f82ed4e, test_neon());
291313
test_utils::ops::test_abs!(i8x64, 0xd62d8de09f82ed4e, test_neon());
314+
test_utils::ops::test_splitjoin!(i8x64 => i8x32, 0x2e301b7e12090d5c, test_neon());
292315
}
293316

294317
// i16s
@@ -299,6 +322,8 @@ mod tests {
299322
// Bit ops
300323
test_utils::ops::test_bitops!(i16x16, 0x9167644fc4ad5cfa, test_neon());
301324
test_utils::ops::test_abs!(i16x16, 0x9167644fc4ad5cfa, test_neon());
325+
test_utils::ops::test_splitjoin!(i16x16 => i16x8, 0x2e301b7e12090d5c, test_neon());
326+
test_utils::ops::test_zipunzip!(i16x16 => i16x8, 0x3f84d1b6e7a20c59, test_neon());
302327
}
303328

304329
mod test_i16x32 {
@@ -308,6 +333,7 @@ mod tests {
308333
// Bit ops
309334
test_utils::ops::test_bitops!(i16x32, 0x9167644fc4ad5cfa, test_neon());
310335
test_utils::ops::test_abs!(i16x32, 0x9167644fc4ad5cfa, test_neon());
336+
test_utils::ops::test_splitjoin!(i16x32 => i16x16, 0x2e301b7e12090d5c, test_neon());
311337
}
312338

313339
// i32s
@@ -340,6 +366,8 @@ mod tests {
340366

341367
// Reductions
342368
test_utils::ops::test_sumtree!(i32x8, 0x90a59e23ad545de1, test_neon());
369+
test_utils::ops::test_splitjoin!(i32x8 => i32x4, 0x2e301b7e12090d5c, test_neon());
370+
test_utils::ops::test_zipunzip!(i32x8 => i32x4, 0x92d5f4a83e1b07c6, test_neon());
343371
}
344372

345373
mod test_i32x16 {
@@ -371,6 +399,7 @@ mod tests {
371399

372400
// Reductions
373401
test_utils::ops::test_sumtree!(i32x16, 0x90a59e23ad545de1, test_neon());
402+
test_utils::ops::test_splitjoin!(i32x16 => i32x8, 0x2e301b7e12090d5c, test_neon());
374403
}
375404

376405
// Conversions
@@ -388,4 +417,8 @@ mod tests {
388417
test_utils::ops::test_cast!(f32x16 => f16x16, 0xba8fe343fc9dbeff, test_neon());
389418

390419
test_utils::ops::test_cast!(i32x8 => f32x8, 0xba8fe343fc9dbeff, test_neon());
420+
421+
// SplitJoin + ZipUnzip for f16x16
422+
test_utils::ops::test_splitjoin!(f16x16 => f16x8, 0x2e301b7e12090d5c, test_neon());
423+
test_utils::ops::test_zipunzip!(f16x16 => f16x8, 0x6b2e0f9d8a41c573, test_neon());
391424
}

diskann-wide/src/arch/aarch64/macros.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,3 +569,51 @@ pub(crate) use aarch64_define_loadstore;
569569
pub(crate) use aarch64_define_register;
570570
pub(crate) use aarch64_define_splat;
571571
pub(crate) use aarch64_splitjoin;
572+
573+
/// Implement [`ZipUnzip`] for a [`Doubled`] type using Neon zip/unzip intrinsics.
574+
///
575+
/// ## Parameters
576+
///
577+
/// * `$half` — the native 128-bit Neon type (e.g. `i8x16`)
578+
/// * `$zip1` — `vzip1q_*` intrinsic (interleave lower halves)
579+
/// * `$zip2` — `vzip2q_*` intrinsic (interleave upper halves)
580+
/// * `$uzp1` — `vuzp1q_*` intrinsic (collect even-indexed elements)
581+
/// * `$uzp2` — `vuzp2q_*` intrinsic (collect odd-indexed elements)
582+
///
583+
/// ## Safety
584+
///
585+
/// The caller must ensure the provided intrinsics match the element type of `$half`.
586+
macro_rules! aarch64_zipunzip {
587+
($half:path, $zip1:ident, $zip2:ident, $uzp1:ident, $uzp2:ident) => {
588+
impl $crate::ZipUnzip for $crate::doubled::Doubled<$half> {
589+
#[inline(always)]
590+
fn zip(halves: $crate::LoHi<<Self as $crate::SplitJoin>::Halved>) -> Self {
591+
use $crate::SIMDVector;
592+
// SAFETY: Caller asserts that these intrinsics match the element type.
593+
unsafe {
594+
let lo_raw = halves.lo.to_underlying();
595+
let hi_raw = halves.hi.to_underlying();
596+
$crate::doubled::Doubled(
597+
<$half>::from_underlying(halves.lo.arch(), $zip1(lo_raw, hi_raw)),
598+
<$half>::from_underlying(halves.lo.arch(), $zip2(lo_raw, hi_raw)),
599+
)
600+
}
601+
}
602+
603+
#[inline(always)]
604+
fn unzip(self) -> $crate::LoHi<<Self as $crate::SplitJoin>::Halved> {
605+
use $crate::SIMDVector;
606+
// SAFETY: Caller asserts that these intrinsics match the element type.
607+
unsafe {
608+
let lo_raw = self.0.to_underlying();
609+
let hi_raw = self.1.to_underlying();
610+
$crate::LoHi::new(
611+
<$half>::from_underlying(self.0.arch(), $uzp1(lo_raw, hi_raw)),
612+
<$half>::from_underlying(self.0.arch(), $uzp2(lo_raw, hi_raw)),
613+
)
614+
}
615+
}
616+
}
617+
};
618+
}
619+
pub(crate) use aarch64_zipunzip;

0 commit comments

Comments
 (0)