Skip to content

Commit

Permalink
feat: Add headers() to XrpcClient (#170)
Browse files Browse the repository at this point in the history
* Update XRPC

* Add headers to XrpcClient
  • Loading branch information
sugyan authored May 17, 2024
1 parent d4a3cbb commit 2022414
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 158 deletions.
19 changes: 12 additions & 7 deletions atrium-api/src/agent/inner.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::{Session, SessionStore};
use crate::did_doc::DidDocument;
use async_trait::async_trait;
use atrium_xrpc::error::{Error, XrpcErrorKind};
use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, XrpcResult};
use atrium_xrpc::error::{Error, Result, XrpcErrorKind};
use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
use http::{Method, Request, Response, Uri};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::{Arc, RwLock};
Expand All @@ -25,7 +25,8 @@ where
async fn send_http(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
) -> core::result::Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>
{
self.inner.send_http(request).await
}
}
Expand Down Expand Up @@ -113,7 +114,7 @@ where
&self,
) -> Result<
crate::com::atproto::server::refresh_session::Output,
Error<crate::com::atproto::server::refresh_session::Error>,
crate::com::atproto::server::refresh_session::Error,
> {
let response = self
.inner
Expand All @@ -130,7 +131,7 @@ where
_ => Err(Error::UnexpectedResponseType),
}
}
fn is_expired<O, E>(result: &XrpcResult<O, E>) -> bool
fn is_expired<O, E>(result: &Result<OutputDataOrBytes<O>, E>) -> bool
where
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync,
Expand All @@ -156,7 +157,8 @@ where
async fn send_http(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
) -> core::result::Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>
{
self.inner.send_http(request).await
}
}
Expand All @@ -171,7 +173,10 @@ where
fn base_uri(&self) -> String {
self.inner.base_uri()
}
async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E>
async fn send_xrpc<P, I, O, E>(
&self,
request: &XrpcRequest<P, I>,
) -> Result<OutputDataOrBytes<O>, E>
where
P: Serialize + Send + Sync,
I: Serialize + Send + Sync,
Expand Down
42 changes: 23 additions & 19 deletions atrium-xrpc/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
#![doc = "Error types."]
use http::StatusCode;
use serde::{Deserialize, Serialize};
use std::fmt::{self, Debug, Display};

/// An enum of possible error kinds.
#[derive(thiserror::Error, Debug)]
pub enum Error<E> {
#[error("xrpc response error: {0}")]
XrpcResponse(XrpcError<E>),
#[error("http request error: {0}")]
HttpRequest(#[from] http::Error),
#[error("http client error: {0}")]
HttpClient(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("serde_json error: {0}")]
SerdeJson(#[from] serde_json::Error),
#[error("serde_html_form error: {0}")]
SerdeHtmlForm(#[from] serde_html_form::ser::Error),
#[error("unexpected response type")]
UnexpectedResponseType,
}

/// Type alias to use this library's [`Error`] type in a [`Result`](core::result::Result).
pub type Result<T, E> = core::result::Result<T, Error<E>>;

/// [A standard error response schema](https://atproto.com/specs/xrpc#error-responses)
///
/// ```typescript
Expand All @@ -11,7 +32,7 @@ use std::fmt::{self, Debug, Display};
/// })
/// export type ErrorResponseBody = z.infer<typeof errorResponseBody>
/// ```
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct ErrorResponseBody {
pub error: Option<String>,
pub message: Option<String>,
Expand All @@ -30,7 +51,7 @@ impl Display for ErrorResponseBody {
/// An enum of possible XRPC error kinds.
///
/// Error defined in Lexicon schema or other standard error.
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(untagged)]
pub enum XrpcErrorKind<E> {
Custom(E),
Expand Down Expand Up @@ -66,20 +87,3 @@ impl<E: Display> Display for XrpcError<E> {
Ok(())
}
}

/// An enum of possible error kinds.
#[derive(thiserror::Error, Debug)]
pub enum Error<E> {
#[error("xrpc response error: {0}")]
XrpcResponse(XrpcError<E>),
#[error("http request error: {0}")]
HttpRequest(#[from] http::Error),
#[error("http client error: {0}")]
HttpClient(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("serde_json error: {0}")]
SerdeJson(#[from] serde_json::Error),
#[error("serde_html_form error: {0}")]
SerdeHtmlForm(#[from] serde_html_form::ser::Error),
#[error("unexpected response type")]
UnexpectedResponseType,
}
149 changes: 17 additions & 132 deletions atrium-xrpc/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,134 +1,19 @@
#![doc = include_str!("../README.md")]
pub mod error;
mod traits;
mod types;

use crate::error::{Error, XrpcError, XrpcErrorKind};
use async_trait::async_trait;
use http::{Method, Request, Response};
use serde::{de::DeserializeOwned, Serialize};

/// A type which can be used as a parameter of [`XrpcRequest`].
///
/// JSON serializable data or raw bytes.
pub enum InputDataOrBytes<T>
where
T: Serialize,
{
Data(T),
Bytes(Vec<u8>),
}

/// A type which can be used as a return value of [`XrpcClient::send_xrpc()`].
///
/// JSON deserializable data or raw bytes.
pub enum OutputDataOrBytes<T>
where
T: DeserializeOwned,
{
Data(T),
Bytes(Vec<u8>),
}

/// An abstract HTTP client.
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait HttpClient {
async fn send_http(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>>;
}

/// A request which can be executed with [`XrpcClient::send_xrpc()`].
pub struct XrpcRequest<P, I>
where
I: Serialize,
{
pub method: Method,
pub path: String,
pub parameters: Option<P>,
pub input: Option<InputDataOrBytes<I>>,
pub encoding: Option<String>,
}

pub type XrpcResult<O, E> = Result<OutputDataOrBytes<O>, self::Error<E>>;

/// An abstract XRPC client.
///
/// [`send_xrpc()`](XrpcClient::send_xrpc) method has a default implementation,
/// which wraps the [`HttpClient::send_http()`]` method to handle input and output as an XRPC Request.
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait XrpcClient: HttpClient {
fn base_uri(&self) -> String;
#[allow(unused_variables)]
async fn auth(&self, is_refresh: bool) -> Option<String> {
None
}
async fn send_xrpc<P, I, O, E>(&self, request: &XrpcRequest<P, I>) -> XrpcResult<O, E>
where
P: Serialize + Send + Sync,
I: Serialize + Send + Sync,
O: DeserializeOwned + Send + Sync,
E: DeserializeOwned + Send + Sync,
{
let mut uri = format!("{}/xrpc/{}", self.base_uri(), request.path);
if let Some(p) = &request.parameters {
serde_html_form::to_string(p).map(|qs| {
uri += "?";
uri += &qs;
})?;
};
let mut builder = Request::builder().method(&request.method).uri(&uri);
if let Some(encoding) = &request.encoding {
builder = builder.header(http::header::CONTENT_TYPE, encoding);
}
if let Some(token) = self
.auth(
request.method == Method::POST
&& request.path == "com.atproto.server.refreshSession",
)
.await
{
builder = builder.header(http::header::AUTHORIZATION, format!("Bearer {}", token));
}
let body = if let Some(input) = &request.input {
match input {
InputDataOrBytes::Data(data) => serde_json::to_vec(&data)?,
InputDataOrBytes::Bytes(bytes) => bytes.clone(),
}
} else {
Vec::new()
};
let (parts, body) = self
.send_http(builder.body(body)?)
.await
.map_err(Error::HttpClient)?
.into_parts();
if parts.status.is_success() {
if parts
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map_or(false, |content_type| {
content_type.starts_with("application/json")
})
{
Ok(OutputDataOrBytes::Data(serde_json::from_slice(&body)?))
} else {
Ok(OutputDataOrBytes::Bytes(body))
}
} else {
Err(Error::XrpcResponse(XrpcError {
status: parts.status,
error: serde_json::from_slice::<XrpcErrorKind<E>>(&body).ok(),
}))
}
}
}
pub use crate::error::{Error, Result};
pub use crate::traits::{HttpClient, XrpcClient};
pub use crate::types::{InputDataOrBytes, OutputDataOrBytes, XrpcRequest};

#[cfg(test)]
mod tests {
use super::*;
use crate::error::{XrpcError, XrpcErrorKind};
use crate::{HttpClient, XrpcClient};
use async_trait::async_trait;
use http::{Request, Response};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*;

Expand All @@ -144,7 +29,10 @@ mod tests {
async fn send_http(
&self,
_request: Request<Vec<u8>>,
) -> Result<Response<Vec<u8>>, Box<dyn std::error::Error + Send + Sync + 'static>> {
) -> core::result::Result<
Response<Vec<u8>>,
Box<dyn std::error::Error + Send + Sync + 'static>,
> {
let mut builder = Response::builder().status(self.status);
if self.json {
builder = builder.header(http::header::CONTENT_TYPE, "application/json")
Expand All @@ -162,7 +50,7 @@ mod tests {
mod errors {
use super::*;

async fn get_example<T>(xrpc: &T, params: Parameters) -> Result<Output, crate::Error<Error>>
async fn get_example<T>(xrpc: &T, params: Parameters) -> Result<Output, Error>
where
T: crate::XrpcClient + Send + Sync,
{
Expand Down Expand Up @@ -305,10 +193,7 @@ mod tests {
mod bytes {
use super::*;

async fn get_bytes<T>(
xrpc: &T,
params: Parameters,
) -> Result<Vec<u8>, crate::Error<Error>>
async fn get_bytes<T>(xrpc: &T, params: Parameters) -> Result<Vec<u8>, Error>
where
T: crate::XrpcClient + Send + Sync,
{
Expand Down Expand Up @@ -387,7 +272,7 @@ mod tests {
mod no_content {
use super::*;

async fn create_data<T>(xrpc: &T, input: Input) -> Result<(), crate::Error<Error>>
async fn create_data<T>(xrpc: &T, input: Input) -> Result<(), Error>
where
T: crate::XrpcClient + Send + Sync,
{
Expand Down Expand Up @@ -449,7 +334,7 @@ mod tests {
mod bytes {
use super::*;

async fn create_data<T>(xrpc: &T, input: Vec<u8>) -> Result<Output, crate::Error<Error>>
async fn create_data<T>(xrpc: &T, input: Vec<u8>) -> Result<Output, Error>
where
T: crate::XrpcClient + Send + Sync,
{
Expand Down
Loading

0 comments on commit 2022414

Please sign in to comment.