Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: mls 1on1 race condition [WPB-15395] #3237

Merged
merged 6 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
if (!featureSupport.isMLSSupported ||
!clientRepository.hasRegisteredMLSClient().getOrElse(false)
) {
kaliumLogger.d("Skip re-join existing MLS conversation, since MLS is not supported.")
kaliumLogger.d("$TAG: Skip re-join existing MLS conversation, since MLS is not supported.")
Either.Right(Unit)
} else {
conversationRepository.getConversationById(conversationId).fold({
Expand Down Expand Up @@ -115,7 +115,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
}
}
} else if (failure.kaliumException.isMlsMissingGroupInfo()) {
kaliumLogger.w("conversation has no group info, ignoring...")
kaliumLogger.w("$TAG: conversation has no group info, ignoring...")
Either.Right(Unit)
} else {
Either.Left(failure)
Expand All @@ -135,6 +135,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
protocol.epoch != 0UL -> {
// TODO(refactor): don't use conversationAPI directly
// we could use mlsConversationRepository to solve this
kaliumLogger.d("$TAG: Joining group by external commit ${conversation.id.toLogString()}")
wrapApiRequest {
conversationApi.fetchGroupInfo(conversation.id.toApi())
}.flatMap { groupInfo ->
Expand Down Expand Up @@ -185,6 +186,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
}

type == Conversation.Type.SELF -> {
kaliumLogger.d("$TAG: Establish Self MLS Conversation ${conversation.id.toLogString()}")
mlsConversationRepository.establishMLSGroup(
protocol.groupId,
emptyList()
Expand All @@ -203,6 +205,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
}

type == Conversation.Type.ONE_ON_ONE -> {
kaliumLogger.d("$TAG: Establish 1on1 MLS Conversation ${conversation.id.toLogString()}")
conversationRepository.getConversationMembers(conversation.id).flatMap { members ->
mlsConversationRepository.establishMLSGroup(
protocol.groupId,
Expand All @@ -226,4 +229,8 @@ internal class JoinExistingMLSConversationUseCaseImpl(
else -> Either.Right(Unit)
}
}

companion object {
private const val TAG = "[JoinExistingMLSConversationUseCase]"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ internal class MLSConversationDataSource(
private suspend fun sendCommitBundle(groupID: GroupID, bundle: CommitBundle): Either<CoreFailure, Unit> {
return mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapApiRequest {
kaliumLogger.d("Sending commit bundle for ${groupID.toLogString()}")
mlsMessageApi.sendCommitBundle(mlsCommitBundleMapper.toDTO(bundle))
}.flatMap { response ->
processCommitBundleEvents(response.events)
Expand Down Expand Up @@ -376,6 +377,7 @@ internal class MLSConversationDataSource(
}

private suspend fun processCommitBundleEvents(events: List<EventContentDTO>) {
kaliumLogger.d("Processing commit bundle events")
events.forEach { eventContentDTO ->
val event =
MapperProvider.eventMapper(selfUserId).fromEventContentDTO(
Expand Down Expand Up @@ -454,7 +456,8 @@ internal class MLSConversationDataSource(
retryOnStaleMessage = true,
allowPartialMemberList = false,
cipherSuite = cipherSuite
).map { Unit }
)
.map { Unit }

private suspend fun internalAddMemberToMLSGroup(
groupID: GroupID,
Expand All @@ -464,7 +467,7 @@ internal class MLSConversationDataSource(
allowPartialMemberList: Boolean = false,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
commitPendingProposals(groupID).flatMap {
kaliumLogger.d("adding ${userIdList.count()} users to MLS group")
kaliumLogger.d("adding ${userIdList.count()} users to MLS group ${groupID.toLogString()}")
produceAndSendCommitWithRetryAndResult(groupID, retryOnStaleMessage = retryOnStaleMessage) {
keyPackageRepository.claimKeyPackages(userIdList, cipherSuite).flatMap { result ->
if (result.usersWithoutKeyPackagesAvailable.isNotEmpty() && !allowPartialMemberList) {
Expand All @@ -485,12 +488,15 @@ internal class MLSConversationDataSource(
// We are creating a group with only our self client which technically
// doesn't need be added with a commit, but our backend API requires one,
// so we create a commit by updating our key material.
kaliumLogger.d("add members to MLS Group: updating keying material for self client")
updateKeyingMaterial(idMapper.toCryptoModel(groupID))
} else {
kaliumLogger.d("add members to MLS Group: executing for groupID ${groupID.toLogString()}")
addMember(idMapper.toCryptoModel(groupID), clientKeyPackageList)
}
}.onSuccess { commitBundle ->
commitBundle?.crlNewDistributionPoints?.let { revocationList ->
kaliumLogger.d("add members to MLS Group: checking revocation list")
checkRevocationList(revocationList)
}
}.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,7 @@ class UserSessionScope internal constructor(
clientIdProvider,
messages.messageSender,
teamRepository,
slowSyncRepository,
userId,
selfConversationIdProvider,
persistMessage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import com.wire.kalium.logic.data.id.QualifiedIdMapper
import com.wire.kalium.logic.data.id.SelfTeamIdProvider
import com.wire.kalium.logic.data.message.PersistMessageUseCase
import com.wire.kalium.logic.data.properties.UserPropertyRepository
import com.wire.kalium.logic.data.sync.SlowSyncRepository
import com.wire.kalium.logic.data.team.TeamRepository
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
Expand Down Expand Up @@ -96,6 +97,7 @@ class ConversationScope internal constructor(
private val currentClientIdProvider: CurrentClientIdProvider,
private val messageSender: MessageSender,
private val teamRepository: TeamRepository,
private val slowSyncRepository: SlowSyncRepository,
private val selfUserId: UserId,
private val selfConversationIdProvider: SelfConversationIdProvider,
private val persistMessage: PersistMessageUseCase,
Expand Down Expand Up @@ -151,6 +153,7 @@ class ConversationScope internal constructor(
oneOnOneResolver,
conversationRepository,
deleteEphemeralMessageEndDate,
slowSyncRepository,
kaliumLogger
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@ import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.data.conversation.ConversationDetails
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.sync.SlowSyncRepository
import com.wire.kalium.logic.data.sync.SlowSyncStatus
import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessagesAfterEndDateUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.util.KaliumDispatcher
import com.wire.kalium.util.KaliumDispatcherImpl
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext

/**
* Used by the UI to notify Kalium that a conversation is open.
Expand All @@ -45,12 +51,26 @@ internal class NotifyConversationIsOpenUseCaseImpl(
private val oneOnOneResolver: OneOnOneResolver,
private val conversationRepository: ConversationRepository,
private val deleteEphemeralMessageEndDate: DeleteEphemeralMessagesAfterEndDateUseCase,
private val kaliumLogger: KaliumLogger
private val slowSyncRepository: SlowSyncRepository,
private val kaliumLogger: KaliumLogger,
private val dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) : NotifyConversationIsOpenUseCase {

override suspend operator fun invoke(conversationId: ConversationId) {
override suspend operator fun invoke(conversationId: ConversationId) = withContext(dispatcher.io) {
val ephemeralCleanupJob = launch {
kaliumLogger.v("$TAG: Starting ephemeral messages deletion in background")
deleteEphemeralMessageEndDate()
}

val slowSyncStatus = slowSyncRepository.slowSyncStatus.first()

if (slowSyncStatus != SlowSyncStatus.Complete) {
kaliumLogger.v("$TAG: Slow sync is not completed yet, skipping further steps")
return@withContext
}

kaliumLogger.v(
"Notifying that conversation with ID: ${conversationId.toLogString()} is open"
"$TAG: Notifying that conversation with ID: ${conversationId.toLogString()} is open"
)
val conversation = conversationRepository.observeConversationDetailsById(conversationId)
.filterIsInstance<Either.Right<ConversationDetails>>()
Expand All @@ -59,15 +79,18 @@ internal class NotifyConversationIsOpenUseCaseImpl(

if (conversation is ConversationDetails.OneOne) {
kaliumLogger.v(
"Reevaluating protocol for 1:1 conversation with ID: ${conversationId.toLogString()}"
"$TAG: Reevaluating protocol for 1:1 conversation with ID: ${conversationId.toLogString()}"
)
oneOnOneResolver.resolveOneOnOneConversationWithUser(
user = conversation.otherUser,
invalidateCurrentKnownProtocols = true
)
}

// Delete Ephemeral Messages that has passed the end date
deleteEphemeralMessageEndDate()
ephemeralCleanupJob.join()
}

companion object {
private const val TAG = "[NotifyConversationIsOpenUseCase]"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

package com.wire.kalium.logic.sync.receiver.conversation

import com.wire.kalium.logger.obfuscateId
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationDetails
Expand All @@ -33,6 +35,7 @@ import com.wire.kalium.logic.feature.keypackage.RefillKeyPackagesResult
import com.wire.kalium.logic.feature.keypackage.RefillKeyPackagesUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.flatMapLeft
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
Expand Down Expand Up @@ -61,22 +64,33 @@ internal class MLSWelcomeEventHandlerImpl(
mlsClientProvider.getMLSClient()
}
.flatMap { client ->
kaliumLogger.d("$TAG: Processing MLS welcome message")
wrapMLSRequest {
client.processWelcomeMessage(event.message.decodeBase64Bytes())
}
}.flatMap { welcomeBundle ->
welcomeBundle.crlNewDistributionPoints?.let {
kaliumLogger.d("$TAG: checking revocation list")
checkRevocationList(it)
}
kaliumLogger.d("$TAG: Marking conversation as established ${welcomeBundle.groupId.obfuscateId()}")
markConversationAsEstablished(GroupID(welcomeBundle.groupId))
}.flatMap {
kaliumLogger.d("$TAG: Resolving conversation if one-on-one ${event.conversationId.toLogString()}")
resolveConversationIfOneOnOne(event.conversationId)
}
.flatMapLeft {
if (it is MLSFailure.ConversationAlreadyExists) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add warning here then? "Discarding welcome since the conversation already exists "

Either.Right(Unit)
} else {
Either.Left(it)
}
}
.onSuccess {
val didSucceedRefillingKeyPackages = when (val refillResult = refillKeyPackages()) {
is RefillKeyPackagesResult.Failure -> {
val exception = (refillResult.failure as? CoreFailure.Unknown)?.rootCause
kaliumLogger.w("Failed to refill key packages; Failure: ${refillResult.failure}", exception)
kaliumLogger.w("$TAG: Failed to refill key packages; Failure: ${refillResult.failure}", exception)
false
}

Expand Down Expand Up @@ -119,4 +133,8 @@ internal class MLSWelcomeEventHandlerImpl(
}
}

companion object {
private const val TAG = "[MLSWelcomeEventHandler]"
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package com.wire.kalium.logic.feature.conversation

import com.wire.kalium.logic.data.sync.SlowSyncRepository
import com.wire.kalium.logic.data.sync.SlowSyncStatus
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessagesAfterEndDateUseCase
import com.wire.kalium.logic.framework.TestConversationDetails
import com.wire.kalium.logic.functional.Either
Expand All @@ -30,8 +32,11 @@ import io.mockative.any
import io.mockative.coEvery
import io.mockative.coVerify
import io.mockative.eq
import io.mockative.every
import io.mockative.mock
import io.mockative.once
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.runTest
import kotlin.test.Test

Expand Down Expand Up @@ -102,19 +107,29 @@ class NotifyConversationIsOpenUseCaseTest {
@Mock
private val deleteEphemeralMessageEndDate = mock(DeleteEphemeralMessagesAfterEndDateUseCase::class)

@Mock
private val slowSyncRepository = mock(SlowSyncRepository::class)

suspend fun withDeleteEphemeralMessageEndDateSuccess() {
coEvery {
deleteEphemeralMessageEndDate.invoke()
}.returns(Unit)
}

init {
every {
slowSyncRepository.slowSyncStatus
}.returns(MutableStateFlow(SlowSyncStatus.Complete))
}

suspend fun arrange(): Pair<Arrangement, NotifyConversationIsOpenUseCase> = run {
configure()
this@Arrangement to NotifyConversationIsOpenUseCaseImpl(
oneOnOneResolver = oneOnOneResolver,
conversationRepository = conversationRepository,
kaliumLogger = kaliumLogger,
deleteEphemeralMessageEndDate = deleteEphemeralMessageEndDate
deleteEphemeralMessageEndDate = deleteEphemeralMessageEndDate,
slowSyncRepository = slowSyncRepository
)
}
}
Expand Down
Loading