Skip to content

Commit de18c4f

Browse files
authored
feat: Extract common abstractions (#244)
* initialize crate * add `Cached` trait * add `Throttled` trait * add resolvers * add store * add workflows * fix `atrium-oauth` * add error conversions * change type visibility * fix identity crate * fix oauth-client crate * small fix * mofify crate authors * change `Resolver` type signature * apply suggestions * fix `Throttled` tests * fix wasm `CacheTrait`
1 parent ac1ad3f commit de18c4f

37 files changed

+574
-427
lines changed

Diff for: .github/workflows/common.yml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name: Common
2+
on:
3+
push:
4+
branches: ["main"]
5+
pull_request:
6+
branches: ["main"]
7+
env:
8+
CARGO_TERM_COLOR: always
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
- name: Build
15+
run: |
16+
cargo build -p atrium-common --verbose
17+
- name: Run tests
18+
run: |
19+
cargo test -p atrium-common --lib

Diff for: .github/workflows/wasm.yml

+1
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,4 @@ jobs:
6767
- run: wasm-pack test --node atrium-xrpc
6868
- run: wasm-pack test --node atrium-xrpc-client
6969
- run: wasm-pack test --node atrium-oauth/identity
70+
- run: wasm-pack test --node atrium-common

Diff for: Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[workspace]
22
members = [
33
"atrium-api",
4+
"atrium-common",
45
"atrium-crypto",
56
"atrium-xrpc",
67
"atrium-xrpc-client",
@@ -26,6 +27,7 @@ keywords = ["atproto", "bluesky"]
2627
[workspace.dependencies]
2728
# Intra-workspace dependencies
2829
atrium-api = { version = "0.24.8", path = "atrium-api", default-features = false }
30+
atrium-common = { version = "0.1.0", path = "atrium-common" }
2931
atrium-identity = { version = "0.1.0", path = "atrium-oauth/identity" }
3032
atrium-xrpc = { version = "0.12.0", path = "atrium-xrpc" }
3133
atrium-xrpc-client = { version = "0.5.10", path = "atrium-xrpc-client" }

Diff for: atrium-common/Cargo.toml

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
[package]
2+
name = "atrium-common"
3+
version = "0.1.0"
4+
authors = ["sugyan <[email protected]>", "avdb13 <[email protected]>"]
5+
edition.workspace = true
6+
rust-version.workspace = true
7+
description = "Utility library for common abstractions in atproto"
8+
documentation = "https://docs.rs/atrium-common"
9+
readme = "README.md"
10+
repository.workspace = true
11+
license.workspace = true
12+
keywords = ["atproto", "bluesky"]
13+
14+
[dependencies]
15+
dashmap.workspace = true
16+
thiserror.workspace = true
17+
tokio = { workspace = true, default-features = false, features = ["sync"] }
18+
trait-variant.workspace = true
19+
20+
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
21+
moka = { workspace = true, features = ["future"] }
22+
23+
[target.'cfg(target_arch = "wasm32")'.dependencies]
24+
lru.workspace = true
25+
web-time.workspace = true
26+
27+
[dev-dependencies]
28+
futures.workspace = true
29+
30+
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
31+
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] }
32+
33+
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
34+
gloo-timers.workspace = true
35+
tokio = { workspace = true, features = ["time"] }
36+
wasm-bindgen-test.workspace = true

Diff for: atrium-common/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pub mod resolver;
2+
pub mod store;
3+
pub mod types;

Diff for: atrium-common/src/resolver.rs

+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
mod cached;
2+
mod throttled;
3+
4+
pub use self::cached::CachedResolver;
5+
pub use self::throttled::ThrottledResolver;
6+
use std::future::Future;
7+
8+
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
9+
pub trait Resolver {
10+
type Input: ?Sized;
11+
type Output;
12+
type Error;
13+
14+
fn resolve(
15+
&self,
16+
input: &Self::Input,
17+
) -> impl Future<Output = core::result::Result<Self::Output, Self::Error>>;
18+
}
19+
20+
#[cfg(test)]
21+
mod tests {
22+
use super::*;
23+
use crate::types::cached::r#impl::{Cache, CacheImpl};
24+
use crate::types::cached::{CacheConfig, Cacheable};
25+
use crate::types::throttled::Throttleable;
26+
use std::collections::HashMap;
27+
use std::sync::Arc;
28+
use std::time::Duration;
29+
use tokio::sync::RwLock;
30+
#[cfg(target_arch = "wasm32")]
31+
use wasm_bindgen_test::wasm_bindgen_test;
32+
33+
#[cfg(not(target_arch = "wasm32"))]
34+
async fn sleep(duration: Duration) {
35+
tokio::time::sleep(duration).await;
36+
}
37+
38+
#[cfg(target_arch = "wasm32")]
39+
async fn sleep(duration: Duration) {
40+
gloo_timers::future::sleep(duration).await;
41+
}
42+
43+
#[derive(Debug, PartialEq)]
44+
struct Error;
45+
46+
type Result<T> = core::result::Result<T, Error>;
47+
48+
struct MockResolver {
49+
data: HashMap<String, String>,
50+
counts: Arc<RwLock<HashMap<String, usize>>>,
51+
}
52+
53+
impl Resolver for MockResolver {
54+
type Input = String;
55+
type Output = String;
56+
type Error = Error;
57+
58+
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
59+
sleep(Duration::from_millis(10)).await;
60+
*self.counts.write().await.entry(input.clone()).or_default() += 1;
61+
if let Some(value) = self.data.get(input) {
62+
Ok(value.clone())
63+
} else {
64+
Err(Error)
65+
}
66+
}
67+
}
68+
69+
fn mock_resolver(counts: Arc<RwLock<HashMap<String, usize>>>) -> MockResolver {
70+
MockResolver {
71+
data: [
72+
(String::from("k1"), String::from("v1")),
73+
(String::from("k2"), String::from("v2")),
74+
]
75+
.into_iter()
76+
.collect(),
77+
counts,
78+
}
79+
}
80+
81+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
82+
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
83+
async fn test_no_cached() {
84+
let counts = Arc::new(RwLock::new(HashMap::new()));
85+
let resolver = mock_resolver(counts.clone());
86+
for (input, expected) in [
87+
("k1", Some("v1")),
88+
("k2", Some("v2")),
89+
("k2", Some("v2")),
90+
("k1", Some("v1")),
91+
("k3", None),
92+
("k1", Some("v1")),
93+
("k3", None),
94+
] {
95+
let result = resolver.resolve(&input.to_string()).await;
96+
match expected {
97+
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
98+
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
99+
}
100+
}
101+
assert_eq!(
102+
*counts.read().await,
103+
[(String::from("k1"), 3), (String::from("k2"), 2), (String::from("k3"), 2),]
104+
.into_iter()
105+
.collect()
106+
);
107+
}
108+
109+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
110+
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
111+
async fn test_cached() {
112+
let counts = Arc::new(RwLock::new(HashMap::new()));
113+
let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig::default()));
114+
for (input, expected) in [
115+
("k1", Some("v1")),
116+
("k2", Some("v2")),
117+
("k2", Some("v2")),
118+
("k1", Some("v1")),
119+
("k3", None),
120+
("k1", Some("v1")),
121+
("k3", None),
122+
] {
123+
let result = resolver.resolve(&input.to_string()).await;
124+
match expected {
125+
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
126+
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
127+
}
128+
}
129+
assert_eq!(
130+
*counts.read().await,
131+
[(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 2),]
132+
.into_iter()
133+
.collect()
134+
);
135+
}
136+
137+
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
138+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
139+
async fn test_cached_with_max_capacity() {
140+
let counts = Arc::new(RwLock::new(HashMap::new()));
141+
let resolver = mock_resolver(counts.clone())
142+
.cached(CacheImpl::new(CacheConfig { max_capacity: Some(1), ..Default::default() }));
143+
for (input, expected) in [
144+
("k1", Some("v1")),
145+
("k2", Some("v2")),
146+
("k2", Some("v2")),
147+
("k1", Some("v1")),
148+
("k3", None),
149+
("k1", Some("v1")),
150+
("k3", None),
151+
] {
152+
let result = resolver.resolve(&input.to_string()).await;
153+
match expected {
154+
Some(value) => assert_eq!(result.expect("failed to resolve"), value),
155+
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
156+
}
157+
}
158+
assert_eq!(
159+
*counts.read().await,
160+
[(String::from("k1"), 2), (String::from("k2"), 1), (String::from("k3"), 2),]
161+
.into_iter()
162+
.collect()
163+
);
164+
}
165+
166+
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
167+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
168+
async fn test_cached_with_time_to_live() {
169+
let counts = Arc::new(RwLock::new(HashMap::new()));
170+
let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig {
171+
time_to_live: Some(Duration::from_millis(10)),
172+
..Default::default()
173+
}));
174+
for _ in 0..10 {
175+
let result = resolver.resolve(&String::from("k1")).await;
176+
assert_eq!(result.expect("failed to resolve"), "v1");
177+
}
178+
sleep(Duration::from_millis(10)).await;
179+
for _ in 0..10 {
180+
let result = resolver.resolve(&String::from("k1")).await;
181+
assert_eq!(result.expect("failed to resolve"), "v1");
182+
}
183+
assert_eq!(*counts.read().await, [(String::from("k1"), 2)].into_iter().collect());
184+
}
185+
186+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
187+
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
188+
async fn test_throttled() {
189+
let counts = Arc::new(RwLock::new(HashMap::new()));
190+
let resolver = Arc::new(mock_resolver(counts.clone()).throttled());
191+
192+
let mut handles = Vec::new();
193+
for (input, expected) in [
194+
("k1", Some("v1")),
195+
("k2", Some("v2")),
196+
("k2", Some("v2")),
197+
("k1", Some("v1")),
198+
("k3", None),
199+
("k1", Some("v1")),
200+
("k3", None),
201+
] {
202+
let resolver = resolver.clone();
203+
handles.push(async move { (resolver.resolve(&input.to_string()).await, expected) });
204+
}
205+
for (result, expected) in futures::future::join_all(handles).await {
206+
let result = result.and_then(|opt| opt.ok_or(Error));
207+
208+
match expected {
209+
Some(value) => {
210+
assert_eq!(result.expect("failed to resolve"), value)
211+
}
212+
None => assert_eq!(result.expect_err("succesfully resolved"), Error),
213+
}
214+
}
215+
assert_eq!(
216+
*counts.read().await,
217+
[(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 1),]
218+
.into_iter()
219+
.collect()
220+
);
221+
}
222+
}

Diff for: atrium-common/src/resolver/cached.rs

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use std::hash::Hash;
2+
3+
use crate::types::cached::r#impl::{Cache, CacheImpl};
4+
use crate::types::cached::Cached;
5+
6+
use super::Resolver;
7+
8+
pub type CachedResolver<R> = Cached<R, CacheImpl<<R as Resolver>::Input, <R as Resolver>::Output>>;
9+
10+
impl<R, C> Resolver for Cached<R, C>
11+
where
12+
R: Resolver + Send + Sync + 'static,
13+
R::Input: Clone + Hash + Eq + Send + Sync + 'static,
14+
R::Output: Clone + Send + Sync + 'static,
15+
C: Cache<Input = R::Input, Output = R::Output> + Send + Sync + 'static,
16+
C::Input: Clone + Hash + Eq + Send + Sync + 'static,
17+
C::Output: Clone + Send + Sync + 'static,
18+
{
19+
type Input = R::Input;
20+
type Output = R::Output;
21+
type Error = R::Error;
22+
23+
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
24+
if let Some(output) = self.cache.get(input).await {
25+
return Ok(output);
26+
}
27+
let output = self.inner.resolve(input).await?;
28+
self.cache.set(input.clone(), output.clone()).await;
29+
Ok(output)
30+
}
31+
}

Diff for: atrium-common/src/resolver/throttled.rs

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use std::{hash::Hash, sync::Arc};
2+
3+
use dashmap::{DashMap, Entry};
4+
use tokio::sync::broadcast::{channel, Sender};
5+
use tokio::sync::Mutex;
6+
7+
use crate::types::throttled::Throttled;
8+
9+
use super::Resolver;
10+
11+
pub type SenderMap<R> =
12+
DashMap<<R as Resolver>::Input, Arc<Mutex<Sender<Option<<R as Resolver>::Output>>>>>;
13+
14+
pub type ThrottledResolver<R> = Throttled<R, SenderMap<R>>;
15+
16+
impl<R> Resolver for Throttled<R, SenderMap<R>>
17+
where
18+
R: Resolver + Send + Sync + 'static,
19+
R::Input: Clone + Hash + Eq + Send + Sync + 'static,
20+
R::Output: Clone + Send + Sync + 'static,
21+
{
22+
type Input = R::Input;
23+
type Output = Option<R::Output>;
24+
type Error = R::Error;
25+
26+
async fn resolve(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
27+
match self.pending.entry(input.clone()) {
28+
Entry::Occupied(occupied) => {
29+
let tx = occupied.get().lock().await.clone();
30+
drop(occupied);
31+
Ok(tx.subscribe().recv().await.expect("recv"))
32+
}
33+
Entry::Vacant(vacant) => {
34+
let (tx, _) = channel(1);
35+
vacant.insert(Arc::new(Mutex::new(tx.clone())));
36+
let result = self.inner.resolve(input).await;
37+
tx.send(result.as_ref().ok().cloned()).ok();
38+
self.pending.remove(input);
39+
result.map(Some)
40+
}
41+
}
42+
}
43+
}

0 commit comments

Comments
 (0)