Skip to content

Commit 82b3375

Browse files
committed
fix initial response for windows clients
1 parent a241c97 commit 82b3375

File tree

6 files changed

+169
-1
lines changed

6 files changed

+169
-1
lines changed

cmake/sdksCommon.cmake

+2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ list(APPEND SDK_TEST_PROJECT_LIST "transcribestreaming:tests/aws-cpp-sdk-transcr
112112
list(APPEND SDK_TEST_PROJECT_LIST "eventbridge:tests/aws-cpp-sdk-eventbridge-tests")
113113
list(APPEND SDK_TEST_PROJECT_LIST "timestream-query:tests/aws-cpp-sdk-timestream-query-integration-tests")
114114
list(APPEND SDK_TEST_PROJECT_LIST "dsql:tests/aws-cpp-sdk-dsql-unit-tests")
115+
list(APPEND SDK_TEST_PROJECT_LIST "bedrock-runtime:tests/aws-cpp-sdk-bedrock-runtime-integration-tests")
116+
115117

116118
build_sdk_list()
117119

src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,13 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr<HttpRequest>&
289289
{
290290
hashIterator.second->Update(reinterpret_cast<unsigned char*>(dst), static_cast<size_t>(read));
291291
}
292+
293+
auto& headersHandler = request->GetHeadersReceivedEventHandler();
294+
if (headersHandler)
295+
{
296+
headersHandler(request.get(), response.get());
297+
}
298+
292299
if (readLimiter != nullptr)
293300
{
294301
readLimiter->ApplyAndPayForCost(read);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
add_project(aws-cpp-sdk-bedrock-runtime-integration-tests
2+
"Tests for Bedrock Runtime C++ SDK"
3+
#aws-cpp-sdk-access-management
4+
aws-cpp-sdk-bedrock-runtime
5+
testing-resources
6+
aws-cpp-sdk-core
7+
)
8+
9+
# Headers are included in the source so that they show up in Visual Studio.
10+
# They are included elsewhere for consistency.
11+
12+
file(GLOB AWS_BEDROCK_RUNTIME_SRC
13+
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
14+
)
15+
16+
file(GLOB AWS_BEDROCK_RUNTIME_INTEGRATION_TESTS_SRC
17+
${AWS_BEDROCK_RUNTIME_SRC}
18+
)
19+
20+
if(MSVC AND BUILD_SHARED_LIBS)
21+
add_definitions(-DGTEST_LINKED_AS_SHARED_LIBRARY=1)
22+
endif()
23+
24+
enable_testing()
25+
26+
if(PLATFORM_ANDROID AND BUILD_SHARED_LIBS)
27+
add_library(${PROJECT_NAME} ${AWS_BEDROCK_RUNTIME_INTEGRATION_TESTS_SRC})
28+
else()
29+
add_executable(${PROJECT_NAME} ${AWS_BEDROCK_RUNTIME_INTEGRATION_TESTS_SRC})
30+
endif()
31+
32+
set_compiler_flags(${PROJECT_NAME})
33+
set_compiler_warnings(${PROJECT_NAME})
34+
35+
target_link_libraries(${PROJECT_NAME} ${PROJECT_LIBS})
36+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/**
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
#include <gtest/gtest.h>
7+
#include <aws/testing/AwsTestHelpers.h>
8+
#include <aws/testing/MemoryTesting.h>
9+
#include <algorithm>
10+
#include <thread>
11+
12+
#include <aws/bedrock-runtime/BedrockRuntimeClient.h>
13+
#include <aws/bedrock-runtime/BedrockRuntimeErrors.h>
14+
#include <aws/bedrock-runtime/model/ConverseStreamRequest.h>
15+
#include <aws/bedrock-runtime/model/ConverseStreamHandler.h>
16+
#include <aws/core/client/CoreErrors.h>
17+
#include <aws/core/utils/json/JsonSerializer.h>
18+
#include <aws/core/utils/Outcome.h>
19+
#include <aws/testing/TestingEnvironment.h>
20+
#include <aws/core/platform/Environment.h>
21+
#include <condition_variable>
22+
#include <chrono>
23+
24+
using namespace Aws::BedrockRuntime;
25+
using namespace Aws::BedrockRuntime::Model;
26+
using namespace Aws::Client;
27+
using namespace Aws::Region;
28+
29+
30+
namespace
31+
{
32+
static const char* ALLOCATION_TAG = "BedrockRuntimeTests";
33+
34+
class BedrockRuntimeTests : public ::testing::Test
35+
{
36+
public:
37+
38+
std::shared_ptr<BedrockRuntimeClient> m_client;
39+
Aws::String testTrace;
40+
41+
protected:
42+
void SetUp()
43+
{
44+
Aws::Client::ClientConfiguration config;
45+
config.region = AWS_TEST_REGION;
46+
m_client = Aws::MakeShared<BedrockRuntimeClient>(ALLOCATION_TAG, config);
47+
}
48+
49+
void TearDown()
50+
{
51+
if (::testing::Test::HasFailure())
52+
{
53+
std::cout << "Test traces: " << testTrace << "\n";
54+
}
55+
testTrace.erase();
56+
}
57+
58+
};
59+
60+
TEST_F(BedrockRuntimeTests, TestStreaming)
61+
{
62+
std::shared_ptr<Aws::BedrockRuntime::Model::ConverseStreamHandler> streamHandler = Aws::MakeShared<Aws::BedrockRuntime::Model::ConverseStreamHandler>(ALLOCATION_TAG);
63+
64+
Aws::BedrockRuntime::Model::ConverseStreamRequest bedrockRequest;
65+
// other request setup
66+
std::mutex mutex;
67+
std::condition_variable cv;
68+
auto startTime = std::chrono::system_clock::now();
69+
bool responseReceived = false;
70+
streamHandler->SetInitialResponseCallbackEx([&](const Aws::BedrockRuntime::Model::ConverseStreamInitialResponse& , const Aws::Utils::Event::InitialResponseType awsResponseType)
71+
{
72+
std::unique_lock<std::mutex> lock(mutex);
73+
cv.wait(lock, [&] {
74+
return !responseReceived && awsResponseType == Aws::Utils::Event::InitialResponseType::ON_RESPONSE; });
75+
responseReceived = true;
76+
cv.notify_one();
77+
mutex.unlock();
78+
});
79+
80+
bedrockRequest.SetEventStreamHandler(*streamHandler);
81+
bedrockRequest.SetModelId("dummy model");
82+
bedrockRequest.SetMessages({});
83+
Aws::BedrockRuntime::Model::ConverseStreamOutcome outcome = m_client->ConverseStream(bedrockRequest);
84+
ASSERT_FALSE(outcome.IsSuccess());
85+
std::unique_lock<std::mutex> lock(mutex);
86+
cv.wait_until(lock, startTime + std::chrono::seconds(10), [&] {
87+
return responseReceived; });
88+
ASSERT_TRUE(responseReceived);
89+
lock.unlock();
90+
}
91+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/**
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
#include <gtest/gtest.h>
7+
#include <aws/core/Aws.h>
8+
#include <aws/testing/platform/PlatformTesting.h>
9+
#include <aws/testing/TestingEnvironment.h>
10+
#include <aws/testing/MemoryTesting.h>
11+
12+
int main(int argc, char** argv)
13+
{
14+
Aws::Testing::SetDefaultSigPipeHandler();
15+
Aws::SDKOptions options;
16+
options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Trace;
17+
AWS_BEGIN_MEMORY_TEST_EX(options, 1024, 128);
18+
19+
Aws::Testing::InitPlatformTest(options);
20+
Aws::Testing::ParseArgs(argc, argv);
21+
22+
Aws::InitAPI(options);
23+
::testing::InitGoogleTest(&argc, argv);
24+
int exitCode = RUN_ALL_TESTS();
25+
26+
Aws::ShutdownAPI(options);
27+
AWS_END_MEMORY_TEST_EX;
28+
Aws::Testing::ShutdownPlatformTest(options);
29+
return exitCode;
30+
}

tools/scripts/run_integration_tests.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def main():
5454
"aws-cpp-sdk-elasticfilesystem-integration-tests",
5555
"aws-cpp-sdk-rds-integration-tests",
5656
"aws-cpp-sdk-ec2-integration-tests",
57-
"aws-cpp-sdk-timestream-query-integration-tests"]
57+
"aws-cpp-sdk-timestream-query-integration-tests",
58+
"aws-cpp-sdk-bedrock-runtime-integration-tests"
59+
]
5860

5961
# check for existence of these binaries before adding them to tests
6062
# as they will not always be present

0 commit comments

Comments
 (0)