diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt index 4ac6972a87..b148a10560 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt @@ -49,12 +49,12 @@ import com.wire.kalium.common.error.wrapNullableFlowStorageRequest import com.wire.kalium.common.error.wrapStorageRequest import com.wire.kalium.network.api.authenticated.conversation.AddConversationMembersRequest import com.wire.kalium.network.api.authenticated.conversation.AddServiceRequest -import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi import com.wire.kalium.network.api.authenticated.conversation.ConversationMemberAddedResponse import com.wire.kalium.network.api.authenticated.conversation.ConversationMemberRemovedResponse import com.wire.kalium.network.api.authenticated.conversation.ConversationResponse import com.wire.kalium.network.api.authenticated.conversation.model.ConversationCodeInfo import com.wire.kalium.network.api.authenticated.notification.EventContentDTO +import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi import com.wire.kalium.network.api.model.ServiceAddedResponse import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.exceptions.isConversationHasNoCode @@ -493,13 +493,8 @@ internal class ConversationGroupRepositoryImpl( is ConversationEntity.ProtocolInfo.Proteus -> deleteMemberFromCloudAndStorage(userId, conversationId) - is ConversationEntity.ProtocolInfo.Mixed -> - deleteMemberFromCloudAndStorage(userId, conversationId) - .flatMap { deleteMemberFromMlsGroup(userId, conversationId, protocol) } - - is ConversationEntity.ProtocolInfo.MLS -> { + is ConversationEntity.ProtocolInfo.MLSCapable -> deleteMemberFromMlsGroup(userId, conversationId, protocol) - } } } @@ -547,15 +542,24 @@ internal class ConversationGroupRepositoryImpl( userId: UserId, conversationId: ConversationId, protocol: ConversationEntity.ProtocolInfo.MLSCapable - ) = - if (userId == selfUserId) { + ) = when (protocol) { + is ConversationEntity.ProtocolInfo.MLS -> { + if (userId == selfUserId) { + deleteMemberFromCloudAndStorage(userId, conversationId).flatMap { + mlsConversationRepository.leaveGroup(GroupID(protocol.groupId)) + } + } else { + // when removing a member from an MLS group, don't need to call the api + mlsConversationRepository.removeMembersFromMLSGroup(GroupID(protocol.groupId), listOf(userId)) + } + } + + is ConversationEntity.ProtocolInfo.Mixed -> { deleteMemberFromCloudAndStorage(userId, conversationId).flatMap { - mlsConversationRepository.leaveGroup(GroupID(protocol.groupId)) + mlsConversationRepository.removeMembersFromMLSGroup(GroupID(protocol.groupId), listOf(userId)) } - } else { - // when removing a member from an MLS group, don't need to call the api - mlsConversationRepository.removeMembersFromMLSGroup(GroupID(protocol.groupId), listOf(userId)) } + } private suspend fun deleteMemberFromCloudAndStorage(userId: UserId, conversationId: ConversationId) = wrapApiRequest { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index b855903cbb..2b98f19c9e 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -19,9 +19,15 @@ package com.wire.kalium.logic.feature +import com.wire.kalium.common.error.CoreFailure +import com.wire.kalium.common.error.wrapStorageNullableRequest +import com.wire.kalium.common.functional.Either +import com.wire.kalium.common.functional.isRight +import com.wire.kalium.common.functional.map +import com.wire.kalium.common.functional.onSuccess +import com.wire.kalium.common.logger.kaliumLogger import com.wire.kalium.logger.KaliumLogger import com.wire.kalium.logger.obfuscateId -import com.wire.kalium.common.error.CoreFailure import com.wire.kalium.logic.GlobalKaliumScope import com.wire.kalium.logic.cache.MLSSelfConversationIdProvider import com.wire.kalium.logic.cache.MLSSelfConversationIdProviderImpl @@ -354,11 +360,6 @@ import com.wire.kalium.logic.feature.user.webSocketStatus.PersistPersistentWebSo import com.wire.kalium.logic.featureFlags.FeatureSupport import com.wire.kalium.logic.featureFlags.FeatureSupportImpl import com.wire.kalium.logic.featureFlags.KaliumConfigs -import com.wire.kalium.common.functional.Either -import com.wire.kalium.common.functional.isRight -import com.wire.kalium.common.functional.map -import com.wire.kalium.common.functional.onSuccess -import com.wire.kalium.common.logger.kaliumLogger import com.wire.kalium.logic.network.ApiMigrationManager import com.wire.kalium.logic.network.ApiMigrationV3 import com.wire.kalium.logic.network.SessionManagerImpl @@ -452,7 +453,6 @@ import com.wire.kalium.logic.sync.slow.SlowSyncWorkerImpl import com.wire.kalium.logic.sync.slow.migration.SyncMigrationStepsProvider import com.wire.kalium.logic.sync.slow.migration.SyncMigrationStepsProviderImpl import com.wire.kalium.logic.util.MessageContentEncoder -import com.wire.kalium.common.error.wrapStorageNullableRequest import com.wire.kalium.network.NetworkStateObserver import com.wire.kalium.network.networkContainer.AuthenticatedNetworkContainer import com.wire.kalium.network.session.SessionManager @@ -1448,6 +1448,8 @@ class UserSessionScope internal constructor( updateConversationClientsForCurrentCall = updateConversationClientsForCurrentCall, legalHoldHandler = legalHoldHandler, selfTeamIdProvider = selfTeamId, + mlsClientProvider = mlsClientProvider, + conversationDAO = userStorage.database.conversationDAO, selfUserId = userId ) private val memberChangeHandler: MemberChangeEventHandler diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandler.kt index 9167d1f3d1..5fa3714ded 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandler.kt @@ -19,6 +19,17 @@ package com.wire.kalium.logic.sync.receiver.conversation import com.wire.kalium.common.error.CoreFailure +import com.wire.kalium.common.error.wrapMLSRequest +import com.wire.kalium.common.error.wrapStorageRequest +import com.wire.kalium.common.functional.Either +import com.wire.kalium.common.functional.flatMap +import com.wire.kalium.common.functional.getOrElse +import com.wire.kalium.common.functional.getOrNull +import com.wire.kalium.common.functional.map +import com.wire.kalium.common.functional.onFailure +import com.wire.kalium.common.functional.onSuccess +import com.wire.kalium.common.logger.kaliumLogger +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.event.MemberLeaveReason @@ -31,16 +42,10 @@ import com.wire.kalium.logic.data.message.PersistMessageUseCase import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.feature.call.usecase.UpdateConversationClientsForCurrentCallUseCase -import com.wire.kalium.common.functional.Either -import com.wire.kalium.common.functional.flatMap -import com.wire.kalium.common.functional.getOrElse -import com.wire.kalium.common.functional.getOrNull -import com.wire.kalium.common.functional.onFailure -import com.wire.kalium.common.functional.onSuccess -import com.wire.kalium.common.logger.kaliumLogger import com.wire.kalium.logic.sync.receiver.handler.legalhold.LegalHoldHandler import com.wire.kalium.logic.util.createEventProcessingLogger -import com.wire.kalium.common.error.wrapStorageRequest +import com.wire.kalium.persistence.dao.conversation.ConversationDAO +import com.wire.kalium.persistence.dao.conversation.ConversationEntity import com.wire.kalium.persistence.dao.member.MemberDAO interface MemberLeaveEventHandler { @@ -53,6 +58,8 @@ internal class MemberLeaveEventHandlerImpl( private val userRepository: UserRepository, private val conversationRepository: ConversationRepository, private val persistMessage: PersistMessageUseCase, + private val mlsClientProvider: MLSClientProvider, + private val conversationDAO: ConversationDAO, // TODO: refactor to not have DAO here private val updateConversationClientsForCurrentCall: Lazy, private val legalHoldHandler: LegalHoldHandler, private val selfTeamIdProvider: SelfTeamIdProvider, @@ -123,17 +130,6 @@ internal class MemberLeaveEventHandlerImpl( } } - private suspend fun deleteMembers( - userIDList: List, - conversationID: ConversationId - ): Either = - wrapStorageRequest { - memberDAO.deleteMembersByQualifiedID( - userIDList.map { it.toDao() }, - conversationID.toDao() - ) - } - private suspend fun deleteConversationIfNeeded(event: Event.Conversation.MemberLeave) { val isSelfUserLeftConversation = event.removedList == listOf(selfUserId) && event.reason == MemberLeaveReason.Left if (!isSelfUserLeftConversation) return @@ -145,4 +141,33 @@ internal class MemberLeaveEventHandlerImpl( conversationRepository.deleteConversation(event.conversationId) conversationRepository.removeConversationFromDeleteQueue(event.conversationId) } + + private suspend fun deleteMembers( + userIDList: List, + conversationID: ConversationId + ): Either = + wrapStorageRequest { conversationDAO.getConversationProtocolInfo(conversationID.toDao()) } + .onSuccess { protocol -> + when (protocol) { + is ConversationEntity.ProtocolInfo.MLSCapable -> { + if (userIDList.contains(selfUserId)) { + mlsClientProvider.getMLSClient().map { mlsClient -> + wrapMLSRequest { + mlsClient.wipeConversation(protocol.groupId) + } + } + } + } + + ConversationEntity.ProtocolInfo.Proteus -> {} + } + } + .flatMap { + wrapStorageRequest { + memberDAO.deleteMembersByQualifiedID( + userIDList.map { it.toDao() }, + conversationID.toDao() + ) + } + } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandlerTest.kt index 8654cb3eb9..c7bf85be78 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandlerTest.kt @@ -18,6 +18,9 @@ package com.wire.kalium.logic.sync.receiver.conversation import com.wire.kalium.common.error.CoreFailure +import com.wire.kalium.common.functional.Either +import com.wire.kalium.cryptography.MLSClient +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.event.MemberLeaveReason import com.wire.kalium.logic.data.id.ConversationId @@ -27,7 +30,6 @@ import com.wire.kalium.logic.data.message.Message import com.wire.kalium.logic.data.message.MessageContent import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.feature.call.usecase.UpdateConversationClientsForCurrentCallUseCase -import com.wire.kalium.common.functional.Either import com.wire.kalium.logic.sync.receiver.handler.legalhold.LegalHoldHandler import com.wire.kalium.logic.util.arrangement.dao.MemberDAOArrangement import com.wire.kalium.logic.util.arrangement.dao.MemberDAOArrangementImpl @@ -40,6 +42,8 @@ import com.wire.kalium.logic.util.arrangement.repository.UserRepositoryArrangeme import com.wire.kalium.logic.util.arrangement.usecase.PersistMessageUseCaseArrangement import com.wire.kalium.logic.util.arrangement.usecase.PersistMessageUseCaseArrangementImpl import com.wire.kalium.persistence.dao.QualifiedIDEntity +import com.wire.kalium.persistence.dao.conversation.ConversationDAO +import com.wire.kalium.persistence.dao.conversation.ConversationEntity import com.wire.kalium.util.time.UNIX_FIRST_DATE import io.mockative.Mock import io.mockative.any @@ -71,6 +75,7 @@ class MemberLeaveEventHandlerTest { conversationId = EqualsMatcher(event.conversationId.toDao()), memberIdList = EqualsMatcher(list) ) + withGetConversationProtocolInfoReturns(ConversationEntity.ProtocolInfo.Proteus) } memberLeaveEventHandler.handle(event) @@ -101,6 +106,7 @@ class MemberLeaveEventHandlerTest { ) withPersistingMessage(Either.Left(failure)) withDeleteMembersByQualifiedIDThrows(throws = IllegalArgumentException()) + withGetConversationProtocolInfoReturns(ConversationEntity.ProtocolInfo.Proteus) } memberLeaveEventHandler.handle(event) @@ -131,6 +137,7 @@ class MemberLeaveEventHandlerTest { withPersistingMessage(Either.Right(Unit), messageMatcher = EqualsMatcher(message)) withTeamId(Either.Right(TeamId("teamId"))) withIsAtLeastOneUserATeamMember(Either.Right(true)) + withGetConversationProtocolInfoReturns(ConversationEntity.ProtocolInfo.Proteus) } memberLeaveEventHandler.handle(event) @@ -162,6 +169,7 @@ class MemberLeaveEventHandlerTest { withFetchUsersIfUnknownByIdsReturning(Either.Right(Unit), userIdList = EqualsMatcher(event.removedList.toSet())) withTeamId(Either.Right(null)) withPersistingMessage(Either.Right(Unit)) + withGetConversationProtocolInfoReturns(ConversationEntity.ProtocolInfo.Proteus) } memberLeaveEventHandler.handle(event) @@ -206,6 +214,7 @@ class MemberLeaveEventHandlerTest { memberIdList = EqualsMatcher(list) ) withIsAtLeastOneUserATeamMember(Either.Right(false)) + withGetConversationProtocolInfoReturns(ConversationEntity.ProtocolInfo.Proteus) } memberLeaveEventHandler.handle(event) @@ -244,6 +253,7 @@ class MemberLeaveEventHandlerTest { conversationId = EqualsMatcher(event.conversationId.toDao()), memberIdList = EqualsMatcher(list) ) + withGetConversationProtocolInfoReturns(ConversationEntity.ProtocolInfo.Proteus) } // when memberLeaveEventHandler.handle(event) @@ -269,6 +279,7 @@ class MemberLeaveEventHandlerTest { withPersistingMessage(Either.Right(Unit)) withGetConversationsDeleteQueue(listOf(event.conversationId)) withDeletingConversationSucceeding(EqualsMatcher(event.conversationId)) + withGetConversationProtocolInfoReturns(ConversationEntity.ProtocolInfo.Proteus) } memberLeaveEventHandler.handle(event) @@ -281,6 +292,60 @@ class MemberLeaveEventHandlerTest { coVerify { arrangement.conversationRepository.removeConversationFromDeleteQueue(event.conversationId) }.wasInvoked(once) } + @Test + fun givenUserLeavesMLSGroup_whenHandlingMemberLeaveEvent_thenMLSClientShouldWipeConversation() = runTest { + val event = memberLeaveEvent(reason = MemberLeaveReason.Left).copy( + conversationId = conversationId, + removedList = listOf(selfUserId), removedBy = selfUserId + ) + + val (arrangement, memberLeaveEventHandler) = Arrangement() + .arrange { + withDeleteMembersByQualifiedID( + result = event.removedList.size.toLong(), + conversationId = EqualsMatcher(event.conversationId.toDao()), + memberIdList = EqualsMatcher(event.removedList.map { QualifiedIDEntity(it.value, it.domain) }) + ) + withFetchUsersIfUnknownByIdsReturning(Either.Right(Unit), userIdList = EqualsMatcher(event.removedList.toSet())) + withTeamId(Either.Right(null)) + withPersistingMessage(Either.Right(Unit)) + withGetConversationsDeleteQueue(listOf(event.conversationId)) + withDeletingConversationSucceeding(EqualsMatcher(event.conversationId)) + withGetConversationProtocolInfoReturns(mlsProtocolInfo1) + } + + memberLeaveEventHandler.handle(event) + + coVerify { arrangement.updateConversationClientsForCurrentCall.invoke(eq(event.conversationId)) }.wasInvoked(exactly = once) + coVerify { arrangement.mlsClient.wipeConversation(any()) }.wasInvoked(once) + } + + @Test + fun givenOtherUsersRemainInMLSGroup_whenHandlingMemberLeaveEvent_thenDoNotWipeConversation() = runTest { + val event = memberLeaveEvent(reason = MemberLeaveReason.Removed).copy( + removedList = listOf(UserId("userId1", "domain"), UserId("userId2", "domain")) + ) + + val (arrangement, memberLeaveEventHandler) = Arrangement() + .arrange { + withFetchUsersIfUnknownByIdsReturning(Either.Right(Unit), userIdList = EqualsMatcher(event.removedList.toSet())) + withPersistingMessage(Either.Right(Unit)) + withTeamId(Either.Right(null)) + withDeleteMembersByQualifiedID( + result = list.size.toLong(), + conversationId = EqualsMatcher(event.conversationId.toDao()), + memberIdList = EqualsMatcher(list) + ) + withGetConversationsDeleteQueue(listOf(event.conversationId)) + withDeletingConversationSucceeding(EqualsMatcher(event.conversationId)) + withGetConversationProtocolInfoReturns(mlsProtocolInfo1) + } + + memberLeaveEventHandler.handle(event) + + coVerify { arrangement.mlsClient.wipeConversation(any()) }.wasNotInvoked() + } + private class Arrangement : UserRepositoryArrangement by UserRepositoryArrangementImpl(), PersistMessageUseCaseArrangement by PersistMessageUseCaseArrangementImpl(), @@ -294,13 +359,37 @@ class MemberLeaveEventHandlerTest { @Mock val legalHoldHandler = mock(LegalHoldHandler::class) + @Mock + val mlsClientProvider = mock(MLSClientProvider::class) + + @Mock + val conversationDAO = mock(ConversationDAO::class) + + @Mock + val mlsClient = mock(MLSClient::class) + private lateinit var memberLeaveEventHandler: MemberLeaveEventHandler + suspend fun withGetConversationProtocolInfoReturns(protocolInfo: ConversationEntity.ProtocolInfo) = apply { + coEvery { + conversationDAO.getConversationProtocolInfo(any()) + }.returns(protocolInfo) + } + suspend fun arrange(block: suspend Arrangement.() -> Unit): Pair = run { coEvery { legalHoldHandler.handleConversationMembersChanged(any()) }.returns(Either.Right(Unit)) withRemoveConversationToDeleteQueue() + coEvery { + mlsClientProvider.getMLSClient(any()) + }.returns(Either.Right(mlsClient)) + coEvery { + mlsClient.wipeConversation(any()) + }.returns(Unit) + coEvery { + updateConversationClientsForCurrentCall.invoke(any()) + }.returns(Unit) block() memberLeaveEventHandler = MemberLeaveEventHandlerImpl( memberDAO = memberDAO, @@ -310,7 +399,9 @@ class MemberLeaveEventHandlerTest { updateConversationClientsForCurrentCall = lazy { updateConversationClientsForCurrentCall }, legalHoldHandler = legalHoldHandler, selfTeamIdProvider = selfTeamIdProvider, - selfUserId = selfUserId + selfUserId = selfUserId, + mlsClientProvider = mlsClientProvider, + conversationDAO = conversationDAO, ) this to memberLeaveEventHandler } @@ -326,6 +417,14 @@ class MemberLeaveEventHandlerTest { val conversationId = ConversationId("conversationId", "domain") val list = listOf(qualifiedUserIdEntity) + val mlsProtocolInfo1 = ConversationEntity.ProtocolInfo.MLS( + "group2", + ConversationEntity.GroupState.ESTABLISHED, + 0UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + ) + fun memberLeaveEvent(reason: MemberLeaveReason) = Event.Conversation.MemberLeave( id = "id", conversationId = conversationId,