diff --git a/examples/pub.rs b/examples/pub.rs index 607aa27..0a5d05d 100644 --- a/examples/pub.rs +++ b/examples/pub.rs @@ -1,4 +1,31 @@ +//! Publish to a redis channel example. +//! +//! A simple client that connects to a mini-redis server, and +//! publishes a message on `foo` channel +//! +//! You can test this out by running: +//! +//! cargo run --bin server +//! +//! Then in another terminal run: +//! +//! cargo run --example sub +//! +//! And then in another terminal run: +//! +//! cargo run --example pub + +#![warn(rust_2018_idioms)] + +use mini_redis::{client, Result}; + #[tokio::main] -async fn main() { - unimplemented!(); +async fn main() -> Result<()> { + // Open a connection to the mini-redis address. + let mut client = client::connect("127.0.0.1:6379").await?; + + // publish message `bar` on channel foo + client.publish("foo", "bar".into()).await?; + + Ok(()) } diff --git a/examples/sub.rs b/examples/sub.rs index eda175c..97823e0 100644 --- a/examples/sub.rs +++ b/examples/sub.rs @@ -1,6 +1,38 @@ -/// Subscribe to a redis channel +//! Subscribe to a redis channel example. +//! +//! A simple client that connects to a mini-redis server, subscribes to "foo" and "bar" channels +//! and awaits messages published on those channels +//! +//! You can test this out by running: +//! +//! cargo run --bin server +//! +//! Then in another terminal run: +//! +//! cargo run --example sub +//! +//! And then in another terminal run: +//! +//! cargo run --example pub + +#![warn(rust_2018_idioms)] + +use mini_redis::{client, Result}; +use tokio::stream::StreamExt; #[tokio::main] -async fn main() { - unimplemented!(); +pub async fn main() -> Result<()> { + // Open a connection to the mini-redis address. + let client = client::connect("127.0.0.1:6379").await?; + + + // subscribe to channel foo + let mut result = client.subscribe(vec!["foo".into()]).await?; + + // await messages on channel foo + while let Some(Ok(msg)) = result.next().await { + println!("got message from the channel: {}; message = {:?}", msg.channel, msg.content); + } + + Ok(()) } diff --git a/src/client.rs b/src/client.rs index 0568976..beb3963 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,10 +1,14 @@ +use crate::cmd::{Get, Publish, Set, Subscribe, Unsubscribe}; use crate::{Connection, Frame}; -use crate::cmd::{Get, Set}; use bytes::Bytes; +use std::future::Future; use std::io::{Error, ErrorKind}; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; use tokio::net::{TcpStream, ToSocketAddrs}; +use tokio::stream::Stream; use tracing::{debug, instrument}; /// Mini asynchronous Redis client @@ -47,7 +51,31 @@ impl Client { key: key.to_string(), value: value, expire: None, - }).await + }) + .await + } + + /// publish `message` on the `channel` + #[instrument(skip(self))] + pub async fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result { + self.publish_cmd(Publish { + channel: channel.to_string(), + message: message, + }) + .await + } + + /// subscribe to the list of channels + /// when client sends `SUBSCRIBE`, server's handle for client start's accepting only + /// `SUBSCRIBE` and `UNSUBSCRIBE` commands so we consume client and return Subscribe + #[instrument(skip(self))] + pub async fn subscribe(mut self, channels: Vec) -> crate::Result { + let subscribed_channels = self.subscribe_cmd(Subscribe { channels: channels }).await?; + + Ok(Subscriber { + conn: self.conn, + subscribed_channels, + }) } /// Set the value of a key to `value`. The value expires after `expiration`. @@ -62,7 +90,8 @@ impl Client { key: key.to_string(), value: value.into(), expire: Some(expiration), - }).await + }) + .await } async fn set_cmd(&mut self, cmd: Set) -> crate::Result<()> { @@ -81,6 +110,52 @@ impl Client { } } + async fn publish_cmd(&mut self, cmd: Publish) -> crate::Result { + // Convert the `Publish` command into a frame + let frame = cmd.into_frame(); + + debug!(request = ?frame); + + // Write the frame to the socket + self.conn.write_frame(&frame).await?; + + // Read the response + match self.read_response().await? { + Frame::Integer(response) => Ok(response), + frame => Err(frame.to_error()), + } + } + + async fn subscribe_cmd(&mut self, cmd: Subscribe) -> crate::Result> { + // Convert the `Subscribe` command into a frame + let channels = cmd.channels.clone(); + let frame = cmd.into_frame(); + + debug!(request = ?frame); + + // Write the frame to the socket + self.conn.write_frame(&frame).await?; + + // Read the response + for channel in &channels { + let response = self.read_response().await?; + match response { + Frame::Array(ref frame) => match frame.as_slice() { + [subscribe, schannel] + if subscribe.to_string() == "subscribe" + && &schannel.to_string() == channel => + { + () + } + _ => return Err(response.to_error()), + }, + frame => return Err(frame.to_error()), + }; + } + + Ok(channels) + } + /// Reads a response frame from the socket. If an `Error` frame is read, it /// is converted to `Err`. async fn read_response(&mut self) -> crate::Result { @@ -89,20 +164,155 @@ impl Client { debug!(?response); match response { - Some(Frame::Error(msg)) => { - Err(msg.into()) + Some(Frame::Error(msg)) => Err(msg.into()), + Some(frame) => Ok(frame), + None => { + // Receiving `None` here indicates the server has closed the + // connection without sending a frame. This is unexpected and is + // represented as a "connection reset by peer" error. + let err = Error::new(ErrorKind::ConnectionReset, "connection reset by server"); + + Err(err.into()) } + } + } +} + +pub struct Subscriber { + conn: Connection, + subscribed_channels: Vec, +} + +impl Subscriber { + /// Subscribe to a list of new channels + #[instrument(skip(self))] + pub async fn subscribe(&mut self, channels: Vec) -> crate::Result<()> { + let cmd = Subscribe { channels: channels }; + + let channels = cmd.channels.clone(); + let frame = cmd.into_frame(); + + debug!(request = ?frame); + + // Write the frame to the socket + self.conn.write_frame(&frame).await?; + + // Read the response + for channel in &channels { + let response = self.read_response().await?; + match response { + Frame::Array(ref frame) => match frame.as_slice() { + [subscribe, schannel] + if &subscribe.to_string() == "subscribe" + && &schannel.to_string() == channel => + { + () + } + _ => return Err(response.to_error()), + }, + frame => return Err(frame.to_error()), + }; + } + + self.subscribed_channels.extend(channels); + + Ok(()) + } + + /// Unsubscribe to a list of new channels + #[instrument(skip(self))] + pub async fn unsubscribe(&mut self, channels: Vec) -> crate::Result<()> { + let cmd = Unsubscribe { channels: channels }; + + let mut channels = cmd.channels.clone(); + let frame = cmd.into_frame(); + + debug!(request = ?frame); + + // Write the frame to the socket + self.conn.write_frame(&frame).await?; + + // if the input channel list is empty, server acknowledges as unsubscribing + // from all subscribed channels, so we assert that the unsubscribe list received + // matches the client subscribed one + if channels.is_empty() { + channels = self.subscribed_channels.clone(); + } + + // Read the response + for channel in &channels { + let response = self.read_response().await?; + match response { + Frame::Array(ref frame) => match frame.as_slice() { + [unsubscribe, uchannel] + if &unsubscribe.to_string() == "unsubscribe" + && &uchannel.to_string() == channel => + { + self.subscribed_channels + .retain(|channel| channel != &uchannel.to_string()); + } + _ => return Err(response.to_error()), + }, + frame => return Err(frame.to_error()), + }; + } + + Ok(()) + } + + /// Reads a response frame from the socket. If an `Error` frame is read, it + /// is converted to `Err`. + async fn read_response(&mut self) -> crate::Result { + let response = self.conn.read_frame().await?; + + debug!(?response); + + match response { + Some(Frame::Error(msg)) => Err(msg.into()), Some(frame) => Ok(frame), None => { // Receiving `None` here indicates the server has closed the // connection without sending a frame. This is unexpected and is // represented as a "connection reset by peer" error. - let err = Error::new( - ErrorKind::ConnectionReset, - "connection reset by server"); + let err = Error::new(ErrorKind::ConnectionReset, "connection reset by server"); Err(err.into()) } } } } + +impl Stream for Subscriber { + type Item = crate::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let mut read_frame = Box::pin(self.conn.read_frame()); + match Pin::new(&mut read_frame).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(None)) => Poll::Ready(None), + Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err.into()))), + Poll::Ready(Ok(Some(mframe))) => { + debug!(?mframe); + match mframe { + Frame::Array(ref frame) => match frame.as_slice() { + [message, channel, content] if &message.to_string() == "message" => { + Poll::Ready(Some(Ok(Message { + channel: channel.to_string(), + content: Bytes::from(content.to_string()), + }))) + } + _ => Poll::Ready(Some(Err(mframe.to_error()))), + }, + frame => Poll::Ready(Some(Err(frame.to_error()))), + } + } + } + } +} + +/// A message received on a subscribed channel +#[derive(Debug, Clone)] +pub struct Message { + pub channel: String, + pub content: Bytes, +} diff --git a/src/cmd/publish.rs b/src/cmd/publish.rs index 7e937f2..dc13a7e 100644 --- a/src/cmd/publish.rs +++ b/src/cmd/publish.rs @@ -4,8 +4,8 @@ use bytes::Bytes; #[derive(Debug)] pub struct Publish { - channel: String, - message: Bytes, + pub(crate) channel: String, + pub(crate) message: Bytes, } impl Publish { @@ -24,4 +24,13 @@ impl Publish { dst.write_frame(&response).await?; Ok(()) } + + pub(crate) fn into_frame(self) -> Frame { + let mut frame = Frame::array(); + frame.push_bulk(Bytes::from("publish".as_bytes())); + frame.push_bulk(Bytes::from(self.channel.into_bytes())); + frame.push_bulk(self.message); + + frame + } } diff --git a/src/cmd/subscribe.rs b/src/cmd/subscribe.rs index d97202e..ed6bce5 100644 --- a/src/cmd/subscribe.rs +++ b/src/cmd/subscribe.rs @@ -1,5 +1,5 @@ -use crate::{Command, Connection, Db, Frame, Shutdown}; use crate::cmd::{Parse, ParseError}; +use crate::{Command, Connection, Db, Frame, Shutdown}; use bytes::Bytes; use tokio::select; @@ -7,12 +7,12 @@ use tokio::stream::{StreamExt, StreamMap}; #[derive(Debug)] pub struct Subscribe { - channels: Vec, + pub(crate) channels: Vec, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Unsubscribe { - channels: Vec, + pub(crate) channels: Vec, } impl Subscribe { @@ -147,6 +147,15 @@ impl Subscribe { }; } } + + pub(crate) fn into_frame(self) -> Frame { + let mut frame = Frame::array(); + frame.push_bulk(Bytes::from("subscribe".as_bytes())); + for channel in self.channels { + frame.push_bulk(Bytes::from(channel.into_bytes())); + } + frame + } } impl Unsubscribe { @@ -166,4 +175,13 @@ impl Unsubscribe { Ok(Unsubscribe { channels }) } + + pub(crate) fn into_frame(self) -> Frame { + let mut frame = Frame::array(); + frame.push_bulk(Bytes::from("unsubscribe".as_bytes())); + for channel in self.channels { + frame.push_bulk(Bytes::from(channel.into_bytes())); + } + frame + } } diff --git a/src/server.rs b/src/server.rs index 4710563..2f76f65 100644 --- a/src/server.rs +++ b/src/server.rs @@ -57,8 +57,8 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result< tokio::select! { res = server.run() => { if let Err(err) = res { - // TODO: gracefully handle this error error!(cause = %err, "failed to accept"); + return Err(err.into()); } } _ = shutdown => {