Skip to content

Commit

Permalink
feat: adapt api v8 changes [WPB-15722] (#3288)
Browse files Browse the repository at this point in the history
* feat: adapt api v8 changes

* tests

* adapt changes to api making cipher suite mandatory

* detekt
  • Loading branch information
MohamadJaara authored Feb 11, 2025
1 parent a5d42e4 commit 23e2243
Show file tree
Hide file tree
Showing 16 changed files with 202 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ enum class DeviceType {

sealed class ClientCapability {
data object LegalHoldImplicitConsent : ClientCapability()
data object ConsumableNotifications : ClientCapability()
data class Unknown(val name: String) : ClientCapability()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,13 @@ class ClientMapper(

private fun toClientCapabilityDTO(clientCapability: ClientCapability): ClientCapabilityDTO = when (clientCapability) {
ClientCapability.LegalHoldImplicitConsent -> ClientCapabilityDTO.LegalHoldImplicitConsent
ClientCapability.ConsumableNotifications -> ClientCapabilityDTO.ConsumableNotifications
is ClientCapability.Unknown -> ClientCapabilityDTO.Unknown(clientCapability.name)
}

private fun fromClientCapabilityDTO(clientCapabilityDTO: ClientCapabilityDTO): ClientCapability = when (clientCapabilityDTO) {
ClientCapabilityDTO.LegalHoldImplicitConsent -> ClientCapability.LegalHoldImplicitConsent
ClientCapabilityDTO.ConsumableNotifications -> ClientCapability.ConsumableNotifications
is ClientCapabilityDTO.Unknown -> ClientCapability.Unknown(clientCapabilityDTO.name)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ interface KeyPackageRepository {

suspend fun replaceKeyPackages(clientId: ClientId, keyPackages: List<ByteArray>, cipherSuite: CipherSuite): Either<CoreFailure, Unit>

suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either<NetworkFailure, KeyPackageCountDTO>
suspend fun getAvailableKeyPackageCount(clientId: ClientId, cipherSuite: CipherSuite): Either<NetworkFailure, KeyPackageCountDTO>

suspend fun validKeyPackageCount(clientId: ClientId): Either<CoreFailure, Int>
}
Expand Down Expand Up @@ -138,8 +138,11 @@ class KeyPackageDataSource(
}
}

override suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either<NetworkFailure, KeyPackageCountDTO> =
override suspend fun getAvailableKeyPackageCount(
clientId: ClientId,
cipherSuite: CipherSuite
): Either<NetworkFailure, KeyPackageCountDTO> =
wrapApiRequest {
keyPackageApi.getAvailableKeyPackageCount(clientId.value)
keyPackageApi.getAvailableKeyPackageCount(clientId.value, cipherSuite.tag)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,21 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
val deregisterNativePushToken: DeregisterTokenUseCase
get() = DeregisterTokenUseCaseImpl(clientRepository, notificationTokenRepository)
val mlsKeyPackageCountUseCase: MLSKeyPackageCountUseCase
get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider, userConfigRepository)
get() = MLSKeyPackageCountUseCaseImpl(
keyPackageRepository = keyPackageRepository,
currentClientIdProvider = clientIdProvider,
mlsClientProvider = mlsClientProvider,
keyPackageLimitsProvider = keyPackageLimitsProvider,
userConfigRepository = userConfigRepository,
)
val restartSlowSyncProcessForRecoveryUseCase: RestartSlowSyncProcessForRecoveryUseCase
get() = RestartSlowSyncProcessForRecoveryUseCaseImpl(slowSyncRepository)
val refillKeyPackages: RefillKeyPackagesUseCase
get() = RefillKeyPackagesUseCaseImpl(
keyPackageRepository,
keyPackageLimitsProvider,
clientIdProvider
keyPackageRepository = keyPackageRepository,
keyPackageLimitsProvider = keyPackageLimitsProvider,
mlsClientProvider = mlsClientProvider,
currentClientIdProvider = clientIdProvider,
)

val observeCurrentClientId: ObserveCurrentClientIdUseCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ package com.wire.kalium.logic.feature.keypackage
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.functional.map

/**
* This use case will return the current number of key packages.
Expand All @@ -38,6 +41,7 @@ interface MLSKeyPackageCountUseCase {
internal class MLSKeyPackageCountUseCaseImpl(
private val keyPackageRepository: KeyPackageRepository,
private val currentClientIdProvider: CurrentClientIdProvider,
private val mlsClientProvider: MLSClientProvider,
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
private val userConfigRepository: UserConfigRepository
) : MLSKeyPackageCountUseCase {
Expand All @@ -47,19 +51,24 @@ internal class MLSKeyPackageCountUseCaseImpl(
false -> validKeyPackagesCountFromMLSClient()
}

private suspend fun validKeyPackagesCountFromAPI() = currentClientIdProvider().fold({
MLSKeyPackageCountResult.Failure.FetchClientIdFailure(it)
}, { selfClient ->
if (userConfigRepository.isMLSEnabled().getOrElse(false)) {
keyPackageRepository.getAvailableKeyPackageCount(selfClient)
.fold(
{ MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) },
{ MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) }
)
} else {
MLSKeyPackageCountResult.Failure.NotEnabled
@Suppress("ReturnCount")
private suspend fun validKeyPackagesCountFromAPI(): MLSKeyPackageCountResult {
val selfClientId = currentClientIdProvider().getOrElse {
return MLSKeyPackageCountResult.Failure.FetchClientIdFailure(it)
}
})

if (!userConfigRepository.isMLSEnabled().getOrElse(false)) {
return MLSKeyPackageCountResult.Failure.NotEnabled
}

val cipherSuite = mlsClientProvider.getMLSClient().map { CipherSuite.fromTag(it.getDefaultCipherSuite()) }
.getOrElse { return MLSKeyPackageCountResult.Failure.Generic(it) }

return keyPackageRepository.getAvailableKeyPackageCount(selfClientId, cipherSuite).fold(
{ MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) },
{ MLSKeyPackageCountResult.Success(selfClientId, it.count, keyPackageLimitsProvider.needsRefill(it.count)) }
)
}

private suspend fun validKeyPackagesCountFromMLSClient() =
currentClientIdProvider().fold({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
package com.wire.kalium.logic.feature.keypackage

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.kaliumLogger

sealed class RefillKeyPackagesResult {
Expand All @@ -47,30 +51,32 @@ interface RefillKeyPackagesUseCase {
internal class RefillKeyPackagesUseCaseImpl(
private val keyPackageRepository: KeyPackageRepository,
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
private val mlsClientProvider: MLSClientProvider,
private val currentClientIdProvider: CurrentClientIdProvider,
) : RefillKeyPackagesUseCase {
override suspend operator fun invoke(): RefillKeyPackagesResult =
currentClientIdProvider().flatMap { selfClientId ->
// TODO: Maybe use MLSKeyPackageCountUseCase instead of repository directly,
// and fetch from local instead of remote
keyPackageRepository.getAvailableKeyPackageCount(selfClientId)
.flatMap {
kaliumLogger.i("Key packages: Found ${it.count} available key packages")
if (keyPackageLimitsProvider.needsRefill(it.count)) {
kaliumLogger.i("Key packages: Refilling key packages...")
val amount = keyPackageLimitsProvider.refillAmount()
keyPackageRepository.uploadNewKeyPackages(selfClientId, amount).flatMap {
Either.Right(Unit)
}
} else {
kaliumLogger.i("Key packages: Refill not needed")
Either.Right(Unit)
}
override suspend operator fun invoke(): RefillKeyPackagesResult {
val selfClientId = currentClientIdProvider().getOrElse {
return RefillKeyPackagesResult.Failure(it)
}

return mlsClientProvider.getMLSClient().map { CipherSuite.fromTag(it.getDefaultCipherSuite()) }.flatMap { cipherSuite ->
keyPackageRepository.getAvailableKeyPackageCount(selfClientId, cipherSuite)
}.flatMap {
kaliumLogger.i("Key packages: Found ${it.count} available key packages")
if (keyPackageLimitsProvider.needsRefill(it.count)) {
kaliumLogger.i("Key packages: Refilling key packages...")
val amount = keyPackageLimitsProvider.refillAmount()
keyPackageRepository.uploadNewKeyPackages(selfClientId, amount).flatMap {
Either.Right(Unit)
}
} else {
kaliumLogger.i("Key packages: Refill not needed")
Either.Right(Unit)
}
}.fold({ failure ->
RefillKeyPackagesResult.Failure(failure)
}, {
RefillKeyPackagesResult.Success
})

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ class KeyPackageRepositoryTest {

@Test
fun givenExistingClient_whenGettingAvailableKeyPackageCount_thenResultShouldBePropagated() = runTest {

val cipherSuite = CipherSuite.fromTag(1)
val (_, keyPackageRepository) = Arrangement()
.withMLSClient()
.withGetAvailableKeyPackageCountSuccessful()
.arrange()

val keyPackageCount = keyPackageRepository.getAvailableKeyPackageCount(Arrangement.SELF_CLIENT_ID)
val keyPackageCount = keyPackageRepository.getAvailableKeyPackageCount(Arrangement.SELF_CLIENT_ID, cipherSuite)

assertIs<Either.Right<KeyPackageCountDTO>>(keyPackageCount)
assertEquals(Arrangement.KEY_PACKAGE_COUNT_DTO.count, keyPackageCount.value.count)
Expand Down Expand Up @@ -210,7 +210,7 @@ class KeyPackageRepositoryTest {

suspend fun withGetAvailableKeyPackageCountSuccessful() = apply {
coEvery {
keyPackageApi.getAvailableKeyPackageCount(eq(SELF_CLIENT_ID.value))
keyPackageApi.getAvailableKeyPackageCount(eq(SELF_CLIENT_ID.value), any())
}.returns(NetworkResponse.Success(KEY_PACKAGE_COUNT_DTO, mapOf(), 200))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@

package com.wire.kalium.logic.feature.keypackage

import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.feature.keypackage.MLSKeyPackageCountUseCaseTest.Arrangement.Companion.CLIENT_FETCH_ERROR
import com.wire.kalium.logic.feature.keypackage.MLSKeyPackageCountUseCaseTest.Arrangement.Companion.KEY_PACKAGE_COUNT
import com.wire.kalium.logic.feature.keypackage.MLSKeyPackageCountUseCaseTest.Arrangement.Companion.KEY_PACKAGE_COUNT_DTO
Expand Down Expand Up @@ -53,16 +56,19 @@ class MLSKeyPackageCountUseCaseTest {

@Test
fun givenClientIdIsNotRegistered_ThenReturnGenericError() = runTest {
val expectedCipherSuite = CipherSuite.MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519

val (arrangement, keyPackageCountUseCase) = Arrangement()
.withClientId(Either.Left(CLIENT_FETCH_ERROR))
.arrange{
.withDefaultCipherSuite(expectedCipherSuite)
.arrange {
withGetMLSEnabledReturning(true.right())
}

val actual = keyPackageCountUseCase()

coVerify {
arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID))
arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID), eq(expectedCipherSuite))
}.wasNotInvoked()

assertIs<MLSKeyPackageCountResult.Failure.FetchClientIdFailure>(actual)
Expand All @@ -71,47 +77,55 @@ class MLSKeyPackageCountUseCaseTest {

@Test
fun givenClientId_whenCallingKeyPackageCountReturnValue_ThenReturnKeyPackageCountSuccess() = runTest {
val expectedCipherSuite = CipherSuite.MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519

val (arrangement, keyPackageCountUseCase) = Arrangement()
.withAvailableKeyPackageCountReturn(Either.Right(KEY_PACKAGE_COUNT_DTO))
.withClientId(Either.Right(TestClient.CLIENT_ID))
.withDefaultCipherSuite(expectedCipherSuite)
.withKeyPackageLimitSucceed()
.arrange{
.arrange {
withGetMLSEnabledReturning(true.right())
}

val actual = keyPackageCountUseCase()

coVerify {
arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID))
arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID), eq(expectedCipherSuite))
}.wasInvoked(once)
assertIs<MLSKeyPackageCountResult.Success>(actual)
assertEquals(actual, MLSKeyPackageCountResult.Success(TestClient.CLIENT_ID, KEY_PACKAGE_COUNT, true))
}

@Test
fun givenClientID_whenCallingKeyPackageCountReturnError_ThenReturnKeyPackageCountFailure() = runTest {
val expectedCipherSuite = CipherSuite.MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519

val (arrangement, keyPackageCountUseCase) = Arrangement()
.withAvailableKeyPackageCountReturn(Either.Left(NETWORK_FAILURE))
.withClientId(Either.Right(TestClient.CLIENT_ID))
.arrange{
.withDefaultCipherSuite(expectedCipherSuite)
.arrange {
withGetMLSEnabledReturning(true.right())
}

val actual = keyPackageCountUseCase()

coVerify {
arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID))
arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID), eq(expectedCipherSuite))
}.wasInvoked(once)
assertIs<MLSKeyPackageCountResult.Failure.NetworkCallFailure>(actual)
assertEquals(actual.networkFailure, NETWORK_FAILURE)
}

@Test
fun givenClientID_whenCallingGetMLSEnabledReturnFalse_ThenReturnKeyPackageCountNotEnabledFailure() = runTest {
val expectedCipherSuite = CipherSuite.MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519
val (arrangement, keyPackageCountUseCase) = Arrangement()
.withAvailableKeyPackageCountReturn(Either.Right(KEY_PACKAGE_COUNT_DTO))
.withClientId(Either.Right(TestClient.CLIENT_ID))
.arrange{
.withDefaultCipherSuite(expectedCipherSuite)
.arrange {
withGetMLSEnabledReturning(false.right())
}

Expand All @@ -122,7 +136,7 @@ class MLSKeyPackageCountUseCaseTest {
}.wasInvoked(once)

coVerify {
arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID))
arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID), eq(expectedCipherSuite))
}.wasNotInvoked()
assertIs<MLSKeyPackageCountResult.Failure.NotEnabled>(actual)
}
Expand All @@ -137,6 +151,12 @@ class MLSKeyPackageCountUseCaseTest {
@Mock
val keyPackageLimitsProvider = mock(KeyPackageLimitsProvider::class)

@Mock
val mlsClientProvider = mock(MLSClientProvider::class)

@Mock
val mlsClient = mock(MLSClient::class)

suspend fun withClientId(result: Either<CoreFailure, ClientId>) = apply {
coEvery {
currentClientIdProvider.invoke()
Expand All @@ -151,13 +171,27 @@ class MLSKeyPackageCountUseCaseTest {

suspend fun withAvailableKeyPackageCountReturn(result: Either<NetworkFailure, KeyPackageCountDTO>) = apply {
coEvery {
keyPackageRepository.getAvailableKeyPackageCount(any())
keyPackageRepository.getAvailableKeyPackageCount(any(), any())
}.returns(result)
}

fun arrange(block: suspend Arrangement.() -> Unit) = apply { runBlocking { block() } }.let {
fun withDefaultCipherSuite(cipherSuite: CipherSuite) = apply {
every {
mlsClient.getDefaultCipherSuite()
}.returns(cipherSuite.tag.toUShort())
}

suspend fun arrange(block: suspend Arrangement.() -> Unit) = apply {
coEvery { mlsClientProvider.getMLSClient() }.returns(mlsClient.right())
}.apply {
runBlocking { block() }
}.let {
this to MLSKeyPackageCountUseCaseImpl(
keyPackageRepository, currentClientIdProvider, keyPackageLimitsProvider, userConfigRepository
keyPackageRepository = keyPackageRepository,
currentClientIdProvider = currentClientIdProvider,
keyPackageLimitsProvider = keyPackageLimitsProvider,
userConfigRepository = userConfigRepository,
mlsClientProvider = mlsClientProvider,
)
}

Expand Down
Loading

0 comments on commit 23e2243

Please sign in to comment.