Skip to content

Commit 213ab6d

Browse files
committed
Add Stream::wait_event
1 parent ec6be00 commit 213ab6d

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

src/event.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,10 @@ impl Event {
279279
}
280280
}
281281

282+
pub(crate) fn get_inner(&self) -> CUevent {
283+
self.0
284+
}
285+
282286
/// Destroy an `Event` returning an error.
283287
///
284288
/// Destroying an event can return errors from previous asynchronous work.

src/stream.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//! a stream to be completed.
1212
1313
use crate::error::{CudaResult, DropResult, ToResult};
14+
use crate::event::Event;
1415
use crate::function::{BlockSize, Function, GridSize};
1516
use cuda_sys::cuda::{self, cudaError_t, CUstream};
1617
use std::ffi::c_void;
@@ -39,6 +40,16 @@ bitflags! {
3940
}
4041
}
4142

43+
bitflags! {
44+
/// Bit flags for configuring a CUDA Stream waiting on an CUDA Event.
45+
///
46+
/// Current versions of CUDA support only the default flag.
47+
pub struct StreamWaitEventFlags: u32 {
48+
/// No flags set.
49+
const DEFAULT = 0x0;
50+
}
51+
}
52+
4253
/// A stream of work for the device to perform.
4354
///
4455
/// See the module-level documentation for more information.
@@ -211,6 +222,40 @@ impl Stream {
211222
unsafe { cuda::cuStreamSynchronize(self.inner).to_result() }
212223
}
213224

225+
/// Make the stream wait on an event.
226+
///
227+
/// All future work submitted to the stream will wait for the event to
228+
/// complete. Synchronization is performed on the device, if possible. The
229+
/// event may originate from different context or device than the stream.
230+
///
231+
/// # Example:
232+
///
233+
/// ```
234+
/// # use rustacuda::quick_init;
235+
/// # use std::error::Error;
236+
/// # fn main() -> Result<(), Box<dyn Error>> {
237+
/// # let _context = quick_init()?;
238+
/// use rustacuda::stream::{Stream, StreamFlags, StreamWaitEventFlags};
239+
/// use rustacuda::event::{Event, EventFlags};
240+
///
241+
/// let stream_0 = Stream::new(StreamFlags::NON_BLOCKING, None)?;
242+
/// let stream_1 = Stream::new(StreamFlags::NON_BLOCKING, None)?;
243+
/// let event = Event::new(EventFlags::DEFAULT)?;
244+
///
245+
/// // do some work on stream_0 ...
246+
///
247+
/// // record an event
248+
/// event.record(&stream_0)?;
249+
///
250+
/// // wait until the work on stream_0 is finished before continuing stream_1
251+
/// stream_1.wait_event(event)?;
252+
/// # Ok(())
253+
/// }
254+
/// ```
255+
pub fn wait_event(&self, event: Event, flags: StreamWaitEventFlags) -> CudaResult<()> {
256+
unsafe { cuda::cuStreamWaitEvent(self.inner, event.get_inner(), flags.bits()).to_result() }
257+
}
258+
214259
// Hidden implementation detail function. Highly unsafe. Use the `launch!` macro instead.
215260
#[doc(hidden)]
216261
pub unsafe fn launch<G, B>(

0 commit comments

Comments
 (0)