Skip to content

Commit 31e0d06

Browse files
authored
Migrate VertexAI serialization to be localized (#6631)
There are some considerations to how this should be finalized. Current implementation details that I've decided on which we can change: * Based on the example doc, all classes `Foo` have been renamed `InternalFoo` * All internal serialization classes, where relevant, are moved to inner classes of their API counterparts * All classes only used as child fields for serialization classes have been moved to inner classes of those classes * All `toPublic` and `toInternal` methods on API and serialization classes have been moved inside of those classes and `conversions.kt` has been mostly emptied. * A few serialization classes do not have API equivalents and are left in a `Types.kt` file Possible changes: * Change all `InternalFoo` classes to have the same name, referenced as `Foo.Internal` rather than `Foo.InternalFoo`. This will probably make the codebase feel cleaner, but I'll wait for opinions on it * Move serialization only classes out of the internal serialization classes, either to inner classes of the API classes or top level classes themselves. * For classes that have serializers, rename serializers from `InternalFooSerializer` to `Serializer` for example `Foo.InternalFoo.Serializer` or `Foo.Internal.Serializer` instead of `Foo.InternalFoo.InternalFooSerializer` or `Foo.Internal.InternalFooSerializer`
1 parent f2d05d6 commit 31e0d06

38 files changed

+823
-898
lines changed

firebase-vertexai/consumer-rules.pro

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@
2020
# hide the original source file name.
2121
#-renamesourcefileattribute SourceFile
2222

23+
-keep class com.google.firebase.vertexai.type.** { *; }
2324
-keep class com.google.firebase.vertexai.common.** { *; }

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ import com.google.firebase.vertexai.common.APIController
2424
import com.google.firebase.vertexai.common.CountTokensRequest
2525
import com.google.firebase.vertexai.common.GenerateContentRequest
2626
import com.google.firebase.vertexai.common.HeaderProvider
27-
import com.google.firebase.vertexai.internal.util.toInternal
28-
import com.google.firebase.vertexai.internal.util.toPublic
2927
import com.google.firebase.vertexai.type.Content
3028
import com.google.firebase.vertexai.type.CountTokensResponse
3129
import com.google.firebase.vertexai.type.FinishReason

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ package com.google.firebase.vertexai.common
1919
import android.util.Log
2020
import com.google.firebase.Firebase
2121
import com.google.firebase.options
22-
import com.google.firebase.vertexai.common.server.FinishReason
23-
import com.google.firebase.vertexai.common.server.GRpcError
24-
import com.google.firebase.vertexai.common.server.GRpcErrorDetails
2522
import com.google.firebase.vertexai.common.util.decodeToFlow
2623
import com.google.firebase.vertexai.common.util.fullModelName
24+
import com.google.firebase.vertexai.type.CountTokensResponse
25+
import com.google.firebase.vertexai.type.FinishReason
26+
import com.google.firebase.vertexai.type.GRpcErrorResponse
27+
import com.google.firebase.vertexai.type.GenerateContentResponse
2728
import com.google.firebase.vertexai.type.RequestOptions
29+
import com.google.firebase.vertexai.type.Response
2830
import io.ktor.client.HttpClient
2931
import io.ktor.client.call.body
3032
import io.ktor.client.engine.HttpClientEngine
@@ -106,31 +108,33 @@ internal constructor(
106108
install(ContentNegotiation) { json(JSON) }
107109
}
108110

109-
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
111+
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse.Internal =
110112
try {
111113
client
112114
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") {
113115
applyCommonConfiguration(request)
114116
applyHeaderProvider()
115117
}
116118
.also { validateResponse(it) }
117-
.body<GenerateContentResponse>()
119+
.body<GenerateContentResponse.Internal>()
118120
.validate()
119121
} catch (e: Throwable) {
120122
throw FirebaseCommonAIException.from(e)
121123
}
122124

123-
fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> =
125+
fun generateContentStream(
126+
request: GenerateContentRequest
127+
): Flow<GenerateContentResponse.Internal> =
124128
client
125-
.postStream<GenerateContentResponse>(
129+
.postStream<GenerateContentResponse.Internal>(
126130
"${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
127131
) {
128132
applyCommonConfiguration(request)
129133
}
130134
.map { it.validate() }
131135
.catch { throw FirebaseCommonAIException.from(it) }
132136

133-
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
137+
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse.Internal =
134138
try {
135139
client
136140
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") {
@@ -275,19 +279,21 @@ private suspend fun validateResponse(response: HttpResponse) {
275279
throw ServerException(message)
276280
}
277281

278-
private fun getServiceDisabledErrorDetailsOrNull(error: GRpcError): GRpcErrorDetails? {
282+
private fun getServiceDisabledErrorDetailsOrNull(
283+
error: GRpcErrorResponse.GRpcError
284+
): GRpcErrorResponse.GRpcError.GRpcErrorDetails? {
279285
return error.details?.firstOrNull {
280286
it.reason == "SERVICE_DISABLED" && it.domain == "googleapis.com"
281287
}
282288
}
283289

284-
private fun GenerateContentResponse.validate() = apply {
290+
private fun GenerateContentResponse.Internal.validate() = apply {
285291
if ((candidates?.isEmpty() != false) && promptFeedback == null) {
286292
throw SerializationException("Error deserializing response, found no valid fields")
287293
}
288294
promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
289295
candidates
290296
?.mapNotNull { it.finishReason }
291-
?.firstOrNull { it != FinishReason.STOP }
297+
?.firstOrNull { it != FinishReason.Internal.STOP }
292298
?.let { throw ResponseStoppedException(this) }
293299
}

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.firebase.vertexai.common
1818

19+
import com.google.firebase.vertexai.type.GenerateContentResponse
1920
import io.ktor.serialization.JsonConvertException
2021
import kotlinx.coroutines.TimeoutCancellationException
2122

@@ -66,7 +67,7 @@ internal class InvalidAPIKeyException(message: String, cause: Throwable? = null)
6667
* @property response the full server response for the request.
6768
*/
6869
internal class PromptBlockedException(
69-
val response: GenerateContentResponse,
70+
val response: GenerateContentResponse.Internal,
7071
cause: Throwable? = null
7172
) :
7273
FirebaseCommonAIException(
@@ -98,7 +99,7 @@ internal class InvalidStateException(message: String, cause: Throwable? = null)
9899
* @property response the full server response for the request
99100
*/
100101
internal class ResponseStoppedException(
101-
val response: GenerateContentResponse,
102+
val response: GenerateContentResponse.Internal,
102103
cause: Throwable? = null
103104
) :
104105
FirebaseCommonAIException(
@@ -125,3 +126,18 @@ internal class ServiceDisabledException(message: String, cause: Throwable? = nul
125126
/** Catch all case for exceptions not explicitly expected. */
126127
internal class UnknownException(message: String, cause: Throwable? = null) :
127128
FirebaseCommonAIException(message, cause)
129+
130+
internal fun makeMissingCaseException(
131+
source: String,
132+
ordinal: Int
133+
): com.google.firebase.vertexai.type.SerializationException {
134+
return com.google.firebase.vertexai.type.SerializationException(
135+
"""
136+
|Missing case for a $source: $ordinal
137+
|This error indicates that one of the `toInternal` conversions needs updating.
138+
|If you're a developer seeing this exception, please file an issue on our GitHub repo:
139+
|https://github.com/firebase/firebase-android-sdk
140+
"""
141+
.trimMargin()
142+
)
143+
}

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
package com.google.firebase.vertexai.common
1818

19-
import com.google.firebase.vertexai.common.client.GenerationConfig
20-
import com.google.firebase.vertexai.common.client.Tool
21-
import com.google.firebase.vertexai.common.client.ToolConfig
22-
import com.google.firebase.vertexai.common.shared.Content
23-
import com.google.firebase.vertexai.common.shared.SafetySetting
2419
import com.google.firebase.vertexai.common.util.fullModelName
20+
import com.google.firebase.vertexai.type.Content
21+
import com.google.firebase.vertexai.type.GenerationConfig
22+
import com.google.firebase.vertexai.type.SafetySetting
23+
import com.google.firebase.vertexai.type.Tool
24+
import com.google.firebase.vertexai.type.ToolConfig
2525
import kotlinx.serialization.SerialName
2626
import kotlinx.serialization.Serializable
2727

@@ -30,21 +30,21 @@ internal sealed interface Request
3030
@Serializable
3131
internal data class GenerateContentRequest(
3232
val model: String? = null,
33-
val contents: List<Content>,
34-
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
35-
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
36-
val tools: List<Tool>? = null,
37-
@SerialName("tool_config") var toolConfig: ToolConfig? = null,
38-
@SerialName("system_instruction") val systemInstruction: Content? = null,
33+
val contents: List<Content.Internal>,
34+
@SerialName("safety_settings") val safetySettings: List<SafetySetting.Internal>? = null,
35+
@SerialName("generation_config") val generationConfig: GenerationConfig.Internal? = null,
36+
val tools: List<Tool.Internal>? = null,
37+
@SerialName("tool_config") var toolConfig: ToolConfig.Internal? = null,
38+
@SerialName("system_instruction") val systemInstruction: Content.Internal? = null,
3939
) : Request
4040

4141
@Serializable
4242
internal data class CountTokensRequest(
4343
val generateContentRequest: GenerateContentRequest? = null,
4444
val model: String? = null,
45-
val contents: List<Content>? = null,
46-
val tools: List<Tool>? = null,
47-
@SerialName("system_instruction") val systemInstruction: Content? = null,
45+
val contents: List<Content.Internal>? = null,
46+
val tools: List<Tool.Internal>? = null,
47+
@SerialName("system_instruction") val systemInstruction: Content.Internal? = null,
4848
) : Request {
4949
companion object {
5050
fun forGenAI(generateContentRequest: GenerateContentRequest) =

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Response.kt

Lines changed: 0 additions & 46 deletions
This file was deleted.

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt

Lines changed: 0 additions & 80 deletions
This file was deleted.

0 commit comments

Comments
 (0)