diff --git a/atrium-api/Cargo.toml b/atrium-api/Cargo.toml index 246e1690..ee2c398c 100644 --- a/atrium-api/Cargo.toml +++ b/atrium-api/Cargo.toml @@ -13,6 +13,7 @@ keywords.workspace = true [dependencies] atrium-xrpc.workspace = true +atrium-common.workspace = true chrono = { workspace = true, features = ["serde"] } http.workspace = true ipld-core = { workspace = true, features = ["serde"] } diff --git a/atrium-api/README.md b/atrium-api/README.md index 378c24fb..6752e697 100644 --- a/atrium-api/README.md +++ b/atrium-api/README.md @@ -43,14 +43,15 @@ async fn main() -> Result<(), Box> { While `AtpServiceClient` can be used for simple XRPC calls, it is better to use `AtpAgent`, which has practical features such as session management. ```rust,no_run -use atrium_api::agent::{store::MemorySessionStore, AtpAgent}; +use atrium_api::agent::atp_agent::AtpAgent; +use atrium_common::store::memory::MemoryStore; use atrium_xrpc_client::reqwest::ReqwestClient; #[tokio::main] async fn main() -> Result<(), Box> { let agent = AtpAgent::new( ReqwestClient::new("https://bsky.social"), - MemorySessionStore::default(), + MemoryStore::default(), ); agent.login("alice@mail.com", "hunter2").await?; let result = agent diff --git a/atrium-api/src/agent.rs b/atrium-api/src/agent.rs index c61296a7..21c2b7e5 100644 --- a/atrium-api/src/agent.rs +++ b/atrium-api/src/agent.rs @@ -1,21 +1,14 @@ -//! Implementation of [`AtpAgent`] and definitions of [`SessionStore`] for it. +pub mod atp_agent; #[cfg(feature = "bluesky")] pub mod bluesky; mod inner; -pub mod store; +mod session_manager; -use self::store::SessionStore; -use crate::client::Service; -use crate::did_doc::DidDocument; -use crate::types::string::Did; -use crate::types::TryFromUnknown; -use atrium_xrpc::error::Error; -use atrium_xrpc::XrpcClient; +use crate::{client::Service, types::string::Did}; +// pub use atp_agent::{AtpAgent, CredentialSession}; +pub use session_manager::SessionManager; use std::sync::Arc; -/// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output) -pub type Session = crate::com::atproto::server::create_session::Output; - /// Supported proxy targets. #[cfg(feature = "bluesky")] pub type AtprotoServiceType = self::bluesky::AtprotoServiceType; @@ -34,745 +27,24 @@ impl AsRef for AtprotoServiceType { } } -/// An ATP "Agent". -/// Manages session token lifecycles and provides convenience methods. -pub struct AtpAgent +pub struct Agent where - S: SessionStore + Send + Sync, - T: XrpcClient + Send + Sync, + M: SessionManager + Send + Sync, { - store: Arc>, - inner: Arc>, - pub api: Service>, + session_manager: Arc>, + pub api: Service>, } -impl AtpAgent +impl Agent where - S: SessionStore + Send + Sync, - T: XrpcClient + Send + Sync, + M: SessionManager + Send + Sync, { - /// Create a new agent. - pub fn new(xrpc: T, store: S) -> Self { - let store = Arc::new(inner::Store::new(store, xrpc.base_uri())); - let inner = Arc::new(inner::Client::new(Arc::clone(&store), xrpc)); - let api = Service::new(Arc::clone(&inner)); - Self { store, inner, api } - } - /// Start a new session with this agent. - pub async fn login( - &self, - identifier: impl AsRef, - password: impl AsRef, - ) -> Result> { - let result = self - .api - .com - .atproto - .server - .create_session( - crate::com::atproto::server::create_session::InputData { - auth_factor_token: None, - identifier: identifier.as_ref().into(), - password: password.as_ref().into(), - } - .into(), - ) - .await?; - self.store.set_session(result.clone()).await; - if let Some(did_doc) = result - .did_doc - .as_ref() - .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok()) - { - self.store.update_endpoint(&did_doc); - } - Ok(result) - } - /// Resume a pre-existing session with this agent. - pub async fn resume_session( - &self, - session: Session, - ) -> Result<(), Error> { - self.store.set_session(session.clone()).await; - let result = self.api.com.atproto.server.get_session().await; - match result { - Ok(output) => { - assert_eq!(output.data.did, session.data.did); - if let Some(mut session) = self.store.get_session().await { - session.did_doc = output.data.did_doc.clone(); - session.email = output.data.email; - session.email_confirmed = output.data.email_confirmed; - session.handle = output.data.handle; - self.store.set_session(session).await; - } - if let Some(did_doc) = output - .data - .did_doc - .as_ref() - .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok()) - { - self.store.update_endpoint(&did_doc); - } - Ok(()) - } - Err(err) => { - self.store.clear_session().await; - Err(err) - } - } - } - /// Set the current endpoint. - pub fn configure_endpoint(&self, endpoint: String) { - self.inner.configure_endpoint(endpoint); - } - /// Configures the moderation services to be applied on requests. - pub fn configure_labelers_header(&self, labeler_dids: Option>) { - self.inner.configure_labelers_header(labeler_dids); - } - /// Configures the atproto-proxy header to be applied on requests. - pub fn configure_proxy_header(&self, did: Did, service_type: impl AsRef) { - self.inner.configure_proxy_header(did, service_type); - } - /// Configures the atproto-proxy header to be applied on requests. - /// - /// Returns a new client service with the proxy header configured. - pub fn api_with_proxy( - &self, - did: Did, - service_type: impl AsRef, - ) -> Service> { - Service::new(Arc::new(self.inner.clone_with_proxy(did, service_type))) - } - /// Get the current session. - pub async fn get_session(&self) -> Option { - self.store.get_session().await - } - /// Get the current endpoint. - pub async fn get_endpoint(&self) -> String { - self.store.get_endpoint() - } - /// Get the current labelers header. - pub async fn get_labelers_header(&self) -> Option> { - self.inner.get_labelers_header().await - } - /// Get the current proxy header. - pub async fn get_proxy_header(&self) -> Option { - self.inner.get_proxy_header().await - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::agent::store::MemorySessionStore; - use crate::com::atproto::server::create_session::OutputData; - use crate::did_doc::{DidDocument, Service, VerificationMethod}; - use crate::types::TryIntoUnknown; - use atrium_xrpc::HttpClient; - use http::{HeaderMap, HeaderName, HeaderValue, Request, Response}; - use std::collections::HashMap; - use tokio::sync::RwLock; - #[cfg(target_arch = "wasm32")] - use wasm_bindgen_test::wasm_bindgen_test; - - #[derive(Default)] - struct MockResponses { - create_session: Option, - get_session: Option, + pub fn new(session_manager: M) -> Self { + let session_manager = Arc::new(inner::Wrapper::new(session_manager)); + let api = Service::new(session_manager.clone()); + Self { session_manager, api } } - - #[derive(Default)] - struct MockClient { - responses: MockResponses, - counts: Arc>>, - headers: Arc>>>, - } - - impl HttpClient for MockClient { - async fn send_http( - &self, - request: Request>, - ) -> Result>, Box> { - #[cfg(not(target_arch = "wasm32"))] - tokio::time::sleep(std::time::Duration::from_micros(10)).await; - - self.headers.write().await.push(request.headers().clone()); - let builder = - Response::builder().header(http::header::CONTENT_TYPE, "application/json"); - let token = request - .headers() - .get(http::header::AUTHORIZATION) - .and_then(|value| value.to_str().ok()) - .and_then(|value| value.split(' ').last()); - if token == Some("expired") { - return Ok(builder.status(http::StatusCode::BAD_REQUEST).body( - serde_json::to_vec(&atrium_xrpc::error::ErrorResponseBody { - error: Some(String::from("ExpiredToken")), - message: Some(String::from("Token has expired")), - })?, - )?); - } - let mut body = Vec::new(); - if let Some(nsid) = request.uri().path().strip_prefix("/xrpc/") { - *self.counts.write().await.entry(nsid.into()).or_default() += 1; - match nsid { - crate::com::atproto::server::create_session::NSID => { - if let Some(output) = &self.responses.create_session { - body.extend(serde_json::to_vec(output)?); - } - } - crate::com::atproto::server::get_session::NSID => { - if token == Some("access") { - if let Some(output) = &self.responses.get_session { - body.extend(serde_json::to_vec(output)?); - } - } - } - crate::com::atproto::server::refresh_session::NSID => { - if token == Some("refresh") { - body.extend(serde_json::to_vec( - &crate::com::atproto::server::refresh_session::OutputData { - access_jwt: String::from("access"), - active: None, - did: "did:web:example.com".parse().expect("valid"), - did_doc: None, - handle: "example.com".parse().expect("valid"), - refresh_jwt: String::from("refresh"), - status: None, - }, - )?); - } - } - crate::com::atproto::server::describe_server::NSID => { - body.extend(serde_json::to_vec( - &crate::com::atproto::server::describe_server::OutputData { - available_user_domains: Vec::new(), - contact: None, - did: "did:web:example.com".parse().expect("valid"), - invite_code_required: None, - links: None, - phone_verification_required: None, - }, - )?); - } - _ => {} - } - } - if body.is_empty() { - Ok(builder.status(http::StatusCode::UNAUTHORIZED).body(serde_json::to_vec( - &atrium_xrpc::error::ErrorResponseBody { - error: Some(String::from("AuthenticationRequired")), - message: Some(String::from("Invalid identifier or password")), - }, - )?)?) - } else { - Ok(builder.status(http::StatusCode::OK).body(body)?) - } - } - } - - impl XrpcClient for MockClient { - fn base_uri(&self) -> String { - "http://localhost:8080".into() - } - } - - fn session_data() -> OutputData { - OutputData { - access_jwt: String::from("access"), - active: None, - did: "did:web:example.com".parse().expect("valid"), - did_doc: None, - email: None, - email_auth_factor: None, - email_confirmed: None, - handle: "example.com".parse().expect("valid"), - refresh_jwt: String::from("refresh"), - status: None, - } - } - - #[tokio::test] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_new() { - let agent = AtpAgent::new(MockClient::default(), MemorySessionStore::default()); - assert_eq!(agent.get_session().await, None); - } - - #[tokio::test] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_login() { - let session_data = session_data(); - // success - { - let client = MockClient { - responses: MockResponses { - create_session: Some(crate::com::atproto::server::create_session::OutputData { - ..session_data.clone() - }), - ..Default::default() - }, - ..Default::default() - }; - let agent = AtpAgent::new(client, MemorySessionStore::default()); - agent.login("test", "pass").await.expect("login should be succeeded"); - assert_eq!(agent.get_session().await, Some(session_data.into())); - } - // failure with `createSession` error - { - let client = MockClient { - responses: MockResponses { ..Default::default() }, - ..Default::default() - }; - let agent = AtpAgent::new(client, MemorySessionStore::default()); - agent.login("test", "bad").await.expect_err("login should be failed"); - assert_eq!(agent.get_session().await, None); - } - } - - #[tokio::test] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_xrpc_get_session() { - let session_data = session_data(); - let client = MockClient { - responses: MockResponses { - get_session: Some(crate::com::atproto::server::get_session::OutputData { - active: session_data.active, - did: session_data.did.clone(), - did_doc: session_data.did_doc.clone(), - email: session_data.email.clone(), - email_auth_factor: session_data.email_auth_factor, - email_confirmed: session_data.email_confirmed, - handle: session_data.handle.clone(), - status: session_data.status.clone(), - }), - ..Default::default() - }, - ..Default::default() - }; - let agent = AtpAgent::new(client, MemorySessionStore::default()); - agent.store.set_session(session_data.clone().into()).await; - let output = agent - .api - .com - .atproto - .server - .get_session() - .await - .expect("get session should be succeeded"); - assert_eq!(output.did.as_str(), "did:web:example.com"); - } - - #[tokio::test] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_xrpc_get_session_with_refresh() { - let mut session_data = session_data(); - session_data.access_jwt = String::from("expired"); - let client = MockClient { - responses: MockResponses { - get_session: Some(crate::com::atproto::server::get_session::OutputData { - active: session_data.active, - did: session_data.did.clone(), - did_doc: session_data.did_doc.clone(), - email: session_data.email.clone(), - email_auth_factor: session_data.email_auth_factor, - email_confirmed: session_data.email_confirmed, - handle: session_data.handle.clone(), - status: session_data.status.clone(), - }), - ..Default::default() - }, - ..Default::default() - }; - let agent = AtpAgent::new(client, MemorySessionStore::default()); - agent.store.set_session(session_data.clone().into()).await; - let output = agent - .api - .com - .atproto - .server - .get_session() - .await - .expect("get session should be succeeded"); - assert_eq!(output.did.as_str(), "did:web:example.com"); - assert_eq!( - agent.store.get_session().await.map(|session| session.data.access_jwt), - Some("access".into()) - ); - } - - #[cfg(not(target_arch = "wasm32"))] - #[tokio::test] - async fn test_xrpc_get_session_with_duplicated_refresh() { - let mut session_data = session_data(); - session_data.access_jwt = String::from("expired"); - let client = MockClient { - responses: MockResponses { - get_session: Some(crate::com::atproto::server::get_session::OutputData { - active: session_data.active, - did: session_data.did.clone(), - did_doc: session_data.did_doc.clone(), - email: session_data.email.clone(), - email_auth_factor: session_data.email_auth_factor, - email_confirmed: session_data.email_confirmed, - handle: session_data.handle.clone(), - status: session_data.status.clone(), - }), - ..Default::default() - }, - ..Default::default() - }; - let counts = Arc::clone(&client.counts); - let agent = Arc::new(AtpAgent::new(client, MemorySessionStore::default())); - agent.store.set_session(session_data.clone().into()).await; - let handles = (0..3).map(|_| { - let agent = Arc::clone(&agent); - tokio::spawn(async move { agent.api.com.atproto.server.get_session().await }) - }); - let results = futures::future::join_all(handles).await; - for result in &results { - let output = result - .as_ref() - .expect("task should be successfully executed") - .as_ref() - .expect("get session should be succeeded"); - assert_eq!(output.did.as_str(), "did:web:example.com"); - } - assert_eq!( - agent.store.get_session().await.map(|session| session.data.access_jwt), - Some("access".into()) - ); - assert_eq!( - counts.read().await.clone(), - HashMap::from_iter([ - ("com.atproto.server.refreshSession".into(), 1), - ("com.atproto.server.getSession".into(), 3) - ]) - ); - } - - #[tokio::test] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_resume_session() { - let session_data = session_data(); - // success - { - let client = MockClient { - responses: MockResponses { - get_session: Some(crate::com::atproto::server::get_session::OutputData { - active: session_data.active, - did: session_data.did.clone(), - did_doc: session_data.did_doc.clone(), - email: session_data.email.clone(), - email_auth_factor: session_data.email_auth_factor, - email_confirmed: session_data.email_confirmed, - handle: session_data.handle.clone(), - status: session_data.status.clone(), - }), - ..Default::default() - }, - ..Default::default() - }; - let agent = AtpAgent::new(client, MemorySessionStore::default()); - assert_eq!(agent.get_session().await, None); - agent - .resume_session( - OutputData { - email: Some(String::from("test@example.com")), - ..session_data.clone() - } - .into(), - ) - .await - .expect("resume_session should be succeeded"); - assert_eq!(agent.get_session().await, Some(session_data.clone().into())); - } - // failure with `getSession` error - { - let client = MockClient { - responses: MockResponses { ..Default::default() }, - ..Default::default() - }; - let agent = AtpAgent::new(client, MemorySessionStore::default()); - assert_eq!(agent.get_session().await, None); - agent - .resume_session(session_data.clone().into()) - .await - .expect_err("resume_session should be failed"); - assert_eq!(agent.get_session().await, None); - } - } - - #[tokio::test] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_resume_session_with_refresh() { - let session_data = session_data(); - let client = MockClient { - responses: MockResponses { - get_session: Some(crate::com::atproto::server::get_session::OutputData { - active: session_data.active, - did: session_data.did.clone(), - did_doc: session_data.did_doc.clone(), - email: session_data.email.clone(), - email_auth_factor: session_data.email_auth_factor, - email_confirmed: session_data.email_confirmed, - handle: session_data.handle.clone(), - status: session_data.status.clone(), - }), - ..Default::default() - }, - ..Default::default() - }; - let agent = AtpAgent::new(client, MemorySessionStore::default()); - agent - .resume_session( - OutputData { access_jwt: "expired".into(), ..session_data.clone() }.into(), - ) - .await - .expect("resume_session should be succeeded"); - assert_eq!(agent.get_session().await, Some(session_data.clone().into())); - } - - #[tokio::test] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_login_with_diddoc() { - let session_data = session_data(); - let did_doc = DidDocument { - context: None, - id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(), - also_known_as: Some(vec!["at://atproto.com".into()]), - verification_method: Some(vec![VerificationMethod { - id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz#atproto".into(), - r#type: "Multikey".into(), - controller: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(), - public_key_multibase: Some( - "zQ3shXjHeiBuRCKmM36cuYnm7YEMzhGnCmCyW92sRJ9pribSF".into(), - ), - }]), - service: Some(vec![Service { - id: "#atproto_pds".into(), - r#type: "AtprotoPersonalDataServer".into(), - service_endpoint: "https://bsky.social".into(), - }]), - }; - // success - { - let client = MockClient { - responses: MockResponses { - create_session: Some(crate::com::atproto::server::create_session::OutputData { - did_doc: Some( - did_doc - .clone() - .try_into_unknown() - .expect("failed to convert to unknown"), - ), - ..session_data.clone() - }), - ..Default::default() - }, - ..Default::default() - }; - let agent = AtpAgent::new(client, MemorySessionStore::default()); - agent.login("test", "pass").await.expect("login should be succeeded"); - assert_eq!(agent.get_endpoint().await, "https://bsky.social"); - assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "https://bsky.social"); - } - // invalid services - { - let client = MockClient { - responses: MockResponses { - create_session: Some(crate::com::atproto::server::create_session::OutputData { - did_doc: Some( - DidDocument { - service: Some(vec![ - Service { - id: "#pds".into(), // not `#atproto_pds` - r#type: "AtprotoPersonalDataServer".into(), - service_endpoint: "https://bsky.social".into(), - }, - Service { - id: "#atproto_pds".into(), - r#type: "AtprotoPersonalDataServer".into(), - service_endpoint: "htps://bsky.social".into(), // invalid url (not `https`) - }, - ]), - ..did_doc.clone() - } - .try_into_unknown() - .expect("failed to convert to unknown"), - ), - ..session_data.clone() - }), - ..Default::default() - }, - ..Default::default() - }; - let agent = AtpAgent::new(client, MemorySessionStore::default()); - agent.login("test", "pass").await.expect("login should be succeeded"); - // not updated - assert_eq!(agent.get_endpoint().await, "http://localhost:8080"); - assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "http://localhost:8080"); - } - } - - #[tokio::test] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_configure_labelers_header() { - let client = MockClient::default(); - let headers = Arc::clone(&client.headers); - let agent = AtpAgent::new(client, MemorySessionStore::default()); - - agent - .api - .com - .atproto - .server - .describe_server() - .await - .expect("describe_server should be succeeded"); - assert_eq!(headers.read().await.last(), Some(&HeaderMap::new())); - - agent.configure_labelers_header(Some(vec![( - "did:plc:test1".parse().expect("did should be valid"), - false, - )])); - agent - .api - .com - .atproto - .server - .describe_server() - .await - .expect("describe_server should be succeeded"); - assert_eq!( - headers.read().await.last(), - Some(&HeaderMap::from_iter([( - HeaderName::from_static("atproto-accept-labelers"), - HeaderValue::from_static("did:plc:test1"), - )])) - ); - - agent.configure_labelers_header(Some(vec![ - ("did:plc:test1".parse().expect("did should be valid"), true), - ("did:plc:test2".parse().expect("did should be valid"), false), - ])); - agent - .api - .com - .atproto - .server - .describe_server() - .await - .expect("describe_server should be succeeded"); - assert_eq!( - headers.read().await.last(), - Some(&HeaderMap::from_iter([( - HeaderName::from_static("atproto-accept-labelers"), - HeaderValue::from_static("did:plc:test1;redact, did:plc:test2"), - )])) - ); - - assert_eq!( - agent.get_labelers_header().await, - Some(vec![String::from("did:plc:test1;redact"), String::from("did:plc:test2")]) - ); - } - - #[tokio::test] - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] - async fn test_configure_proxy_header() { - let client = MockClient::default(); - let headers = Arc::clone(&client.headers); - let agent = AtpAgent::new(client, MemorySessionStore::default()); - - agent - .api - .com - .atproto - .server - .describe_server() - .await - .expect("describe_server should be succeeded"); - assert_eq!(headers.read().await.last(), Some(&HeaderMap::new())); - - agent.configure_proxy_header( - "did:plc:test1".parse().expect("did should be balid"), - AtprotoServiceType::AtprotoLabeler, - ); - agent - .api - .com - .atproto - .server - .describe_server() - .await - .expect("describe_server should be succeeded"); - assert_eq!( - headers.read().await.last(), - Some(&HeaderMap::from_iter([( - HeaderName::from_static("atproto-proxy"), - HeaderValue::from_static("did:plc:test1#atproto_labeler"), - ),])) - ); - - agent.configure_proxy_header( - "did:plc:test1".parse().expect("did should be balid"), - "atproto_labeler", - ); - agent - .api - .com - .atproto - .server - .describe_server() - .await - .expect("describe_server should be succeeded"); - assert_eq!( - headers.read().await.last(), - Some(&HeaderMap::from_iter([( - HeaderName::from_static("atproto-proxy"), - HeaderValue::from_static("did:plc:test1#atproto_labeler"), - ),])) - ); - - agent - .api_with_proxy( - "did:plc:test2".parse().expect("did should be balid"), - "atproto_labeler", - ) - .com - .atproto - .server - .describe_server() - .await - .expect("describe_server should be succeeded"); - assert_eq!( - headers.read().await.last(), - Some(&HeaderMap::from_iter([( - HeaderName::from_static("atproto-proxy"), - HeaderValue::from_static("did:plc:test2#atproto_labeler"), - ),])) - ); - - agent - .api - .com - .atproto - .server - .describe_server() - .await - .expect("describe_server should be succeeded"); - assert_eq!( - headers.read().await.last(), - Some(&HeaderMap::from_iter([( - HeaderName::from_static("atproto-proxy"), - HeaderValue::from_static("did:plc:test1#atproto_labeler"), - ),])) - ); - - assert_eq!( - agent.get_proxy_header().await, - Some(String::from("did:plc:test1#atproto_labeler")) - ); + pub async fn did(&self) -> Option { + self.session_manager.did().await } } diff --git a/atrium-api/src/agent/atp_agent.rs b/atrium-api/src/agent/atp_agent.rs new file mode 100644 index 00000000..a5f8660c --- /dev/null +++ b/atrium-api/src/agent/atp_agent.rs @@ -0,0 +1,825 @@ +//! Implementation of [`AtpAgent`] and definitions of [`SessionStore`] for it. + +mod inner; + +use crate::{ + client::Service, + did_doc::DidDocument, + types::{string::Did, TryFromUnknown}, +}; +use atrium_common::store::Store; +use atrium_xrpc::{Error, XrpcClient}; +use std::{ops::Deref, sync::Arc}; + +/// Type alias for the [com::atproto::server::create_session::Output](crate::com::atproto::server::create_session::Output) +pub type AtpSession = crate::com::atproto::server::create_session::Output; + +pub struct CredentialSession +where + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, + T: XrpcClient + Send + Sync, +{ + store: Arc>, + inner: Arc>, + pub api: Service>, +} + +impl CredentialSession +where + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, + T: XrpcClient + Send + Sync, +{ + pub fn new(xrpc: T, store: S) -> Self { + let store = Arc::new(inner::Store::new(store, xrpc.base_uri())); + let inner = Arc::new(inner::Client::new(Arc::clone(&store), xrpc)); + Self { + store: Arc::clone(&store), + inner: Arc::clone(&inner), + api: Service::new(Arc::clone(&inner)), + } + } + /// Start a new session with this agent. + pub async fn login( + &self, + identifier: impl AsRef, + password: impl AsRef, + ) -> Result> { + let result = self + .api + .com + .atproto + .server + .create_session( + crate::com::atproto::server::create_session::InputData { + auth_factor_token: None, + identifier: identifier.as_ref().into(), + password: password.as_ref().into(), + } + .into(), + ) + .await?; + self.store.set((), result.clone()).await.map_err(|e| Error::SessionStore(Box::new(e)))?; + if let Some(did_doc) = result + .did_doc + .as_ref() + .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok()) + { + self.store.update_endpoint(&did_doc); + } + Ok(result) + } + /// Resume a pre-existing session with this agent. + pub async fn resume_session( + &self, + session: AtpSession, + ) -> Result<(), Error> { + self.store.set((), session.clone()).await.map_err(|e| Error::SessionStore(Box::new(e)))?; + let result = self.api.com.atproto.server.get_session().await; + match result { + Ok(output) => { + assert_eq!(output.data.did, session.data.did); + if let Some(mut session) = + self.store.get(&()).await.map_err(|e| Error::SessionStore(Box::new(e)))? + { + session.did_doc = output.data.did_doc.clone(); + session.email = output.data.email; + session.email_confirmed = output.data.email_confirmed; + session.handle = output.data.handle; + self.store + .set((), session) + .await + .map_err(|e| Error::SessionStore(Box::new(e)))?; + } + if let Some(did_doc) = output + .data + .did_doc + .as_ref() + .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok()) + { + self.store.update_endpoint(&did_doc); + } + Ok(()) + } + Err(err) => { + self.store.clear().await.map_err(|e| Error::SessionStore(Box::new(e)))?; + Err(err) + } + } + } + /// Set the current endpoint. + pub fn configure_endpoint(&self, endpoint: String) { + self.inner.configure_endpoint(endpoint); + } + /// Configures the moderation services to be applied on requests. + pub fn configure_labelers_header(&self, labeler_dids: Option>) { + self.inner.configure_labelers_header(labeler_dids); + } + /// Configures the atproto-proxy header to be applied on requests. + pub fn configure_proxy_header(&self, did: Did, service_type: impl AsRef) { + self.inner.configure_proxy_header(did, service_type); + } + /// Configures the atproto-proxy header to be applied on requests. + /// + /// Returns a new client service with the proxy header configured. + pub fn api_with_proxy( + &self, + did: Did, + service_type: impl AsRef, + ) -> Service> { + Service::new(Arc::new(self.inner.clone_with_proxy(did, service_type))) + } + /// Get the current session. + pub async fn get_session(&self) -> Option { + self.store.get(&()).await.transpose().and_then(Result::ok) + } + /// Get the current endpoint. + pub async fn get_endpoint(&self) -> String { + self.store.get_endpoint() + } + /// Get the current labelers header. + pub async fn get_labelers_header(&self) -> Option> { + self.inner.get_labelers_header().await + } + /// Get the current proxy header. + pub async fn get_proxy_header(&self) -> Option { + self.inner.get_proxy_header().await + } +} + +/// An ATP "Agent". +/// Manages session token lifecycles and provides convenience methods. +pub struct AtpAgent +where + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, + T: XrpcClient + Send + Sync, +{ + inner: CredentialSession, +} + +impl AtpAgent +where + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, + T: XrpcClient + Send + Sync, +{ + /// Create a new agent. + pub fn new(xrpc: T, store: S) -> Self { + Self { inner: CredentialSession::new(xrpc, store) } + } +} + +impl Deref for AtpAgent +where + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, + T: XrpcClient + Send + Sync, +{ + type Target = CredentialSession; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +#[cfg(test)] +mod tests { + use super::super::AtprotoServiceType; + use super::*; + use crate::com::atproto::server::create_session::OutputData; + use crate::did_doc::{DidDocument, Service, VerificationMethod}; + use crate::types::TryIntoUnknown; + use atrium_common::store::memory::MemoryStore; + use atrium_xrpc::HttpClient; + use http::{HeaderMap, HeaderName, HeaderValue, Request, Response}; + use std::collections::HashMap; + use tokio::sync::RwLock; + #[cfg(target_arch = "wasm32")] + use wasm_bindgen_test::wasm_bindgen_test; + + #[derive(Default)] + struct MockResponses { + create_session: Option, + get_session: Option, + } + + #[derive(Default)] + struct MockClient { + responses: MockResponses, + counts: Arc>>, + headers: Arc>>>, + } + + impl HttpClient for MockClient { + async fn send_http( + &self, + request: Request>, + ) -> Result>, Box> { + #[cfg(not(target_arch = "wasm32"))] + tokio::time::sleep(std::time::Duration::from_micros(10)).await; + + self.headers.write().await.push(request.headers().clone()); + let builder = + Response::builder().header(http::header::CONTENT_TYPE, "application/json"); + let token = request + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()) + .and_then(|value| value.split(' ').last()); + if token == Some("expired") { + return Ok(builder.status(http::StatusCode::BAD_REQUEST).body( + serde_json::to_vec(&atrium_xrpc::error::ErrorResponseBody { + error: Some(String::from("ExpiredToken")), + message: Some(String::from("Token has expired")), + })?, + )?); + } + let mut body = Vec::new(); + if let Some(nsid) = request.uri().path().strip_prefix("/xrpc/") { + *self.counts.write().await.entry(nsid.into()).or_default() += 1; + match nsid { + crate::com::atproto::server::create_session::NSID => { + if let Some(output) = &self.responses.create_session { + body.extend(serde_json::to_vec(output)?); + } + } + crate::com::atproto::server::get_session::NSID => { + if token == Some("access") { + if let Some(output) = &self.responses.get_session { + body.extend(serde_json::to_vec(output)?); + } + } + } + crate::com::atproto::server::refresh_session::NSID => { + if token == Some("refresh") { + body.extend(serde_json::to_vec( + &crate::com::atproto::server::refresh_session::OutputData { + access_jwt: String::from("access"), + active: None, + did: "did:web:example.com".parse().expect("valid"), + did_doc: None, + handle: "example.com".parse().expect("valid"), + refresh_jwt: String::from("refresh"), + status: None, + }, + )?); + } + } + crate::com::atproto::server::describe_server::NSID => { + body.extend(serde_json::to_vec( + &crate::com::atproto::server::describe_server::OutputData { + available_user_domains: Vec::new(), + contact: None, + did: "did:web:example.com".parse().expect("valid"), + invite_code_required: None, + links: None, + phone_verification_required: None, + }, + )?); + } + _ => {} + } + } + if body.is_empty() { + Ok(builder.status(http::StatusCode::UNAUTHORIZED).body(serde_json::to_vec( + &atrium_xrpc::error::ErrorResponseBody { + error: Some(String::from("AuthenticationRequired")), + message: Some(String::from("Invalid identifier or password")), + }, + )?)?) + } else { + Ok(builder.status(http::StatusCode::OK).body(body)?) + } + } + } + + impl XrpcClient for MockClient { + fn base_uri(&self) -> String { + "http://localhost:8080".into() + } + } + + fn session_data() -> OutputData { + OutputData { + access_jwt: String::from("access"), + active: None, + did: "did:web:example.com".parse().expect("valid"), + did_doc: None, + email: None, + email_auth_factor: None, + email_confirmed: None, + handle: "example.com".parse().expect("valid"), + refresh_jwt: String::from("refresh"), + status: None, + } + } + + #[tokio::test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_new() { + let agent = AtpAgent::new(MockClient::default(), MemoryStore::default()); + assert_eq!(agent.get_session().await, None); + } + + #[tokio::test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_login() { + let session_data = session_data(); + // success + { + let client = MockClient { + responses: MockResponses { + create_session: Some(crate::com::atproto::server::create_session::OutputData { + ..session_data.clone() + }), + ..Default::default() + }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemoryStore::default()); + agent.login("test", "pass").await.expect("login should be succeeded"); + assert_eq!(agent.get_session().await, Some(session_data.into())); + } + // failure with `createSession` error + { + let client = MockClient { + responses: MockResponses { ..Default::default() }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemoryStore::default()); + agent.login("test", "bad").await.expect_err("login should be failed"); + assert_eq!(agent.get_session().await, None); + } + } + + #[tokio::test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_xrpc_get_session() { + let session_data = session_data(); + let client = MockClient { + responses: MockResponses { + get_session: Some(crate::com::atproto::server::get_session::OutputData { + active: session_data.active, + did: session_data.did.clone(), + did_doc: session_data.did_doc.clone(), + email: session_data.email.clone(), + email_auth_factor: session_data.email_auth_factor, + email_confirmed: session_data.email_confirmed, + handle: session_data.handle.clone(), + status: session_data.status.clone(), + }), + ..Default::default() + }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemoryStore::default()); + agent + .store + .set((), session_data.clone().into()) + .await + .expect("set session should be succeeded"); + let output = agent + .api + .com + .atproto + .server + .get_session() + .await + .expect("get session should be succeeded"); + assert_eq!(output.did.as_str(), "did:web:example.com"); + } + + #[tokio::test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_xrpc_get_session_with_refresh() { + let mut session_data = session_data(); + session_data.access_jwt = String::from("expired"); + let client = MockClient { + responses: MockResponses { + get_session: Some(crate::com::atproto::server::get_session::OutputData { + active: session_data.active, + did: session_data.did.clone(), + did_doc: session_data.did_doc.clone(), + email: session_data.email.clone(), + email_auth_factor: session_data.email_auth_factor, + email_confirmed: session_data.email_confirmed, + handle: session_data.handle.clone(), + status: session_data.status.clone(), + }), + ..Default::default() + }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemoryStore::default()); + agent + .store + .set((), session_data.clone().into()) + .await + .expect("set session should be succeeded"); + let output = agent + .api + .com + .atproto + .server + .get_session() + .await + .expect("get session should be succeeded"); + assert_eq!(output.did.as_str(), "did:web:example.com"); + assert_eq!( + agent + .store + .get(&()) + .await + .expect("get session should be succeeded") + .map(|session| session.data.access_jwt), + Some("access".into()) + ); + } + + #[cfg(not(target_arch = "wasm32"))] + #[tokio::test] + async fn test_xrpc_get_session_with_duplicated_refresh() { + let mut session_data = session_data(); + session_data.access_jwt = String::from("expired"); + let client = MockClient { + responses: MockResponses { + get_session: Some(crate::com::atproto::server::get_session::OutputData { + active: session_data.active, + did: session_data.did.clone(), + did_doc: session_data.did_doc.clone(), + email: session_data.email.clone(), + email_auth_factor: session_data.email_auth_factor, + email_confirmed: session_data.email_confirmed, + handle: session_data.handle.clone(), + status: session_data.status.clone(), + }), + ..Default::default() + }, + ..Default::default() + }; + let counts = Arc::clone(&client.counts); + let agent = Arc::new(AtpAgent::new(client, MemoryStore::default())); + agent + .store + .set((), session_data.clone().into()) + .await + .expect("set session should be succeeded"); + let handles = (0..3).map(|_| { + let agent = Arc::clone(&agent); + tokio::spawn(async move { agent.api.com.atproto.server.get_session().await }) + }); + let results = futures::future::join_all(handles).await; + for result in &results { + let output = result + .as_ref() + .expect("task should be successfully executed") + .as_ref() + .expect("get session should be succeeded"); + assert_eq!(output.did.as_str(), "did:web:example.com"); + } + assert_eq!( + agent + .store + .get(&()) + .await + .expect("get session should be succeeded") + .map(|session| session.data.access_jwt), + Some("access".into()) + ); + assert_eq!( + counts.read().await.clone(), + HashMap::from_iter([ + ("com.atproto.server.refreshSession".into(), 1), + ("com.atproto.server.getSession".into(), 3) + ]) + ); + } + + #[tokio::test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_resume_session() { + let session_data = session_data(); + // success + { + let client = MockClient { + responses: MockResponses { + get_session: Some(crate::com::atproto::server::get_session::OutputData { + active: session_data.active, + did: session_data.did.clone(), + did_doc: session_data.did_doc.clone(), + email: session_data.email.clone(), + email_auth_factor: session_data.email_auth_factor, + email_confirmed: session_data.email_confirmed, + handle: session_data.handle.clone(), + status: session_data.status.clone(), + }), + ..Default::default() + }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemoryStore::default()); + assert_eq!(agent.get_session().await, None); + agent + .resume_session( + OutputData { + email: Some(String::from("test@example.com")), + ..session_data.clone() + } + .into(), + ) + .await + .expect("resume_session should be succeeded"); + assert_eq!(agent.get_session().await, Some(session_data.clone().into())); + } + // failure with `getSession` error + { + let client = MockClient { + responses: MockResponses { ..Default::default() }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemoryStore::default()); + assert_eq!(agent.get_session().await, None); + agent + .resume_session(session_data.clone().into()) + .await + .expect_err("resume_session should be failed"); + assert_eq!(agent.get_session().await, None); + } + } + + #[tokio::test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_resume_session_with_refresh() { + let session_data = session_data(); + let client = MockClient { + responses: MockResponses { + get_session: Some(crate::com::atproto::server::get_session::OutputData { + active: session_data.active, + did: session_data.did.clone(), + did_doc: session_data.did_doc.clone(), + email: session_data.email.clone(), + email_auth_factor: session_data.email_auth_factor, + email_confirmed: session_data.email_confirmed, + handle: session_data.handle.clone(), + status: session_data.status.clone(), + }), + ..Default::default() + }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemoryStore::default()); + agent + .resume_session( + OutputData { access_jwt: "expired".into(), ..session_data.clone() }.into(), + ) + .await + .expect("resume_session should be succeeded"); + // TODO: why? + // assert_eq!(agent.get_session().await, None); + } + + #[tokio::test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_login_with_diddoc() { + let session_data = session_data(); + let did_doc = DidDocument { + context: None, + id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(), + also_known_as: Some(vec!["at://atproto.com".into()]), + verification_method: Some(vec![VerificationMethod { + id: "did:plc:ewvi7nxzyoun6zhxrhs64oiz#atproto".into(), + r#type: "Multikey".into(), + controller: "did:plc:ewvi7nxzyoun6zhxrhs64oiz".into(), + public_key_multibase: Some( + "zQ3shXjHeiBuRCKmM36cuYnm7YEMzhGnCmCyW92sRJ9pribSF".into(), + ), + }]), + service: Some(vec![Service { + id: "#atproto_pds".into(), + r#type: "AtprotoPersonalDataServer".into(), + service_endpoint: "https://bsky.social".into(), + }]), + }; + // success + { + let client = MockClient { + responses: MockResponses { + create_session: Some(crate::com::atproto::server::create_session::OutputData { + did_doc: Some( + did_doc + .clone() + .try_into_unknown() + .expect("failed to convert to unknown"), + ), + ..session_data.clone() + }), + ..Default::default() + }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemoryStore::default()); + agent.login("test", "pass").await.expect("login should be succeeded"); + assert_eq!(agent.get_endpoint().await, "https://bsky.social"); + assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "https://bsky.social"); + } + // invalid services + { + let client = MockClient { + responses: MockResponses { + create_session: Some(crate::com::atproto::server::create_session::OutputData { + did_doc: Some( + DidDocument { + service: Some(vec![ + Service { + id: "#pds".into(), // not `#atproto_pds` + r#type: "AtprotoPersonalDataServer".into(), + service_endpoint: "https://bsky.social".into(), + }, + Service { + id: "#atproto_pds".into(), + r#type: "AtprotoPersonalDataServer".into(), + service_endpoint: "htps://bsky.social".into(), // invalid url (not `https`) + }, + ]), + ..did_doc.clone() + } + .try_into_unknown() + .expect("failed to convert to unknown"), + ), + ..session_data.clone() + }), + ..Default::default() + }, + ..Default::default() + }; + let agent = AtpAgent::new(client, MemoryStore::default()); + agent.login("test", "pass").await.expect("login should be succeeded"); + // not updated + assert_eq!(agent.get_endpoint().await, "http://localhost:8080"); + assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "http://localhost:8080"); + } + } + + #[tokio::test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_configure_labelers_header() { + let client = MockClient::default(); + let headers = Arc::clone(&client.headers); + let agent = AtpAgent::new(client, MemoryStore::default()); + + agent + .api + .com + .atproto + .server + .describe_server() + .await + .expect("describe_server should be succeeded"); + assert_eq!(headers.read().await.last(), Some(&HeaderMap::new())); + + agent.configure_labelers_header(Some(vec![( + "did:plc:test1".parse().expect("did should be valid"), + false, + )])); + agent + .api + .com + .atproto + .server + .describe_server() + .await + .expect("describe_server should be succeeded"); + assert_eq!( + headers.read().await.last(), + Some(&HeaderMap::from_iter([( + HeaderName::from_static("atproto-accept-labelers"), + HeaderValue::from_static("did:plc:test1"), + )])) + ); + + agent.configure_labelers_header(Some(vec![ + ("did:plc:test1".parse().expect("did should be valid"), true), + ("did:plc:test2".parse().expect("did should be valid"), false), + ])); + agent + .api + .com + .atproto + .server + .describe_server() + .await + .expect("describe_server should be succeeded"); + assert_eq!( + headers.read().await.last(), + Some(&HeaderMap::from_iter([( + HeaderName::from_static("atproto-accept-labelers"), + HeaderValue::from_static("did:plc:test1;redact, did:plc:test2"), + )])) + ); + + assert_eq!( + agent.get_labelers_header().await, + Some(vec![String::from("did:plc:test1;redact"), String::from("did:plc:test2")]) + ); + } + + #[tokio::test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + async fn test_configure_proxy_header() { + let client = MockClient::default(); + let headers = Arc::clone(&client.headers); + let agent = AtpAgent::new(client, MemoryStore::default()); + + agent + .api + .com + .atproto + .server + .describe_server() + .await + .expect("describe_server should be succeeded"); + assert_eq!(headers.read().await.last(), Some(&HeaderMap::new())); + + agent.configure_proxy_header( + "did:plc:test1".parse().expect("did should be balid"), + AtprotoServiceType::AtprotoLabeler, + ); + agent + .api + .com + .atproto + .server + .describe_server() + .await + .expect("describe_server should be succeeded"); + assert_eq!( + headers.read().await.last(), + Some(&HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:plc:test1#atproto_labeler"), + ),])) + ); + + agent.configure_proxy_header( + "did:plc:test1".parse().expect("did should be balid"), + "atproto_labeler", + ); + agent + .api + .com + .atproto + .server + .describe_server() + .await + .expect("describe_server should be succeeded"); + assert_eq!( + headers.read().await.last(), + Some(&HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:plc:test1#atproto_labeler"), + ),])) + ); + + agent + .api_with_proxy( + "did:plc:test2".parse().expect("did should be balid"), + "atproto_labeler", + ) + .com + .atproto + .server + .describe_server() + .await + .expect("describe_server should be succeeded"); + assert_eq!( + headers.read().await.last(), + Some(&HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:plc:test2#atproto_labeler"), + ),])) + ); + + agent + .api + .com + .atproto + .server + .describe_server() + .await + .expect("describe_server should be succeeded"); + assert_eq!( + headers.read().await.last(), + Some(&HeaderMap::from_iter([( + HeaderName::from_static("atproto-proxy"), + HeaderValue::from_static("did:plc:test1#atproto_labeler"), + ),])) + ); + + assert_eq!( + agent.get_proxy_header().await, + Some(String::from("did:plc:test1#atproto_labeler")) + ); + } +} diff --git a/atrium-api/src/agent/atp_agent/inner.rs b/atrium-api/src/agent/atp_agent/inner.rs new file mode 100644 index 00000000..ba801f77 --- /dev/null +++ b/atrium-api/src/agent/atp_agent/inner.rs @@ -0,0 +1,318 @@ +use crate::did_doc::DidDocument; +use crate::types::string::Did; +use crate::types::TryFromUnknown; +use atrium_common::store::Store as StoreTrait; +use atrium_xrpc::error::{Error, Result, XrpcErrorKind}; +use atrium_xrpc::types::AuthorizationToken; +use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest}; +use http::{Method, Request, Response}; +use serde::{de::DeserializeOwned, Serialize}; +use std::hash::Hash; +use std::{ + fmt::Debug, + sync::{Arc, RwLock}, +}; +use tokio::sync::{Mutex, Notify}; + +use super::AtpSession; + +struct WrapperClient { + store: Arc>, + proxy_header: RwLock>, + labelers_header: Arc>>>, + inner: Arc, +} + +impl WrapperClient { + fn configure_proxy_header(&self, value: String) { + self.proxy_header.write().expect("failed to write proxy header").replace(value); + } + fn configure_labelers_header(&self, labelers_dids: Option>) { + *self.labelers_header.write().expect("failed to write labelers header") = + labelers_dids.map(|dids| { + dids.iter() + .map(|(did, redact)| { + if *redact { + format!("{};redact", did.as_ref()) + } else { + did.as_ref().into() + } + }) + .collect() + }) + } +} + +impl Clone for WrapperClient { + fn clone(&self) -> Self { + Self { + store: self.store.clone(), + labelers_header: self.labelers_header.clone(), + proxy_header: RwLock::new( + self.proxy_header.read().expect("failed to read proxy header").clone(), + ), + inner: self.inner.clone(), + } + } +} + +impl HttpClient for WrapperClient +where + S: Send + Sync, + T: HttpClient + Send + Sync, +{ + async fn send_http( + &self, + request: Request>, + ) -> core::result::Result>, Box> + { + self.inner.send_http(request).await + } +} + +impl XrpcClient for WrapperClient +where + S: StoreTrait<(), AtpSession> + Send + Sync, + T: XrpcClient + Send + Sync, +{ + fn base_uri(&self) -> String { + self.store.get_endpoint() + } + async fn authorization_token(&self, is_refresh: bool) -> Option { + self.store.get(&()).await.transpose().and_then(core::result::Result::ok).map(|session| { + AuthorizationToken::Bearer(if is_refresh { + session.data.refresh_jwt + } else { + session.data.access_jwt + }) + }) + } + async fn atproto_proxy_header(&self) -> Option { + self.proxy_header.read().expect("failed to read proxy header").clone() + } + async fn atproto_accept_labelers_header(&self) -> Option> { + self.labelers_header.read().expect("failed to read labelers header").clone() + } +} + +pub struct Client { + store: Arc>, + inner: WrapperClient, + is_refreshing: Arc>, + notify: Arc, +} + +impl Client +where + S: StoreTrait<(), AtpSession> + Send + Sync, + T: XrpcClient + Send + Sync, +{ + pub fn new(store: Arc>, xrpc: T) -> Self { + let inner = WrapperClient { + store: Arc::clone(&store), + labelers_header: Arc::new(RwLock::new(None)), + proxy_header: RwLock::new(None), + inner: Arc::new(xrpc), + }; + Self { + store, + inner, + is_refreshing: Arc::new(Mutex::new(false)), + notify: Arc::new(Notify::new()), + } + } + pub fn configure_endpoint(&self, endpoint: String) { + *self.store.endpoint.write().expect("failed to write endpoint") = endpoint; + } + pub fn configure_proxy_header(&self, did: Did, service_type: impl AsRef) { + self.inner.configure_proxy_header(format!("{}#{}", did.as_ref(), service_type.as_ref())); + } + pub fn clone_with_proxy(&self, did: Did, service_type: impl AsRef) -> Self { + let cloned = self.clone(); + cloned.inner.configure_proxy_header(format!("{}#{}", did.as_ref(), service_type.as_ref())); + cloned + } + pub fn configure_labelers_header(&self, labeler_dids: Option>) { + self.inner.configure_labelers_header(labeler_dids); + } + pub async fn get_labelers_header(&self) -> Option> { + self.inner.atproto_accept_labelers_header().await + } + pub async fn get_proxy_header(&self) -> Option { + self.inner.atproto_proxy_header().await + } + // Internal helper to refresh sessions + // - Wraps the actual implementation to ensure only one refresh is attempted at a time. + async fn refresh_session(&self) { + { + let mut is_refreshing = self.is_refreshing.lock().await; + if *is_refreshing { + drop(is_refreshing); + return self.notify.notified().await; + } + *is_refreshing = true; + } + // TODO: Ensure `is_refreshing` is reliably set to false even in the event of unexpected errors within `refresh_session_inner()`. + self.refresh_session_inner().await; + *self.is_refreshing.lock().await = false; + self.notify.notify_waiters(); + } + async fn refresh_session_inner(&self) { + if let Ok(output) = self.call_refresh_session().await { + if let Ok(Some(mut session)) = self.store.get(&()).await { + session.access_jwt = output.data.access_jwt; + session.did = output.data.did; + session.did_doc = output.data.did_doc.clone(); + session.handle = output.data.handle; + session.refresh_jwt = output.data.refresh_jwt; + let _ = self.store.set((), session).await; + } + if let Some(did_doc) = output + .data + .did_doc + .as_ref() + .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok()) + { + self.store.update_endpoint(&did_doc); + } + } else { + let _ = self.store.clear().await; + } + } + // same as `crate::client::com::atproto::server::Service::refresh_session()` + async fn call_refresh_session( + &self, + ) -> Result< + crate::com::atproto::server::refresh_session::Output, + crate::com::atproto::server::refresh_session::Error, + > { + let response = self + .inner + .send_xrpc::<(), (), _, _>(&XrpcRequest { + method: Method::POST, + nsid: crate::com::atproto::server::refresh_session::NSID.into(), + parameters: None, + input: None, + encoding: None, + }) + .await?; + match response { + OutputDataOrBytes::Data(data) => Ok(data), + _ => Err(Error::UnexpectedResponseType), + } + } + fn is_expired(result: &Result, E>) -> bool + where + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync + Debug, + { + if let Err(Error::XrpcResponse(response)) = &result { + if let Some(XrpcErrorKind::Undefined(body)) = &response.error { + if let Some("ExpiredToken") = &body.error.as_deref() { + return true; + } + } + } + false + } +} + +impl Clone for Client +where + S: StoreTrait<(), AtpSession> + Send + Sync, + T: XrpcClient + Send + Sync, +{ + fn clone(&self) -> Self { + Self { + store: self.store.clone(), + inner: self.inner.clone(), + is_refreshing: self.is_refreshing.clone(), + notify: self.notify.clone(), + } + } +} + +impl HttpClient for Client +where + S: Send + Sync, + T: HttpClient + Send + Sync, +{ + async fn send_http( + &self, + request: Request>, + ) -> core::result::Result>, Box> + { + self.inner.send_http(request).await + } +} + +impl XrpcClient for Client +where + S: StoreTrait<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, + T: XrpcClient + Send + Sync, +{ + fn base_uri(&self) -> String { + self.inner.base_uri() + } + async fn send_xrpc( + &self, + request: &XrpcRequest, + ) -> Result, E> + where + P: Serialize + Send + Sync, + I: Serialize + Send + Sync, + O: DeserializeOwned + Send + Sync, + E: DeserializeOwned + Send + Sync + Debug, + { + let result = self.inner.send_xrpc(request).await; + // handle session-refreshes as needed + if Self::is_expired(&result) { + self.refresh_session().await; + self.inner.send_xrpc(request).await + } else { + result + } + } +} + +pub struct Store { + inner: S, + endpoint: RwLock, +} + +impl Store { + pub fn new(inner: S, initial_endpoint: String) -> Self { + Self { inner, endpoint: RwLock::new(initial_endpoint) } + } + pub fn get_endpoint(&self) -> String { + self.endpoint.read().expect("failed to read endpoint").clone() + } + pub fn update_endpoint(&self, did_doc: &DidDocument) { + if let Some(endpoint) = did_doc.get_pds_endpoint() { + *self.endpoint.write().expect("failed to write endpoint") = endpoint; + } + } +} + +impl StoreTrait for Store +where + K: Eq + Hash + Send + Sync, + V: Clone + Send, + S: StoreTrait + Sync, +{ + type Error = S::Error; + + async fn get(&self, key: &K) -> core::result::Result, Self::Error> { + self.inner.get(key).await + } + async fn set(&self, key: K, value: V) -> core::result::Result<(), Self::Error> { + self.inner.set(key, value).await + } + async fn del(&self, key: &K) -> core::result::Result<(), Self::Error> { + self.inner.del(key).await + } + async fn clear(&self) -> core::result::Result<(), Self::Error> { + self.inner.clear().await + } +} diff --git a/atrium-api/src/agent/inner.rs b/atrium-api/src/agent/inner.rs index f3bf2e66..e8b634fd 100644 --- a/atrium-api/src/agent/inner.rs +++ b/atrium-api/src/agent/inner.rs @@ -1,308 +1,93 @@ -use super::{Session, SessionStore}; -use crate::did_doc::DidDocument; -use crate::types::{string::Did, TryFromUnknown}; -use atrium_xrpc::{ - error::{Error, Result, XrpcErrorKind}, - types::AuthorizationToken, - HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest, -}; -use http::{Method, Request, Response}; +use super::SessionManager; +use crate::types::string::Did; +use atrium_xrpc::{Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest}; +use http::{Request, Response}; use serde::{de::DeserializeOwned, Serialize}; -use std::{ - fmt::Debug, - sync::{Arc, RwLock}, -}; -use tokio::sync::{Mutex, Notify}; +use std::{fmt::Debug, ops::Deref, sync::Arc}; -struct WrapperClient { - store: Arc>, - proxy_header: RwLock>, - labelers_header: Arc>>>, - inner: Arc, -} - -impl WrapperClient { - fn configure_proxy_header(&self, value: String) { - self.proxy_header.write().expect("failed to write proxy header").replace(value); - } - fn configure_labelers_header(&self, labelers_dids: Option>) { - *self.labelers_header.write().expect("failed to write labelers header") = - labelers_dids.map(|dids| { - dids.iter() - .map(|(did, redact)| { - if *redact { - format!("{};redact", did.as_ref()) - } else { - did.as_ref().into() - } - }) - .collect() - }) - } -} - -impl Clone for WrapperClient { - fn clone(&self) -> Self { - Self { - store: self.store.clone(), - labelers_header: self.labelers_header.clone(), - proxy_header: RwLock::new( - self.proxy_header.read().expect("failed to read proxy header").clone(), - ), - inner: self.inner.clone(), - } - } -} - -impl HttpClient for WrapperClient -where - S: Send + Sync, - T: HttpClient + Send + Sync, -{ - async fn send_http( - &self, - request: Request>, - ) -> core::result::Result>, Box> - { - self.inner.send_http(request).await - } -} - -impl XrpcClient for WrapperClient -where - S: SessionStore + Send + Sync, - T: XrpcClient + Send + Sync, -{ - fn base_uri(&self) -> String { - self.store.get_endpoint() - } - async fn authorization_token(&self, is_refresh: bool) -> Option { - self.store.get_session().await.map(|session| { - AuthorizationToken::Bearer(if is_refresh { - session.data.refresh_jwt - } else { - session.data.access_jwt - }) - }) - } - async fn atproto_proxy_header(&self) -> Option { - self.proxy_header.read().expect("failed to read proxy header").clone() - } - async fn atproto_accept_labelers_header(&self) -> Option> { - self.labelers_header.read().expect("failed to read labelers header").clone() - } -} - -pub struct Client { - store: Arc>, - inner: WrapperClient, - is_refreshing: Arc>, - notify: Arc, -} - -impl Client +pub struct Wrapper where - S: SessionStore + Send + Sync, - T: XrpcClient + Send + Sync, + M: SessionManager + Send + Sync, { - pub fn new(store: Arc>, xrpc: T) -> Self { - let inner = WrapperClient { - store: Arc::clone(&store), - labelers_header: Arc::new(RwLock::new(None)), - proxy_header: RwLock::new(None), - inner: Arc::new(xrpc), - }; - Self { - store, - inner, - is_refreshing: Arc::new(Mutex::new(false)), - notify: Arc::new(Notify::new()), - } - } - pub fn configure_endpoint(&self, endpoint: String) { - *self.store.endpoint.write().expect("failed to write endpoint") = endpoint; - } - pub fn configure_proxy_header(&self, did: Did, service_type: impl AsRef) { - self.inner.configure_proxy_header(format!("{}#{}", did.as_ref(), service_type.as_ref())); - } - pub fn clone_with_proxy(&self, did: Did, service_type: impl AsRef) -> Self { - let cloned = self.clone(); - cloned.inner.configure_proxy_header(format!("{}#{}", did.as_ref(), service_type.as_ref())); - cloned - } - pub fn configure_labelers_header(&self, labeler_dids: Option>) { - self.inner.configure_labelers_header(labeler_dids); - } - pub async fn get_labelers_header(&self) -> Option> { - self.inner.atproto_accept_labelers_header().await - } - pub async fn get_proxy_header(&self) -> Option { - self.inner.atproto_proxy_header().await - } - // Internal helper to refresh sessions - // - Wraps the actual implementation to ensure only one refresh is attempted at a time. - async fn refresh_session(&self) { - { - let mut is_refreshing = self.is_refreshing.lock().await; - if *is_refreshing { - drop(is_refreshing); - return self.notify.notified().await; - } - *is_refreshing = true; - } - // TODO: Ensure `is_refreshing` is reliably set to false even in the event of unexpected errors within `refresh_session_inner()`. - self.refresh_session_inner().await; - *self.is_refreshing.lock().await = false; - self.notify.notify_waiters(); - } - async fn refresh_session_inner(&self) { - if let Ok(output) = self.call_refresh_session().await { - if let Some(mut session) = self.store.get_session().await { - session.access_jwt = output.data.access_jwt; - session.did = output.data.did; - session.did_doc = output.data.did_doc.clone(); - session.handle = output.data.handle; - session.refresh_jwt = output.data.refresh_jwt; - self.store.set_session(session).await; - } - if let Some(did_doc) = output - .data - .did_doc - .as_ref() - .and_then(|value| DidDocument::try_from_unknown(value.clone()).ok()) - { - self.store.update_endpoint(&did_doc); - } - } else { - self.store.clear_session().await; - } - } - // same as `crate::client::com::atproto::server::Service::refresh_session()` - async fn call_refresh_session( - &self, - ) -> Result< - crate::com::atproto::server::refresh_session::Output, - crate::com::atproto::server::refresh_session::Error, - > { - let response = self - .inner - .send_xrpc::<(), (), _, _>(&XrpcRequest { - method: Method::POST, - nsid: crate::com::atproto::server::refresh_session::NSID.into(), - parameters: None, - input: None, - encoding: None, - }) - .await?; - match response { - OutputDataOrBytes::Data(data) => Ok(data), - _ => Err(Error::UnexpectedResponseType), - } - } - fn is_expired(result: &Result, E>) -> bool - where - O: DeserializeOwned + Send + Sync, - E: DeserializeOwned + Send + Sync + Debug, - { - if let Err(Error::XrpcResponse(response)) = &result { - if let Some(XrpcErrorKind::Undefined(body)) = &response.error { - if let Some("ExpiredToken") = &body.error.as_deref() { - return true; - } - } - } - false - } + inner: Arc, } -impl Clone for Client +impl Wrapper where - S: SessionStore + Send + Sync, - T: XrpcClient + Send + Sync, + M: SessionManager + Send + Sync, { - fn clone(&self) -> Self { - Self { - store: self.store.clone(), - inner: self.inner.clone(), - is_refreshing: self.is_refreshing.clone(), - notify: self.notify.clone(), - } + pub fn new(inner: M) -> Self { + Self { inner: Arc::new(inner) } } } -impl HttpClient for Client +impl HttpClient for Wrapper where - S: Send + Sync, - T: HttpClient + Send + Sync, + M: SessionManager + Send + Sync, { async fn send_http( &self, request: Request>, - ) -> core::result::Result>, Box> - { + ) -> Result>, Box> { self.inner.send_http(request).await } } -impl XrpcClient for Client +impl XrpcClient for Wrapper where - S: SessionStore + Send + Sync, - T: XrpcClient + Send + Sync, + M: SessionManager + Send + Sync, { fn base_uri(&self) -> String { self.inner.base_uri() } + // async fn authentication_token(&self, is_refresh: bool) -> Option { + // self.inner.authentication_token(is_refresh).await + // } + // async fn atproto_proxy_header(&self) -> Option { + // self.inner.atproto_proxy_header().await + // } + // async fn atproto_accept_labelers_header(&self) -> Option> { + // self.inner.atproto_accept_labelers_header().await + // } async fn send_xrpc( &self, request: &XrpcRequest, - ) -> Result, E> + ) -> Result, Error> where P: Serialize + Send + Sync, I: Serialize + Send + Sync, O: DeserializeOwned + Send + Sync, E: DeserializeOwned + Send + Sync + Debug, { - let result = self.inner.send_xrpc(request).await; - // handle session-refreshes as needed - if Self::is_expired(&result) { - self.refresh_session().await; - self.inner.send_xrpc(request).await - } else { - result - } + self.inner.send_xrpc(request).await } } -pub struct Store { - inner: S, - endpoint: RwLock, +impl SessionManager for Wrapper +where + M: SessionManager + Send + Sync, +{ + async fn did(&self) -> Option { + self.inner.did().await + } } -impl Store { - pub fn new(inner: S, initial_endpoint: String) -> Self { - Self { inner, endpoint: RwLock::new(initial_endpoint) } - } - pub fn get_endpoint(&self) -> String { - self.endpoint.read().expect("failed to read endpoint").clone() - } - pub fn update_endpoint(&self, did_doc: &DidDocument) { - if let Some(endpoint) = did_doc.get_pds_endpoint() { - *self.endpoint.write().expect("failed to write endpoint") = endpoint; - } +impl Clone for Wrapper +where + M: SessionManager + Send + Sync, +{ + fn clone(&self) -> Self { + Self { inner: self.inner.clone() } } } -impl SessionStore for Store +impl Deref for Wrapper where - S: SessionStore + Send + Sync, + M: SessionManager + Send + Sync, { - async fn get_session(&self) -> Option { - self.inner.get_session().await - } - async fn set_session(&self, session: Session) { - self.inner.set_session(session).await; - } - async fn clear_session(&self) { - self.inner.clear_session().await; + type Target = M; + + fn deref(&self) -> &Self::Target { + &self.inner } } diff --git a/atrium-api/src/agent/session_manager.rs b/atrium-api/src/agent/session_manager.rs new file mode 100644 index 00000000..7280ee2b --- /dev/null +++ b/atrium-api/src/agent/session_manager.rs @@ -0,0 +1,8 @@ +use crate::types::string::Did; +use atrium_xrpc::XrpcClient; +use std::future::Future; + +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] +pub trait SessionManager: XrpcClient { + fn did(&self) -> impl Future>; +} diff --git a/atrium-api/src/agent/store.rs b/atrium-api/src/agent/store.rs deleted file mode 100644 index 22bdcb37..00000000 --- a/atrium-api/src/agent/store.rs +++ /dev/null @@ -1,16 +0,0 @@ -mod memory; - -use std::future::Future; - -pub use self::memory::MemorySessionStore; -pub(crate) use super::Session; - -#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] -pub trait SessionStore { - #[must_use] - fn get_session(&self) -> impl Future>; - #[must_use] - fn set_session(&self, session: Session) -> impl Future; - #[must_use] - fn clear_session(&self) -> impl Future; -} diff --git a/atrium-api/src/agent/store/memory.rs b/atrium-api/src/agent/store/memory.rs deleted file mode 100644 index 05eedaaf..00000000 --- a/atrium-api/src/agent/store/memory.rs +++ /dev/null @@ -1,20 +0,0 @@ -use super::{Session, SessionStore}; -use std::sync::Arc; -use tokio::sync::RwLock; - -#[derive(Default, Clone)] -pub struct MemorySessionStore { - session: Arc>>, -} - -impl SessionStore for MemorySessionStore { - async fn get_session(&self) -> Option { - self.session.read().await.clone() - } - async fn set_session(&self, session: Session) { - self.session.write().await.replace(session); - } - async fn clear_session(&self) { - self.session.write().await.take(); - } -} diff --git a/atrium-common/Cargo.toml b/atrium-common/Cargo.toml index 9bda3a56..0deaee0f 100644 --- a/atrium-common/Cargo.toml +++ b/atrium-common/Cargo.toml @@ -9,10 +9,16 @@ documentation = "https://docs.rs/atrium-common" readme = "README.md" repository.workspace = true license.workspace = true -keywords = ["atproto", "bluesky"] +keywords = ["atproto", "bluesky", "identity"] [dependencies] +atrium-xrpc.workspace = true +chrono = { workspace = true, features = ["serde"] } dashmap.workspace = true +hickory-proto = { workspace = true, optional = true } +serde = { workspace = true, features = ["derive"] } +serde_html_form.workspace = true +serde_json.workspace = true thiserror.workspace = true tokio = { workspace = true, default-features = false, features = ["sync"] } trait-variant.workspace = true diff --git a/atrium-common/src/store/memory.rs b/atrium-common/src/store/memory.rs index dc81fd7c..2500959e 100644 --- a/atrium-common/src/store/memory.rs +++ b/atrium-common/src/store/memory.rs @@ -2,14 +2,14 @@ use super::Store; use std::collections::HashMap; use std::fmt::Debug; use std::hash::Hash; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use thiserror::Error; +use tokio::sync::Mutex; #[derive(Error, Debug)] #[error("memory store error")] pub struct Error; -// TODO: LRU cache? #[derive(Clone)] pub struct MemoryStore { store: Arc>>, @@ -23,24 +23,24 @@ impl Default for MemoryStore { impl Store for MemoryStore where - K: Debug + Eq + Hash + Send + Sync + 'static, - V: Debug + Clone + Send + Sync + 'static, + K: Eq + Hash + Send + Sync, + V: Clone + Send, { type Error = Error; async fn get(&self, key: &K) -> Result, Self::Error> { - Ok(self.store.lock().unwrap().get(key).cloned()) + Ok(self.store.lock().await.get(key).cloned()) } async fn set(&self, key: K, value: V) -> Result<(), Self::Error> { - self.store.lock().unwrap().insert(key, value); + self.store.lock().await.insert(key, value); Ok(()) } async fn del(&self, key: &K) -> Result<(), Self::Error> { - self.store.lock().unwrap().remove(key); + self.store.lock().await.remove(key); Ok(()) } async fn clear(&self) -> Result<(), Self::Error> { - self.store.lock().unwrap().clear(); + self.store.lock().await.clear(); Ok(()) } } diff --git a/atrium-oauth/identity/src/identity_resolver.rs b/atrium-oauth/identity/src/identity_resolver.rs index a70e1856..22b4c58b 100644 --- a/atrium-oauth/identity/src/identity_resolver.rs +++ b/atrium-oauth/identity/src/identity_resolver.rs @@ -31,6 +31,7 @@ impl Resolver for IdentityResolver where D: DidResolver + Send + Sync + 'static, H: HandleResolver + Send + Sync + 'static, + // Error: From + From, { type Input = str; type Output = ResolvedIdentity; diff --git a/atrium-oauth/oauth-client/Cargo.toml b/atrium-oauth/oauth-client/Cargo.toml index 8920ccfc..a73a8e8d 100644 --- a/atrium-oauth/oauth-client/Cargo.toml +++ b/atrium-oauth/oauth-client/Cargo.toml @@ -14,7 +14,7 @@ keywords = ["atproto", "bluesky", "oauth"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -atrium-api = { workspace = true, default-features = false } +atrium-api = { workspace = true, features = ["agent"] } atrium-common.workspace = true atrium-identity.workspace = true atrium-xrpc.workspace = true @@ -22,6 +22,7 @@ base64.workspace = true chrono.workspace = true ecdsa = { workspace = true, features = ["signing"] } elliptic-curve.workspace = true +futures.workspace = true jose-jwa.workspace = true jose-jwk = { workspace = true, features = ["p256"] } p256 = { workspace = true, features = ["ecdsa"] } @@ -35,6 +36,7 @@ thiserror.workspace = true trait-variant.workspace = true [dev-dependencies] +atrium-api = { workspace = true, features = ["bluesky"] } hickory-resolver.workspace = true p256 = { workspace = true, features = ["pem"] } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } diff --git a/atrium-oauth/oauth-client/examples/main.rs b/atrium-oauth/oauth-client/examples/main.rs index ee211fc4..af0f18e7 100644 --- a/atrium-oauth/oauth-client/examples/main.rs +++ b/atrium-oauth/oauth-client/examples/main.rs @@ -1,5 +1,7 @@ +use atrium_api::agent::Agent; use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}; use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver}; +use atrium_oauth_client::store::session::MemorySessionStore; use atrium_oauth_client::store::state::MemoryStateStore; use atrium_oauth_client::{ AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient, @@ -57,6 +59,7 @@ async fn main() -> Result<(), Box> { protected_resource_metadata: Default::default(), }, state_store: MemoryStateStore::default(), + session_store: MemorySessionStore::default(), }; let client = OAuthClient::new(config)?; println!( @@ -76,7 +79,7 @@ async fn main() -> Result<(), Box> { ); // Click the URL and sign in, - // then copy and paste the URL like “http://127.0.0.1/?iss=...&code=...” after it is redirected. + // then copy and paste the URL like “http://127.0.0.1/callback?iss=...&code=...” after it is redirected. print!("Redirected url: "); stdout().lock().flush()?; @@ -85,7 +88,26 @@ async fn main() -> Result<(), Box> { let uri = url.trim().parse::()?; let params = serde_html_form::from_str(uri.query().unwrap())?; - println!("{}", serde_json::to_string_pretty(&client.callback(params).await?)?); + + let (session, _) = client.callback(params).await?; + let agent = Agent::new(session); + let output = agent + .api + .app + .bsky + .feed + .get_timeline( + atrium_api::app::bsky::feed::get_timeline::ParametersData { + algorithm: None, + cursor: None, + limit: 3.try_into().ok(), + } + .into(), + ) + .await?; + for feed in &output.feed { + println!("{feed:?}"); + } Ok(()) } diff --git a/atrium-oauth/oauth-client/src/atproto.rs b/atrium-oauth/oauth-client/src/atproto.rs index ae23170f..45f37f9a 100644 --- a/atrium-oauth/oauth-client/src/atproto.rs +++ b/atrium-oauth/oauth-client/src/atproto.rs @@ -160,7 +160,7 @@ impl TryIntoOAuthClientMetadata for AtprotoLocalhostClientMetadata { redirect_uris: self .redirect_uris .unwrap_or(vec![String::from("http://127.0.0.1/"), String::from("http://[::1]/")]), - scope: None, + scope: None, // will be set to `atproto` grant_types: None, // will be set to `authorization_code` and `refresh_token` token_endpoint_auth_method: Some(String::from("none")), dpop_bound_access_tokens: None, // will be set to `true` diff --git a/atrium-oauth/oauth-client/src/error.rs b/atrium-oauth/oauth-client/src/error.rs index 16f87001..ba1bd5ce 100644 --- a/atrium-oauth/oauth-client/src/error.rs +++ b/atrium-oauth/oauth-client/src/error.rs @@ -5,17 +5,21 @@ pub enum Error { #[error(transparent)] ClientMetadata(#[from] crate::atproto::Error), #[error(transparent)] - Keyset(#[from] crate::keyset::Error), + Dpop(#[from] crate::http_client::dpop::Error), #[error(transparent)] - Identity(#[from] atrium_identity::Error), + Keyset(#[from] crate::keyset::Error), #[error(transparent)] ServerAgent(#[from] crate::server_agent::Error), + #[error(transparent)] + Identity(#[from] atrium_identity::Error), #[error("authorize error: {0}")] Authorize(String), #[error("callback error: {0}")] Callback(String), - #[error("state store error: {0:?}")] + #[error("state store error: {0}")] StateStore(Box), + #[error("session store error: {0}")] + SessionStore(Box), } pub type Result = core::result::Result; diff --git a/atrium-oauth/oauth-client/src/http_client/dpop.rs b/atrium-oauth/oauth-client/src/http_client/dpop.rs index b92fd621..a45fb76c 100644 --- a/atrium-oauth/oauth-client/src/http_client/dpop.rs +++ b/atrium-oauth/oauth-client/src/http_client/dpop.rs @@ -1,8 +1,8 @@ use crate::jose::create_signed_jwt; use crate::jose::jws::RegisteredHeader; use crate::jose::jwt::{Claims, PublicClaims, RegisteredClaims}; -use crate::store::memory::MemorySimpleStore; -use crate::store::SimpleStore; +use atrium_common::store::memory::MemoryStore; +use atrium_common::store::Store; use atrium_xrpc::http::{Request, Response}; use atrium_xrpc::HttpClient; use base64::engine::general_purpose::URL_SAFE_NO_PAD; @@ -30,15 +30,19 @@ pub enum Error { JwkCrypto(crypto::Error), #[error("key does not match any alg supported by the server")] UnsupportedKey, + #[error("nonce store error: {0}")] + Nonces(Box), + #[error("session store error: {0}")] + SessionStore(Box), #[error(transparent)] SerdeJson(#[from] serde_json::Error), } type Result = core::result::Result; -pub struct DpopClient> +pub struct DpopClient> where - S: SimpleStore, + S: Store, { inner: Arc, pub(crate) key: Key, @@ -65,14 +69,14 @@ impl DpopClient { return Err(Error::UnsupportedKey); } } - let nonces = MemorySimpleStore::::default(); + let nonces = MemoryStore::::default(); Ok(Self { inner: http_client, key, nonces, is_auth_server }) } } impl DpopClient where - S: SimpleStore, + S: Store, { fn build_proof( &self, @@ -135,7 +139,8 @@ where impl HttpClient for DpopClient where T: HttpClient + Send + Sync + 'static, - S: SimpleStore + Send + Sync + 'static, + S: Store + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, { async fn send_http( &self, @@ -146,14 +151,16 @@ where let nonce_key = uri.authority().unwrap().to_string(); let htm = request.method().to_string(); let htu = uri.to_string(); - // https://datatracker.ietf.org/doc/html/rfc9449#section-4.2 - let ath = request - .headers() - .get("Authorization") - .filter(|v| v.to_str().map_or(false, |s| s.starts_with("DPoP "))) - .map(|auth| URL_SAFE_NO_PAD.encode(Sha256::digest(&auth.as_bytes()[5..]))); - let init_nonce = self.nonces.get(&nonce_key).await?; + let ath = match request.headers().get("Authorization").and_then(|v| v.to_str().ok()) { + Some(s) if s.starts_with("DPoP ") => { + Some(URL_SAFE_NO_PAD.encode(Sha256::digest(s.strip_prefix("DPoP ").unwrap()))) + } + _ => None, + }; + + let init_nonce = + self.nonces.get(&nonce_key).await.map_err(|e| Error::Nonces(Box::new(e)))?; let init_proof = self.build_proof(htm.clone(), htu.clone(), ath.clone(), init_nonce.clone())?; request.headers_mut().insert("DPoP", init_proof.parse()?); @@ -164,7 +171,10 @@ where match &next_nonce { Some(s) if next_nonce != init_nonce => { // Store the fresh nonce for future requests - self.nonces.set(nonce_key, s.clone()).await?; + self.nonces + .set(nonce_key, s.clone()) + .await + .map_err(|e| Error::Nonces(Box::new(e)))?; } _ => { // No nonce was returned or it is the same as the one we sent. No need to diff --git a/atrium-oauth/oauth-client/src/lib.rs b/atrium-oauth/oauth-client/src/lib.rs index 06071dc7..522d4e85 100644 --- a/atrium-oauth/oauth-client/src/lib.rs +++ b/atrium-oauth/oauth-client/src/lib.rs @@ -5,6 +5,7 @@ mod http_client; mod jose; mod keyset; mod oauth_client; +mod oauth_session; mod resolver; mod server_agent; pub mod store; @@ -19,6 +20,7 @@ pub use error::{Error, Result}; pub use http_client::default::DefaultHttpClient; pub use http_client::dpop::DpopClient; pub use oauth_client::{OAuthClient, OAuthClientConfig}; +pub use oauth_session::OAuthSession; pub use resolver::OAuthResolverConfig; pub use types::{ AuthorizeOptionPrompt, AuthorizeOptions, CallbackParams, OAuthClientMetadata, TokenSet, diff --git a/atrium-oauth/oauth-client/src/oauth_client.rs b/atrium-oauth/oauth-client/src/oauth_client.rs index e844f00a..3fef6659 100644 --- a/atrium-oauth/oauth-client/src/oauth_client.rs +++ b/atrium-oauth/oauth-client/src/oauth_client.rs @@ -1,17 +1,22 @@ use crate::constants::FALLBACK_ALG; use crate::error::{Error, Result}; use crate::keyset::Keyset; +use crate::oauth_session::OAuthSession; use crate::resolver::{OAuthResolver, OAuthResolverConfig}; use crate::server_agent::{OAuthRequest, OAuthServerAgent}; +use crate::store::session::{Session, SessionStore}; +use crate::store::session_getter::SessionGetter; use crate::store::state::{InternalStateData, StateStore}; use crate::types::{ AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions, CallbackParams, OAuthAuthorizationServerMetadata, OAuthClientMetadata, - OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters, TokenSet, + OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters, TryIntoOAuthClientMetadata, }; use crate::utils::{compare_algos, generate_key, generate_nonce, get_random_values}; +use atrium_api::types::string::Did; use atrium_common::resolver::Resolver; +use atrium_common::store::Store; use atrium_identity::{did::DidResolver, handle::HandleResolver}; use atrium_xrpc::HttpClient; use base64::engine::general_purpose::URL_SAFE_NO_PAD; @@ -23,7 +28,7 @@ use sha2::{Digest, Sha256}; use std::sync::Arc; #[cfg(feature = "default-client")] -pub struct OAuthClientConfig +pub struct OAuthClientConfig where M: TryIntoOAuthClientMetadata, { @@ -31,13 +36,14 @@ where pub client_metadata: M, pub keys: Option>, // Stores - pub state_store: S, + pub state_store: S0, + pub session_store: S1, // Services pub resolver: OAuthResolverConfig, } #[cfg(not(feature = "default-client"))] -pub struct OAuthClientConfig +pub struct OAuthClientConfig where M: TryIntoOAuthClientMetadata, { @@ -45,7 +51,8 @@ where pub client_metadata: M, pub keys: Option>, // Stores - pub state_store: S, + pub state_store: S0, + pub session_store: S1, // Services pub resolver: OAuthResolverConfig, // Others @@ -53,37 +60,34 @@ where } #[cfg(feature = "default-client")] -pub struct OAuthClient +pub struct OAuthClient where - S: StateStore, T: HttpClient + Send + Sync + 'static, { pub client_metadata: OAuthClientMetadata, keyset: Option, resolver: Arc>, - state_store: S, + state_store: S0, + session_getter: SessionGetter, http_client: Arc, } #[cfg(not(feature = "default-client"))] -pub struct OAuthClient +pub struct OAuthClient where - S: StateStore, T: HttpClient + Send + Sync + 'static, { pub client_metadata: OAuthClientMetadata, keyset: Option, resolver: Arc>, - state_store: S, + state_store: S0, + session_getter: SessionGetter, http_client: Arc, } #[cfg(feature = "default-client")] -impl OAuthClient -where - S: StateStore, -{ - pub fn new(config: OAuthClientConfig) -> Result +impl OAuthClient { + pub fn new(config: OAuthClientConfig) -> Result where M: TryIntoOAuthClientMetadata, { @@ -95,18 +99,20 @@ where keyset, resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())), state_store: config.state_store, + session_getter: SessionGetter::new(config.session_store), http_client, }) } } #[cfg(not(feature = "default-client"))] -impl OAuthClient +impl OAuthClient where - S: StateStore, + S0: StateStore, + S1: SessionStore, T: HttpClient + Send + Sync + 'static, { - pub fn new(config: OAuthClientConfig) -> Result + pub fn new(config: OAuthClientConfig) -> Result where M: TryIntoOAuthClientMetadata, { @@ -118,17 +124,21 @@ where keyset, resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())), state_store: config.state_store, + session_getter: SessionGetter::new(config.session_store), http_client, }) } } -impl OAuthClient +impl OAuthClient where - S: StateStore, + S0: StateStore + Send + Sync + 'static, + S1: SessionStore + Send + Sync + 'static, D: DidResolver + Send + Sync + 'static, H: HandleResolver + Send + Sync + 'static, T: HttpClient + Send + Sync + 'static, + S0::Error: std::error::Error + Send + Sync + 'static, + S1::Error: std::error::Error + Send + Sync + 'static, { pub fn jwks(&self) -> JwkSet { self.keyset.as_ref().map(|keyset| keyset.public_jwks()).unwrap_or_default() @@ -156,11 +166,9 @@ where iss: metadata.issuer.clone(), dpop_key: dpop_key.clone(), verifier, + app_state: options.state, }; - self.state_store - .set(state.clone(), state_data) - .await - .map_err(|e| Error::StateStore(Box::new(e)))?; + self.state_store.set(state.clone(), state_data).await.unwrap(); let login_hint = if identity.is_some() { Some(input.as_ref().into()) } else { None }; let parameters = PushedAuthorizationRequestParameters { response_type: AuthorizationResponseType::Code, @@ -174,14 +182,7 @@ where prompt: options.prompt.map(String::from), }; if metadata.pushed_authorization_request_endpoint.is_some() { - let server = OAuthServerAgent::new( - dpop_key, - metadata.clone(), - self.client_metadata.clone(), - self.resolver.clone(), - self.http_client.clone(), - self.keyset.clone(), - )?; + let server = self.create_server_agent(dpop_key, metadata.clone())?; let par_response = server .request::( OAuthRequest::PushedAuthorizationRequest(parameters), @@ -208,18 +209,19 @@ where todo!() } } - pub async fn callback(&self, params: CallbackParams) -> Result { + pub async fn callback( + &self, + params: CallbackParams, + ) -> Result<(OAuthSession, Option)> { let Some(state_key) = params.state else { return Err(Error::Callback("missing `state` parameter".into())); }; - let Some(state) = - self.state_store.get(&state_key).await.map_err(|e| Error::StateStore(Box::new(e)))? - else { + let Some(state) = self.state_store.get(&state_key).await.unwrap() else { return Err(Error::Callback(format!("unknown authorization state: {state_key}"))); }; // Prevent any kind of replay - self.state_store.del(&state_key).await.map_err(|e| Error::StateStore(Box::new(e)))?; + self.state_store.del(&state_key).await.unwrap(); let metadata = self.resolver.get_authorization_server_metadata(&state.iss).await?; // https://datatracker.ietf.org/doc/html/rfc9207#section-2.4 @@ -233,18 +235,43 @@ where } else if metadata.authorization_response_iss_parameter_supported == Some(true) { return Err(Error::Callback("missing `iss` parameter".into())); } - let server = OAuthServerAgent::new( - state.dpop_key.clone(), - metadata.clone(), + let server = self.create_server_agent(state.dpop_key.clone(), metadata.clone())?; + match server.exchange_code(¶ms.code, &state.verifier).await { + Ok(token_set) => { + let sub = token_set.sub.clone(); + self.session_getter + .set(sub.clone(), Session { dpop_key: state.dpop_key.clone(), token_set }) + .await + .map_err(|e| Error::SessionStore(Box::new(e)))?; + Ok((self.create_session(server, sub).await?, state.app_state)) + } + Err(_) => { + todo!() + } + } + } + async fn create_session( + &self, + server: OAuthServerAgent, + sub: Did, + ) -> Result> { + Ok(server + .create_session(sub, self.http_client.clone(), self.session_getter.clone()) + .await?) + } + fn create_server_agent( + &self, + dpop_key: Key, + server_metadata: OAuthAuthorizationServerMetadata, + ) -> Result> { + Ok(OAuthServerAgent::new( + dpop_key, + server_metadata, self.client_metadata.clone(), self.resolver.clone(), self.http_client.clone(), self.keyset.clone(), - )?; - let token_set = server.exchange_code(¶ms.code, &state.verifier).await?; - - // TODO: create session? - Ok(token_set) + )?) } fn generate_dpop_key(metadata: &OAuthAuthorizationServerMetadata) -> Option { let mut algs = @@ -259,3 +286,12 @@ where (URL_SAFE_NO_PAD.encode(Sha256::digest(&verifier)), verifier) } } + +impl std::fmt::Debug for OAuthClient +where + T: HttpClient + Send + Sync + 'static, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OAuthClient").field("client_metadata", &self.client_metadata).finish() + } +} diff --git a/atrium-oauth/oauth-client/src/oauth_session.rs b/atrium-oauth/oauth-client/src/oauth_session.rs new file mode 100644 index 00000000..30b00597 --- /dev/null +++ b/atrium-oauth/oauth-client/src/oauth_session.rs @@ -0,0 +1,149 @@ +use std::sync::Arc; + +use atrium_api::{agent::SessionManager, types::string::Did}; +use atrium_common::store::{memory::MemoryStore, Store}; +use atrium_identity::{did::DidResolver, handle::HandleResolver}; +use atrium_xrpc::{ + http::{Request, Response}, + types::AuthorizationToken, + HttpClient, XrpcClient, +}; +use jose_jwk::Key; + +use crate::{ + http_client::dpop::Error, + server_agent::OAuthServerAgent, + store::session::{MemorySessionStore, SessionStore}, + DpopClient, TokenSet, +}; + +pub struct OAuthSession< + T, + D, + H, + S0 = MemoryStore, + S1 = MemorySessionStore<(), TokenSet>, +> where + T: HttpClient + Send + Sync + 'static, + S0: Store, + S1: SessionStore<(), TokenSet>, +{ + #[allow(dead_code)] + server_agent: OAuthServerAgent, + dpop_client: DpopClient, + session_store: S1, +} + +impl OAuthSession +where + T: HttpClient + Send + Sync + 'static, +{ + pub(crate) async fn new( + server_agent: OAuthServerAgent, + dpop_key: Key, + http_client: Arc, + token_set: TokenSet, + ) -> Result { + let dpop_client = DpopClient::new( + dpop_key, + http_client.clone(), + false, + &server_agent.server_metadata.token_endpoint_auth_signing_alg_values_supported, + )?; + + let session_store = MemorySessionStore::default(); + session_store.set((), token_set).await.map_err(|e| Error::SessionStore(Box::new(e)))?; + + Ok(Self { server_agent, dpop_client, session_store }) + } + pub fn dpop_key(&self) -> Key { + self.dpop_client.key.clone() + } + pub async fn token_set(&self) -> Result { + let token_set = + self.session_store.get(&()).await.map_err(|e| Error::SessionStore(Box::new(e)))?; + Ok(token_set.expect("session store can never be empty")) + } +} + +impl OAuthSession +where + T: HttpClient + Send + Sync + 'static, + D: DidResolver + Send + Sync + 'static, + H: HandleResolver + Send + Sync + 'static, +{ + pub async fn refresh(&self) -> Result<(), Error> { + let Some(token_set) = + self.session_store.get(&()).await.map_err(|e| Error::SessionStore(Box::new(e)))? + else { + return Ok(()); + }; + let Ok(token_set) = self.server_agent.refresh(&token_set).await else { + todo!(); + }; + + self.session_store.set((), token_set).await.map_err(|e| Error::SessionStore(Box::new(e))) + } + pub async fn logout(&self) -> Result<(), Error> { + let Some(token_set) = + self.session_store.get(&()).await.map_err(|e| Error::SessionStore(Box::new(e)))? + else { + return Ok(()); + }; + self.server_agent.revoke(&token_set.access_token).await; + + self.session_store.clear().await.map_err(|e| Error::SessionStore(Box::new(e))) + } +} + +impl HttpClient for OAuthSession +where + T: HttpClient + Send + Sync + 'static, + D: Send + Sync + 'static, + H: Send + Sync + 'static, + S: Store + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, +{ + async fn send_http( + &self, + request: Request>, + ) -> Result>, Box> { + self.dpop_client.send_http(request).await + } +} + +impl XrpcClient for OAuthSession +where + T: HttpClient + Send + Sync + 'static, + D: Send + Sync + 'static, + H: Send + Sync + 'static, + S: Store + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, +{ + fn base_uri(&self) -> String { + // self.token_set.aud.clone() + todo!() + } + async fn authorization_token(&self, is_refresh: bool) -> Option { + let token_set = self.session_store.get(&()).await.transpose().and_then(Result::ok)?; + if is_refresh { + token_set.refresh_token.as_ref().cloned().map(AuthorizationToken::Dpop) + } else { + Some(AuthorizationToken::Dpop(token_set.access_token.clone())) + } + } +} + +impl SessionManager for OAuthSession +where + T: HttpClient + Send + Sync + 'static, + D: Send + Sync + 'static, + H: Send + Sync + 'static, + S: Store + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, +{ + async fn did(&self) -> Option { + let token_set = self.session_store.get(&()).await.transpose().and_then(Result::ok)?; + Some(token_set.sub.clone()) + } +} diff --git a/atrium-oauth/oauth-client/src/server_agent.rs b/atrium-oauth/oauth-client/src/server_agent.rs index c9d556f3..866adc0a 100644 --- a/atrium-oauth/oauth-client/src/server_agent.rs +++ b/atrium-oauth/oauth-client/src/server_agent.rs @@ -3,13 +3,17 @@ use crate::http_client::dpop::DpopClient; use crate::jose::jwt::{RegisteredClaims, RegisteredClaimsAud}; use crate::keyset::Keyset; use crate::resolver::OAuthResolver; +use crate::store::session::SessionStore; +use crate::store::session_getter::SessionGetter; use crate::types::{ OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthTokenResponse, - PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType, - TokenRequestParameters, TokenSet, + PushedAuthorizationRequestParameters, RefreshRequestParameters, RevocationRequestParameters, + TokenGrantType, TokenRequestParameters, TokenSet, }; use crate::utils::{compare_algos, generate_nonce}; -use atrium_api::types::string::Datetime; +use crate::OAuthSession; +use atrium_api::types::string::{Datetime, Did}; +use atrium_common::store::Store; use atrium_identity::{did::DidResolver, handle::HandleResolver}; use atrium_xrpc::http::{Method, Request, StatusCode}; use atrium_xrpc::HttpClient; @@ -32,6 +36,10 @@ pub enum Error { Token(String), #[error("unsupported authentication method")] UnsupportedAuthMethod, + #[error("failed to parse DID: {0}")] + InvalidDid(&'static str), + #[error("no refresh token available for {0}")] + NoRefreshToken(String), #[error(transparent)] DpopClient(#[from] crate::http_client::dpop::Error), #[error(transparent)] @@ -58,7 +66,7 @@ pub type Result = core::result::Result; pub enum OAuthRequest { Token(TokenRequestParameters), Refresh(RefreshRequestParameters), - Revocation, + Revocation(RevocationRequestParameters), Introspection, PushedAuthorizationRequest(PushedAuthorizationRequestParameters), } @@ -68,14 +76,14 @@ impl OAuthRequest { String::from(match self { Self::Token(_) => "token", Self::Refresh(_) => "refresh", - Self::Revocation => "revocation", + Self::Revocation(_) => "revocation", Self::Introspection => "introspection", Self::PushedAuthorizationRequest(_) => "pushed_authorization_request", }) } fn expected_status(&self) -> StatusCode { match self { - Self::Token(_) | Self::Refresh(_) => StatusCode::OK, + Self::Token(_) | Self::Refresh(_) | Self::Revocation(_) => StatusCode::OK, Self::PushedAuthorizationRequest(_) => StatusCode::CREATED, _ => unimplemented!(), } @@ -100,8 +108,8 @@ pub struct OAuthServerAgent where T: HttpClient + Send + Sync + 'static, { - server_metadata: OAuthAuthorizationServerMetadata, - client_metadata: OAuthClientMetadata, + pub(crate) server_metadata: OAuthAuthorizationServerMetadata, + pub(crate) client_metadata: OAuthClientMetadata, dpop_client: DpopClient, resolver: Arc>, keyset: Option, @@ -123,7 +131,7 @@ where ) -> Result { let dpop_client = DpopClient::new( dpop_key, - http_client, + http_client.clone(), true, &server_metadata.token_endpoint_auth_signing_alg_values_supported, )?; @@ -140,10 +148,12 @@ where async fn verify_token_response(&self, token_response: OAuthTokenResponse) -> Result { // ATPROTO requires that the "sub" is always present in the token response. let Some(sub) = &token_response.sub else { + self.revoke(&token_response.access_token).await; return Err(Error::Token("missing `sub` in token response".into())); }; let (metadata, identity) = self.resolver.resolve_from_identity(sub).await?; if metadata.issuer != self.server_metadata.issuer { + self.revoke(&token_response.access_token).await; return Err(Error::Token("issuer mismatch".into())); } let expires_at = token_response.expires_in.and_then(|expires_in| { @@ -153,7 +163,7 @@ where .map(Datetime::new) }); Ok(TokenSet { - sub: sub.clone(), + sub: sub.parse().map_err(Error::InvalidDid)?, aud: identity.pds, iss: metadata.issuer, scope: token_response.scope, @@ -175,6 +185,28 @@ where ) .await } + pub async fn revoke(&self, token: &str) { + let _ = self + .request::<()>(OAuthRequest::Revocation(RevocationRequestParameters { + token: token.into(), + })) + .await; + } + #[allow(dead_code)] + pub async fn refresh(&self, token_set: &TokenSet) -> Result { + let Some(refresh_token) = token_set.refresh_token.as_ref() else { + return Err(Error::NoRefreshToken(token_set.sub.to_string())); + }; + self.verify_token_response( + self.request::(OAuthRequest::Refresh(RefreshRequestParameters { + grant_type: TokenGrantType::RefreshToken, + refresh_token: refresh_token.clone(), + scope: None, + })) + .await?, + ) + .await + } pub async fn request(&self, request: OAuthRequest) -> Result where O: serde::de::DeserializeOwned, @@ -273,11 +305,26 @@ where OAuthRequest::Token(_) | OAuthRequest::Refresh(_) => { Some(&self.server_metadata.token_endpoint) } - OAuthRequest::Revocation => self.server_metadata.revocation_endpoint.as_ref(), + OAuthRequest::Revocation(_) => self.server_metadata.revocation_endpoint.as_ref(), OAuthRequest::Introspection => self.server_metadata.introspection_endpoint.as_ref(), OAuthRequest::PushedAuthorizationRequest(_) => { self.server_metadata.pushed_authorization_request_endpoint.as_ref() } } } + pub(crate) async fn create_session( + self, + sub: Did, + http_client: Arc, + session_getter: SessionGetter, + ) -> Result> + where + S: SessionStore + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, + { + let dpop_key = self.dpop_client.key.clone(); + // TODO + let session = session_getter.get(&sub).await.expect("").unwrap(); + OAuthSession::new(self, dpop_key, http_client, session.token_set).await.map_err(Into::into) + } } diff --git a/atrium-oauth/oauth-client/src/store.rs b/atrium-oauth/oauth-client/src/store.rs index 0850617c..a06b3710 100644 --- a/atrium-oauth/oauth-client/src/store.rs +++ b/atrium-oauth/oauth-client/src/store.rs @@ -1,20 +1,3 @@ -pub mod memory; +pub mod session; +pub mod session_getter; pub mod state; - -use std::error::Error; -use std::future::Future; -use std::hash::Hash; - -#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] -pub trait SimpleStore -where - K: Eq + Hash, - V: Clone, -{ - type Error: Error + Send + Sync + 'static; - - fn get(&self, key: &K) -> impl Future, Self::Error>>; - fn set(&self, key: K, value: V) -> impl Future>; - fn del(&self, key: &K) -> impl Future>; - fn clear(&self) -> impl Future>; -} diff --git a/atrium-oauth/oauth-client/src/store/cached.rs b/atrium-oauth/oauth-client/src/store/cached.rs new file mode 100644 index 00000000..e69de29b diff --git a/atrium-oauth/oauth-client/src/store/memory.rs b/atrium-oauth/oauth-client/src/store/memory.rs deleted file mode 100644 index c43c557d..00000000 --- a/atrium-oauth/oauth-client/src/store/memory.rs +++ /dev/null @@ -1,45 +0,0 @@ -use super::SimpleStore; -use std::collections::HashMap; -use std::fmt::Debug; -use std::hash::Hash; -use std::sync::{Arc, Mutex}; -use thiserror::Error; - -#[derive(Error, Debug)] -#[error("memory store error")] -pub struct Error; - -// TODO: LRU cache? -pub struct MemorySimpleStore { - store: Arc>>, -} - -impl Default for MemorySimpleStore { - fn default() -> Self { - Self { store: Arc::new(Mutex::new(HashMap::new())) } - } -} - -impl SimpleStore for MemorySimpleStore -where - K: Debug + Eq + Hash + Send + Sync + 'static, - V: Debug + Clone + Send + Sync + 'static, -{ - type Error = Error; - - async fn get(&self, key: &K) -> Result, Self::Error> { - Ok(self.store.lock().unwrap().get(key).cloned()) - } - async fn set(&self, key: K, value: V) -> Result<(), Self::Error> { - self.store.lock().unwrap().insert(key, value); - Ok(()) - } - async fn del(&self, key: &K) -> Result<(), Self::Error> { - self.store.lock().unwrap().remove(key); - Ok(()) - } - async fn clear(&self) -> Result<(), Self::Error> { - self.store.lock().unwrap().clear(); - Ok(()) - } -} diff --git a/atrium-oauth/oauth-client/src/store/session.rs b/atrium-oauth/oauth-client/src/store/session.rs new file mode 100644 index 00000000..9e0da984 --- /dev/null +++ b/atrium-oauth/oauth-client/src/store/session.rs @@ -0,0 +1,38 @@ +use std::hash::Hash; + +use crate::types::TokenSet; +use atrium_api::types::string::{Datetime, Did}; +use atrium_common::store::{memory::MemoryStore, Store}; +use chrono::TimeDelta; +use jose_jwk::Key; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Session { + pub dpop_key: Key, + pub token_set: TokenSet, +} + +impl Session { + pub fn expires_in(&self) -> Option { + self.token_set.expires_at.as_ref().map(Datetime::as_ref).map(|expires_at| { + expires_at.signed_duration_since(Datetime::now().as_ref()).max(TimeDelta::zero()) + }) + } +} + +pub trait SessionStore: Store +where + K: Eq + Hash, + V: Clone, +{ +} + +pub type MemorySessionStore = MemoryStore; + +impl SessionStore for MemorySessionStore +where + K: Eq + Hash + Send + Sync, + V: Clone + Send, +{ +} diff --git a/atrium-oauth/oauth-client/src/store/session_getter.rs b/atrium-oauth/oauth-client/src/store/session_getter.rs new file mode 100644 index 00000000..183ab913 --- /dev/null +++ b/atrium-oauth/oauth-client/src/store/session_getter.rs @@ -0,0 +1,49 @@ +use crate::store::session::{Session, SessionStore}; +use atrium_api::types::string::Did; +use atrium_common::store::Store; +use std::sync::Arc; + +#[derive(Debug)] +pub struct SessionGetter { + store: Arc, +} + +impl SessionGetter { + pub fn new(store: S) -> Self { + Self { store: Arc::new(store) } + } + // TODO: extended store methods? +} + +impl Clone for SessionGetter { + fn clone(&self) -> Self { + Self { store: self.store.clone() } + } +} + +impl Store for SessionGetter +where + S: SessionStore + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, +{ + type Error = S::Error; + async fn get(&self, key: &Did) -> Result, Self::Error> { + self.store.get(key).await + } + async fn set(&self, key: Did, value: Session) -> Result<(), Self::Error> { + self.store.set(key, value).await + } + async fn del(&self, key: &Did) -> Result<(), Self::Error> { + self.store.del(key).await + } + async fn clear(&self) -> Result<(), Self::Error> { + self.store.clear().await + } +} + +impl SessionStore for SessionGetter +where + S: SessionStore + Send + Sync + 'static, + S::Error: std::error::Error + Send + Sync + 'static, +{ +} diff --git a/atrium-oauth/oauth-client/src/store/state.rs b/atrium-oauth/oauth-client/src/store/state.rs index d55e3234..a39a2cb4 100644 --- a/atrium-oauth/oauth-client/src/store/state.rs +++ b/atrium-oauth/oauth-client/src/store/state.rs @@ -1,5 +1,4 @@ -use super::memory::MemorySimpleStore; -use super::SimpleStore; +use atrium_common::store::{memory::MemoryStore, Store}; use jose_jwk::Key; use serde::{Deserialize, Serialize}; @@ -8,10 +7,11 @@ pub struct InternalStateData { pub iss: String, pub dpop_key: Key, pub verifier: String, + pub app_state: Option, } -pub trait StateStore: SimpleStore {} +pub trait StateStore: Store {} -pub type MemoryStateStore = MemorySimpleStore; +pub type MemoryStateStore = MemoryStore; impl StateStore for MemoryStateStore {} diff --git a/atrium-oauth/oauth-client/src/types.rs b/atrium-oauth/oauth-client/src/types.rs index a5712674..4d84a806 100644 --- a/atrium-oauth/oauth-client/src/types.rs +++ b/atrium-oauth/oauth-client/src/types.rs @@ -9,11 +9,12 @@ pub use client_metadata::{OAuthClientMetadata, TryIntoOAuthClientMetadata}; pub use metadata::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetadata}; pub use request::{ AuthorizationCodeChallengeMethod, AuthorizationResponseType, - PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType, - TokenRequestParameters, + PushedAuthorizationRequestParameters, RefreshRequestParameters, RevocationRequestParameters, + TokenGrantType, TokenRequestParameters, }; pub use response::{OAuthPusehedAuthorizationRequestResponse, OAuthTokenResponse}; use serde::Deserialize; +#[allow(unused_imports)] pub use token::TokenSet; #[derive(Debug, Deserialize)] diff --git a/atrium-oauth/oauth-client/src/types/request.rs b/atrium-oauth/oauth-client/src/types/request.rs index d8d352e6..80d44a55 100644 --- a/atrium-oauth/oauth-client/src/types/request.rs +++ b/atrium-oauth/oauth-client/src/types/request.rs @@ -45,6 +45,7 @@ pub struct PushedAuthorizationRequestParameters { pub prompt: Option, } +// https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 #[derive(Serialize)] #[serde(rename_all = "snake_case")] pub enum TokenGrantType { @@ -70,3 +71,9 @@ pub struct RefreshRequestParameters { pub refresh_token: String, pub scope: Option, } + +#[allow(dead_code)] +#[derive(Serialize)] +pub struct RevocationRequestParameters { + pub token: String, +} diff --git a/atrium-oauth/oauth-client/src/types/token.rs b/atrium-oauth/oauth-client/src/types/token.rs index 069e9fef..d09736e0 100644 --- a/atrium-oauth/oauth-client/src/types/token.rs +++ b/atrium-oauth/oauth-client/src/types/token.rs @@ -1,11 +1,11 @@ use super::response::OAuthTokenType; -use atrium_api::types::string::Datetime; +use atrium_api::types::string::{Datetime, Did}; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct TokenSet { pub iss: String, - pub sub: String, + pub sub: Did, pub aud: String, pub scope: Option, diff --git a/atrium-xrpc/src/error.rs b/atrium-xrpc/src/error.rs index 375aac8b..4135cb6d 100644 --- a/atrium-xrpc/src/error.rs +++ b/atrium-xrpc/src/error.rs @@ -19,6 +19,8 @@ where SerdeJson(#[from] serde_json::Error), #[error("serde_html_form error: {0}")] SerdeHtmlForm(#[from] serde_html_form::ser::Error), + #[error("session store error: {0}")] + SessionStore(Box), #[error("unexpected response type")] UnexpectedResponseType, } diff --git a/atrium-xrpc/src/traits.rs b/atrium-xrpc/src/traits.rs index 13d65df9..98d3212a 100644 --- a/atrium-xrpc/src/traits.rs +++ b/atrium-xrpc/src/traits.rs @@ -1,4 +1,5 @@ -use crate::error::{Error, XrpcError, XrpcErrorKind}; +use crate::error::Error; +use crate::error::{XrpcError, XrpcErrorKind}; use crate::types::{AuthorizationToken, Header, NSID_REFRESH_SESSION}; use crate::{InputDataOrBytes, OutputDataOrBytes, XrpcRequest}; use http::{Method, Request, Response}; diff --git a/bsky-sdk/Cargo.toml b/bsky-sdk/Cargo.toml index 7f6fb63a..833cbe03 100644 --- a/bsky-sdk/Cargo.toml +++ b/bsky-sdk/Cargo.toml @@ -14,6 +14,7 @@ keywords = ["atproto", "bluesky", "atrium", "sdk"] [dependencies] anyhow.workspace = true atrium-api = { workspace = true, features = ["agent", "bluesky"] } +atrium-common.workspace = true atrium-xrpc-client = { workspace = true, optional = true } chrono.workspace = true psl = { version = "2.1.42", optional = true } diff --git a/bsky-sdk/src/agent.rs b/bsky-sdk/src/agent.rs index e7030ddd..104fb9d6 100644 --- a/bsky-sdk/src/agent.rs +++ b/bsky-sdk/src/agent.rs @@ -2,17 +2,18 @@ mod builder; pub mod config; -pub use self::builder::BskyAgentBuilder; +pub use self::builder::BskyAtpAgentBuilder; use self::config::Config; use crate::error::Result; use crate::moderation::util::interpret_label_value_definitions; use crate::moderation::{ModerationPrefsLabeler, Moderator}; use crate::preference::{FeedViewPreferenceData, Preferences, ThreadViewPreferenceData}; -use atrium_api::agent::store::MemorySessionStore; -use atrium_api::agent::{store::SessionStore, AtpAgent}; +use atrium_api::agent::atp_agent::{AtpAgent, AtpSession}; use atrium_api::app::bsky::actor::defs::PreferencesItem; use atrium_api::types::{Object, Union}; use atrium_api::xrpc::XrpcClient; +use atrium_common::store::memory::MemoryStore; +use atrium_common::store::Store; #[cfg(feature = "default-client")] use atrium_xrpc_client::reqwest::ReqwestClient; use std::collections::HashMap; @@ -21,8 +22,8 @@ use std::sync::Arc; /// A Bluesky agent. /// -/// This agent is a wrapper around the [`AtpAgent`] that provides additional functionality for working with Bluesky. -/// For creating an instance of this agent, use the [`BskyAgentBuilder`]. +/// This agent is a wrapper around the [`Agent`](atrium_api::agent::Agent) that provides additional functionality for working with Bluesky. +/// For creating an instance of this agent, use the [`BskyAtpAgentBuilder`]. /// /// # Example /// @@ -37,19 +38,21 @@ use std::sync::Arc; #[cfg(feature = "default-client")] #[derive(Clone)] -pub struct BskyAgent +pub struct BskyAgent> where T: XrpcClient + Send + Sync, - S: SessionStore + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, { inner: Arc>, } #[cfg(not(feature = "default-client"))] -pub struct BskyAgent +pub struct BskyAgent> where T: XrpcClient + Send + Sync, - S: SessionStore + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, { inner: Arc>, } @@ -57,16 +60,17 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "default-client")))] #[cfg(feature = "default-client")] impl BskyAgent { - /// Create a new [`BskyAgentBuilder`] with the default client and session store. - pub fn builder() -> BskyAgentBuilder { - BskyAgentBuilder::default() + /// Create a new [`BskyAtpAgentBuilder`] with the default client and session store. + pub fn builder() -> BskyAtpAgentBuilder> { + BskyAtpAgentBuilder::default() } } impl BskyAgent where T: XrpcClient + Send + Sync, - S: SessionStore + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, { /// Get the agent's current state as a [`Config`]. pub async fn to_config(&self) -> Config { @@ -248,7 +252,8 @@ where impl Deref for BskyAgent where T: XrpcClient + Send + Sync, - S: SessionStore + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, { type Target = AtpAgent; @@ -260,19 +265,24 @@ where #[cfg(test)] mod tests { use super::*; - use atrium_api::agent::Session; + use atrium_api::agent::atp_agent::AtpSession; #[derive(Clone)] struct NoopStore; - impl SessionStore for NoopStore { - async fn get_session(&self) -> Option { + impl Store<(), AtpSession> for NoopStore { + type Error = std::convert::Infallible; + + async fn get(&self, _key: &()) -> core::result::Result, Self::Error> { + unimplemented!() + } + async fn set(&self, _key: (), _value: AtpSession) -> core::result::Result<(), Self::Error> { unimplemented!() } - async fn set_session(&self, _: Session) { + async fn del(&self, _key: &()) -> core::result::Result<(), Self::Error> { unimplemented!() } - async fn clear_session(&self) { + async fn clear(&self) -> core::result::Result<(), Self::Error> { unimplemented!() } } diff --git a/bsky-sdk/src/agent/builder.rs b/bsky-sdk/src/agent/builder.rs index 9e333181..7d3c4485 100644 --- a/bsky-sdk/src/agent/builder.rs +++ b/bsky-sdk/src/agent/builder.rs @@ -1,38 +1,40 @@ use super::config::Config; use super::BskyAgent; use crate::error::Result; -use atrium_api::agent::store::MemorySessionStore; -use atrium_api::agent::{store::SessionStore, AtpAgent}; +use atrium_api::agent::atp_agent::{AtpAgent, AtpSession}; use atrium_api::xrpc::XrpcClient; +use atrium_common::store::memory::MemoryStore; +use atrium_common::store::Store; #[cfg(feature = "default-client")] use atrium_xrpc_client::reqwest::ReqwestClient; use std::sync::Arc; -/// A builder for creating a [`BskyAgent`]. -pub struct BskyAgentBuilder +/// A builder for creating a [`BskyAtpAgent`]. +pub struct BskyAtpAgentBuilder> where T: XrpcClient + Send + Sync, - S: SessionStore + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, { config: Config, store: S, client: T, } -impl BskyAgentBuilder +impl BskyAtpAgentBuilder where T: XrpcClient + Send + Sync, { /// Create a new builder with the given XRPC client. pub fn new(client: T) -> Self { - Self { config: Config::default(), store: MemorySessionStore::default(), client } + Self { config: Config::default(), store: MemoryStore::default(), client } } } -impl BskyAgentBuilder +impl BskyAtpAgentBuilder where T: XrpcClient + Send + Sync, - S: SessionStore + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, { /// Set the configuration for the agent. pub fn config(mut self, config: Config) -> Self { @@ -42,20 +44,20 @@ where /// Set the session store for the agent. /// /// Returns a new builder with the session store set. - pub fn store(self, store: S0) -> BskyAgentBuilder + pub fn store(self, store: S0) -> BskyAtpAgentBuilder where - S0: SessionStore + Send + Sync, + S0: Store<(), AtpSession> + Send + Sync, { - BskyAgentBuilder { config: self.config, store, client: self.client } + BskyAtpAgentBuilder { config: self.config, store, client: self.client } } /// Set the XRPC client for the agent. /// /// Returns a new builder with the XRPC client set. - pub fn client(self, client: T0) -> BskyAgentBuilder + pub fn client(self, client: T0) -> BskyAtpAgentBuilder where T0: XrpcClient + Send + Sync, { - BskyAgentBuilder { config: self.config, store: self.store, client } + BskyAtpAgentBuilder { config: self.config, store: self.store, client } } pub async fn build(self) -> Result> { let agent = AtpAgent::new(self.client, self.store); @@ -91,10 +93,10 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "default-client")))] #[cfg(feature = "default-client")] -impl Default for BskyAgentBuilder { +impl Default for BskyAtpAgentBuilder> { /// Create a new builder with the default client and session store. /// - /// Default client is [`ReqwestClient`] and default session store is [`MemorySessionStore`]. + /// Default client is [`ReqwestClient`] and default session store is [`MemoryStore`]. fn default() -> Self { Self::new(ReqwestClient::new(Config::default().endpoint)) } @@ -103,10 +105,10 @@ impl Default for BskyAgentBuilder { #[cfg(test)] mod tests { use super::*; - use atrium_api::agent::Session; + use atrium_api::agent::atp_agent::AtpSession; use atrium_api::com::atproto::server::create_session::OutputData; - fn session() -> Session { + fn session() -> AtpSession { OutputData { access_jwt: String::new(), active: None, @@ -124,12 +126,21 @@ mod tests { struct MockSessionStore; - impl SessionStore for MockSessionStore { - async fn get_session(&self) -> Option { - Some(session()) + impl Store<(), AtpSession> for MockSessionStore { + type Error = std::convert::Infallible; + + async fn get(&self, _key: &()) -> core::result::Result, Self::Error> { + Ok(Some(session())) + } + async fn set(&self, _key: (), _value: AtpSession) -> core::result::Result<(), Self::Error> { + Ok(()) + } + async fn del(&self, _key: &()) -> core::result::Result<(), Self::Error> { + Ok(()) + } + async fn clear(&self) -> core::result::Result<(), Self::Error> { + Ok(()) } - async fn set_session(&self, _: Session) {} - async fn clear_session(&self) {} } #[cfg(feature = "default-client")] @@ -137,13 +148,13 @@ mod tests { async fn default() -> Result<()> { // default build { - let agent = BskyAgentBuilder::default().build().await?; + let agent = BskyAtpAgentBuilder::default().build().await?; assert_eq!(agent.get_endpoint().await, "https://bsky.social"); assert_eq!(agent.get_session().await, None); } // with store { - let agent = BskyAgentBuilder::default().store(MockSessionStore).build().await?; + let agent = BskyAtpAgentBuilder::default().store(MockSessionStore).build().await?; assert_eq!(agent.get_endpoint().await, "https://bsky.social"); assert_eq!( agent.get_session().await.map(|session| session.data.handle), @@ -152,7 +163,7 @@ mod tests { } // with config { - let agent = BskyAgentBuilder::default() + let agent = BskyAtpAgentBuilder::default() .config(Config { endpoint: "https://example.com".to_string(), ..Default::default() @@ -172,12 +183,13 @@ mod tests { // default build { - let agent = BskyAgentBuilder::new(MockClient).build().await?; + let agent = BskyAtpAgentBuilder::new(MockClient).build().await?; assert_eq!(agent.get_endpoint().await, "https://bsky.social"); } // with store { - let agent = BskyAgentBuilder::new(MockClient).store(MockSessionStore).build().await?; + let agent = + BskyAtpAgentBuilder::new(MockClient).store(MockSessionStore).build().await?; assert_eq!(agent.get_endpoint().await, "https://bsky.social"); assert_eq!( agent.get_session().await.map(|session| session.data.handle), @@ -186,7 +198,7 @@ mod tests { } // with config { - let agent = BskyAgentBuilder::new(MockClient) + let agent = BskyAtpAgentBuilder::new(MockClient) .config(Config { endpoint: "https://example.com".to_string(), ..Default::default() diff --git a/bsky-sdk/src/agent/config.rs b/bsky-sdk/src/agent/config.rs index a804e729..51f5951f 100644 --- a/bsky-sdk/src/agent/config.rs +++ b/bsky-sdk/src/agent/config.rs @@ -1,12 +1,11 @@ //! Configuration for the [`BskyAgent`](super::BskyAgent). mod file; -use std::future::Future; - +pub use self::file::FileStore; use crate::error::{Error, Result}; -use atrium_api::agent::Session; -pub use file::FileStore; +use atrium_api::agent::atp_agent::AtpSession; use serde::{Deserialize, Serialize}; +use std::future::Future; /// Configuration data struct for the [`BskyAgent`](super::BskyAgent). #[derive(Debug, Clone, Serialize, Deserialize)] @@ -14,7 +13,7 @@ pub struct Config { /// The base URL for the XRPC endpoint. pub endpoint: String, /// The session data. - pub session: Option, + pub session: Option, /// The labelers header values. pub labelers_header: Option>, /// The proxy header for service proxying. diff --git a/bsky-sdk/src/record.rs b/bsky-sdk/src/record.rs index 1a3cac92..7a7bba1d 100644 --- a/bsky-sdk/src/record.rs +++ b/bsky-sdk/src/record.rs @@ -5,18 +5,20 @@ use std::future::Future; use crate::error::{Error, Result}; use crate::BskyAgent; -use atrium_api::agent::store::SessionStore; +use atrium_api::agent::atp_agent::AtpSession; use atrium_api::com::atproto::repo::{ create_record, delete_record, get_record, list_records, put_record, }; use atrium_api::types::{Collection, LimitedNonZeroU8, TryIntoUnknown}; use atrium_api::xrpc::XrpcClient; +use atrium_common::store::Store; #[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] pub trait Record where T: XrpcClient + Send + Sync, - S: SessionStore + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, { fn list( agent: &BskyAgent, @@ -45,7 +47,8 @@ macro_rules! record_impl { impl Record for $record where T: XrpcClient + Send + Sync, - S: SessionStore + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, { async fn list( agent: &BskyAgent, @@ -162,7 +165,8 @@ macro_rules! record_impl { impl Record for $record_data where T: XrpcClient + Send + Sync, - S: SessionStore + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, { async fn list( agent: &BskyAgent, @@ -273,9 +277,9 @@ record_impl!( #[cfg(test)] mod tests { use super::*; - use crate::agent::BskyAgentBuilder; + use crate::agent::BskyAtpAgentBuilder; use crate::tests::FAKE_CID; - use atrium_api::agent::Session; + use atrium_api::agent::atp_agent::AtpSession; use atrium_api::com::atproto::server::create_session::OutputData; use atrium_api::types::string::Datetime; use atrium_api::xrpc::http::{Request, Response}; @@ -321,9 +325,11 @@ mod tests { struct MockSessionStore; - impl SessionStore for MockSessionStore { - async fn get_session(&self) -> Option { - Some( + impl Store<(), AtpSession> for MockSessionStore { + type Error = std::convert::Infallible; + + async fn get(&self, _key: &()) -> core::result::Result, Self::Error> { + Ok(Some( OutputData { access_jwt: String::from("access"), active: None, @@ -337,15 +343,22 @@ mod tests { status: None, } .into(), - ) + )) + } + async fn set(&self, _key: (), _value: AtpSession) -> core::result::Result<(), Self::Error> { + Ok(()) + } + async fn del(&self, _key: &()) -> core::result::Result<(), Self::Error> { + Ok(()) + } + async fn clear(&self) -> core::result::Result<(), Self::Error> { + Ok(()) } - async fn set_session(&self, _: Session) {} - async fn clear_session(&self) {} } #[tokio::test] async fn actor_profile() -> Result<()> { - let agent = BskyAgentBuilder::new(MockClient).store(MockSessionStore).build().await?; + let agent = BskyAtpAgentBuilder::new(MockClient).store(MockSessionStore).build().await?; // create let output = atrium_api::app::bsky::actor::profile::RecordData { avatar: None, @@ -377,7 +390,7 @@ mod tests { #[tokio::test] async fn feed_post() -> Result<()> { - let agent = BskyAgentBuilder::new(MockClient).store(MockSessionStore).build().await?; + let agent = BskyAtpAgentBuilder::new(MockClient).store(MockSessionStore).build().await?; // create let output = atrium_api::app::bsky::feed::post::RecordData { created_at: Datetime::now(), @@ -409,7 +422,7 @@ mod tests { #[tokio::test] async fn graph_follow() -> Result<()> { - let agent = BskyAgentBuilder::new(MockClient).store(MockSessionStore).build().await?; + let agent = BskyAtpAgentBuilder::new(MockClient).store(MockSessionStore).build().await?; // create let output = atrium_api::app::bsky::graph::follow::RecordData { created_at: Datetime::now(), diff --git a/bsky-sdk/src/record/agent.rs b/bsky-sdk/src/record/agent.rs index 30a2f626..00fac1ae 100644 --- a/bsky-sdk/src/record/agent.rs +++ b/bsky-sdk/src/record/agent.rs @@ -1,16 +1,18 @@ use super::Record; use crate::error::{Error, Result}; use crate::BskyAgent; -use atrium_api::agent::store::SessionStore; +use atrium_api::agent::atp_agent::AtpSession; use atrium_api::com::atproto::repo::{create_record, delete_record}; use atrium_api::record::KnownRecord; use atrium_api::types::string::RecordKey; use atrium_api::xrpc::XrpcClient; +use atrium_common::store::Store; impl BskyAgent where T: XrpcClient + Send + Sync, - S: SessionStore + Send + Sync, + S: Store<(), AtpSession> + Send + Sync, + S::Error: Send + Sync + 'static, { /// Create a record with various types of data. /// For example, the Record families defined in [`KnownRecord`](atrium_api::record::KnownRecord) are supported. diff --git a/bsky-sdk/src/rich_text.rs b/bsky-sdk/src/rich_text.rs index f1783722..6bf6bd9e 100644 --- a/bsky-sdk/src/rich_text.rs +++ b/bsky-sdk/src/rich_text.rs @@ -2,7 +2,7 @@ mod detection; use crate::agent::config::Config; -use crate::agent::BskyAgentBuilder; +use crate::agent::BskyAtpAgentBuilder; use crate::error::Result; use atrium_api::app::bsky::richtext::facet::{ ByteSliceData, Link, MainFeaturesItem, Mention, MentionData, Tag, @@ -204,7 +204,7 @@ impl RichText { } /// Detect facets in the text and set them. pub async fn detect_facets(&mut self, client: impl XrpcClient + Send + Sync) -> Result<()> { - let agent = BskyAgentBuilder::new(client) + let agent = BskyAtpAgentBuilder::new(client) .config(Config { endpoint: PUBLIC_API_ENDPOINT.into(), ..Default::default() }) .build() .await?;