Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
oestradiol committed Sep 19, 2024
1 parent 5ddbf33 commit 02ef85f
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 165 deletions.
15 changes: 10 additions & 5 deletions atrium-streams-client/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ type StreamKind = WebSocketStream<MaybeTlsStream<TcpStream>>;
impl<P: Serialize + Send + Sync> EventStreamClient<<StreamKind as Stream>::Item, Error>
for WssClient<P>
{
async fn connect(&self, mut uri: String) -> Result<impl Stream<Item = <StreamKind as Stream>::Item>, Error> {
async fn connect(
&self,
mut uri: String,
) -> Result<impl Stream<Item = <StreamKind as Stream>::Item>, Error> {
let Self { params } = self;

// Query parameters
Expand All @@ -66,17 +69,19 @@ impl<P: Serialize + Send + Sync> EventStreamClient<<StreamKind as Stream>::Item,
fn get_host(uri: &str) -> Result<(Uri, Box<str>), Error> {
let uri = Uri::from_str(uri).map_err(|_| Error::InvalidUri)?;
let authority = uri.authority().ok_or_else(|| Error::InvalidUri)?.as_str();
let host = authority
.find('@')
.map_or_else(|| authority, |idx| authority.split_at(idx + 1).1);
let host = authority.find('@').map_or_else(|| authority, |idx| authority.split_at(idx + 1).1);
let host = Box::from(host);
Ok((uri, host))
}

/// Generate a request for the given URI and host.
/// It sets the necessary headers for a WebSocket connection,
/// plus the client's `AtprotoProxy` and `AtprotoAcceptLabelers` headers.
async fn gen_request<P: Serialize + Send + Sync>(client: &WssClient<P>, uri: &Uri, host: &str) -> Result<Request<()>, Error> {
async fn gen_request<P: Serialize + Send + Sync>(
client: &WssClient<P>,
uri: &Uri,
host: &str,
) -> Result<Request<()>, Error> {
let mut request = Request::builder()
.uri(uri)
.method("GET")
Expand Down
56 changes: 31 additions & 25 deletions atrium-streams-client/src/client/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@ use std::net::{Ipv4Addr, SocketAddr};
use atrium_streams::{atrium_api::com::atproto::sync::subscribe_repos, client::EventStreamClient};
use atrium_xrpc::http::{header::SEC_WEBSOCKET_KEY, HeaderMap, HeaderValue};
use futures::{SinkExt, StreamExt};
use tokio::{net::{TcpListener, TcpStream}, runtime::Runtime};
use tokio_tungstenite::{tungstenite::{handshake::server::{ErrorResponse, Request, Response}, Message}, WebSocketStream};
use tokio::{
net::{TcpListener, TcpStream},
runtime::Runtime,
};
use tokio_tungstenite::{
tungstenite::{
handshake::server::{ErrorResponse, Request, Response},
Message,
},
WebSocketStream,
};

use crate::WssClient;

Expand Down Expand Up @@ -35,15 +44,13 @@ fn client() {
Runtime::new().unwrap().block_on(fut);
}

async fn wss_client(uri: &str) -> (WssClient<subscribe_repos::ParametersData>, HeaderMap<HeaderValue>) {
let params = subscribe_repos::ParametersData {
cursor: None,
};
async fn wss_client(
uri: &str,
) -> (WssClient<subscribe_repos::ParametersData>, HeaderMap<HeaderValue>) {
let params = subscribe_repos::ParametersData { cursor: None };

let client = WssClient::builder().params(params).build();

let client = WssClient::builder()
.params(params)
.build();

let (uri, host) = get_host(uri).unwrap();
let req = gen_request(&client, &uri, &host).await.unwrap();
let headers = req.headers();
Expand All @@ -54,11 +61,9 @@ async fn wss_client(uri: &str) -> (WssClient<subscribe_repos::ParametersData>, H
async fn mock_wss_server() -> (WebSocketStream<TcpStream>, HeaderMap, String) {
let sock_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 3000));

let listener = TcpListener::bind(sock_addr)
.await
.expect("Failed to bind to port!");
let listener = TcpListener::bind(sock_addr).await.expect("Failed to bind to port!");

let headers: HeaderMap;
let headers: HeaderMap;
let route: String;
let (stream, _) = listener.accept().await.unwrap();
let (headers_, route_, stream) = extract_headers(stream).await;
Expand All @@ -68,21 +73,22 @@ async fn mock_wss_server() -> (WebSocketStream<TcpStream>, HeaderMap, String) {
(stream, headers, route)
}

async fn extract_headers(raw_stream: TcpStream) -> (HeaderMap<HeaderValue>, String, WebSocketStream<TcpStream>) {
async fn extract_headers(
raw_stream: TcpStream,
) -> (HeaderMap<HeaderValue>, String, WebSocketStream<TcpStream>) {
let mut headers: Option<HeaderMap<HeaderValue>> = None;
let mut route: Option<String> = None;

let copy_headers_callback = |request: &Request, response: Response| -> Result<Response, ErrorResponse> {
headers = Some(request.headers().clone());
route = Some(request.uri().path().to_owned());
Ok(response)
};
let copy_headers_callback =
|request: &Request, response: Response| -> Result<Response, ErrorResponse> {
headers = Some(request.headers().clone());
route = Some(request.uri().path().to_owned());
Ok(response)
};

let stream = tokio_tungstenite::accept_hdr_async(
raw_stream,
copy_headers_callback,
).await
let stream = tokio_tungstenite::accept_hdr_async(raw_stream, copy_headers_callback)
.await
.expect("Error during the websocket handshake occurred");

(headers.unwrap(), route.unwrap(), stream)
}
}
82 changes: 12 additions & 70 deletions atrium-streams-client/src/subscriptions/repositories/firehose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,8 @@ impl Handler for Firehose {
&self,
payload: subscribe_repos::Commit,
) -> Result<Option<ProcessedPayload<Self::ProcessedCommitData>>, Self::HandlingError> {
let CommitData {
blobs,
blocks,
commit,
ops,
repo,
rev,
seq,
since,
time,
too_big,
..
} = payload.data;
let CommitData { blobs, blocks, commit, ops, repo, rev, seq, since, time, too_big, .. } =
payload.data;

// If it is too big, the blocks and ops are not sent, so we skip the processing.
let ops_opt = if too_big {
Expand All @@ -172,15 +161,7 @@ impl Handler for Firehose {

Ok(Some(ProcessedPayload {
seq: Some(seq),
data: Self::ProcessedCommitData {
ops: ops_opt,
blobs,
commit,
repo,
rev,
since,
time,
},
data: Self::ProcessedCommitData { ops: ops_opt, blobs, commit, repo, rev, since, time },
}))
}

Expand All @@ -189,12 +170,7 @@ impl Handler for Firehose {
&self,
payload: subscribe_repos::Identity,
) -> Result<Option<ProcessedPayload<Self::ProcessedIdentityData>>, Self::HandlingError> {
let IdentityData {
did,
handle,
seq,
time,
} = payload.data;
let IdentityData { did, handle, seq, time } = payload.data;
Ok(Some(ProcessedPayload {
seq: Some(seq),
data: Self::ProcessedIdentityData { did, handle, time },
Expand All @@ -206,21 +182,10 @@ impl Handler for Firehose {
&self,
payload: subscribe_repos::Account,
) -> Result<Option<ProcessedPayload<Self::ProcessedAccountData>>, Self::HandlingError> {
let AccountData {
did,
seq,
time,
active,
status,
} = payload.data;
let AccountData { did, seq, time, active, status } = payload.data;
Ok(Some(ProcessedPayload {
seq: Some(seq),
data: Self::ProcessedAccountData {
did,
active,
status,
time,
},
data: Self::ProcessedAccountData { did, active, status, time },
}))
}

Expand All @@ -229,12 +194,7 @@ impl Handler for Firehose {
&self,
payload: subscribe_repos::Handle,
) -> Result<Option<ProcessedPayload<Self::ProcessedHandleData>>, Self::HandlingError> {
let HandleData {
did,
handle,
seq,
time,
} = payload.data;
let HandleData { did, handle, seq, time } = payload.data;
Ok(Some(ProcessedPayload {
seq: Some(seq),
data: Self::ProcessedHandleData { did, handle, time },
Expand All @@ -246,19 +206,10 @@ impl Handler for Firehose {
&self,
payload: subscribe_repos::Migrate,
) -> Result<Option<ProcessedPayload<Self::ProcessedMigrateData>>, Self::HandlingError> {
let MigrateData {
did,
migrate_to,
seq,
time,
} = payload.data;
let MigrateData { did, migrate_to, seq, time } = payload.data;
Ok(Some(ProcessedPayload {
seq: Some(seq),
data: Self::ProcessedMigrateData {
did,
migrate_to,
time,
},
data: Self::ProcessedMigrateData { did, migrate_to, time },
}))
}

Expand All @@ -279,10 +230,7 @@ impl Handler for Firehose {
&self,
payload: subscribe_repos::Info,
) -> Result<Option<ProcessedPayload<Self::ProcessedInfoData>>, Self::HandlingError> {
Ok(Some(ProcessedPayload {
seq: None,
data: payload.data,
}))
Ok(Some(ProcessedPayload { seq: None, data: payload.data }))
}
}

Expand Down Expand Up @@ -315,15 +263,9 @@ fn process_op(
// Finds in the map the `Record` with the operation's CID and deserializes it.
// If the item is not found, returns `None`.
let record = match cid.as_ref().and_then(|c| map.get_mut(&c.0)) {
Some(item) => Some(serde_ipld_dagcbor::from_reader::<KnownRecord, _>(
Cursor::new(item),
)?),
Some(item) => Some(serde_ipld_dagcbor::from_reader::<KnownRecord, _>(Cursor::new(item))?),
None => None,
};

Ok(Operation {
action,
path,
record,
})
Ok(Operation { action, path, record })
}
Loading

0 comments on commit 02ef85f

Please sign in to comment.