Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions distr_test/tests/cdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,55 @@ fn poisson() {
}
}

#[test]
fn truncated_normal() {
let parameters = [
// Rejection sampling: interval spans the mean with diff >= 1 stddev
(0.0, 1.0, -1.0, 1.0),
(0.0, 1.0, 0.0, 2.0),
(1.0, 2.0, -1.0, 3.0),
(5.0, 0.5, 4.0, 6.0),
(10.0, 1.0, 8.0, 12.0),
// OneSided (lower bound only): upper = +inf, std_lower > 0.3
(0.0, 1.0, 1.0, f64::INFINITY),
(2.0, 0.5, 3.0, f64::INFINITY),
// OneSided (upper bound only): lower = -inf, std_upper < -0.3
(0.0, 1.0, f64::NEG_INFINITY, -1.0),
(2.0, 0.5, f64::NEG_INFINITY, 1.0),
// TailInterval (lower tail): std_lower >= 0.5, diff >= 1.0, two-sided
(0.0, 1.0, 1.0, 3.0),
(5.0, 1.0, 6.0, 8.0),
// TailInterval (upper tail): std_upper <= -0.5, diff >= 1.0, two-sided
(0.0, 1.0, -3.0, -1.0),
(5.0, 1.0, 2.0, 4.0),
// TwoSided: narrow interval not matching any other conditions
(0.0, 1.0, 0.1, 0.9),
(0.0, 1.0, 0.35, 1.5),
];

for (seed, (mu, sigma, lower, upper)) in parameters.into_iter().enumerate() {
let dist = rand_distr::NormalTruncated::new(mu, sigma, lower, upper).unwrap();
dbg!(&dist);
let analytic = |x| {
if x < lower {
0.0
} else if x > upper {
1.0
} else {
let standard_lower = (lower - mu) / sigma;
let standard_upper = (upper - mu) / sigma;
let standard_x = (x - mu) / sigma;

let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap();

let z = normal.cdf(standard_upper) - normal.cdf(standard_lower);
(normal.cdf(standard_x) - normal.cdf(standard_lower)) / z
}
};
test_continuous(seed as u64, dist, analytic);
}
}

fn ln_factorial(n: u64) -> f64 {
(n as f64 + 1.0).lgamma().0
}
Expand Down
11 changes: 10 additions & 1 deletion distr_test/tests/ks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,15 @@ where

/// Tests a distribution against an analytical CDF.
/// The CDF has to be continuous.
pub fn test_continuous(seed: u64, dist: impl Distribution<f64>, cdf: impl Fn(f64) -> f64) {
pub fn test_continuous(
seed: u64,
dist: impl Distribution<f64> + std::fmt::Debug,
cdf: impl Fn(f64) -> f64,
) {
let time = std::time::Instant::now();
println!("Testing distribution: {:?}", &dist);
let ecdf = sample_ecdf(seed, dist);
println!("Sampling took {} seconds", time.elapsed().as_secs_f64());
let ks_statistic = kolmogorov_smirnov_statistic_continuous(ecdf, cdf);

let critical_value = critical_value();
Expand All @@ -125,7 +132,9 @@ where
D: Distribution<I>,
F: Fn(i64) -> f64,
{
let time = std::time::Instant::now();
let ecdf = sample_ecdf(seed, dist);
println!("Sampling took {} seconds", time.elapsed().as_secs_f64());
let ks_statistic = kolmogorov_smirnov_statistic_discrete(ecdf, cdf);

// This critical value is bigger than it could be for discrete distributions, but because of large sample sizes this should not matter too much
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
//! - Misc. distributions
//! - [`InverseGaussian`] distribution
//! - [`NormalInverseGaussian`] distribution
//! - [`NormalTruncated`] distribution

#[cfg(feature = "alloc")]
extern crate alloc;
Expand Down Expand Up @@ -110,6 +111,7 @@ pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal};
pub use self::normal_inverse_gaussian::{
Error as NormalInverseGaussianError, NormalInverseGaussian,
};
pub use self::normal_truncated::{Error as NormalTruncatedError, NormalTruncated};
pub use self::pareto::{Error as ParetoError, Pareto};
pub use self::pert::{Pert, PertBuilder, PertError};
pub use self::poisson::{Error as PoissonError, Poisson};
Expand Down Expand Up @@ -211,6 +213,7 @@ mod hypergeometric;
mod inverse_gaussian;
mod normal;
mod normal_inverse_gaussian;
mod normal_truncated;
mod pareto;
mod pert;
pub(crate) mod poisson;
Expand Down
283 changes: 283 additions & 0 deletions src/normal_truncated.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
#[allow(unused_imports)]
use num_traits::Float;
use rand::{Rng, RngExt, distr::Distribution};

/// The [truncated normal distribution](https://en.wikipedia.org/wiki/Truncated_normal_distribution).
///
/// # Current Implementation
/// We follow the approach described in
/// Robert, Christian P. (1995). "Simulation of truncated normal variables".
/// Statistics and Computing. 5 (2): 121–125.

#[derive(Debug)]
pub struct NormalTruncated(Method);

#[derive(Debug)]
enum Method {
Rejection(NormalTruncatedRejection),
OneSided(bool, NormalTruncatedOneSided), // bool indicates if lower bound is used
TailInterval(bool, NormalTruncatedTailInterval), // bool indicates mirrored upper-tail proposal
TwoSided(NormalTruncatedTwoSided),
}

#[derive(Debug)]
/// Errors that can occur when constructing a `NormalTruncated` distribution.
pub enum Error {
/// The standard deviation was not positive.
InvalidStdDev,
/// The lower bound was not less than the upper bound.
InvalidBounds,
}

impl NormalTruncated {
/// Constructs a new `NormalTruncated` distribution with the given
/// mean, standard deviation, lower bound, and upper bound.
pub fn new(mean: f64, stddev: f64, lower: f64, upper: f64) -> Result<Self, Error> {
if !(stddev > 0.0) {
return Err(Error::InvalidStdDev);
}
Comment on lines +35 to +38
Copy link
Contributor

@mstoeckl mstoeckl Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may slightly improve API consistency to have NormalTruncated generalize Normal and allow stddev = 0.0 iff mean is in [lower, upper). The wikipedia page is also careful to define the real-valued truncated normal distribution in a way that allows this.

if !(lower < upper) {
return Err(Error::InvalidBounds);
}

let std_lower = (lower - mean) / stddev;
let std_upper = (upper - mean) / stddev;
Comment on lines +35 to +44
Copy link
Contributor

@mstoeckl mstoeckl Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The specific first issue that I noticed causing panics when fuzzing is that even if lower < upper, std_lower may equal std_upper, making the sampling in NormalTruncatedTwoSided fail.

(I'm not certain how best to resolve this. Perhaps make NormalTruncatedTwoSided::sample sample based on the original lower..upper range?)


if upper == f64::INFINITY {
// This threshold depends on how fast normal vs exponential sampling is. This value was found empirically, but it can probably be tuned better.
if std_lower > 0.3 {
// One sided truncation, lower bound only
Ok(NormalTruncated(Method::OneSided(
true,
NormalTruncatedOneSided::new(mean, stddev, std_lower),
)))
} else {
// We use naive rejection sampling
// Also catches the case where both bounds are infinite
Ok(NormalTruncated(Method::Rejection(
NormalTruncatedRejection {
normal: crate::Normal::new(mean, stddev).unwrap(),
lower,
upper,
},
)))
}
} else if lower == f64::NEG_INFINITY {
// This threshold depends on how fast normal vs exponential sampling is. This value was found empirically, but it can probably be tuned better.
if std_upper < -0.3 {
// One sided truncation, upper bound only
Ok(NormalTruncated(Method::OneSided(
false,
NormalTruncatedOneSided::new(-mean, stddev, -std_upper),
)))
} else {
// We use naive rejection sampling
Ok(NormalTruncated(Method::Rejection(
NormalTruncatedRejection {
normal: crate::Normal::new(mean, stddev).unwrap(),
lower,
upper,
},
)))
}
} else {
// Two sided truncation
let diff = std_upper - std_lower;
// Threshold can probably be tuned better for performance
if diff >= 1.0 && std_lower <= 0.3 && std_upper >= -0.3 {
// Naive rejection sampling
Ok(NormalTruncated(Method::Rejection(
NormalTruncatedRejection {
normal: crate::Normal::new(mean, stddev).unwrap(),
lower,
upper,
},
)))
} else if std_lower >= 0.5 && diff >= 1.0 {
// Two sided truncation in the upper tail.
// Use the one-sided sampler as a proposal and reject past the upper bound.
Ok(NormalTruncated(Method::TailInterval(
false,
NormalTruncatedTailInterval::new(
NormalTruncatedOneSided::new(mean, stddev, std_lower),
upper,
),
)))
} else if std_upper <= -0.5 && diff >= 1.0 {
// Mirror the lower-tail case to reuse the same one-sided sampler.
Ok(NormalTruncated(Method::TailInterval(
true,
NormalTruncatedTailInterval::new(
NormalTruncatedOneSided::new(-mean, stddev, -std_upper),
-lower,
),
)))
} else {
// Two sided truncation
Ok(NormalTruncated(Method::TwoSided(
NormalTruncatedTwoSided::new(mean, stddev, std_lower, std_upper),
)))
}
}
}
}

impl Distribution<f64> for NormalTruncated {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
match &self.0 {
Method::Rejection(rej) => rej.sample(rng),
Method::OneSided(true, one_sided) => one_sided.sample(rng),
Method::OneSided(false, one_sided) => -one_sided.sample(rng),
Method::TailInterval(false, tail_interval) => tail_interval.sample(rng),
Method::TailInterval(true, tail_interval) => -tail_interval.sample(rng),
Method::TwoSided(two_sided) => two_sided.sample(rng),
}
}
}

/// A truncated normal distribution using naive rejection sampling.
/// We use this when the acceptance rate is high enough.
#[derive(Debug)]
struct NormalTruncatedRejection {
normal: crate::Normal<f64>,
lower: f64,
upper: f64,
}

impl Distribution<f64> for NormalTruncatedRejection {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let mut sample;
loop {
sample = self.normal.sample(rng);
if sample >= self.lower && sample <= self.upper {
break;
}
}
sample
}
}

#[derive(Debug)]
struct NormalTruncatedOneSided {
alpha_star: f64,
lower_bound: f64,
exp_distribution: crate::Exp<f64>,
mu: f64,
sigma: f64,
}

impl NormalTruncatedOneSided {
fn new(mu: f64, sigma: f64, standard_lower_bound: f64) -> Self {
let alpha_star = (standard_lower_bound + (standard_lower_bound.powi(2) + 4.0).sqrt()) / 2.0;
let lambda = alpha_star;
NormalTruncatedOneSided {
alpha_star,
lower_bound: standard_lower_bound,
exp_distribution: crate::Exp::new(lambda).unwrap(),
mu,
sigma,
}
}
}

impl Distribution<f64> for NormalTruncatedOneSided {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
loop {
let z = self.exp_distribution.sample(rng) + self.lower_bound;
let u: f64 = rng.random();
let rho = (-0.5 * (z - self.alpha_star).powi(2)).exp();
if u <= rho {
return self.mu + self.sigma * z;
}
}
}
}

#[derive(Debug)]
struct NormalTruncatedTailInterval {
proposal: NormalTruncatedOneSided,
upper_bound: f64,
}

impl NormalTruncatedTailInterval {
fn new(proposal: NormalTruncatedOneSided, upper_bound: f64) -> Self {
NormalTruncatedTailInterval {
proposal,
upper_bound,
}
}
}

impl Distribution<f64> for NormalTruncatedTailInterval {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
loop {
let sample = self.proposal.sample(rng);
if sample <= self.upper_bound {
return sample;
}
}
}
}

#[derive(Debug)]
struct NormalTruncatedTwoSided {
mu: f64,
sigma: f64,
// In standard normal coordinates
standard_lower: f64,
// In standard normal coordinates
standard_upper: f64,
}

impl NormalTruncatedTwoSided {
fn new(mu: f64, sigma: f64, standard_lower: f64, standard_upper: f64) -> Self {
NormalTruncatedTwoSided {
mu,
sigma,
standard_lower,
standard_upper,
}
}
}

impl Distribution<f64> for NormalTruncatedTwoSided {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
loop {
let z = rng.random_range(self.standard_lower..self.standard_upper);
let u: f64 = rng.random();
let rho = if self.standard_lower <= 0.0 && self.standard_upper >= 0.0 {
(-0.5 * z.powi(2)).exp()
} else if self.standard_upper < 0.0 {
(0.5 * (self.standard_upper.powi(2) - z.powi(2))).exp()
} else {
(0.5 * (self.standard_lower.powi(2) - z.powi(2))).exp()
};
if u <= rho {
return self.mu + self.sigma * z;
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn uses_tail_interval_method_for_positive_tail() {
let distr = NormalTruncated::new(0.0, 1.0, 2.0, 3.0).unwrap();
assert!(matches!(distr.0, Method::TailInterval(false, _)));
}

#[test]
fn uses_tail_interval_method_for_negative_tail() {
let distr = NormalTruncated::new(0.0, 1.0, -3.0, -2.0).unwrap();
assert!(matches!(distr.0, Method::TailInterval(true, _)));
}

#[test]
fn keeps_uniform_two_sided_method_for_narrow_positive_interval() {
let distr = NormalTruncated::new(0.0, 1.0, 0.1, 0.2).unwrap();
assert!(matches!(distr.0, Method::TwoSided(_)));
}
}
Loading