Skip to content

Commit

Permalink
deprecate CellStore
Browse files Browse the repository at this point in the history
  • Loading branch information
avdb13 committed Nov 20, 2024
1 parent cf88233 commit 8147a40
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 160 deletions.
76 changes: 38 additions & 38 deletions atrium-api/src/agent/atp_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
did_doc::DidDocument,
types::{string::Did, TryFromUnknown},
};
use atrium_common::store::CellStore;
use atrium_common::store::MapStore;
use atrium_xrpc::{Error, HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
use http::{Request, Response};
use serde::{de::DeserializeOwned, Serialize};
Expand All @@ -20,7 +20,7 @@ pub type AtpSession = crate::com::atproto::server::create_session::Output;

pub struct CredentialSession<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
store: Arc<inner::Store<S>>,
Expand All @@ -30,7 +30,7 @@ where

impl<S, T> CredentialSession<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
pub fn new(xrpc: T, store: S) -> Self {
Expand Down Expand Up @@ -60,7 +60,7 @@ where
.into(),
)
.await?;
self.store.set(result.clone()).await.expect("todo");
self.store.set((), result.clone()).await.expect("todo");
if let Some(did_doc) = result
.did_doc
.as_ref()
Expand All @@ -75,17 +75,17 @@ where
&self,
session: AtpSession,
) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> {
self.store.set(session.clone()).await.expect("todo");
self.store.set((), session.clone()).await.expect("todo");
let result = self.atproto_service.server.get_session().await;
match result {
Ok(output) => {
assert_eq!(output.data.did, session.data.did);
if let Some(mut session) = self.store.get().await.expect("todo") {
if let Some(mut session) = self.store.get(&()).await.expect("todo") {
session.did_doc = output.data.did_doc.clone();
session.email = output.data.email;
session.email_confirmed = output.data.email_confirmed;
session.handle = output.data.handle;
self.store.set(session).await.expect("todo");
self.store.set((), session).await.expect("todo");
}
if let Some(did_doc) = output
.data
Expand Down Expand Up @@ -127,7 +127,7 @@ where
}
/// Get the current session.
pub async fn get_session(&self) -> Option<AtpSession> {
self.store.get().await.expect("todo")
self.store.get(&()).await.expect("todo")
}
/// Get the current endpoint.
pub async fn get_endpoint(&self) -> String {
Expand All @@ -145,7 +145,7 @@ where

impl<S, T> HttpClient for CredentialSession<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
async fn send_http(
Expand All @@ -158,7 +158,7 @@ where

impl<S, T> XrpcClient for CredentialSession<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
fn base_uri(&self) -> String {
Expand All @@ -180,11 +180,11 @@ where

impl<S, T> SessionManager for CredentialSession<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
async fn did(&self) -> Option<Did> {
self.store.get().await.expect("todo").map(|session| session.data.did)
self.store.get(&()).await.expect("todo").map(|session| session.data.did)
}
}

Expand All @@ -196,18 +196,18 @@ where
/// ```
/// use atrium_api::agent::atp_agent::CredentialSession;
/// use atrium_api::agent::Agent;
/// use atrium_common::store::{memory::MemoryCellStore, CellStore};
/// use atrium_common::store::{memory::MemoryMapStore, MapStore};
/// use atrium_xrpc_client::reqwest::ReqwestClient;
///
/// let session = CredentialSession::new(
/// ReqwestClient::new("https://bsky.social"),
/// MemoryCellStore::default(),
/// MemoryMapStore::default(),
/// );
/// let agent = Agent::new(session);
/// ```
pub struct AtpAgent<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
session_manager: Wrapper<CredentialSession<S, T>>,
Expand All @@ -216,7 +216,7 @@ where

impl<S, T> AtpAgent<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
/// Create a new agent.
Expand Down Expand Up @@ -282,7 +282,7 @@ where

impl<S, T> Deref for AtpAgent<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
type Target = Agent<Wrapper<CredentialSession<S, T>>>;
Expand All @@ -299,7 +299,7 @@ mod tests {
use crate::com::atproto::server::create_session::OutputData;
use crate::did_doc::{DidDocument, Service, VerificationMethod};
use crate::types::TryIntoUnknown;
use atrium_common::store::memory::MemoryCellStore;
use atrium_common::store::memory::MemoryMapStore;
use atrium_xrpc::HttpClient;
use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
use std::collections::HashMap;
Expand Down Expand Up @@ -427,7 +427,7 @@ mod tests {
#[tokio::test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
async fn test_new() {
let agent = AtpAgent::new(MockClient::default(), MemoryCellStore::default());
let agent = AtpAgent::new(MockClient::default(), MemoryMapStore::default());
assert_eq!(agent.get_session().await, None);
}

Expand All @@ -446,7 +446,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryCellStore::default());
let agent = AtpAgent::new(client, MemoryMapStore::default());
agent.login("test", "pass").await.expect("login should be succeeded");
assert_eq!(agent.get_session().await, Some(session_data.into()));
}
Expand All @@ -456,7 +456,7 @@ mod tests {
responses: MockResponses { ..Default::default() },
..Default::default()
};
let agent = AtpAgent::new(client, MemoryCellStore::default());
let agent = AtpAgent::new(client, MemoryMapStore::default());
agent.login("test", "bad").await.expect_err("login should be failed");
assert_eq!(agent.get_session().await, None);
}
Expand All @@ -482,8 +482,8 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryCellStore::default());
agent.session_manager.store.set(session_data.clone().into()).await.expect("todo");
let agent = AtpAgent::new(client, MemoryMapStore::default());
agent.session_manager.store.set((), session_data.clone().into()).await.expect("todo");
let output = agent
.api
.com
Expand Down Expand Up @@ -516,8 +516,8 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryCellStore::default());
agent.session_manager.store.set(session_data.clone().into()).await.expect("todo");
let agent = AtpAgent::new(client, MemoryMapStore::default());
agent.session_manager.store.set((), session_data.clone().into()).await.expect("todo");
let output = agent
.api
.com
Expand All @@ -531,7 +531,7 @@ mod tests {
agent
.session_manager
.store
.get()
.get(&())
.await
.expect("todo")
.map(|session| session.data.access_jwt),
Expand Down Expand Up @@ -561,8 +561,8 @@ mod tests {
..Default::default()
};
let counts = Arc::clone(&client.counts);
let agent = Arc::new(AtpAgent::new(client, MemoryCellStore::default()));
agent.session_manager.store.set(session_data.clone().into()).await.expect("todo");
let agent = Arc::new(AtpAgent::new(client, MemoryMapStore::default()));
agent.session_manager.store.set((), session_data.clone().into()).await.expect("todo");
let handles = (0..3).map(|_| {
let agent = Arc::clone(&agent);
tokio::spawn(async move { agent.api.com.atproto.server.get_session().await })
Expand All @@ -580,7 +580,7 @@ mod tests {
agent
.session_manager
.store
.get()
.get(&())
.await
.expect("todo")
.map(|session| session.data.access_jwt),
Expand Down Expand Up @@ -617,7 +617,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryCellStore::default());
let agent = AtpAgent::new(client, MemoryMapStore::default());
assert_eq!(agent.get_session().await, None);
agent
.resume_session(
Expand All @@ -637,7 +637,7 @@ mod tests {
responses: MockResponses { ..Default::default() },
..Default::default()
};
let agent = AtpAgent::new(client, MemoryCellStore::default());
let agent = AtpAgent::new(client, MemoryMapStore::default());
assert_eq!(agent.get_session().await, None);
agent
.resume_session(session_data.clone().into())
Expand Down Expand Up @@ -667,7 +667,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryCellStore::default());
let agent = AtpAgent::new(client, MemoryMapStore::default());
agent
.resume_session(
OutputData { access_jwt: "expired".into(), ..session_data.clone() }.into(),
Expand Down Expand Up @@ -716,7 +716,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryCellStore::default());
let agent = AtpAgent::new(client, MemoryMapStore::default());
agent.login("test", "pass").await.expect("login should be succeeded");
assert_eq!(agent.get_endpoint().await, "https://bsky.social");
assert_eq!(agent.api.com.atproto.server.xrpc.base_uri(), "https://bsky.social");
Expand Down Expand Up @@ -751,7 +751,7 @@ mod tests {
},
..Default::default()
};
let agent = AtpAgent::new(client, MemoryCellStore::default());
let agent = AtpAgent::new(client, MemoryMapStore::default());
agent.login("test", "pass").await.expect("login should be succeeded");
// not updated
assert_eq!(agent.get_endpoint().await, "http://localhost:8080");
Expand All @@ -764,7 +764,7 @@ mod tests {
async fn test_configure_labelers_header() {
let client = MockClient::default();
let headers = Arc::clone(&client.headers);
let agent = AtpAgent::new(client, MemoryCellStore::default());
let agent = AtpAgent::new(client, MemoryMapStore::default());

agent
.api
Expand Down Expand Up @@ -827,7 +827,7 @@ mod tests {
async fn test_configure_proxy_header() {
let client = MockClient::default();
let headers = Arc::clone(&client.headers);
let agent = AtpAgent::new(client, MemoryCellStore::default());
let agent = AtpAgent::new(client, MemoryMapStore::default());

agent
.api
Expand Down Expand Up @@ -925,9 +925,9 @@ mod tests {
async fn test_agent_did() {
let session_data = session_data();
let client = MockClient { responses: MockResponses::default(), ..Default::default() };
let agent = AtpAgent::new(client, MemoryCellStore::default());
let agent = AtpAgent::new(client, MemoryMapStore::default());
assert_eq!(agent.did().await, None);
agent.session_manager.store.set(session_data.clone().into()).await.expect("todo");
agent.session_manager.store.set((), session_data.clone().into()).await.expect("todo");
assert_eq!(agent.did().await, Some(session_data.did));
}
}
32 changes: 17 additions & 15 deletions atrium-api/src/agent/atp_agent/inner.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::did_doc::DidDocument;
use crate::types::string::Did;
use crate::types::TryFromUnknown;
use atrium_common::store::CellStore;
use atrium_common::store::MapStore;
use atrium_xrpc::error::{Error, Result, XrpcErrorKind};
use atrium_xrpc::types::JwtTokenType;
use atrium_xrpc::{HttpClient, OutputDataOrBytes, XrpcClient, XrpcRequest};
Expand Down Expand Up @@ -69,14 +69,14 @@ where

impl<S, T> XrpcClient for WrapperClient<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
fn base_uri(&self) -> String {
self.store.get_endpoint()
}
async fn authentication_token(&self, is_refresh: bool) -> Option<(JwtTokenType, String)> {
self.store.get().await.expect("todo").map(|session| {
self.store.get(&()).await.expect("todo").map(|session| {
if is_refresh {
(JwtTokenType::Bearer, session.data.refresh_jwt)
} else {
Expand All @@ -101,7 +101,7 @@ pub struct Client<S, T> {

impl<S, T> Client<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
pub fn new(store: Arc<Store<S>>, xrpc: T) -> Self {
Expand Down Expand Up @@ -156,13 +156,13 @@ where
}
async fn refresh_session_inner(&self) {
if let Ok(output) = self.call_refresh_session().await {
if let Some(mut session) = self.store.get().await.expect("todo") {
if let Some(mut session) = self.store.get(&()).await.expect("todo") {
session.access_jwt = output.data.access_jwt;
session.did = output.data.did;
session.did_doc = output.data.did_doc.clone();
session.handle = output.data.handle;
session.refresh_jwt = output.data.refresh_jwt;
self.store.set(session).await.expect("todo");
self.store.set((), session).await.expect("todo");
}
if let Some(did_doc) = output
.data
Expand Down Expand Up @@ -216,7 +216,7 @@ where

impl<S, T> Clone for Client<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
fn clone(&self) -> Self {
Expand Down Expand Up @@ -245,7 +245,7 @@ where

impl<S, T> XrpcClient for Client<S, T>
where
S: CellStore<AtpSession> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
T: XrpcClient + Send + Sync,
{
fn base_uri(&self) -> String {
Expand Down Expand Up @@ -291,18 +291,20 @@ impl<S> Store<S> {
}
}

impl<S, V> CellStore<V> for Store<S>
impl<S> MapStore<(), AtpSession> for Store<S>
where
V: Clone + Send + Sync,
S: CellStore<V> + Send + Sync,
S: MapStore<(), AtpSession> + Send + Sync,
{
type Error = S::Error;

async fn get(&self) -> core::result::Result<Option<V>, Self::Error> {
self.inner.get().await
async fn get(&self, key: &()) -> core::result::Result<Option<AtpSession>, Self::Error> {
self.inner.get(key).await
}
async fn set(&self, value: V) -> core::result::Result<(), Self::Error> {
self.inner.set(value).await
async fn set(&self, key: (), value: AtpSession) -> core::result::Result<(), Self::Error> {
self.inner.set(key, value).await
}
async fn del(&self, key: &()) -> core::result::Result<(), Self::Error> {
self.inner.del(key).await
}
async fn clear(&self) -> core::result::Result<(), Self::Error> {
self.inner.clear().await
Expand Down
Loading

0 comments on commit 8147a40

Please sign in to comment.