Skip to content

Commit 4afdbf5

Browse files
committed
🤖 parameterize tests
1 parent 10a13a2 commit 4afdbf5

File tree

4 files changed

+74
-31
lines changed

4 files changed

+74
-31
lines changed

downsample_rs/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ num-traits = { version = "0.2.15", default-features = false }
1616
rayon = { version = "1.6.0", default-features = false }
1717

1818
[dev-dependencies]
19+
rstest = { version = "0.16", default-features = false }
20+
rstest_reuse = { version = "0.5", default-features = false }
1921
criterion = "0.4.0"
2022
dev_utils = { path = "dev_utils" }
2123

downsample_rs/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
// It is necessary to import this at the root of the crate
2+
// See: https://github.com/la10736/rstest/tree/master/rstest_reuse#use-rstest_resuse-at-the-top-of-your-crate
3+
#[cfg(test)]
4+
use rstest_reuse;
5+
16
pub mod minmax;
27
pub use minmax::*;
38
pub mod lttb;

downsample_rs/src/m4/scalar.rs

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,24 @@ mod tests {
116116
fn test_m4_scalar_without_x_parallel_correct() {
117117
let arr = (0..100).map(|x| x as f32).collect::<Vec<f32>>();
118118
let arr = Array1::from(arr);
119-
let half_n_threads: usize = available_parallelism().map(|x| x.get()).unwrap_or(2) / 2;
120-
121-
let sampled_indices = m4_scalar_without_x_parallel(arr.view(), 12, half_n_threads);
122-
let sampled_values = sampled_indices.mapv(|x| arr[x]);
123119

124120
let expected_indices = vec![0, 0, 33, 33, 34, 34, 66, 66, 67, 67, 99, 99];
121+
let expected_indices = Array1::from(expected_indices);
125122
let expected_values = expected_indices
126123
.iter()
127124
.map(|x| *x as f32)
128125
.collect::<Vec<f32>>();
126+
let expected_values = Array1::from(expected_values);
129127

130-
assert_eq!(sampled_indices, Array1::from(expected_indices));
131-
assert_eq!(sampled_values, Array1::from(expected_values));
128+
let all_threads = available_parallelism().map(|x| x.get()).unwrap_or(2);
129+
let nb_threads = vec![1, all_threads / 2, all_threads, all_threads + 1];
130+
131+
for n_threads in nb_threads {
132+
let sampled_indices = m4_scalar_without_x_parallel(arr.view(), 12, n_threads);
133+
let sampled_values = sampled_indices.mapv(|x| arr[x]);
134+
assert_eq!(sampled_indices, expected_indices);
135+
assert_eq!(sampled_values, expected_values);
136+
}
132137
}
133138

134139
#[test]
@@ -157,19 +162,23 @@ mod tests {
157162
let x = Array1::from(x);
158163
let arr = (0..100).map(|x| x as f32).collect::<Vec<f32>>();
159164
let arr = Array1::from(arr);
160-
let half_n_threads: usize = available_parallelism().map(|x| x.get()).unwrap_or(2) / 2;
161-
162-
let sampled_indices = m4_scalar_with_x_parallel(x.view(), arr.view(), 12, half_n_threads);
163-
let sampled_values = sampled_indices.mapv(|x| arr[x]);
164165

165166
let expected_indices = vec![0, 0, 33, 33, 34, 34, 66, 66, 67, 67, 99, 99];
167+
let expected_indices = Array1::from(expected_indices);
166168
let expected_values = expected_indices
167169
.iter()
168170
.map(|x| *x as f32)
169171
.collect::<Vec<f32>>();
170-
171-
assert_eq!(sampled_indices, Array1::from(expected_indices));
172-
assert_eq!(sampled_values, Array1::from(expected_values));
172+
let expected_values = Array1::from(expected_values);
173+
174+
let all_threads = available_parallelism().map(|x| x.get()).unwrap_or(2);
175+
let nb_threads = vec![1, all_threads / 2, all_threads, all_threads + 1];
176+
for n_threads in nb_threads {
177+
let sampled_indices = m4_scalar_with_x_parallel(x.view(), arr.view(), 12, n_threads);
178+
let sampled_values = sampled_indices.mapv(|x| arr[x]);
179+
assert_eq!(sampled_indices, expected_indices);
180+
assert_eq!(sampled_values, expected_values);
181+
}
173182
}
174183

175184
#[test]
@@ -247,16 +256,19 @@ mod tests {
247256
let n_out: usize = 204;
248257
let x = (0..n as i32).collect::<Vec<i32>>();
249258
let x = Array1::from(x);
250-
let half_n_threads: usize = available_parallelism().map(|x| x.get()).unwrap_or(2) / 2;
259+
let all_threads = available_parallelism().map(|x| x.get()).unwrap_or(2);
260+
let nb_threads = vec![1, all_threads / 2, all_threads, all_threads + 1];
251261
for _ in 0..100 {
252262
let arr = get_array_f32(n);
253263
let idxs1 = m4_scalar_without_x(arr.view(), n_out);
254-
let idxs2 = m4_scalar_without_x_parallel(arr.view(), n_out, half_n_threads);
255-
let idxs3 = m4_scalar_with_x(x.view(), arr.view(), n_out);
256-
let idxs4 = m4_scalar_with_x_parallel(x.view(), arr.view(), n_out, half_n_threads);
264+
let idxs2 = m4_scalar_with_x(x.view(), arr.view(), n_out);
257265
assert_eq!(idxs1, idxs2);
258-
assert_eq!(idxs1, idxs3);
259-
assert_eq!(idxs1, idxs4);
266+
for &n_threads in nb_threads.iter() {
267+
let idxs3 = m4_scalar_without_x_parallel(arr.view(), n_out, n_threads);
268+
let idxs4 = m4_scalar_with_x_parallel(x.view(), arr.view(), n_out, n_threads);
269+
assert_eq!(idxs1, idxs3);
270+
assert_eq!(idxs1, idxs4); // TODO: this should not fail
271+
}
260272
}
261273
}
262274
}

downsample_rs/src/searchsorted.rs

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ use ndarray::ArrayView1;
22

33
use rayon::iter::IndexedParallelIterator;
44
use rayon::prelude::*;
5-
use std::thread::available_parallelism;
65

76
use super::types::Num;
87
use num_traits::{AsPrimitive, FromPrimitive};
@@ -191,12 +190,29 @@ where
191190

192191
#[cfg(test)]
193192
mod tests {
193+
use rstest::rstest;
194+
use rstest_reuse::{self, *};
195+
194196
use super::*;
195197
use ndarray::Array1;
198+
use std::thread::available_parallelism;
196199

197200
extern crate dev_utils;
198201
use dev_utils::utils::get_random_array;
199202

203+
fn get_all_threads() -> usize {
204+
available_parallelism().map(|x| x.get()).unwrap_or(1)
205+
}
206+
207+
// Template for the n_threads matrix
208+
#[template]
209+
#[rstest]
210+
#[case(1)]
211+
#[case(get_all_threads() / 2)]
212+
#[case(get_all_threads())]
213+
#[case(get_all_threads() * 2)]
214+
fn threads(#[case] n_threads: usize) {}
215+
200216
#[test]
201217
fn test_search_sorted_identicial_to_np_linspace_searchsorted() {
202218
// Create a 0..9999 array
@@ -281,42 +297,50 @@ mod tests {
281297
// assert_eq!(binary_search_with_mid(arr.view(), 11, 0, arr.len() - 1, 9), 10);
282298
}
283299

284-
#[test]
285-
fn test_get_equidistant_bin_idxs() {
300+
#[apply(threads)]
301+
fn test_get_equidistant_bin_idxs(n_threads: usize) {
302+
let expected_indices = vec![0, 4, 7];
303+
286304
let arr = Array1::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
287305
let bin_idxs_iter = get_equidistant_bin_idx_iterator(arr.view(), 3);
288306
let bin_idxs = bin_idxs_iter.map(|x| x.unwrap().0).collect::<Vec<usize>>();
289-
let half_n_threads: usize = available_parallelism().map(|x| x.get()).unwrap_or(2) / 2;
290-
assert_eq!(bin_idxs, vec![0, 4, 7]);
291-
let bin_idxs_iter =
292-
get_equidistant_bin_idx_iterator_parallel(arr.view(), 3, half_n_threads);
307+
assert_eq!(bin_idxs, expected_indices);
308+
309+
let bin_idxs_iter = get_equidistant_bin_idx_iterator_parallel(arr.view(), 3, n_threads);
293310
let bin_idxs = bin_idxs_iter
294311
.map(|x| x.map(|x| x.unwrap().0).collect::<Vec<usize>>())
295312
.flatten()
296313
.collect::<Vec<usize>>();
297-
assert_eq!(bin_idxs, vec![0, 4, 7]);
314+
assert_eq!(bin_idxs, expected_indices);
298315
}
299316

300-
#[test]
301-
fn test_many_random_same_result() {
317+
#[apply(threads)]
318+
fn test_many_random_same_result(n_threads: usize) {
302319
let n = 5_000;
303320
let nb_bins = 100;
304-
let half_n_threads: usize = available_parallelism().map(|x| x.get()).unwrap_or(2) / 2;
321+
let all_threads = available_parallelism().map(|x| x.get()).unwrap_or(2);
322+
let nb_threads = vec![1, all_threads / 2, all_threads, all_threads + 1];
323+
305324
for _ in 0..100 {
306325
let arr = get_random_array::<i32>(n, i32::MIN, i32::MAX);
307326
// Sort the array
308327
let mut arr = arr.to_vec();
309328
arr.sort_by(|a, b| a.partial_cmp(b).unwrap());
310329
let arr = Array1::from(arr);
330+
311331
// Calculate the bin indexes
312332
let bin_idxs_iter = get_equidistant_bin_idx_iterator(arr.view(), nb_bins);
313333
let bin_idxs = bin_idxs_iter.map(|x| x.unwrap().0).collect::<Vec<usize>>();
334+
335+
// Calculate the bin indexes in parallel
314336
let bin_idxs_iter =
315-
get_equidistant_bin_idx_iterator_parallel(arr.view(), nb_bins, half_n_threads);
337+
get_equidistant_bin_idx_iterator_parallel(arr.view(), nb_bins, n_threads);
316338
let bin_idxs_parallel = bin_idxs_iter
317339
.map(|x| x.map(|x| x.unwrap().0).collect::<Vec<usize>>())
318340
.flatten()
319341
.collect::<Vec<usize>>();
342+
343+
// Check that the results are the same
320344
assert_eq!(bin_idxs, bin_idxs_parallel);
321345
}
322346
}

0 commit comments

Comments
 (0)