Skip to content

Commit

Permalink
Cleanup EngineConsumer and duplicate test data
Browse files Browse the repository at this point in the history
  • Loading branch information
darkdh committed Feb 28, 2025
1 parent 6daef00 commit 6baaa87
Show file tree
Hide file tree
Showing 16 changed files with 133 additions and 99 deletions.
2 changes: 2 additions & 0 deletions components/ai_chat/core/browser/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ if (!is_ios) {
"//base/test:test_support",
"//brave/components/ai_chat/core/browser",
"//brave/components/ai_chat/core/common",
"//brave/components/ai_chat/core/common:test_support",
"//brave/components/ai_chat/core/common/mojom",
"//brave/components/api_request_helper",
"//brave/components/brave_component_updater/browser:test_support",
Expand Down Expand Up @@ -200,6 +201,7 @@ source_set("test_support") {

deps = [
"//brave/components/ai_chat/core/browser",
"//brave/components/ai_chat/core/common:test_support",
"//brave/components/ai_chat/core/common/mojom",
"//services/network/public/cpp",
"//testing/gmock",
Expand Down
30 changes: 8 additions & 22 deletions components/ai_chat/core/browser/conversation_handler_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "brave/components/ai_chat/core/common/features.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
#include "brave/components/ai_chat/core/common/pref_names.h"
#include "brave/components/ai_chat/core/common/test_utils.h"
#include "components/grit/brave_components_strings.h"
#include "components/os_crypt/async/browser/os_crypt_async.h"
#include "components/os_crypt/async/browser/test_utils.h"
Expand Down Expand Up @@ -1301,18 +1302,7 @@ TEST_F(ConversationHandlerUnitTest, UploadImage) {
loop2.Run();
testing::Mock::VerifyAndClearExpectations(&client);

const std::vector<std::vector<uint8_t>> test_images = {
{0x01, 0x02, 0x03, 0x04, 0x05},
{0xde, 0xed, 0xbe, 0xef},
{0xff, 0xff, 0xff},
};
std::vector<mojom::UploadedImagePtr> uploaded_images;
uploaded_images.emplace_back(mojom::UploadedImage::New(
"filename1", sizeof(test_images[0]), test_images[0]));
uploaded_images.emplace_back(mojom::UploadedImage::New(
"filename2", sizeof(test_images[1]), test_images[1]));
uploaded_images.emplace_back(mojom::UploadedImage::New(
"filename3", sizeof(test_images[2]), test_images[2]));
auto uploaded_images = CreateSampleUploadedImages(3);

// There are uploaded images.
// Note that this will need to be put at the end of this test suite
Expand All @@ -1321,7 +1311,7 @@ TEST_F(ConversationHandlerUnitTest, UploadImage) {
EXPECT_CALL(delegate, GetUploadedImagesSize()).WillOnce(testing::Return(3));
EXPECT_CALL(delegate, GetUploadedImages())
.Times(1)
.WillOnce(testing::Return(std::move(uploaded_images)));
.WillOnce(testing::Return(Clone(uploaded_images)));
EXPECT_CALL(delegate, ClearUploadedImages()).Times(1);
base::RunLoop loop3;
EXPECT_CALL(client, OnModelDataChanged)
Expand All @@ -1338,15 +1328,11 @@ TEST_F(ConversationHandlerUnitTest, UploadImage) {
auto& last_entry = conversation_handler_->GetConversationHistory().back();
EXPECT_TRUE(last_entry->uploaded_images);
const auto& images = last_entry->uploaded_images.value();
EXPECT_EQ(images[0]->filename, "filename1");
EXPECT_EQ(images[0]->filesize, static_cast<int64_t>(sizeof(test_images[0])));
EXPECT_EQ(images[0]->image_data, test_images[0]);
EXPECT_EQ(images[1]->filename, "filename2");
EXPECT_EQ(images[1]->filesize, static_cast<int64_t>(sizeof(test_images[1])));
EXPECT_EQ(images[1]->image_data, test_images[1]);
EXPECT_EQ(images[2]->filename, "filename3");
EXPECT_EQ(images[2]->filesize, static_cast<int64_t>(sizeof(test_images[2])));
EXPECT_EQ(images[2]->image_data, test_images[2]);
for (size_t i = 0; i < images.size(); ++i) {
EXPECT_EQ(images[i]->filename, uploaded_images[i]->filename);
EXPECT_EQ(images[i]->filesize, uploaded_images[i]->filesize);
EXPECT_EQ(images[i]->image_data, uploaded_images[i]->image_data);
}
}

TEST_F(ConversationHandlerUnitTest_NoAssociatedContent,
Expand Down
8 changes: 7 additions & 1 deletion components/ai_chat/core/browser/engine/engine_consumer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

#include <optional>
#include <string>
#include <string_view>

#include "base/base64.h"
#include "base/strings/strcat.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"

namespace ai_chat {
Expand All @@ -29,6 +30,11 @@ bool EngineConsumer::SupportsDeltaTextResponses() const {
return false;
}

std::string EngineConsumer::GetImageDataURL(base::span<uint8_t> image_data) {
constexpr char kDataUrlPrefix[] = "data:image/png;base64,";
return base::StrCat({kDataUrlPrefix, base::Base64Encode(image_data)});
}

bool EngineConsumer::CanPerformCompletionRequest(
const ConversationHistory& conversation_history) const {
if (conversation_history.empty()) {
Expand Down
2 changes: 2 additions & 0 deletions components/ai_chat/core/browser/engine/engine_consumer.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class EngineConsumer {
max_associated_content_length_ = max_associated_content_length;
}

static std::string GetImageDataURL(base::span<uint8_t> image_data);

protected:
// Check if we should call GenerationCompletedCallback early based on the
// conversation history. Ex. empty history, or if the last entry is not a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "base/memory/weak_ptr.h"
#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <type_traits>
#include <vector>

#include "base/base64.h"
#include "base/check.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
Expand All @@ -20,7 +19,6 @@
#include "base/memory/weak_ptr.h"
#include "base/numerics/clamped_math.h"
#include "base/strings/string_split.h"
#include "base/strings/string_util.h"
#include "base/time/time.h"
#include "base/types/expected.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
Expand Down Expand Up @@ -121,17 +119,15 @@ void EngineConsumerConversationAPI::GenerateAssistantResponse(
const auto& last_entry = conversation_history.back();
if (last_entry->uploaded_images) {
size_t counter = 0;
constexpr char kImageUrl[] = R"(data:image/png;base64,$1)";
for (const auto& uploaded_image : last_entry->uploaded_images.value()) {
// Only send the first uploaded_image becasue llama-vision seems to take
// the last one if there are multiple images
if (counter++ > 0) {
break;
}
const std::string image_url = base::ReplaceStringPlaceholders(
kImageUrl, {base::Base64Encode(uploaded_image->image_data)}, nullptr);
conversation.push_back({mojom::CharacterType::HUMAN,
ConversationEventType::UploadImage, image_url});
ConversationEventType::UploadImage,
GetImageDataURL(uploaded_image->image_data)});
}
}
// history
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <type_traits>
#include <vector>

#include "base/base64.h"
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"
#include "base/json/json_writer.h"
Expand All @@ -19,13 +20,15 @@
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "base/test/values_test_util.h"
#include "base/time/time.h"
#include "base/types/expected.h"
#include "base/values.h"
#include "brave/components/ai_chat/core/browser/engine/conversation_api_client.h"
#include "brave/components/ai_chat/core/browser/engine/engine_consumer.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
#include "brave/components/ai_chat/core/common/test_utils.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
Expand Down Expand Up @@ -435,11 +438,7 @@ TEST_F(EngineConsumerConversationAPIUnitTest, GenerateEvents_SummarizePage) {
}

TEST_F(EngineConsumerConversationAPIUnitTest, GenerateEvents_UploadImage) {
const std::vector<std::vector<uint8_t>> test_images = {
{0x01, 0x02, 0x03, 0x04, 0x05},
{0xde, 0xed, 0xbe, 0xef},
{0xff, 0xff, 0xff},
};
auto uploaded_images = CreateSampleUploadedImages(3);
constexpr char kTestPrompt[] = "Tell the user what is in the image?";
constexpr char kAssistantResponse[] = "It's a lion!";
auto* mock_api_client = GetMockConversationAPIClient();
Expand All @@ -452,36 +451,28 @@ TEST_F(EngineConsumerConversationAPIUnitTest, GenerateEvents_UploadImage) {
// Only support one image for now.
ASSERT_EQ(conversation.size(), 2u);
EXPECT_EQ(conversation[0].role, mojom::CharacterType::HUMAN);
EXPECT_EQ(conversation[0].content, "");
EXPECT_EQ(
conversation[0].content,
base::StrCat({"data:image/png;base64,",
base::Base64Encode(uploaded_images[0]->image_data)}));
EXPECT_EQ(conversation[0].type, ConversationAPIClient::UploadImage);
EXPECT_EQ(conversation[1].role, mojom::CharacterType::HUMAN);
EXPECT_EQ(conversation[1].content, kTestPrompt);
EXPECT_EQ(conversation[1].type, ConversationAPIClient::ChatMessage);
std::move(callback).Run(kAssistantResponse);
});

std::vector<mojom::UploadedImagePtr> uploaded_images;
uploaded_images.emplace_back(
mojom::UploadedImage::New("filename1", 1, test_images[0]));
uploaded_images.emplace_back(
mojom::UploadedImage::New("filename", 2, test_images[1]));
uploaded_images.emplace_back(
mojom::UploadedImage::New("filename", 3, test_images[2]));

std::vector<mojom::ConversationTurnPtr> history;
history.push_back(mojom::ConversationTurn::New(
std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::UNSPECIFIED,
"What is this image?", kTestPrompt, std::nullopt, std::nullopt,
base::Time::Now(), std::nullopt, std::move(uploaded_images), false));
base::Time::Now(), std::nullopt, CloneUpdatedImages(uploaded_images),
false));

engine_->GenerateAssistantResponse(
false, "", history, "", base::DoNothing(),
base::BindLambdaForTesting([&run_loop, kAssistantResponse](
EngineConsumer::GenerationResult result) {
EXPECT_STREQ(result.value().c_str(), kAssistantResponse);
run_loop.Quit();
}));
run_loop.Run();
base::test::TestFuture<EngineConsumer::GenerationResult> future;
engine_->GenerateAssistantResponse(false, "", history, "", base::DoNothing(),
future.GetCallback());
EXPECT_STREQ(future.Take()->c_str(), kAssistantResponse);
testing::Mock::VerifyAndClearExpectations(mock_api_client);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "base/memory/weak_ptr.h"
#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h"
Expand Down
7 changes: 2 additions & 5 deletions components/ai_chat/core/browser/engine/engine_consumer_oai.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <type_traits>
#include <vector>

#include "base/base64.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"
Expand Down Expand Up @@ -101,7 +100,6 @@ base::Value::List BuildMessages(
user_message.Set("text", "These images are uploaded by the users");
content.Append(std::move(user_message));
size_t counter = 0;
constexpr char kImageUrl[] = R"(data:image/png;base64,$1)";
// Only send the first uploaded_image becasue llama-vision seems to take the
// last one if there are multiple uploaded_images
for (const auto& uploaded_image : last_entry->uploaded_images.value()) {
Expand All @@ -110,10 +108,9 @@ base::Value::List BuildMessages(
}
base::Value::Dict image;
image.Set("type", "image_url");
const std::string image_url = base::ReplaceStringPlaceholders(
kImageUrl, {base::Base64Encode(uploaded_image->image_data)}, nullptr);
base::Value::Dict image_url_dict;
image_url_dict.Set("url", image_url);
image_url_dict.Set(
"url", EngineConsumer::GetImageDataURL(uploaded_image->image_data));
image.Set("image_url", std::move(image_url_dict));
content.Append(std::move(image));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "base/memory/weak_ptr.h"
#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <type_traits>
#include <vector>

#include "base/base64.h"
#include "base/containers/checked_iterators.h"
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"
Expand All @@ -22,12 +23,14 @@
#include "base/strings/utf_string_conversions.h"
#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "base/test/values_test_util.h"
#include "base/time/time.h"
#include "base/values.h"
#include "brave/components/ai_chat/core/browser/engine/engine_consumer.h"
#include "brave/components/ai_chat/core/browser/engine/test_utils.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h"
#include "brave/components/ai_chat/core/common/test_utils.h"
#include "components/grit/brave_components_strings.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "testing/gmock/include/gmock/gmock.h"
Expand Down Expand Up @@ -446,35 +449,35 @@ TEST_F(EngineConsumerOAIUnitTest, GenerateAssistantResponseEarlyReturn) {
TEST_F(EngineConsumerOAIUnitTest, GenerateAssistantResponseUploadImage) {
EngineConsumer::ConversationHistory history;
auto* client = GetClient();
auto run_loop = std::make_unique<base::RunLoop>();
const std::vector<std::vector<uint8_t>> test_images = {
{0x01, 0x02, 0x03, 0x04, 0x05},
{0xde, 0xed, 0xbe, 0xef},
{0xff, 0xff, 0xff},
};
auto uploaded_images = CreateSampleUploadedImages(3);
constexpr char kTestPrompt[] = "Tell the user what is in the image?";
constexpr char kAssistantResponse[] = "It's a lion!";
EXPECT_CALL(*client, PerformRequest(_, _, _, _))
.WillOnce(
[kTestPrompt, kAssistantResponse](
[kTestPrompt, kAssistantResponse, &uploaded_images](
const mojom::CustomModelOptions, base::Value::List messages,
EngineConsumer::GenerationDataCallback,
EngineConsumer::GenerationCompletedCallback completed_callback) {
EXPECT_EQ(*messages[0].GetDict().Find("role"), "system");

auto expected_dict = ParseJsonDict(R"({
constexpr char kJsonTemplate[] = R"({
"content": [ {
"text": "These images are uploaded by the users",
"type": "text"
}, {
"image_url": {
"url": ""
"url": "data:image/png;base64,$1"
},
"type": "image_url"
} ],
"role": "user"
}
)");
)";
const std::string json_str = base::ReplaceStringPlaceholders(
kJsonTemplate,
{base::Base64Encode(uploaded_images[0]->image_data)}, nullptr);
auto expected_dict = ParseJsonDict(json_str);

EXPECT_EQ(messages[1].GetDict(), expected_dict);

EXPECT_EQ(*messages[2].GetDict().Find("role"), "user");
Expand All @@ -484,27 +487,15 @@ TEST_F(EngineConsumerOAIUnitTest, GenerateAssistantResponseUploadImage) {
.Run(EngineConsumer::GenerationResult(kAssistantResponse));
});

std::vector<mojom::UploadedImagePtr> uploaded_images;
uploaded_images.emplace_back(
mojom::UploadedImage::New("filename1", 1, test_images[0]));
uploaded_images.emplace_back(
mojom::UploadedImage::New("filename", 2, test_images[1]));
uploaded_images.emplace_back(
mojom::UploadedImage::New("filename", 3, test_images[2]));

history.push_back(mojom::ConversationTurn::New(
std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::UNSPECIFIED,
"What is this image?", kTestPrompt, std::nullopt, std::nullopt,
base::Time::Now(), std::nullopt, std::move(uploaded_images), false));

engine_->GenerateAssistantResponse(
false, "", history, "", base::DoNothing(),
base::BindLambdaForTesting([&run_loop, kAssistantResponse](
EngineConsumer::GenerationResult result) {
EXPECT_STREQ(result.value().c_str(), kAssistantResponse);
run_loop->Quit();
}));
run_loop->Run();
base::Time::Now(), std::nullopt, CloneUpdatedImages(uploaded_images),
false));
base::test::TestFuture<EngineConsumer::GenerationResult> future;
engine_->GenerateAssistantResponse(false, "", history, "", base::DoNothing(),
future.GetCallback());
EXPECT_STREQ(future.Take()->c_str(), kAssistantResponse);
testing::Mock::VerifyAndClearExpectations(client);
}

Expand Down
Loading

0 comments on commit 6baaa87

Please sign in to comment.