-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathIntegrationTests.cpp
77 lines (66 loc) · 2.57 KB
/
IntegrationTests.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <gtest/gtest.h>
#include <aws/testing/AwsTestHelpers.h>
#include <aws/testing/MemoryTesting.h>
#include <algorithm>
#include <thread>
#include <aws/bedrock-runtime/BedrockRuntimeClient.h>
#include <aws/bedrock-runtime/BedrockRuntimeErrors.h>
#include <aws/bedrock-runtime/model/ConverseStreamRequest.h>
#include <aws/bedrock-runtime/model/ConverseStreamHandler.h>
#include <aws/core/client/CoreErrors.h>
#include <aws/core/utils/json/JsonSerializer.h>
#include <aws/core/utils/Outcome.h>
#include <aws/testing/TestingEnvironment.h>
#include <aws/core/platform/Environment.h>
#include <condition_variable>
#include <chrono>
using namespace Aws::BedrockRuntime;
using namespace Aws::BedrockRuntime::Model;
using namespace Aws::Client;
using namespace Aws::Region;
namespace
{
static const char* ALLOCATION_TAG = "BedrockRuntimeTests";
class BedrockRuntimeTests : public ::testing::Test
{
protected:
std::shared_ptr<BedrockRuntimeClient> m_client;
void SetUp()
{
Aws::Client::ClientConfiguration config;
config.region = AWS_TEST_REGION;
m_client = Aws::MakeShared<BedrockRuntimeClient>(ALLOCATION_TAG, config);
}
};
TEST_F(BedrockRuntimeTests, TestStreaming)
{
std::shared_ptr<Aws::BedrockRuntime::Model::ConverseStreamHandler> streamHandler = Aws::MakeShared<Aws::BedrockRuntime::Model::ConverseStreamHandler>(ALLOCATION_TAG);
Aws::BedrockRuntime::Model::ConverseStreamRequest bedrockRequest;
// other request setup
std::mutex mutex;
std::condition_variable cv;
auto startTime = std::chrono::system_clock::now();
bool responseReceived = false;
streamHandler->SetInitialResponseCallbackEx([&](const Aws::BedrockRuntime::Model::ConverseStreamInitialResponse& , const Aws::Utils::Event::InitialResponseType awsResponseType)
{
std::unique_lock<std::mutex> lock(mutex);
if (awsResponseType == Aws::Utils::Event::InitialResponseType::ON_RESPONSE) {
responseReceived = true;
cv.notify_one();
}
});
bedrockRequest.SetEventStreamHandler(*streamHandler);
bedrockRequest.SetModelId("dummy model");
bedrockRequest.SetMessages({});
Aws::BedrockRuntime::Model::ConverseStreamOutcome outcome = m_client->ConverseStream(bedrockRequest);
ASSERT_FALSE(outcome.IsSuccess());
std::unique_lock<std::mutex> lock(mutex);
cv.wait_until(lock, startTime + std::chrono::seconds(10), [&] {
return responseReceived; });
ASSERT_TRUE(responseReceived);
}
}