Skip to content

Commit 7498982

Browse files
committed
refactor: clean up double buffer to fix a Miri warning
1 parent bbf76d8 commit 7498982

3 files changed

Lines changed: 146 additions & 127 deletions

File tree

src/double_buffer.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
use core::{mem::MaybeUninit, slice};
2+
3+
use alloc::{boxed::Box, vec::Vec};
4+
5+
/// Double buffer. Wraps a mutable slice and allocates a scratch memory of the same size, so that
6+
/// elements can be freely scattered from buffer to buffer.
7+
///
8+
/// # Drop behavior
9+
///
10+
/// Drop ensures that the mutable slice this buffer was constructed with contains all the original
11+
/// elements.
12+
pub struct DoubleBuffer<'a, T> {
13+
slice: &'a mut [MaybeUninit<T>],
14+
scratch: Box<[MaybeUninit<T>]>,
15+
slice_is_write: bool,
16+
}
17+
18+
impl<'a, T> DoubleBuffer<'a, T> {
19+
/// Creates a double buffer, allocating a scratch buffer of the same length as the input slice.
20+
///
21+
/// The supplied slice becomes the read buffer, the scratch buffer becomes the write buffer.
22+
pub fn new(slice: &'a mut [T]) -> Self {
23+
// SAFETY: The Drop impl ensures that the slice is initialized.
24+
let slice = unsafe { slice_as_uninit_mut(slice) };
25+
let scratch = {
26+
let mut v = Vec::with_capacity(slice.len());
27+
// SAFETY: we just allocated this capacity and MaybeUninit can be garbage.
28+
unsafe {
29+
v.set_len(slice.len());
30+
}
31+
v.into_boxed_slice()
32+
};
33+
DoubleBuffer {
34+
slice,
35+
scratch,
36+
slice_is_write: false,
37+
}
38+
}
39+
40+
/// Scatters the elements from the read buffer to the computed indices in
41+
/// the write buffer. The read buffer is iterated from the beginning.
42+
///
43+
/// Call `swap` after this function to commit the write buffer state.
44+
pub fn scatter<F>(&mut self, mut indexer: F)
45+
where
46+
F: FnMut(&T) -> usize,
47+
{
48+
let (read, write) = self.as_read_write();
49+
50+
let len = write.len();
51+
52+
for t in read {
53+
let index = indexer(t);
54+
if index >= len {
55+
return;
56+
}
57+
let write_ptr = write[index].as_mut_ptr();
58+
unsafe {
59+
// SAFETY: both pointers are valid for T, aligned, and nonoverlapping
60+
write_ptr.copy_from_nonoverlapping(t as *const T, 1);
61+
}
62+
}
63+
}
64+
65+
/// Returns the current read and write buffers.
66+
fn as_read_write(&mut self) -> (&[T], &mut [MaybeUninit<T>]) {
67+
let (read, write): (&[MaybeUninit<T>], &mut [MaybeUninit<T>]) = if self.slice_is_write {
68+
(self.scratch.as_ref(), self.slice)
69+
} else {
70+
(self.slice, self.scratch.as_mut())
71+
};
72+
73+
// SAFETY: The read buffer is always initialized.
74+
let read = unsafe { slice_assume_init_ref(read) };
75+
76+
(read, write)
77+
}
78+
79+
/// Swaps the read and write buffer, committing the write buffer state.
80+
///
81+
/// # Safety
82+
///
83+
/// The caller must ensure that every element of the write buffer was
84+
/// written to before calling this function.
85+
pub unsafe fn swap(&mut self) {
86+
self.slice_is_write = !self.slice_is_write;
87+
}
88+
}
89+
90+
/// Ensures that the input slice contains all the original elements.
91+
impl<'a, T> Drop for DoubleBuffer<'a, T> {
92+
fn drop(&mut self) {
93+
if self.slice_is_write {
94+
// The input slice is the write buffer, copy the consistent state from the read buffer
95+
unsafe {
96+
// SAFETY: `scratch` is the read buffer, it is initialized. The length is the same.
97+
self.slice
98+
.as_mut_ptr()
99+
.copy_from_nonoverlapping(self.scratch.as_ptr(), self.slice.len());
100+
}
101+
self.slice_is_write = false;
102+
}
103+
}
104+
}
105+
106+
/// Get a slice of the initialized items.
107+
///
108+
/// # Safety
109+
///
110+
/// The caller must ensure that all the items are initialized.
111+
#[inline(always)]
112+
pub unsafe fn slice_assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
113+
// SAFETY: `[MaybeUninit<T>]` and `[T]` have the same layout.
114+
unsafe { slice::from_raw_parts(slice.as_ptr() as *const T, slice.len()) }
115+
}
116+
117+
/// View the mutable slice of `T` as a slice of `MaybeUnint<T>`.
118+
///
119+
/// # Safety
120+
///
121+
/// The caller must ensure that all the items of the returned slice are
122+
/// initialized before dropping it.
123+
#[inline(always)]
124+
pub unsafe fn slice_as_uninit_mut<T>(slice: &mut [T]) -> &mut [MaybeUninit<T>] {
125+
// SAFETY: `[MaybeUninit<T>]` and `[T]` have the same layout.
126+
unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<T>, slice.len()) }
127+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ extern crate alloc;
9494

9595
use alloc::vec::Vec;
9696

97+
mod double_buffer;
9798
mod scalar;
9899
mod sort;
99100

src/sort.rs

Lines changed: 18 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
//! Implementations of radix keys and sorting functions.
22
3-
use alloc::vec::Vec;
43
use core::mem;
54

6-
use crate::Key;
5+
use crate::{double_buffer::DoubleBuffer, Key};
76

87
/// Unsigned integers used as sorting keys for radix sort.
98
///
@@ -139,41 +138,32 @@ macro_rules! sort_impl {
139138
// digit, our elements are sorted.
140139
for digit in 0..DIGIT_COUNT {
141140
if !(digit_skip_enabled && skip_digit[digit]) {
142-
// Copy the offsets. We need the original later for a consistency check.
143-
// As we write elements into each bucket, we increment the bucket offset
144-
// so that it points to the next empty slot.
145-
let mut working_offsets: [$offset_type; BUCKET_COUNT] = offsets[digit];
146-
147-
for r_pos in 0..len {
148-
let t: &T = unsafe {
149-
// This is safe, r_pos is in (0..len)
150-
buffer.read(r_pos)
151-
};
141+
// Initial offset of each bucket.
142+
let init_offsets = &offsets[digit];
143+
// Offset of the first empty index in each bucket.
144+
let mut working_offsets = *init_offsets;
152145

146+
buffer.scatter(|t| {
153147
let key = key_fn(t);
154148
let bucket = extract_digit(key, digit);
155149

156150
let offset = &mut working_offsets[bucket];
157151

158-
unsafe {
159-
// Make sure the offset is in bounds. An unreliable key function, which
160-
// returns different keys for the same element when called repeatedly,
161-
// can cause offsets to go out of bounds.
162-
let w_pos = usize::min(*offset as usize, len - 1);
163-
164-
// This is safe, w_pos is in (0..len)
165-
buffer.write(w_pos, t);
166-
}
152+
// Make sure the offset is in bounds. An unreliable key function, which
153+
// returns different keys for the same element when called repeatedly,
154+
// can cause offsets to go out of bounds.
155+
let clamped_offset = usize::min(*offset as usize, len - 1);
167156

168157
// Increment the offset of the bucket. Use wrapping add in case the
169158
// key function is unreliable and the bucket overflowed.
170159
*offset = offset.wrapping_add(1);
171-
}
160+
161+
clamped_offset
162+
});
172163

173164
// Check that each bucket had the same number of insertions as we expected.
174-
// If this is not true, then the key function is unreliable and the write buffer
175-
// is not consistent: some elements were overwritten, some were not written to
176-
// and contain garbage.
165+
// If this is not true, then the key function is unreliable and some elements
166+
// in the write buffer were not written to.
177167
//
178168
// If the key function is unreliable, but the sizes of buckets ended up being
179169
// the same, it would not get detected. This is sound, the only consequence is
@@ -191,9 +181,7 @@ macro_rules! sort_impl {
191181
// The bucket sizes do not match expected sizes, the key function is
192182
// unreliable (programming mistake).
193183
//
194-
// Drop impl of the double buffer will make sure that the input slice is
195-
// consistent. This would happen automatically, but I'm making it
196-
// explicit for clarity.
184+
// The Drop impl will copy the last completed buffer into the slice.
197185
drop(buffer);
198186
panic!(
199187
"The key function is not reliable: when called repeatedly, \
@@ -203,15 +191,13 @@ macro_rules! sort_impl {
203191
}
204192

205193
unsafe {
206-
// This is safe, we just ensured that the write buffer is consistent.
194+
// SAFETY: we just ensured that every index was written to.
207195
buffer.swap();
208196
}
209197
}
210198
}
211199

212-
// In case the result ended up in the temporary buffer, the Drop impl will copy it over
213-
// to the input slice. This would happen automatically, but I'm making it explicit for
214-
// clarity.
200+
// The Drop impl will copy the last completed buffer into the slice.
215201
drop(buffer);
216202
}
217203
};
@@ -230,98 +216,3 @@ macro_rules! radix_key_impl {
230216
}
231217

232218
radix_key_impl! { u8 u16 u32 u64 u128 }
233-
234-
/// Double buffer. Allocates a temporary memory the size of the slice, so that
235-
/// elements can be freely reordered from buffer to buffer.
236-
///
237-
/// # Consistency
238-
///
239-
/// For the purposes of this struct, buffer in a consistent state contains a
240-
/// permutation of elements from the original slice. In other words, elements
241-
/// can be reordered, but not duplicated or lost.
242-
///
243-
/// `read_buf` is always consistent. Before calling `swap`, the caller must
244-
/// ensure that `write_buf` is also consistent.
245-
///
246-
/// # Drop behavior
247-
///
248-
/// Drop impl ensures that the slice this buffer was constructed with is left in
249-
/// a consistent state. If the input slice ended up as `write_buf`, the
250-
/// temporary memory (which is now `read_buf` and therefore consistent) is
251-
/// copied into the slice and the buffers are swapped.
252-
struct DoubleBuffer<'a, T> {
253-
slice: &'a mut [T],
254-
_aux: Vec<T>,
255-
256-
/// Read buffer is read-only and always consistent.
257-
read_buf: *const T,
258-
259-
/// Write buffer is write-only. Elements can be present multiple times or
260-
/// not at all. The caller must ensure that it is consistent before calling
261-
/// `swap`.
262-
write_buf: *mut T,
263-
}
264-
265-
impl<'a, T> DoubleBuffer<'a, T> {
266-
fn new(slice: &'a mut [T]) -> DoubleBuffer<'a, T> {
267-
let mut aux = Vec::with_capacity(slice.len());
268-
let read_buf = slice.as_ptr();
269-
let write_buf = aux.as_mut_ptr();
270-
DoubleBuffer {
271-
// Hold on to the &mut slice to make sure it outlives the pointer
272-
// and to prevent writes from the outside
273-
slice,
274-
// Hold on to the Vec to make sure it outlives the pointer
275-
_aux: aux,
276-
read_buf,
277-
write_buf,
278-
}
279-
}
280-
281-
/// Returns a ref to an element from the read buffer.
282-
///
283-
/// Caller must ensure that `index` is in (0..len).
284-
#[inline(always)]
285-
unsafe fn read(&self, index: usize) -> &T {
286-
&*self.read_buf.add(index)
287-
}
288-
289-
/// Copies the referenced element into the write buffer.
290-
///
291-
/// Caller must ensure that `index` is in (0..len).
292-
#[inline(always)]
293-
unsafe fn write(&self, index: usize, t: &T) {
294-
self.write_buf
295-
.add(index)
296-
.copy_from_nonoverlapping(t as *const T, 1);
297-
}
298-
299-
/// Swaps the read and write buffers.
300-
///
301-
/// Caller must ensure that the write buffer is consistent before calling
302-
/// this function.
303-
unsafe fn swap(&mut self) {
304-
// The cast is ok, we have an exclusive access to both buffers
305-
// (&mut [T] and Vec<T>). The caller guarantees that the write buffer is
306-
// consistent and therefore it's safe to read from it and use it as a
307-
// read buffer.
308-
let temp = self.write_buf as *const T;
309-
self.write_buf = self.read_buf as *mut T;
310-
self.read_buf = temp;
311-
}
312-
}
313-
314-
impl<'a, T> Drop for DoubleBuffer<'a, T> {
315-
fn drop(&mut self) {
316-
let input_slice_is_write = self.write_buf as *const T == self.slice.as_ptr();
317-
if input_slice_is_write {
318-
// Input slice is the write buffer, copy the consistent state from the read buffer
319-
unsafe {
320-
// This is safe, `read_buf` is always consistent and the length is the same.
321-
self.write_buf
322-
.copy_from_nonoverlapping(self.read_buf, self.slice.len());
323-
self.swap();
324-
}
325-
}
326-
}
327-
}

0 commit comments

Comments
 (0)