diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 6db9747da..f9c82aab9 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -888,7 +888,6 @@ where ); if mls_group.is_active() { group.maybe_update_installations(provider, None).await?; - group.sync_with_conn(provider).await?; active_group_count.fetch_add(1, Ordering::SeqCst); } diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 6a642bc41..2a31a4ca2 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -5,6 +5,7 @@ use super::{ Installation, PostCommitAction, SendMessageIntentData, SendWelcomesAction, UpdateAdminListIntentData, UpdateGroupMembershipIntentData, UpdatePermissionIntentData, }, + serial::SerialOpenMlsGroup, validated_commit::{extract_group_membership, CommitValidationError}, GroupError, HmacKey, IntentError, MlsGroup, ScopedGroupClient, }; @@ -16,7 +17,7 @@ use crate::{ groups::device_sync::DeviceSyncContent, groups::{ device_sync::preference_sync::UserPreferenceUpdate, intents::UpdateMetadataIntentData, - validated_commit::ValidatedCommit, + serial::OpenMlsLock, validated_commit::ValidatedCommit, }, hpke::{encrypt_welcome, HpkeError}, identity::{parse_credential, IdentityError}, @@ -368,17 +369,18 @@ where id: ref msg_id, .. } = *envelope; + let mut locked_openmls_group = openmls_group.lock().await; if intent.state == IntentState::Committed { return Ok(IntentState::Committed); } let message_epoch = message.epoch(); - let group_epoch = openmls_group.epoch(); + let group_epoch = locked_openmls_group.epoch(); debug!( inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), + current_epoch = locked_openmls_group.epoch().as_u64(), msg_id, intent.id, intent.kind = %intent.kind, @@ -407,7 +409,7 @@ where inbox_id = self.client.inbox_id(), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), + current_epoch = locked_openmls_group.epoch().as_u64(), msg_id, intent.id, intent.kind = %intent.kind, @@ -436,7 +438,7 @@ where self.client.as_ref(), conn, &pending_commit, - openmls_group, + &locked_openmls_group, ) .await; @@ -458,7 +460,9 @@ where self.context().inbox_id(), intent.id ); - if let Err(err) = openmls_group.merge_staged_commit(&provider, pending_commit) { + if let Err(err) = + locked_openmls_group.merge_staged_commit(&provider, pending_commit) + { tracing::error!("error merging commit: {}", err); return Ok(IntentState::ToPublish); } else { @@ -498,10 +502,14 @@ where id: ref msg_id, .. } = *envelope; + let mut locked_openmls_group = openmls_group.lock().await; - let decrypted_message = openmls_group.process_message(provider, message)?; - let (sender_inbox_id, sender_installation_id) = - extract_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?; + let decrypted_message = locked_openmls_group.process_message(provider, message)?; + let (sender_inbox_id, sender_installation_id) = extract_message_sender( + &mut locked_openmls_group, + &decrypted_message, + envelope_timestamp_ns, + )?; tracing::info!( inbox_id = self.client.inbox_id(), @@ -509,7 +517,7 @@ where sender_inbox_id = sender_inbox_id, sender_installation_id = hex::encode(&sender_installation_id), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), + current_epoch = locked_openmls_group.epoch().as_u64(), msg_epoch = decrypted_message.epoch().as_u64(), msg_group_id = hex::encode(decrypted_message.group_id().as_slice()), msg_id, @@ -530,7 +538,7 @@ where sender_installation_id = hex::encode(&sender_installation_id), installation_id = %self.client.installation_id(), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), + current_epoch = locked_openmls_group.epoch().as_u64(), msg_epoch, msg_group_id, msg_id, @@ -658,7 +666,7 @@ where installation_id = %self.client.installation_id(), sender_installation_id = hex::encode(&sender_installation_id), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), + current_epoch = locked_openmls_group.epoch().as_u64(), msg_epoch, msg_group_id, msg_id, @@ -673,7 +681,7 @@ where self.client.as_ref(), provider.conn_ref(), &sc, - openmls_group, + &locked_openmls_group, ) .await?; tracing::info!( @@ -682,14 +690,14 @@ where installation_id = %self.client.installation_id(), sender_installation_id = hex::encode(&sender_installation_id), group_id = hex::encode(&self.group_id), - current_epoch = openmls_group.epoch().as_u64(), + current_epoch = locked_openmls_group.epoch().as_u64(), msg_epoch, msg_group_id, msg_id, "[{}] staged commit is valid, will attempt to merge", self.context().inbox_id() ); - openmls_group.merge_staged_commit(provider, sc)?; + locked_openmls_group.merge_staged_commit(provider, sc)?; self.save_transcript_message( provider.conn_ref(), validated_commit, @@ -970,6 +978,7 @@ where provider: &XmtpOpenMlsProvider, ) -> Result<(), GroupError> { let mut openmls_group = self.load_mls_group(provider)?; + let mut locked_openmls_group = openmls_group.lock().await; let intents = provider.conn_ref().find_group_intents( self.group_id.clone(), @@ -981,7 +990,7 @@ where let result = retry_async!( Retry::default(), (async { - self.get_publish_intent_data(provider, &mut openmls_group, &intent) + self.get_publish_intent_data(provider, &mut locked_openmls_group, &intent) .await }) ); @@ -1021,7 +1030,7 @@ where sha256(payload_slice), post_commit_action, staged_commit, - openmls_group.epoch().as_u64() as i64, + locked_openmls_group.epoch().as_u64() as i64, )?; tracing::debug!( inbox_id = self.client.inbox_id(), @@ -1075,7 +1084,7 @@ where async fn get_publish_intent_data( &self, provider: &XmtpOpenMlsProvider, - openmls_group: &mut OpenMlsGroup, + openmls_group: &mut SerialOpenMlsGroup<'_>, intent: &StoredGroupIntent, ) -> Result, GroupError> { match intent.kind { @@ -1292,8 +1301,9 @@ where inbox_ids_to_add: &[InboxIdRef<'_>], inbox_ids_to_remove: &[InboxIdRef<'_>], ) -> Result { - let mls_group = self.load_mls_group(provider)?; - let existing_group_membership = extract_group_membership(mls_group.extensions())?; + let mut mls_group = self.load_mls_group(provider)?; + let locked_mls_group = mls_group.lock().await; + let existing_group_membership = extract_group_membership(locked_mls_group.extensions())?; // TODO:nm prevent querying for updates on members who are being removed let mut inbox_ids = existing_group_membership.inbox_ids(); diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 9213c38ef..144029dc1 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -6,6 +6,7 @@ pub mod group_permissions; pub mod intents; pub mod members; pub mod scoped_client; +mod serial; pub(super) mod mls_sync; pub(super) mod subscriptions; @@ -331,11 +332,11 @@ impl MlsGroup { } // Load the stored OpenMLS group from the OpenMLS provider's keystore - #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn load_mls_group( &self, provider: impl OpenMlsProvider, ) -> Result { + // Get the group ID for locking let mls_group = OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) .map_err(|_| GroupError::GroupNotFound)? @@ -1147,7 +1148,6 @@ impl MlsGroup { provider: impl OpenMlsProvider, ) -> Result { let mls_group = &self.load_mls_group(provider)?; - Ok(mls_group.try_into()?) } diff --git a/xmtp_mls/src/groups/serial.rs b/xmtp_mls/src/groups/serial.rs new file mode 100644 index 000000000..66a72e1ba --- /dev/null +++ b/xmtp_mls/src/groups/serial.rs @@ -0,0 +1,73 @@ +use openmls::prelude::MlsGroup as OpenMlsGroup; + +use std::{ + collections::HashMap, + ops::{Deref, DerefMut}, + sync::{Arc, LazyLock}, +}; +use tokio::sync::{Mutex, OwnedMutexGuard}; + +type CommitLock = parking_lot::Mutex, Arc>>>; +pub static MLS_COMMIT_LOCK: LazyLock = LazyLock::new(parking_lot::Mutex::default); + +pub struct SerialOpenMlsGroup<'a> { + group: &'a mut OpenMlsGroup, + _lock: OwnedMutexGuard<()>, +} + +impl Deref for SerialOpenMlsGroup<'_> { + type Target = OpenMlsGroup; + fn deref(&self) -> &Self::Target { + self.group + } +} + +impl DerefMut for SerialOpenMlsGroup<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.group + } +} + +#[allow(unused)] +pub(crate) trait OpenMlsLock { + fn lock_blocking(&mut self) -> SerialOpenMlsGroup; + async fn lock(&mut self) -> SerialOpenMlsGroup; +} + +impl OpenMlsLock for OpenMlsGroup { + #[allow(clippy::needless_lifetimes)] + async fn lock<'a>(&'a mut self) -> SerialOpenMlsGroup<'a> { + // .clone() is important here so that the outer lock gets dropped + let mutex = MLS_COMMIT_LOCK + .lock() + .entry(self.group_id().to_vec()) + .or_default() + .clone(); + + // this may block + let lock = mutex.lock_owned().await; + + SerialOpenMlsGroup { + group: self, + _lock: lock, + } + } + + #[allow(clippy::needless_lifetimes)] + fn lock_blocking<'a>(&'a mut self) -> SerialOpenMlsGroup<'a> { + // .clone() is important here so that the outer lock gets dropped + let mutex = MLS_COMMIT_LOCK + .lock() + .entry(self.group_id().to_vec()) + .or_default() + .clone(); + + // this may block + let lock = mutex.blocking_lock_owned(); + + SerialOpenMlsGroup { + group: self, + _lock: lock, + } + } +} diff --git a/xmtp_mls/src/groups/validated_commit.rs b/xmtp_mls/src/groups/validated_commit.rs index f13557236..78661af64 100644 --- a/xmtp_mls/src/groups/validated_commit.rs +++ b/xmtp_mls/src/groups/validated_commit.rs @@ -37,6 +37,7 @@ use super::{ group_permissions::{ extract_group_permissions, GroupMutablePermissions, GroupMutablePermissionsError, }, + serial::SerialOpenMlsGroup, ScopedGroupClient, }; @@ -214,11 +215,11 @@ pub struct ValidatedCommit { } impl ValidatedCommit { - pub async fn from_staged_commit( + pub async fn from_staged_commit<'a>( client: impl ScopedGroupClient, conn: &DbConnection, staged_commit: &StagedCommit, - openmls_group: &OpenMlsGroup, + openmls_group: &SerialOpenMlsGroup<'a>, ) -> Result { // Get the immutable and mutable metadata let extensions = openmls_group.extensions();