diff --git a/data/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientModel.kt b/data/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientModel.kt index ec562bd9319..e76ffa6d642 100644 --- a/data/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientModel.kt +++ b/data/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientModel.kt @@ -71,6 +71,7 @@ enum class DeviceType { sealed class ClientCapability { data object LegalHoldImplicitConsent : ClientCapability() + data object ConsumableNotifications : ClientCapability() data class Unknown(val name: String) : ClientCapability() } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientMapper.kt index f53feb17f5b..01b7f7ecdc6 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientMapper.kt @@ -198,11 +198,13 @@ class ClientMapper( private fun toClientCapabilityDTO(clientCapability: ClientCapability): ClientCapabilityDTO = when (clientCapability) { ClientCapability.LegalHoldImplicitConsent -> ClientCapabilityDTO.LegalHoldImplicitConsent + ClientCapability.ConsumableNotifications -> ClientCapabilityDTO.ConsumableNotifications is ClientCapability.Unknown -> ClientCapabilityDTO.Unknown(clientCapability.name) } private fun fromClientCapabilityDTO(clientCapabilityDTO: ClientCapabilityDTO): ClientCapability = when (clientCapabilityDTO) { ClientCapabilityDTO.LegalHoldImplicitConsent -> ClientCapability.LegalHoldImplicitConsent + ClientCapabilityDTO.ConsumableNotifications -> ClientCapability.ConsumableNotifications is ClientCapabilityDTO.Unknown -> ClientCapability.Unknown(clientCapabilityDTO.name) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt index d3ecbc13029..9578ebe8676 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt @@ -63,7 +63,7 @@ interface KeyPackageRepository { suspend fun replaceKeyPackages(clientId: ClientId, keyPackages: List, cipherSuite: CipherSuite): Either - suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either + suspend fun getAvailableKeyPackageCount(clientId: ClientId, cipherSuite: CipherSuite): Either suspend fun validKeyPackageCount(clientId: ClientId): Either } @@ -138,8 +138,11 @@ class KeyPackageDataSource( } } - override suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either = + override suspend fun getAvailableKeyPackageCount( + clientId: ClientId, + cipherSuite: CipherSuite + ): Either = wrapApiRequest { - keyPackageApi.getAvailableKeyPackageCount(clientId.value) + keyPackageApi.getAvailableKeyPackageCount(clientId.value, cipherSuite.tag) } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt index c2e37c14161..c98db95a4d7 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt @@ -104,14 +104,21 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor( val deregisterNativePushToken: DeregisterTokenUseCase get() = DeregisterTokenUseCaseImpl(clientRepository, notificationTokenRepository) val mlsKeyPackageCountUseCase: MLSKeyPackageCountUseCase - get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider, userConfigRepository) + get() = MLSKeyPackageCountUseCaseImpl( + keyPackageRepository = keyPackageRepository, + currentClientIdProvider = clientIdProvider, + mlsClientProvider = mlsClientProvider, + keyPackageLimitsProvider = keyPackageLimitsProvider, + userConfigRepository = userConfigRepository, + ) val restartSlowSyncProcessForRecoveryUseCase: RestartSlowSyncProcessForRecoveryUseCase get() = RestartSlowSyncProcessForRecoveryUseCaseImpl(slowSyncRepository) val refillKeyPackages: RefillKeyPackagesUseCase get() = RefillKeyPackagesUseCaseImpl( - keyPackageRepository, - keyPackageLimitsProvider, - clientIdProvider + keyPackageRepository = keyPackageRepository, + keyPackageLimitsProvider = keyPackageLimitsProvider, + mlsClientProvider = mlsClientProvider, + currentClientIdProvider = clientIdProvider, ) val observeCurrentClientId: ObserveCurrentClientIdUseCase diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt index ab4528d6f9a..dec34dd2d44 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt @@ -21,12 +21,15 @@ package com.wire.kalium.logic.feature.keypackage import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.configuration.UserConfigRepository +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider import com.wire.kalium.logic.data.keypackage.KeyPackageRepository import com.wire.kalium.logic.data.id.CurrentClientIdProvider +import com.wire.kalium.logic.data.mls.CipherSuite import com.wire.kalium.logic.functional.fold import com.wire.kalium.logic.functional.getOrElse +import com.wire.kalium.logic.functional.map /** * This use case will return the current number of key packages. @@ -38,6 +41,7 @@ interface MLSKeyPackageCountUseCase { internal class MLSKeyPackageCountUseCaseImpl( private val keyPackageRepository: KeyPackageRepository, private val currentClientIdProvider: CurrentClientIdProvider, + private val mlsClientProvider: MLSClientProvider, private val keyPackageLimitsProvider: KeyPackageLimitsProvider, private val userConfigRepository: UserConfigRepository ) : MLSKeyPackageCountUseCase { @@ -47,19 +51,24 @@ internal class MLSKeyPackageCountUseCaseImpl( false -> validKeyPackagesCountFromMLSClient() } - private suspend fun validKeyPackagesCountFromAPI() = currentClientIdProvider().fold({ - MLSKeyPackageCountResult.Failure.FetchClientIdFailure(it) - }, { selfClient -> - if (userConfigRepository.isMLSEnabled().getOrElse(false)) { - keyPackageRepository.getAvailableKeyPackageCount(selfClient) - .fold( - { MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) }, - { MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) } - ) - } else { - MLSKeyPackageCountResult.Failure.NotEnabled + @Suppress("ReturnCount") + private suspend fun validKeyPackagesCountFromAPI(): MLSKeyPackageCountResult { + val selfClientId = currentClientIdProvider().getOrElse { + return MLSKeyPackageCountResult.Failure.FetchClientIdFailure(it) } - }) + + if (!userConfigRepository.isMLSEnabled().getOrElse(false)) { + return MLSKeyPackageCountResult.Failure.NotEnabled + } + + val cipherSuite = mlsClientProvider.getMLSClient().map { CipherSuite.fromTag(it.getDefaultCipherSuite()) } + .getOrElse { return MLSKeyPackageCountResult.Failure.Generic(it) } + + return keyPackageRepository.getAvailableKeyPackageCount(selfClientId, cipherSuite).fold( + { MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) }, + { MLSKeyPackageCountResult.Success(selfClientId, it.count, keyPackageLimitsProvider.needsRefill(it.count)) } + ) + } private suspend fun validKeyPackagesCountFromMLSClient() = currentClientIdProvider().fold({ diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/RefillKeyPackageUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/RefillKeyPackageUseCase.kt index 10a17acccdd..51bc5806d46 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/RefillKeyPackageUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/RefillKeyPackageUseCase.kt @@ -19,12 +19,16 @@ package com.wire.kalium.logic.feature.keypackage import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider import com.wire.kalium.logic.data.keypackage.KeyPackageRepository import com.wire.kalium.logic.data.id.CurrentClientIdProvider +import com.wire.kalium.logic.data.mls.CipherSuite import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.fold +import com.wire.kalium.logic.functional.getOrElse +import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.kaliumLogger sealed class RefillKeyPackagesResult { @@ -47,30 +51,32 @@ interface RefillKeyPackagesUseCase { internal class RefillKeyPackagesUseCaseImpl( private val keyPackageRepository: KeyPackageRepository, private val keyPackageLimitsProvider: KeyPackageLimitsProvider, + private val mlsClientProvider: MLSClientProvider, private val currentClientIdProvider: CurrentClientIdProvider, ) : RefillKeyPackagesUseCase { - override suspend operator fun invoke(): RefillKeyPackagesResult = - currentClientIdProvider().flatMap { selfClientId -> - // TODO: Maybe use MLSKeyPackageCountUseCase instead of repository directly, - // and fetch from local instead of remote - keyPackageRepository.getAvailableKeyPackageCount(selfClientId) - .flatMap { - kaliumLogger.i("Key packages: Found ${it.count} available key packages") - if (keyPackageLimitsProvider.needsRefill(it.count)) { - kaliumLogger.i("Key packages: Refilling key packages...") - val amount = keyPackageLimitsProvider.refillAmount() - keyPackageRepository.uploadNewKeyPackages(selfClientId, amount).flatMap { - Either.Right(Unit) - } - } else { - kaliumLogger.i("Key packages: Refill not needed") - Either.Right(Unit) - } + override suspend operator fun invoke(): RefillKeyPackagesResult { + val selfClientId = currentClientIdProvider().getOrElse { + return RefillKeyPackagesResult.Failure(it) + } + + return mlsClientProvider.getMLSClient().map { CipherSuite.fromTag(it.getDefaultCipherSuite()) }.flatMap { cipherSuite -> + keyPackageRepository.getAvailableKeyPackageCount(selfClientId, cipherSuite) + }.flatMap { + kaliumLogger.i("Key packages: Found ${it.count} available key packages") + if (keyPackageLimitsProvider.needsRefill(it.count)) { + kaliumLogger.i("Key packages: Refilling key packages...") + val amount = keyPackageLimitsProvider.refillAmount() + keyPackageRepository.uploadNewKeyPackages(selfClientId, amount).flatMap { + Either.Right(Unit) } + } else { + kaliumLogger.i("Key packages: Refill not needed") + Either.Right(Unit) + } }.fold({ failure -> RefillKeyPackagesResult.Failure(failure) }, { RefillKeyPackagesResult.Success }) - + } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepositoryTest.kt index 1d252a0b6b1..2b64e7a29bb 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepositoryTest.kt @@ -70,13 +70,13 @@ class KeyPackageRepositoryTest { @Test fun givenExistingClient_whenGettingAvailableKeyPackageCount_thenResultShouldBePropagated() = runTest { - + val cipherSuite = CipherSuite.fromTag(1) val (_, keyPackageRepository) = Arrangement() .withMLSClient() .withGetAvailableKeyPackageCountSuccessful() .arrange() - val keyPackageCount = keyPackageRepository.getAvailableKeyPackageCount(Arrangement.SELF_CLIENT_ID) + val keyPackageCount = keyPackageRepository.getAvailableKeyPackageCount(Arrangement.SELF_CLIENT_ID, cipherSuite) assertIs>(keyPackageCount) assertEquals(Arrangement.KEY_PACKAGE_COUNT_DTO.count, keyPackageCount.value.count) @@ -210,7 +210,7 @@ class KeyPackageRepositoryTest { suspend fun withGetAvailableKeyPackageCountSuccessful() = apply { coEvery { - keyPackageApi.getAvailableKeyPackageCount(eq(SELF_CLIENT_ID.value)) + keyPackageApi.getAvailableKeyPackageCount(eq(SELF_CLIENT_ID.value), any()) }.returns(NetworkResponse.Success(KEY_PACKAGE_COUNT_DTO, mapOf(), 200)) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt index 5c8afbeb180..d4781d22bdd 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt @@ -18,12 +18,15 @@ package com.wire.kalium.logic.feature.keypackage +import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.id.CurrentClientIdProvider import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider import com.wire.kalium.logic.data.keypackage.KeyPackageRepository +import com.wire.kalium.logic.data.mls.CipherSuite import com.wire.kalium.logic.feature.keypackage.MLSKeyPackageCountUseCaseTest.Arrangement.Companion.CLIENT_FETCH_ERROR import com.wire.kalium.logic.feature.keypackage.MLSKeyPackageCountUseCaseTest.Arrangement.Companion.KEY_PACKAGE_COUNT import com.wire.kalium.logic.feature.keypackage.MLSKeyPackageCountUseCaseTest.Arrangement.Companion.KEY_PACKAGE_COUNT_DTO @@ -53,16 +56,19 @@ class MLSKeyPackageCountUseCaseTest { @Test fun givenClientIdIsNotRegistered_ThenReturnGenericError() = runTest { + val expectedCipherSuite = CipherSuite.MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519 + val (arrangement, keyPackageCountUseCase) = Arrangement() .withClientId(Either.Left(CLIENT_FETCH_ERROR)) - .arrange{ + .withDefaultCipherSuite(expectedCipherSuite) + .arrange { withGetMLSEnabledReturning(true.right()) } val actual = keyPackageCountUseCase() coVerify { - arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID)) + arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID), eq(expectedCipherSuite)) }.wasNotInvoked() assertIs(actual) @@ -71,18 +77,21 @@ class MLSKeyPackageCountUseCaseTest { @Test fun givenClientId_whenCallingKeyPackageCountReturnValue_ThenReturnKeyPackageCountSuccess() = runTest { + val expectedCipherSuite = CipherSuite.MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519 + val (arrangement, keyPackageCountUseCase) = Arrangement() .withAvailableKeyPackageCountReturn(Either.Right(KEY_PACKAGE_COUNT_DTO)) .withClientId(Either.Right(TestClient.CLIENT_ID)) + .withDefaultCipherSuite(expectedCipherSuite) .withKeyPackageLimitSucceed() - .arrange{ + .arrange { withGetMLSEnabledReturning(true.right()) } val actual = keyPackageCountUseCase() coVerify { - arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID)) + arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID), eq(expectedCipherSuite)) }.wasInvoked(once) assertIs(actual) assertEquals(actual, MLSKeyPackageCountResult.Success(TestClient.CLIENT_ID, KEY_PACKAGE_COUNT, true)) @@ -90,17 +99,20 @@ class MLSKeyPackageCountUseCaseTest { @Test fun givenClientID_whenCallingKeyPackageCountReturnError_ThenReturnKeyPackageCountFailure() = runTest { + val expectedCipherSuite = CipherSuite.MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519 + val (arrangement, keyPackageCountUseCase) = Arrangement() .withAvailableKeyPackageCountReturn(Either.Left(NETWORK_FAILURE)) .withClientId(Either.Right(TestClient.CLIENT_ID)) - .arrange{ + .withDefaultCipherSuite(expectedCipherSuite) + .arrange { withGetMLSEnabledReturning(true.right()) } val actual = keyPackageCountUseCase() coVerify { - arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID)) + arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID), eq(expectedCipherSuite)) }.wasInvoked(once) assertIs(actual) assertEquals(actual.networkFailure, NETWORK_FAILURE) @@ -108,10 +120,12 @@ class MLSKeyPackageCountUseCaseTest { @Test fun givenClientID_whenCallingGetMLSEnabledReturnFalse_ThenReturnKeyPackageCountNotEnabledFailure() = runTest { + val expectedCipherSuite = CipherSuite.MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519 val (arrangement, keyPackageCountUseCase) = Arrangement() .withAvailableKeyPackageCountReturn(Either.Right(KEY_PACKAGE_COUNT_DTO)) .withClientId(Either.Right(TestClient.CLIENT_ID)) - .arrange{ + .withDefaultCipherSuite(expectedCipherSuite) + .arrange { withGetMLSEnabledReturning(false.right()) } @@ -122,7 +136,7 @@ class MLSKeyPackageCountUseCaseTest { }.wasInvoked(once) coVerify { - arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID)) + arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID), eq(expectedCipherSuite)) }.wasNotInvoked() assertIs(actual) } @@ -137,6 +151,12 @@ class MLSKeyPackageCountUseCaseTest { @Mock val keyPackageLimitsProvider = mock(KeyPackageLimitsProvider::class) + @Mock + val mlsClientProvider = mock(MLSClientProvider::class) + + @Mock + val mlsClient = mock(MLSClient::class) + suspend fun withClientId(result: Either) = apply { coEvery { currentClientIdProvider.invoke() @@ -151,13 +171,27 @@ class MLSKeyPackageCountUseCaseTest { suspend fun withAvailableKeyPackageCountReturn(result: Either) = apply { coEvery { - keyPackageRepository.getAvailableKeyPackageCount(any()) + keyPackageRepository.getAvailableKeyPackageCount(any(), any()) }.returns(result) } - fun arrange(block: suspend Arrangement.() -> Unit) = apply { runBlocking { block() } }.let { + fun withDefaultCipherSuite(cipherSuite: CipherSuite) = apply { + every { + mlsClient.getDefaultCipherSuite() + }.returns(cipherSuite.tag.toUShort()) + } + + suspend fun arrange(block: suspend Arrangement.() -> Unit) = apply { + coEvery { mlsClientProvider.getMLSClient() }.returns(mlsClient.right()) + }.apply { + runBlocking { block() } + }.let { this to MLSKeyPackageCountUseCaseImpl( - keyPackageRepository, currentClientIdProvider, keyPackageLimitsProvider, userConfigRepository + keyPackageRepository = keyPackageRepository, + currentClientIdProvider = currentClientIdProvider, + keyPackageLimitsProvider = keyPackageLimitsProvider, + userConfigRepository = userConfigRepository, + mlsClientProvider = mlsClientProvider, ) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/RefillKeyPackageUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/RefillKeyPackageUseCaseTest.kt index f078e0ab03f..ef2ff15768a 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/RefillKeyPackageUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/RefillKeyPackageUseCaseTest.kt @@ -18,19 +18,29 @@ package com.wire.kalium.logic.feature.keypackage +import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.data.client.MLSClientProvider +import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.id.CurrentClientIdProvider import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider import com.wire.kalium.logic.data.keypackage.KeyPackageRepository +import com.wire.kalium.logic.data.mls.CipherSuite import com.wire.kalium.logic.framework.TestClient import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.right import com.wire.kalium.network.api.authenticated.keypackage.KeyPackageCountDTO +import io.mockative.Matchers import io.mockative.Mock import io.mockative.any import io.mockative.coEvery import io.mockative.coVerify import io.mockative.eq import io.mockative.every +import io.mockative.fake.valueOf +import io.mockative.matchers.AnyMatcher +import io.mockative.matchers.Matcher +import io.mockative.matches import io.mockative.mock import io.mockative.once import kotlinx.coroutines.ExperimentalCoroutinesApi @@ -39,7 +49,6 @@ import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertIs -@OptIn(ExperimentalCoroutinesApi::class) class RefillKeyPackageUseCaseTest { @Test @@ -51,6 +60,7 @@ class RefillKeyPackageUseCaseTest { .withKeyPackageLimits(true, Arrangement.KEY_PACKAGE_LIMIT - keyPackageCount) .withKeyPackageCount(keyPackageCount) .withUploadKeyPackagesSuccessful() + .withDefaultCipherSuite(CipherSuite.fromTag(1)) .arrange() val actual = refillKeyPackagesUseCase() @@ -70,6 +80,7 @@ class RefillKeyPackageUseCaseTest { .withExistingSelfClientId() .withKeyPackageLimits(false, 0) .withKeyPackageCount(keyPackageCount) + .withDefaultCipherSuite(CipherSuite.fromTag(1)) .arrange() val actual = refillKeyPackagesUseCase() @@ -85,6 +96,7 @@ class RefillKeyPackageUseCaseTest { .withExistingSelfClientId() .withKeyPackageLimits(true, 0) .withGetAvailableKeyPackagesFailing(networkFailure) + .withDefaultCipherSuite(CipherSuite.fromTag(1)) .arrange() val actual = refillKeyPackagesUseCase() @@ -103,12 +115,25 @@ class RefillKeyPackageUseCaseTest { @Mock val currentClientIdProvider = mock(CurrentClientIdProvider::class) + @Mock + val mlsClientProvider = mock(MLSClientProvider::class) + + @Mock + val mlsClient: MLSClient = mock(MLSClient::class) + private var refillKeyPackageUseCase = RefillKeyPackagesUseCaseImpl( keyPackageRepository, keyPackageLimitsProvider, - currentClientIdProvider + mlsClientProvider, + currentClientIdProvider, ) + fun withDefaultCipherSuite(cipherSuite: CipherSuite) = apply { + every { + mlsClient.getDefaultCipherSuite() + }.returns(cipherSuite.tag.toUShort()) + } + suspend fun withExistingSelfClientId() = apply { coEvery { currentClientIdProvider.invoke() @@ -124,9 +149,13 @@ class RefillKeyPackageUseCaseTest { }.returns(refillAmount) } - suspend fun withKeyPackageCount(count: Int) = apply { + suspend fun withKeyPackageCount( + count: Int, + clientId: AnyMatcher = AnyMatcher(valueOf()), + cipherSuite: AnyMatcher = AnyMatcher(valueOf()), + ) = apply { coEvery { - keyPackageRepository.getAvailableKeyPackageCount(any()) + keyPackageRepository.getAvailableKeyPackageCount(matches { clientId.matches(it) }, matches { cipherSuite.matches(it) }) }.returns(Either.Right(KeyPackageCountDTO(count))) } @@ -136,13 +165,23 @@ class RefillKeyPackageUseCaseTest { }.returns(Either.Right(Unit)) } - suspend fun withGetAvailableKeyPackagesFailing(failure: NetworkFailure) = apply { + suspend fun withGetAvailableKeyPackagesFailing( + failure: NetworkFailure, + clientId: AnyMatcher = AnyMatcher(valueOf()), + cipherSuite: AnyMatcher = AnyMatcher(valueOf()), + ) = apply { coEvery { - keyPackageRepository.getAvailableKeyPackageCount(any()) + keyPackageRepository.getAvailableKeyPackageCount(matches { clientId.matches(it) }, matches { cipherSuite.matches(it) }) }.returns(Either.Left(failure)) } - fun arrange() = this to refillKeyPackageUseCase + suspend fun arrange() = apply { + coEvery { + mlsClientProvider.getMLSClient() + }.returns(mlsClient.right()) + }.let { + this to refillKeyPackageUseCase + } companion object { const val KEY_PACKAGE_LIMIT = 100 @@ -150,5 +189,4 @@ class RefillKeyPackageUseCaseTest { } } - } diff --git a/network-model/src/commonMain/kotlin/com/wire/kalium/network/api/authenticated/client/ClientRequest.kt b/network-model/src/commonMain/kotlin/com/wire/kalium/network/api/authenticated/client/ClientRequest.kt index feb0c29fa24..878d1127072 100644 --- a/network-model/src/commonMain/kotlin/com/wire/kalium/network/api/authenticated/client/ClientRequest.kt +++ b/network-model/src/commonMain/kotlin/com/wire/kalium/network/api/authenticated/client/ClientRequest.kt @@ -96,6 +96,9 @@ enum class DeviceTypeDTO { sealed class ClientCapabilityDTO { @SerialName("legalhold-implicit-consent") data object LegalHoldImplicitConsent : ClientCapabilityDTO() + + @SerialName("consumable-notifications") + data object ConsumableNotifications : ClientCapabilityDTO() data class Unknown(val name: String) : ClientCapabilityDTO() } @@ -110,14 +113,19 @@ object ClientCapabilityDTOSerializer : KSerializer { is ClientCapabilityDTO.LegalHoldImplicitConsent -> encoder.encodeString("legalhold-implicit-consent") + ClientCapabilityDTO.ConsumableNotifications -> + encoder.encodeString("consumable-notifications") + is ClientCapabilityDTO.Unknown -> encoder.encodeString(value.name) + } } override fun deserialize(decoder: Decoder): ClientCapabilityDTO { return when (val value = decoder.decodeString()) { "legalhold-implicit-consent" -> ClientCapabilityDTO.LegalHoldImplicitConsent + "consumable-notifications" -> ClientCapabilityDTO.ConsumableNotifications else -> ClientCapabilityDTO.Unknown(value) } } diff --git a/network-model/src/commonTest/kotlin/com/wire/kalium/network/api/authenticated/client/ClientCapabilityDTOSerializerTest.kt b/network-model/src/commonTest/kotlin/com/wire/kalium/network/api/authenticated/client/ClientCapabilityDTOSerializerTest.kt index 408490c1b44..6799c7d590c 100644 --- a/network-model/src/commonTest/kotlin/com/wire/kalium/network/api/authenticated/client/ClientCapabilityDTOSerializerTest.kt +++ b/network-model/src/commonTest/kotlin/com/wire/kalium/network/api/authenticated/client/ClientCapabilityDTOSerializerTest.kt @@ -35,6 +35,13 @@ class ClientCapabilityDTOSerializerTest { assertEquals("\"legalhold-implicit-consent\"", result) } + @Test + fun serialize_consumable_notifications() { + val capability = ClientCapabilityDTO.ConsumableNotifications + val result = json.encodeToString(ClientCapabilityDTO.serializer(), capability) + assertEquals("\"consumable-notifications\"", result) + } + @Test fun serialize_unknown_capability() { val capability = ClientCapabilityDTO.Unknown("custom-capability") @@ -49,6 +56,13 @@ class ClientCapabilityDTOSerializerTest { assertEquals(ClientCapabilityDTO.LegalHoldImplicitConsent, result) } + @Test + fun deserialize_consumable_notifications() { + val jsonString = "\"consumable-notifications\"" + val result = json.decodeFromString(ClientCapabilityDTO.serializer(), jsonString) + assertEquals(ClientCapabilityDTO.ConsumableNotifications, result) + } + @Test fun deserialize_unknown_capability() { val jsonString = "\"unknown-capability\"" diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt index 157cde34fd8..a81b097f236 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt @@ -89,5 +89,8 @@ interface KeyPackageApi { * * @return unclaimed key package count */ - suspend fun getAvailableKeyPackageCount(clientId: String): NetworkResponse + suspend fun getAvailableKeyPackageCount( + clientId: String, + cipherSuite: Int, + ): NetworkResponse } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/ConversationApiV0.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/ConversationApiV0.kt index 600c6102788..a509f0bf99d 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/ConversationApiV0.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/ConversationApiV0.kt @@ -370,7 +370,7 @@ internal open class ConversationApiV0 internal constructor( messageTimer: Long? ): NetworkResponse = wrapKaliumResponse { - httpClient.put("$PATH_CONVERSATIONS/${conversationId.value}/$PATH_MESSAGE_TIMER") { + httpClient.put("$PATH_CONVERSATIONS/${conversationId.domain}/${conversationId.value}/$PATH_MESSAGE_TIMER") { setBody(ConversationMessageTimerDTO(messageTimer)) } } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/KeyPackageApiV0.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/KeyPackageApiV0.kt index 1cdc364e8a0..000f6191aa9 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/KeyPackageApiV0.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/KeyPackageApiV0.kt @@ -47,7 +47,7 @@ internal open class KeyPackageApiV0 internal constructor() : KeyPackageApi { APINotSupported("MLS: replaceKeyPackages api is only available on API V5") ) - override suspend fun getAvailableKeyPackageCount(clientId: String): NetworkResponse = + override suspend fun getAvailableKeyPackageCount(clientId: String, cipherSuite: Int): NetworkResponse = NetworkResponse.Error( APINotSupported("MLS: getAvailableKeyPackageCount api is only available on API V5") ) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt index 9b4c592ef6a..af5c86ef112 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt @@ -79,8 +79,15 @@ internal open class KeyPackageApiV5 internal constructor( } } - override suspend fun getAvailableKeyPackageCount(clientId: String): NetworkResponse = - wrapKaliumResponse { httpClient.get("$PATH_KEY_PACKAGES/$PATH_SELF/$clientId/$PATH_COUNT") } + override suspend fun getAvailableKeyPackageCount( + clientId: String, + cipherSuite: Int, + ): NetworkResponse = + wrapKaliumResponse { + httpClient.get("$PATH_KEY_PACKAGES/$PATH_SELF/$clientId/$PATH_COUNT") { + parameter(QUERY_CIPHER_SUITE, cipherSuite.toHexString()) + } + } private companion object { const val PATH_KEY_PACKAGES = "mls/key-packages" diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/v5/KeyPackageApiV5Test.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/v5/KeyPackageApiV5Test.kt index 9f65d0f7ae9..503d0a1c837 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/v5/KeyPackageApiV5Test.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/v5/KeyPackageApiV5Test.kt @@ -24,6 +24,7 @@ import com.wire.kalium.network.api.base.authenticated.keypackage.KeyPackageApi import com.wire.kalium.network.api.model.UserId import com.wire.kalium.network.api.v5.authenticated.KeyPackageApiV5 import com.wire.kalium.network.utils.isSuccessful +import com.wire.kalium.util.int.toHexString import io.ktor.http.HttpStatusCode import kotlinx.coroutines.test.runTest import kotlin.test.Test @@ -34,17 +35,21 @@ internal class KeyPackageApiV5Test : ApiTest() { @Test fun givenAValidClientId_whenCallingGetAvailableKeyPackageCountEndpoint_theRequestShouldBeConfiguredCorrectly() = runTest { + val cipherSuite = 0 + val expectedCipherSuite = cipherSuite.toHexString() + val networkClient = mockAuthenticatedNetworkClient( KeyPackageJson.keyPackageCountJson(KEY_PACKAGE_COUNT).rawJson, statusCode = HttpStatusCode.OK, assertion = { assertGet() assertPathEqual(KEY_PACKAGE_COUNT_PATH) + assertQueryParameter(name = "ciphersuite", hasValue = expectedCipherSuite) } ) val keyPackageApi: KeyPackageApi = KeyPackageApiV5(networkClient) - val response = keyPackageApi.getAvailableKeyPackageCount(VALID_CLIENT_ID) + val response = keyPackageApi.getAvailableKeyPackageCount(VALID_CLIENT_ID, cipherSuite) assertTrue(response.isSuccessful()) assertEquals(response.value.count, KEY_PACKAGE_COUNT) }