From 90606550133418779e31f1ffb8f57922b786dd89 Mon Sep 17 00:00:00 2001 From: Yoshihiro Sugi Date: Fri, 19 Jul 2024 16:08:07 +0900 Subject: [PATCH] feat(bsky-sdk): Add record operations (#200) * sdk: Add Record trait and implementations * Add create_record/delete_record to BskyAgent --- bsky-sdk/src/error.rs | 4 + bsky-sdk/src/lib.rs | 3 + bsky-sdk/src/moderation/tests.rs | 3 +- bsky-sdk/src/record.rs | 434 +++++++++++++++++++++++++++++++ bsky-sdk/src/record/agent.rs | 238 +++++++++++++++++ 5 files changed, 680 insertions(+), 2 deletions(-) create mode 100644 bsky-sdk/src/record.rs create mode 100644 bsky-sdk/src/record/agent.rs diff --git a/bsky-sdk/src/error.rs b/bsky-sdk/src/error.rs index 0ab55171..139e85f8 100644 --- a/bsky-sdk/src/error.rs +++ b/bsky-sdk/src/error.rs @@ -7,6 +7,10 @@ use thiserror::Error; /// Error type for this crate. #[derive(Error, Debug)] pub enum Error { + #[error("not logged in")] + NotLoggedIn, + #[error("invalid AT URI")] + InvalidAtUri, #[error("xrpc response error: {0}")] Xrpc(Box), #[error("loading config error: {0}")] diff --git a/bsky-sdk/src/lib.rs b/bsky-sdk/src/lib.rs index 2b2b1e83..a3fac26e 100644 --- a/bsky-sdk/src/lib.rs +++ b/bsky-sdk/src/lib.rs @@ -4,6 +4,7 @@ pub mod agent; mod error; pub mod moderation; pub mod preference; +pub mod record; #[cfg_attr(docsrs, doc(cfg(feature = "rich-text")))] #[cfg(feature = "rich-text")] pub mod rich_text; @@ -19,6 +20,8 @@ mod tests { use atrium_api::xrpc::types::Header; use atrium_api::xrpc::{HttpClient, XrpcClient}; + pub const FAKE_CID: &str = "bafyreiclp443lavogvhj3d2ob2cxbfuscni2k5jk7bebjzg7khl3esabwq"; + pub struct MockClient; #[async_trait] diff --git a/bsky-sdk/src/moderation/tests.rs b/bsky-sdk/src/moderation/tests.rs index f7eb5fba..9666ab0d 100644 --- a/bsky-sdk/src/moderation/tests.rs +++ b/bsky-sdk/src/moderation/tests.rs @@ -7,6 +7,7 @@ use crate::moderation::decision::{DecisionContext, ModerationDecision}; use crate::moderation::types::*; use crate::moderation::util::interpret_label_value_definition; use crate::moderation::Moderator; +use crate::tests::FAKE_CID; use atrium_api::app::bsky::actor::defs::{ProfileViewBasic, ProfileViewBasicData}; use atrium_api::app::bsky::feed::defs::{PostView, PostViewData}; use atrium_api::com::atproto::label::defs::{Label, LabelData, LabelValueDefinitionData}; @@ -14,8 +15,6 @@ use atrium_api::records::{KnownRecord, Record}; use atrium_api::types::string::Datetime; use std::collections::HashMap; -const FAKE_CID: &str = "bafyreiclp443lavogvhj3d2ob2cxbfuscni2k5jk7bebjzg7khl3esabwq"; - #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ResultFlag { Filter, diff --git a/bsky-sdk/src/record.rs b/bsky-sdk/src/record.rs new file mode 100644 index 00000000..f4bb8ba7 --- /dev/null +++ b/bsky-sdk/src/record.rs @@ -0,0 +1,434 @@ +//! Record operations. +mod agent; + +pub use self::agent::*; +use crate::error::{Error, Result}; +use crate::BskyAgent; +use async_trait::async_trait; +use atrium_api::agent::store::SessionStore; +use atrium_api::com::atproto::repo::{create_record, get_record, list_records, put_record}; +use atrium_api::types::{Collection, LimitedNonZeroU8}; +use atrium_api::xrpc::XrpcClient; + +#[async_trait] +pub trait Record +where + T: XrpcClient + Send + Sync, + S: SessionStore + Send + Sync, +{ + async fn list( + agent: &BskyAgent, + cursor: Option, + limit: Option>, + ) -> Result; + async fn get(agent: &BskyAgent, rkey: String) -> Result; + async fn put(self, agent: &BskyAgent, rkey: String) -> Result; + async fn create(self, agent: &BskyAgent) -> Result; + async fn delete(agent: &BskyAgent, rkey: String) -> Result<()>; +} + +macro_rules! record_impl { + ($collection:path, $record:path, $record_data:path, $variant:ident) => { + #[async_trait] + impl Record for $record + where + T: XrpcClient + Send + Sync, + S: SessionStore + Send + Sync, + { + async fn list( + agent: &BskyAgent, + cursor: Option, + limit: Option>, + ) -> Result { + let session = agent.get_session().await.ok_or(Error::NotLoggedIn)?; + Ok(agent + .api + .com + .atproto + .repo + .list_records( + atrium_api::com::atproto::repo::list_records::ParametersData { + collection: <$collection>::nsid(), + cursor, + limit, + repo: session.data.did.into(), + reverse: None, + rkey_end: None, + rkey_start: None, + } + .into(), + ) + .await?) + } + async fn get(agent: &BskyAgent, rkey: String) -> Result { + let session = agent.get_session().await.ok_or(Error::NotLoggedIn)?; + Ok(agent + .api + .com + .atproto + .repo + .get_record( + atrium_api::com::atproto::repo::get_record::ParametersData { + cid: None, + collection: <$collection>::nsid(), + repo: session.data.did.into(), + rkey, + } + .into(), + ) + .await?) + } + async fn put( + self, + agent: &BskyAgent, + rkey: String, + ) -> Result { + let session = agent.get_session().await.ok_or(Error::NotLoggedIn)?; + Ok(agent + .api + .com + .atproto + .repo + .put_record( + atrium_api::com::atproto::repo::put_record::InputData { + collection: <$collection>::nsid(), + record: atrium_api::records::Record::Known( + atrium_api::records::KnownRecord::$variant(Box::new(self)), + ), + repo: session.data.did.into(), + rkey, + swap_commit: None, + swap_record: None, + validate: None, + } + .into(), + ) + .await?) + } + async fn create(self, agent: &BskyAgent) -> Result { + let session = agent.get_session().await.ok_or(Error::NotLoggedIn)?; + Ok(agent + .api + .com + .atproto + .repo + .create_record( + atrium_api::com::atproto::repo::create_record::InputData { + collection: <$collection>::nsid(), + record: atrium_api::records::Record::Known( + atrium_api::records::KnownRecord::$variant(Box::new(self)), + ), + repo: session.data.did.into(), + rkey: None, + swap_commit: None, + validate: None, + } + .into(), + ) + .await?) + } + async fn delete(agent: &BskyAgent, rkey: String) -> Result<()> { + let session = agent.get_session().await.ok_or(Error::NotLoggedIn)?; + Ok(agent + .api + .com + .atproto + .repo + .delete_record( + atrium_api::com::atproto::repo::delete_record::InputData { + collection: <$collection>::nsid(), + repo: session.data.did.into(), + rkey, + swap_commit: None, + swap_record: None, + } + .into(), + ) + .await?) + } + } + + #[async_trait] + impl Record for $record_data + where + T: XrpcClient + Send + Sync, + S: SessionStore + Send + Sync, + { + async fn list( + agent: &BskyAgent, + cursor: Option, + limit: Option>, + ) -> Result { + <$record>::list(agent, cursor, limit).await + } + async fn get(agent: &BskyAgent, rkey: String) -> Result { + <$record>::get(agent, rkey).await + } + async fn put( + self, + agent: &BskyAgent, + rkey: String, + ) -> Result { + <$record>::from(self).put(agent, rkey).await + } + async fn create(self, agent: &BskyAgent) -> Result { + <$record>::from(self).create(agent).await + } + async fn delete(agent: &BskyAgent, rkey: String) -> Result<()> { + <$record>::delete(agent, rkey).await + } + } + }; +} + +record_impl!( + atrium_api::app::bsky::actor::Profile, + atrium_api::app::bsky::actor::profile::Record, + atrium_api::app::bsky::actor::profile::RecordData, + AppBskyActorProfile +); +record_impl!( + atrium_api::app::bsky::feed::Generator, + atrium_api::app::bsky::feed::generator::Record, + atrium_api::app::bsky::feed::generator::RecordData, + AppBskyFeedGenerator +); +record_impl!( + atrium_api::app::bsky::feed::Like, + atrium_api::app::bsky::feed::like::Record, + atrium_api::app::bsky::feed::like::RecordData, + AppBskyFeedLike +); +record_impl!( + atrium_api::app::bsky::feed::Post, + atrium_api::app::bsky::feed::post::Record, + atrium_api::app::bsky::feed::post::RecordData, + AppBskyFeedPost +); +record_impl!( + atrium_api::app::bsky::feed::Repost, + atrium_api::app::bsky::feed::repost::Record, + atrium_api::app::bsky::feed::repost::RecordData, + AppBskyFeedRepost +); +record_impl!( + atrium_api::app::bsky::feed::Threadgate, + atrium_api::app::bsky::feed::threadgate::Record, + atrium_api::app::bsky::feed::threadgate::RecordData, + AppBskyFeedThreadgate +); +record_impl!( + atrium_api::app::bsky::graph::Block, + atrium_api::app::bsky::graph::block::Record, + atrium_api::app::bsky::graph::block::RecordData, + AppBskyGraphBlock +); +record_impl!( + atrium_api::app::bsky::graph::Follow, + atrium_api::app::bsky::graph::follow::Record, + atrium_api::app::bsky::graph::follow::RecordData, + AppBskyGraphFollow +); +record_impl!( + atrium_api::app::bsky::graph::List, + atrium_api::app::bsky::graph::list::Record, + atrium_api::app::bsky::graph::list::RecordData, + AppBskyGraphList +); +record_impl!( + atrium_api::app::bsky::graph::Listblock, + atrium_api::app::bsky::graph::listblock::Record, + atrium_api::app::bsky::graph::listblock::RecordData, + AppBskyGraphListblock +); +record_impl!( + atrium_api::app::bsky::graph::Listitem, + atrium_api::app::bsky::graph::listitem::Record, + atrium_api::app::bsky::graph::listitem::RecordData, + AppBskyGraphListitem +); +record_impl!( + atrium_api::app::bsky::graph::Starterpack, + atrium_api::app::bsky::graph::starterpack::Record, + atrium_api::app::bsky::graph::starterpack::RecordData, + AppBskyGraphStarterpack +); +record_impl!( + atrium_api::app::bsky::labeler::Service, + atrium_api::app::bsky::labeler::service::Record, + atrium_api::app::bsky::labeler::service::RecordData, + AppBskyLabelerService +); +record_impl!( + atrium_api::chat::bsky::actor::Declaration, + atrium_api::chat::bsky::actor::declaration::Record, + atrium_api::chat::bsky::actor::declaration::RecordData, + ChatBskyActorDeclaration +); + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent::BskyAgentBuilder; + use crate::tests::FAKE_CID; + use atrium_api::agent::Session; + use atrium_api::com::atproto::server::create_session::OutputData; + use atrium_api::types::string::Datetime; + use atrium_api::xrpc::http::{Request, Response}; + use atrium_api::xrpc::types::Header; + use atrium_api::xrpc::{HttpClient, XrpcClient}; + + struct MockClient; + + #[async_trait] + impl HttpClient for MockClient { + async fn send_http( + &self, + request: Request>, + ) -> core::result::Result< + Response>, + Box, + > { + match request.uri().path() { + "/xrpc/com.atproto.repo.createRecord" => { + let output = create_record::Output::from(create_record::OutputData { + cid: FAKE_CID.parse().expect("invalid cid"), + uri: String::from("at://did:fake:handle.test/app.bsky.feed.post/somerkey"), + }); + Ok(Response::builder() + .header(Header::ContentType, "application/json") + .status(200) + .body(serde_json::to_vec(&output)?)?) + } + "/xrpc/com.atproto.repo.deleteRecord" => { + Ok(Response::builder().status(200).body(Vec::new())?) + } + _ => unreachable!(), + } + } + } + + #[async_trait] + impl XrpcClient for MockClient { + fn base_uri(&self) -> String { + String::new() + } + } + + struct MockSessionStore; + + #[async_trait] + impl SessionStore for MockSessionStore { + async fn get_session(&self) -> Option { + Some( + OutputData { + access_jwt: String::from("access"), + active: None, + did: "did:fake:handle.test".parse().expect("invalid did"), + did_doc: None, + email: None, + email_auth_factor: None, + email_confirmed: None, + handle: "handle.test".parse().expect("invalid handle"), + refresh_jwt: String::from("refresh"), + status: None, + } + .into(), + ) + } + async fn set_session(&self, _: Session) {} + async fn clear_session(&self) {} + } + + #[tokio::test] + async fn actor_profile() -> Result<()> { + let agent = BskyAgentBuilder::new(MockClient) + .store(MockSessionStore) + .build() + .await?; + // create + let output = atrium_api::app::bsky::actor::profile::RecordData { + avatar: None, + banner: None, + created_at: None, + description: None, + display_name: None, + joined_via_starter_pack: None, + labels: None, + } + .create(&agent) + .await?; + assert_eq!( + output, + create_record::OutputData { + cid: FAKE_CID.parse().expect("invalid cid"), + uri: String::from("at://did:fake:handle.test/app.bsky.feed.post/somerkey"), + } + .into() + ); + // delete + atrium_api::app::bsky::actor::profile::Record::delete(&agent, String::from("somerkey")) + .await?; + Ok(()) + } + + #[tokio::test] + async fn feed_post() -> Result<()> { + let agent = BskyAgentBuilder::new(MockClient) + .store(MockSessionStore) + .build() + .await?; + // create + let output = atrium_api::app::bsky::feed::post::RecordData { + created_at: Datetime::now(), + embed: None, + entities: None, + facets: None, + labels: None, + langs: None, + reply: None, + tags: None, + text: String::from("text"), + } + .create(&agent) + .await?; + assert_eq!( + output, + create_record::OutputData { + cid: FAKE_CID.parse().expect("invalid cid"), + uri: String::from("at://did:fake:handle.test/app.bsky.feed.post/somerkey"), + } + .into() + ); + // delete + atrium_api::app::bsky::feed::post::Record::delete(&agent, String::from("somerkey")).await?; + Ok(()) + } + + #[tokio::test] + async fn graph_follow() -> Result<()> { + let agent = BskyAgentBuilder::new(MockClient) + .store(MockSessionStore) + .build() + .await?; + // create + let output = atrium_api::app::bsky::graph::follow::RecordData { + created_at: Datetime::now(), + subject: "did:fake:handle.test".parse().expect("invalid did"), + } + .create(&agent) + .await?; + assert_eq!( + output, + create_record::OutputData { + cid: FAKE_CID.parse().expect("invalid cid"), + uri: String::from("at://did:fake:handle.test/app.bsky.feed.post/somerkey"), + } + .into() + ); + // delete + atrium_api::app::bsky::graph::follow::Record::delete(&agent, String::from("somerkey")) + .await?; + Ok(()) + } +} diff --git a/bsky-sdk/src/record/agent.rs b/bsky-sdk/src/record/agent.rs new file mode 100644 index 00000000..08c7d888 --- /dev/null +++ b/bsky-sdk/src/record/agent.rs @@ -0,0 +1,238 @@ +use super::Record; +use crate::error::{Error, Result}; +use crate::BskyAgent; +use atrium_api::agent::store::SessionStore; +use atrium_api::com::atproto::repo::create_record; +use atrium_api::records::KnownRecord; +use atrium_api::types::string::RecordKey; +use atrium_api::xrpc::XrpcClient; + +pub enum CreateRecordSubject { + AppBskyActorProfile(Box), + AppBskyFeedGenerator(Box), + AppBskyFeedLike(Box), + AppBskyFeedPost(Box), + AppBskyFeedRepost(Box), + AppBskyFeedThreadgate(Box), + AppBskyGraphBlock(Box), + AppBskyGraphFollow(Box), + AppBskyGraphList(Box), + AppBskyGraphListblock(Box), + AppBskyGraphListitem(Box), + AppBskyGraphStarterpack(Box), + AppBskyLabelerService(Box), + ChatBskyActorDeclaration(Box), +} + +impl TryFrom for CreateRecordSubject { + type Error = (); + + fn try_from(record: atrium_api::records::Record) -> std::result::Result { + match record { + atrium_api::records::Record::Known(record) => Ok(record.into()), + _ => Err(()), + } + } +} + +impl From for CreateRecordSubject { + fn from(value: KnownRecord) -> Self { + match value { + KnownRecord::AppBskyActorProfile(record) => Self::AppBskyActorProfile(record), + KnownRecord::AppBskyFeedGenerator(record) => Self::AppBskyFeedGenerator(record), + KnownRecord::AppBskyFeedLike(record) => Self::AppBskyFeedLike(record), + KnownRecord::AppBskyFeedPost(record) => Self::AppBskyFeedPost(record), + KnownRecord::AppBskyFeedRepost(record) => Self::AppBskyFeedRepost(record), + KnownRecord::AppBskyFeedThreadgate(record) => Self::AppBskyFeedThreadgate(record), + KnownRecord::AppBskyGraphBlock(record) => Self::AppBskyGraphBlock(record), + KnownRecord::AppBskyGraphFollow(record) => Self::AppBskyGraphFollow(record), + KnownRecord::AppBskyGraphList(record) => Self::AppBskyGraphList(record), + KnownRecord::AppBskyGraphListblock(record) => Self::AppBskyGraphListblock(record), + KnownRecord::AppBskyGraphListitem(record) => Self::AppBskyGraphListitem(record), + KnownRecord::AppBskyGraphStarterpack(record) => Self::AppBskyGraphStarterpack(record), + KnownRecord::AppBskyLabelerService(record) => Self::AppBskyLabelerService(record), + KnownRecord::ChatBskyActorDeclaration(record) => Self::ChatBskyActorDeclaration(record), + } + } +} + +macro_rules! into_create_record_subject { + ($record:path, $record_data:path, $variant:ident) => { + impl From<$record> for CreateRecordSubject { + fn from(record: $record) -> Self { + Self::$variant(Box::new(record)) + } + } + + impl From<$record_data> for CreateRecordSubject { + fn from(record_data: $record_data) -> Self { + Self::$variant(Box::new(record_data.into())) + } + } + }; +} + +into_create_record_subject!( + atrium_api::app::bsky::actor::profile::Record, + atrium_api::app::bsky::actor::profile::RecordData, + AppBskyActorProfile +); +into_create_record_subject!( + atrium_api::app::bsky::feed::generator::Record, + atrium_api::app::bsky::feed::generator::RecordData, + AppBskyFeedGenerator +); +into_create_record_subject!( + atrium_api::app::bsky::feed::like::Record, + atrium_api::app::bsky::feed::like::RecordData, + AppBskyFeedLike +); +into_create_record_subject!( + atrium_api::app::bsky::feed::post::Record, + atrium_api::app::bsky::feed::post::RecordData, + AppBskyFeedPost +); +into_create_record_subject!( + atrium_api::app::bsky::feed::repost::Record, + atrium_api::app::bsky::feed::repost::RecordData, + AppBskyFeedRepost +); +into_create_record_subject!( + atrium_api::app::bsky::feed::threadgate::Record, + atrium_api::app::bsky::feed::threadgate::RecordData, + AppBskyFeedThreadgate +); +into_create_record_subject!( + atrium_api::app::bsky::graph::block::Record, + atrium_api::app::bsky::graph::block::RecordData, + AppBskyGraphBlock +); +into_create_record_subject!( + atrium_api::app::bsky::graph::follow::Record, + atrium_api::app::bsky::graph::follow::RecordData, + AppBskyGraphFollow +); +into_create_record_subject!( + atrium_api::app::bsky::graph::list::Record, + atrium_api::app::bsky::graph::list::RecordData, + AppBskyGraphList +); +into_create_record_subject!( + atrium_api::app::bsky::graph::listblock::Record, + atrium_api::app::bsky::graph::listblock::RecordData, + AppBskyGraphListblock +); +into_create_record_subject!( + atrium_api::app::bsky::graph::listitem::Record, + atrium_api::app::bsky::graph::listitem::RecordData, + AppBskyGraphListitem +); +into_create_record_subject!( + atrium_api::app::bsky::graph::starterpack::Record, + atrium_api::app::bsky::graph::starterpack::RecordData, + AppBskyGraphStarterpack +); +into_create_record_subject!( + atrium_api::app::bsky::labeler::service::Record, + atrium_api::app::bsky::labeler::service::RecordData, + AppBskyLabelerService +); +into_create_record_subject!( + atrium_api::chat::bsky::actor::declaration::Record, + atrium_api::chat::bsky::actor::declaration::RecordData, + ChatBskyActorDeclaration +); + +impl BskyAgent +where + T: XrpcClient + Send + Sync, + S: SessionStore + Send + Sync, +{ + /// Create a record with various types of data. + /// For example, the Record families defined in [`KnownRecord`](atrium_api::records::KnownRecord) are supported. + /// + /// # Example + /// + /// ```no_run + /// use bsky_sdk::{BskyAgent, Result}; + /// + /// #[tokio::main] + /// async fn main() -> Result<()> { + /// let agent = BskyAgent::builder().build().await?; + /// let output = agent.create_record(atrium_api::app::bsky::graph::block::RecordData { + /// created_at: atrium_api::types::string::Datetime::now(), + /// subject: "did:fake:handle.test".parse().expect("invalid did"), + /// }).await?; + /// Ok(()) + /// } + /// ``` + pub async fn create_record( + &self, + subject: impl Into, + ) -> Result { + match subject.into() { + CreateRecordSubject::AppBskyActorProfile(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyFeedGenerator(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyFeedLike(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyFeedPost(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyFeedRepost(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyFeedThreadgate(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyGraphBlock(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyGraphFollow(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyGraphList(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyGraphListblock(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyGraphListitem(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyGraphStarterpack(record) => record.data.create(self).await, + CreateRecordSubject::AppBskyLabelerService(record) => record.data.create(self).await, + CreateRecordSubject::ChatBskyActorDeclaration(record) => record.data.create(self).await, + } + } + /// Delete a record with AT URI. + /// + /// # Errors + /// + /// Returns an [`Error::InvalidAtUri`] if the `at_uri` is invalid. + /// + /// # Example + /// + /// ```no_run + /// use bsky_sdk::{BskyAgent, Result}; + /// + /// #[tokio::main] + /// async fn main() -> Result<()> { + /// let agent = BskyAgent::builder().build().await?; + /// agent.delete_record("at://did:fake:handle.test/app.bsky.graph.block/3kxmfwtgfxl2w").await?; + /// Ok(()) + /// } + /// ``` + pub async fn delete_record(&self, at_uri: impl AsRef) -> Result<()> { + let parts = at_uri + .as_ref() + .strip_prefix("at://") + .ok_or(Error::InvalidAtUri)? + .splitn(3, '/') + .collect::>(); + let repo = parts[0].parse().or(Err(Error::InvalidAtUri))?; + let collection = parts[1].parse().or(Err(Error::InvalidAtUri))?; + let rkey = parts[2] + .parse::() + .or(Err(Error::InvalidAtUri))? + .into(); + Ok(self + .api + .com + .atproto + .repo + .delete_record( + atrium_api::com::atproto::repo::delete_record::InputData { + collection, + repo, + rkey, + swap_commit: None, + swap_record: None, + } + .into(), + ) + .await?) + } +}