diff --git a/granian/_futures.py b/granian/_futures.py index 83006ea8..21e4936c 100644 --- a/granian/_futures.py +++ b/granian/_futures.py @@ -2,7 +2,7 @@ def future_watcher_wrapper(inner): async def future_watcher(watcher): try: await inner(watcher.scope, watcher.proto) - except Exception: + except BaseException: watcher.err() raise watcher.done() diff --git a/src/asgi/callbacks.rs b/src/asgi/callbacks.rs index 327bad46..2f01c44f 100644 --- a/src/asgi/callbacks.rs +++ b/src/asgi/callbacks.rs @@ -3,7 +3,7 @@ use pyo3_asyncio::TaskLocals; use tokio::sync::oneshot; use super::{ - io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol}, + io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol, WebsocketDetachedTransport}, types::ASGIScope as Scope, }; use crate::{ @@ -336,7 +336,7 @@ macro_rules! call_impl_rtb_ws { ws: HyperWebsocket, upgrade: UpgradeData, scope: Scope, - ) -> oneshot::Receiver { + ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); @@ -357,7 +357,7 @@ macro_rules! call_impl_rtt_ws { ws: HyperWebsocket, upgrade: UpgradeData, scope: Scope, - ) -> oneshot::Receiver { + ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); diff --git a/src/asgi/http.rs b/src/asgi/http.rs index 679cee66..c41f575a 100644 --- a/src/asgi/http.rs +++ b/src/asgi/http.rs @@ -82,17 +82,22 @@ macro_rules! handle_request_with_ws { let tx_ref = restx.clone(); match $handler_ws(callback, rt, ws, UpgradeData::new(res, restx), scope).await { - Ok(consumed) => { - if !consumed { - let _ = tx_ref - .send( - ResponseBuilder::new() - .status(StatusCode::FORBIDDEN) - .header(HK_SERVER, HV_SERVER) - .body(empty_body()) - .unwrap(), - ) - .await; + Ok(mut detached) => { + match detached.consumed { + false => { + let _ = tx_ref + .send( + ResponseBuilder::new() + .status(StatusCode::FORBIDDEN) + .header(HK_SERVER, HV_SERVER) + .body(empty_body()) + .unwrap(), + ) + .await; + } + true => { + detached.close().await; + } }; } _ => { diff --git a/src/asgi/io.rs b/src/asgi/io.rs index 01a874d1..c7c04bf9 100644 --- a/src/asgi/io.rs +++ b/src/asgi/io.rs @@ -1,8 +1,5 @@ -use futures::{ - sink::SinkExt, - stream::{SplitSink, SplitStream}, - StreamExt, TryStreamExt, -}; +use anyhow::Result; +use futures::{sink::SinkExt, StreamExt, TryStreamExt}; use http_body_util::BodyExt; use hyper::{ body, @@ -14,9 +11,8 @@ use pyo3::types::{PyBytes, PyDict}; use std::{borrow::Cow, sync::Arc}; use tokio::{ fs::File, - sync::{mpsc, oneshot, Mutex}, + sync::{mpsc, oneshot, Mutex, RwLock}, }; -use tokio_tungstenite::WebSocketStream; use tokio_util::io::ReaderStream; use tungstenite::Message; @@ -28,7 +24,7 @@ use crate::{ conversion::BytesToPy, http::{response_404, HTTPRequest, HTTPResponse, HTTPResponseBody, HV_SERVER}, runtime::{empty_future_into_py, future_into_py_futlike, future_into_py_iter, RuntimeRef}, - ws::{HyperWebsocket, UpgradeData}, + ws::{HyperWebsocket, UpgradeData, WSRxStream, WSTxStream}, }; const EMPTY_BYTES: Cow<[u8]> = Cow::Borrowed(b""); @@ -231,50 +227,81 @@ impl ASGIHTTPProtocol { } } +pub(crate) struct WebsocketDetachedTransport { + pub consumed: bool, + rx: Option, + tx: Option, +} + +impl WebsocketDetachedTransport { + pub fn new(consumed: bool, rx: Option, tx: Option) -> Self { + Self { consumed, rx, tx } + } + + pub async fn close(&mut self) { + if let Some(mut tx) = self.tx.take() { + if let Err(err) = tx.close().await { + log::info!("Failed to close websocket with error {:?}", err); + } + } + drop(self.rx.take()); + } +} + #[pyclass(module = "granian._granian")] pub(crate) struct ASGIWebsocketProtocol { rt: RuntimeRef, - tx: Option>, + tx: Option>, websocket: Option, upgrade: Option, - ws_tx: Arc>, Message>>>>, - ws_rx: Arc>>>>>, - accepted: Arc>, - closed: bool, + ws_rx: Arc>>, + ws_tx: Arc>>, + accepted: Arc>, + closed: Arc>, } impl ASGIWebsocketProtocol { - pub fn new(rt: RuntimeRef, tx: oneshot::Sender, websocket: HyperWebsocket, upgrade: UpgradeData) -> Self { + pub fn new( + rt: RuntimeRef, + tx: oneshot::Sender, + websocket: HyperWebsocket, + upgrade: UpgradeData, + ) -> Self { Self { rt, tx: Some(tx), websocket: Some(websocket), upgrade: Some(upgrade), - ws_tx: Arc::new(Mutex::new(None)), ws_rx: Arc::new(Mutex::new(None)), - accepted: Arc::new(Mutex::new(false)), - closed: false, + ws_tx: Arc::new(Mutex::new(None)), + accepted: Arc::new(RwLock::new(false)), + closed: Arc::new(RwLock::new(false)), } } #[inline(always)] fn accept<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { - let mut upgrade = self.upgrade.take().unwrap(); - let websocket = self.websocket.take().unwrap(); + let upgrade = self.upgrade.take(); + let websocket = self.websocket.take(); let accepted = self.accepted.clone(); - let tx = self.ws_tx.clone(); let rx = self.ws_rx.clone(); + let tx = self.ws_tx.clone(); + future_into_py_iter(self.rt.clone(), py, async move { - if (upgrade.send().await).is_ok() { - if let Ok(stream) = websocket.await { - let mut wtx = tx.lock().await; - let mut wrx = rx.lock().await; - let mut accepted = accepted.lock().await; - let (tx, rx) = stream.split(); - *wtx = Some(tx); - *wrx = Some(rx); - *accepted = true; - return Ok(()); + if let Some(mut upgrade) = upgrade { + if (upgrade.send().await).is_ok() { + if let Some(websocket) = websocket { + if let Ok(stream) = websocket.await { + let mut wtx = tx.lock().await; + let mut wrx = rx.lock().await; + let mut accepted = accepted.write().await; + let (tx, rx) = stream.split(); + *wtx = Some(tx); + *wrx = Some(rx); + *accepted = true; + return Ok(()); + } + } } } error_flow!() @@ -285,29 +312,48 @@ impl ASGIWebsocketProtocol { fn send_message<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> { let transport = self.ws_tx.clone(); let message = ws_message_into_rs(py, data); - future_into_py_iter(self.rt.clone(), py, async move { - if let Ok(message) = message { - if let Some(ws) = &mut *(transport.lock().await) { - if (ws.send(message).await).is_ok() { - return Ok(()); - } - }; - }; - error_flow!() + let closed = self.closed.clone(); + + future_into_py_futlike(self.rt.clone(), py, async move { + match message { + Ok(message) => { + if let Some(ws) = &mut *(transport.lock().await) { + match ws.send(message).await { + Ok(()) => return Ok(()), + _ => { + let closed = closed.read().await; + if *closed { + log::info!("Attempted to write to a closed websocket"); + return Ok(()); + } + } + }; + }; + error_flow!() + } + Err(err) => Err(err), + } }) } #[inline(always)] fn close<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { - self.closed = true; - let transport = self.ws_tx.clone(); + let closed = self.closed.clone(); + let ws_rx = self.ws_rx.clone(); + let ws_tx = self.ws_tx.clone(); + future_into_py_iter(self.rt.clone(), py, async move { - if let Some(ws) = &mut *(transport.lock().await) { - if (ws.close().await).is_ok() { - return Ok(()); + match ws_tx.lock().await.take() { + Some(tx) => { + let mut closed = closed.write().await; + *closed = true; + WebsocketDetachedTransport::new(true, ws_rx.lock().await.take(), Some(tx)) + .close() + .await; + Ok(()) } - }; - error_flow!() + _ => error_flow!(), + } }) } @@ -315,34 +361,47 @@ impl ASGIWebsocketProtocol { self.upgrade.is_none() } - pub fn tx(&mut self) -> (Option>, bool) { - (self.tx.take(), self.consumed()) + pub fn tx( + &mut self, + ) -> ( + Option>, + WebsocketDetachedTransport, + ) { + let mut ws_rx = self.ws_rx.blocking_lock(); + let mut ws_tx = self.ws_tx.blocking_lock(); + ( + self.tx.take(), + WebsocketDetachedTransport::new(self.consumed(), ws_rx.take(), ws_tx.take()), + ) } } #[pymethods] impl ASGIWebsocketProtocol { fn receive<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { - let transport = self.ws_rx.clone(); let accepted = self.accepted.clone(); - let closed = self.closed; + let closed = self.closed.clone(); + let transport = self.ws_rx.clone(); + future_into_py_futlike(self.rt.clone(), py, async move { - let accepted = accepted.lock().await; - match (*accepted, closed) { - (false, false) => { - return Python::with_gil(|py| { - let dict = PyDict::new(py); - dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?; - Ok(dict.to_object(py)) - }) - } - (true, false) => {} - _ => return error_flow!(), + let accepted = accepted.read().await; + if !*accepted { + return Python::with_gil(|py| { + let dict = PyDict::new(py); + dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?; + Ok(dict.to_object(py)) + }); } + if let Some(ws) = &mut *(transport.lock().await) { while let Some(recv) = ws.next().await { match recv { Ok(Message::Ping(_)) => continue, + Ok(message @ Message::Close(_)) => { + let mut closed = closed.write().await; + *closed = true; + return ws_message_into_py(message); + } Ok(message) => return ws_message_into_py(message), _ => break, } @@ -353,12 +412,11 @@ impl ASGIWebsocketProtocol { } fn send<'p>(&mut self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> { - match (adapt_message_type(data), self.closed) { - (Ok(ASGIMessageType::WSAccept), _) => self.accept(py), - (Ok(ASGIMessageType::WSClose), false) => self.close(py), - (Ok(ASGIMessageType::WSMessage), false) => self.send_message(py, data), - (Err(err), _) => Err(err.into()), - _ => error_message!(), + match adapt_message_type(data) { + Ok(ASGIMessageType::WSAccept) => self.accept(py), + Ok(ASGIMessageType::WSClose) => self.close(py), + Ok(ASGIMessageType::WSMessage) => self.send_message(py, data), + _ => future_into_py_iter::<_, _, PyErr>(self.rt.clone(), py, async { error_message!() }), } } } diff --git a/src/callbacks.rs b/src/callbacks.rs index 88cb663d..b06308c8 100644 --- a/src/callbacks.rs +++ b/src/callbacks.rs @@ -158,12 +158,18 @@ impl PyFutureAwaitable { } fn result(pyself: PyRef<'_, Self>) -> PyResult { + if pyself.py_cancelled { + return Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled.")); + } + match &pyself.result { Some(res) => { let py = pyself.py(); res.as_ref().map(|v| v.clone_ref(py)).map_err(|err| err.clone_ref(py)) } - _ => Ok(pyself.py().None()), + _ => Err(pyo3::exceptions::asyncio::InvalidStateError::new_err( + "Result is not ready.", + )), } } diff --git a/src/rsgi/callbacks.rs b/src/rsgi/callbacks.rs index 714057e7..8f64780f 100644 --- a/src/rsgi/callbacks.rs +++ b/src/rsgi/callbacks.rs @@ -3,7 +3,7 @@ use pyo3_asyncio::TaskLocals; use tokio::sync::oneshot; use super::{ - io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol}, + io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol, WebsocketDetachedTransport}, types::{PyResponse, PyResponseBody, RSGIScope as Scope}, }; use crate::{ @@ -176,9 +176,7 @@ impl CallbackRunnerWebsocket { macro_rules! callback_impl_done_ws { ($self:expr, $py:expr) => { if let Ok(mut proto) = $self.proto.as_ref($py).try_borrow_mut() { - if let (Some(tx), res) = proto.tx() { - let _ = tx.send(res); - } + let _ = proto.close($py, None); } }; } @@ -313,7 +311,7 @@ macro_rules! call_impl_rtb_ws { ws: HyperWebsocket, upgrade: UpgradeData, scope: Scope, - ) -> oneshot::Receiver<(i32, bool)> { + ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); @@ -334,7 +332,7 @@ macro_rules! call_impl_rtt_ws { ws: HyperWebsocket, upgrade: UpgradeData, scope: Scope, - ) -> oneshot::Receiver<(i32, bool)> { + ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); diff --git a/src/rsgi/http.rs b/src/rsgi/http.rs index ff3e1d24..61c188ad 100644 --- a/src/rsgi/http.rs +++ b/src/rsgi/http.rs @@ -84,22 +84,23 @@ macro_rules! handle_request_with_ws { let tx_ref = restx.clone(); match $handler_ws(callback, rt, ws, UpgradeData::new(res, restx), scope).await { - Ok((status, consumed)) => { - if !consumed { + Ok((status, consumed, handle)) => match (consumed, handle) { + (false, _) => { let _ = tx_ref .send( ResponseBuilder::new() - .status( - StatusCode::from_u16(status as u16) - .unwrap_or(StatusCode::FORBIDDEN), - ) + .status(status as u16) .header(HK_SERVER, HV_SERVER) .body(empty_body()) .unwrap(), ) .await; } - } + (true, Some(handle)) => { + let _ = handle.await; + } + _ => {} + }, _ => { log::error!("RSGI protocol failure"); let _ = tx_ref.send(response_500()).await; diff --git a/src/rsgi/io.rs b/src/rsgi/io.rs index f9bad68f..73f59436 100644 --- a/src/rsgi/io.rs +++ b/src/rsgi/io.rs @@ -1,15 +1,10 @@ -use futures::{ - sink::SinkExt, - stream::{SplitSink, SplitStream}, - StreamExt, TryStreamExt, -}; +use futures::{sink::SinkExt, StreamExt, TryStreamExt}; use http_body_util::BodyExt; use hyper::body; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyString}; use std::{borrow::Cow, sync::Arc}; use tokio::sync::{mpsc, oneshot, Mutex}; -use tokio_tungstenite::WebSocketStream; use tungstenite::Message; use super::{ @@ -20,9 +15,11 @@ use crate::{ conversion::BytesToPy, http::HTTPRequest, runtime::{future_into_py_futlike, future_into_py_iter, Runtime, RuntimeRef}, - ws::{HyperWebsocket, UpgradeData}, + ws::{HyperWebsocket, UpgradeData, WSRxStream, WSStream, WSTxStream}, }; +pub(crate) type WebsocketDetachedTransport = (i32, bool, Option>); + #[pyclass(module = "granian._granian")] pub(crate) struct RSGIHTTPStreamTransport { rt: RuntimeRef, @@ -180,27 +177,37 @@ impl RSGIHTTPProtocol { #[pyclass(module = "granian._granian")] pub(crate) struct RSGIWebsocketTransport { rt: RuntimeRef, - tx: Arc>, Message>>>, - rx: Arc>>>>, + tx: Arc>, + rx: Arc>, + closed: bool, } impl RSGIWebsocketTransport { - pub fn new(rt: RuntimeRef, transport: WebSocketStream>) -> Self { + pub fn new(rt: RuntimeRef, transport: WSStream) -> Self { let (tx, rx) = transport.split(); Self { rt, tx: Arc::new(Mutex::new(tx)), rx: Arc::new(Mutex::new(rx)), + closed: false, } } - pub fn close(&self) { - let stream = self.tx.clone(); - self.rt.spawn(async move { - if let Ok(mut stream) = stream.try_lock() { - let _ = stream.close().await; + pub fn close(&mut self) -> Option> { + if self.closed { + return None; + } + self.closed = true; + + let tx = self.tx.clone(); + let handle = self.rt.spawn(async move { + if let Ok(mut tx) = tx.try_lock() { + if let Err(err) = tx.close().await { + log::info!("Failed to close websocket with error {:?}", err); + } } }); + Some(handle) } } @@ -254,7 +261,7 @@ impl RSGIWebsocketTransport { #[pyclass(module = "granian._granian")] pub(crate) struct RSGIWebsocketProtocol { rt: RuntimeRef, - tx: Option>, + tx: Option>, websocket: Arc>, upgrade: Option, transport: Arc>>>, @@ -264,7 +271,7 @@ pub(crate) struct RSGIWebsocketProtocol { impl RSGIWebsocketProtocol { pub fn new( rt: RuntimeRef, - tx: oneshot::Sender<(i32, bool)>, + tx: oneshot::Sender, websocket: HyperWebsocket, upgrade: UpgradeData, ) -> Self { @@ -281,10 +288,6 @@ impl RSGIWebsocketProtocol { fn consumed(&self) -> bool { self.upgrade.is_none() } - - pub fn tx(&mut self) -> (Option>, (i32, bool)) { - (self.tx.take(), (self.status, self.consumed())) - } } enum WebsocketMessageType { @@ -344,18 +347,19 @@ impl WebsocketInboundTextMessage { #[pymethods] impl RSGIWebsocketProtocol { #[pyo3(signature = (status=None))] - fn close(&mut self, py: Python, status: Option) -> PyResult<()> { + pub fn close(&mut self, py: Python, status: Option) -> PyResult<()> { self.status = status.unwrap_or(0); if let Some(tx) = self.tx.take() { + let mut handle = None; if let Ok(mut transport) = self.transport.try_lock() { if let Some(transport) = transport.take() { - if let Ok(trx) = transport.try_borrow_mut(py) { - trx.close(); + if let Ok(mut trx) = transport.try_borrow_mut(py) { + handle = trx.close(); } } } - let _ = tx.send((self.status, self.consumed())); + let _ = tx.send((self.status, self.consumed(), handle)); } Ok(()) } diff --git a/src/ws.rs b/src/ws.rs index 8035078d..b1a98c38 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -18,8 +18,13 @@ use tungstenite::{ protocol::{Role, WebSocketConfig}, }; +use super::http::HTTPResponse; use super::utils::header_contains_value; +pub(crate) type WSStream = WebSocketStream>; +pub(crate) type WSRxStream = futures::stream::SplitStream; +pub(crate) type WSTxStream = futures::stream::SplitSink; + #[pin_project] #[derive(Debug)] pub struct HyperWebsocket { @@ -29,7 +34,7 @@ pub struct HyperWebsocket { } impl Future for HyperWebsocket { - type Output = Result>, tungstenite::Error>; + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let this = self.project(); @@ -53,16 +58,12 @@ impl Future for HyperWebsocket { pub(crate) struct UpgradeData { response_builder: Option, - response_tx: - Option>>>, + response_tx: Option>, pub consumed: bool, } impl UpgradeData { - pub fn new( - response_builder: Builder, - response_tx: mpsc::Sender>>, - ) -> Self { + pub fn new(response_builder: Builder, response_tx: mpsc::Sender) -> Self { Self { response_builder: Some(response_builder), response_tx: Some(response_tx), @@ -70,12 +71,7 @@ impl UpgradeData { } } - pub async fn send( - &mut self, - ) -> Result< - (), - mpsc::error::SendError>>, - > { + pub async fn send(&mut self) -> Result<(), mpsc::error::SendError> { let res = self .response_builder .take()