@@ -116,19 +116,24 @@ mod tests {
116116 fn test_m4_scalar_without_x_parallel_correct ( ) {
117117 let arr = ( 0 ..100 ) . map ( |x| x as f32 ) . collect :: < Vec < f32 > > ( ) ;
118118 let arr = Array1 :: from ( arr) ;
119- let half_n_threads: usize = available_parallelism ( ) . map ( |x| x. get ( ) ) . unwrap_or ( 2 ) / 2 ;
120-
121- let sampled_indices = m4_scalar_without_x_parallel ( arr. view ( ) , 12 , half_n_threads) ;
122- let sampled_values = sampled_indices. mapv ( |x| arr[ x] ) ;
123119
124120 let expected_indices = vec ! [ 0 , 0 , 33 , 33 , 34 , 34 , 66 , 66 , 67 , 67 , 99 , 99 ] ;
121+ let expected_indices = Array1 :: from ( expected_indices) ;
125122 let expected_values = expected_indices
126123 . iter ( )
127124 . map ( |x| * x as f32 )
128125 . collect :: < Vec < f32 > > ( ) ;
126+ let expected_values = Array1 :: from ( expected_values) ;
129127
130- assert_eq ! ( sampled_indices, Array1 :: from( expected_indices) ) ;
131- assert_eq ! ( sampled_values, Array1 :: from( expected_values) ) ;
128+ let all_threads = available_parallelism ( ) . map ( |x| x. get ( ) ) . unwrap_or ( 2 ) ;
129+ let nb_threads = vec ! [ 1 , all_threads / 2 , all_threads, all_threads + 1 ] ;
130+
131+ for n_threads in nb_threads {
132+ let sampled_indices = m4_scalar_without_x_parallel ( arr. view ( ) , 12 , n_threads) ;
133+ let sampled_values = sampled_indices. mapv ( |x| arr[ x] ) ;
134+ assert_eq ! ( sampled_indices, expected_indices) ;
135+ assert_eq ! ( sampled_values, expected_values) ;
136+ }
132137 }
133138
134139 #[ test]
@@ -157,19 +162,23 @@ mod tests {
157162 let x = Array1 :: from ( x) ;
158163 let arr = ( 0 ..100 ) . map ( |x| x as f32 ) . collect :: < Vec < f32 > > ( ) ;
159164 let arr = Array1 :: from ( arr) ;
160- let half_n_threads: usize = available_parallelism ( ) . map ( |x| x. get ( ) ) . unwrap_or ( 2 ) / 2 ;
161-
162- let sampled_indices = m4_scalar_with_x_parallel ( x. view ( ) , arr. view ( ) , 12 , half_n_threads) ;
163- let sampled_values = sampled_indices. mapv ( |x| arr[ x] ) ;
164165
165166 let expected_indices = vec ! [ 0 , 0 , 33 , 33 , 34 , 34 , 66 , 66 , 67 , 67 , 99 , 99 ] ;
167+ let expected_indices = Array1 :: from ( expected_indices) ;
166168 let expected_values = expected_indices
167169 . iter ( )
168170 . map ( |x| * x as f32 )
169171 . collect :: < Vec < f32 > > ( ) ;
170-
171- assert_eq ! ( sampled_indices, Array1 :: from( expected_indices) ) ;
172- assert_eq ! ( sampled_values, Array1 :: from( expected_values) ) ;
172+ let expected_values = Array1 :: from ( expected_values) ;
173+
174+ let all_threads = available_parallelism ( ) . map ( |x| x. get ( ) ) . unwrap_or ( 2 ) ;
175+ let nb_threads = vec ! [ 1 , all_threads / 2 , all_threads, all_threads + 1 ] ;
176+ for n_threads in nb_threads {
177+ let sampled_indices = m4_scalar_with_x_parallel ( x. view ( ) , arr. view ( ) , 12 , n_threads) ;
178+ let sampled_values = sampled_indices. mapv ( |x| arr[ x] ) ;
179+ assert_eq ! ( sampled_indices, expected_indices) ;
180+ assert_eq ! ( sampled_values, expected_values) ;
181+ }
173182 }
174183
175184 #[ test]
@@ -247,16 +256,19 @@ mod tests {
247256 let n_out: usize = 204 ;
248257 let x = ( 0 ..n as i32 ) . collect :: < Vec < i32 > > ( ) ;
249258 let x = Array1 :: from ( x) ;
250- let half_n_threads: usize = available_parallelism ( ) . map ( |x| x. get ( ) ) . unwrap_or ( 2 ) / 2 ;
259+ let all_threads = available_parallelism ( ) . map ( |x| x. get ( ) ) . unwrap_or ( 2 ) ;
260+ let nb_threads = vec ! [ 1 , all_threads / 2 , all_threads, all_threads + 1 ] ;
251261 for _ in 0 ..100 {
252262 let arr = get_array_f32 ( n) ;
253263 let idxs1 = m4_scalar_without_x ( arr. view ( ) , n_out) ;
254- let idxs2 = m4_scalar_without_x_parallel ( arr. view ( ) , n_out, half_n_threads) ;
255- let idxs3 = m4_scalar_with_x ( x. view ( ) , arr. view ( ) , n_out) ;
256- let idxs4 = m4_scalar_with_x_parallel ( x. view ( ) , arr. view ( ) , n_out, half_n_threads) ;
264+ let idxs2 = m4_scalar_with_x ( x. view ( ) , arr. view ( ) , n_out) ;
257265 assert_eq ! ( idxs1, idxs2) ;
258- assert_eq ! ( idxs1, idxs3) ;
259- assert_eq ! ( idxs1, idxs4) ;
266+ for & n_threads in nb_threads. iter ( ) {
267+ let idxs3 = m4_scalar_without_x_parallel ( arr. view ( ) , n_out, n_threads) ;
268+ let idxs4 = m4_scalar_with_x_parallel ( x. view ( ) , arr. view ( ) , n_out, n_threads) ;
269+ assert_eq ! ( idxs1, idxs3) ;
270+ assert_eq ! ( idxs1, idxs4) ; // TODO: this should not fail
271+ }
260272 }
261273 }
262274}
0 commit comments