Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vertexai): Add Live streaming feature #16991

Merged
merged 44 commits into from
Mar 25, 2025
Merged
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
f53b00c
init
cynthiajoan Jan 13, 2025
0e5a21d
more testing for the audio player
cynthiajoan Jan 20, 2025
cd59d74
organize example page and functions
cynthiajoan Jan 23, 2025
8b3285d
Merge branch 'vertexai/example_update' into vertexai/bidi
cynthiajoan Jan 23, 2025
c5bb860
get bidi demo working under new example layout
cynthiajoan Jan 23, 2025
9575c5e
Merge branch 'main' into vertexai/bidi
cynthiajoan Jan 27, 2025
3ff550e
kinda playable
cynthiajoan Jan 28, 2025
d2cbaff
working audio output
cynthiajoan Jan 29, 2025
76aa22b
use firebase project for the connection
cynthiajoan Jan 30, 2025
c1695e5
record to file and load file to stream, no response yet
cynthiajoan Feb 3, 2025
b361009
try different ways to test sending audio
cynthiajoan Feb 3, 2025
9988f72
try realtime input
cynthiajoan Feb 6, 2025
2d37054
Merge branch 'main' into vertexai/bidi
cynthiajoan Feb 12, 2025
c9476e0
stream recording, try local file dump
cynthiajoan Feb 18, 2025
89c7190
first kinda working version
cynthiajoan Feb 21, 2025
8b3bc15
much better
cynthiajoan Feb 21, 2025
fe2073b
Merge branch 'main' into vertexai/bidi
cynthiajoan Feb 21, 2025
d53655f
Make function calling working
cynthiajoan Feb 26, 2025
fd128b2
add new autopush backend
cynthiajoan Feb 27, 2025
2541969
Some update after api doc
cynthiajoan Mar 9, 2025
1326935
more update on api
cynthiajoan Mar 11, 2025
7ee188c
more clean up for api
cynthiajoan Mar 11, 2025
0ca9d5a
Let's still keep image modality
cynthiajoan Mar 11, 2025
b675e6a
minor json fix
cynthiajoan Mar 12, 2025
e5920c2
Use IOWebSocketChannel class to pass in headers
cynthiajoan Mar 12, 2025
eff222c
unit test for live_api
cynthiajoan Mar 14, 2025
8275927
more controllable receive logic
cynthiajoan Mar 17, 2025
2a8164b
test fix after model library refactor
cynthiajoan Mar 17, 2025
07e18f9
documentation and try catch in the example
cynthiajoan Mar 17, 2025
202418e
Somehow working continuously conversation
cynthiajoan Mar 17, 2025
6c53ab3
function calling fix
cynthiajoan Mar 18, 2025
fdb006a
changes after bugbash
cynthiajoan Mar 19, 2025
04d1b40
Merge branch 'main' into vertexai/bidi
cynthiajoan Mar 19, 2025
22ae77e
fix after merge main
cynthiajoan Mar 19, 2025
9689825
more fixes after merge
cynthiajoan Mar 19, 2025
7de0ebb
fix analyzing issues
cynthiajoan Mar 24, 2025
3d78da0
more fix for analyzer
cynthiajoan Mar 24, 2025
7113e44
fix analyzer
cynthiajoan Mar 24, 2025
104714e
two more analyzer
cynthiajoan Mar 24, 2025
bda672a
fix the live test for function call
cynthiajoan Mar 24, 2025
d03df57
Address review comment
cynthiajoan Mar 25, 2025
ce683ee
More review comments to address
cynthiajoan Mar 25, 2025
c8bf5fe
review comments on audio player and recorder
cynthiajoan Mar 25, 2025
d2951e7
Update code to handle optional id for FunctionCall parse
cynthiajoan Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Some update after api doc
cynthiajoan committed Mar 9, 2025
commit 2541969dff72d5e87b35a696e82c3f360c897075
Original file line number Diff line number Diff line change
@@ -46,10 +46,29 @@ class _BidiPageState extends State<BidiPage> {
bool _loading = false;
bool _session_opening = false;
bool _recording = false;
late AsyncSession _session;
late LiveGenerativeModel _liveModel;
late LiveSession _session;
final _audioManager = AudioStreamManager();
final _audioRecorder = InMemoryAudioRecorder();

@override
void initState() {
super.initState();

final config = LiveGenerationConfig(
speechConfig: SpeechConfig(voice: Voices.Charon),
responseModalities: [ResponseModalities.Audio],
);

_liveModel = FirebaseVertexAI.instance.liveGenerativeModel(
model: 'gemini-2.0-flash-exp',
liveGenerationConfig: config,
tools: [
Tool.functionDeclarations([lightControlTool]),
],
);
}

void _scrollDown() {
WidgetsBinding.instance.addPostFrameCallback(
(_) => _scrollController.animateTo(
@@ -224,21 +243,8 @@ class _BidiPageState extends State<BidiPage> {
_loading = true;
});

const modelName = 'gemini-2.0-flash-exp';

final config = LiveGenerationConfig(
speechConfig: SpeechConfig(voice: Voices.Charon),
responseModalities: [ResponseModalities.Audio],
);

if (!_session_opening) {
_session = await widget.model.connect(
model: modelName,
config: config,
tools: [
Tool.functionDeclarations([lightControlTool]),
],
);
_session = await _liveModel.connect();
_session_opening = true;
unawaited(_handle_response());
} else {
@@ -365,7 +371,7 @@ class _BidiPageState extends State<BidiPage> {
final data = InlineDataPart('audio/pcm', audio, willContinue: true);
media_chunks.add(data);

await _session!.stream(mediaChunks: media_chunks);
await _session!.sendMediaChunks(mediaChunks: media_chunks);
// print('Stream realtime audio in one chunk to server in one request');
//_session.printWsStatus();
setState(() {
@@ -442,8 +448,8 @@ class _BidiPageState extends State<BidiPage> {
var chunkBuilder = BytesBuilder();
var audioIndex = 0;
await for (var response in responseStream) {
if (response.serverContent?.modelTurn != null) {
final partList = response.serverContent?.modelTurn?.parts;
if (response is LiveServerContent && response.modelTurn != null) {
final partList = response.modelTurn?.parts;
if (partList != null) {
for (var part in partList) {
if (part is TextPart) {
@@ -483,7 +489,9 @@ class _BidiPageState extends State<BidiPage> {
}

// Check if the turn is complete
if (response.serverContent?.turnComplete ?? false) {
if (response is LiveServerContent &&
response.turnComplete != null &&
response.turnComplete!) {
print('Turn complete!');
if (chunkBuilder.isNotEmpty) {
Uint8List chunk = await AudioUtil.audioChunkWithHeader(
@@ -496,9 +504,8 @@ class _BidiPageState extends State<BidiPage> {
}
}

if (response.toolCall != null &&
response.toolCall!.functionCalls != null) {
final functionCalls = response.toolCall!.functionCalls!.toList();
if (response is LiveServerToolCall && response.functionCalls != null) {
final functionCalls = response.functionCalls!.toList();
// When the model response with a function call, invoke the function.
if (functionCalls.isNotEmpty) {
final functionCall = functionCalls.first;
Original file line number Diff line number Diff line change
@@ -56,9 +56,6 @@ export 'src/function_calling.dart'
FunctionDeclaration,
Tool,
ToolConfig;
export 'src/live.dart' show AsyncLive, AsyncSession;
export 'src/live_api.dart'
show LiveGenerationConfig, SpeechConfig, Voices, ResponseModalities;
export 'src/generative_model.dart' show GenerativeModel;
export 'src/imagen_api.dart'
show
@@ -70,4 +67,15 @@ export 'src/imagen_api.dart'
ImagenAspectRatio;
export 'src/imagen_content.dart' show ImagenInlineImage;
export 'src/imagen_model.dart' show ImagenModel;
export 'src/live_api.dart'
show
LiveGenerationConfig,
SpeechConfig,
Voices,
ResponseModalities,
LiveServerContent,
LiveServerToolCall,
LiveServerToolCallCancellation;
export 'src/live_model.dart' show LiveGenerativeModel;
export 'src/live_session.dart' show LiveSession;
export 'src/schema.dart' show Schema, SchemaType;
157 changes: 132 additions & 25 deletions packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart
Original file line number Diff line number Diff line change
@@ -607,31 +607,22 @@ enum HarmBlockThreshold {
Object toJson() => _jsonString;
}

/// Configuration options for model generation and outputs.
final class GenerationConfig {
abstract class BaseGenerationConfig {
// ignore: public_member_api_docs
GenerationConfig(
{this.candidateCount,
this.stopSequences,
this.maxOutputTokens,
this.temperature,
this.topP,
this.topK,
this.responseMimeType,
this.responseSchema});
BaseGenerationConfig({
this.candidateCount,
this.maxOutputTokens,
this.temperature,
this.topP,
this.topK,
});

/// Number of generated responses to return.
///
/// This value must be between [1, 8], inclusive. If unset, this will default
/// to 1.
final int? candidateCount;

/// The set of character sequences (up to 5) that will stop output generation.
///
/// If specified, the API will stop at the first appearance of a stop
/// sequence. The stop sequence will not be included as part of the response.
final List<String>? stopSequences;

/// The maximum number of tokens to include in a candidate.
///
/// If unset, this will default to output_token_limit specified in the `Model`
@@ -665,6 +656,38 @@ final class GenerationConfig {
/// Note: The default value varies by model.
final int? topK;

// ignore: public_member_api_docs
Map<String, Object?> toJson() => {
if (candidateCount case final candidateCount?)
'candidateCount': candidateCount,
if (maxOutputTokens case final maxOutputTokens?)
'maxOutputTokens': maxOutputTokens,
if (temperature case final temperature?) 'temperature': temperature,
if (topP case final topP?) 'topP': topP,
if (topK case final topK?) 'topK': topK,
};
}

/// Configuration options for model generation and outputs.
final class GenerationConfig extends BaseGenerationConfig {
// ignore: public_member_api_docs
GenerationConfig({
super.candidateCount,
this.stopSequences,
super.maxOutputTokens,
super.temperature,
super.topP,
super.topK,
this.responseMimeType,
this.responseSchema,
});

/// The set of character sequences (up to 5) that will stop output generation.
///
/// If specified, the API will stop at the first appearance of a stop
/// sequence. The stop sequence will not be included as part of the response.
final List<String>? stopSequences;

/// Output response mimetype of the generated candidate text.
///
/// Supported mimetype:
@@ -678,25 +701,109 @@ final class GenerationConfig {
/// a schema; currently this is limited to `application/json`.
final Schema? responseSchema;

/// Convert to json format
@override
Map<String, Object?> toJson() => {
if (candidateCount case final candidateCount?)
'candidateCount': candidateCount,
...super.toJson(),
if (stopSequences case final stopSequences?
when stopSequences.isNotEmpty)
'stopSequences': stopSequences,
if (maxOutputTokens case final maxOutputTokens?)
'maxOutputTokens': maxOutputTokens,
if (temperature case final temperature?) 'temperature': temperature,
if (topP case final topP?) 'topP': topP,
if (topK case final topK?) 'topK': topK,
if (responseMimeType case final responseMimeType?)
'responseMimeType': responseMimeType,
if (responseSchema case final responseSchema?)
'responseSchema': responseSchema,
};
}

/// Configuration options for model generation and outputs.
// final class GenerationConfig {
// // ignore: public_member_api_docs
// GenerationConfig(
// {this.candidateCount,
// this.stopSequences,
// this.maxOutputTokens,
// this.temperature,
// this.topP,
// this.topK,
// this.responseMimeType,
// this.responseSchema});

// /// Number of generated responses to return.
// ///
// /// This value must be between [1, 8], inclusive. If unset, this will default
// /// to 1.
// final int? candidateCount;

// /// The set of character sequences (up to 5) that will stop output generation.
// ///
// /// If specified, the API will stop at the first appearance of a stop
// /// sequence. The stop sequence will not be included as part of the response.
// final List<String>? stopSequences;

// /// The maximum number of tokens to include in a candidate.
// ///
// /// If unset, this will default to output_token_limit specified in the `Model`
// /// specification.
// final int? maxOutputTokens;

// /// Controls the randomness of the output.
// ///
// /// Note: The default value varies by model.
// ///
// /// Values can range from `[0.0, infinity]`, inclusive. A value temperature
// /// must be greater than 0.0.
// final double? temperature;

// /// The maximum cumulative probability of tokens to consider when sampling.
// ///
// /// The model uses combined Top-k and nucleus sampling. Tokens are sorted
// /// based on their assigned probabilities so that only the most likely tokens
// /// are considered. Top-k sampling directly limits the maximum number of
// /// tokens to consider, while Nucleus sampling limits number of tokens based
// /// on the cumulative probability.
// ///
// /// Note: The default value varies by model.
// final double? topP;

// /// The maximum number of tokens to consider when sampling.
// ///
// /// The model uses combined Top-k and nucleus sampling. Top-k sampling
// /// considers the set of `top_k` most probable tokens. Defaults to 40.
// ///
// /// Note: The default value varies by model.
// final int? topK;

// /// Output response mimetype of the generated candidate text.
// ///
// /// Supported mimetype:
// /// - `text/plain`: (default) Text output.
// /// - `application/json`: JSON response in the candidates.
// final String? responseMimeType;

// /// Output response schema of the generated candidate text.
// ///
// /// - Note: This only applies when the [responseMimeType] supports
// /// a schema; currently this is limited to `application/json`.
// final Schema? responseSchema;

// /// Convert to json format
// Map<String, Object?> toJson() => {
// if (candidateCount case final candidateCount?)
// 'candidateCount': candidateCount,
// if (stopSequences case final stopSequences?
// when stopSequences.isNotEmpty)
// 'stopSequences': stopSequences,
// if (maxOutputTokens case final maxOutputTokens?)
// 'maxOutputTokens': maxOutputTokens,
// if (temperature case final temperature?) 'temperature': temperature,
// if (topP case final topP?) 'topP': topP,
// if (topK case final topK?) 'topK': topK,
// if (responseMimeType case final responseMimeType?)
// 'responseMimeType': responseMimeType,
// if (responseSchema case final responseSchema?)
// 'responseSchema': responseSchema,
// };
// }

/// Type of task for which the embedding will be used.
enum TaskType {
/// Unset value, which will default to one of the other enum values.
Original file line number Diff line number Diff line change
@@ -40,29 +40,23 @@ enum Task {
/// Do not instantiate directly.
abstract class BaseModel {
// ignore: public_member_api_docs
BaseModel({
required String model,
required String location,
required FirebaseApp app,
required ApiClient client,
}) : _model = normalizeModelName(model),
_projectUri = _vertexUri(app, location),
_client = client;
BaseModel(
{required String model,
required String location,
required FirebaseApp app})
: _model = normalizeModelName(model),
_projectUri = _vertexUri(app, location);

static const _baseUrl = 'firebasevertexai.googleapis.com';
static const _apiVersion = 'v1beta';

final ({String prefix, String name}) _model;

final Uri _projectUri;
final ApiClient _client;

/// The normalized model name.
({String prefix, String name}) get model => _model;

/// The API client.
ApiClient get client => _client;

/// Returns the model code for a user friendly model name.
///
/// If the model name is already a model code (contains a `/`), use the parts
@@ -109,6 +103,21 @@ abstract class BaseModel {
Uri taskUri(Task task) => _projectUri.replace(
pathSegments: _projectUri.pathSegments
.followedBy([_model.prefix, '${_model.name}:${task.name}']));
}

abstract class BaseApiClientModel extends BaseModel {
// ignore: public_member_api_docs
BaseApiClientModel({
required super.model,
required super.location,
required super.app,
required ApiClient client,
}) : _client = client;

final ApiClient _client;

/// The API client.
ApiClient get client => _client;

/// Make a unary request for [task] with JSON encodable [params].
Future<T> makeRequest<T>(Task task, Map<String, Object?> params,
Original file line number Diff line number Diff line change
@@ -18,12 +18,10 @@ import 'package:firebase_core/firebase_core.dart';
import 'package:firebase_core_platform_interface/firebase_core_platform_interface.dart'
show FirebasePluginPlatform;

import 'api.dart';
import 'content.dart';
import 'function_calling.dart';
import '../firebase_vertexai.dart';
import 'generative_model.dart';
import 'imagen_api.dart';
import 'imagen_model.dart';
import 'live_model.dart';

const _defaultLocation = 'us-central1';

@@ -130,4 +128,22 @@ class FirebaseVertexAI extends FirebasePluginPlatform {
appCheck: appCheck,
auth: auth);
}

LiveGenerativeModel liveGenerativeModel({
required String model,
LiveGenerationConfig? liveGenerationConfig,
List<Tool>? tools,
Content? systemInstruction,
}) {
return createLiveGenerativeModel(
app: app,
location: location,
model: model,
liveGenerationConfig: liveGenerationConfig,
tools: tools,
systemInstruction: systemInstruction,
appCheck: appCheck,
auth: auth,
);
}
}
Original file line number Diff line number Diff line change
@@ -15,45 +15,29 @@
// ignore_for_file: use_late_for_private_fields_and_variables

import 'dart:async';
import 'dart:convert';

import 'package:firebase_app_check/firebase_app_check.dart';
import 'package:firebase_auth/firebase_auth.dart';
import 'package:firebase_core/firebase_core.dart';
import 'package:http/http.dart' as http;
import 'package:web_socket_channel/web_socket_channel.dart';

import 'api.dart';
import 'base_model.dart';
import 'client.dart';
import 'content.dart';
import 'function_calling.dart';
import 'live.dart';
import 'live_session.dart';
import 'live_api.dart';
import 'vertex_version.dart';

const _baseUrl = 'firebasevertexai.googleapis.com';
const _apiVersion = 'v1beta';

const _baseDailyUrl = 'daily-firebaseml.sandbox.googleapis.com';
const _apiUrl =
'ws/google.firebase.machinelearning.v2beta.LlmBidiService/BidiGenerateContent?key=';

const _baseAutopushUrl = 'autopush-firebasevertexai.sandbox.googleapis.com';
const _apiAutopushUrl =
'ws/google.firebase.vertexai.v1beta.LlmBidiService/BidiGenerateContent/locations';

const _baseGAIUrl = 'generativelanguage.googleapis.com';
const _apiGAIUrl =
'ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key=';

const _bidiGoogleAI = false;

/// A multimodel generative model (like Gemini).
///
/// Allows generating content, creating embeddings, and counting the number of
/// tokens in a piece of content.
final class GenerativeModel extends BaseModel {
final class GenerativeModel extends BaseApiClientModel {
/// Create a [GenerativeModel] backed by the generative model named [model].
///
/// The [model] argument can be a model name (such as `'gemini-pro'`) or a
@@ -240,45 +224,6 @@ final class GenerativeModel extends BaseModel {
};
return makeRequest(Task.countTokens, parameters, parseCountTokensResponse);
}

Future<AsyncSession> connect({
required String model,
LiveGenerationConfig? config,
Content? systemInstruction,
List<Tool>? tools,
}) async {
late String uri;
late String modelString;
if (_bidiGoogleAI) {
uri = 'wss://$_baseGAIUrl/$_apiGAIUrl${_app.options.apiKey}';
modelString = 'models/$model';
} else {
// uri = 'wss://$_baseDailyUrl/$_apiUrl${_app.options.apiKey}';
uri =
'wss://$_baseAutopushUrl/$_apiAutopushUrl/$_location?key=${_app.options.apiKey}';
modelString =
'projects/${_app.options.projectId}/locations/$_location/publishers/google/models/$model';
}

final requestJson = {
'setup': {
'model': modelString,
if (config != null) 'generation_config': config.toJson(),
if (systemInstruction != null)
'system_instruction': systemInstruction.toJson(),
if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(),
}
};

final request = jsonEncode(requestJson);
var ws = WebSocketChannel.connect(Uri.parse(uri));
await ws.ready;
print(uri);
print(request);

ws.sink.add(request);
return AsyncSession(ws: ws);
}
}

/// Returns a [GenerativeModel] using it's private constructor.
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ import 'imagen_content.dart';
/// > Warning: For Vertex AI in Firebase, image generation using Imagen 3 models
/// is in Public Preview, which means that the feature is not subject to any SLA
/// or deprecation policy and could change in backwards-incompatible ways.
final class ImagenModel extends BaseModel {
final class ImagenModel extends BaseApiClientModel {
ImagenModel._(
{required FirebaseApp app,
required String model,
292 changes: 181 additions & 111 deletions packages/firebase_vertexai/firebase_vertexai/lib/src/live_api.dart

Large diffs are not rendered by default.

132 changes: 132 additions & 0 deletions packages/firebase_vertexai/firebase_vertexai/lib/src/live_model.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import 'dart:convert';
import 'package:firebase_app_check/firebase_app_check.dart';
import 'package:firebase_auth/firebase_auth.dart';
import 'package:firebase_core/firebase_core.dart';
import 'package:web_socket_channel/web_socket_channel.dart';

import 'base_model.dart';
import 'live_api.dart';
import 'function_calling.dart';
import 'content.dart';
import 'live_session.dart';

const _baseDailyUrl = 'daily-firebaseml.sandbox.googleapis.com';
const _apiUrl =
'ws/google.firebase.machinelearning.v2beta.LlmBidiService/BidiGenerateContent?key=';

const _baseAutopushUrl = 'autopush-firebasevertexai.sandbox.googleapis.com';
const _apiAutopushUrl =
'ws/google.firebase.vertexai.v1beta.LlmBidiService/BidiGenerateContent/locations';

const _baseGAIUrl = 'generativelanguage.googleapis.com';
const _apiGAIUrl =
'ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key=';

const _bidiGoogleAI = false;

final class LiveGenerativeModel extends BaseModel {
LiveGenerativeModel._(
{required String model,
required String location,
required FirebaseApp app,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
LiveGenerationConfig? liveGenerationConfig,
List<Tool>? tools,
Content? systemInstruction})
: _app = app,
_location = location,
_liveGenerationConfig = liveGenerationConfig,
_tools = tools,
_systemInstruction = systemInstruction,
super(
model: model,
app: app,
location: location,
);

final FirebaseApp _app;
final String _location;
final LiveGenerationConfig? _liveGenerationConfig;
final List<Tool>? _tools;
final Content? _systemInstruction;

/// Establishes a connection to a live generation service.
///
/// This function handles the WebSocket connection setup and returns an [LiveSession]
/// object that can be used to communicate with the service.
///
/// Returns a [Future] that resolves to an [LiveSession] object upon successful
/// connection.
Future<LiveSession> connect() async {
late String uri;
late String modelString;

if (_bidiGoogleAI) {
uri = 'wss://$_baseGAIUrl/$_apiGAIUrl${_app.options.apiKey}';
modelString = '${model.prefix}/${model.name}}';
} else {
// uri = 'wss://$_baseDailyUrl/$_apiUrl${_app.options.apiKey}';
uri =
'wss://$_baseAutopushUrl/$_apiAutopushUrl/$_location?key=${_app.options.apiKey}';
modelString =
'projects/${_app.options.projectId}/locations/$_location/publishers/google/models/${model.name}';
}

final requestJson = {
'setup': {
'model': modelString,
if (_liveGenerationConfig != null)
'generation_config': _liveGenerationConfig.toJson(),
if (_systemInstruction != null)
'system_instruction': _systemInstruction.toJson(),
if (_tools != null) 'tools': _tools.map((t) => t.toJson()).toList(),
}
};

final request = jsonEncode(requestJson);
var ws = WebSocketChannel.connect(Uri.parse(uri));
await ws.ready;
print(uri);
print(request);

ws.sink.add(request);
return LiveSession(ws: ws);
}
}

/// Returns a [LiveGenerativeModel] using it's private constructor.
LiveGenerativeModel createLiveGenerativeModel({
required FirebaseApp app,
required String location,
required String model,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
LiveGenerationConfig? liveGenerationConfig,
List<Tool>? tools,
Content? systemInstruction,
}) =>
LiveGenerativeModel._(
model: model,
app: app,
appCheck: appCheck,
auth: auth,
location: location,
liveGenerationConfig: liveGenerationConfig,
tools: tools,
systemInstruction: systemInstruction,
);
Original file line number Diff line number Diff line change
@@ -25,11 +25,17 @@ const _FUNCTION_RESPONSE_REQUIRES_ID =
'FunctionResponse request must have an `id` field from the'
' response of a ToolCall.FunctionalCalls in Google AI.';

class AsyncSession {
/// Manages asynchronous communication with Gemini model over a WebSocket
/// connection.
class LiveSession {
final WebSocketChannel _ws;

AsyncSession({required WebSocketChannel ws}) : _ws = ws;
LiveSession({required WebSocketChannel ws}) : _ws = ws;

/// Sends content to the server.
///
/// [input] (optional): The content to send.
/// [turnComplete] (optional): Indicates if the turn is complete. Defaults to false.
Future<void> send({
Content? input,
bool turnComplete = false,
@@ -43,7 +49,10 @@ class AsyncSession {
_ws.sink.add(clientJson);
}

Future<void> stream({
/// Sends realtime input (media chunks) to the server.
///
/// [mediaChunks]: The list of media chunks to send.
Future<void> sendMediaChunks({
required List<InlineDataPart> mediaChunks,
}) async {
var clientMessage = LiveClientRealtimeInput(mediaChunks: mediaChunks);
@@ -53,57 +62,22 @@ class AsyncSession {
_ws.sink.add(clientJson);
}

/// Receives messages from the server.
///
/// Returns a [Stream] of [LiveServerMessage] objects representing the
/// messages received from the server.
Stream<LiveServerMessage> receive() async* {
await for (var message in _ws.stream) {
var jsonString = utf8.decode(message);
var response = json.decode(jsonString);
// print(response);
Map<String, dynamic> responseDict;
//print(response);

responseDict = _LiveServerMessageFromVertex(response);

var result = parseServerMessage(responseDict);
var result = parseServerMessage(response);

yield result;
}
}

Stream<LiveServerMessage> startStream({
required Stream<InlineDataPart> stream,
required String mimeType,
}) async* {
print('beginning of startStream');
var completer = Completer();
// Start the send loop. When stream is complete, complete the completer.
unawaited(_sendLoop(stream, mimeType, completer));

// Listen for messages from the WebSocket.
// await for (final message in _ws.stream) {
// var jsonString = utf8.decode(message);
// var response = json.decode(jsonString);
// print(response);
// Map<String, dynamic> responseDict;

// responseDict = _LiveServerMessageFromVertex(response);

// var result = parseServerMessage(responseDict);

// if (result.serverContent?.turnComplete ?? false) {
// yield result;
// break;
// }
// yield result;
// }

// Wait for the send loop to complete or the websocket to close.
await Future.any([completer.future]);

// Close the websocket if it's not already closed.
// if (_ws.closeCode == null) {
// await _ws.sink.close();
// }
}

Future<void> _sendLoop(
Stream<InlineDataPart> dataStream,
String mimeType,
@@ -114,7 +88,7 @@ class AsyncSession {
await for (final data in dataStream) {
print('send audio data with size ${data.bytes.length}');

await stream(mediaChunks: [data]);
await sendMediaChunks(mediaChunks: [data]);

// await send(input: Content.inlineData(mimeType, data.bytes));
// Give a chance for the receive loop to process responses.
@@ -135,40 +109,6 @@ class AsyncSession {
printWsStatus();
}

Map<String, dynamic> _LiveServerContentFromMldev(dynamic fromObject) {
var toObject = <String, dynamic>{};
if (fromObject is Map && fromObject.containsKey('modelTurn')) {
toObject['model_turn'] = parseContent(fromObject['modelTurn']);
}
if (fromObject is Map && fromObject.containsKey('turnComplete')) {
toObject['turn_complete'] = fromObject['turnComplete'];
}
return toObject;
}

Map<String, dynamic> _LiveToolCallFromMldev(dynamic fromObject) {
var toObject = <String, dynamic>{};
if (fromObject is Map && fromObject.containsKey('functionCalls')) {
toObject['function_calls'] = fromObject['functionCalls'];
}
return toObject;
}

// Map<String, dynamic> _LiveServerMessageFromMldev(dynamic fromObject) {
// var toObject = <String, dynamic>{};
// if (fromObject is Map && fromObject.containsKey('serverContent')) {
// toObject['server_content'] =
// _LiveServerContentFromMldev(fromObject['serverContent']);
// }
// if (fromObject is Map && fromObject.containsKey('toolCall')) {
// toObject['tool_call'] = _LiveToolCallFromMldev(fromObject['toolCall']);
// }
// if (fromObject is Map && fromObject.containsKey('toolCallCancellation')) {
// toObject['tool_call_cancellation'] = fromObject['toolCallCancellation'];
// }
// return toObject;
// }

Map<String, dynamic> _LiveServerContentFromVertex(dynamic fromObject) {
var toObject = <String, dynamic>{};
if (fromObject is Map && fromObject.containsKey('modelTurn')) {
@@ -281,6 +221,7 @@ class AsyncSession {
}
}

/// Closes the WebSocket connection.
Future<void> close() async {
await _ws.sink.close();
}