Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a lock struct for the mls groups #1413

Closed
wants to merge 16 commits into from
1 change: 0 additions & 1 deletion xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,6 @@ where
);
if mls_group.is_active() {
group.maybe_update_installations(provider, None).await?;

group.sync_with_conn(provider).await?;
active_group_count.fetch_add(1, Ordering::SeqCst);
}
Expand Down
50 changes: 30 additions & 20 deletions xmtp_mls/src/groups/mls_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::{
Installation, PostCommitAction, SendMessageIntentData, SendWelcomesAction,
UpdateAdminListIntentData, UpdateGroupMembershipIntentData, UpdatePermissionIntentData,
},
serial::SerialOpenMlsGroup,
validated_commit::{extract_group_membership, CommitValidationError},
GroupError, HmacKey, IntentError, MlsGroup, ScopedGroupClient,
};
Expand All @@ -16,7 +17,7 @@ use crate::{
groups::device_sync::DeviceSyncContent,
groups::{
device_sync::preference_sync::UserPreferenceUpdate, intents::UpdateMetadataIntentData,
validated_commit::ValidatedCommit,
serial::OpenMlsLock, validated_commit::ValidatedCommit,
},
hpke::{encrypt_welcome, HpkeError},
identity::{parse_credential, IdentityError},
Expand Down Expand Up @@ -368,17 +369,18 @@ where
id: ref msg_id,
..
} = *envelope;
let mut locked_openmls_group = openmls_group.lock().await;

if intent.state == IntentState::Committed {
return Ok(IntentState::Committed);
}
let message_epoch = message.epoch();
let group_epoch = openmls_group.epoch();
let group_epoch = locked_openmls_group.epoch();
debug!(
inbox_id = self.client.inbox_id(),
installation_id = %self.client.installation_id(),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_id,
intent.id,
intent.kind = %intent.kind,
Expand Down Expand Up @@ -407,7 +409,7 @@ where
inbox_id = self.client.inbox_id(),
installation_id = %self.client.installation_id(),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_id,
intent.id,
intent.kind = %intent.kind,
Expand Down Expand Up @@ -436,7 +438,7 @@ where
self.client.as_ref(),
conn,
&pending_commit,
openmls_group,
&locked_openmls_group,
)
.await;

Expand All @@ -458,7 +460,9 @@ where
self.context().inbox_id(),
intent.id
);
if let Err(err) = openmls_group.merge_staged_commit(&provider, pending_commit) {
if let Err(err) =
locked_openmls_group.merge_staged_commit(&provider, pending_commit)
{
tracing::error!("error merging commit: {}", err);
return Ok(IntentState::ToPublish);
} else {
Expand Down Expand Up @@ -498,18 +502,22 @@ where
id: ref msg_id,
..
} = *envelope;
let mut locked_openmls_group = openmls_group.lock().await;

let decrypted_message = openmls_group.process_message(provider, message)?;
let (sender_inbox_id, sender_installation_id) =
extract_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?;
let decrypted_message = locked_openmls_group.process_message(provider, message)?;
let (sender_inbox_id, sender_installation_id) = extract_message_sender(
&mut locked_openmls_group,
&decrypted_message,
envelope_timestamp_ns,
)?;

tracing::info!(
inbox_id = self.client.inbox_id(),
installation_id = %self.client.installation_id(),
sender_inbox_id = sender_inbox_id,
sender_installation_id = hex::encode(&sender_installation_id),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_epoch = decrypted_message.epoch().as_u64(),
msg_group_id = hex::encode(decrypted_message.group_id().as_slice()),
msg_id,
Expand All @@ -530,7 +538,7 @@ where
sender_installation_id = hex::encode(&sender_installation_id),
installation_id = %self.client.installation_id(),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_epoch,
msg_group_id,
msg_id,
Expand Down Expand Up @@ -658,7 +666,7 @@ where
installation_id = %self.client.installation_id(),
sender_installation_id = hex::encode(&sender_installation_id),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_epoch,
msg_group_id,
msg_id,
Expand All @@ -673,7 +681,7 @@ where
self.client.as_ref(),
provider.conn_ref(),
&sc,
openmls_group,
&locked_openmls_group,
)
.await?;
tracing::info!(
Expand All @@ -682,14 +690,14 @@ where
installation_id = %self.client.installation_id(),
sender_installation_id = hex::encode(&sender_installation_id),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_epoch,
msg_group_id,
msg_id,
"[{}] staged commit is valid, will attempt to merge",
self.context().inbox_id()
);
openmls_group.merge_staged_commit(provider, sc)?;
locked_openmls_group.merge_staged_commit(provider, sc)?;
self.save_transcript_message(
provider.conn_ref(),
validated_commit,
Expand Down Expand Up @@ -970,6 +978,7 @@ where
provider: &XmtpOpenMlsProvider,
) -> Result<(), GroupError> {
let mut openmls_group = self.load_mls_group(provider)?;
let mut locked_openmls_group = openmls_group.lock().await;

let intents = provider.conn_ref().find_group_intents(
self.group_id.clone(),
Expand All @@ -981,7 +990,7 @@ where
let result = retry_async!(
Retry::default(),
(async {
self.get_publish_intent_data(provider, &mut openmls_group, &intent)
self.get_publish_intent_data(provider, &mut locked_openmls_group, &intent)
.await
})
);
Expand Down Expand Up @@ -1021,7 +1030,7 @@ where
sha256(payload_slice),
post_commit_action,
staged_commit,
openmls_group.epoch().as_u64() as i64,
locked_openmls_group.epoch().as_u64() as i64,
)?;
tracing::debug!(
inbox_id = self.client.inbox_id(),
Expand Down Expand Up @@ -1075,7 +1084,7 @@ where
async fn get_publish_intent_data(
&self,
provider: &XmtpOpenMlsProvider,
openmls_group: &mut OpenMlsGroup,
openmls_group: &mut SerialOpenMlsGroup<'_>,
intent: &StoredGroupIntent,
) -> Result<Option<PublishIntentData>, GroupError> {
match intent.kind {
Expand Down Expand Up @@ -1292,8 +1301,9 @@ where
inbox_ids_to_add: &[InboxIdRef<'_>],
inbox_ids_to_remove: &[InboxIdRef<'_>],
) -> Result<UpdateGroupMembershipIntentData, GroupError> {
let mls_group = self.load_mls_group(provider)?;
let existing_group_membership = extract_group_membership(mls_group.extensions())?;
let mut mls_group = self.load_mls_group(provider)?;
let locked_mls_group = mls_group.lock().await;
let existing_group_membership = extract_group_membership(locked_mls_group.extensions())?;

// TODO:nm prevent querying for updates on members who are being removed
let mut inbox_ids = existing_group_membership.inbox_ids();
Expand Down
4 changes: 2 additions & 2 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod group_permissions;
pub mod intents;
pub mod members;
pub mod scoped_client;
mod serial;

pub(super) mod mls_sync;
pub(super) mod subscriptions;
Expand Down Expand Up @@ -331,11 +332,11 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
}

// Load the stored OpenMLS group from the OpenMLS provider's keystore
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn load_mls_group(
&self,
provider: impl OpenMlsProvider,
) -> Result<OpenMlsGroup, GroupError> {
// Get the group ID for locking
let mls_group =
OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id))
.map_err(|_| GroupError::GroupNotFound)?
Expand Down Expand Up @@ -1147,7 +1148,6 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
provider: impl OpenMlsProvider,
) -> Result<GroupMutableMetadata, GroupError> {
let mls_group = &self.load_mls_group(provider)?;

Ok(mls_group.try_into()?)
}

Expand Down
73 changes: 73 additions & 0 deletions xmtp_mls/src/groups/serial.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use openmls::prelude::MlsGroup as OpenMlsGroup;

use std::{
collections::HashMap,
ops::{Deref, DerefMut},
sync::{Arc, LazyLock},
};
use tokio::sync::{Mutex, OwnedMutexGuard};

type CommitLock = parking_lot::Mutex<HashMap<Vec<u8>, Arc<Mutex<()>>>>;
pub static MLS_COMMIT_LOCK: LazyLock<CommitLock> = LazyLock::new(parking_lot::Mutex::default);

pub struct SerialOpenMlsGroup<'a> {
group: &'a mut OpenMlsGroup,
_lock: OwnedMutexGuard<()>,
}

impl Deref for SerialOpenMlsGroup<'_> {
type Target = OpenMlsGroup;
fn deref(&self) -> &Self::Target {
self.group
}
}

impl DerefMut for SerialOpenMlsGroup<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.group
}
}

#[allow(unused)]
pub(crate) trait OpenMlsLock {
fn lock_blocking(&mut self) -> SerialOpenMlsGroup;
async fn lock(&mut self) -> SerialOpenMlsGroup;
}

impl OpenMlsLock for OpenMlsGroup {
#[allow(clippy::needless_lifetimes)]
async fn lock<'a>(&'a mut self) -> SerialOpenMlsGroup<'a> {
// .clone() is important here so that the outer lock gets dropped
let mutex = MLS_COMMIT_LOCK
.lock()
.entry(self.group_id().to_vec())
.or_default()
.clone();

// this may block
let lock = mutex.lock_owned().await;

SerialOpenMlsGroup {
group: self,
_lock: lock,
}
}

#[allow(clippy::needless_lifetimes)]
fn lock_blocking<'a>(&'a mut self) -> SerialOpenMlsGroup<'a> {
// .clone() is important here so that the outer lock gets dropped
let mutex = MLS_COMMIT_LOCK
.lock()
.entry(self.group_id().to_vec())
.or_default()
.clone();

// this may block
let lock = mutex.blocking_lock_owned();

SerialOpenMlsGroup {
group: self,
_lock: lock,
}
}
}
5 changes: 3 additions & 2 deletions xmtp_mls/src/groups/validated_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use super::{
group_permissions::{
extract_group_permissions, GroupMutablePermissions, GroupMutablePermissionsError,
},
serial::SerialOpenMlsGroup,
ScopedGroupClient,
};

Expand Down Expand Up @@ -214,11 +215,11 @@ pub struct ValidatedCommit {
}

impl ValidatedCommit {
pub async fn from_staged_commit(
pub async fn from_staged_commit<'a>(
client: impl ScopedGroupClient,
conn: &DbConnection,
staged_commit: &StagedCommit,
openmls_group: &OpenMlsGroup,
openmls_group: &SerialOpenMlsGroup<'a>,
) -> Result<Self, CommitValidationError> {
// Get the immutable and mutable metadata
let extensions = openmls_group.extensions();
Expand Down