Skip to content

Commit 215eb35

Browse files
committed
Rename join_background_tasks to flush_output
This builds on the prior commit to redefine the `join_background_tasks` as purely a "flush" operation that flushes all output streams. This no longer attempts to gracefully join output tasks one-by-one since that should no longer be necessary and everything will get cleaned up during drop when `abort` calls are invoked.
1 parent 3ecbc62 commit 215eb35

3 files changed

Lines changed: 70 additions & 137 deletions

File tree

crates/wasi/src/preview2/ctx.rs

Lines changed: 14 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@ use super::clocks::host::{monotonic_clock, wall_clock};
22
use crate::preview2::{
33
clocks::{self, HostMonotonicClock, HostWallClock},
44
filesystem::{Dir, TableFsExt},
5-
pipe,
6-
poll::TablePollableExt,
7-
random, stdio,
5+
pipe, random, stdio,
86
stream::{HostInputStream, HostOutputStream, TableStreamExt},
97
DirPerms, FilePerms, Table,
108
};
119
use cap_rand::{Rng, RngCore, SeedableRng};
12-
use std::future::Future;
1310
use std::mem;
1411

1512
pub struct WasiCtxBuilder {
@@ -263,8 +260,7 @@ pub struct WasiCtx {
263260
}
264261

265262
impl WasiCtx {
266-
/// Wait for all background tasks to join (complete) gracefully, after flushing any
267-
/// buffered output.
263+
/// Flush all buffered output and wait for it to reach the destination.
268264
///
269265
/// NOTE: This function should be used when [`WasiCtx`] is used in an async embedding
270266
/// (i.e. with [`crate::preview2::command::add_to_linker`]). Use its counterpart
@@ -274,92 +270,40 @@ impl WasiCtx {
274270
/// In order to implement non-blocking streams, we often often need to offload async
275271
/// operations to background `tokio::task`s. These tasks are aborted when the resources
276272
/// in the `Table` referencing them are dropped. In some cases, this abort may occur before
277-
/// buffered output has been flushed. Use this function to wait for all background tasks to
278-
/// join gracefully.
273+
/// buffered output has been flushed. Use this function to wait for all
274+
/// written data to reach its destination gracefully, even if wasm didn't
275+
/// explicitly wait for this.
279276
///
280277
/// In some embeddings, a misbehaving client might cause this graceful exit to await for an
281278
/// unbounded amount of time, so we recommend bounding this with a timeout or other mechanism.
282-
pub fn join_background_tasks<'a>(&mut self, table: &'a mut Table) -> impl Future<Output = ()> {
283-
use std::pin::Pin;
284-
use std::task::{Context, Poll};
279+
pub async fn flush_output(&mut self, table: &mut Table) {
285280
let keys = table.keys().cloned().collect::<Vec<u32>>();
286-
let mut set = Vec::new();
287-
// we can't remove an stream from the table if it has any child pollables,
288-
// so first delete all pollables from the table.
289-
for k in keys.iter() {
290-
let _ = table.delete_host_pollable(*k);
291-
}
292281
for k in keys {
293282
match table.delete_output_stream(k) {
294283
Ok(mut ostream) => {
295-
// async block takes ownership of the ostream and flushes it
296-
let f = async move { ostream.join_background_tasks().await };
297-
set.push(Box::pin(f) as _)
298-
}
299-
_ => {}
300-
}
301-
match table.delete_input_stream(k) {
302-
Ok(mut istream) => {
303-
// async block takes ownership of the istream and flushes it
304-
let f = async move { istream.join_background_tasks().await };
305-
set.push(Box::pin(f) as _)
284+
let _ = ostream.flush().await;
306285
}
307286
_ => {}
308287
}
309288
}
310-
// poll futures until all are ready.
311-
// We can't write this as an `async fn` because we want to eagerly poll on each possible
312-
// join, rather than sequentially awaiting on them.
313-
struct JoinAll(Vec<Pin<Box<dyn Future<Output = ()> + Send>>>);
314-
impl Future for JoinAll {
315-
type Output = ();
316-
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
317-
// Iterate through the set, polling each future and removing it from the set if it
318-
// is ready:
319-
self.as_mut()
320-
.0
321-
.retain_mut(|fut| match fut.as_mut().poll(cx) {
322-
Poll::Ready(_) => false,
323-
_ => true,
324-
});
325-
// Ready if set is empty:
326-
if self.as_mut().0.is_empty() {
327-
Poll::Ready(())
328-
} else {
329-
Poll::Pending
330-
}
331-
}
332-
}
333-
334-
JoinAll(set)
335289
}
336290

337-
/// Wait for all background tasks to join (complete) gracefully, after flushing any
338-
/// buffered output.
339-
///
340-
/// NOTE: This function should be used when [`WasiCtx`] is used in an synchronous embedding
341-
/// (i.e. with [`crate::preview2::command::sync::add_to_linker`]). Use its counterpart
342-
/// `join_background_tasks` in an async embedding (i.e. with
343-
/// [`crate::preview2::command::add_to_linker`].
291+
/// Same as [`WasiCtx::flush_output`] except suitable for synchronous
292+
/// embeddings.
344293
///
345-
/// In order to implement non-blocking streams, we often often need to offload async
346-
/// operations to background `tokio::task`s. These tasks are aborted when the resources
347-
/// in the `Table` referencing them are dropped. In some cases, this abort may occur before
348-
/// buffered output has been flushed. Use this function to wait for all background tasks to
349-
/// join gracefully.
350-
///
351-
/// In some embeddings, a misbehaving client might cause this graceful exit to await for an
352-
/// unbounded amount of time, so we recommend providing a timeout for this method.
294+
/// This will block the current thread up to the `timeout` specified, or
295+
/// forever if `None` is specified, until all wasm output has been flushed
296+
/// out.
353297
pub fn sync_join_background_tasks<'a>(
354298
&mut self,
355299
table: &'a mut Table,
356300
timeout: Option<std::time::Duration>,
357301
) {
358302
crate::preview2::in_tokio(async move {
359303
if let Some(timeout) = timeout {
360-
let _ = tokio::time::timeout(timeout, self.join_background_tasks(table)).await;
304+
let _ = tokio::time::timeout(timeout, self.flush_output(table)).await;
361305
} else {
362-
self.join_background_tasks(table).await
306+
self.flush_output(table).await
363307
}
364308
})
365309
}

crates/wasi/src/preview2/pipe.rs

Lines changed: 48 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ impl HostOutputStream for MemoryOutputPipe {
8787
// This stream is always ready for writing.
8888
Ok(())
8989
}
90+
91+
async fn flush(&mut self) -> Result<(), Error> {
92+
Ok(())
93+
}
9094
}
9195

9296
/// TODO
@@ -105,9 +109,7 @@ pub struct AsyncReadStream {
105109
state: StreamState,
106110
buffer: Option<Result<Bytes, std::io::Error>>,
107111
receiver: tokio::sync::mpsc::Receiver<Result<(Bytes, StreamState), std::io::Error>>,
108-
// the join handle for the background task is Some until join_background_tasks, after which
109-
// further use of the AsyncReadStream is not allowed.
110-
join_handle: Option<tokio::task::JoinHandle<()>>,
112+
join_handle: tokio::task::JoinHandle<()>,
111113
}
112114

113115
impl AsyncReadStream {
@@ -136,17 +138,13 @@ impl AsyncReadStream {
136138
state: StreamState::Open,
137139
buffer: None,
138140
receiver,
139-
join_handle: Some(join_handle),
141+
join_handle,
140142
}
141143
}
142144
// stdio implementation uses this to determine if the backing tokio runtime has been shutdown and
143145
// restarted:
144146
pub(crate) fn is_finished(&self) -> bool {
145-
assert!(
146-
self.join_handle.is_some(),
147-
"illegal use of AsyncReadStream after join_background_tasks"
148-
);
149-
self.join_handle.as_ref().unwrap().is_finished()
147+
self.join_handle.is_finished()
150148
}
151149
}
152150

@@ -155,18 +153,14 @@ impl AsyncReadStream {
155153
// on reader.read_buf's await it could hold the reader open indefinitely.
156154
impl Drop for AsyncReadStream {
157155
fn drop(&mut self) {
158-
self.join_handle.take().map(|h| h.abort());
156+
self.join_handle.abort();
159157
}
160158
}
161159

162160
#[async_trait::async_trait]
163161
impl HostInputStream for AsyncReadStream {
164162
fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> {
165163
use tokio::sync::mpsc::error::TryRecvError;
166-
assert!(
167-
self.join_handle.is_some(),
168-
"illegal use of AsyncReadStream after join_background_tasks"
169-
);
170164

171165
match self.buffer.take() {
172166
Some(Ok(mut bytes)) => {
@@ -209,11 +203,6 @@ impl HostInputStream for AsyncReadStream {
209203
}
210204

211205
async fn ready(&mut self) -> Result<(), Error> {
212-
assert!(
213-
self.join_handle.is_some(),
214-
"illegal use of AsyncReadStream after join_background_tasks"
215-
);
216-
217206
if self.buffer.is_some() || self.state == StreamState::Closed {
218207
return Ok(());
219208
}
@@ -233,9 +222,6 @@ impl HostInputStream for AsyncReadStream {
233222
}
234223
Ok(())
235224
}
236-
async fn join_background_tasks(&mut self) {
237-
self.join_handle.take().map(|h| h.abort());
238-
}
239225
}
240226

241227
#[derive(Debug)]
@@ -348,40 +334,6 @@ impl AsyncWriteStream {
348334
fn has_pending_op(&self) -> bool {
349335
matches!(self.state, Some(WriteState::Pending))
350336
}
351-
352-
async fn flush(&mut self) -> anyhow::Result<()> {
353-
// NB: This method needs to be "cancel safe" where it can be cancelled
354-
// at any `.await` point but the flush operation still needs to be able
355-
// to be restarted successfully.
356-
357-
// First wait for any pending operation to complete to have the ability
358-
// to send another message.
359-
self.ready().await?;
360-
361-
// Queue up a flush operation in our background task, and if it's
362-
// already gone then that's ok as flushing has completed anyway.
363-
//
364-
// Note that this may end up returning a queued error from a previous
365-
// write or flush.
366-
if !self.send(WriteMesage::Flush)? {
367-
return Ok(());
368-
}
369-
370-
// Wait again for the flush to fully complete before considering this
371-
// request to flush as fully complete.
372-
self.ready().await?;
373-
374-
// Extract the error, if any, that occurred.
375-
match mem::replace(&mut self.state, Some(WriteState::Ready)) {
376-
Some(WriteState::Err(e)) => Err(e.into()),
377-
Some(WriteState::Pending) => unreachable!(),
378-
Some(WriteState::Ready) => Ok(()),
379-
None => {
380-
self.state = None;
381-
Ok(())
382-
}
383-
}
384-
}
385337
}
386338

387339
// Make sure the background task does not outlive the AsyncWriteStream handle.
@@ -429,8 +381,38 @@ impl HostOutputStream for AsyncWriteStream {
429381
Ok(())
430382
}
431383

432-
async fn join_background_tasks(&mut self) {
433-
let _ = self.flush().await;
384+
async fn flush(&mut self) -> anyhow::Result<()> {
385+
// NB: This method needs to be "cancel safe" where it can be cancelled
386+
// at any `.await` point but the flush operation still needs to be able
387+
// to be restarted successfully.
388+
389+
// First wait for any pending operation to complete to have the ability
390+
// to send another message.
391+
self.ready().await?;
392+
393+
// Queue up a flush operation in our background task, and if it's
394+
// already gone then that's ok as flushing has completed anyway.
395+
//
396+
// Note that this may end up returning a queued error from a previous
397+
// write or flush.
398+
if !self.send(WriteMesage::Flush)? {
399+
return Ok(());
400+
}
401+
402+
// Wait again for the flush to fully complete before considering this
403+
// request to flush as fully complete.
404+
self.ready().await?;
405+
406+
// Extract the error, if any, that occurred.
407+
match mem::replace(&mut self.state, Some(WriteState::Ready)) {
408+
Some(WriteState::Err(e)) => Err(e.into()),
409+
Some(WriteState::Pending) => unreachable!(),
410+
Some(WriteState::Ready) => Ok(()),
411+
None => {
412+
self.state = None;
413+
Ok(())
414+
}
415+
}
434416
}
435417
}
436418

@@ -446,6 +428,10 @@ impl HostOutputStream for SinkOutputStream {
446428
async fn ready(&mut self) -> Result<(), Error> {
447429
Ok(())
448430
}
431+
432+
async fn flush(&mut self) -> Result<(), Error> {
433+
Ok(())
434+
}
449435
}
450436

451437
/// A stream that is ready immediately, but will always report that it's closed.
@@ -474,6 +460,10 @@ impl HostOutputStream for ClosedOutputStream {
474460
async fn ready(&mut self) -> Result<(), Error> {
475461
Ok(())
476462
}
463+
464+
async fn flush(&mut self) -> Result<(), Error> {
465+
Ok(())
466+
}
477467
}
478468

479469
#[cfg(test)]

crates/wasi/src/preview2/stream.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ pub trait HostInputStream: Send + Sync {
4343
/// Check for read readiness: this method blocks until the stream is ready
4444
/// for reading.
4545
async fn ready(&mut self) -> Result<(), Error>;
46-
47-
/// Terminate all background tasks. Exposed only to the host, not accessible from WebAssembly.
48-
/// Must cancel background tasks even if dropped before completion.
49-
/// No other methods may be used after calling this method.
50-
async fn join_background_tasks(&mut self) {}
5146
}
5247

5348
/// Host trait for implementing the `wasi:io/streams.output-stream` resource:
@@ -94,10 +89,11 @@ pub trait HostOutputStream: Send + Sync {
9489
/// ready for writing.
9590
async fn ready(&mut self) -> Result<(), Error>;
9691

97-
/// Flush any output which has been buffered, and terminate all background tasks. Exposed only
98-
/// to the host, not accessible from WebAssembly. Must cancel background tasks even if dropped
99-
/// before completion. No other methods may be used after calling this method.
100-
async fn join_background_tasks(&mut self) {}
92+
/// Flush any output which has been buffered.
93+
///
94+
/// This will attempt to flush buffers and wait for all bytes to reach the
95+
/// "end", whatever the "end" means in the context of this stream.
96+
async fn flush(&mut self) -> Result<(), Error>;
10197
}
10298

10399
pub(crate) enum InternalInputStream {
@@ -288,6 +284,9 @@ mod test {
288284
async fn ready(&mut self) -> Result<(), Error> {
289285
unimplemented!();
290286
}
287+
async fn flush(&mut self) -> Result<(), Error> {
288+
Ok(())
289+
}
291290
}
292291

293292
let dummy = DummyOutputStream;

0 commit comments

Comments
 (0)