diff --git a/.bleep b/.bleep index 2911a4d3..b2b6ca2a 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -67cc768b717d3865a73bcd917c905d7d9aeb4c62 +e2089546a5962c0f65c081211d604dadd9330195 \ No newline at end of file diff --git a/.cargo/audit.toml b/.cargo/audit.toml new file mode 100644 index 00000000..7c6e098f --- /dev/null +++ b/.cargo/audit.toml @@ -0,0 +1,3 @@ +[advisories] +# Temp before internal sync applies dependency bumps +ignore = ["RUSTSEC-2026-0097", "RUSTSEC-2026-0098", "RUSTSEC-2026-0099"] diff --git a/Cargo.toml b/Cargo.toml index d3c8603b..c78de1f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ members = [ "pingora-ketama", "pingora-load-balancing", "pingora-memory-cache", + "pingora-prometheus", "tinyufo", ] diff --git a/docs/user_guide/conf.md b/docs/user_guide/conf.md index 1f55859e..70a8f569 100644 --- a/docs/user_guide/conf.md +++ b/docs/user_guide/conf.md @@ -29,6 +29,9 @@ group: webusers | s2n_config_cache_size | The maximum number of unique s2n configs to cache. A value of 0 disables the cache. Default: 10 (s2n-tls only) | number | | work_stealing | Enable work stealing runtime (default true). See Pingora runtime (WIP) section for more info | bool | | upstream_keepalive_pool_size | The number of total connections to keep in the connection pool | number | +| daemon_wait_for_ready | When `true` and `daemon` is `true`, the parent process waits for the daemon to signal readiness (via `SIGUSR1`) before exiting. This causes systemd to delay sending `SIGQUIT` to the old process until the new instance is fully bootstrapped. Default: `false` | bool | +| daemon_ready_timeout_seconds | How long (in seconds) the parent waits for the daemon to signal readiness when `daemon_wait_for_ready` is `true`. If the daemon does not signal in time the parent exits with a non-zero code, causing systemd to abort the reload. Default: `600` | number | +| daemon_notify_timeout_seconds | How long (in seconds) the daemon retries sending `SIGUSR1` to the parent when the attempt fails with a permission error. This covers the brief window after the fork where the parent has not yet dropped its UID to match the daemon. Default: `60` | number | ## Extension Any unknown settings will be ignored. This allows extending the conf file to add and pass user defined settings. See User defined configuration section. diff --git a/docs/user_guide/modify_filter.md b/docs/user_guide/modify_filter.md index 3e5378fb..a833fc27 100644 --- a/docs/user_guide/modify_filter.md +++ b/docs/user_guide/modify_filter.md @@ -123,8 +123,7 @@ impl ProxyHttp for MyGateway { fn main() { ... - let mut prometheus_service_http = - pingora::services::listening::Service::prometheus_http_service(); + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("127.0.0.1:6192"); my_server.add_service(prometheus_service_http); diff --git a/docs/user_guide/prom.md b/docs/user_guide/prom.md index b1868f12..1e83c0f3 100644 --- a/docs/user_guide/prom.md +++ b/docs/user_guide/prom.md @@ -1,29 +1,21 @@ # Prometheus -Pingora has a built-in prometheus HTTP metric server for scraping. +The [`pingora-prometheus`](https://docs.rs/pingora-prometheus) crate provides a +Prometheus HTTP metrics server for scraping. -## Enabling Prometheus Support +## Adding the Dependency -Prometheus support is an optional feature in Pingora. To use it, you need to enable the `prometheus` feature in your `Cargo.toml`: +Add `pingora-prometheus` to your `Cargo.toml`: ```toml -# If using the main pingora crate -pingora = { version = "0.8.0", features = ["prometheus"] } - -# If using pingora-core directly -pingora-core = { version = "0.8.0", features = ["prometheus"] } - -# If using pingora-proxy crate -pingora-proxy = { version = "0.8.0", features = ["prometheus"] } +pingora-prometheus = "0.8.0" ``` ## Setting up a Prometheus Metrics Endpoint -Once the feature is enabled, you can set up a Prometheus metrics endpoint like this: - ```rust ... - let mut prometheus_service_http = Service::prometheus_http_service(); + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("0.0.0.0:1234"); my_server.add_service(prometheus_service_http); my_server.run_forever(); diff --git a/pingora-cache/src/eviction/lru.rs b/pingora-cache/src/eviction/lru.rs index d241ee69..96285700 100644 --- a/pingora-cache/src/eviction/lru.rs +++ b/pingora-cache/src/eviction/lru.rs @@ -85,6 +85,15 @@ impl Manager { (u64key(key) % N as u64) as usize } + /// Peek at the least-recently-used key in the given shard without evicting it. + /// + /// Returns the cache key at the LRU tail of the shard, or `None` if empty. + /// Useful for reporting the eviction frontier (the age of the next item + /// that would be evicted). + pub fn peek_lru(&self, shard: usize) -> Option { + self.0.peek_lru(shard).map(|(key, _weight)| key) + } + /// Serialize the given shard pub fn serialize_shard(&self, shard: usize) -> Result> { use rmp_serde::encode::Serializer; @@ -614,4 +623,42 @@ mod test { // Cleanup test directory std::fs::remove_dir_all(dir_path).unwrap(); } + + #[test] + fn test_peek_lru() { + let lru = Manager::<1>::with_capacity(20, 20); + let until = SystemTime::now(); + + // empty shard returns None + assert!(lru.peek_lru(0).is_none()); + + let key1 = CacheKey::new("", "a", "1").to_compact(); + lru.admit(key1.clone(), 1, until); + // single item: it's both the head and the tail + assert_eq!(lru.peek_lru(0).unwrap(), key1); + + // admit more keys to push key1 to the tail + let key2 = CacheKey::new("", "b", "1").to_compact(); + lru.admit(key2.clone(), 1, until); + for i in 0..5 { + lru.admit( + CacheKey::new("", format!("f{i}"), "1").to_compact(), + 1, + until, + ); + } + // key1 is the LRU tail (admitted first) + assert_eq!(lru.peek_lru(0).unwrap(), key1); + + // promote key1 — now key2 becomes the tail + lru.access(&key1, 1, until); + assert_eq!(lru.peek_lru(0).unwrap(), key2); + + // peek_lru should not remove the item + assert_eq!(lru.peek_lru(0).unwrap(), key2); + assert!(lru.peek(&key2)); + + // out-of-bounds shard returns None + assert!(lru.peek_lru(999).is_none()); + } } diff --git a/pingora-core/Cargo.toml b/pingora-core/Cargo.toml index e2854966..12ff7a23 100644 --- a/pingora-core/Cargo.toml +++ b/pingora-core/Cargo.toml @@ -47,7 +47,6 @@ strum = "0.26.2" strum_macros = "0.26.2" libc = "0.2.70" chrono = { version = "~0.4.31", features = ["alloc"], default-features = false } -prometheus = { version = "0.14", optional = true } sentry = { version = "0.36", features = [ "backtrace", "contexts", @@ -108,4 +107,3 @@ openssl_derived = ["any_tls"] any_tls = [] sentry = ["dep:sentry"] connection_filter = [] -prometheus = ["dep:prometheus"] diff --git a/pingora-core/src/apps/mod.rs b/pingora-core/src/apps/mod.rs index 8c087489..82989e5c 100644 --- a/pingora-core/src/apps/mod.rs +++ b/pingora-core/src/apps/mod.rs @@ -15,8 +15,6 @@ //! The abstraction and implementation interface for service application logic pub mod http_app; -#[cfg(feature = "prometheus")] -pub mod prometheus_http_app; use crate::server::ShutdownWatch; use async_trait::async_trait; diff --git a/pingora-core/src/apps/prometheus_http_app.rs b/pingora-core/src/apps/prometheus_http_app.rs deleted file mode 100644 index f06cce7d..00000000 --- a/pingora-core/src/apps/prometheus_http_app.rs +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2026 Cloudflare, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! An HTTP application that reports Prometheus metrics. - -#[cfg(feature = "prometheus")] -mod prometheus_impl { - use async_trait::async_trait; - use http::Response; - use prometheus::{Encoder, TextEncoder}; - - use super::super::http_app::HttpServer; - use crate::apps::http_app::ServeHttp; - use crate::modules::http::compression::ResponseCompressionBuilder; - use crate::protocols::http::ServerSession; - - /// An HTTP application that reports Prometheus metrics. - /// - /// This application will report all the [static metrics](https://docs.rs/prometheus/latest/prometheus/index.html#static-metrics) - /// collected via the [Prometheus](https://docs.rs/prometheus/) crate; - pub struct PrometheusHttpApp; - - #[async_trait] - impl ServeHttp for PrometheusHttpApp { - async fn response(&self, _http_session: &mut ServerSession) -> Response> { - let encoder = TextEncoder::new(); - let metric_families = prometheus::gather(); - let mut buffer = vec![]; - encoder.encode(&metric_families, &mut buffer).unwrap(); - Response::builder() - .status(200) - .header(http::header::CONTENT_TYPE, encoder.format_type()) - .header(http::header::CONTENT_LENGTH, buffer.len()) - .body(buffer) - .unwrap() - } - } - - /// The [HttpServer] for [PrometheusHttpApp] - /// - /// This type provides the functionality of [PrometheusHttpApp] with compression enabled - pub type PrometheusServer = HttpServer; - - impl PrometheusServer { - pub fn new() -> Self { - let mut server = Self::new_app(PrometheusHttpApp); - // enable gzip level 7 compression - server.add_module(ResponseCompressionBuilder::enable(7)); - server - } - } -} - -#[cfg(feature = "prometheus")] -pub use prometheus_impl::*; diff --git a/pingora-core/src/connectors/http/mod.rs b/pingora-core/src/connectors/http/mod.rs index 2545cf7c..5a671ef2 100644 --- a/pingora-core/src/connectors/http/mod.rs +++ b/pingora-core/src/connectors/http/mod.rs @@ -21,6 +21,8 @@ use crate::protocols::http::client::HttpSession; use crate::protocols::http::v1::client::HttpSession as Http1Session; use crate::upstreams::peer::Peer; use pingora_error::Result; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; use std::time::Duration; pub mod custom; @@ -151,6 +153,17 @@ where pub fn prefer_h1(&self, peer: &impl Peer) { self.h2.prefer_h1(peer); } + + /// Return the number of times a pooled connection was found to contain + /// unexpected data from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.h1.unexpected_data_connection_count() + } + + /// Return a shared reference to the unexpected data connection counter for periodic metric reporting. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.h1.unexpected_data_connection_counter() + } } #[cfg(test)] diff --git a/pingora-core/src/connectors/http/v1.rs b/pingora-core/src/connectors/http/v1.rs index 62ecfcb6..ab04b2f6 100644 --- a/pingora-core/src/connectors/http/v1.rs +++ b/pingora-core/src/connectors/http/v1.rs @@ -17,6 +17,8 @@ use crate::protocols::http::v1::client::HttpSession; use crate::upstreams::peer::Peer; use pingora_error::Result; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; use std::time::Duration; pub struct Connector { @@ -60,6 +62,17 @@ impl Connector { .release_stream(stream, peer.reuse_hash(), idle_timeout); } } + + /// Return the number of times a pooled connection was found to contain + /// unexpected data from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.transport.unexpected_data_connection_count() + } + + /// Return a shared reference to the unexpected data connection counter for periodic metric reporting. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.transport.unexpected_data_connection_counter() + } } #[cfg(test)] diff --git a/pingora-core/src/connectors/http/v2.rs b/pingora-core/src/connectors/http/v2.rs index dd1d2b27..0b70b66e 100644 --- a/pingora-core/src/connectors/http/v2.rs +++ b/pingora-core/src/connectors/http/v2.rs @@ -343,8 +343,13 @@ impl Connector { // the caller that the server speaks h2c } } - let max_h2_stream = peer.get_peer_options().map_or(1, |o| o.max_h2_streams); - let conn = handshake(stream, max_h2_stream, peer.h2_ping_interval()).await?; + let peer_options = peer.get_peer_options(); + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = peer_options.map_or(1, |o| o.max_h2_streams); + settings.ping_interval = peer.h2_ping_interval(); + settings.stream_window_size = peer_options.and_then(|o| o.h2_stream_window_size); + settings.connection_window_size = peer_options.and_then(|o| o.h2_connection_window_size); + let conn = handshake(stream, settings).await?; let h2_stream = conn .spawn_stream() .await? @@ -484,19 +489,84 @@ impl Connector { // 8 Mbytes = 80 Mbytes X 100ms, which should be enough for most links. const H2_WINDOW_SIZE: u32 = 1 << 23; -pub async fn handshake( - stream: Stream, - max_streams: usize, - h2_ping_interval: Option, -) -> Result { +/// Maximum allowed H2 window size per [RFC 9113 §6.9.1](https://datatracker.ietf.org/doc/html/rfc9113#section-6.9.1-7). +const H2_MAX_WINDOW_SIZE: u32 = (1u32 << 31) - 1; + +/// Settings for HTTP/2 handshake. +/// +/// # Example +/// +/// ```rust,ignore +/// use pingora_core::connectors::http::v2::{handshake, H2HandshakeSettings}; +/// +/// // With custom window sizes +/// let mut settings = H2HandshakeSettings::new(); +/// settings.max_streams = 100; +/// settings.stream_window_size = Some(1 << 20); // 1MiB +/// settings.connection_window_size = Some(1 << 24); // 16MiB +/// let conn = handshake(stream, settings).await?; +/// ``` +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct H2HandshakeSettings { + /// The maximum number of concurrent streams allowed on this connection. + pub max_streams: usize, + /// Optional interval for sending H2 ping frames to keep the connection alive. + pub ping_interval: Option, + /// Optional initial per-stream receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub stream_window_size: Option, + /// Optional initial connection-level receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub connection_window_size: Option, +} + +impl H2HandshakeSettings { + /// Create a new `H2HandshakeSettings` with all defaults. + pub fn new() -> Self { + Self::default() + } +} + +/// Perform an HTTP/2 handshake on the given stream with the given settings. +pub async fn handshake(stream: Stream, settings: H2HandshakeSettings) -> Result { use h2::client::Builder; use pingora_runtime::current_handle; + let max_streams = settings.max_streams; + // Safe guard: new_http_session() assumes there should be at least one free stream if max_streams == 0 { return Error::e_explain(H2Error, "zero max_stream configured"); } + // Validate window sizes against RFC 9113 §6.9.1 limit + // https://datatracker.ietf.org/doc/html/rfc9113#section-6.9.1-7 + if settings + .stream_window_size + .is_some_and(|w| w == 0 || w > H2_MAX_WINDOW_SIZE) + { + return Error::e_explain( + H2Error, + format!( + "stream_window_size must be between 1 and {} (2^31-1)", + H2_MAX_WINDOW_SIZE + ), + ); + } + if settings + .connection_window_size + .is_some_and(|w| w == 0 || w > H2_MAX_WINDOW_SIZE) + { + return Error::e_explain( + H2Error, + format!( + "connection_window_size must be between 1 and {} (2^31-1)", + H2_MAX_WINDOW_SIZE + ), + ); + } + let id = stream.id(); let digest = Digest { // NOTE: this field is always false because the digest is shared across all streams @@ -507,16 +577,16 @@ pub async fn handshake( proxy_digest: stream.get_proxy_digest(), socket_digest: stream.get_socket_digest(), }; - // TODO: make these configurable + let stream_window = settings.stream_window_size.unwrap_or(H2_WINDOW_SIZE); + let conn_window = settings.connection_window_size.unwrap_or(H2_WINDOW_SIZE); let (send_req, connection) = Builder::new() .enable_push(false) .initial_max_send_streams(max_streams) // The limit for the server. Server push is not allowed, so this value doesn't matter .max_concurrent_streams(1) .max_frame_size(64 * 1024) // advise server to send larger frames - .initial_window_size(H2_WINDOW_SIZE) - // should this be max_streams * H2_WINDOW_SIZE? - .initial_connection_window_size(H2_WINDOW_SIZE) + .initial_window_size(stream_window) + .initial_connection_window_size(conn_window) .handshake(stream) .await .or_err(HandshakeError, "during H2 handshake")?; @@ -538,7 +608,7 @@ pub async fn handshake( connection, id, closed_tx, - h2_ping_interval, + settings.ping_interval, ping_timeout_clone, ) .await; @@ -558,6 +628,9 @@ pub async fn handshake( mod tests { use super::*; use crate::upstreams::peer::HttpPeer; + use bytes::Bytes; + use http::{Response, StatusCode}; + use pingora_http::RequestHeader; #[tokio::test] #[cfg(feature = "any_tls")] @@ -818,4 +891,110 @@ mod tests { .unwrap() .is_none()); } + + #[tokio::test] + async fn test_h2_handshake_settings_validation() { + use super::H2HandshakeSettings; + + // Test zero stream window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.stream_window_size = Some(0); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("stream_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for stream_window_size = 0"), + } + + // Test zero connection window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.connection_window_size = Some(0); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("connection_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for connection_window_size = 0"), + } + + // Test exceeding max stream window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.stream_window_size = Some(super::H2_MAX_WINDOW_SIZE + 1); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("stream_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for stream_window_size > max"), + } + + // Test exceeding max connection window size is rejected + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.connection_window_size = Some(super::H2_MAX_WINDOW_SIZE + 1); + let (client, _server) = tokio::io::duplex(65536); + match handshake(Box::new(client), settings).await { + Err(e) => assert!( + e.to_string() + .contains("connection_window_size must be between 1"), + "Unexpected error: {}", + e + ), + Ok(_) => panic!("Expected error for connection_window_size > max"), + } + } + + #[tokio::test] + async fn test_h2_handshake_custom_window_sizes() { + // Test that valid custom window sizes are accepted and handshake succeeds + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 100; + settings.stream_window_size = Some(1 << 20); // 1MiB + settings.connection_window_size = Some(1 << 24); // 16MiB + + let (client, server) = tokio::io::duplex(65536); + + // Spawn server side + tokio::spawn(async move { + let mut server_conn = h2::server::handshake(server).await.unwrap(); + if let Some(result) = server_conn.accept().await { + let (_request, mut respond) = result.unwrap(); + let resp = Response::builder().status(StatusCode::OK).body(()).unwrap(); + let mut stream = respond.send_response(resp, false).unwrap(); + stream.send_data(Bytes::from("ok"), true).unwrap(); + server_conn.graceful_shutdown(); + } + // Drive the server connection until the client closes + while let Some(_res) = server_conn.accept().await {} + }); + + // Client side - should succeed with custom window sizes + let conn = handshake(Box::new(client), settings).await.unwrap(); + + // Verify we can spawn a stream and complete a request/response cycle + let mut stream = conn.spawn_stream().await.unwrap().unwrap(); + let mut request = RequestHeader::build("GET", b"/", None).unwrap(); + request + .insert_header(http::header::HOST, "example.com") + .unwrap(); + stream + .write_request_header(Box::new(request), true) + .unwrap(); + + stream.read_response_header().await.unwrap(); + assert_eq!(stream.response_header().unwrap().status, 200); + } } diff --git a/pingora-core/src/connectors/mod.rs b/pingora-core/src/connectors/mod.rs index e5e987cb..0e3c727c 100644 --- a/pingora-core/src/connectors/mod.rs +++ b/pingora-core/src/connectors/mod.rs @@ -37,6 +37,7 @@ use pingora_error::{Error, ErrorType::*, OrErr, Result}; use pingora_pool::{ConnectionMeta, ConnectionPool}; use std::collections::HashMap; use std::net::SocketAddr; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use tls::TlsConnector; use tokio::sync::Mutex; @@ -146,6 +147,9 @@ pub struct TransportConnector { bind_to_v4: Vec, bind_to_v6: Vec, preferred_http_version: PreferredHttpVersion, + /// Wrapped in `Arc` so external consumers (e.g. proxy services) can clone a reference + /// for periodic metric reporting without needing access to the connector itself. + unexpected_data_conn_count: Arc, } const DEFAULT_POOL_SIZE: usize = 128; @@ -172,6 +176,7 @@ impl TransportConnector { bind_to_v4, bind_to_v6, preferred_http_version: PreferredHttpVersion::new(), + unexpected_data_conn_count: Arc::new(AtomicU64::new(0)), } } @@ -212,7 +217,9 @@ impl TransportConnector { // test_reusable_stream: we assume server would never actively send data // first on an idle stream. #[cfg(unix)] - if peer.matches_fd(stream.id()) && test_reusable_stream(&mut stream) { + if peer.matches_fd(stream.id()) + && test_reusable_stream(&mut stream, &self.unexpected_data_conn_count) + { Some(stream) } else { None @@ -227,7 +234,10 @@ impl TransportConnector { } } if peer.matches_sock(WrappedRawSocket(stream.id() as RawSocket)) - && test_reusable_stream(&mut stream) + && test_reusable_stream( + &mut stream, + &self.unexpected_data_conn_count, + ) { Some(stream) } else { @@ -261,7 +271,7 @@ impl TransportConnector { key: u64, // usually peer.reuse_hash() idle_timeout: Option, ) { - if !test_reusable_stream(&mut stream) { + if !test_reusable_stream(&mut stream, &self.unexpected_data_conn_count) { return; } let id = stream.id(); @@ -301,6 +311,21 @@ impl TransportConnector { pub fn prefer_h1(&self, peer: &impl Peer) { self.preferred_http_version.add(peer, 1); } + + /// Return the number of times a pooled connection was found to contain unexpected data + /// from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.unexpected_data_conn_count.load(Ordering::Relaxed) + } + + /// Return a shared reference to the unexpected data connection counter. + /// + /// This allows external consumers (e.g. proxy services) to clone the `Arc` and + /// periodically read the counter for metric reporting without needing ongoing + /// access to the connector. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.unexpected_data_conn_count.clone() + } } // Perform the actual L4 and tls connection steps while respecting the peer's @@ -376,7 +401,7 @@ use futures::future::FutureExt; use tokio::io::AsyncReadExt; /// Test whether a stream is already closed or not reusable (server sent unexpected data) -fn test_reusable_stream(stream: &mut Stream) -> bool { +fn test_reusable_stream(stream: &mut Stream, unexpected_data_conn_count: &AtomicU64) -> bool { let mut buf = [0; 1]; // tokio::task::unconstrained because now_or_never may yield None when the future is ready let result = tokio::task::unconstrained(stream.read(&mut buf[..])).now_or_never(); @@ -387,6 +412,7 @@ fn test_reusable_stream(stream: &mut Stream) -> bool { debug!("Idle connection is closed"); } else { warn!("Unexpected data read in idle connection"); + unexpected_data_conn_count.fetch_add(1, Ordering::Relaxed); } } Err(e) => { @@ -644,4 +670,32 @@ mod tests { let (etype, context) = get_do_connect_failure_with_peer(&peer).await; assert!(etype != ConnectTimedout || !context.contains("total-connection timeout")); } + + #[tokio::test] + async fn test_unexpected_data_connection_count_increments() { + // Create a duplex stream where we control both ends + let (mut server, client) = tokio::io::duplex(64); + + let counter = AtomicU64::new(0); + let mut stream: Stream = Box::new(client); + + // With no data available, the stream should be considered reusable + assert!(test_reusable_stream(&mut stream, &counter)); + assert_eq!(counter.load(Ordering::Relaxed), 0); + + // Write unexpected data from the server side + use tokio::io::AsyncWriteExt; + server.write_all(b"unexpected").await.unwrap(); + + // Give the data a moment to be buffered + tokio::task::yield_now().await; + + // Now test_reusable_stream should detect the unexpected data + assert!(!test_reusable_stream(&mut stream, &counter)); + assert_eq!( + counter.load(Ordering::Relaxed), + 1, + "unexpected_data_connection_count should have incremented" + ); + } } diff --git a/pingora-core/src/protocols/http/server.rs b/pingora-core/src/protocols/http/server.rs index 78852939..438f3cb0 100644 --- a/pingora-core/src/protocols/http/server.rs +++ b/pingora-core/src/protocols/http/server.rs @@ -811,4 +811,67 @@ impl Session { Self::Custom(_) => None, } } + + /// Check if this session supports the cancel-safe proxy task API. + /// + /// For HTTP/1.x, this can be toggled per-session via + /// [`set_proxy_tasks_enabled`](Self::set_proxy_tasks_enabled). + pub fn supports_proxy_task_api(&self) -> bool { + match self { + Self::H1(s) => s.proxy_tasks_enabled(), + _ => false, + } + } + + /// Enable or disable the cancel-safe proxy task API for this session. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + if let Self::H1(s) = self { + s.set_proxy_tasks_enabled(enabled); + } + } + + /// Queue a downstream proxy task for cancel-safe writing. + /// + /// # Panics + /// Panics if called on a session that doesn't support the proxy task API. + /// Check [`supports_proxy_task_api`](Self::supports_proxy_task_api) first, + /// or use `write_response_header()` / `write_response_body()` for other + /// session types. + pub fn send_downstream_proxy_task(&mut self, task: HttpTask) { + match self { + Self::H1(s) => s.send_proxy_task(task), + Self::H2(_) => panic!("H2 proxy task API not yet implemented"), + Self::Subrequest(_) => panic!("Subrequest proxy task API not yet implemented"), + Self::Custom(_) => panic!("Custom proxy task API not yet implemented"), + } + } + + /// Check if there are pending downstream proxy tasks queued for writing. + /// + /// Returns false for sessions that don't support the proxy task API. + pub fn has_pending_downstream_proxy_tasks(&self) -> bool { + match self { + Self::H1(s) => s.has_pending_proxy_tasks(), + Self::H2(_) => false, // TODO: implement for H2 + Self::Subrequest(_) => false, // TODO: implement for subrequests + Self::Custom(_) => false, // TODO: implement for custom + } + } + + /// Write all queued downstream proxy tasks in a cancel-safe manner. + /// Returns `Ok(true)` if this was the end of the response stream. + /// + /// # Panics + /// Panics if called on a session that doesn't support the proxy task API. + /// Check [`supports_proxy_task_api`](Self::supports_proxy_task_api) first, + /// or use `write_response_header()` / `write_response_body()` for other + /// session types. + pub async fn write_downstream_proxy_tasks(&mut self) -> Result { + match self { + Self::H1(s) => s.write_proxy_tasks().await, + Self::H2(_) => panic!("H2 proxy task API not yet implemented"), + Self::Subrequest(_) => panic!("Subrequest proxy task API not yet implemented"), + Self::Custom(_) => panic!("Custom proxy task API not yet implemented"), + } + } } diff --git a/pingora-core/src/protocols/http/v1/body.rs b/pingora-core/src/protocols/http/v1/body.rs index 72899257..61872af6 100644 --- a/pingora-core/src/protocols/http/v1/body.rs +++ b/pingora-core/src/protocols/http/v1/body.rs @@ -20,9 +20,14 @@ use pingora_error::{ OrErr, Result, }; use std::fmt::Debug; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::protocols::l4::stream::AsyncWriteVec; +use crate::protocols::l4::stream::{ + async_write_vec::{poll_write_all_buf, poll_write_vec_all_buf}, + AsyncWriteVec, +}; use crate::utils::BufRef; // TODO: make this dynamically adjusted @@ -905,14 +910,188 @@ pub enum BodyMode { type BM = BodyMode; +// ============================================================================ +// Cancel-safe body writing types +// ============================================================================ + +impl BodyMode { + /// Extract `(total, written)` from `ContentLength`, panicking on mismatch. + fn expect_content_length(&self) -> (usize, usize) { + match self { + BodyMode::ContentLength(total, written) => (*total, *written), + _ => panic!("wrong body mode: expected ContentLength, got {:?}", self), + } + } + + /// Extract `written` from `ChunkedEncoding`, panicking on mismatch. + fn expect_chunked(&self) -> usize { + match self { + BodyMode::ChunkedEncoding(written) => *written, + _ => panic!("wrong body mode: expected ChunkedEncoding, got {:?}", self), + } + } + + /// Extract `written` from `UntilClose`, panicking on mismatch. + fn expect_until_close(&self) -> usize { + match self { + BodyMode::UntilClose(written) => *written, + _ => panic!("wrong body mode: expected UntilClose, got {:?}", self), + } + } +} + +/// Type alias for the chunked encoding buffer chain +type ChunkedBuf = bytes::buf::Chain, &'static [u8]>; + +enum WriteBuf { + /// Simple bytes buffer + Simple(Bytes), + /// Chained buffer for chunked encoding or other complex writes + Chained(C), +} + +// Implement Buf for WriteBuf to delegate to the inner buffer +impl Buf for WriteBuf { + fn remaining(&self) -> usize { + match self { + WriteBuf::Simple(b) => b.remaining(), + WriteBuf::Chained(c) => c.remaining(), + } + } + + fn chunk(&self) -> &[u8] { + match self { + WriteBuf::Simple(b) => b.chunk(), + WriteBuf::Chained(c) => c.chunk(), + } + } + + fn advance(&mut self, cnt: usize) { + match self { + WriteBuf::Simple(b) => b.advance(cnt), + WriteBuf::Chained(c) => c.advance(cnt), + } + } +} + +enum WriteState { + /// No write in progress + Idle, + /// Writing data (original size, bytes remaining to write) + Writing(usize, WriteBuf), + /// Flushing after write (original size to return) + Flushing(usize), + /// Write complete (bytes written in this task) + Done(usize), + /// Write timed out - cannot be reused + TimedOut, +} + +// Custom Debug implementation since we can't derive it with futures +impl std::fmt::Debug for WriteState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + WriteState::Idle => write!(f, "Idle"), + WriteState::Writing(size, _buf) => { + write!(f, "Writing(size: {})", size) + } + WriteState::Flushing(size) => write!(f, "Flushing(size: {})", size), + WriteState::Done(size) => write!(f, "Done(size: {})", size), + WriteState::TimedOut => write!(f, "TimedOut"), + } + } +} + +enum FinishWriteState { + /// No finish task queued + NotStarted, + /// Finish queued but not started yet + Idle, + /// Writing last chunk marker (for chunked encoding) + WritingLastChunk(WriteBuf), + /// Flushing after writing last chunk + Flushing, + /// Finish complete + Done, +} + +// Custom Debug implementation since WriteBuf doesn't implement Debug +impl std::fmt::Debug for FinishWriteState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FinishWriteState::NotStarted => write!(f, "NotStarted"), + FinishWriteState::Idle => write!(f, "Idle"), + FinishWriteState::WritingLastChunk(_) => write!(f, "WritingLastChunk"), + FinishWriteState::Flushing => write!(f, "Flushing"), + FinishWriteState::Done => write!(f, "Done"), + } + } +} + +/// Internal state for the cancel-safe body write state machine. +/// +/// Tracks the pending body bytes, write progress +/// (idle → writing → flushing → done), and an optional timeout. +struct SendBodyState { + /// Application bytes queued to be written + pending_bytes: Option, + /// Current write state for cancel-safe operations + write_state: WriteState, + /// Timeout duration for this write task + timeout_duration: Option, + /// Timeout future (only created if write returns Pending) + timeout_fut: Option + Send + Sync>>>, +} + +impl std::fmt::Debug for SendBodyState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SendBodyState") + .field("pending_bytes", &self.pending_bytes) + .field("write_state", &self.write_state) + .field("timeout_duration", &self.timeout_duration) + .field( + "timeout_fut", + &self.timeout_fut.as_ref().map(|_| "Some(Future)"), + ) + .finish() + } +} + +impl SendBodyState { + fn new() -> Self { + SendBodyState { + pending_bytes: None, + write_state: WriteState::Idle, + timeout_duration: None, + timeout_fut: None, + } + } +} + +/// Tracks how response body bytes are framed and written to the wire. +/// +/// Supports both a legacy async API (`write_body` / `finish`) and a cancel-safe +/// task API that can be driven inside a `tokio::select!` loop without losing +/// write progress. pub struct BodyWriter { pub body_mode: BodyMode, + // Boxed to reduce inline size. Only used by the cancel-safe proxy task API. + send_body_state: Box, + send_finish_state: FinishWriteState, +} + +impl Default for BodyWriter { + fn default() -> Self { + Self::new() + } } impl BodyWriter { pub fn new() -> Self { BodyWriter { body_mode: BM::ToSelect, + send_body_state: Box::new(SendBodyState::new()), + send_finish_state: FinishWriteState::NotStarted, } } @@ -1109,6 +1288,543 @@ impl BodyWriter { _ => panic!("wrong body mode: {:?}", self.body_mode), } } + + // ======================================================================== + // Cancel-safe body task API + // ======================================================================== + + #[cfg(test)] + pub fn has_pending_body_task(&self) -> bool { + self.send_body_state.pending_bytes.is_some() + || !matches!( + self.send_body_state.write_state, + WriteState::Idle | WriteState::Done(_) | WriteState::TimedOut + ) + } + + /// Queue application bytes as a body write task with an optional timeout. + /// This is a non-async function that just saves the bytes. + /// Call `write_current_body_task()` to actually perform the write. + /// + /// The timeout, if provided, will be enforced internally across all + /// write attempts, even if the write is cancelled and resumed via `tokio::select!`. + pub fn send_body_task(&mut self, bytes: Bytes, timeout: Option) { + assert!( + matches!( + self.send_body_state.write_state, + WriteState::Idle | WriteState::Done(_) + ), + "send_body_task called while previous task is still in progress: {:?}", + self.send_body_state.write_state + ); + self.send_body_state.pending_bytes = Some(bytes); + self.send_body_state.write_state = WriteState::Idle; + self.send_body_state.timeout_duration = timeout; + self.send_body_state.timeout_fut = None; + } + + /// Writes the current queued body task to the stream. + /// + /// ## Cancel-safety + /// + /// This function can be safely used in a `tokio::select!` loop. + /// Returns `Ok(Some(bytes_written))` when complete, `Ok(None)` if no bytes to write. + pub async fn write_current_body_task(&mut self, stream: &mut S) -> Result> + where + S: AsyncWrite + Unpin + Send, + { + // Use poll_fn to wrap our poll-based implementation + std::future::poll_fn(|cx| self.poll_write_current_body_task(cx, Pin::new(stream))).await + } + + /// Poll-based implementation for writing body tasks. + /// This is the core implementation that maintains state across cancellations. + fn poll_write_current_body_task( + &mut self, + cx: &mut Context<'_>, + stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Check if already timed out - don't allow reuse + if matches!(self.send_body_state.write_state, WriteState::TimedOut) { + return Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )); + } + + // Lazy timeout optimization: Poll write first, create timeout only if needed. + // + // This follows the pattern from `pingora_timeout::Timeout` to avoid allocating + // and registering timeout futures when writes complete immediately (the common case). + // + // Fast path: Write completes → return immediately, no timeout future created + // Slow path: Write blocks → lazily create timeout future and poll both + + // First, try the write operation + // Dispatch to the appropriate body mode handler + let result = match self.body_mode { + BM::Complete(_) => Poll::Ready(Ok(None)), + BM::ContentLength(_, _) => self.poll_write_content_length_body_task(cx, stream), + BM::ChunkedEncoding(_) => self.poll_write_chunked_body_task(cx, stream), + BM::UntilClose(_) => self.poll_write_until_close_body_task(cx, stream), + BM::ToSelect => Poll::Ready(Ok(None)), + }; + + // If write completed immediately, return without ever creating/polling timeout + if result.is_ready() { + return result; + } + + // Write returned Pending - lazily create and check timeout if duration is set + if let Some(duration) = self.send_body_state.timeout_duration { + let timeout = self.send_body_state.timeout_fut.get_or_insert_with(|| { + Box::pin(pingora_timeout::sleep(duration)) + as std::pin::Pin + Send + Sync>> + }); + + if timeout.as_mut().poll(cx).is_ready() { + // Timeout fired! Mark state as timed out and clear the timeout future + self.send_body_state.write_state = WriteState::TimedOut; + self.send_body_state.timeout_fut = None; + return Poll::Ready(Error::e_explain( + WriteTimedout, + "writing body task timed out", + )); + } + } + + // Both write and timeout are pending + Poll::Pending + } + + // ======================================================================== + // Cancel-safe finish task API + // ======================================================================== + + #[cfg(test)] + pub fn has_pending_finish_task(&self) -> bool { + !matches!( + self.send_finish_state, + FinishWriteState::NotStarted | FinishWriteState::Done + ) + } + + /// Queue a finish operation as a task. + /// This is a non-async function that just marks the finish as pending. + /// Call `write_current_finish_task()` to actually perform the finish. + /// + /// This API is stateful and cancel-safe - use it when you need to finish + /// the body in a `tokio::select!` loop or other cancellable context. + pub fn send_finish_task(&mut self) { + self.send_finish_state = FinishWriteState::Idle; + } + + /// Async function that performs the current queued finish task on the stream. + /// This function is cancel-safe and can be called in a `tokio::select!` loop. + /// Returns `Ok(Some(bytes_written))` when complete, `Ok(None)` if already complete. + /// + /// This API is stateful - it tracks progress across cancellations and can be + /// safely resumed after being dropped mid-execution. + pub async fn write_current_finish_task(&mut self, stream: &mut S) -> Result> + where + S: AsyncWrite + Unpin + Send, + { + // Use poll_fn to wrap our poll-based implementation + std::future::poll_fn(|cx| self.poll_write_current_finish_task(cx, Pin::new(stream))).await + } + + /// Poll-based implementation for finish tasks. + /// This is the core implementation that maintains state across cancellations. + fn poll_write_current_finish_task( + &mut self, + cx: &mut Context<'_>, + stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // If no finish queued, return None + if matches!( + self.send_finish_state, + FinishWriteState::NotStarted | FinishWriteState::Done + ) { + return Poll::Ready(Ok(None)); + } + + // Route to body-mode-specific implementation + match self.body_mode { + BM::Complete(_) => Poll::Ready(Ok(None)), + BM::ContentLength(_, _) => self.poll_finish_content_length_task(cx, stream), + BM::ChunkedEncoding(_) => self.poll_finish_chunked_task(cx, stream), + BM::UntilClose(_) => self.poll_finish_until_close_task(cx, stream), + BM::ToSelect => Poll::Ready(Ok(None)), + } + } + + /// Finish content-length body - just validates and updates state. + /// No I/O needed since body write tasks already flushed after the last write. + fn poll_finish_content_length_task( + &mut self, + _cx: &mut Context<'_>, + _stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + let written = match self.body_mode { + BM::ContentLength(total, w) => { + if w < total { + self.send_finish_state = FinishWriteState::Done; + return Poll::Ready(Error::e_explain( + PREMATURE_BODY_END, + format!("Content-length: {total} bytes written: {w}"), + )); + } + w + } + _ => panic!("wrong body mode: {:?}", self.body_mode), + }; + + // All bytes written - just update state to Complete + self.body_mode = BM::Complete(written); + self.send_finish_state = FinishWriteState::Done; + Poll::Ready(Ok(Some(written))) + } + + /// Poll-based helper to finish chunked encoding body + fn poll_finish_chunked_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + let written = match self.body_mode { + BM::ChunkedEncoding(w) => w, + _ => panic!("wrong body mode: {:?}", self.body_mode), + }; + + loop { + match &mut self.send_finish_state { + FinishWriteState::Idle => { + // Start writing last chunk marker "0\r\n\r\n" + let buf = WriteBuf::Simple(Bytes::from_static(&LAST_CHUNK[..])); + self.send_finish_state = FinishWriteState::WritingLastChunk(buf); + } + FinishWriteState::WritingLastChunk(buf) => { + // Poll write_vec_all - write until all bytes are written + ready!(poll_write_vec_all_buf(cx, stream.as_mut(), buf)) + .map_err(|e| Error::because(WriteError, "while writing last chunk", e))?; + + // All bytes written, move to flushing state + self.send_finish_state = FinishWriteState::Flushing; + } + FinishWriteState::Flushing => { + // Poll flush + ready!(stream.as_mut().poll_flush(cx)) + .map_err(|e| Error::because(WriteError, "flushing after last chunk", e))?; + + // Flush complete! Update body_mode and mark done + self.body_mode = BM::Complete(written); + self.send_finish_state = FinishWriteState::Done; + return Poll::Ready(Ok(Some(written))); + } + FinishWriteState::Done => { + unreachable!( + "Done state should have been handled in poll_write_current_finish_task" + ) + } + FinishWriteState::NotStarted => { + unreachable!("NotStarted state should have been handled in poll_write_current_finish_task") + } + } + } + } + + /// Finish until-close body - just updates state. + /// No I/O needed since body write tasks already flushed after each write. + fn poll_finish_until_close_task( + &mut self, + _cx: &mut Context<'_>, + _stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + let written = match self.body_mode { + BM::UntilClose(w) => w, + _ => panic!("wrong body mode: {:?}", self.body_mode), + }; + + // Just update state to Complete + self.body_mode = BM::Complete(written); + self.send_finish_state = FinishWriteState::Done; + Poll::Ready(Ok(Some(written))) + } + + // ======================================================================== + // Internal helpers + // ======================================================================== + + /// Internal helper to poll a body task that writes in content-length mode + /// and flushes at end. + fn poll_write_content_length_body_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Move to Writing state if we're Idle + if matches!(self.send_body_state.write_state, WriteState::Idle) { + if let Some(mut bytes) = self.send_body_state.pending_bytes.take() { + let (total, written) = self.body_mode.expect_content_length(); + + // Check if we've already written everything + if written >= total { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + + let original_size = bytes.len(); + let remaining = total - written; + + // Truncate bytes if they exceed content-length + if original_size > remaining { + warn!( + "Trying to write {} bytes over content-length: {}, truncating to {}", + original_size, total, remaining + ); + bytes.truncate(remaining); + } + + let bytes_to_write = bytes.len(); + self.send_body_state.write_state = + WriteState::Writing(bytes_to_write, WriteBuf::Simple(bytes)); + } else { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + } + + // Handle Writing state - do the write, transition to Flushing or Done + if let WriteState::Writing(size, ref mut buf) = &mut self.send_body_state.write_state { + let bytes_written = *size; + + // Attempt write + match ready!(poll_write_all_buf(cx, stream.as_mut(), buf)) { + Ok(()) => { + // Write completed - update body_mode to track bytes written + let (total, written) = self.body_mode.expect_content_length(); + self.body_mode = BM::ContentLength(total, written + bytes_written); + + if written + bytes_written >= total { + // All content-length bytes written, flush needed + self.send_body_state.write_state = WriteState::Flushing(bytes_written); + } else { + // More bytes to come, no flush needed + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + } + Err(e) => { + return Poll::Ready(Error::e_because(WriteError, "while writing body", e)) + } + } + } + + // Handle Flushing state - do the flush, transition to Done + if let WriteState::Flushing(size) = self.send_body_state.write_state { + let bytes_written = size; + + // Attempt flush + match ready!(stream.poll_flush(cx)) { + Ok(()) => { + // Flush completed - transition to Done + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + Err(e) => return Poll::Ready(Error::e_because(WriteError, "flushing body", e)), + } + } + + // Return based on final state + match self.send_body_state.write_state { + WriteState::Done(size) => { + self.send_body_state.timeout_fut = None; + Poll::Ready(Ok(Some(size))) + } + WriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )), + WriteState::Writing(..) | WriteState::Flushing(..) => { + unreachable!("Writing/Flushing states should have been handled above or returned Pending via ready!") + } + WriteState::Idle => { + unreachable!("Idle state should have been handled in setup") + } + } + } + + /// Poll-based implementation for chunked encoding mode + fn poll_write_chunked_body_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Move to Writing state if we're Idle + if matches!(self.send_body_state.write_state, WriteState::Idle) { + if let Some(bytes) = self.send_body_state.pending_bytes.take() { + let application_bytes_size = bytes.len(); + + // Format the chunk: size\r\ndata\r\n + let chunk_size_header = format!("{:X}\r\n", application_bytes_size); + let output_buf = Bytes::from(chunk_size_header) + .chain(bytes) + .chain(&b"\r\n"[..]); + + // Store the chained buffer directly to avoid copying + self.send_body_state.write_state = + WriteState::Writing(application_bytes_size, WriteBuf::Chained(output_buf)); + } else { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + } + + // Handle Writing state - do the write using vectored I/O, transition to Flushing + if let WriteState::Writing(size, ref mut buf) = &mut self.send_body_state.write_state { + let bytes_written = *size; + + // Attempt vectored write for chained buffer (chunk size + data + CRLF) + match ready!(poll_write_vec_all_buf(cx, stream.as_mut(), buf)) { + Ok(()) => { + // Write completed - update body_mode with application bytes (not wire bytes) + let written = self.body_mode.expect_chunked(); + self.body_mode = BM::ChunkedEncoding(written + bytes_written); + + // Chunked encoding always flushes + self.send_body_state.write_state = WriteState::Flushing(bytes_written); + } + Err(e) => { + return Poll::Ready(Error::e_because(WriteError, "while writing body", e)) + } + } + } + + // Handle Flushing state - do the flush, transition to Done + if let WriteState::Flushing(size) = self.send_body_state.write_state { + let bytes_written = size; + + // Attempt flush + match ready!(stream.poll_flush(cx)) { + Ok(()) => { + // Flush completed - transition to Done + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + Err(e) => return Poll::Ready(Error::e_because(WriteError, "flushing body", e)), + } + } + + // Return based on final state + match self.send_body_state.write_state { + WriteState::Done(size) => { + self.send_body_state.timeout_fut = None; + Poll::Ready(Ok(Some(size))) + } + WriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )), + WriteState::Writing(..) | WriteState::Flushing(..) => { + unreachable!("Writing/Flushing states should have been handled above or returned Pending via ready!") + } + WriteState::Idle => { + unreachable!("Idle state should have been handled in setup") + } + } + } + + /// Poll-based implementation for UntilClose (close-delimited) body mode + fn poll_write_until_close_body_task( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll>> + where + S: AsyncWrite + Unpin + Send, + { + // Move to Writing state if we're Idle + if matches!(self.send_body_state.write_state, WriteState::Idle) { + if let Some(bytes) = self.send_body_state.pending_bytes.take() { + let original_size = bytes.len(); + self.send_body_state.write_state = + WriteState::Writing(original_size, WriteBuf::Simple(bytes)); + } else { + self.send_body_state.write_state = WriteState::Done(0); + return Poll::Ready(Ok(None)); + } + } + + // Handle Writing state - do the write, transition to Flushing + if let WriteState::Writing(size, ref mut buf) = &mut self.send_body_state.write_state { + let bytes_written = *size; + + // Attempt write + match ready!(poll_write_all_buf(cx, stream.as_mut(), buf)) { + Ok(()) => { + // Write completed - update body_mode to track bytes written + let written = self.body_mode.expect_until_close(); + self.body_mode = BM::UntilClose(written + bytes_written); + + // Close-delimited mode always flushes + self.send_body_state.write_state = WriteState::Flushing(bytes_written); + } + Err(e) => { + return Poll::Ready(Error::e_because(WriteError, "while writing body", e)) + } + } + } + + // Handle Flushing state - do the flush, transition to Done + if let WriteState::Flushing(size) = self.send_body_state.write_state { + let bytes_written = size; + + // Attempt flush + match ready!(stream.poll_flush(cx)) { + Ok(()) => { + // Flush completed - transition to Done + self.send_body_state.write_state = WriteState::Done(bytes_written); + } + Err(e) => return Poll::Ready(Error::e_because(WriteError, "flushing body", e)), + } + } + + // Return based on final state + match self.send_body_state.write_state { + WriteState::Done(size) => { + self.send_body_state.timeout_fut = None; + Poll::Ready(Ok(Some(size))) + } + WriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "write task previously timed out", + )), + WriteState::Writing(..) | WriteState::Flushing(..) => { + unreachable!("Writing/Flushing states should have been handled above or returned Pending via ready!") + } + WriteState::Idle => { + unreachable!("Idle state should have been handled in setup") + } + } + } } #[cfg(test)] @@ -1717,33 +2433,41 @@ mod tests { let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(0, 0)); assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); - let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(3, 1)); // input1 concat input2 - assert_eq!(&input2[1..2], body_reader.get_body(&res)); - assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 11, 0)); - let res = body_reader.read_body(&mut mock_io).await.unwrap(); - assert_eq!(res, None); - assert_eq!(body_reader.body_state, ParseState::Complete(1)); - assert_eq!(body_reader.get_body_overread(), None); + let _res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); } #[tokio::test] - async fn read_with_body_partial_head_terminal_crlf() { + async fn read_with_body_partial_head_chunk_incomplete() { init_log(); let input1 = b"1\r"; - let input2 = b"\na\r\n0\r\n\r"; - let input3 = b"\n"; - let mut mock_io = Builder::new() - .read(&input1[..]) - .read(&input2[..]) - .read(&input3[..]) - .build(); + let mut mock_io = Builder::new().read(&input1[..]).build(); let mut body_reader = BodyReader::new(false); body_reader.init_chunked(b""); let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(0, 0)); assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); - let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + let res = body_reader.read_body(&mut mock_io).await; + assert!(res.is_err()); + assert_eq!(body_reader.body_state, ParseState::Done(0)); + } + + #[tokio::test] + async fn read_with_body_partial_head_terminal_crlf() { + init_log(); + let input1 = b"1\r"; + let input2 = b"\na\r\n0\r\n\r"; + let input3 = b"\n"; + let mut mock_io = Builder::new() + .read(&input1[..]) + .read(&input2[..]) + .read(&input3[..]) + .build(); + let mut body_reader = BodyReader::new(false); + body_reader.init_chunked(b""); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); + assert_eq!(res, BufRef::new(0, 0)); + assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); + let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); assert_eq!(res, BufRef::new(3, 1)); // input1 concat input2 assert_eq!(&input2[1..2], body_reader.get_body(&res)); assert_eq!(body_reader.body_state, ParseState::Chunked(1, 6, 10, 0)); @@ -1925,21 +2649,6 @@ mod tests { assert_eq!(body_reader.get_body_overread(), Some(&b"abc"[..])); } - #[tokio::test] - async fn read_with_body_partial_head_chunk_incomplete() { - init_log(); - let input1 = b"1\r"; - let mut mock_io = Builder::new().read(&input1[..]).build(); - let mut body_reader = BodyReader::new(false); - body_reader.init_chunked(b""); - let res = body_reader.read_body(&mut mock_io).await.unwrap().unwrap(); - assert_eq!(res, BufRef::new(0, 0)); - assert_eq!(body_reader.body_state, ParseState::Chunked(0, 0, 2, 2)); - let res = body_reader.read_body(&mut mock_io).await; - assert!(res.is_err()); - assert_eq!(body_reader.body_state, ParseState::Done(0)); - } - #[tokio::test] async fn read_with_body_trailers() { init_log(); @@ -2319,7 +3028,7 @@ mod tests { } #[tokio::test] - async fn write_body_http10() { + async fn write_body_until_close() { init_log(); let data = b"a"; let mut mock_io = Builder::new().write(&data[..]).write(&data[..]).build(); @@ -2345,3 +3054,799 @@ mod tests { assert_eq!(body_writer.body_mode, BodyMode::Complete(2)); } } + +#[cfg(test)] +mod test_body_task_api { + use super::*; + use crate::protocols::http::v1::test_util::FlushTrackingMock; + use tokio_test::io::Builder; + + // Cancel-safety tests use tokio::select! to race a short sleep against a mock + // I/O wait, simulating cancellation. We use #[tokio::test(start_paused = true)] + // on these tests so that tokio auto-advances time deterministically rather than + // relying on wall-clock timing. + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[tokio::test] + async fn test_has_pending_body_task() { + init_log(); + let data = b"test data"; + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Initially should have no pending task + assert!(!body_writer.has_pending_body_task()); + + // After queuing bytes, should have pending task + body_writer.send_body_task(Bytes::from_static(data), None); + assert!(body_writer.has_pending_body_task()); + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_content_length_write() { + init_log(); + let data = b"Hello, World!"; + + // Create a mock stream that will block to allow cancellation + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue the bytes to write + body_writer.send_body_task(Bytes::from_static(data), None); + + // Use tokio::select! loop - keep looping until write completes + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + // Break if no pending writes + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + // Timeout fires first, cancelling the write + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + // Write completed + assert!(result.is_ok(), "Write should succeed"); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!( + cancel_count > 0, + "At least one cancellation should have occurred" + ); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + assert_eq!( + body_writer.body_mode, + BodyMode::ContentLength(data.len(), data.len()) + ); + + // Now test finish() in a select loop as well + let mut mock_io_finish = Builder::new().build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => { + // Allow cancellation attempts + } + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + + assert_eq!(body_writer.body_mode, BodyMode::Complete(data.len())); + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_chunked_write() { + init_log(); + let data = b"abcdefghij"; + let expected_output = b"A\r\nabcdefghij\r\n"; + + // Mock stream that blocks to allow cancellation + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(expected_output) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_chunked(); + + // Queue bytes + body_writer.send_body_task(Bytes::from_static(data), None); + + // Use select loop - keep looping until write completes + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count > 0, "Should have cancelled at least once"); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + assert_eq!(body_writer.body_mode, BodyMode::ChunkedEncoding(data.len())); + + // Test finish() with select loop - must write terminating chunk + let mut mock_io_finish = Builder::new() + .wait(std::time::Duration::from_millis(50)) + .write(&LAST_CHUNK[..]) // Expect 0\r\n\r\n + .build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + + assert_eq!(body_writer.body_mode, BodyMode::Complete(data.len())); + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_until_close_write() { + init_log(); + let data = b"test data"; + + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_close_delimited(); + + body_writer.send_body_task(Bytes::from_static(data), None); + + // Use select loop - keep looping until write completes + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count > 0); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + + // Test finish() with select loop + let mut mock_io_finish = Builder::new().build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_multiple_cancellations() { + init_log(); + let data = b"Long test data that requires multiple writes"; + + // Create a mock that blocks multiple times + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(50)) + .write(&data[..15]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[15..30]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[30..]) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + body_writer.send_body_task(Bytes::from_static(data), None); + + // Loop until write completes, allowing cancellations + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count >= 2, "Should have multiple cancellations"); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + + // Test finish with select loop + let mut mock_io_finish = Builder::new().build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok()); + break; + } + } + } + } + + #[tokio::test(start_paused = true)] + async fn cancel_safe_partial_writes() { + init_log(); + let data = b"12345678901234567890"; // 20 bytes + + // Simulate partial writes with blocking + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(50)) + .write(&data[..7]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[7..14]) + .wait(std::time::Duration::from_millis(50)) + .write(&data[14..]) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + body_writer.send_body_task(Bytes::from_static(data), None); + + let mut cancel_count = 0; + let mut total_bytes_written = 0; + + loop { + if !body_writer.has_pending_body_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + assert!(result.is_ok()); + if let Ok(Some(n)) = result { + total_bytes_written += n; + } + } + } + } + + assert!(cancel_count > 0); + assert_eq!( + total_bytes_written, + data.len(), + "Should have written all application bytes" + ); + + // Test finish in select loop + let mut mock_io_finish = Builder::new() + .wait(std::time::Duration::from_millis(30)) + .build(); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(5)) => {} + result = body_writer.finish(&mut mock_io_finish) => { + assert!(result.is_ok(), "Finish should succeed after cancel-safe writes"); + break; + } + } + } + } + + #[tokio::test] + async fn test_task_write_timeout() { + init_log(); + let data = b"test data"; + + // Create a mock that blocks forever + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue the task with a timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(50)), + ); + + // The write should timeout + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_err(), "Write should timeout"); + + // Check that it's a timeout error + if let Err(e) = result { + assert_eq!(e.etype(), &WriteTimedout); + } + } + + // Even if the user's select! cancels the write, the internal timeout + // should continue counting across cancellations. + #[tokio::test] + async fn test_task_timeout_persists_across_cancellations() { + init_log(); + let data = b"test data"; + + // Create a mock that blocks for a while + // Since timeout is 100ms and this waits 200ms, the write should never happen + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(200)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue the task with a 100ms timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(100)), + ); + + let mut attempts = 0; + let mut timedout = false; + + // Try to write in a loop, but cancel early each time + // The timeout should still fire even though we're cancelling + loop { + if !body_writer.has_pending_body_task() { + break; + } + + attempts += 1; + + tokio::select! { + // Cancel after just 10ms each time + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + // Cancelled by our select, continue looping + continue; + } + result = body_writer.write_current_body_task(&mut mock_io) => { + match result { + Ok(_) => { + // Write succeeded before timeout + break; + } + Err(e) if e.etype() == &WriteTimedout => { + // Timeout fired! + timedout = true; + break; + } + Err(e) => { + panic!("Unexpected error: {:?}", e); + } + } + } + } + } + + assert!(timedout, "Timeout should have fired despite cancellations"); + assert!( + attempts >= 5, + "Should have had multiple attempts before timeout" + ); + } + + #[tokio::test] + async fn test_task_write_succeeds_within_timeout() { + init_log(); + let data = b"Hello, World!"; + + // Create a mock that completes quickly + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(20)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue with a generous timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(500)), + ); + + // Write should succeed + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok(), "Write should succeed: {:?}", result); + assert_eq!(result.unwrap(), Some(data.len())); + } + + #[tokio::test] + async fn test_task_write_no_timeout() { + init_log(); + let data = b"test data"; + + // Create a mock that takes a bit of time but eventually succeeds + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + // Queue without timeout + body_writer.send_body_task(Bytes::from_static(data), None); + + // Write should eventually succeed + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok(), "Write should succeed without timeout"); + assert_eq!(result.unwrap(), Some(data.len())); + } + + #[tokio::test] + async fn test_task_chunked_write_timeout() { + init_log(); + let data = b"chunked data"; + + // Create a mock that blocks + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_chunked(); + + // Queue with short timeout + body_writer.send_body_task( + Bytes::from_static(data), + Some(std::time::Duration::from_millis(50)), + ); + + // Should timeout + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e.etype(), &WriteTimedout); + } + } + + #[tokio::test] + async fn test_task_timeout_reset_on_new_task() { + init_log(); + let data1 = b"first"; + let data2 = b"second"; + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data1.len() + data2.len()); + + // Queue first task with short timeout + body_writer.send_body_task( + Bytes::from_static(data1), + Some(std::time::Duration::from_millis(50)), + ); + + // Wait a bit but don't let it timeout yet + tokio::time::sleep(std::time::Duration::from_millis(30)).await; + + // Queue a new task with a longer timeout + // This should reset/replace the timeout + body_writer.send_body_task( + Bytes::from_static(data2), + Some(std::time::Duration::from_millis(500)), + ); + + // Create a mock that takes some time + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(data2) + .build(); + + // The second write should succeed with its own timeout + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!( + result.is_ok(), + "Second task should succeed with new timeout" + ); + } + + #[tokio::test] + async fn test_task_timeout_with_partial_writes() { + init_log(); + let data1 = b"first"; + let data2 = b"second"; + let data3 = b"third"; + + // Mock that writes data1 quickly, data2 with delay, data3 blocks forever + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(10)) + .write(data1) + .wait(std::time::Duration::from_millis(40)) + .write(data2) + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data1.len() + data2.len() + data3.len()); + + let mut total_written = 0; + + // First write - should succeed within timeout + body_writer.send_body_task( + Bytes::from_static(data1), + Some(std::time::Duration::from_millis(100)), + ); + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok()); + total_written += result.unwrap().unwrap(); + + // Second write - should succeed within timeout + body_writer.send_body_task( + Bytes::from_static(data2), + Some(std::time::Duration::from_millis(100)), + ); + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_ok()); + total_written += result.unwrap().unwrap(); + + // Third write - should timeout + body_writer.send_body_task( + Bytes::from_static(data3), + Some(std::time::Duration::from_millis(50)), + ); + let result = body_writer.write_current_body_task(&mut mock_io).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().etype(), &WriteTimedout); + + // We should have written data1 and data2 but not data3 + assert_eq!(total_written, data1.len() + data2.len()); + assert!( + total_written < data1.len() + data2.len() + data3.len(), + "Should not have written all data" + ); + } + + // Cancel-safe finish task for chunked encoding: send_finish_task() queues + // the terminating chunk, write_current_finish_task() writes it and can be + // cancelled and resumed in a select! loop. + // Verifies that the finish flushes the stream exactly once. + #[tokio::test(start_paused = true)] + async fn cancel_safe_finish_task_chunked() { + init_log(); + + let data = Bytes::from("hello"); + let expected_chunk = b"5\r\nhello\r\n"; + + let mock_io = Builder::new().write(expected_chunk).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + + let mut body_writer = BodyWriter::new(); + body_writer.init_chunked(); + + // Write body data via task API + body_writer.send_body_task(data, None); + body_writer + .write_current_body_task(&mut flush_mock) + .await + .unwrap(); + + // Chunked body writes always flush after each chunk + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 1, + "Chunked body data write should flush once" + ); + + // Queue the finish task + body_writer.send_finish_task(); + assert!(body_writer.has_pending_finish_task()); + + // Write the finish in a select! loop with cancellations + let mock_io_finish = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(b"0\r\n\r\n") + .build(); + let (mut flush_mock_finish, flush_count_finish) = FlushTrackingMock::new(mock_io_finish); + + let mut cancel_count = 0; + + loop { + if !body_writer.has_pending_finish_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = body_writer.write_current_finish_task(&mut flush_mock_finish) => { + assert!(result.is_ok()); + break; + } + } + } + + assert!(cancel_count > 0, "Should have cancelled at least once"); + assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count_finish), + 1, + "Chunked finish should flush exactly once" + ); + } + + // Finish task for content-length is a no-op (no terminating chunk needed), + // but it should still transition body_mode to Complete. + // Verifies that no flush occurs (content-length finish has no I/O). + #[tokio::test] + async fn finish_task_content_length() { + init_log(); + + let data = b"hello"; + let mock_io = Builder::new().write(data).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(data.len()); + + body_writer.send_body_task(Bytes::from_static(data), None); + body_writer + .write_current_body_task(&mut flush_mock) + .await + .unwrap(); + + // Content-length body write flushes when all bytes are written + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 1, + "Content-length body write should flush once (all bytes written)" + ); + + body_writer.send_finish_task(); + let mock_io_finish = Builder::new().build(); + let (mut flush_mock_finish, flush_count_finish) = FlushTrackingMock::new(mock_io_finish); + let result = body_writer + .write_current_finish_task(&mut flush_mock_finish) + .await; + assert!(result.is_ok()); + assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count_finish), + 0, + "Content-length finish should not flush (no I/O needed)" + ); + } + + // Verifies that body_mode byte tracking is correct when writing + // content-length body in multiple chunks. Each intermediate chunk + // does not trigger a flush; the body_mode must still accumulate + // bytes correctly so that finish_task succeeds. + #[tokio::test] + async fn content_length_body_mode_tracks_across_chunks() { + init_log(); + + let chunk1 = b"Hello"; + let chunk2 = b", World!"; + let total_len = chunk1.len() + chunk2.len(); // 13 + + // Mock expects both writes; the final write triggers a flush internally + let mut mock_io = Builder::new().write(chunk1).write(chunk2).build(); + + let mut body_writer = BodyWriter::new(); + body_writer.init_content_length(total_len); + + // Write first chunk (intermediate, no flush expected) + body_writer.send_body_task(Bytes::from_static(chunk1), None); + let result = body_writer + .write_current_body_task(&mut mock_io) + .await + .unwrap(); + assert_eq!(result, Some(chunk1.len())); + assert!( + !body_writer.finished(), + "Should not be finished after first chunk" + ); + + // Verify body_mode tracks the bytes from the first chunk + assert!( + matches!(body_writer.body_mode, BodyMode::ContentLength(total, written) + if total == total_len && written == chunk1.len()), + "body_mode should reflect bytes written so far, got: {:?}", + body_writer.body_mode + ); + + // Write second chunk (final, completes content-length) + body_writer.send_body_task(Bytes::from_static(chunk2), None); + let result = body_writer + .write_current_body_task(&mut mock_io) + .await + .unwrap(); + assert_eq!(result, Some(chunk2.len())); + assert!( + body_writer.finished(), + "Should be finished after all bytes written" + ); + + // Finish should succeed since all content-length bytes were written + body_writer.send_finish_task(); + let mut mock_io_finish = Builder::new().build(); + let result = body_writer + .write_current_finish_task(&mut mock_io_finish) + .await; + assert!( + result.is_ok(), + "finish_task should succeed when all content-length bytes written" + ); + assert!(matches!(body_writer.body_mode, BodyMode::Complete(_))); + } +} diff --git a/pingora-core/src/protocols/http/v1/header.rs b/pingora-core/src/protocols/http/v1/header.rs new file mode 100644 index 00000000..b6abdb71 --- /dev/null +++ b/pingora-core/src/protocols/http/v1/header.rs @@ -0,0 +1,459 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Cancel-safe header writing for HTTP/1.x + +use bytes::Bytes; +use pingora_error::{Error, ErrorType::*, Result}; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use tokio::io::AsyncWrite; + +use crate::protocols::l4::stream::async_write_vec::poll_write_all_buf; + +enum HeaderWriteState { + /// No write in progress + Idle, + /// Writing header bytes (original size, buffer) + Writing(usize, Bytes), + /// Flushing after write (original size to return) + Flushing(usize), + /// Write complete + Done, + /// Write timed out - cannot be reused + TimedOut, +} + +// Custom Debug implementation +impl std::fmt::Debug for HeaderWriteState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HeaderWriteState::Idle => write!(f, "Idle"), + HeaderWriteState::Writing(size, _) => write!(f, "Writing(size: {})", size), + HeaderWriteState::Flushing(size) => write!(f, "Flushing(size: {})", size), + HeaderWriteState::Done => write!(f, "Done"), + HeaderWriteState::TimedOut => write!(f, "TimedOut"), + } + } +} + +/// Internal state for the cancel-safe header write state machine. +/// +/// Tracks the pending header bytes, write progress (idle → writing → flushing → done), +/// and an optional timeout that is lazily created on the first `Pending` poll. +struct SendHeaderState { + /// Serialized header bytes ready to be written + pending_header: Option, + /// Whether to flush after writing + should_flush: bool, + /// Current write state + write_state: HeaderWriteState, + /// Timeout duration for this write task + timeout_duration: Option, + /// Timeout future (only created if write returns Pending) + timeout_fut: Option + Send + Sync>>>, +} + +// Custom Debug implementation since timeout_fut doesn't implement Debug +impl std::fmt::Debug for SendHeaderState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SendHeaderState") + .field("pending_header", &self.pending_header) + .field("should_flush", &self.should_flush) + .field("write_state", &self.write_state) + .field("timeout_duration", &self.timeout_duration) + .field( + "timeout_fut", + &self.timeout_fut.as_ref().map(|_| "Some(Future)"), + ) + .finish() + } +} + +impl SendHeaderState { + fn new() -> Self { + SendHeaderState { + pending_header: None, + should_flush: false, + write_state: HeaderWriteState::Idle, + timeout_duration: None, + timeout_fut: None, + } + } +} + +/// Cancel-safe header writer for HTTP/1.x response headers. +/// +/// This writer allows response headers to be written to a downstream connection +/// inside a `tokio::select!` loop without losing progress. If the write is +/// cancelled (e.g. because another branch of the select fires first), the +/// partially-written state is preserved and will be resumed on the next call to +/// [`write_current_header_task`](Self::write_current_header_task). +/// +/// ## Usage +/// +/// 1. Call [`send_header_task`](Self::send_header_task) with pre-serialized +/// header bytes, a flush flag, and an optional timeout. +/// 2. Await [`write_current_header_task`](Self::write_current_header_task) +/// (possibly inside `tokio::select!`). The method returns `Ok(bytes_written)` +/// on success. +/// +/// A timeout, if set, is enforced *across* cancellations — the clock keeps +/// ticking even when the future is dropped and re-polled. +pub struct HeaderWriter { + // Boxed to reduce inline size. Only used by the cancel-safe proxy task API. + send_header_state: Box, +} + +impl Default for HeaderWriter { + fn default() -> Self { + Self::new() + } +} + +impl HeaderWriter { + pub fn new() -> Self { + HeaderWriter { + send_header_state: Box::new(SendHeaderState::new()), + } + } + + #[cfg(test)] + pub fn has_pending_header_task(&self) -> bool { + self.send_header_state.pending_header.is_some() + || !matches!( + self.send_header_state.write_state, + HeaderWriteState::Idle | HeaderWriteState::Done | HeaderWriteState::TimedOut + ) + } + + /// Queue serialized header bytes as a write task with an optional timeout. + /// This is a non-async function that just saves the bytes. + /// Call [`write_current_header_task`](Self::write_current_header_task) to actually perform the write. + pub fn send_header_task( + &mut self, + header_bytes: Bytes, + should_flush: bool, + timeout: Option, + ) { + assert!( + matches!( + self.send_header_state.write_state, + HeaderWriteState::Idle | HeaderWriteState::Done + ), + "send_header_task called while previous task is still in progress: {:?}", + self.send_header_state.write_state + ); + self.send_header_state.pending_header = Some(header_bytes); + self.send_header_state.should_flush = should_flush; + self.send_header_state.write_state = HeaderWriteState::Idle; + self.send_header_state.timeout_duration = timeout; + self.send_header_state.timeout_fut = None; + } + + /// Async function that writes the current queued header task to the stream. + /// This function is cancel-safe and can be called in a `tokio::select!` loop. + /// Returns `Ok(bytes_written)` when complete, `Ok(0)` if no bytes to write. + pub async fn write_current_header_task(&mut self, stream: &mut S) -> Result + where + S: AsyncWrite + Unpin + Send, + { + std::future::poll_fn(|cx| self.poll_write_current_header_task(cx, Pin::new(stream))).await + } + + /// Poll-based implementation for writing the current header task. + fn poll_write_current_header_task( + &mut self, + cx: &mut Context<'_>, + stream: Pin<&mut S>, + ) -> Poll> + where + S: AsyncWrite + Unpin + Send, + { + // Check if already timed out - don't allow reuse + if matches!( + self.send_header_state.write_state, + HeaderWriteState::TimedOut + ) { + return Poll::Ready(Error::e_explain( + WriteTimedout, + "header write task previously timed out", + )); + } + + // First, try the write operation + match self.poll_do_write_header_and_flush(cx, stream) { + Poll::Ready(Ok(size)) => { + // Write completed! Clear timeout and return + if matches!(self.send_header_state.write_state, HeaderWriteState::Done) { + self.send_header_state.timeout_fut = None; + } + return Poll::Ready(Ok(size)); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + // Write is pending - now check timeout + } + } + + // Lazy timeout optimization: Polls write first, creates timeout only if needed. + // This follows the pattern from `pingora_timeout::Timeout` to avoid allocating + // timeout futures when writes complete immediately (the common case). + if let Some(duration) = self.send_header_state.timeout_duration { + let timeout = self.send_header_state.timeout_fut.get_or_insert_with(|| { + Box::pin(pingora_timeout::sleep(duration)) + as std::pin::Pin + Send + Sync>> + }); + + if timeout.as_mut().poll(cx).is_ready() { + // Timeout fired! + self.send_header_state.write_state = HeaderWriteState::TimedOut; + self.send_header_state.timeout_fut = None; + return Poll::Ready(Error::e_explain( + WriteTimedout, + "writing header task timed out", + )); + } + } + + // Both write and timeout are pending + Poll::Pending + } + + /// Poll-based helper to write header bytes and optionally flush. + /// Handles state transitions explicitly. + fn poll_do_write_header_and_flush( + &mut self, + cx: &mut Context<'_>, + mut stream: Pin<&mut S>, + ) -> Poll> + where + S: AsyncWrite + Unpin + Send, + { + // Handle Idle state - take pending header and transition to Writing + if matches!(self.send_header_state.write_state, HeaderWriteState::Idle) { + if let Some(header_bytes) = self.send_header_state.pending_header.take() { + let size = header_bytes.len(); + self.send_header_state.write_state = HeaderWriteState::Writing(size, header_bytes); + } else { + // No pending header + self.send_header_state.write_state = HeaderWriteState::Done; + return Poll::Ready(Ok(0)); + } + } + + // Write if in Writing state + if let HeaderWriteState::Writing(original_size, ref mut buf) = + self.send_header_state.write_state + { + let size = original_size; + ready!(poll_write_all_buf(cx, stream.as_mut(), buf)) + .map_err(|e| Error::because(WriteError, "writing response header", e))?; + + // Write complete - transition to next state + if self.send_header_state.should_flush { + self.send_header_state.write_state = HeaderWriteState::Flushing(size); + } else { + self.send_header_state.write_state = HeaderWriteState::Done; + return Poll::Ready(Ok(size)); + } + } + + // Handle the state after writing (or if we started in a non-Writing state) + match self.send_header_state.write_state { + HeaderWriteState::Flushing(size) => { + ready!(stream.as_mut().poll_flush(cx)) + .map_err(|e| Error::because(WriteError, "flushing response header", e))?; + // Flush complete - transition to Done + self.send_header_state.write_state = HeaderWriteState::Done; + Poll::Ready(Ok(size)) + } + HeaderWriteState::Done => Poll::Ready(Ok(0)), + HeaderWriteState::TimedOut => Poll::Ready(Error::e_explain( + WriteTimedout, + "header write task previously timed out", + )), + HeaderWriteState::Idle => { + unreachable!("Idle state should have been handled above") + } + HeaderWriteState::Writing(..) => { + unreachable!("Writing state should have been handled above") + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::http::v1::test_util::FlushTrackingMock; + use tokio_test::io::Builder; + + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[tokio::test] + async fn test_simple_header_write_no_flush() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + + let mock_io = Builder::new().write(header_data).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task(Bytes::from_static(header_data), false, None); + + let result = header_writer + .write_current_header_task(&mut flush_mock) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), header_data.len()); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 0, + "should_flush=false should not flush" + ); + } + + #[tokio::test] + async fn test_header_write_with_flush() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\n\r\n"; + + let mock_io = Builder::new().write(header_data).build(); + let (mut flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task(Bytes::from_static(header_data), true, None); + + let result = header_writer + .write_current_header_task(&mut flush_mock) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), header_data.len()); + assert_eq!( + FlushTrackingMock::flush_count(&flush_count), + 1, + "should_flush=true should flush exactly once" + ); + } + + // Uses start_paused for deterministic timer-based cancellation in select! + #[tokio::test(start_paused = true)] + async fn test_cancel_safe_header_write() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\nServer: pingora\r\n\r\n"; + + // Mock that blocks to allow cancellation + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(100)) + .write(header_data) + .build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task(Bytes::from_static(header_data), false, None); + + let mut cancel_count = 0; + + loop { + if !header_writer.has_pending_header_task() { + break; + } + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + cancel_count += 1; + } + result = header_writer.write_current_header_task(&mut mock_io) => { + assert!(result.is_ok()); + assert_eq!(result.unwrap(), header_data.len()); + break; + } + } + } + + assert!(cancel_count > 0, "Should have cancelled at least once"); + } + + #[tokio::test] + async fn test_header_write_timeout() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\n\r\n"; + + // Mock that blocks forever + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_secs(1000)) + .build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task( + Bytes::from_static(header_data), + false, + Some(std::time::Duration::from_millis(50)), + ); + + let result = header_writer.write_current_header_task(&mut mock_io).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().etype(), &WriteTimedout); + } + + #[tokio::test] + async fn test_header_write_timeout_persists() { + init_log(); + let header_data = b"HTTP/1.1 200 OK\r\n\r\n"; + + // Mock that blocks for a while + let mut mock_io = Builder::new() + .wait(std::time::Duration::from_millis(200)) + .build(); + + let mut header_writer = HeaderWriter::new(); + header_writer.send_header_task( + Bytes::from_static(header_data), + false, + Some(std::time::Duration::from_millis(100)), + ); + + let mut attempts = 0; + let mut timedout = false; + + loop { + if !header_writer.has_pending_header_task() { + break; + } + + attempts += 1; + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => { + continue; + } + result = header_writer.write_current_header_task(&mut mock_io) => { + match result { + Ok(_) => break, + Err(e) if e.etype() == &WriteTimedout => { + timedout = true; + break; + } + Err(e) => panic!("Unexpected error: {:?}", e), + } + } + } + } + + assert!(timedout, "Timeout should have fired"); + assert!(attempts >= 5, "Should have had multiple attempts"); + } +} diff --git a/pingora-core/src/protocols/http/v1/mod.rs b/pingora-core/src/protocols/http/v1/mod.rs index 19602491..6f085a70 100644 --- a/pingora-core/src/protocols/http/v1/mod.rs +++ b/pingora-core/src/protocols/http/v1/mod.rs @@ -17,4 +17,110 @@ pub(crate) mod body; pub mod client; pub mod common; +pub(crate) mod header; pub mod server; + +/// Test utilities shared across HTTP/1.x unit tests +#[cfg(test)] +pub(crate) mod test_util { + use std::pin::Pin; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::task::{Context, Poll}; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_test::io::Mock; + + /// A wrapper around [`Mock`] that counts flush calls. + /// + /// `tokio_test::io::Mock`'s `poll_flush` always returns `Ready(Ok(()))`, + /// so we can't detect flush calls via mock alone. This wrapper counts them. + #[derive(Debug)] + pub(crate) struct FlushTrackingMock { + inner: Mock, + flush_count: Arc, + } + + impl FlushTrackingMock { + pub(crate) fn new(mock: Mock) -> (Self, Arc) { + let flush_count = Arc::new(AtomicUsize::new(0)); + ( + FlushTrackingMock { + inner: mock, + flush_count: flush_count.clone(), + }, + flush_count, + ) + } + + pub(crate) fn flush_count(counter: &Arc) -> usize { + counter.load(Ordering::Relaxed) + } + } + + impl AsyncRead for FlushTrackingMock { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_read(cx, buf) + } + } + + impl AsyncWrite for FlushTrackingMock { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let result = Pin::new(&mut this.inner).poll_flush(cx); + if let Poll::Ready(Ok(())) = &result { + this.flush_count.fetch_add(1, Ordering::Relaxed); + } + result + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_shutdown(cx) + } + } + + // Implement IO-required traits so FlushTrackingMock can be used as Box + // in HttpSession tests (server.rs). + use crate::protocols::{ + raw_connect::ProxyDigest, GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, + SocketDigest, Ssl, TimingDigest, UniqueID, UniqueIDType, + }; + + #[async_trait::async_trait] + impl Shutdown for FlushTrackingMock { + async fn shutdown(&mut self) -> () {} + } + impl UniqueID for FlushTrackingMock { + fn id(&self) -> UniqueIDType { + 0 + } + } + impl Ssl for FlushTrackingMock {} + impl GetTimingDigest for FlushTrackingMock { + fn get_timing_digest(&self) -> Vec> { + vec![] + } + } + impl GetProxyDigest for FlushTrackingMock { + fn get_proxy_digest(&self) -> Option> { + None + } + } + impl GetSocketDigest for FlushTrackingMock { + fn get_socket_digest(&self) -> Option> { + None + } + } + impl Peek for FlushTrackingMock {} +} diff --git a/pingora-core/src/protocols/http/v1/server.rs b/pingora-core/src/protocols/http/v1/server.rs index 7e648ca5..9144c6e5 100644 --- a/pingora-core/src/protocols/http/v1/server.rs +++ b/pingora-core/src/protocols/http/v1/server.rs @@ -28,15 +28,48 @@ use pingora_http::{IntoCaseHeaderName, RequestHeader, ResponseHeader}; use pingora_timeout::timeout; use regex::bytes::Regex; use std::any::Any; +use std::collections::VecDeque; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::body::{BodyReader, BodyWriter}; use super::common::*; +use super::header::HeaderWriter; use crate::protocols::http::{body_buffer::FixedBuffer, date, HttpTask}; use crate::protocols::{Digest, SocketAddr, Stream}; use crate::utils::{BufRef, KVRef}; +/// Tracks which writer is currently processing a task. +/// +/// This enables resuming writes after cancellation. Each variant stores the +/// minimal data needed for cleanup after write completes. +#[derive(Debug)] +enum ProxyTaskWriter { + /// Currently writing a header task. + /// Stores: (header for `response_written`, end_stream flag) + WritingHeader(Box, bool), + /// Currently writing a body task (`Body` or `UpgradedBody`). + /// Stores: (end_stream flag) + WritingBody(bool), + /// Currently finishing the body (writing last chunk + flush). + FinishingBody, +} + +/// State for the cancel-safe proxy task write API. +#[derive(Default)] +struct ProxyTaskState { + /// Lazily initialized — `HeaderWriter::new()` heap-allocates. + header_writer: Option, + tasks: VecDeque, + current_writer: Option, +} + +impl ProxyTaskState { + fn header_writer(&mut self) -> &mut HeaderWriter { + self.header_writer.get_or_insert_with(HeaderWriter::new) + } +} + /// The HTTP 1.x server session pub struct HttpSession { underlying_stream: Stream, @@ -52,6 +85,8 @@ pub struct HttpSession { body_reader: BodyReader, /// A state machine to track how to write the response body body_writer: BodyWriter, + /// Cancel-safe proxy task state. + proxy_task_state: ProxyTaskState, /// An internal buffer to buf multiple body writes to reduce the underlying syscalls body_write_buf: BytesMut, /// Track how many application (not on the wire) body bytes already sent @@ -98,6 +133,9 @@ pub struct HttpSession { /// close is tolerated and `read_body_or_idle` stays pending so the proxy can /// finish delivering the upstream response (RFC 9112 Section 9.6). abort_on_close: bool, + /// Whether the cancel-safe proxy task API is enabled for this session. + /// Defaults to false. Can be enabled via [`set_proxy_tasks_enabled`](Self::set_proxy_tasks_enabled). + proxy_tasks_enabled: bool, } impl HttpSession { @@ -120,6 +158,7 @@ impl HttpSession { preread_body: None, body_reader: BodyReader::new(false), body_writer: BodyWriter::new(), + proxy_task_state: ProxyTaskState::default(), body_write_buf: BytesMut::new(), keepalive_timeout: KeepaliveStatus::Off, update_resp_headers: true, @@ -141,6 +180,7 @@ impl HttpSession { connection_user_context: None, half_closed: false, abort_on_close: true, + proxy_tasks_enabled: false, } } @@ -510,101 +550,12 @@ impl HttpSession { /// Write the response header to the client. /// This function can be called more than once to send 1xx informational headers excluding 101. pub async fn write_response_header(&mut self, mut header: Box) -> Result<()> { - if header.status.is_informational() && self.ignore_info_resp(header.status.into()) { - debug!("ignoring informational headers"); + // Prepare header (handle upgrades, set headers, initialize body writer, serialize to bytes) + let Some((write_buf, flush)) = self.prepare_response_header(&mut header)? else { + // Header already sent or should be ignored return Ok(()); - } - - if let Some(resp) = self.response_written.as_ref() { - if !resp.status.is_informational() || self.upgraded { - warn!("Respond header is already sent, cannot send again"); - return Ok(()); - } - } - - // if body unfinished, or request header was not finished reading - if self.close_on_response_before_downstream_finish - && (self.request_header.is_none() || !self.is_body_done()) - { - debug!("set connection close before downstream finish"); - self.set_keepalive(None); - } - - // no need to add these headers to 1xx responses - if !header.status.is_informational() && self.update_resp_headers { - /* update headers */ - header.insert_header(header::DATE, date::get_cached_date())?; - - // TODO: make these lazy static - let connection_value = if self.will_keepalive() { - "keep-alive" - } else { - "close" - }; - header.insert_header(header::CONNECTION, connection_value)?; - } - - if header.status == 101 { - // make sure the connection is closed at the end when 101/upgrade is used - self.set_keepalive(None); - } - - // Allow informational header (excluding 101) to pass through without affecting the state - // of the request - if header.status == 101 || !header.status.is_informational() { - // reset request body to done for incomplete upgrade handshakes - if let Some(upgrade_ok) = self.is_upgrade(&header) { - if upgrade_ok { - debug!("ok upgrade handshake"); - // For ws we use HTTP1_0 do_read_body_until_closed - // - // On ws close the initiator sends a close frame and - // then waits for a response from the peer, once it receives - // a response it closes the conn. After receiving a - // control frame indicating the connection should be closed, - // a peer discards any further data received. - // https://www.rfc-editor.org/rfc/rfc6455#section-1.4 - self.upgraded = true; - // Now that the upgrade was successful, we need to change - // how we interpret the rest of the body as pass-through. - if self.body_reader.need_init() { - self.init_body_reader(); - } else { - // already initialized - // immediately start reading the rest of the body as upgraded - // (in practice most upgraded requests shouldn't have any body) - // - // TODO: https://datatracker.ietf.org/doc/html/rfc9110#name-upgrade - // the most spec-compliant behavior is to switch interpretation - // after sending the former body, - // we immediately switch interpretation to match nginx - self.body_reader.convert_to_close_delimited(); - } - } else { - // this was a request that requested Upgrade, - // but upstream did not comply - debug!("bad upgrade handshake!"); - // continue to read body as-is, this is now just a regular request - } - } - self.init_body_writer(&header); - } - - // Defense-in-depth: if response body is close-delimited, mark session - // as un-reusable - if self.body_writer.is_close_delimited() { - self.set_keepalive(None); - } - - // Don't have to flush response with content length because it is less - // likely to be real time communication. So do flush when - // 1.1xx response: client needs to see it before the rest of response - // 2.No content length: the response could be generated in real time - let flush = header.status.is_informational() - || header.headers.get(header::CONTENT_LENGTH).is_none(); + }; - let mut write_buf = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE); - http_resp_header_to_buf(&header, &mut write_buf).unwrap(); match self.underlying_stream.write_all(&write_buf).await { Ok(()) => { // flush the stream if 1xx header or there is no response body @@ -615,7 +566,6 @@ impl HttpSession { .or_err(WriteError, "flushing response header")?; } self.response_written = Some(header); - self.body_bytes_sent += write_buf.len(); Ok(()) } Err(e) => Error::e_because(WriteError, "writing response header", e), @@ -759,6 +709,117 @@ impl HttpSession { } } + /// Prepare response header for writing: handle upgrades, set headers, initialize body writer. + /// This contains all the synchronous logic that should happen before writing the header. + /// Returns Ok(Some((bytes, should_flush))) if the header should be written, Ok(None) if should skip. + fn prepare_response_header( + &mut self, + header: &mut ResponseHeader, + ) -> Result> { + // Check if we should ignore informational responses + if header.status.is_informational() && self.ignore_info_resp(header.status.into()) { + debug!("ignoring informational headers"); + return Ok(None); + } + + // Check if we already sent a response header + if let Some(ref resp) = self.response_written { + if !resp.status.is_informational() || self.upgraded { + warn!("Respond header is already sent, cannot send again"); + return Ok(None); + } + } + + // if body unfinished, or request header was not finished reading + if self.close_on_response_before_downstream_finish + && (self.request_header.is_none() || !self.is_body_done()) + { + debug!("set connection close before downstream finish"); + self.set_keepalive(None); + } + + // no need to add these headers to 1xx responses + if !header.status.is_informational() && self.update_resp_headers { + /* update headers */ + header.insert_header(header::DATE, date::get_cached_date())?; + + // TODO: make these lazy static + let connection_value = if self.will_keepalive() { + "keep-alive" + } else { + "close" + }; + header.insert_header(header::CONNECTION, connection_value)?; + } + + if header.status == 101 { + // make sure the connection is closed at the end when 101/upgrade is used + self.set_keepalive(None); + } + + // Allow informational header (excluding 101) to pass through without affecting the state + // of the request + if header.status == 101 || !header.status.is_informational() { + // reset request body to done for incomplete upgrade handshakes + if let Some(upgrade_ok) = self.is_upgrade(header) { + if upgrade_ok { + debug!("ok upgrade handshake"); + // For ws we use HTTP1_0 do_read_body_until_closed + // + // On ws close the initiator sends a close frame and + // then waits for a response from the peer, once it receives + // a response it closes the conn. After receiving a + // control frame indicating the connection should be closed, + // a peer discards any further data received. + // https://www.rfc-editor.org/rfc/rfc6455#section-1.4 + self.upgraded = true; + // Now that the upgrade was successful, we need to change + // how we interpret the rest of the body as pass-through. + if self.body_reader.need_init() { + self.init_body_reader(); + } else { + // already initialized + // immediately start reading the rest of the body as upgraded + // (in practice most upgraded requests shouldn't have any body) + // + // TODO: https://datatracker.ietf.org/doc/html/rfc9110#name-upgrade + // the most spec-compliant behavior is to switch interpretation + // after sending the former body, + // we immediately switch interpretation to match nginx + self.body_reader.convert_to_close_delimited(); + } + } else { + // this was a request that requested Upgrade, + // but upstream did not comply + debug!("bad upgrade handshake!"); + // continue to read body as-is, this is now just a regular request + } + } + self.init_body_writer(header); + } + + // Defense-in-depth: if response body is close-delimited, mark session + // as un-reusable + if self.body_writer.is_close_delimited() { + self.set_keepalive(None); + } + + // Serialize header to bytes + let mut write_buf = BytesMut::with_capacity(INIT_HEADER_BUF_SIZE); + http_resp_header_to_buf(header, &mut write_buf) + .map_err(|_| Error::explain(WriteError, "serializing response header"))?; + + // Determine if we should flush + // Don't have to flush response with content length because it is less + // likely to be real time communication. So do flush when + // 1. 1xx response: client needs to see it before the rest of response + // 2. No content length: the response could be generated in real time + let should_flush = header.status.is_informational() + || header.headers.get(header::CONTENT_LENGTH).is_none(); + + Ok(Some((write_buf.freeze(), should_flush))) + } + fn init_body_writer(&mut self, header: &ResponseHeader) { use http::StatusCode; /* the following responses don't have body 204, 304, and HEAD */ @@ -814,6 +875,16 @@ impl HttpSession { } } + /// Whether the cancel-safe proxy task API is enabled for this session. + pub fn proxy_tasks_enabled(&self) -> bool { + self.proxy_tasks_enabled + } + + /// Enable or disable the cancel-safe proxy task API for this session. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + self.proxy_tasks_enabled = enabled; + } + async fn do_write_body_buf(&mut self) -> Result> { // Don't flush empty chunks, they are considered end of body for chunks if self.body_write_buf.is_empty() { @@ -1320,6 +1391,152 @@ impl HttpSession { Ok(end_stream || self.body_writer.finished()) } + /// Queue a proxy task for cancel-safe writing with the current write_timeout. + /// The task will be written when `write_proxy_tasks()` is called. + /// + /// A write canceled mid-operation can be resumed via `write_proxy_tasks()`. + pub fn send_proxy_task(&mut self, task: HttpTask) { + self.proxy_task_state.tasks.push_back(task); + } + + /// Check if there are pending proxy tasks queued for writing. + pub fn has_pending_proxy_tasks(&self) -> bool { + self.proxy_task_state.current_writer.is_some() || !self.proxy_task_state.tasks.is_empty() + } + + /// Write all queued proxy tasks (response `HttpTask`s from `send_proxy_task`) + /// in a cancel-safe manner. + /// + /// If cancelled mid-write, the next call will resume the in-progress write. + /// + /// Returns `Ok(true)` if this was the end of the response stream. + // Leverages the cancel-safe `HeaderWriter` and `BodyWriter` primitives. + // TODO: we can do the same for the non-cancel-safe APIs. + pub async fn write_proxy_tasks(&mut self) -> Result { + let mut end_stream = false; + + // TODO: buffer body data like response_duplex_vec + loop { + // - Resume any in-progress write + if let Some(ref writer_state) = self.proxy_task_state.current_writer { + match writer_state { + ProxyTaskWriter::WritingHeader(_, _) => { + let _bytes_written = self + .proxy_task_state + .header_writer() + .write_current_header_task(&mut self.underlying_stream) + .await + .map_err(|e| e.into_down())?; + } + ProxyTaskWriter::WritingBody(_) => { + let written = self + .body_writer + .write_current_body_task(&mut self.underlying_stream) + .await + .map_err(|e| e.into_down())?; + if let Some(n) = written { + self.body_bytes_sent += n; + } + } + ProxyTaskWriter::FinishingBody => { + self.body_writer + .write_current_finish_task(&mut self.underlying_stream) + .await + .map_err(|e| e.into_down())?; + } + } + + match self + .proxy_task_state + .current_writer + .take() + .expect("writer state present") + { + ProxyTaskWriter::WritingHeader(header, end) => { + self.response_written = Some(header); + end_stream = end; + } + ProxyTaskWriter::WritingBody(end) => { + end_stream = end; + } + ProxyTaskWriter::FinishingBody => { + end_stream = true; + self.maybe_force_close_body_reader(); + break; // fine to break after finish, no tasks should be queued after + } + } + continue; + } + + // - Send tasks, set state. + // Pop next task + let Some(task) = self.proxy_task_state.tasks.pop_front() else { + if end_stream { + self.body_writer.send_finish_task(); + self.proxy_task_state.current_writer = Some(ProxyTaskWriter::FinishingBody); + continue; + } + break; + }; + + match task { + HttpTask::Header(mut header, end) => { + let Some((write_buf, should_flush)) = + self.prepare_response_header(&mut header)? + else { + end_stream = end; + continue; + }; + // header only responses will want to flush + let flush = should_flush || self.body_writer.finished(); + self.proxy_task_state + .header_writer() + .send_header_task(write_buf, flush, None); + self.proxy_task_state.current_writer = + Some(ProxyTaskWriter::WritingHeader(header, end)); + } + HttpTask::Body(ref data, end) => { + if self.upgraded { + panic!("Unexpected Body task received on upgraded downstream session"); + } + if let Some(d) = data.as_ref() { + if !d.is_empty() { + let body_timeout = self.write_timeout(d.len()); + self.body_writer.send_body_task(d.clone(), body_timeout); + self.proxy_task_state.current_writer = + Some(ProxyTaskWriter::WritingBody(end)); + continue; + } + } + end_stream = end; + } + HttpTask::UpgradedBody(ref data, end) => { + if !self.upgraded { + panic!("Unexpected UpgradedBody task received on un-upgraded downstream session"); + } + if let Some(d) = data.as_ref() { + if !d.is_empty() { + let body_timeout = self.write_timeout(d.len()); + self.body_writer.send_body_task(d.clone(), body_timeout); + self.proxy_task_state.current_writer = + Some(ProxyTaskWriter::WritingBody(end)); + continue; + } + } + end_stream = end; + } + HttpTask::Trailer(_) | HttpTask::Done => { + end_stream = true; + } + HttpTask::Failed(e) => { + return Err(e); + } + } + } + + Ok(end_stream || self.body_writer.finished()) + } + /// Get the reference of the [Stream] that this HTTP session is operating upon. pub fn stream(&self) -> &Stream { &self.underlying_stream @@ -2381,6 +2598,30 @@ mod tests_stream { assert_eq!(wire_body.len(), n); } + #[tokio::test] + async fn body_bytes_sent_excludes_response_header() { + let read_wire = b"GET / HTTP/1.1\r\n\r\n"; + let wire_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let wire_body = b"hello"; + let mock_io = Builder::new() + .read(read_wire) + .write(wire_header) + .write(wire_body) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap(); + new_response.append_header("Content-Length", "5").unwrap(); + http_stream.update_resp_headers = false; + http_stream + .write_response_header(Box::new(new_response)) + .await + .unwrap(); + assert_eq!(http_stream.body_bytes_sent(), 0); + http_stream.write_body(wire_body).await.unwrap(); + assert_eq!(http_stream.body_bytes_sent(), wire_body.len()); + } + #[tokio::test] async fn write_body_http10() { let read_wire = b"GET / HTTP/1.1\r\n\r\n"; @@ -2852,25 +3093,30 @@ mod test_sync { } #[cfg(test)] -mod test_timeouts { +mod test_proxy_tasks { use super::*; + use http::StatusCode; use std::future::IntoFuture; use tokio_test::io::{Builder, Mock}; - /// An upper limit for any read within any test to prevent tests from hanging forever if - /// an internal read call never returns, etc. + fn init_log() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + // An upper limit for any read within any test to prevent tests from hanging forever if + // an internal read call never returns, etc. const TEST_MAX_WAIT_FOR_READ: Duration = Duration::from_secs(3); - /// The duration of 600 seconds is chosen to be "effectively forever" for the purpose of testing + // The duration of 600 seconds is chosen to be "effectively forever" for the purpose of testing const TEST_FOREVER_DURATION: Duration = Duration::from_secs(600); - /// The read_timeout to use, when we want to test that a read operation times out + // The read_timeout to use, when we want to test that a read operation times out const TEST_READ_TIMEOUT: Duration = Duration::from_secs(1); #[derive(Debug)] struct ReadBlockedForeverError; - /// Returns a client stream that will "never" send any bytes / return from a read operation + // Returns a client stream that will "never" send any bytes / return from a read operation fn mocked_blocking_headers_forever_stream() -> Box { Box::new(Builder::new().wait(TEST_FOREVER_DURATION).build()) } @@ -2887,8 +3133,8 @@ mod test_timeouts { ) } - /// Helper function to test a read operation with a tokio timeout - /// to prevent tests from hanging forever in case of a bug + // Helper function to test a read operation with a tokio timeout + // to prevent tests from hanging forever in case of a bug async fn test_read_with_tokio_timeout( read_future: F, ) -> Result>, ReadBlockedForeverError> @@ -2932,6 +3178,352 @@ mod test_timeouts { assert!(res.is_ok()); assert_eq!(res.unwrap().unwrap_err().etype(), &ReadTimedout); } + + #[tokio::test] + async fn test_send_proxy_task_and_write() { + init_log(); + + // We need to know exact bytes that will be written + // "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello" + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let expected_body = b"hello"; + + let mock_io = Builder::new() + .write(expected_header) + .write(expected_body) + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; // Disable automatic headers + + // Queue header task + let mut header = ResponseHeader::build(StatusCode::OK, Some(5)).unwrap(); + header.insert_header("Content-Length", "5").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + + // Queue body task + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // Write all tasks + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + assert!(end_stream); + } + + #[tokio::test] + async fn test_proxy_task_with_timeout() { + init_log(); + + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let expected_body = b"hello"; + + let mock_io = Builder::new() + .write(expected_header) + .write(expected_body) + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + http_stream.write_timeout = Some(Duration::from_secs(1)); // Set write timeout + + // Queue tasks + let mut header = ResponseHeader::build(StatusCode::OK, Some(5)).unwrap(); + header.insert_header("Content-Length", "5").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // Verify initial state + assert_eq!( + http_stream.body_bytes_sent(), + 0, + "Should start with 0 bytes sent" + ); + + // Write all tasks with timeout + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + assert!(end_stream); + + // Verify body bytes were counted correctly (not double counted) + assert_eq!( + http_stream.body_bytes_sent(), + 5, + "Should count exactly 5 bytes (application level), not double counted" + ); + } + + // Test that write_proxy_tasks is cancel-safe: if the future is dropped mid-execution, + // unwritten tasks should remain in the queue. + #[tokio::test] + async fn test_proxy_task_cancel_safety() { + init_log(); + + let expected_header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; + // First chunk: "5\r\nhello\r\n" + let expected_chunk1 = b"5\r\nhello\r\n"; + + // Create a mock IO that will write the header and first chunk, + // but will block indefinitely on the second chunk + let mock_io = Builder::new() + .write(expected_header) + .write(expected_chunk1) + .wait(Duration::from_secs(999)) // This will cause timeout + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + http_stream.write_timeout = Some(Duration::from_millis(100)); + + // Queue 3 tasks: header + 2 body chunks + let mut header = ResponseHeader::build(StatusCode::OK, None).unwrap(); + header + .insert_header("Transfer-Encoding", "chunked") + .unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("world")), true)); + + // Verify we have 3 tasks queued + assert_eq!(http_stream.proxy_task_state.tasks.len(), 3); + + // Try to write all tasks - this should timeout while writing the second body chunk + let result = http_stream.write_proxy_tasks().await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().etype(), &WriteTimedout); + + // With the refactored cancel-safe design: + // - First task (header) was written successfully and removed from queue + // - Second task (first body "hello") was removed and sent to BodyWriter, write succeeded, state cleared + // - Third task (second body "world") was removed and sent to BodyWriter, timed out mid-write + // - The in-progress write state is tracked in current_writer, NOT in the queue + assert_eq!( + http_stream.proxy_task_state.tasks.len(), + 0, + "Queue should be empty - tasks are owned by writers once sent" + ); + + // The task being written should be tracked in current_writer + assert!( + matches!( + http_stream.proxy_task_state.current_writer, + Some(ProxyTaskWriter::WritingBody(_)) + ), + "Should be mid-write of body task - writer owns the 'world' task state" + ); + + // Verify body_bytes_sent only counts the successfully written "hello" (5 bytes) + // not the timed-out "world" + assert_eq!( + http_stream.body_bytes_sent(), + 5, + "Should only count the 5 bytes from 'hello', not the incomplete 'world' write" + ); + + // On next call to write_proxy_tasks(), Step 1 will resume the "world" write + } + + use crate::protocols::http::v1::test_util::FlushTrackingMock; + + // Test that write_continue_response can be called before write_proxy_tasks + // and both work correctly together. + #[tokio::test] + async fn test_continue_response_before_proxy_tasks() { + init_log(); + + // Expected bytes written: + // 1. 100 Continue response + // 2. 200 OK response header + // 3. Body data + let expected_continue = b"HTTP/1.1 100 Continue\r\n\r\n"; + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n"; + let expected_body = b"hello"; + + let mock_io = Builder::new() + .write(expected_continue) + .write(expected_header) + .write(expected_body) + .build(); + + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; // Disable automatic headers + + // First, write the 100 Continue response + http_stream.write_continue_response().await.unwrap(); + + // Verify that 100 Continue was recorded + assert!( + http_stream.response_written().is_some(), + "100 Continue should be recorded in response_written" + ); + assert_eq!( + http_stream.response_written().unwrap().status, + StatusCode::CONTINUE, + "Should have recorded 100 Continue" + ); + + // Now queue the actual response using proxy tasks + let mut header = ResponseHeader::build(StatusCode::OK, Some(5)).unwrap(); + header.insert_header("Content-Length", "5").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // Write all proxy tasks + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + assert!(end_stream, "Should indicate end of stream"); + + // Verify final response is 200 OK, not 100 Continue + assert_eq!( + http_stream.response_written().unwrap().status, + StatusCode::OK, + "Final response should be 200 OK, overwriting 100 Continue" + ); + } + + #[tokio::test] + async fn test_head_response_with_content_length_flushes() { + init_log(); + + // HEAD request line + headers + let request = b"HEAD / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let expected_header = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n"; + + let mock_io = Builder::new().read(request).write(expected_header).build(); + let (flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + let mut http_stream = HttpSession::new(Box::new(flush_mock)); + http_stream.update_resp_headers = false; + + // Read the HEAD request + http_stream.read_request().await.unwrap(); + assert_eq!(http_stream.get_method(), Some(&Method::HEAD)); + + // Queue header with Content-Length (body will be empty for HEAD) + let mut header = ResponseHeader::build(StatusCode::OK, Some(2)).unwrap(); + header.insert_header("Content-Length", "100").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), true)); + + let flush_before = FlushTrackingMock::flush_count(&flush_count); + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + let flush_after = FlushTrackingMock::flush_count(&flush_count); + + assert!(end_stream, "HEAD response should be end of stream"); + assert!( + flush_after > flush_before, + "Should flush after writing HEAD response header with Content-Length \ + (body_writer.finished() is true). Got flush_before={flush_before}, \ + flush_after={flush_after}" + ); + } + + #[tokio::test] + async fn test_204_response_with_content_length_flushes() { + init_log(); + + let request = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let expected_header = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n"; + + let mock_io = Builder::new().read(request).write(expected_header).build(); + let (flush_mock, flush_count) = FlushTrackingMock::new(mock_io); + let mut http_stream = HttpSession::new(Box::new(flush_mock)); + http_stream.update_resp_headers = false; + + http_stream.read_request().await.unwrap(); + + let mut header = ResponseHeader::build(StatusCode::NO_CONTENT, Some(2)).unwrap(); + header.insert_header("Content-Length", "0").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), true)); + + let flush_before = FlushTrackingMock::flush_count(&flush_count); + let end_stream = http_stream.write_proxy_tasks().await.unwrap(); + let flush_after = FlushTrackingMock::flush_count(&flush_count); + + assert!(end_stream, "204 response should be end of stream"); + assert!( + flush_after > flush_before, + "Should flush after writing 204 response header with Content-Length \ + (body_writer.finished() is true). Got flush_before={flush_before}, \ + flush_after={flush_after}" + ); + } + + #[tokio::test] + #[should_panic( + expected = "Unexpected UpgradedBody task received on un-upgraded downstream session" + )] + async fn test_upgraded_body_on_non_upgraded_session_panics() { + init_log(); + + let request = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + let expected_header = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"; + // UpgradedBody on a non-upgraded session should panic before writing, + // but if the bug exists, BodyWriter would encode it as a chunk: + let expected_chunk = b"5\r\nhello\r\n"; + let expected_finish = b"0\r\n\r\n"; + + let mock_io = Builder::new() + .read(request) + .write(expected_header) + // If the panic check is missing, the body gets written as a chunk + .write(expected_chunk) + .write(expected_finish) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + + http_stream.read_request().await.unwrap(); + assert!( + !http_stream.was_upgraded(), + "Session should NOT be upgraded" + ); + + // Queue a normal header + let mut header = ResponseHeader::build(StatusCode::OK, Some(2)).unwrap(); + header + .insert_header("Transfer-Encoding", "chunked") + .unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + + // Queue an UpgradedBody task on a non-upgraded session — should panic + http_stream.send_proxy_task(HttpTask::UpgradedBody(Some(Bytes::from("hello")), true)); + + // This should panic before/during the body write + let _ = http_stream.write_proxy_tasks().await; + } + + #[tokio::test] + #[should_panic(expected = "Unexpected Body task received on upgraded downstream session")] + async fn test_body_on_upgraded_session_panics() { + init_log(); + + // Upgrade request + let request = + b"GET / HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; + // 101 Switching Protocols response + let expected_header = + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; + // If the panic check is missing, Body data would be written raw (close-delimited) + let expected_body = b"hello"; + + let mock_io = Builder::new() + .read(request) + .write(expected_header) + .write(expected_body) + .build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.update_resp_headers = false; + + http_stream.read_request().await.unwrap(); + + // Queue 101 header to complete the upgrade + let mut header = ResponseHeader::build(StatusCode::SWITCHING_PROTOCOLS, Some(3)).unwrap(); + header.insert_header("Upgrade", "websocket").unwrap(); + header.insert_header("Connection", "Upgrade").unwrap(); + http_stream.send_proxy_task(HttpTask::Header(Box::new(header), false)); + + // Queue a regular Body task on what will be an upgraded session — should panic + http_stream.send_proxy_task(HttpTask::Body(Some(Bytes::from("hello")), true)); + + // This should panic (after writing the header, session becomes upgraded, + // then the Body task should be rejected) + let _ = http_stream.write_proxy_tasks().await; + } } #[cfg(test)] diff --git a/pingora-core/src/protocols/http/v2/mod.rs b/pingora-core/src/protocols/http/v2/mod.rs index 01711807..615fcee5 100644 --- a/pingora-core/src/protocols/http/v2/mod.rs +++ b/pingora-core/src/protocols/http/v2/mod.rs @@ -111,7 +111,10 @@ mod test { // Client handles.push(tokio::spawn(async move { - let conn = crate::connectors::http::v2::handshake(Box::new(client), 500, None) + use crate::connectors::http::v2::H2HandshakeSettings; + let mut settings = H2HandshakeSettings::new(); + settings.max_streams = 500; + let conn = crate::connectors::http::v2::handshake(Box::new(client), settings) .await .unwrap(); diff --git a/pingora-core/src/protocols/l4/stream.rs b/pingora-core/src/protocols/l4/stream.rs index 4aa70f70..ddbaceb1 100644 --- a/pingora-core/src/protocols/l4/stream.rs +++ b/pingora-core/src/protocols/l4/stream.rs @@ -814,14 +814,67 @@ pub mod async_write_vec { fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { let me = &mut *self; - while me.buf.has_remaining() { - let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?; - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } + poll_write_vec_all_buf(ctx, Pin::new(&mut *me.writer), me.buf) + } + } + + /// Primitive poll function to write ALL bytes from a buffer using vectored writes. + /// Keeps polling `poll_write_vec` until the entire buffer is written. + /// The buffer is advanced as bytes are written. + /// + /// Returns Poll::Ready(Ok(())) when all bytes are written. + /// Returns WriteZero error if poll_write_vec returns 0. + /// + /// This is essentially a polling form of tokio's + /// [`write_all_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncWriteExt.html#method.write_all_buf). + // TODO: we should be able to switch over to polling the future from tokio AsyncWriteExt directly, + // for now we continue to use the old trait. + pub fn poll_write_vec_all_buf( + ctx: &mut Context<'_>, + mut writer: Pin<&mut W>, + buf: &mut B, + ) -> Poll> + where + W: AsyncWriteVec + ?Sized, + B: Buf, + { + while buf.has_remaining() { + let n = ready!(writer.as_mut().poll_write_vec(ctx, buf))?; + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } - Poll::Ready(Ok(())) } + Poll::Ready(Ok(())) + } + + /// Primitive poll function to write ALL bytes from a buffer using regular writes. + /// Keeps polling `poll_write` until the entire buffer is written. + /// The buffer is advanced as bytes are written. + /// + /// Returns Poll::Ready(Ok(())) when all bytes are written. + /// Returns WriteZero error if poll_write returns 0. + /// + /// This is essentially a polling form of tokio's + /// [`write_all_buf`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncWriteExt.html#method.write_all_buf) + /// though we explicitly use non-vectored writes in this case for strict parity with the + /// original `write_all` method. + pub fn poll_write_all_buf( + ctx: &mut Context<'_>, + mut writer: Pin<&mut W>, + buf: &mut B, + ) -> Poll> + where + W: AsyncWrite + ?Sized, + B: Buf, + { + while buf.has_remaining() { + let n = ready!(writer.as_mut().poll_write(ctx, buf.chunk()))?; + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + buf.advance(n); + } + Poll::Ready(Ok(())) } /* from https://github.com/tokio-rs/tokio/blob/master/tokio-util/src/lib.rs#L177 */ diff --git a/pingora-core/src/server/bootstrap_services.rs b/pingora-core/src/server/bootstrap_services.rs index ca5bfad0..74c81d79 100644 --- a/pingora-core/src/server/bootstrap_services.rs +++ b/pingora-core/src/server/bootstrap_services.rs @@ -23,8 +23,12 @@ use tokio::sync::broadcast; #[cfg(feature = "sentry")] use sentry::ClientOptions; +#[cfg(unix)] +use crate::server::daemon::notify_parent_ready_for_fds; #[cfg(unix)] use crate::server::ListenFds; +#[cfg(unix)] +use std::time::Duration; use crate::{ prelude::Opt, @@ -32,6 +36,10 @@ use crate::{ services::{background::BackgroundService, ServiceReadyNotifier}, }; +/// Default timeout for retrying `SIGUSR1` to the parent when it fails with `EPERM`. +#[cfg(unix)] +const DEFAULT_DAEMON_NOTIFY_TIMEOUT: Duration = Duration::from_secs(60); + /// Service that allows the bootstrap process to be delayed until after /// dependencies are ready pub struct BootstrapService { @@ -60,6 +68,16 @@ pub struct Bootstrap { #[cfg(unix)] listen_fds: ListenFds, + /// PID of the original parent process to notify via `SIGUSR1` after bootstrap completes. + /// Set when [`ServerConf::daemon_wait_for_ready`] is `true`. + #[cfg(unix)] + notify_parent_pid: Option, + + /// How long to keep retrying `SIGUSR1` to the parent when it fails with `EPERM`. + /// See [`ServerConf::daemon_notify_timeout_seconds`]. + #[cfg(unix)] + daemon_notify_timeout: std::time::Duration, + #[cfg(feature = "sentry")] #[cfg_attr(docsrs, doc(cfg(feature = "sentry")))] /// The Sentry ClientOptions. @@ -96,6 +114,13 @@ impl Bootstrap { upgrade_sock, #[cfg(unix)] listen_fds: Arc::new(Mutex::new(Fds::new())), + #[cfg(unix)] + notify_parent_pid: None, + #[cfg(unix)] + daemon_notify_timeout: conf + .daemon_notify_timeout_seconds + .map(|n| Duration::from_secs(n.get())) + .unwrap_or(DEFAULT_DAEMON_NOTIFY_TIMEOUT), execution_phase_watch: execution_phase_watch.clone(), completed: false, #[cfg(feature = "sentry")] @@ -110,6 +135,13 @@ impl Bootstrap { self.sentry = sentry_config; } + /// Store the parent process PID to notify via `SIGUSR1` after bootstrap completes. + /// Only relevant when [`ServerConf::daemon_wait_for_ready`] is `true`. + #[cfg(unix)] + pub fn set_notify_parent_pid(&mut self, pid: u32) { + self.notify_parent_pid = Some(pid); + } + /// Initialize the Sentry client from the configured [`ClientOptions`] and /// store the resulting guard. /// @@ -161,7 +193,17 @@ impl Bootstrap { std::process::exit(0); } - // load fds + // Notify the parent process that it can exit. It might seem like we should load the file + // descriptors from the old process first, but the purpose of this notification is to + // release the parent so that the process managing it (e.g. systemd) can continue and send + // a quit signal to the old process. That quit signal is required before the old process + // will start trying to send its file descriptors to us — so if we called load_fds first, + // we would be guaranteeing a timeout. + #[cfg(unix)] + if let Some(pid) = self.notify_parent_pid { + notify_parent_ready_for_fds(pid, self.daemon_notify_timeout); + } + #[cfg(unix)] match self.load_fds(self.upgrade) { Ok(_) => { diff --git a/pingora-core/src/server/configuration/mod.rs b/pingora-core/src/server/configuration/mod.rs index 8ab02bf3..1f410892 100644 --- a/pingora-core/src/server/configuration/mod.rs +++ b/pingora-core/src/server/configuration/mod.rs @@ -25,6 +25,7 @@ use pingora_error::{Error, ErrorType::*, OrErr, Result}; use serde::{Deserialize, Serialize}; use std::ffi::OsString; use std::fs; +use std::num::NonZeroU64; // default maximum upstream retries for retry-able proxy errors const DEFAULT_MAX_RETRIES: usize = 16; @@ -125,6 +126,45 @@ pub struct ServerConf { /// /// When not set, the tokio default (10 seconds) is used. pub blocking_threads_ttl_seconds: Option, + /// When `daemon` is `true`, controls whether the parent process of the daemon fork waits for + /// the child to signal readiness before exiting. + /// + /// When `false` (default), the parent exits immediately after the daemon fork, matching the + /// traditional daemonization behavior. Systemd will consider the service started as soon as + /// the parent exits, which may be before the child has finished bootstrapping. + /// + /// When `true`, the parent waits (up to [`Self::daemon_ready_timeout_seconds`]) for the child + /// to send `SIGUSR1` after bootstrap completes. This causes systemd to delay any subsequent + /// steps (such as sending `SIGQUIT` to the old process) until the new instance is fully ready + /// to serve traffic. If the child does not signal in time, the parent exits with a non-zero + /// exit code, causing systemd to abort the reload. + pub daemon_wait_for_ready: bool, + /// Timeout in seconds for the parent process to wait for the child to signal readiness during + /// daemonization when [`Self::daemon_wait_for_ready`] is `true`. + /// + /// If the child does not send `SIGUSR1` within this timeout, the parent exits with a non-zero + /// exit code. + /// + /// Defaults to 600 seconds (10 minutes). + pub daemon_ready_timeout_seconds: Option, + /// How long the child process will keep retrying `SIGUSR1` to the parent when the signal + /// fails with a permission error (`EPERM`) during daemonization. + /// + /// After the daemon fork, the parent always drops its credentials to the configured user and + /// group (see [`Self::user`], [`Self::group`]). Because the privilege drop happens after the + /// fork, there is a small window where the child may attempt to signal the parent before the + /// parent has finished changing its credentials. During this window the kernel will reject the + /// signal with `EPERM` because the child and parent are running as different users. The child + /// retries every 100 ms until this timeout elapses. + /// + /// In practice this window is very small, so the default of 60 seconds is far more than + /// enough to account for it. + /// + /// Only retries on `EPERM`; any other error (e.g. `ESRCH` — parent no longer exists) is + /// treated as fatal and logged without retrying. + /// + /// Defaults to 60 seconds. + pub daemon_notify_timeout_seconds: Option, } impl Default for ServerConf { @@ -155,6 +195,9 @@ impl Default for ServerConf { upgrade_sock_connect_accept_max_retries: None, max_blocking_threads: None, blocking_threads_ttl_seconds: None, + daemon_ready_timeout_seconds: None, + daemon_wait_for_ready: false, + daemon_notify_timeout_seconds: None, } } } @@ -326,6 +369,9 @@ mod tests { upgrade_sock_connect_accept_max_retries: None, max_blocking_threads: None, blocking_threads_ttl_seconds: None, + daemon_ready_timeout_seconds: None, + daemon_wait_for_ready: false, + daemon_notify_timeout_seconds: None, }; // cargo test -- --nocapture not_a_test_i_cannot_write_yaml_by_hand println!("{}", conf.to_yaml()); diff --git a/pingora-core/src/server/daemon.rs b/pingora-core/src/server/daemon.rs index 7381fc93..b6c95cb0 100644 --- a/pingora-core/src/server/daemon.rs +++ b/pingora-core/src/server/daemon.rs @@ -12,18 +12,71 @@ // See the License for the specific language governing permissions and // limitations under the License. -use daemonize::{Daemonize, Stdio}; -use log::{debug, error}; +use daemonize::{Daemonize, Outcome, Stdio}; +use log::{debug, error, info}; +use pingora_error::{Error, ErrorType, OrErr, Result}; use std::ffi::CString; use std::fs::{self, OpenOptions}; use std::os::unix::prelude::OpenOptionsExt; use std::path::Path; +use std::process; +use std::thread; +use std::time::{Duration, Instant}; use crate::server::configuration::ServerConf; +/// Error returned by [`send_signal`]. +#[derive(Debug)] +pub(crate) enum SignalError { + /// The caller does not have permission to send the signal to the target process (`EPERM`). + PermissionDenied, + /// Any other error from `kill(2)`. Contains the raw `errno` value. + OtherSignalError(i32), +} + +impl std::fmt::Display for SignalError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SignalError::PermissionDenied => write!(f, "permission denied (EPERM)"), + SignalError::OtherSignalError(errno) => { + write!(f, "kill failed with errno {errno}") + } + } + } +} + +/// Send `signal` to the process identified by `pid`. +/// +/// Returns `Ok(())` on success. On failure, maps `errno` to [`SignalError`]: +/// - `EPERM` → [`SignalError::PermissionDenied`] +/// - anything else → [`SignalError::OtherSignalError`] containing the raw errno value. +fn send_signal(pid: libc::pid_t, signal: libc::c_int) -> Result<(), SignalError> { + // SAFETY: `kill(2)` is safe to call with any pid/signal combination — invalid values + // simply return an error via errno rather than causing undefined behavior. + let ret = unsafe { libc::kill(pid, signal) }; + if ret == 0 { + return Ok(()); + } + let errno = std::io::Error::last_os_error().raw_os_error().unwrap_or(-1); + if errno == libc::EPERM { + Err(SignalError::PermissionDenied) + } else { + Err(SignalError::OtherSignalError(errno)) + } +} + // Utilities to daemonize a pingora server, i.e. run the process in the background, possibly // under a different running user and/or group. +/// Default timeout for the parent to wait for the daemon child to signal readiness. +const DEFAULT_DAEMON_READY_TIMEOUT: Duration = Duration::from_secs(600); + +/// How long to sleep between `SIGUSR1` send attempts when `EPERM` is returned. +const NOTIFY_RETRY_INTERVAL: Duration = Duration::from_millis(100); + +/// How long to sleep between pid-file liveness checks in the async wait loop. +const LIVENESS_CHECK_INTERVAL: Duration = Duration::from_millis(100); + // XXX: this operation should have been done when the old service is exiting. // Now the new pid file just kick the old one out of the way fn move_old_pid(path: &str) { @@ -45,7 +98,15 @@ fn move_old_pid(path: &str) { } } +/// # Safety +/// +/// `name` must be a valid, null-terminated C string. The returned `gid_t` is read from the +/// `passwd` struct returned by `getpwnam(3)`, which points to a static buffer that may be +/// overwritten by subsequent calls to `getpwnam` or `getpwuid`. The caller must not hold the +/// pointer across such calls. unsafe fn gid_for_username(name: &CString) -> Option { + // SAFETY: `name` is a valid CString; `getpwnam` returns a pointer to a static buffer + // or null. We read `pw_gid` immediately and do not retain the pointer. let passwd = libc::getpwnam(name.as_ptr() as *const libc::c_char); if !passwd.is_null() { return Some((*passwd).pw_gid); @@ -53,9 +114,277 @@ unsafe fn gid_for_username(name: &CString) -> Option { None } +/// Drop the parent process's UID to the user specified in [`ServerConf::user`]. +/// +/// The kernel only permits a process to send a signal to another if they share the same UID (or +/// the sender is root). Since the daemon child sends `SIGUSR1` to the parent to signal readiness, +/// the parent must be running as the same UID as the child by the time that signal arrives — +/// otherwise the kernel will reject it with `EPERM`. +/// +/// This function is called in the `Outcome::Parent` path immediately after `execute()` returns, +/// before the parent enters its readiness wait loop, so the parent's UID matches the child's as +/// quickly as possible after the fork. +/// +/// Only the UID is changed; the GID is left as-is. Signal permission checks are based on UID, +/// so changing the GID is not necessary for this purpose. +/// +/// Logs an error and continues if the user cannot be resolved or `setuid` fails — the parent +/// is short-lived and about to exit, so a failed privilege drop is non-fatal. The child's +/// `EPERM` retry window (see [`ServerConf::daemon_notify_timeout_seconds`]) exists precisely to +/// cover the small gap between the fork and the parent completing this UID change. +fn drop_privileges_in_parent(conf: &ServerConf) -> Result<()> { + let Some(user) = conf.user.as_ref() else { + return Ok(()); + }; + + let user_cstr = CString::new(user.as_str()).or_err_with(ErrorType::Custom("Daemon"), || { + format!("drop_privileges_in_parent: user '{user}' invalid") + })?; + + // SAFETY: `user_cstr` is a valid CString. `getpwnam` returns a pointer to a static + // buffer or null. We read `pw_uid` immediately and do not retain the pointer. + let passwd = unsafe { libc::getpwnam(user_cstr.as_ptr() as *const libc::c_char) }; + if passwd.is_null() { + return Error::e_explain( + ErrorType::Custom("Daemon"), + format!("drop_privileges_in_parent: user '{user}' not found"), + ); + } + + // SAFETY: `passwd` was checked for null above. We dereference it once to read `pw_uid`. + let uid = unsafe { (*passwd).pw_uid }; + // SAFETY: `setuid(2)` is safe to call with any uid — invalid values return an error. + let ret = unsafe { libc::setuid(uid) }; + if ret == 0 { + Ok(()) + } else { + Error::e_explain( + ErrorType::Custom("Daemon"), + format!( + "drop_privileges_in_parent: setuid({uid}) failed: {}", + std::io::Error::last_os_error() + ), + ) + } +} + +/// Outcome of calling [`daemonize`]. +/// +/// When [`ServerConf::daemon_wait_for_ready`] is `true`, the child process must call +/// [`notify_parent_ready_for_fds`] after bootstrap completes to unblock the parent's wait loop. +pub struct DaemonizeResult { + /// The PID of the original parent process to notify via `SIGUSR1` after bootstrap completes. + /// + /// `Some` when [`ServerConf::daemon_wait_for_ready`] is `true`, `None` otherwise. + pub notify_parent_pid: Option, +} + /// Start a server instance as a daemon. -#[cfg(unix)] -pub fn daemonize(conf: &ServerConf) { +/// +/// Both code paths use [`daemonize::Daemonize::execute()`] rather than calling `fork()` directly. +/// `execute()` returns an [`Outcome`] to the caller in each process rather than having the parent +/// exit inside the crate, which gives us the opportunity to run additional logic in the parent +/// before it exits. +/// +/// When [`ServerConf::daemon_wait_for_ready`] is `false` (the default), the parent exits +/// immediately — matching the behavior of `start()`. +/// +/// When `daemon_wait_for_ready` is `true`, the parent registers a `SIGUSR1` handler before +/// forking, then waits (in a sleep loop polling the pid file and the signal flag) for up to +/// [`ServerConf::daemon_ready_timeout_seconds`] (default 600 s) for the grandchild to send +/// `SIGUSR1`. On success the parent exits with code 0. On timeout, or if the daemon process +/// exits before signaling, the parent exits with code 1, causing systemd to abort the reload. +/// +/// Returns a [`DaemonizeResult`] that is only meaningful to the child process. The parent always +/// exits before returning. +pub fn daemonize(conf: &ServerConf) -> DaemonizeResult { + // Capture the parent PID before forking so it can be passed to the grandchild. The + // grandchild sends SIGUSR1 to this PID after bootstrap completes. + let parent_pid = if conf.daemon_wait_for_ready { + Some(process::id()) + } else { + None + }; + + move_old_pid(&conf.pid_file); + + match build_daemonize(conf).execute() { + Outcome::Parent(result) => { + result.unwrap_or_else(|e| panic!("Daemonize failed: {e}")); + } + Outcome::Child(result) => { + result.unwrap_or_else(|e| panic!("Daemonize child setup failed: {e}")); + return DaemonizeResult { + notify_parent_pid: parent_pid, + }; + } + } + + if conf.daemon_wait_for_ready { + // Drop root privileges before waiting so the parent does not linger as root. + if let Err(e) = drop_privileges_in_parent(conf) { + error!("drop_privileges_in_parent failed: {e}"); + + // Exiting the parent process should be fine because if downgrading + // the user's privileges fails here, it will fail in the child and + // the child will exit too + process::exit(1); + } + + let timeout = conf + .daemon_ready_timeout_seconds + .map(|n| Duration::from_secs(n.get())) + .unwrap_or(DEFAULT_DAEMON_READY_TIMEOUT); + + info!( + "Waiting up to {:?} for daemon to signal readiness via SIGUSR1", + timeout + ); + + wait_for_ready_or_exit(&conf.pid_file, timeout); + } + + process::exit(0); +} + +/// Build a single-threaded tokio runtime for the parent's signal wait loop. +/// +/// The parent process is short-lived and only needs to wait for a signal and check the pid file. +/// A current-thread runtime avoids spawning worker threads in a process that is about to exit. +fn build_parent_runtime() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to build tokio runtime for parent signal wait") +} + +/// Wait for the daemon grandchild to send `SIGUSR1`, up to `timeout`. +/// +/// Uses a local tokio runtime with [`tokio::signal::unix`] to listen for `SIGUSR1` instead of +/// raw signal handlers and polling loops. The daemon's PID is checked periodically via the pid +/// file — if the process exits before signaling, the parent aborts. +/// +/// Exits the process directly: +/// - exit code 0 if `SIGUSR1` is received (daemon is ready). +/// - exit code 1 if `timeout` elapses (daemon took too long). +/// - exit code 1 if the pid file exists and the process is no longer running. +fn wait_for_ready_or_exit(pid_file: &str, timeout: Duration) { + let rt = build_parent_runtime(); + let pid_file = pid_file.to_owned(); + + rt.block_on(async move { + use tokio::signal::unix::{signal, SignalKind}; + use tokio::time::{interval, timeout as tokio_timeout}; + + let mut sigusr1 = + signal(SignalKind::user_defined1()).expect("failed to register SIGUSR1 listener"); + + let mut liveness_check = interval(LIVENESS_CHECK_INTERVAL); + let mut daemon_pid: Option = None; + + let result = tokio_timeout(timeout, async { + loop { + tokio::select! { + _ = sigusr1.recv() => { + info!("Daemon signaled readiness, parent exiting"); + return; + } + _ = liveness_check.tick() => { + if daemon_pid.is_none() { + daemon_pid = try_read_pid_file(&pid_file); + } + if let Some(pid) = daemon_pid { + if !process_is_running(pid) { + error!( + "Daemon process (pid {pid}) is no longer running \ + before signaling readiness, aborting" + ); + process::exit(1); + } + } + } + } + } + }) + .await; + + if result.is_err() { + error!("Daemon did not signal readiness within {timeout:?}, aborting"); + process::exit(1); + } + }); +} + +/// Notify the parent process that the daemon is ready to serve traffic by sending `SIGUSR1`. +/// +/// Should be called by the daemon process after bootstrap is complete when +/// [`ServerConf::daemon_wait_for_ready`] is `true`. `parent_pid` is the PID of the original +/// process captured before the fork and stored in [`DaemonizeResult::notify_parent_pid`]. +/// +/// `SIGUSR1` sets an atomic flag that the parent's wait loop checks, causing it to exit with +/// code 0 and allowing systemd to proceed with the next step of the service reload. +/// +/// If `kill(2)` returns `EPERM` — which can happen transiently when the child's UID has just +/// been changed by `setuid` and the kernel hasn't yet updated the credential check — the +/// function sleeps for [`NOTIFY_RETRY_INTERVAL`] (100 ms) and retries until `notify_timeout` +/// elapses, at which point it logs an error and returns. Any other error (e.g. `ESRCH`, +/// meaning the parent no longer exists) is logged and the function returns immediately without +/// retrying. +pub fn notify_parent_ready_for_fds(parent_pid: u32, notify_timeout: Duration) { + let parent_pid = parent_pid as libc::pid_t; + info!( + "Sending SIGUSR1 to parent process (pid {}) to signal daemon readiness", + parent_pid + ); + + let start = Instant::now(); + + while start.elapsed() < notify_timeout { + match send_signal(parent_pid, libc::SIGUSR1) { + Ok(()) => return, + Err(SignalError::PermissionDenied) => { + debug!( + "Permission denied sending SIGUSR1 to parent (pid {}), retrying in {:?}", + parent_pid, NOTIFY_RETRY_INTERVAL + ); + thread::sleep(NOTIFY_RETRY_INTERVAL); + } + Err(SignalError::OtherSignalError(errno)) => { + error!( + "Failed to send SIGUSR1 to parent (pid {}): errno {errno}", + parent_pid + ); + return; + } + } + } + + error!( + "Permission denied sending SIGUSR1 to parent (pid {}), giving up after {:?}", + parent_pid, notify_timeout + ); +} + +/// Try to read a PID from `pid_file`. Returns `None` if the file does not exist or cannot be +/// parsed. +fn try_read_pid_file(pid_file: &str) -> Option { + fs::read_to_string(pid_file) + .ok() + .and_then(|c| c.trim().parse().ok()) +} + +/// Returns `true` if a process with `pid` is currently running. +fn process_is_running(pid: libc::pid_t) -> bool { + // Signal 0 does not send a signal; it just checks whether the process exists and whether + // we have permission to signal it. EPERM (no permission) is not possible here because + // drop_privileges_in_parent guarantees the parent has already dropped to the same user as + // the daemon child before this function is called. + send_signal(pid, 0).is_ok() +} + +/// Build a [`Daemonize`] instance configured from `conf`, without calling `start()` or +/// `execute()`. The caller is responsible for driving execution. +fn build_daemonize(conf: &ServerConf) -> Daemonize<()> { // TODO: customize working dir let daemonize = Daemonize::new() @@ -82,6 +411,7 @@ pub fn daemonize(conf: &ServerConf) { Some(user) => { let user_cstr = CString::new(user.as_str()).unwrap(); + // SAFETY: `user_cstr` is a valid CString. See `gid_for_username` safety docs. #[cfg(target_os = "macos")] let group_id = unsafe { gid_for_username(&user_cstr).map(|gid| gid as i32) }; #[cfg(target_os = "freebsd")] @@ -92,7 +422,8 @@ pub fn daemonize(conf: &ServerConf) { daemonize .privileged_action(move || { if let Some(gid) = group_id { - // Set the supplemental group privileges for the child process. + // SAFETY: `user_cstr` is a valid CString captured by the closure. + // `initgroups(3)` is safe to call with a valid username and gid. unsafe { libc::initgroups(user_cstr.as_ptr() as *const libc::c_char, gid); } @@ -104,12 +435,8 @@ pub fn daemonize(conf: &ServerConf) { None => daemonize, }; - let daemonize = match conf.group.as_ref() { + match conf.group.as_ref() { Some(group) => daemonize.group(group.as_str()), None => daemonize, - }; - - move_old_pid(&conf.pid_file); - - daemonize.start().unwrap(); // hard crash when fail + } } diff --git a/pingora-core/src/server/mod.rs b/pingora-core/src/server/mod.rs index ffedf665..0d3a105e 100644 --- a/pingora-core/src/server/mod.rs +++ b/pingora-core/src/server/mod.rs @@ -624,8 +624,13 @@ impl Server { if conf.daemon { info!("Daemonizing the server"); fast_timeout::pause_for_fork(); - daemonize(&self.configuration); + let daemonize_result = daemonize(&self.configuration); fast_timeout::unpause(); + // If daemon_wait_for_ready is enabled, pass the parent PID to bootstrap so it + // can send SIGUSR1 to the parent after bootstrap completes. + if let Some(pid) = daemonize_result.notify_parent_pid { + self.bootstrap.lock().set_notify_parent_pid(pid); + } } #[cfg(windows)] diff --git a/pingora-core/src/services/listening.rs b/pingora-core/src/services/listening.rs index b6886c21..7b718b9b 100644 --- a/pingora-core/src/services/listening.rs +++ b/pingora-core/src/services/listening.rs @@ -309,19 +309,3 @@ impl ServiceTrait for Service { self.threads } } - -#[cfg(feature = "prometheus")] -use crate::apps::prometheus_http_app::PrometheusServer; - -#[cfg(feature = "prometheus")] -impl Service { - /// The Prometheus HTTP server - /// - /// The HTTP server endpoint that reports Prometheus metrics collected in the entire service - pub fn prometheus_http_service() -> Self { - Service::new( - "Prometheus metric HTTP".to_string(), - PrometheusServer::new(), - ) - } -} diff --git a/pingora-core/src/upstreams/peer.rs b/pingora-core/src/upstreams/peer.rs index c9ae0a66..78c6dbcc 100644 --- a/pingora-core/src/upstreams/peer.rs +++ b/pingora-core/src/upstreams/peer.rs @@ -431,8 +431,14 @@ pub struct PeerOptions { pub s2n_security_policy: Option, #[cfg(feature = "s2n")] pub max_blinding_delay: Option, - // how many concurrent h2 stream are allowed in the same connection + /// How many concurrent h2 streams are allowed in the same connection. pub max_h2_streams: usize, + /// Initial per-stream H2 receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub h2_stream_window_size: Option, + /// Initial connection-level H2 receive window size in bytes. + /// If `None`, the default of 8MB is used. + pub h2_connection_window_size: Option, /// Allow invalid Content-Length in HTTP/1 responses (non-RFC compliant). /// /// When enabled, invalid Content-Length responses are treated as close-delimited responses. @@ -494,6 +500,8 @@ impl PeerOptions { #[cfg(feature = "s2n")] max_blinding_delay: None, max_h2_streams: 1, + h2_stream_window_size: None, + h2_connection_window_size: None, allow_h1_response_invalid_content_length: false, extra_proxy_headers: BTreeMap::new(), curves: None, @@ -685,6 +693,10 @@ impl Hash for HttpPeer { self.group_key.hash(state); // max h2 stream settings self.options.max_h2_streams.hash(state); + // h2_stream_window_size and h2_connection_window_size are intentionally excluded + // from the reuse hash for now. These are per-connection settings applied at handshake + // time and may be revisited alongside other h2 settings that could be dynamically + // adjusted over the lifetime of a connection. } } diff --git a/pingora-core/tests/bootstrap_as_a_service.rs b/pingora-core/tests/bootstrap_as_a_service.rs new file mode 100644 index 00000000..fa88ff20 --- /dev/null +++ b/pingora-core/tests/bootstrap_as_a_service.rs @@ -0,0 +1,136 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Integration tests for `bootstrap_as_a_service`. +//! +//! Verifies that when `bootstrap_as_a_service()` dependencies are declared, the +//! `BootstrapComplete` execution phase is not reached until all dependency services have +//! finished their initialization work. + +use async_trait::async_trait; +use pingora_core::server::ShutdownWatch; +use pingora_core::server::{configuration::ServerConf, ExecutionPhase, RunArgs, Server}; +use pingora_core::services::background::{background_service, BackgroundService}; +use pingora_core::services::ServiceReadyNotifier; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +/// A background service that sets a flag when it completes, after an optional delay. +/// +/// Signals readiness only after completing its initialization work so that dependent +/// services (like `BootstrapService`) cannot start until this service is truly done. +struct TrackableService { + delay: Duration, + completed: Arc, +} + +#[async_trait] +impl BackgroundService for TrackableService { + async fn start_with_ready_notifier( + &self, + _shutdown: ShutdownWatch, + ready_notifier: ServiceReadyNotifier, + ) { + if !self.delay.is_zero() { + tokio::time::sleep(self.delay).await; + } + self.completed.store(true, Ordering::SeqCst); + // Signal readiness only after work is done — this is what the dependency + // mechanism waits on before allowing BootstrapService to proceed. + ready_notifier.notify_ready(); + } +} + +/// Verifies that `bootstrap_as_a_service` does not reach `BootstrapComplete` until all +/// declared dependency services have finished their initialization work. +#[test] +fn test_bootstrap_waits_for_dependencies() { + let conf = ServerConf { + grace_period_seconds: Some(1), + graceful_shutdown_timeout_seconds: Some(1), + ..Default::default() + }; + + let mut server = Server::new_with_opt_and_conf(None, conf); + let mut phase = server.watch_execution_phase(); + + // Two dependency services with delays. The second (150 ms) sets the pace. + let dep1_done = Arc::new(AtomicBool::new(false)); + let dep2_done = Arc::new(AtomicBool::new(false)); + + let dep1_handle = server.add_service(background_service( + "dep1", + TrackableService { + delay: Duration::from_millis(50), + completed: dep1_done.clone(), + }, + )); + let dep2_handle = server.add_service(background_service( + "dep2", + TrackableService { + delay: Duration::from_millis(150), + completed: dep2_done.clone(), + }, + )); + + // BootstrapService must not reach BootstrapComplete until dep1 and dep2 are done. + let bootstrap_handle = server.bootstrap_as_a_service(); + bootstrap_handle.add_dependencies([&dep1_handle, &dep2_handle]); + + // When using bootstrap_as_a_service, do NOT call server.bootstrap() separately — + // the BootstrapService runs as a background service during run(), and emits + // BootstrapComplete only after all its declared dependencies are ready. + let _join = std::thread::spawn(move || { + server.run(RunArgs::default()); + }); + + let mut received_bootstrap = false; + let mut received_bootstrap_complete = false; + + // Collect phases until BootstrapComplete is seen. Running may arrive + // before or after Bootstrap/BootstrapComplete since main_loop starts + // concurrently with the service runtimes. + loop { + match phase.blocking_recv() { + Ok(ExecutionPhase::Bootstrap) => { + received_bootstrap = true; + } + Ok(ExecutionPhase::BootstrapComplete) => { + // Both dependencies must have set their flags before bootstrap completes. + assert!( + dep1_done.load(Ordering::SeqCst), + "dep1 should be done before BootstrapComplete" + ); + assert!( + dep2_done.load(Ordering::SeqCst), + "dep2 should be done before BootstrapComplete" + ); + received_bootstrap_complete = true; + break; + } + Ok(_) => {} + Err(_) => break, + } + } + + assert!(received_bootstrap, "should have seen Bootstrap phase"); + assert!( + received_bootstrap_complete, + "should have seen BootstrapComplete phase" + ); + + // Shut down cleanly. + std::process::exit(0); +} diff --git a/pingora-lru/benches/bench_lru.rs b/pingora-lru/benches/bench_lru.rs index c0bdc776..02dadec8 100644 --- a/pingora-lru/benches/bench_lru.rs +++ b/pingora-lru/benches/bench_lru.rs @@ -12,137 +12,237 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Benchmark for `Lru::promote()` vs `Lru::promote_top_n()`. +//! +//! Tests both small (original) and production-scale LRU sizes to show how +//! the `promote_top_n` optimization behaves at different scales. +//! +//! Run with: `cargo bench -p pingora-lru --bench bench_lru` +//! +//! ## Results (Apple M3 Max, 2026-04-03) +//! +//! Benchmark tiers simulate production-level data sizes (items/shard +//! ranging from ~100 to ~500K) with both uniform-hot and heavy-hitter +//! access patterns. +//! +//! ### 8-threaded — uniform hot set (10% of items are 100x hotter) +//! +//! | Items/shard | promote | top_n(0) | top_n(3) | top_n(10) | top_n(50) | top_n(100) | +//! |-------------|------------|----------|----------|-----------|-----------|------------| +//! | 10 (orig) | 366ns | 476ns | 271ns | **164ns** | 164ns | 164ns | +//! | 100K (typ) | **457ns** | 480ns | 437ns | 520ns | 1227ns | 2394ns | +//! +//! ### 8-threaded — heavy hitters (10 or 100 items are 10,000x hotter) +//! +//! | Items/shard | promote | top_n(0) | top_n(3) | top_n(10) | top_n(50) | top_n(100) | +//! |-----------------|------------|----------|----------|-----------|-----------|------------| +//! | 100K, 10 hot | **649ns** | 688ns | 652ns | 773ns | 1811ns | 3534ns | +//! | 100K, 100 hot | **607ns** | 632ns | 607ns | 716ns | 1493ns | 2759ns | +//! +//! ### Single-threaded — uniform hot set (10% of items are 100x hotter) +//! +//! | Items/shard | promote | top_n(0) | top_n(3) | top_n(10) | top_n(50) | top_n(100) | +//! |-------------|------------|----------|----------|-----------|-----------|------------| +//! | 10 (orig) | 22ns | 20ns | 29ns | **30ns** | 30ns | 30ns | +//! | 100K (typ) | **297ns** | 306ns | 314ns | 332ns | 663ns | 1092ns | +//! +//! **Conclusions**: +//! +//! - `promote_top_n(0)` is strictly worse than `promote()` — it takes a +//! wasted read lock before falling through to the write lock every time. +//! +//! - `promote_top_n(n)` for n > 0 only wins at the original small scale +//! (10 items/shard) where the threshold covers the entire shard. +//! +//! - Even with heavy-hitter patterns (10 items at 10,000x weight), +//! `promote()` ties or wins at production scale. With 10 hot items +//! across 32 shards, most shards have 0-1 hot items, so the read-lock +//! scan is wasted on the majority of cold-item accesses. +//! +//! - At production scale (~100K+ items/shard), plain `promote()` is fastest +//! regardless of access pattern. + use rand::distributions::WeightedIndex; use rand::prelude::*; -use std::sync::Arc; +use std::sync::{Arc, Barrier}; use std::thread; use std::time::Instant; -// Non-uniform distributions, 100 items, 10 of them are 100x more likely to appear -const WEIGHTS: &[usize] = &[ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, - 100, 100, 100, 100, 100, 100, 100, -]; - const ITERATIONS: usize = 5_000_000; const THREADS: usize = 8; -fn main() { - let lru = parking_lot::Mutex::new(lru::LruCache::::unbounded()); - - let plru = pingora_lru::Lru::<(), 10>::with_capacity(1000, 100); - // populate first, then we bench access/promotion - for i in 0..WEIGHTS.len() { - lru.lock().put(i as u64, ()); - } - for i in 0..WEIGHTS.len() { - plru.admit(i as u64, (), 1); +/// Build a weight distribution where the first `hot_count` items have +/// `hot_weight`x the access probability. +fn make_weights(n: usize, hot_count: usize, hot_weight: usize) -> Vec { + let mut weights = vec![1usize; n]; + for w in weights.iter_mut().take(hot_count) { + *w = hot_weight; } + weights +} - // single thread - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); +fn bench_config(label: &str, items: usize, shards: usize, hot_pct: usize, hot_weight: usize) { + let hot_count = items * hot_pct / 100; + bench_config_abs(label, items, shards, hot_count, hot_weight); +} - let before = Instant::now(); - for _ in 0..ITERATIONS { - lru.lock().get(&(dist.sample(&mut rng) as u64)); - } - let elapsed = before.elapsed(); +fn bench_config_abs(label: &str, items: usize, shards: usize, hot_count: usize, hot_weight: usize) { println!( - "lru promote total {elapsed:?}, {:?} avg per operation", - elapsed / ITERATIONS as u32 + "\n=== {label}: {items} items, {shards} shards ({} per shard), \ + {hot_count} items are {hot_weight}x hotter ===", + items / shards ); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote(dist.sample(&mut rng) as u64); - } - let elapsed = before.elapsed(); - println!( - "pingora lru promote total {elapsed:?}, {:?} avg per operation", - elapsed / ITERATIONS as u32 - ); + let weights = make_weights(items, hot_count, hot_weight); + let dist = Arc::new(WeightedIndex::new(&weights).unwrap()); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote_top_n(dist.sample(&mut rng) as u64, 10); + match shards { + 10 => bench_shards::<10>(items, &dist), + 32 => bench_shards::<32>(items, &dist), + _ => panic!("unsupported shard count: {shards}"), } - let elapsed = before.elapsed(); - println!( - "pingora lru promote_top_10 total {elapsed:?}, {:?} avg per operation", - elapsed / ITERATIONS as u32 - ); +} - // concurrent - - let lru = Arc::new(lru); - let mut handlers = vec![]; - for i in 0..THREADS { - let lru = lru.clone(); - let handler = thread::spawn(move || { - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); - let before = Instant::now(); - for _ in 0..ITERATIONS { - lru.lock().get(&(dist.sample(&mut rng) as u64)); - } - let elapsed = before.elapsed(); - println!( - "lru promote total {elapsed:?}, {:?} avg per operation thread {i}", - elapsed / ITERATIONS as u32 - ); - }); - handlers.push(handler); +/// Populate a fresh LRU with `items` entries. +fn make_lru(items: usize) -> pingora_lru::Lru<(), N> { + let lru = pingora_lru::Lru::<(), N>::with_capacity(items, items / N); + for i in 0..items { + lru.admit(i as u64, (), 1); } - for thread in handlers { - thread.join().unwrap(); + lru +} + +fn bench_shards(items: usize, dist: &Arc>) { + // Each variant gets a fresh LRU to avoid state contamination from + // prior runs warming hot items to the head. + + // --- Single-threaded --- + println!(" Single-threaded:"); + { + let lru = make_lru::(items); + let mut rng = thread_rng(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote(dist.sample(&mut rng) as u64); + } + let elapsed = before.elapsed(); + println!( + " promote: {elapsed:?} total, {:?} avg", + elapsed / ITERATIONS as u32, + ); } - let plru = Arc::new(plru); - - let mut handlers = vec![]; - for i in 0..THREADS { - let plru = plru.clone(); - let handler = thread::spawn(move || { - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote(dist.sample(&mut rng) as u64); - } - let elapsed = before.elapsed(); - println!( - "pingora lru promote total {elapsed:?}, {:?} avg per operation thread {i}", - elapsed / ITERATIONS as u32 - ); - }); - handlers.push(handler); + for top_n in [0, 3, 10, 50, 100] { + let lru = make_lru::(items); + let mut rng = thread_rng(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote_top_n(dist.sample(&mut rng) as u64, top_n); + } + let elapsed = before.elapsed(); + println!( + " promote_top_{top_n:<3} {elapsed:?} total, {:?} avg", + elapsed / ITERATIONS as u32, + ); } - for thread in handlers { - thread.join().unwrap(); + + // --- Multi-threaded --- + println!(" {THREADS}-threaded:"); + + { + let lru = Arc::new(make_lru::(items)); + let barrier = Arc::new(Barrier::new(THREADS)); + let mut handlers = vec![]; + for _ in 0..THREADS { + let lru = lru.clone(); + let dist = Arc::clone(dist); + let barrier = barrier.clone(); + handlers.push(thread::spawn(move || { + let mut rng = thread_rng(); + barrier.wait(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote(dist.sample(&mut rng) as u64); + } + before.elapsed() + })); + } + let elapsed: Vec<_> = handlers.into_iter().map(|h| h.join().unwrap()).collect(); + let avg = elapsed.iter().sum::() / THREADS as u32; + println!( + " promote: avg {avg:?}, {:?} avg per op", + avg / ITERATIONS as u32, + ); } - let mut handlers = vec![]; - for i in 0..THREADS { - let plru = plru.clone(); - let handler = thread::spawn(move || { - let mut rng = thread_rng(); - let dist = WeightedIndex::new(WEIGHTS).unwrap(); - let before = Instant::now(); - for _ in 0..ITERATIONS { - plru.promote_top_n(dist.sample(&mut rng) as u64, 10); - } - let elapsed = before.elapsed(); - println!( - "pingora lru promote_top_10 total {elapsed:?}, {:?} avg per operation thread {i}", - elapsed / ITERATIONS as u32 - ); - }); - handlers.push(handler); + for top_n in [0, 3, 10, 50, 100] { + let lru = Arc::new(make_lru::(items)); + let barrier = Arc::new(Barrier::new(THREADS)); + let mut handlers = vec![]; + for _ in 0..THREADS { + let lru = lru.clone(); + let dist = Arc::clone(dist); + let barrier = barrier.clone(); + handlers.push(thread::spawn(move || { + let mut rng = thread_rng(); + barrier.wait(); + let before = Instant::now(); + for _ in 0..ITERATIONS { + lru.promote_top_n(dist.sample(&mut rng) as u64, top_n); + } + before.elapsed() + })); + } + let elapsed: Vec<_> = handlers.into_iter().map(|h| h.join().unwrap()).collect(); + let avg = elapsed.iter().sum::() / THREADS as u32; + println!( + " promote_top_{top_n:<3} avg {avg:?}, {:?} avg per op", + avg / ITERATIONS as u32, + ); } - for thread in handlers { - thread.join().unwrap(); +} + +fn main() { + // Benchmark tiers to simulate production-level data sizes: + // Small = original bench scale (10 items/shard) + // Typical = ~100K items/shard (3.2M total across 32 shards) + // Large = ~500K items/shard (16M total) — gated behind + // BENCH_LARGE=1 to avoid OOM on CI runners (~1.5GB heap) + // + // Note: the Typical tier allocates ~150MB per make_lru() call. With + // multiple variants (promote + 5 top_n values) × configs, total peak + // memory is ~1GB. Well within CI limits but notable for constrained machines. + + // Original benchmark scale (100 items, 10 shards = 10 per shard) + bench_config("Small (original bench scale)", 100, 10, 10, 100); + + // Typical (~100K items/shard), 10% hot + bench_config("Typical (100K/shard, 10% hot)", 3_200_000, 32, 10, 100); + + // Typical (~100K items/shard), heavy-hitter: only 10 items dominate + // Simulates viral content / popular API endpoints where a handful of + // assets receive the vast majority of traffic. + bench_config_abs( + "Typical (100K/shard, 10 heavy hitters)", + 3_200_000, + 32, + 10, + 10_000, + ); + + // Typical (~100K items/shard), moderate hot set: 100 items dominate + bench_config_abs( + "Typical (100K/shard, 100 heavy hitters)", + 3_200_000, + 32, + 100, + 10_000, + ); + + // Large (~500K items/shard, ~1.5GB heap) + if std::env::var("BENCH_LARGE").is_ok() { + bench_config("Large (500K/shard, 10% hot)", 16_000_000, 32, 10, 100); + } else { + println!("\n=== Skipping large bench (set BENCH_LARGE=1 to enable) ==="); } } diff --git a/pingora-lru/src/lib.rs b/pingora-lru/src/lib.rs index 23728c4f..67f59230 100644 --- a/pingora-lru/src/lib.rs +++ b/pingora-lru/src/lib.rs @@ -128,9 +128,18 @@ impl Lru { /// Promote to the top n of the LRU /// - /// This function is a bit more efficient in terms of reducing lock contention because it - /// will acquire a write lock only if the key is outside top n but only acquires a read lock - /// when the key is already in the top n. + /// This function acquires a read lock first to check if the key is already + /// in the top `n` positions. If so, it returns early without a write lock. + /// Otherwise it falls through to a write lock for the actual promotion. + /// + /// **Performance note**: this optimization only helps when `n` covers a + /// significant fraction of the shard. At production scale (~100K+ items + /// per shard), hot items are rarely in the top N positions, so the + /// read-lock scan is usually wasted work that adds latency without + /// reducing contention. Benchmarks (`cargo bench --bench bench_lru`) + /// show that plain [`promote()`](Self::promote) is faster at scale. + /// Consider using `promote()` directly unless profiling shows a clear + /// benefit for your workload. /// /// Return false if the item doesn't exist pub fn promote_top_n(&self, key: u64, top: usize) -> bool { @@ -226,6 +235,21 @@ impl Lru { self.units[get_shard(key, N)].read().peek_weight(key) } + /// Peek at the least-recently-used item in the given shard without removing it. + /// + /// Returns a clone of the data and the weight, or `None` if the shard is empty + /// or `shard >= N`. + pub fn peek_lru(&self, shard: usize) -> Option<(T, usize)> + where + T: Clone, + { + self.units + .get(shard)? + .read() + .peek_lru() + .map(|(data, weight)| (data.clone(), weight)) + } + /// Return the current total weight. pub fn weight(&self) -> usize { self.weight.load(Ordering::Relaxed) @@ -374,6 +398,19 @@ impl LruUnit { (node.data, node.weight) }) } + + /// Peek at the least-recently-used item without removing it. + /// + /// Returns a reference to the data and weight of the tail item, or `None` + /// if empty. + pub fn peek_lru(&self) -> Option<(&T, usize)> { + self.order + .tail() + .and_then(|idx| self.order.peek(idx)) + .and_then(|key| self.lookup_table.get(&key)) + .map(|node| (&node.data, node.weight)) + } + // TODO: scan the tail up to K elements to decide which ones to evict pub fn remove(&mut self, key: u64) -> Option<(T, usize)> { @@ -696,6 +733,29 @@ mod test_lru { assert_eq!(evicted.len(), 2); assert_eq!(lru.evicted_len(), 2); } + + #[test] + fn test_peek_lru() { + let lru = Lru::::with_capacity(10, 10); + + // empty shard + assert!(lru.peek_lru(0).is_none()); + + lru.admit(1, 10, 1); + assert_eq!(lru.peek_lru(0).unwrap(), (10, 1)); + + lru.admit(2, 20, 2); + // key 1 is LRU tail + assert_eq!(lru.peek_lru(0).unwrap(), (10, 1)); + + // promote key 1 + lru.promote(1); + // key 2 is now LRU tail + assert_eq!(lru.peek_lru(0).unwrap(), (20, 2)); + + // out-of-bounds returns None + assert!(lru.peek_lru(999).is_none()); + } } #[cfg(test)] @@ -865,4 +925,33 @@ mod test_lru_unit { assert_eq!(lru.used_weight(), 1 + 3 + 4 + 5); assert_lru(&lru, &[2, 3, 4, 5]); } + + #[test] + fn test_peek_lru() { + let mut lru = LruUnit::with_capacity(10); + + // empty returns None + assert!(lru.peek_lru().is_none()); + + // single item is both head and tail + lru.admit(1, 10, 1); + let (data, weight) = lru.peek_lru().unwrap(); + assert_eq!(*data, 10); + assert_eq!(weight, 1); + + // second admission pushes first to tail + lru.admit(2, 20, 2); + let (data, _) = lru.peek_lru().unwrap(); + assert_eq!(*data, 10); // key 1 is LRU tail + + // promote key 1 — now key 2 is tail + lru.access(1); + let (data, _) = lru.peek_lru().unwrap(); + assert_eq!(*data, 20); // key 2 is now LRU tail + + // peek doesn't remove + assert!(lru.peek_lru().is_some()); + assert!(lru.peek(1).is_some()); + assert!(lru.peek(2).is_some()); + } } diff --git a/pingora-prometheus/Cargo.toml b/pingora-prometheus/Cargo.toml new file mode 100644 index 00000000..d9701213 --- /dev/null +++ b/pingora-prometheus/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "pingora-prometheus" +version = "0.8.0" +authors = ["Pingora Team at Cloudflare "] +license = "Apache-2.0" +edition = "2021" +repository = "https://github.com/cloudflare/pingora" +categories = ["asynchronous", "network-programming"] +keywords = ["async", "http", "prometheus", "pingora"] +description = """ +A Prometheus metrics HTTP server for pingora services. +""" + +[lib] +name = "pingora_prometheus" +path = "src/lib.rs" + +[dependencies] +pingora-core = { version = "0.8.0", path = "../pingora-core", default-features = false } +prometheus = "0.14" +async-trait = { workspace = true } +http = { workspace = true } diff --git a/pingora-prometheus/src/lib.rs b/pingora-prometheus/src/lib.rs new file mode 100644 index 00000000..cfd90b89 --- /dev/null +++ b/pingora-prometheus/src/lib.rs @@ -0,0 +1,131 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![warn(clippy::all)] + +//! A Prometheus metrics HTTP server for [pingora](https://docs.rs/pingora) services. +//! +//! This crate provides [`PrometheusHttpApp`] and [`PrometheusServer`], which serve +//! all [static metrics](https://docs.rs/prometheus/latest/prometheus/index.html#static-metrics) +//! collected via the [`prometheus`] crate as an HTTP endpoint. +//! +//! # Example +//! +//! ```rust,ignore +//! use pingora_core::services::listening::Service; +//! use pingora_prometheus::new_prometheus_server; +//! +//! let mut prometheus_service = Service::new( +//! "Prometheus HTTP".to_string(), +//! new_prometheus_server(), +//! ); +//! prometheus_service.add_tcp("127.0.0.1:6150"); +//! server.add_service(prometheus_service); +//! ``` +//! +//! Or use the convenience function: +//! +//! ```rust,ignore +//! let mut prometheus_service = pingora_prometheus::prometheus_http_service(); +//! prometheus_service.add_tcp("127.0.0.1:6150"); +//! server.add_service(prometheus_service); +//! ``` + +use async_trait::async_trait; +use http::Response; +use prometheus::{Encoder, TextEncoder}; + +use pingora_core::apps::http_app::{HttpServer, ServeHttp}; +use pingora_core::modules::http::compression::ResponseCompressionBuilder; +use pingora_core::protocols::http::ServerSession; +use pingora_core::services::listening::Service; + +/// Re-export of the [`prometheus`] crate. +/// +/// Use this re-export to ensure your metrics are registered in the same +/// global registry that [`PrometheusHttpApp`] gathers from, avoiding +/// version mismatches that would cause metrics to silently not appear. +/// +/// # Example +/// +/// ```rust,ignore +/// use pingora_prometheus::prometheus::{self, register_int_counter, IntCounter}; +/// use once_cell::sync::Lazy; +/// +/// static REQUESTS: Lazy = Lazy::new(|| { +/// register_int_counter!("requests_total", "Total requests").unwrap() +/// }); +/// ``` +pub use prometheus; + +/// An HTTP application that reports Prometheus metrics. +/// +/// This application will report all the [static metrics](https://docs.rs/prometheus/latest/prometheus/index.html#static-metrics) +/// collected via the [Prometheus](https://docs.rs/prometheus/) crate. +/// +/// Currently serves metrics on all request paths. By convention, Prometheus +/// scrapers expect metrics at `/metrics`. Since this app is typically bound +/// to a dedicated listener address, this works in practice, but callers +/// should be aware of this if sharing the listener with other routes. +// TODO: consider restricting to `/metrics` and returning 404 for other paths +pub struct PrometheusHttpApp; + +#[async_trait] +impl ServeHttp for PrometheusHttpApp { + async fn response(&self, _http_session: &mut ServerSession) -> Response> { + let encoder = TextEncoder::new(); + let metric_families = prometheus::gather(); + let mut buffer = vec![]; + encoder.encode(&metric_families, &mut buffer).unwrap(); + Response::builder() + .status(200) + .header(http::header::CONTENT_TYPE, encoder.format_type()) + .header(http::header::CONTENT_LENGTH, buffer.len()) + .body(buffer) + .unwrap() + } +} + +/// The [`HttpServer`] for [`PrometheusHttpApp`]. +/// +/// This type provides the functionality of [`PrometheusHttpApp`] with gzip +/// compression enabled (level 7). +pub type PrometheusServer = HttpServer; + +/// Create a new [`PrometheusServer`] with compression enabled. +pub fn new_prometheus_server() -> PrometheusServer { + let mut server = PrometheusServer::new_app(PrometheusHttpApp); + // enable gzip level 7 compression + server.add_module(ResponseCompressionBuilder::enable(7)); + server +} + +/// Create a Prometheus HTTP [`Service`] ready to have endpoints added. +/// +/// This is a convenience function that creates a [`Service`] wrapping a +/// [`PrometheusServer`] with compression enabled. +/// +/// # Example +/// +/// ```rust,ignore +/// let mut prometheus_service = pingora_prometheus::prometheus_http_service(); +/// prometheus_service.add_tcp("127.0.0.1:6150"); +/// server.add_service(prometheus_service); +/// ``` +pub fn prometheus_http_service() -> Service { + Service::new( + "Prometheus metric HTTP".to_string(), + new_prometheus_server(), + ) +} diff --git a/pingora-proxy/Cargo.toml b/pingora-proxy/Cargo.toml index d4df1378..d82179cb 100644 --- a/pingora-proxy/Cargo.toml +++ b/pingora-proxy/Cargo.toml @@ -47,6 +47,7 @@ hyper = "0.14" tokio-tungstenite = "0.20.1" pingora-limits = { version = "0.8.0", path = "../pingora-limits" } pingora-load-balancing = { version = "0.8.0", path = "../pingora-load-balancing", default-features=false } +pingora-prometheus = { version = "0.8.0", path = "../pingora-prometheus" } prometheus = "0" futures-util = "0.3" serde = { version = "1.0", features = ["derive"] } @@ -71,7 +72,6 @@ any_tls = [] sentry = ["pingora-core/sentry"] adjust_upstream_modules = [] connection_filter = ["pingora-core/connection_filter"] -prometheus = ["pingora-core/prometheus"] trace = ["pingora-cache/trace"] [[example]] diff --git a/pingora-proxy/examples/gateway.rs b/pingora-proxy/examples/gateway.rs index e320688f..79c1646c 100644 --- a/pingora-proxy/examples/gateway.rs +++ b/pingora-proxy/examples/gateway.rs @@ -129,12 +129,8 @@ fn main() { my_proxy.add_tcp("0.0.0.0:6191"); my_server.add_service(my_proxy); - #[cfg(feature = "prometheus")] - let mut prometheus_service_http = - pingora_core::services::listening::Service::prometheus_http_service(); - #[cfg(feature = "prometheus")] + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("127.0.0.1:6192"); - #[cfg(feature = "prometheus")] my_server.add_service(prometheus_service_http); my_server.run_forever(); diff --git a/pingora-proxy/src/lib.rs b/pingora-proxy/src/lib.rs index 52a89cbd..e5433efa 100644 --- a/pingora-proxy/src/lib.rs +++ b/pingora-proxy/src/lib.rs @@ -46,7 +46,7 @@ use pingora_http::{RequestHeader, ResponseHeader}; use std::fmt::Debug; use std::str; use std::sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, Arc, }; use std::time::Duration; @@ -190,6 +190,17 @@ where } } + /// Return the number of times a pooled upstream connection was found to contain + /// unexpected data from the server. + pub fn unexpected_data_connection_count(&self) -> u64 { + self.client_upstream.unexpected_data_connection_count() + } + + /// Return a shared reference to the unexpected data connection counter for periodic metric reporting. + pub fn unexpected_data_connection_counter(&self) -> Arc { + self.client_upstream.unexpected_data_connection_counter() + } + /// Initialize the downstream modules for this proxy. /// /// This method must be called after creating an [`HttpProxy`] with [`HttpProxy::new()`] @@ -587,57 +598,115 @@ impl Session { .await } - pub async fn write_response_tasks(&mut self, mut tasks: Vec) -> Result { - let mut seen_upgraded = self.was_upgraded(); - for task in tasks.iter_mut() { - match task { - HttpTask::Header(resp, end) => { - self.downstream_modules_ctx - .response_header_filter(resp, *end) - .await?; - } - HttpTask::Body(data, end) => { - self.downstream_modules_ctx - .response_body_filter(data, *end)?; - } - HttpTask::UpgradedBody(data, end) => { - seen_upgraded = true; - self.downstream_modules_ctx - .response_body_filter(data, *end)?; - } - HttpTask::Trailer(trailers) => { - if let Some(buf) = self - .downstream_modules_ctx - .response_trailer_filter(trailers)? - { - // Write the trailers into the body if the filter - // returns a buffer. - // - // Note, this will not work if end of stream has already - // been seen or we've written content-length bytes. - // (Trailers should never come after upgraded body) - *task = HttpTask::Body(Some(buf), true); - } - } - HttpTask::Done => { - // `Done` can be sent in certain response paths to mark end - // of response if not already done via trailers or body with - // end flag set. - // If the filter returns body bytes on Done, - // write them into the response. + // Run downstream module response filters on a single task, updating + // `seen_upgraded` to track whether an upgrade has been seen. Used by both + // `send_downstream_proxy_task` and `write_response_tasks`. + async fn downstream_response_task_filter( + &mut self, + task: &mut HttpTask, + seen_upgraded: &mut bool, + ) -> Result<()> { + match task { + HttpTask::Header(resp, end) => { + self.downstream_modules_ctx + .response_header_filter(resp, *end) + .await?; + } + HttpTask::Body(data, end) => { + self.downstream_modules_ctx + .response_body_filter(data, *end)?; + } + HttpTask::UpgradedBody(data, end) => { + *seen_upgraded = true; + self.downstream_modules_ctx + .response_body_filter(data, *end)?; + } + HttpTask::Trailer(trailers) => { + if let Some(buf) = self + .downstream_modules_ctx + .response_trailer_filter(trailers)? + { + // Write the trailers into the body if the filter + // returns a buffer. // // Note, this will not work if end of stream has already // been seen or we've written content-length bytes. - if let Some(buf) = self.downstream_modules_ctx.response_done_filter()? { - if seen_upgraded { - *task = HttpTask::UpgradedBody(Some(buf), true); - } else { - *task = HttpTask::Body(Some(buf), true); - } + // (Trailers should never come after upgraded body) + *task = HttpTask::Body(Some(buf), true); + } + } + HttpTask::Done => { + // `Done` can be sent in certain response paths to mark end + // of response if not already done via trailers or body with + // end flag set. + // If the filter returns body bytes on Done, + // write them into the response. + // + // Note, this will not work if end of stream has already + // been seen or we've written content-length bytes. + if let Some(buf) = self.downstream_modules_ctx.response_done_filter()? { + if *seen_upgraded { + *task = HttpTask::UpgradedBody(Some(buf), true); + } else { + *task = HttpTask::Body(Some(buf), true); } } - _ => { /* Failed */ } } + _ => { /* Failed */ } + } + Ok(()) + } + + /// Queue a downstream proxy task for cancel-safe writing after running + /// downstream module filters. This allows decoupling cache writes from + /// downstream writes. + /// + /// Only works with sessions that support the proxy task API (currently H1). + /// + /// # Panics + /// Panics if the session doesn't support the proxy task API. + /// Use `write_response_tasks()` for sessions that don't support the proxy task API. + pub async fn send_downstream_proxy_task(&mut self, mut task: HttpTask) -> Result<()> { + let mut seen_upgraded = self.was_upgraded(); + self.downstream_response_task_filter(&mut task, &mut seen_upgraded) + .await?; + self.downstream_session.send_downstream_proxy_task(task); + Ok(()) + } + + /// Enable or disable the cancel-safe proxy task API for this session. + /// + /// When disabled, the proxy falls back to the blocking `write_response_tasks` + /// path. This can be called from request filters to opt out on a per-request + /// basis. + pub fn set_proxy_tasks_enabled(&mut self, enabled: bool) { + self.downstream_session.set_proxy_tasks_enabled(enabled); + } + + /// Check if there are pending downstream tasks queued for writing. + /// Used for backpressure - don't queue more cache tasks if we have pending writes. + /// Returns false for sessions that don't support the proxy task API. + pub fn has_pending_downstream_tasks(&self) -> bool { + self.downstream_session.supports_proxy_task_api() + && self.downstream_session.has_pending_downstream_proxy_tasks() + } + + /// Write all queued downstream proxy tasks. This is cancel-safe and can be called + /// in a select! loop while waiting for upstream tasks. + /// For sessions that don't support the proxy task API, this is a no-op. + pub async fn write_downstream_proxy_tasks(&mut self) -> Result { + if self.downstream_session.supports_proxy_task_api() { + self.downstream_session.write_downstream_proxy_tasks().await + } else { + Ok(false) + } + } + + pub async fn write_response_tasks(&mut self, mut tasks: Vec) -> Result { + let mut seen_upgraded = self.was_upgraded(); + for task in tasks.iter_mut() { + self.downstream_response_task_filter(task, &mut seen_upgraded) + .await?; } self.downstream_session.response_duplex_vec(tasks).await } diff --git a/pingora-proxy/src/proxy_common.rs b/pingora-proxy/src/proxy_common.rs index e1d36f69..6c40760c 100644 --- a/pingora-proxy/src/proxy_common.rs +++ b/pingora-proxy/src/proxy_common.rs @@ -1,3 +1,17 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /// Possible downstream states during request multiplexing #[derive(Debug, Clone, Copy)] pub(crate) enum DownstreamStateMachine { @@ -36,19 +50,28 @@ impl DownstreamStateMachine { matches!(self, Self::Errored) } - /// Move the state machine to Finished state if `set` is true + /// Move the state machine to Finished state if `set` is true. + /// + /// No-op when the current state is [`Errored`](Self::Errored) — once errored the + /// downstream connection must not be reused, and late upstream chunks arriving + /// via `rx.recv()` must not overwrite that decision. pub fn maybe_finished(&mut self, set: bool) { - if set { + if set && !self.is_errored() { *self = Self::ReadingFinished } } - /// Reset if we should continue reading from the downstream again. - /// Only used with upgraded connections when body mode changes. + /// Reset to [`Reading`](Self::Reading) for upgraded connections when body mode changes. + /// + /// No-op when the current state is [`Errored`](Self::Errored). pub fn reset(&mut self) { - *self = Self::Reading; + if !self.is_errored() { + *self = Self::Reading; + } } + /// Transition to [`Errored`](Self::Errored). This is a terminal state: once entered, + /// no other state transition is permitted and the connection must not be reused. pub fn to_errored(&mut self) { *self = Self::Errored } @@ -97,3 +120,61 @@ impl ResponseStateMachine { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn normal_lifecycle() { + let mut ds = DownstreamStateMachine::new(false); + assert!(ds.is_reading()); + assert!(ds.can_poll()); + assert!(!ds.is_errored()); + + ds.maybe_finished(true); + assert!(!ds.is_reading()); + assert!(ds.is_done()); + assert!(ds.can_poll()); // ReadingFinished still allows polling (for idle) + assert!(!ds.is_errored()); + } + + #[test] + fn errored_is_terminal() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + assert!(ds.is_done()); + } + + /// `maybe_finished(false)` is always a no-op regardless of state. + #[test] + fn maybe_finished_false_is_noop() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + ds.maybe_finished(false); // must not panic + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + } + + /// `maybe_finished(true)` on `Errored` is a no-op — `Errored` is terminal. + #[test] + fn maybe_finished_true_noop_on_errored() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + ds.maybe_finished(true); // must not overwrite Errored + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + } + + /// `reset()` on `Errored` is a no-op — `Errored` is terminal. + #[test] + fn reset_noop_on_errored() { + let mut ds = DownstreamStateMachine::new(false); + ds.to_errored(); + ds.reset(); // must not overwrite Errored + assert!(ds.is_errored()); + assert!(!ds.can_poll()); + } +} diff --git a/pingora-proxy/src/proxy_custom.rs b/pingora-proxy/src/proxy_custom.rs index b571b3ce..b7ee1d50 100644 --- a/pingora-proxy/src/proxy_custom.rs +++ b/pingora-proxy/src/proxy_custom.rs @@ -257,7 +257,88 @@ where } } - // returns whether server (downstream) session can be reused + #[allow(clippy::too_many_arguments)] + async fn process_upstream_tasks_custom( + &self, + session: &mut Session, + ctx: &mut SV::CTX, + initial_task: HttpTask, + rx: &mut mpsc::Receiver, + serve_from_cache: &mut ServeFromCache, + range_body_filter: &mut proxy_cache::range_filter::RangeBodyFilter, + response_state: &mut ResponseStateMachine, + ) -> Result> + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + if serve_from_cache.should_discard_upstream() { + // just drain, do we need to do anything else? + return Ok(None); + } + + // Batch: pull as many tasks as we can from rx + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + tasks.push(initial_task); + while let Ok(task) = rx.try_recv() { + tasks.push(task); + } + + /* run filters before sending to downstream */ + let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + for mut t in tasks { + if self.revalidate_or_stale(session, &mut t, ctx).await { + serve_from_cache.enable(); + response_state.enable_cached_response(); + // skip downstream filtering entirely as the 304 will not be sent + break; + } + #[cfg(feature = "adjust_upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } + session.upstream_compression.response_filter(&mut t); + // check error and abort + // otherwise the error is surfaced via write_response_tasks() + if !serve_from_cache.should_send_to_downstream() { + if let HttpTask::Failed(e) = t { + return Err(e); + } + } + filtered_tasks.push( + self.custom_response_filter( + session, + t, + ctx, + serve_from_cache, + range_body_filter, + false, + ) + .await?, + ); + if serve_from_cache.is_miss_header() { + response_state.enable_cached_response(); + } + } + + if !serve_from_cache.should_send_to_downstream() { + // TODO: need to derive response_done from filtered_tasks in case downstream failed already + return Ok(None); + } + + let response_done = session.write_response_tasks(filtered_tasks).await?; + + Ok(Some(response_done)) + } + + // TODO: pre-existing inconsistency with proxy_h1/proxy_h2 to address in a follow-up: + // upstream task rx.recv() branch is missing + // downstream_state.maybe_finished(session.is_body_done()) after processing. proxy_h1 has + // this because upgrade responses can force the body done — since custom upstreams can + // serve H1 downstreams that support upgrades, the same may be needed here. + // Returns whether server (downstream) session can be reused #[allow(clippy::too_many_arguments)] async fn custom_bidirection_down_to_up( &self, @@ -303,6 +384,8 @@ where let mut serve_from_cache = ServeFromCache::new(); let mut range_body_filter = proxy_cache::range_filter::RangeBodyFilter::new(); + let mut next_upstream_task: Option = None; + let mut upstream_custom = true; let mut downstream_custom = true; @@ -361,99 +444,142 @@ where }; }, - task = rx.recv(), if !response_state.upstream_done() => { - debug!("upstream event"); - + // Handle buffered upstream task from previous iteration + task = async { next_upstream_task.take() }, if next_upstream_task.is_some() => { + debug!("buffered upstream event: {:?}", task); if let Some(t) = task { - debug!("upstream event custom: {:?}", t); - if serve_from_cache.should_discard_upstream() { - // just drain, do we need to do anything else? - continue; - } - // pull as many tasks as we can - let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - tasks.push(t); - while let Ok(task) = rx.try_recv() { - tasks.push(task); - } - - /* run filters before sending to downstream */ - let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - for mut t in tasks { - if self.revalidate_or_stale(session, &mut t, ctx).await { - serve_from_cache.enable(); - response_state.enable_cached_response(); - // skip downstream filtering entirely as the 304 will not be sent - break; - } - #[cfg(feature = "adjust_upstream_modules")] - if let HttpTask::Header(header, end_of_stream) = &t { - self.inner - .adjust_upstream_modules(session, header, *end_of_stream, ctx) - .await?; - } - session.upstream_compression.response_filter(&mut t); - // check error and abort - // otherwise the error is surfaced via write_response_tasks() - if !serve_from_cache.should_send_to_downstream() { - if let HttpTask::Failed(e) = t { - return Err(e); - } - } - filtered_tasks.push( - self.custom_response_filter(session, t, ctx, - &mut serve_from_cache, - &mut range_body_filter, false).await?); - if serve_from_cache.is_miss_header() { - response_state.enable_cached_response(); - } - } - - if !serve_from_cache.should_send_to_downstream() { - // TODO: need to derive response_done from filtered_tasks in case downstream failed already + let Some(response_done) = self.process_upstream_tasks_custom( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache continue; - } + }; + response_state.maybe_set_upstream_done(response_done); + } else { + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + }, + task = rx.recv(), if !response_state.upstream_done() && next_upstream_task.is_none() => { + debug!("upstream event: {:?}", task); + if let Some(t) = task { let upgraded = session.was_upgraded(); - let response_done = session.write_response_tasks(filtered_tasks).await?; + let Some(response_done) = self.process_upstream_tasks_custom( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache + continue; + }; if !upgraded && session.was_upgraded() && downstream_state.can_poll() { // just upgraded, the downstream state should be reset to continue to // poll body trace!("reset downstream state on upgrade"); downstream_state.reset(); } - response_state.maybe_set_upstream_done(response_done); } else { debug!("empty upstream event"); response_state.maybe_set_upstream_done(true); } - } + }, task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), - if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { + if !response_state.cached_done() + && !downstream_state.is_errored() + && serve_from_cache.is_on() + && !session.has_pending_downstream_tasks() => { // backpressure: don't queue if pending writes + let task = self.custom_response_filter(session, task?, ctx, &mut serve_from_cache, &mut range_body_filter, true).await?; - match session.write_response_tasks(vec![task]).await { - Ok(b) => response_state.maybe_set_cache_done(b), - Err(e) => if serve_from_cache.is_miss() { - // give up writing to downstream but wait for upstream cache write to finish - downstream_state.to_errored(); - response_state.maybe_set_cache_done(true); - warn!( - "Downstream Error ignored during caching: {}, {}", - e, - self.inner.request_summary(session, ctx) - ); - continue; - } else { - return Err(e); + + if session.downstream_session.supports_proxy_task_api() { + session.send_downstream_proxy_task(task).await?; + } else { + match session.write_response_tasks(vec![task]).await { + Ok(b) => response_state.maybe_set_cache_done(b), + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + session.downstream_session.on_proxy_failure(e); + continue; + } else { + return Err(e); + } + } + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } } } - if response_state.cached_done() { - if let Err(e) = session.cache.finish_hit_handler().await { - warn!("Error during finish_hit_handler: {}", e); + } + + // Write queued downstream proxy tasks while also polling for upstream tasks. + // This allows cache writes to continue even when downstream is stalled. + // + // "Gate" branch: ready(()) resolves immediately, so the guard controls + // whether we enter. This is not a busy-loop because every path through + // the inner select either (a) drains all pending tasks via + // write_downstream_proxy_tasks (making the guard false), (b) stores an + // upstream task in next_upstream_task (making the guard false), or + // (c) blocks on real I/O inside the nested select. + _ = std::future::ready(()), if session.has_pending_downstream_tasks() && next_upstream_task.is_none() => { + tokio::select! { + // Try to write downstream proxy tasks (cancel-safe) + write_result = session.write_downstream_proxy_tasks() => { + match write_result { + Ok(end) => { + response_state.maybe_set_cache_done(end); + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } + } + } + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream write error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + session.downstream_session.on_proxy_failure(e); + } else { + return Err(e); + } + } + } + + // Also poll for upstream tasks - if we get one, cancel the write and handle it. + upstream_task = rx.recv(), if !response_state.upstream_done() && serve_from_cache.is_on() && next_upstream_task.is_none() => { + if let Some(t) = upstream_task { + next_upstream_task = Some(t); + continue; + } else { + response_state.maybe_set_upstream_done(true); + } } } } diff --git a/pingora-proxy/src/proxy_h1.rs b/pingora-proxy/src/proxy_h1.rs index 9f498aa0..dbf6e5ca 100644 --- a/pingora-proxy/src/proxy_h1.rs +++ b/pingora-proxy/src/proxy_h1.rs @@ -267,6 +267,81 @@ where Ok(()) } + #[allow(clippy::too_many_arguments)] + async fn process_upstream_tasks( + &self, + session: &mut Session, + ctx: &mut SV::CTX, + initial_task: HttpTask, + rx: &mut mpsc::Receiver, + serve_from_cache: &mut ServeFromCache, + range_body_filter: &mut proxy_cache::range_filter::RangeBodyFilter, + response_state: &mut ResponseStateMachine, + ) -> Result> + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + if serve_from_cache.should_discard_upstream() { + // just drain, do we need to do anything else? + return Ok(None); + } + + // Batch: pull as many tasks as we can from rx + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + tasks.push(initial_task); + // tokio::task::unconstrained because now_or_never may yield None when the future is ready + while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { + debug!("upstream event now: {:?}", maybe_task); + if let Some(t) = maybe_task { + tasks.push(t); + } else { + break; // upstream closed + } + } + + /* run filters before sending to downstream */ + let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + for mut t in tasks { + if self.revalidate_or_stale(session, &mut t, ctx).await { + serve_from_cache.enable(); + response_state.enable_cached_response(); + // skip downstream filtering entirely as the 304 will not be sent + break; + } + #[cfg(feature = "adjust_upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } + session.upstream_compression.response_filter(&mut t); + let task = self + .h1_response_filter(session, t, ctx, serve_from_cache, range_body_filter, false) + .await?; + if serve_from_cache.is_miss_header() { + response_state.enable_cached_response(); + } + // check error and abort + // otherwise the error is surfaced via write_response_tasks() + if !serve_from_cache.should_send_to_downstream() { + if let HttpTask::Failed(e) = task { + return Err(e); + } + } + filtered_tasks.push(task); + } + + if !serve_from_cache.should_send_to_downstream() { + // TODO: need to derive response_done from filtered_tasks in case downstream failed already + return Ok(None); + } + + let response_done = session.write_response_tasks(filtered_tasks).await?; + + Ok(Some(response_done)) + } + // todo use this function to replace bidirection_1to2() // returns whether this server (downstream) session can be reused async fn proxy_handle_downstream( @@ -329,6 +404,8 @@ where let mut serve_from_cache = proxy_cache::ServeFromCache::new(); let mut range_body_filter = proxy_cache::range_filter::RangeBodyFilter::new(); + let mut next_upstream_task: Option = None; + /* duplex mode without caching * Read body from downstream while reading response from upstream * If response is done, only read body from downstream @@ -424,74 +501,56 @@ where // If tx is closed, the upstream has already finished its job. downstream_state.maybe_finished(tx.is_closed()); debug!("waiting for permit {send_permit:?}, upstream closed {}", tx.is_closed()); - /* No permit, wait on more capacity to avoid starving. + /* No permit, wait on more capacity to avoid starving. * Otherwise this select only blocks on rx, which might send no data * before the entire body is uploaded. * once more capacity arrives we just loop back */ }, - task = rx.recv(), if !response_state.upstream_done() => { - debug!("upstream event: {:?}", task); + // Handle buffered upstream task from previous iteration + task = async { next_upstream_task.take() }, if next_upstream_task.is_some() => { + debug!("buffered upstream event: {:?}", task); if let Some(t) = task { - if serve_from_cache.should_discard_upstream() { - // just drain, do we need to do anything else? - continue; - } - // pull as many tasks as we can - let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - tasks.push(t); - // tokio::task::unconstrained because now_or_never may yield None when the future is ready - while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { - debug!("upstream event now: {:?}", maybe_task); - if let Some(t) = maybe_task { - tasks.push(t); - } else { - break; // upstream closed - } - } - - /* run filters before sending to downstream */ - let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - for mut t in tasks { - if self.revalidate_or_stale(session, &mut t, ctx).await { - serve_from_cache.enable(); - response_state.enable_cached_response(); - // skip downstream filtering entirely as the 304 will not be sent - break; - } - #[cfg(feature = "adjust_upstream_modules")] - if let HttpTask::Header(header, end_of_stream) = &t { - self.inner - .adjust_upstream_modules(session, header, *end_of_stream, ctx) - .await?; - } - session.upstream_compression.response_filter(&mut t); - let task = self.h1_response_filter(session, t, ctx, - &mut serve_from_cache, - &mut range_body_filter, false).await?; - if serve_from_cache.is_miss_header() { - response_state.enable_cached_response(); - } - // check error and abort - // otherwise the error is surfaced via write_response_tasks() - if !serve_from_cache.should_send_to_downstream() { - if let HttpTask::Failed(e) = task { - return Err(e); - } - } - filtered_tasks.push(task); - } - - if !serve_from_cache.should_send_to_downstream() { - // TODO: need to derive response_done from filtered_tasks in case downstream failed already + let Some(response_done) = self.process_upstream_tasks( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache continue; - } + }; + response_state.maybe_set_upstream_done(response_done); + // unsuccessful upgrade response may force the request done + downstream_state.maybe_finished(session.is_body_done()); + } else { + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + }, - // set to downstream + task = rx.recv(), if !response_state.upstream_done() && next_upstream_task.is_none() => { + debug!("upstream event: {:?}", task); + if let Some(t) = task { let upgraded = session.was_upgraded(); - let response_done = session.write_response_tasks(filtered_tasks).await?; + let Some(response_done) = self.process_upstream_tasks( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache + continue; + }; if !upgraded && session.was_upgraded() && downstream_state.can_poll() { + // TODO: write can happen async now // just upgraded, the downstream state should be reset to continue to // poll body trace!("reset downstream state on upgrade"); @@ -508,35 +567,96 @@ where }, task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), - if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { + if !response_state.cached_done() + && !downstream_state.is_errored() + && serve_from_cache.is_on() + && !session.has_pending_downstream_tasks() => { // backpressure: don't queue if pending writes let task = self.h1_response_filter(session, task?, ctx, &mut serve_from_cache, &mut range_body_filter, true).await?; debug!("serve_from_cache task {task:?}"); - match session.write_response_tasks(vec![task]).await { - Ok(b) => response_state.maybe_set_cache_done(b), - Err(e) => if serve_from_cache.is_miss() { - // give up writing to downstream but wait for upstream cache write to finish - downstream_state.to_errored(); - response_state.maybe_set_cache_done(true); - warn!( - "Downstream Error ignored during caching: {}, {}", - e, - self.inner.request_summary(session, ctx) - ); - // This will not be treated as a final error, but we should signal to - // downstream session regardless - session.downstream_session.on_proxy_failure(e); - continue; - } else { - return Err(e); + if session.downstream_session.supports_proxy_task_api() { + session.send_downstream_proxy_task(task).await?; + } else { + match session.write_response_tasks(vec![task]).await { + Ok(b) => response_state.maybe_set_cache_done(b), + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); + continue; + } else { + return Err(e); + } + } + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } } } - if response_state.cached_done() { - if let Err(e) = session.cache.finish_hit_handler().await { - warn!("Error during finish_hit_handler: {}", e); + } + + // Write queued downstream proxy tasks while also polling for upstream tasks. + // This allows cache writes to continue even when downstream is stalled. + // + // "Gate" branch: ready(()) resolves immediately, so the guard controls + // whether we enter. This is not a busy-loop because every path through + // the inner select either (a) drains all pending tasks via + // write_downstream_proxy_tasks (making the guard false), (b) stores an + // upstream task in next_upstream_task (making the guard false), or + // (c) blocks on real I/O inside the nested select. + _ = std::future::ready(()), if session.has_pending_downstream_tasks() && next_upstream_task.is_none() => { + tokio::select! { + // Try to write downstream proxy tasks (cancel-safe) + write_result = session.write_downstream_proxy_tasks() => { + match write_result { + Ok(end) => { + response_state.maybe_set_cache_done(end); + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } + } + } + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream write error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); + } else { + return Err(e); + } + } + } + + // Also poll for upstream tasks - if we get one, cancel the write and handle it. + // Only poll if there is no buffered task already waiting to be processed. + upstream_task = rx.recv(), if !response_state.upstream_done() && serve_from_cache.is_on() && next_upstream_task.is_none() => { + if let Some(t) = upstream_task { + // Store this upstream task to be processed next iteration + next_upstream_task = Some(t); + continue; + } else { + response_state.maybe_set_upstream_done(true); + } } } } diff --git a/pingora-proxy/src/proxy_h2.rs b/pingora-proxy/src/proxy_h2.rs index acf61f07..afe58a0b 100644 --- a/pingora-proxy/src/proxy_h2.rs +++ b/pingora-proxy/src/proxy_h2.rs @@ -265,6 +265,87 @@ where (server_session_reuse, error) } + #[allow(clippy::too_many_arguments)] + async fn process_upstream_tasks_h2( + &self, + session: &mut Session, + ctx: &mut SV::CTX, + initial_task: HttpTask, + rx: &mut mpsc::Receiver, + serve_from_cache: &mut ServeFromCache, + range_body_filter: &mut proxy_cache::range_filter::RangeBodyFilter, + response_state: &mut ResponseStateMachine, + ) -> Result> + where + SV: ProxyHttp + Send + Sync, + SV::CTX: Send + Sync, + { + if serve_from_cache.should_discard_upstream() { + // just drain, do we need to do anything else? + return Ok(None); + } + + // Batch: pull as many tasks as we can from rx + let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + tasks.push(initial_task); + // tokio::task::unconstrained because now_or_never may yield None when the future is ready + while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { + if let Some(t) = maybe_task { + tasks.push(t); + } else { + break; // upstream closed + } + } + + /* run filters before sending to downstream */ + let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); + for mut t in tasks { + if self.revalidate_or_stale(session, &mut t, ctx).await { + serve_from_cache.enable(); + response_state.enable_cached_response(); + // skip downstream filtering entirely as the 304 will not be sent + break; + } + #[cfg(feature = "adjust_upstream_modules")] + if let HttpTask::Header(header, end_of_stream) = &t { + self.inner + .adjust_upstream_modules(session, header, *end_of_stream, ctx) + .await?; + } + session.upstream_compression.response_filter(&mut t); + // check error and abort + // otherwise the error is surfaced via write_response_tasks() + if !serve_from_cache.should_send_to_downstream() { + if let HttpTask::Failed(e) = t { + return Err(e); + } + } + filtered_tasks.push( + self.h2_response_filter( + session, + t, + ctx, + serve_from_cache, + range_body_filter, + false, + ) + .await?, + ); + if serve_from_cache.is_miss_header() { + response_state.enable_cached_response(); + } + } + + if !serve_from_cache.should_send_to_downstream() { + // TODO: need to derive response_done from filtered_tasks in case downstream failed already + return Ok(None); + } + + let response_done = session.write_response_tasks(filtered_tasks).await?; + + Ok(Some(response_done)) + } + // returns whether server (downstream) session can be reused async fn bidirection_down_to_up( &self, @@ -322,6 +403,8 @@ where let mut serve_from_cache = ServeFromCache::new(); let mut range_body_filter = proxy_cache::range_filter::RangeBodyFilter::new(); + let mut next_upstream_task: Option = None; + /* duplex mode * see the Same function for h1 for more comments */ @@ -388,63 +471,47 @@ where }; }, - task = rx.recv(), if !response_state.upstream_done() => { + // Handle buffered upstream task from previous iteration + task = async { next_upstream_task.take() }, if next_upstream_task.is_some() => { + debug!("buffered upstream event: {:?}", task); if let Some(t) = task { - debug!("upstream event: {:?}", t); - if serve_from_cache.should_discard_upstream() { - // just drain, do we need to do anything else? - continue; - } - // pull as many tasks as we can - let mut tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - tasks.push(t); - // tokio::task::unconstrained because now_or_never may yield None when the future is ready - while let Some(maybe_task) = tokio::task::unconstrained(rx.recv()).now_or_never() { - if let Some(t) = maybe_task { - tasks.push(t); - } else { - break - } - } - - /* run filters before sending to downstream */ - let mut filtered_tasks = Vec::with_capacity(TASK_BUFFER_SIZE); - for mut t in tasks { - if self.revalidate_or_stale(session, &mut t, ctx).await { - serve_from_cache.enable(); - response_state.enable_cached_response(); - // skip downstream filtering entirely as the 304 will not be sent - break; - } - #[cfg(feature = "adjust_upstream_modules")] - if let HttpTask::Header(header, end_of_stream) = &t { - self.inner - .adjust_upstream_modules(session, header, *end_of_stream, ctx) - .await?; - } - session.upstream_compression.response_filter(&mut t); - // check error and abort - // otherwise the error is surfaced via write_response_tasks() - if !serve_from_cache.should_send_to_downstream() { - if let HttpTask::Failed(e) = t { - return Err(e); - } - } - filtered_tasks.push( - self.h2_response_filter(session, t, ctx, - &mut serve_from_cache, - &mut range_body_filter, false).await?); - if serve_from_cache.is_miss_header() { - response_state.enable_cached_response(); - } - } - - if !serve_from_cache.should_send_to_downstream() { - // TODO: need to derive response_done from filtered_tasks in case downstream failed already + let Some(response_done) = self.process_upstream_tasks_h2( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache continue; + }; + if session.was_upgraded() { + return Error::e_explain(H2Error, "upgraded while proxying to h2 session"); } + response_state.maybe_set_upstream_done(response_done); + } else { + debug!("empty upstream event"); + response_state.maybe_set_upstream_done(true); + } + }, - let response_done = session.write_response_tasks(filtered_tasks).await?; + task = rx.recv(), if !response_state.upstream_done() && next_upstream_task.is_none() => { + debug!("upstream event: {:?}", task); + if let Some(t) = task { + let Some(response_done) = self.process_upstream_tasks_h2( + session, + ctx, + t, + &mut rx, + &mut serve_from_cache, + &mut range_body_filter, + &mut response_state, + ).await? else { + // nothing sent downstream e.g. serve_from_cache + continue; + }; if session.was_upgraded() { // it is very weird if the downstream session decides to upgrade // since the client h2 session cannot, return an error on this case @@ -455,37 +522,95 @@ where debug!("empty upstream event"); response_state.maybe_set_upstream_done(true); } - } + }, task = serve_from_cache.next_http_task(&mut session.cache, &mut range_body_filter, upgraded), - if !response_state.cached_done() && !downstream_state.is_errored() && serve_from_cache.is_on() => { + if !response_state.cached_done() + && !downstream_state.is_errored() + && serve_from_cache.is_on() + && !session.has_pending_downstream_tasks() => { // backpressure: don't queue if pending writes + let task = self.h2_response_filter(session, task?, ctx, &mut serve_from_cache, &mut range_body_filter, true).await?; debug!("serve_from_cache task {task:?}"); - match session.write_response_tasks(vec![task]).await { - Ok(b) => response_state.maybe_set_cache_done(b), - Err(e) => if serve_from_cache.is_miss() { - // give up writing to downstream but wait for upstream cache write to finish - downstream_state.to_errored(); - response_state.maybe_set_cache_done(true); - warn!( - "Downstream Error ignored during caching: {}, {}", - e, - self.inner.request_summary(session, ctx) - ); - // This will not be treated as a final error, but we should signal to - // downstream session regardless - session.downstream_session.on_proxy_failure(e); - continue; - } else { - return Err(e); + if session.downstream_session.supports_proxy_task_api() { + session.send_downstream_proxy_task(task).await?; + } else { + match session.write_response_tasks(vec![task]).await { + Ok(b) => response_state.maybe_set_cache_done(b), + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream Error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + // This will not be treated as a final error, but we should signal to + // downstream session regardless + session.downstream_session.on_proxy_failure(e); + continue; + } else { + return Err(e); + } + } + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } } } - if response_state.cached_done() { - if let Err(e) = session.cache.finish_hit_handler().await { - warn!("Error during finish_hit_handler: {}", e); + } + + // Write queued downstream proxy tasks while also polling for upstream tasks. + // This allows cache writes to continue even when downstream is stalled. + // + // "Gate" branch: ready(()) resolves immediately, so the guard controls + // whether we enter. This is not a busy-loop because every path through + // the inner select either (a) drains all pending tasks via + // write_downstream_proxy_tasks (making the guard false), (b) stores an + // upstream task in next_upstream_task (making the guard false), or + // (c) blocks on real I/O inside the nested select. + _ = std::future::ready(()), if session.has_pending_downstream_tasks() && next_upstream_task.is_none() => { + tokio::select! { + // Try to write downstream proxy tasks (cancel-safe) + write_result = session.write_downstream_proxy_tasks() => { + match write_result { + Ok(end) => { + response_state.maybe_set_cache_done(end); + if response_state.cached_done() { + if let Err(e) = session.cache.finish_hit_handler().await { + warn!("Error during finish_hit_handler: {}", e); + } + } + } + Err(e) => if serve_from_cache.is_miss() { + // give up writing to downstream but wait for upstream cache write to finish + downstream_state.to_errored(); + response_state.maybe_set_cache_done(true); + warn!( + "Downstream write error ignored during caching: {}, {}", + e, + self.inner.request_summary(session, ctx) + ); + session.downstream_session.on_proxy_failure(e); + } else { + return Err(e); + } + } + } + + // Also poll for upstream tasks - if we get one, cancel the write and handle it. + upstream_task = rx.recv(), if !response_state.upstream_done() && serve_from_cache.is_on() && next_upstream_task.is_none() => { + if let Some(t) = upstream_task { + next_upstream_task = Some(t); + continue; + } else { + response_state.maybe_set_upstream_done(true); + } } } } diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index 9ae4511e..ff6453d4 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -2913,6 +2913,154 @@ mod test_cache { assert_eq!(res.text().await.unwrap(), "hello world"); } + #[tokio::test] + #[ignore = "flaky in CI due to timing/resource contention"] + async fn test_caching_when_downstream_stalls() { + use std::net::ToSocketAddrs; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpStream; + + init(); + let url = "http://127.0.0.1:6148/unique/test_caching_when_downstream_stalls/download/"; + + // Connection 1: read 10KiB then stall, holding the cache lock while + // the proxy populates cache from upstream. + let slow_task = tokio::spawn(async move { + let addr = "127.0.0.1:6148".to_socket_addrs().unwrap().next().unwrap(); + let mut stream = TcpStream::connect(&addr).await.unwrap(); + + let request = concat!( + "GET /unique/test_caching_when_downstream_stalls/download/ HTTP/1.1\r\n", + "Host: 127.0.0.1:6148\r\n", + "x-lock: true\r\n", + "x-set-cache-control: public, max-age=60\r\n", + "\r\n", + ); + stream.write_all(request.as_bytes()).await.unwrap(); + + let mut buf = [0; 10 * 1024]; + let mut b = &mut buf[..]; + while !b.is_empty() { + let n = stream.read(b).await.unwrap(); + b = &mut b[n..] + } + + // Hold the stalled connection open long enough + sleep(Duration::from_secs(10)).await; + }); + + // Give connection 1 time to acquire the cache lock + sleep(Duration::from_secs(1)).await; + + // Connection 2: should get a cache hit once the proxy finishes + // populating cache from upstream (independent of stall). + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(url) + .header("x-lock", "true") + .header("x-set-cache-control", "public, max-age=60") + .timeout(Duration::from_secs(8)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "hit"); + + // If the cache was populated fast enough (before connection 2 arrived), + // there is no lock contention and x-cache-lock-time-ms is absent. + // If there was contention, the wait should be short. + if let Some(lock_ms) = headers.get("x-cache-lock-time-ms") { + let ms: u64 = lock_ms.to_str().unwrap().parse().unwrap(); + assert!( + ms < 2000, + "lock wait {ms}ms should be well under the 2s timeout" + ); + } + + assert_eq!( + res.text().await.unwrap(), + String::from("A").repeat(4 * 1024 * 1024) + ); + + let elapsed = start.elapsed(); + assert!( + elapsed < Duration::from_secs(5), + "second request took {elapsed:?}, should be fast" + ); + + // Don't wait for the slow connection + slow_task.abort(); + } + + // Same as test_caching_when_downstream_stalls but the proxy connects + // to the origin over H2 (via the x-h2 header). + // + #[tokio::test] + #[ignore = "flaky in CI due to timing/resource contention"] + async fn test_caching_h2_upstream_when_downstream_stalls() { + use std::net::ToSocketAddrs; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpStream; + + init(); + let url = "http://127.0.0.1:6148/unique/test_caching_h2_upstream_when_downstream_stalls/download/"; + + let slow_task = tokio::spawn(async move { + let addr = "127.0.0.1:6148".to_socket_addrs().unwrap().next().unwrap(); + let mut stream = TcpStream::connect(&addr).await.unwrap(); + + let request = concat!( + "GET /unique/test_caching_h2_upstream_when_downstream_stalls/download/ HTTP/1.1\r\n", + "Host: 127.0.0.1:6148\r\n", + "x-h2: true\r\n", + "x-lock: true\r\n", + "x-set-cache-control: public, max-age=60\r\n", + "\r\n", + ); + stream.write_all(request.as_bytes()).await.unwrap(); + + let mut buf = [0; 10 * 1024]; + let mut b = &mut buf[..]; + while !b.is_empty() { + let n = stream.read(b).await.unwrap(); + b = &mut b[n..] + } + + sleep(Duration::from_secs(10)).await; + }); + + sleep(Duration::from_secs(1)).await; + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(url) + .header("x-h2", "true") + .header("x-lock", "true") + .header("x-set-cache-control", "public, max-age=60") + .timeout(Duration::from_secs(8)) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "hit"); + assert_eq!( + res.text().await.unwrap(), + String::from("A").repeat(4 * 1024 * 1024) + ); + + let elapsed = start.elapsed(); + assert!( + elapsed < Duration::from_secs(5), + "second request took {elapsed:?}, should be fast (upstream-speed-bound)" + ); + + slow_task.abort(); + } + async fn send_vary_req_with_headers_with_dups( url: &str, vary_field: &str, @@ -3536,4 +3684,145 @@ mod test_cache { assert_eq!(headers["x-cache-status"], "hit"); assert_eq!(res.text().await.unwrap(), "hello world"); } + + // Ignored until H2 downstream gets the proxy task API + // (write_response_tasks blocks on flow control today). + // multi_thread needed for h2 connection driver tasks. + #[tokio::test(flavor = "multi_thread")] + #[ignore] + async fn test_cache_h2_downstream_stalls() { + init(); + + use h2::client; + use http::Request; + use tokio::net::TcpStream; + use tokio::time::{timeout, Duration}; + + // Step 1: Connection 1 - Open h2 connection to h2c cache proxy (port 6154) and STALL + let tcp1 = TcpStream::connect("127.0.0.1:6154").await.unwrap(); + let (mut h2_client1, h2_conn1) = client::handshake(tcp1).await.unwrap(); + + tokio::spawn(async move { + if let Err(e) = h2_conn1.await { + eprintln!("H2 connection 1 error: {:?}", e); + } + }); + + // Request the cached resource on connection 1 + let request1 = Request::builder() + .uri("http://127.0.0.1/unique/test_h2_stall/download/") + .body(()) + .unwrap(); + + let (response1, _) = h2_client1.send_request(request1, true).unwrap(); + let response1 = response1.await.unwrap(); + assert_eq!(response1.status(), 200); + assert_eq!(response1.headers()["x-cache-status"], "miss"); + + let mut body1 = response1.into_body(); + + // Read first chunk but don't release flow control to stall connection 1 + let first_chunk = body1.data().await.unwrap().unwrap(); + assert!(!first_chunk.is_empty()); + + // Connection 2 - While conn 1 is stalled, try to get the same cached resource + let tcp2 = TcpStream::connect("127.0.0.1:6154").await.unwrap(); + let (mut h2_client2, h2_conn2) = client::handshake(tcp2).await.unwrap(); + + tokio::spawn(async move { + if let Err(e) = h2_conn2.await { + eprintln!("H2 connection 2 error: {:?}", e); + } + }); + + let request2 = Request::builder() + .uri("http://127.0.0.1/unique/test_h2_stall/download/") + .body(()) + .unwrap(); + + let (response2, _) = h2_client2.send_request(request2, true).unwrap(); + + // Try to read, proxy should not be blocked + let response2 = match timeout(Duration::from_secs(5), response2).await { + Ok(Ok(resp)) => resp, + Ok(Err(e)) => panic!("Connection 2 failed: {:?}", e), + Err(_) => panic!("Connection 2 timed out - proxy blocked without proxy task API!"), + }; + + assert_eq!(response2.status(), 200); + assert_eq!(response2.headers()["x-cache-status"], "hit"); + + // Read full response from connection 2 + let mut body2 = response2.into_body(); + let mut received2 = Vec::new(); + while let Some(Ok(chunk)) = timeout(Duration::from_secs(5), body2.data()) + .await + .expect("should not time out waiting for data") + { + let len = chunk.len(); + received2.extend_from_slice(&chunk); + body2.flow_control().release_capacity(len).unwrap(); + } + + assert_eq!( + received2.len(), + 4 * 1024 * 1024, + "Connection 2 should receive full cached response" + ); + + // Clean up: unstall connection 1 + body1 + .flow_control() + .release_capacity(first_chunk.len()) + .unwrap(); + } + + // Test cache population from H2 upstream origin with H1 downstream. + #[tokio::test] + async fn test_cache_upstream_h2_downstream_h1() { + init(); + + let test_url = "http://127.0.0.1:6148/unique/test_h2_upstream/download/"; + + // Step 1: Populate cache from H2 origin (cache miss) + let client = reqwest::Client::new(); + let res = client + .get(test_url) + .header("x-h2", "true") + .header("x-lock", "true") + .header("x-set-cache-control", "public, max-age=60") + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 200); + assert_eq!(res.headers()["x-cache-status"], "miss"); + assert_eq!(res.headers()["origin-http2"], "h2c"); + + let body = res.bytes().await.unwrap(); + assert_eq!( + body.len(), + 4 * 1024 * 1024, + "Should receive full 4MB response" + ); + + // Step 2: Request again and verify cache hit + let res = client + .get(test_url) + .header("x-h2", "true") + .header("x-set-cache-control", "public, max-age=60") + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 200); + assert_eq!(res.headers()["x-cache-status"], "hit"); + + let body = res.bytes().await.unwrap(); + assert_eq!( + body.len(), + 4 * 1024 * 1024, + "Should receive full 4MB from cache" + ); + } } diff --git a/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf b/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf index f19c974c..969695eb 100644 --- a/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf +++ b/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf @@ -311,7 +311,6 @@ http { location /download/ { content_by_lua_block { - ngx.req.read_body() local body = string.rep("A", 4194304) ngx.header["Content-Length"] = #body ngx.print(body) diff --git a/pingora-proxy/tests/utils/server_utils.rs b/pingora-proxy/tests/utils/server_utils.rs index 9361182e..0dccb6dd 100644 --- a/pingora-proxy/tests/utils/server_utils.rs +++ b/pingora-proxy/tests/utils/server_utils.rs @@ -253,8 +253,19 @@ impl ProxyHttp for ExampleProxyHttp { session: &mut Session, _ctx: &mut Self::CTX, ) -> Result<()> { - let req = session.req_header(); - let downstream_compression = req.headers.get("x-downstream-compression").is_some(); + let proxy_tasks_enabled = session + .req_header() + .headers + .get("x-proxy-tasks-enabled") + .is_some(); + if proxy_tasks_enabled { + session.downstream_session.set_proxy_tasks_enabled(true); + } + let downstream_compression = session + .req_header() + .headers + .get("x-downstream-compression") + .is_some(); if downstream_compression { session .downstream_modules_ctx @@ -658,6 +669,12 @@ impl ProxyHttp for ExampleProxyCache { upstream_response.remove_header(&CONTENT_LENGTH); upstream_response.remove_header(&TRANSFER_ENCODING); } + // Allow tests to inject Cache-Control into the upstream response + if let Some(cc) = session.req_header().headers.get("x-set-cache-control") { + upstream_response + .insert_header(http::header::CACHE_CONTROL, cc) + .unwrap(); + } Ok(()) } @@ -823,6 +840,15 @@ fn test_main() { pingora_proxy::http_proxy_service(&my_server.configuration, ExampleProxyCache {}); proxy_service_cache.add_tcp("0.0.0.0:6148"); + // H2C-enabled cache proxy on port 6154 + let mut proxy_service_cache_h2c = + pingora_proxy::http_proxy_service(&my_server.configuration, ExampleProxyCache {}); + let cache_h2c_logic = proxy_service_cache_h2c.app_logic_mut().unwrap(); + let mut cache_h2c_options = HttpServerOptions::default(); + cache_h2c_options.h2c = true; + cache_h2c_logic.server_options = Some(cache_h2c_options); + proxy_service_cache_h2c.add_tcp("0.0.0.0:6154"); + #[cfg(feature = "any_tls")] { let cert_path = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR")); @@ -839,6 +865,7 @@ fn test_main() { Box::new(proxy_service_http), Box::new(proxy_service_http_connect), Box::new(proxy_service_cache), + Box::new(proxy_service_cache_h2c), ]; if let Some(proxy_service_https) = proxy_service_https_opt { diff --git a/pingora/Cargo.toml b/pingora/Cargo.toml index d9fb57c2..7de8640f 100644 --- a/pingora/Cargo.toml +++ b/pingora/Cargo.toml @@ -29,6 +29,7 @@ pingora-load-balancing = { version = "0.8.0", path = "../pingora-load-balancing" pingora-proxy = { version = "0.8.0", path = "../pingora-proxy", optional = true, default-features = false } pingora-cache = { version = "0.8.0", path = "../pingora-cache", optional = true, default-features = false } + # Only used for documenting features, but doesn't work in any other dependency # group :( document-features = { version = "0.2.10", optional = true } @@ -42,6 +43,7 @@ hyper = "0.14" async-trait = { workspace = true } http = { workspace = true } log = { workspace = true } +pingora-prometheus = { version = "0.8.0", path = "../pingora-prometheus" } prometheus = "0.14" once_cell = { workspace = true } bytes = { workspace = true } @@ -152,5 +154,4 @@ document-features = [ "sentry", "connection_filter" ] -prometheus = ["pingora-core/prometheus"] trace = ["pingora-cache?/trace", "pingora-proxy?/trace"] diff --git a/pingora/examples/graceful_upgrade.rs b/pingora/examples/graceful_upgrade.rs new file mode 100644 index 00000000..5a64ff7f --- /dev/null +++ b/pingora/examples/graceful_upgrade.rs @@ -0,0 +1,186 @@ +// Copyright 2026 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! # Graceful Upgrade Example +//! +//! Demonstrates the `daemon_wait_for_ready` feature, which coordinates graceful process upgrades +//! by ensuring the new process is fully bootstrapped before the old one begins shutting down. +//! +//! ## Background +//! +//! In a standard daemonized pingora service, the parent process exits immediately after the +//! daemon fork. During a graceful upgrade, the process manager sends SIGQUIT to the old process +//! as soon as the new process's parent exits — potentially before the new process has finished +//! initializing its backends, consistent hash rings, or other state. This can cause a brief +//! window of 502s. +//! +//! With `daemon_wait_for_ready = true`, the parent instead waits for the daemon to send SIGUSR1 +//! before exiting. The process manager only proceeds to stop the old process once the new one +//! signals that it is ready to serve traffic. +//! +//! ## Service startup order +//! +//! This example sets up the following dependency chain: +//! +//! ```text +//! BackendDiscoveryService HashRingService +//! \ / +//! \ / +//! BootstrapService (socket transfer + SIGUSR1 to parent) +//! ``` +//! +//! The bootstrap service — which handles transferring listening sockets from the old process and +//! sending SIGUSR1 to the parent to signal readiness — only runs after both slow initialization +//! services have completed. This ensures the parent never exits until the new process is truly +//! ready to serve traffic. +//! +//! ## Usage +//! +//! ```bash +//! # Run interactively (no daemonization) +//! cargo run --example graceful_upgrade -p pingora +//! +//! # Run as a daemon +//! cargo run --example graceful_upgrade -p pingora -- -d +//! +//! # Graceful upgrade of a running daemon instance +//! cargo run --example graceful_upgrade -p pingora -- -d -u +//! ``` + +use async_trait::async_trait; +use bytes::Bytes; +use clap::Parser; +use http::{Response, StatusCode}; +use log::info; +use std::num::NonZeroU64; +use std::time::Duration; +use tokio::time::sleep; + +use pingora::apps::http_app::ServeHttp; +use pingora::prelude::Opt; +use pingora::protocols::http::ServerSession; +use pingora::server::configuration::ServerConf; +use pingora::server::{Server, ShutdownWatch}; +use pingora::services::background::{background_service, BackgroundService}; +use pingora::services::listening::Service as ListeningService; + +/// Simulates slow backend discovery — e.g. resolving upstream endpoints from a service registry. +pub struct BackendDiscoveryService; + +#[async_trait] +impl BackgroundService for BackendDiscoveryService { + async fn start(&self, _shutdown: ShutdownWatch) { + info!("BackendDiscoveryService: discovering backends..."); + sleep(Duration::from_secs(2)).await; + info!("BackendDiscoveryService: backends ready"); + } +} + +/// Simulates slow consistent hash ring construction. Runs in parallel with +/// `BackendDiscoveryService`; bootstrap waits for both to complete. +pub struct HashRingService; + +#[async_trait] +impl BackgroundService for HashRingService { + async fn start(&self, _shutdown: ShutdownWatch) { + info!("HashRingService: building consistent hash ring..."); + sleep(Duration::from_secs(3)).await; + info!("HashRingService: hash ring ready"); + } +} + +/// A minimal HTTP service that responds to every request with 200 OK. +/// +/// Accepts an optional `sleep` query parameter specifying how many seconds to wait before +/// responding (e.g. `GET /?sleep=20`). This makes in-flight requests easy to observe during a +/// graceful upgrade: a request with a long sleep that arrives just before the upgrade begins will +/// still be running when the new process starts up, demonstrating that the old process keeps +/// serving until all connections are drained. +pub struct HelloApp; + +#[async_trait] +impl ServeHttp for HelloApp { + async fn response(&self, http_stream: &mut ServerSession) -> Response> { + let delay_secs = http_stream + .req_header() + .uri + .query() + .and_then(|q| { + q.split('&').find_map(|pair| { + let (key, val) = pair.split_once('=')?; + if key == "sleep" { + val.parse::().ok() + } else { + None + } + }) + }) + .unwrap_or(0); + + if delay_secs > 0 { + sleep(Duration::from_secs(delay_secs)).await; + } + + let body = Bytes::from("hello from graceful_upgrade example\n"); + Response::builder() + .status(StatusCode::OK) + .header(http::header::CONTENT_TYPE, "text/plain") + .header(http::header::CONTENT_LENGTH, body.len()) + .body(body.to_vec()) + .unwrap() + } +} + +fn main() { + env_logger::init(); + + let opt = Some(Opt::parse()); + + // Build a ServerConf with daemon_wait_for_ready enabled. + // + // When the server is started with -d (daemon mode), the parent process waits for SIGUSR1 + // before exiting. The daemon sends SIGUSR1 only after the bootstrap service completes — + // which in this example means after both slow services have signaled readiness. + let conf = ServerConf { + daemon: true, + daemon_wait_for_ready: true, + daemon_ready_timeout_seconds: NonZeroU64::new(60), + ..ServerConf::default() + }; + + let mut server = Server::new_with_opt_and_conf(opt, conf); + + // Add the slow initialization services and retain their handles so bootstrap can depend + // on them. Both run in parallel; the slowest (HashRingService at 3s) sets the pace. + let backend_handle = server.add_service(background_service( + "backend_discovery", + BackendDiscoveryService, + )); + let hash_ring_handle = server.add_service(background_service("hash_ring", HashRingService)); + + // bootstrap_as_a_service() registers the bootstrap service (socket transfer from the old + // process + SIGUSR1 to the parent) and returns its ServiceHandle. Declaring the slow + // services as dependencies ensures bootstrap only runs once both are ready. + let bootstrap_handle = server.bootstrap_as_a_service(); + bootstrap_handle.add_dependencies([&backend_handle, &hash_ring_handle]); + + let mut http_service = ListeningService::new("hello_http".to_string(), HelloApp); + http_service.add_tcp("0.0.0.0:8000"); + + server + .add_service(http_service) + .add_dependency(backend_handle); + + server.run_forever(); +} diff --git a/pingora/examples/server.rs b/pingora/examples/server.rs index 1e299140..37a246cd 100644 --- a/pingora/examples/server.rs +++ b/pingora/examples/server.rs @@ -20,8 +20,6 @@ use pingora::protocols::TcpKeepalive; use pingora::server::configuration::Opt; use pingora::server::{Server, ShutdownWatch}; use pingora::services::background::{background_service, BackgroundService}; -#[cfg(feature = "prometheus")] -use pingora::services::listening::Service as ListeningService; use pingora::services::ServiceWithDependents; use async_trait::async_trait; @@ -187,9 +185,7 @@ pub fn main() { &key_path, ); - #[cfg(feature = "prometheus")] - let mut prometheus_service_http = ListeningService::prometheus_http_service(); - #[cfg(feature = "prometheus")] + let mut prometheus_service_http = pingora_prometheus::prometheus_http_service(); prometheus_service_http.add_tcp("127.0.0.1:6150"); let background_service = background_service("example", ExampleBackgroundService {}); @@ -199,7 +195,6 @@ pub fn main() { Box::new(echo_service_http), Box::new(proxy_service), Box::new(proxy_service_ssl), - #[cfg(feature = "prometheus")] Box::new(prometheus_service_http), Box::new(background_service), ];