diff --git a/Cargo.toml b/Cargo.toml index 5c8be25..d4b13bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,9 @@ categories = ["development-tools::debugging"] license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +stream = ["futures-core"] + [dependencies] coarsetime = "0.1" derive_builder = "0.12" @@ -20,6 +23,7 @@ pin-project = "1" tokio = { version = "1", features = ["rt"] } tracing = "0.1" weak-table = "0.3.2" +futures-core = { version = "0.3", optional = true } [dev-dependencies] criterion = { version = "0.4", features = ["async", "async_tokio"] } diff --git a/src/registry.rs b/src/registry.rs index 9ddafc7..7c33afc 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -51,6 +51,39 @@ impl TreeRoot { } } +#[cfg(feature = "stream")] +impl TreeRoot { + /// Instrument the given stream with the context of this tree root. + pub fn instrument_stream( + self, + stream: S, + ) -> impl futures_core::Stream { + #[pin_project::pin_project] + struct StreamWithContext { + #[pin] + inner: S, + context: Arc, + } + + impl futures_core::Stream for StreamWithContext { + type Item = S::Item; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + CONTEXT.sync_scope(this.context.clone(), || this.inner.poll_next(cx)) + } + } + + StreamWithContext { + inner: stream, + context: self.context, + } + } +} + /// The registry of multiple await-trees. #[derive(Debug)] pub struct Registry {