Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async websockets related ehnancements #193

Merged
merged 4 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion granian/_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ def future_watcher_wrapper(inner):
async def future_watcher(watcher):
try:
await inner(watcher.scope, watcher.proto)
except Exception:
except BaseException:
watcher.err()
raise
watcher.done()
Expand Down
6 changes: 3 additions & 3 deletions src/asgi/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use pyo3_asyncio::TaskLocals;
use tokio::sync::oneshot;

use super::{
io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol},
io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol, WebsocketDetachedTransport},
types::ASGIScope as Scope,
};
use crate::{
Expand Down Expand Up @@ -336,7 +336,7 @@ macro_rules! call_impl_rtb_ws {
ws: HyperWebsocket,
upgrade: UpgradeData,
scope: Scope,
) -> oneshot::Receiver<bool> {
) -> oneshot::Receiver<WebsocketDetachedTransport> {
let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);

Expand All @@ -357,7 +357,7 @@ macro_rules! call_impl_rtt_ws {
ws: HyperWebsocket,
upgrade: UpgradeData,
scope: Scope,
) -> oneshot::Receiver<bool> {
) -> oneshot::Receiver<WebsocketDetachedTransport> {
let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);

Expand Down
27 changes: 16 additions & 11 deletions src/asgi/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,22 @@ macro_rules! handle_request_with_ws {
let tx_ref = restx.clone();

match $handler_ws(callback, rt, ws, UpgradeData::new(res, restx), scope).await {
Ok(consumed) => {
if !consumed {
let _ = tx_ref
.send(
ResponseBuilder::new()
.status(StatusCode::FORBIDDEN)
.header(HK_SERVER, HV_SERVER)
.body(empty_body())
.unwrap(),
)
.await;
Ok(mut detached) => {
match detached.consumed {
false => {
let _ = tx_ref
.send(
ResponseBuilder::new()
.status(StatusCode::FORBIDDEN)
.header(HK_SERVER, HV_SERVER)
.body(empty_body())
.unwrap(),
)
.await;
}
true => {
detached.close().await;
}
};
}
_ => {
Expand Down
192 changes: 125 additions & 67 deletions src/asgi/io.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use futures::{
sink::SinkExt,
stream::{SplitSink, SplitStream},
StreamExt, TryStreamExt,
};
use anyhow::Result;
use futures::{sink::SinkExt, StreamExt, TryStreamExt};
use http_body_util::BodyExt;
use hyper::{
body,
Expand All @@ -14,9 +11,8 @@ use pyo3::types::{PyBytes, PyDict};
use std::{borrow::Cow, sync::Arc};
use tokio::{
fs::File,
sync::{mpsc, oneshot, Mutex},
sync::{mpsc, oneshot, Mutex, RwLock},
};
use tokio_tungstenite::WebSocketStream;
use tokio_util::io::ReaderStream;
use tungstenite::Message;

Expand All @@ -28,7 +24,7 @@ use crate::{
conversion::BytesToPy,
http::{response_404, HTTPRequest, HTTPResponse, HTTPResponseBody, HV_SERVER},
runtime::{empty_future_into_py, future_into_py_futlike, future_into_py_iter, RuntimeRef},
ws::{HyperWebsocket, UpgradeData},
ws::{HyperWebsocket, UpgradeData, WSRxStream, WSTxStream},
};

const EMPTY_BYTES: Cow<[u8]> = Cow::Borrowed(b"");
Expand Down Expand Up @@ -231,50 +227,81 @@ impl ASGIHTTPProtocol {
}
}

pub(crate) struct WebsocketDetachedTransport {
pub consumed: bool,
rx: Option<WSRxStream>,
tx: Option<WSTxStream>,
}

impl WebsocketDetachedTransport {
pub fn new(consumed: bool, rx: Option<WSRxStream>, tx: Option<WSTxStream>) -> Self {
Self { consumed, rx, tx }
}

pub async fn close(&mut self) {
if let Some(mut tx) = self.tx.take() {
if let Err(err) = tx.close().await {
log::info!("Failed to close websocket with error {:?}", err);
}
}
drop(self.rx.take());
}
}

#[pyclass(module = "granian._granian")]
pub(crate) struct ASGIWebsocketProtocol {
rt: RuntimeRef,
tx: Option<oneshot::Sender<bool>>,
tx: Option<oneshot::Sender<WebsocketDetachedTransport>>,
websocket: Option<HyperWebsocket>,
upgrade: Option<UpgradeData>,
ws_tx: Arc<Mutex<Option<SplitSink<WebSocketStream<hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>>, Message>>>>,
ws_rx: Arc<Mutex<Option<SplitStream<WebSocketStream<hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>>>>>>,
accepted: Arc<Mutex<bool>>,
closed: bool,
ws_rx: Arc<Mutex<Option<WSRxStream>>>,
ws_tx: Arc<Mutex<Option<WSTxStream>>>,
accepted: Arc<tokio::sync::RwLock<bool>>,
closed: Arc<tokio::sync::RwLock<bool>>,
}

impl ASGIWebsocketProtocol {
pub fn new(rt: RuntimeRef, tx: oneshot::Sender<bool>, websocket: HyperWebsocket, upgrade: UpgradeData) -> Self {
pub fn new(
rt: RuntimeRef,
tx: oneshot::Sender<WebsocketDetachedTransport>,
websocket: HyperWebsocket,
upgrade: UpgradeData,
) -> Self {
Self {
rt,
tx: Some(tx),
websocket: Some(websocket),
upgrade: Some(upgrade),
ws_tx: Arc::new(Mutex::new(None)),
ws_rx: Arc::new(Mutex::new(None)),
accepted: Arc::new(Mutex::new(false)),
closed: false,
ws_tx: Arc::new(Mutex::new(None)),
accepted: Arc::new(RwLock::new(false)),
closed: Arc::new(RwLock::new(false)),
}
}

#[inline(always)]
fn accept<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> {
let mut upgrade = self.upgrade.take().unwrap();
let websocket = self.websocket.take().unwrap();
let upgrade = self.upgrade.take();
let websocket = self.websocket.take();
let accepted = self.accepted.clone();
let tx = self.ws_tx.clone();
let rx = self.ws_rx.clone();
let tx = self.ws_tx.clone();

future_into_py_iter(self.rt.clone(), py, async move {
if (upgrade.send().await).is_ok() {
if let Ok(stream) = websocket.await {
let mut wtx = tx.lock().await;
let mut wrx = rx.lock().await;
let mut accepted = accepted.lock().await;
let (tx, rx) = stream.split();
*wtx = Some(tx);
*wrx = Some(rx);
*accepted = true;
return Ok(());
if let Some(mut upgrade) = upgrade {
if (upgrade.send().await).is_ok() {
if let Some(websocket) = websocket {
if let Ok(stream) = websocket.await {
let mut wtx = tx.lock().await;
let mut wrx = rx.lock().await;
let mut accepted = accepted.write().await;
let (tx, rx) = stream.split();
*wtx = Some(tx);
*wrx = Some(rx);
*accepted = true;
return Ok(());
}
}
}
}
error_flow!()
Expand All @@ -285,64 +312,96 @@ impl ASGIWebsocketProtocol {
fn send_message<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
let transport = self.ws_tx.clone();
let message = ws_message_into_rs(py, data);
future_into_py_iter(self.rt.clone(), py, async move {
if let Ok(message) = message {
if let Some(ws) = &mut *(transport.lock().await) {
if (ws.send(message).await).is_ok() {
return Ok(());
}
};
};
error_flow!()
let closed = self.closed.clone();

future_into_py_futlike(self.rt.clone(), py, async move {
match message {
Ok(message) => {
if let Some(ws) = &mut *(transport.lock().await) {
match ws.send(message).await {
Ok(()) => return Ok(()),
_ => {
let closed = closed.read().await;
if *closed {
log::info!("Attempted to write to a closed websocket");
return Ok(());
}
}
};
};
error_flow!()
}
Err(err) => Err(err),
}
})
}

#[inline(always)]
fn close<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> {
self.closed = true;
let transport = self.ws_tx.clone();
let closed = self.closed.clone();
let ws_rx = self.ws_rx.clone();
let ws_tx = self.ws_tx.clone();

future_into_py_iter(self.rt.clone(), py, async move {
if let Some(ws) = &mut *(transport.lock().await) {
if (ws.close().await).is_ok() {
return Ok(());
match ws_tx.lock().await.take() {
Some(tx) => {
let mut closed = closed.write().await;
*closed = true;
WebsocketDetachedTransport::new(true, ws_rx.lock().await.take(), Some(tx))
.close()
.await;
Ok(())
}
};
error_flow!()
_ => error_flow!(),
}
})
}

fn consumed(&self) -> bool {
self.upgrade.is_none()
}

pub fn tx(&mut self) -> (Option<oneshot::Sender<bool>>, bool) {
(self.tx.take(), self.consumed())
pub fn tx(
&mut self,
) -> (
Option<oneshot::Sender<WebsocketDetachedTransport>>,
WebsocketDetachedTransport,
) {
let mut ws_rx = self.ws_rx.blocking_lock();
let mut ws_tx = self.ws_tx.blocking_lock();
(
self.tx.take(),
WebsocketDetachedTransport::new(self.consumed(), ws_rx.take(), ws_tx.take()),
)
}
}

#[pymethods]
impl ASGIWebsocketProtocol {
fn receive<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> {
let transport = self.ws_rx.clone();
let accepted = self.accepted.clone();
let closed = self.closed;
let closed = self.closed.clone();
let transport = self.ws_rx.clone();

future_into_py_futlike(self.rt.clone(), py, async move {
let accepted = accepted.lock().await;
match (*accepted, closed) {
(false, false) => {
return Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?;
Ok(dict.to_object(py))
})
}
(true, false) => {}
_ => return error_flow!(),
let accepted = accepted.read().await;
if !*accepted {
return Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?;
Ok(dict.to_object(py))
});
}

if let Some(ws) = &mut *(transport.lock().await) {
while let Some(recv) = ws.next().await {
match recv {
Ok(Message::Ping(_)) => continue,
Ok(message @ Message::Close(_)) => {
let mut closed = closed.write().await;
*closed = true;
return ws_message_into_py(message);
}
Ok(message) => return ws_message_into_py(message),
_ => break,
}
Expand All @@ -353,12 +412,11 @@ impl ASGIWebsocketProtocol {
}

fn send<'p>(&mut self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> {
match (adapt_message_type(data), self.closed) {
(Ok(ASGIMessageType::WSAccept), _) => self.accept(py),
(Ok(ASGIMessageType::WSClose), false) => self.close(py),
(Ok(ASGIMessageType::WSMessage), false) => self.send_message(py, data),
(Err(err), _) => Err(err.into()),
_ => error_message!(),
match adapt_message_type(data) {
Ok(ASGIMessageType::WSAccept) => self.accept(py),
Ok(ASGIMessageType::WSClose) => self.close(py),
Ok(ASGIMessageType::WSMessage) => self.send_message(py, data),
_ => future_into_py_iter::<_, _, PyErr>(self.rt.clone(), py, async { error_message!() }),
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,18 @@ impl PyFutureAwaitable {
}

fn result(pyself: PyRef<'_, Self>) -> PyResult<PyObject> {
if pyself.py_cancelled {
return Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled."));
}

match &pyself.result {
Some(res) => {
let py = pyself.py();
res.as_ref().map(|v| v.clone_ref(py)).map_err(|err| err.clone_ref(py))
}
_ => Ok(pyself.py().None()),
_ => Err(pyo3::exceptions::asyncio::InvalidStateError::new_err(
"Result is not ready.",
)),
}
}

Expand Down
Loading