diff --git a/atrium-api/src/agent.rs b/atrium-api/src/agent.rs index 3cb1a8c4..eaeddf80 100644 --- a/atrium-api/src/agent.rs +++ b/atrium-api/src/agent.rs @@ -1,8 +1,10 @@ -//! Implementation of [`AtpAgent`]. +//! Implementation of [`Agent`]. #[cfg(feature = "bluesky")] pub mod bluesky; +mod credential_session; mod inner; -pub mod store; +mod session_manager; +mod session_resolver; use crate::client::Service; use crate::did_doc::DidDocument; @@ -11,6 +13,7 @@ use crate::types::TryFromUnknown; use atrium_common::store::SimpleStore; use atrium_xrpc::error::Error; use atrium_xrpc::XrpcClient; +use session_manager::SessionManager; use std::sync::Arc; /// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output) @@ -36,27 +39,30 @@ impl AsRef for AtprotoServiceType { /// An ATP "Agent". /// Manages session token lifecycles and provides convenience methods. -pub struct AtpAgent +pub struct Agent where S: SimpleStore<(), Session> + Send + Sync, + M: SessionManager, T: XrpcClient + Send + Sync, { - store: Arc>, + session_manager: Arc, inner: Arc>, pub api: Service>, } -impl AtpAgent +impl Agent where S: SimpleStore<(), Session> + Send + Sync, T: XrpcClient + Send + Sync, + M: SessionManager, { /// Create a new agent. - pub fn new(xrpc: T, store: S) -> Self { - let store = Arc::new(inner::Store::new(store, xrpc.base_uri())); - let inner = Arc::new(inner::Client::new(Arc::clone(&store), xrpc)); + pub fn new(xrpc: T, store: S, session_manager: M) -> Self { + let inner = Arc::new(inner::Client::new(Arc::clone(&Arc::new(inner::Store(store))), xrpc)); let api = Service::new(Arc::clone(&inner)); - Self { store, inner, api } + let session_manager = Arc::new(session_manager); + + Self { inner, api, session_manager } } /// Start a new session with this agent. pub async fn login( @@ -164,10 +170,10 @@ where #[cfg(test)] mod tests { use super::*; - use atrium_common::store::memory::MemorySimpleStore; use crate::com::atproto::server::create_session::OutputData; use crate::did_doc::{DidDocument, Service, VerificationMethod}; use crate::types::TryIntoUnknown; + use atrium_common::store::memory::MemorySimpleStore; use atrium_xrpc::HttpClient; use http::{HeaderMap, HeaderName, HeaderValue, Request, Response}; use std::collections::HashMap; @@ -295,7 +301,7 @@ mod tests { #[tokio::test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] async fn test_new() { - let agent = AtpAgent::new(MockClient::default(), MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(MockClient::default(), MemorySimpleStore::<(), Session>::default()); assert_eq!(agent.get_session().await, None); } @@ -314,7 +320,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); agent.login("test", "pass").await.expect("login should be succeeded"); assert_eq!(agent.get_session().await, Some(session_data.into())); } @@ -324,7 +330,7 @@ mod tests { responses: MockResponses { ..Default::default() }, ..Default::default() }; - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); agent.login("test", "bad").await.expect_err("login should be failed"); assert_eq!(agent.get_session().await, None); } @@ -350,7 +356,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); agent.store.set_session(session_data.clone().into()).await; let output = agent .api @@ -384,7 +390,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); agent.store.set_session(session_data.clone().into()).await; let output = agent .api @@ -423,7 +429,7 @@ mod tests { ..Default::default() }; let counts = Arc::clone(&client.counts); - let agent = Arc::new(AtpAgent::new(client, MemorySimpleStore::<(), Session>::default())); + let agent = Arc::new(Agent::new(client, MemorySimpleStore::<(), Session>::default())); agent.store.set_session(session_data.clone().into()).await; let handles = (0..3).map(|_| { let agent = Arc::clone(&agent); @@ -473,7 +479,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); assert_eq!(agent.get_session().await, None); agent .resume_session( @@ -493,7 +499,7 @@ mod tests { responses: MockResponses { ..Default::default() }, ..Default::default() }; - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); assert_eq!(agent.get_session().await, None); agent .resume_session(session_data.clone().into()) @@ -523,7 +529,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); agent .resume_session( OutputData { access_jwt: "expired".into(), ..session_data.clone() }.into(), @@ -572,7 +578,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); agent.login("test", "pass").await.expect("login should be succeeded"); assert_eq!(agent.get_endpoint().await, "https://bsky.social"); assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "https://bsky.social"); @@ -607,7 +613,7 @@ mod tests { }, ..Default::default() }; - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); agent.login("test", "pass").await.expect("login should be succeeded"); // not updated assert_eq!(agent.get_endpoint().await, "http://localhost:8080"); @@ -620,7 +626,7 @@ mod tests { async fn test_configure_labelers_header() { let client = MockClient::default(); let headers = Arc::clone(&client.headers); - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); agent .api @@ -683,7 +689,7 @@ mod tests { async fn test_configure_proxy_header() { let client = MockClient::default(); let headers = Arc::clone(&client.headers); - let agent = AtpAgent::new(client, MemorySimpleStore::<(), Session>::default()); + let agent = Agent::new(client, MemorySimpleStore::<(), Session>::default()); agent .api diff --git a/atrium-api/src/agent/credential_session.rs b/atrium-api/src/agent/credential_session.rs new file mode 100644 index 00000000..4ba5969e --- /dev/null +++ b/atrium-api/src/agent/credential_session.rs @@ -0,0 +1,323 @@ +use crate::client::{com, Service}; +use crate::com::atproto::server::{create_account, create_session, delete_session, get_session}; +use crate::did_doc::DidDocument; +use crate::types::string::Did; +use crate::types::TryFromUnknown; +use atrium_common::store::SimpleStore; +use atrium_xrpc::error::Error; +use atrium_xrpc::XrpcClient; +use std::sync::{Arc, RwLock}; + +use super::{inner, Session}; + +pub struct CredentialSession +where + S: SimpleStore<(), Session> + Send + Sync, + T: XrpcClient + Send + Sync, +{ + pub pds_endpoint: RwLock, + pub store: Arc>, + server: com::Service>, +} + +impl CredentialSession +where + S: SimpleStore<(), Session> + Send + Sync, + T: XrpcClient + Send + Sync, +{ + /// Create a new agent. + pub fn new(pds_endpoint: String, xrpc: T, store: S) -> Self { + let store = Arc::new(inner::Store(store)); + + Self { + pds_endpoint: RwLock::new(pds_endpoint), + store: Arc::clone(&store), + server: com::Service::new(Arc::new(inner::Client::new(store, xrpc))), + } + } + + pub async fn get_did(&self) -> Option { + self.store.get_session().await.map(|session| session.did) + } + + // get dispatchUrl() { + // return this.pdsUrl || this.serviceUrl + // } + + pub async fn create_account( + &mut self, + input: create_account::InputData, + ) -> Result<(), Error> { + let result = self.server.atproto.server.create_account(input.clone().into()).await?; + // TODO + // this.session = undefined + // this.persistSession?.('create-failed', undefined) + + let create_account::OutputData { access_jwt, did, did_doc, handle, refresh_jwt } = *result; + + self.store + .set_session( + create_session::OutputData { + access_jwt, + did, + did_doc, + handle, + refresh_jwt, + active: Some(true), + email: input.email.clone(), + // TODO + // emailConfirmed: false, + // emailAuthFactor: false, + email_auth_factor: None, + email_confirmed: None, + status: None, + } + .into(), + ) + .await; + + if let Ok(Some(did_doc)) = result.did_doc.map(DidDocument::try_from_unknown).transpose() { + self.update_endpoint(&did_doc); + } + Ok(()) + } + + /// Start a new session with this agent. + pub async fn login( + &self, + identifier: impl AsRef, + password: impl AsRef, + ) -> Result> { + let result = self + .server + .atproto + .server + .create_session( + create_session::InputData { + auth_factor_token: None, + identifier: identifier.as_ref().into(), + password: password.as_ref().into(), + } + .into(), + ) + .await?; + + self.store.set_session(result.clone()).await; + + if let Some(did_doc) = result + .did_doc + .as_ref() + .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok()) + { + self.update_endpoint(&did_doc); + } + + Ok(result) + } + + pub async fn logout(&self) -> Result<(), Error> { + self.store.clear_session().await; + + self.server.atproto.server.delete_session().await + } + + /// Resume a pre-existing session with this agent. + pub async fn resume_session( + &self, + session: Session, + ) -> Result<(), Error> { + self.store.set_session(session.clone()).await; + + match self.server.atproto.server.get_session().await { + Ok(output) => { + // TODO + assert_eq!(output.data.did, session.data.did); + + if let Some(session) = self.store.get_session().await.as_deref().cloned() { + let session = create_session::OutputData { + did_doc: output.data.did_doc.clone(), + email: output.data.email, + email_confirmed: output.data.email_confirmed, + handle: output.data.handle, + ..session + }; + self.store.set_session(session.into()).await; + } + if let Some(did_doc) = output + .data + .did_doc + .as_ref() + .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok()) + { + self.update_endpoint(&did_doc); + } + Ok(()) + } + Err(err) => { + self.store.clear_session().await; + Err(err) + } + } + } + + // Internal helper to refresh sessions + // - Wraps the actual implementation to ensure only one refresh is attempted at a time. + async fn refresh_session(&self) { + { + let mut is_refreshing = self.is_refreshing.lock().await; + if *is_refreshing { + drop(is_refreshing); + return self.notify.notified().await; + } + *is_refreshing = true; + } + // TODO: Ensure `is_refreshing` is reliably set to false even in the event of unexpected errors within `refresh_session_inner()`. + self.refresh_session_inner().await; + *self.is_refreshing.lock().await = false; + self.notify.notify_waiters(); + } + async fn refresh_session_inner(&self) { + if let Ok(output) = self.call_refresh_session().await { + if let Some(mut session) = self.store.get_session().await { + session.access_jwt = output.data.access_jwt; + session.did = output.data.did; + session.did_doc = output.data.did_doc.clone(); + session.handle = output.data.handle; + session.refresh_jwt = output.data.refresh_jwt; + self.store.set_session(session).await; + } + if let Some(did_doc) = output + .data + .did_doc + .as_ref() + .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok()) + { + self.store.update_endpoint(&did_doc); + } + } else { + self.store.clear_session().await; + } + } + // same as `crate::client::com::atproto::server::Service::refresh_session()` + async fn call_refresh_session( + &self, + ) -> Result< + crate::com::atproto::server::refresh_session::Output, + crate::com::atproto::server::refresh_session::Error, + > { + let response = self + .inner + .send_xrpc::<(), (), _, _>(&XrpcRequest { + method: Method::POST, + nsid: crate::com::atproto::server::refresh_session::NSID.into(), + parameters: None, + input: None, + encoding: None, + }) + .await?; + match response { + OutputDataOrBytes::Data(data) => Ok(data), + _ => Err(Error::UnexpectedResponseType), + } + } + fn is_expired(result: &Result, E>) -> bool + where + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync + Debug, + { + if let Err(Error::XrpcResponse(response)) = &result { + if let Some(XrpcErrorKind::Undefined(body)) = &response.error { + if let Some("ExpiredToken") = &body.error.as_deref() { + return true; + } + } + } + false + } + pub fn get_endpoint(&self) -> String { + self.pds_endpoint.read().expect("failed to read endpoint").clone() + } + pub fn update_endpoint(&self, did_doc: &DidDocument) { + if let Some(endpoint) = did_doc.get_pds_endpoint() { + *self.pds_endpoint.write().expect("failed to write endpoint") = endpoint; + } + } +} + +// public pdsUrl?: URL // The PDS URL, driven by the did doc +// public session?: AtpSessionData +// public refreshSessionPromise: Promise | undefined + +// /** +// * Private {@link ComAtprotoServerNS} used to perform session management API +// * calls on the service endpoint. Calls performed by this agent will not be +// * authenticated using the user's session to allow proper manual configuration +// * of the headers when performing session management operations. +// */ +// protected server = new ComAtprotoServerNS( +// // Note that the use of the codegen "schemas" (to instantiate `this.api`), +// // as well as the use of `ComAtprotoServerNS` will cause this class to +// // reference (way) more code than it actually needs. It is not possible, +// // with the current state of the codegen, to generate a client that only +// // includes the methods that are actually used by this class. This is a +// // known limitation that should be addressed in a future version of the +// // codegen. +// new XrpcClient((url, init) => { +// return (0, this.fetch)(new URL(url, this.serviceUrl), init) +// }, schemas), +// ) + +// constructor( +// ) {} + +// get hasSession() { +// return !!this.session +// } + +// /** +// * Create a new account and hydrate its session in this agent. +// */ +// async createAccount( +// } + +// /** +// * Start a new session with this agent. +// */ +// async login( +// } + +// async logout(): Promise { +// } + +// /** +// * Resume a pre-existing session with this agent. +// */ +// async resumeSession( +// } + +// /** +// * Internal helper to refresh sessions +// * - Wraps the actual implementation in a promise-guard to ensure only +// * one refresh is attempted at a time. +// */ +// async refreshSession(): Promise { +// } + +// /** +// * Internal helper to refresh sessions (actual behavior) +// */ +// private async _refreshSessionInner() { +// } + +// /** +// * Helper to update the pds endpoint dynamically. +// * +// * The session methods (create, resume, refresh) may respond with the user's +// * did document which contains the user's canonical PDS endpoint. That endpoint +// * may differ from the endpoint used to contact the server. We capture that +// * PDS endpoint and update the client to use that given endpoint for future +// * requests. (This helps ensure smooth migrations between PDSes, especially +// * when the PDSes are operated by a single org.) +// */ +// private _updateApiEndpoint(didDoc: unknown) { +// } diff --git a/atrium-api/src/agent/inner.rs b/atrium-api/src/agent/inner.rs index 09c1d5c8..f1bd9a6e 100644 --- a/atrium-api/src/agent/inner.rs +++ b/atrium-api/src/agent/inner.rs @@ -1,8 +1,8 @@ -pub use super::store::Store; use super::Session; use crate::did_doc::DidDocument; use crate::types::string::Did; use crate::types::TryFromUnknown; +use atrium_common::resolver::{Resolver, ThrottledResolver}; use atrium_common::store::SimpleStore; use atrium_xrpc::error::{Error, Result, XrpcErrorKind}; use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest}; @@ -94,7 +94,7 @@ where pub struct Client { store: Arc>, inner: WrapperClient, - is_refreshing: Arc>, + is_refreshing: SessionResolver, notify: Arc, } @@ -155,40 +155,20 @@ where } *is_refreshing = true; } - // TODO: Ensure `is_refreshing` is reliably set to false even in the event of unexpected errors within `refresh_session_inner()`. - let this = &self; - - // same as `crate::client::com::atproto::server::Service::refresh_session()` - let result = async move { - let this = &this; - - let response = this - .inner - .send_xrpc::<(), (), Session, crate::com::atproto::server::refresh_session::Error>( - &XrpcRequest { - method: Method::POST, - nsid: crate::com::atproto::server::refresh_session::NSID.into(), - parameters: None, - input: None, - encoding: None, - }, - ) - .await?; - match response { - OutputDataOrBytes::Data(data) => Ok(data), - _ => Err(Error::UnexpectedResponseType), - } - }; - - if let Ok(output) = result.await { - if let Some(mut session) = this.store.get_session().await { + self.refresh_session_inner().await; + *self.is_refreshing.lock().await = false; + self.notify.notify_waiters(); + } + async fn refresh_session_inner(&self) { + if let Ok(output) = self.call_refresh_session().await { + if let Some(mut session) = self.store.get_session().await { session.access_jwt = output.data.access_jwt; session.did = output.data.did; session.did_doc = output.data.did_doc.clone(); session.handle = output.data.handle; session.refresh_jwt = output.data.refresh_jwt; - this.store.set_session(session).await; + self.store.set_session(session).await; } if let Some(did_doc) = output .data @@ -196,17 +176,34 @@ where .as_ref() .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok()) { - this.store.update_endpoint(&did_doc); + self.store.update_endpoint(&did_doc); } } else { - this.store.clear_session().await; + self.store.clear_session().await; + } + } + // same as `crate::client::com::atproto::server::Service::refresh_session()` + async fn call_refresh_session( + &self, + ) -> Result< + crate::com::atproto::server::refresh_session::Output, + crate::com::atproto::server::refresh_session::Error, + > { + let response = self + .inner + .send_xrpc::<(), (), _, _>(&XrpcRequest { + method: Method::POST, + nsid: crate::com::atproto::server::refresh_session::NSID.into(), + parameters: None, + input: None, + encoding: None, + }) + .await?; + match response { + OutputDataOrBytes::Data(data) => Ok(data), + _ => Err(Error::UnexpectedResponseType), } - - *self.is_refreshing.lock().await = false; - - self.notify.notify_waiters(); } - fn is_expired(result: &Result, E>) -> bool where O: DeserializeOwned + Send + Sync, @@ -280,3 +277,23 @@ where } } } + +pub struct Store(S); + +impl Store +where + S: Resolver, +{ + // pub fn new(inner: S, initial_endpoint: String) -> Self { + // Self { inner, endpoint: RwLock::new(initial_endpoint) } + // } + pub async fn get_session(&self) -> Option { + self.0.get(&()).await.expect("todo") + } + pub async fn set_session(&self, session: Session) { + self.0.set((), session).await.expect("todo") + } + pub async fn clear_session(&self) { + self.0.del(&()).await.expect("todo") + } +} diff --git a/atrium-api/src/agent/session_manager.rs b/atrium-api/src/agent/session_manager.rs new file mode 100644 index 00000000..32b12277 --- /dev/null +++ b/atrium-api/src/agent/session_manager.rs @@ -0,0 +1,7 @@ +use atrium_xrpc::HttpClient; + +pub trait SessionManager +where + T: HttpClient, +{ +} diff --git a/atrium-api/src/agent/session_resolver.rs b/atrium-api/src/agent/session_resolver.rs new file mode 100644 index 00000000..08d27b1a --- /dev/null +++ b/atrium-api/src/agent/session_resolver.rs @@ -0,0 +1,31 @@ +use std::{future::Future, pin::Pin, sync::Arc}; + +use atrium_common::resolver::{Resolver, ThrottledResolver}; +use atrium_xrpc::{HttpClient, XrpcClient}; + +use crate::{ + client::com::Service, + com::atproto::server::{create_account, create_session}, + error::Error, +}; + +use super::Session; + +pub type Resolution<'f, T> = Pin + Send + 'f>>; + +pub struct SessionResolver; + +impl Resolver for SessionResolver +where + F: FnOnce(Option) -> Resolution<'static, Result> + Send + 'static, + E: std::error::Error, +{ + type Input = dyn FnOnce(Option) -> Resolution<'static, Result> + Send + 'static; + type Output = (); + + async fn resolve(&self, input: Box) -> Result> { + let ok = input; + } +} + +pub type ThrottledSessionResolver = ThrottledResolver, Error>; diff --git a/atrium-api/src/agent/store.rs b/atrium-api/src/agent/store.rs deleted file mode 100644 index 5303ccdf..00000000 --- a/atrium-api/src/agent/store.rs +++ /dev/null @@ -1,35 +0,0 @@ -use super::Session; -use crate::did_doc::DidDocument; -use atrium_common::store::SimpleStore; -use std::sync::RwLock; - -pub struct Store { - inner: S, - pub endpoint: RwLock, -} - -impl Store -where - S: SimpleStore<(), Session>, -{ - pub fn new(inner: S, initial_endpoint: String) -> Self { - Self { inner, endpoint: RwLock::new(initial_endpoint) } - } - pub fn get_endpoint(&self) -> String { - self.endpoint.read().expect("failed to read endpoint").clone() - } - pub fn update_endpoint(&self, did_doc: &DidDocument) { - if let Some(endpoint) = did_doc.get_pds_endpoint() { - *self.endpoint.write().expect("failed to write endpoint") = endpoint; - } - } - pub async fn get_session(&self) -> Option { - self.inner.get(&()).await.expect("todo") - } - pub async fn set_session(&self, session: Session) { - self.inner.set((), session).await.expect("todo") - } - pub async fn clear_session(&self) { - self.inner.del(&()).await.expect("todo") - } -} diff --git a/atrium-common/src/resolver.rs b/atrium-common/src/resolver.rs index d0fc3130..ee82f9cb 100644 --- a/atrium-common/src/resolver.rs +++ b/atrium-common/src/resolver.rs @@ -10,58 +10,55 @@ use std::future::Future; use std::hash::Hash; #[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] -pub trait Resolver -where - E: std::error::Error, -{ +pub trait Resolver { type Input: ?Sized; type Output; + type Error; fn resolve( &self, input: &Self::Input, - ) -> impl Future, E>>; + ) -> impl Future, Self::Error>>; } -pub trait Cacheable +pub trait Cacheable where - Self: Sized + Resolver, + Self: Sized + Resolver, Self::Input: Sized, - E: std::error::Error, + Self::Error: std::error::Error, { - fn cached(self, config: CachedResolverConfig) -> CachedResolver; + fn cached(self, config: CachedResolverConfig) -> CachedResolver; } -impl Cacheable for R +impl Cacheable for R where - R: Sized + Resolver, + R: Sized + Resolver, R::Input: Sized + Hash + Eq + Send + Sync + 'static, R::Output: Clone + Send + Sync + 'static, - E: std::error::Error + Send + Sync + 'static, + R::Error: std::error::Error + Send + Sync + 'static, { - fn cached(self, config: CachedResolverConfig) -> CachedResolver { + fn cached(self, config: CachedResolverConfig) -> CachedResolver { CachedResolver::new(self, config) } } -pub trait Throttleable +pub trait Throttleable where - Self: Sized + Resolver, + Self: Sized + Resolver, Self::Input: Sized, - E: std::error::Error, { - fn throttled(self) -> ThrottledResolver; + fn throttled(self) -> ThrottledResolver; } -impl Throttleable for R +impl Throttleable for R where - R: Sized + Resolver, + R: Resolver, R::Input: Clone + Hash + Eq + Send + Sync + 'static, R::Output: Clone + Send + Sync + 'static, - E: std::error::Error + Send + Sync + 'static, + R::Error: std::error::Error + Send + Sync + 'static, { - fn throttled(self) -> ThrottledResolver { - ThrottledResolver::new(self) + fn throttled(self) -> ThrottledResolver { + ThrottledResolver::new(self, Default::default()) } } @@ -90,9 +87,10 @@ mod tests { counts: Arc>>, } - impl Resolver for MockResolver { + impl Resolver for MockResolver { type Input = String; type Output = String; + type Error = Error; async fn resolve(&self, input: &Self::Input) -> Result> { sleep(Duration::from_millis(10)).await; diff --git a/atrium-common/src/resolver/cached_resolver.rs b/atrium-common/src/resolver/cached_resolver.rs index bb6bb36d..e20dcb51 100644 --- a/atrium-common/src/resolver/cached_resolver.rs +++ b/atrium-common/src/resolver/cached_resolver.rs @@ -21,39 +21,39 @@ pub struct CachedResolverConfig { pub time_to_live: Option, } -pub struct CachedResolver +pub struct CachedResolver where - R: Resolver, + R: Resolver, R::Input: Sized, - E: Error, { resolver: R, cache: CacheImpl, } -impl CachedResolver +impl CachedResolver where - R: Resolver, + R: Resolver, R::Input: Sized + Hash + Eq + Send + Sync + 'static, R::Output: Clone + Send + Sync + 'static, - E: Error + Send + Sync + 'static, + R::Error: Error, { pub fn new(resolver: R, config: CachedResolverConfig) -> Self { Self { resolver, cache: CacheImpl::new(config) } } } -impl Resolver for CachedResolver +impl Resolver for CachedResolver where - R: Resolver + Send + Sync + 'static, - R::Input: Clone + Hash + Eq + Send + Sync + 'static + Debug, + R: Resolver + Sync, + R::Input: Sized + Clone + Hash + Eq + Send + Sync + 'static, R::Output: Clone + Send + Sync + 'static, - E: Error + Send + Sync + 'static, + R::Error: Error, { type Input = R::Input; type Output = R::Output; + type Error = R::Error; - async fn resolve(&self, input: &Self::Input) -> Result, E> { + async fn resolve(&self, input: &Self::Input) -> Result, Self::Error> { if let Some(output) = self.cache.get(input).await { return Ok(Some(output)); } diff --git a/atrium-common/src/resolver/throttled_resolver.rs b/atrium-common/src/resolver/throttled_resolver.rs index 45f9772b..e428fda4 100644 --- a/atrium-common/src/resolver/throttled_resolver.rs +++ b/atrium-common/src/resolver/throttled_resolver.rs @@ -1,5 +1,6 @@ +use crate::store::SimpleStore; + use super::Resolver; -use dashmap::{DashMap, Entry}; use std::error::Error; use std::hash::Hash; use std::sync::Arc; @@ -8,52 +9,44 @@ use tokio::sync::Mutex; type SharedSender = Arc>>>; -pub struct ThrottledResolver -where - R: Resolver, - R::Input: Sized, - E: Error, -{ +pub struct ThrottledResolver { resolver: R, - senders: Arc>>, + senders: Arc, } -impl ThrottledResolver -where - R: Resolver, - R::Input: Clone + Hash + Eq + Send + Sync + 'static, - E: Error + Send + Sync + 'static, -{ - pub fn new(resolver: R) -> Self { - Self { resolver, senders: Arc::new(DashMap::new()) } +impl ThrottledResolver { + pub fn new(resolver: R, senders: S) -> Self { + Self { resolver, senders: Arc::new(senders) } } } -impl Resolver for ThrottledResolver +impl Resolver for ThrottledResolver where - R: Resolver + Send + Sync + 'static, - R::Input: Clone + Hash + Eq + Send + Sync + 'static, - R::Output: Clone + Send + Sync + 'static, - E: Error + Send + Sync + 'static, + R: Resolver + Send + Sync, + R::Input: Clone + Hash + Eq + Send + Sync, + R::Output: Clone + Send + Sync, + R::Error: Error + Send + Sync, + S: SimpleStore> + Sync, { type Input = R::Input; type Output = R::Output; + type Error = R::Error; - async fn resolve(&self, input: &Self::Input) -> Result, E> { - match self.senders.entry(input.clone()) { - Entry::Occupied(occupied) => { - let tx = occupied.get().lock().await.clone(); + async fn resolve(&self, input: &Self::Input) -> Result, Self::Error> { + match self.senders.get(input).await? { + Some(occupied) => { + let tx = occupied.lock().await.clone(); drop(occupied); Ok(tx.subscribe().recv().await.expect("recv")) } - Entry::Vacant(vacant) => { + None => { let (tx, _) = channel(1); - vacant.insert(Arc::new(Mutex::new(tx.clone()))); + self.senders.set(input.clone(), Arc::new(Mutex::new(tx.clone()))); let Some(result) = self.resolver.resolve(input).await.transpose() else { return Ok(None); }; tx.send(result.as_ref().ok().cloned()).ok(); - self.senders.remove(input); + self.senders.del(input); result.map(Some) } } diff --git a/atrium-common/src/store.rs b/atrium-common/src/store.rs index 89f51e47..e57895c9 100644 --- a/atrium-common/src/store.rs +++ b/atrium-common/src/store.rs @@ -1,4 +1,3 @@ -pub mod cached; pub mod memory; use std::error::Error; @@ -11,7 +10,7 @@ where K: Eq + Hash, V: Clone, { - type Error: Error + Send + Sync + 'static; + type Error: Error; fn get(&self, key: &K) -> impl Future, Self::Error>>; fn set(&self, key: K, value: V) -> impl Future>; @@ -19,4 +18,55 @@ where fn clear(&self) -> impl Future>; } -// pub trait SessionStore: SimpleStore> + Clone {} +impl SimpleStore<(), T> for std::sync::Mutex> +where + T: Clone + Send, +{ + type Error = std::convert::Infallible; + + async fn get(&self, _: &()) -> Result, Self::Error> { + Ok(self.lock().as_deref().cloned().expect("todo")) + } + + async fn set(&self, _: (), value: T) -> Result<(), Self::Error> { + *self.lock().expect("todo") = Some(value); + Ok(()) + } + + async fn del(&self, _: &()) -> Result<(), Self::Error> { + self.clear().await.expect("todo"); + Ok(()) + } + + async fn clear(&self) -> Result<(), Self::Error> { + *self.lock().expect("todo") = None; + Ok(()) + } +} + +impl SimpleStore for dashmap::DashMap +where + K: Eq + Hash + Clone + Send + Sync, + V: Clone + Send + Sync, +{ + type Error = std::convert::Infallible; + + async fn get(&self, key: &K) -> Result, Self::Error> { + Ok(self.get(key).as_deref().cloned()) + } + + async fn set(&self, key: K, value: V) -> Result<(), Self::Error> { + self.insert(key, value); + Ok(()) + } + + async fn del(&self, key: &K) -> Result<(), Self::Error> { + self.remove(key); + Ok(()) + } + + async fn clear(&self) -> Result<(), Self::Error> { + self.clear(); + Ok(()) + } +} diff --git a/atrium-common/src/store/cached.rs b/atrium-common/src/store/cached.rs index 39324cee..8736fe42 100644 --- a/atrium-common/src/store/cached.rs +++ b/atrium-common/src/store/cached.rs @@ -10,10 +10,14 @@ use std::{ use chrono::{DateTime, FixedOffset, Utc}; use tokio::sync::broadcast; +use crate::resolver::Resolver; + use super::{memory::MemorySimpleStore, SimpleStore}; pub type Getter<'f, T> = Pin + Send + 'f>>; +pub trait SessionResolver: Resolver {} + pub trait CachedStore: SimpleStore> where K: Clone + Debug + Eq + Hash + Send + Sync + 'static,