Skip to content

Commit 035d9ee

Browse files
authored
ml-dsa: use Barrett reduction instead of integer division to prevent side-channels (#1144)
1 parent 73474a9 commit 035d9ee

File tree

2 files changed

+119
-32
lines changed

2 files changed

+119
-32
lines changed

ml-dsa/src/algebra.rs

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,65 @@ pub(crate) trait Decompose {
5454
fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem);
5555
}
5656

57+
/// Constant-time division by a compile-time constant divisor.
58+
///
59+
/// This trait provides a constant-time alternative to the hardware division
60+
/// instruction, which has variable timing based on operand values.
61+
/// Uses Barrett reduction to compute `x / M` where M is a compile-time constant.
62+
pub(crate) trait ConstantTimeDiv: Unsigned {
63+
/// Bit shift for Barrett reduction, chosen to provide sufficient precision
64+
const CT_DIV_SHIFT: usize;
65+
/// Precomputed multiplier: ceil(2^SHIFT / M)
66+
const CT_DIV_MULTIPLIER: u64;
67+
68+
/// Perform constant-time division of x by `Self::U32`
69+
/// Requires: x < Q (the field modulus, ~2^23)
70+
#[allow(clippy::inline_always)] // Required for constant-time guarantees in crypto code
71+
#[inline(always)]
72+
fn ct_div(x: u32) -> u32 {
73+
// Barrett reduction: q = (x * MULTIPLIER) >> SHIFT
74+
// This gives us floor(x / M) for x < 2^SHIFT / MULTIPLIER * M
75+
let x64 = u64::from(x);
76+
let quotient = (x64 * Self::CT_DIV_MULTIPLIER) >> Self::CT_DIV_SHIFT;
77+
// SAFETY: quotient is guaranteed to fit in u32 because:
78+
// - x < Q (~2^23), so quotient = x / M < x < 2^23 < 2^32
79+
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
80+
let result = quotient as u32;
81+
result
82+
}
83+
}
84+
85+
impl<M> ConstantTimeDiv for M
86+
where
87+
M: Unsigned,
88+
{
89+
// Use a shift that provides enough precision for the ML-DSA field (Q ~ 2^23)
90+
// We need SHIFT > log2(Q) + log2(M) to ensure accuracy
91+
// With Q < 2^24 and M < 2^20, SHIFT = 48 is sufficient
92+
const CT_DIV_SHIFT: usize = 48;
93+
94+
// Precompute the multiplier at compile time
95+
// We add (M-1) before dividing to get ceiling division, ensuring we never underestimate
96+
#[allow(clippy::integer_division_remainder_used)]
97+
const CT_DIV_MULTIPLIER: u64 = (1u64 << Self::CT_DIV_SHIFT).div_ceil(M::U64);
98+
}
99+
57100
impl Decompose for Elem {
58101
// Algorithm 36 Decompose
102+
//
103+
// This implementation uses constant-time division to avoid timing side-channels.
104+
// The original algorithm used hardware division which has variable timing based
105+
// on operand values, potentially leaking secret information during signing.
59106
fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem) {
60107
let r_plus = self.clone();
61108
let r0 = r_plus.mod_plus_minus::<TwoGamma2>();
62109

63110
if r_plus - r0 == Elem::new(BaseField::Q - 1) {
64111
(Elem::new(0), r0 - Elem::new(1))
65112
} else {
66-
let mut r1 = r_plus - r0;
67-
r1.0 /= TwoGamma2::U32;
113+
let diff = r_plus - r0;
114+
// Use constant-time division instead of hardware division
115+
let r1 = Elem::new(TwoGamma2::ct_div(diff.0));
68116
(r1, r0)
69117
}
70118
}

ml-dsa/src/ntt.rs

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -50,28 +50,46 @@ pub(crate) trait Ntt {
5050
fn ntt(&self) -> Self::Output;
5151
}
5252

53+
/// Constant-time NTT butterfly layer.
54+
///
55+
/// Uses const generics to ensure loop bounds are compile-time constants,
56+
/// avoiding UDIV instructions from runtime `step_by` calculations.
57+
#[allow(clippy::inline_always)] // Required for constant-time guarantees in crypto code
58+
#[inline(always)]
59+
fn ntt_layer<const LEN: usize, const ITERATIONS: usize>(w: &mut [Elem; 256], m: &mut usize) {
60+
for i in 0..ITERATIONS {
61+
let start = i * 2 * LEN;
62+
*m += 1;
63+
let z = ZETA_POW_BITREV[*m];
64+
for j in start..(start + LEN) {
65+
let t = z * w[j + LEN];
66+
w[j + LEN] = w[j] - t;
67+
w[j] = w[j] + t;
68+
}
69+
}
70+
}
71+
5372
impl Ntt for Polynomial {
5473
type Output = NttPolynomial;
5574

5675
// Algorithm 41 NTT
76+
//
77+
// This implementation uses const-generic helper functions to ensure all loop
78+
// bounds are compile-time constants, avoiding potential UDIV instructions.
5779
fn ntt(&self) -> Self::Output {
58-
let mut w = self.0.clone();
59-
80+
let mut w: [Elem; 256] = self.0.clone().into();
6081
let mut m = 0;
61-
for len in [128, 64, 32, 16, 8, 4, 2, 1] {
62-
for start in (0..256).step_by(2 * len) {
63-
m += 1;
64-
let z = ZETA_POW_BITREV[m];
65-
66-
for j in start..(start + len) {
67-
let t = z * w[j + len];
68-
w[j + len] = w[j] - t;
69-
w[j] = w[j] + t;
70-
}
71-
}
72-
}
7382

74-
NttPolynomial::new(w)
83+
ntt_layer::<128, 1>(&mut w, &mut m);
84+
ntt_layer::<64, 2>(&mut w, &mut m);
85+
ntt_layer::<32, 4>(&mut w, &mut m);
86+
ntt_layer::<16, 8>(&mut w, &mut m);
87+
ntt_layer::<8, 16>(&mut w, &mut m);
88+
ntt_layer::<4, 32>(&mut w, &mut m);
89+
ntt_layer::<2, 64>(&mut w, &mut m);
90+
ntt_layer::<1, 128>(&mut w, &mut m);
91+
92+
NttPolynomial::new(w.into())
7593
}
7694
}
7795

@@ -89,30 +107,51 @@ pub(crate) trait NttInverse {
89107
fn ntt_inverse(&self) -> Self::Output;
90108
}
91109

110+
/// Constant-time inverse NTT butterfly layer.
111+
///
112+
/// Uses const generics to ensure loop bounds are compile-time constants,
113+
/// avoiding UDIV instructions from runtime `step_by` calculations.
114+
#[allow(clippy::inline_always)] // Required for constant-time guarantees in crypto code
115+
#[inline(always)]
116+
fn ntt_inverse_layer<const LEN: usize, const ITERATIONS: usize>(
117+
w: &mut [Elem; 256],
118+
m: &mut usize,
119+
) {
120+
for i in 0..ITERATIONS {
121+
let start = i * 2 * LEN;
122+
*m -= 1;
123+
let z = -ZETA_POW_BITREV[*m];
124+
for j in start..(start + LEN) {
125+
let t = w[j];
126+
w[j] = t + w[j + LEN];
127+
w[j + LEN] = z * (t - w[j + LEN]);
128+
}
129+
}
130+
}
131+
92132
impl NttInverse for NttPolynomial {
93133
type Output = Polynomial;
94134

95135
// Algorithm 42 NTT^{−1}
136+
//
137+
// This implementation uses const-generic helper functions to ensure all loop
138+
// bounds are compile-time constants, avoiding potential UDIV instructions.
96139
fn ntt_inverse(&self) -> Self::Output {
97140
const INVERSE_256: Elem = Elem::new(8_347_681);
98141

99-
let mut w = self.0.clone();
100-
142+
let mut w: [Elem; 256] = self.0.clone().into();
101143
let mut m = 256;
102-
for len in [1, 2, 4, 8, 16, 32, 64, 128] {
103-
for start in (0..256).step_by(2 * len) {
104-
m -= 1;
105-
let z = -ZETA_POW_BITREV[m];
106-
107-
for j in start..(start + len) {
108-
let t = w[j];
109-
w[j] = t + w[j + len];
110-
w[j + len] = z * (t - w[j + len]);
111-
}
112-
}
113-
}
114144

115-
INVERSE_256 * &Polynomial::new(w)
145+
ntt_inverse_layer::<1, 128>(&mut w, &mut m);
146+
ntt_inverse_layer::<2, 64>(&mut w, &mut m);
147+
ntt_inverse_layer::<4, 32>(&mut w, &mut m);
148+
ntt_inverse_layer::<8, 16>(&mut w, &mut m);
149+
ntt_inverse_layer::<16, 8>(&mut w, &mut m);
150+
ntt_inverse_layer::<32, 4>(&mut w, &mut m);
151+
ntt_inverse_layer::<64, 2>(&mut w, &mut m);
152+
ntt_inverse_layer::<128, 1>(&mut w, &mut m);
153+
154+
INVERSE_256 * &Polynomial::new(w.into())
116155
}
117156
}
118157

0 commit comments

Comments
 (0)