From d8a6a4e681c95220af2a1e91c2be5f4822bf5ae9 Mon Sep 17 00:00:00 2001 From: Soumava Bera Date: Mon, 17 Mar 2025 15:25:54 -0400 Subject: [PATCH] fix initial response for windows clients --- cmake/sdksCommon.cmake | 2 + .../source/http/windows/WinSyncHttpClient.cpp | 7 ++ .../CMakeLists.txt | 35 +++++++++ .../IntegrationTests.cpp | 77 +++++++++++++++++++ .../RunTests.cpp | 30 ++++++++ tools/scripts/run_integration_tests.py | 4 +- 6 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 tests/aws-cpp-sdk-bedrock-runtime-integration-tests/CMakeLists.txt create mode 100644 tests/aws-cpp-sdk-bedrock-runtime-integration-tests/IntegrationTests.cpp create mode 100644 tests/aws-cpp-sdk-bedrock-runtime-integration-tests/RunTests.cpp diff --git a/cmake/sdksCommon.cmake b/cmake/sdksCommon.cmake index 7eb26f8e40b..258db006342 100644 --- a/cmake/sdksCommon.cmake +++ b/cmake/sdksCommon.cmake @@ -112,6 +112,8 @@ list(APPEND SDK_TEST_PROJECT_LIST "transcribestreaming:tests/aws-cpp-sdk-transcr list(APPEND SDK_TEST_PROJECT_LIST "eventbridge:tests/aws-cpp-sdk-eventbridge-tests") list(APPEND SDK_TEST_PROJECT_LIST "timestream-query:tests/aws-cpp-sdk-timestream-query-integration-tests") list(APPEND SDK_TEST_PROJECT_LIST "dsql:tests/aws-cpp-sdk-dsql-unit-tests") +list(APPEND SDK_TEST_PROJECT_LIST "bedrock-runtime:tests/aws-cpp-sdk-bedrock-runtime-integration-tests") + build_sdk_list() diff --git a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp index a0bb78004d0..b76ff25583a 100644 --- a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp @@ -289,6 +289,13 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr& { hashIterator.second->Update(reinterpret_cast(dst), static_cast(read)); } + + auto& headersHandler = request->GetHeadersReceivedEventHandler(); + if (headersHandler) + { + headersHandler(request.get(), response.get()); + } + if (readLimiter != nullptr) { readLimiter->ApplyAndPayForCost(read); diff --git a/tests/aws-cpp-sdk-bedrock-runtime-integration-tests/CMakeLists.txt b/tests/aws-cpp-sdk-bedrock-runtime-integration-tests/CMakeLists.txt new file mode 100644 index 00000000000..1bc1ef81abb --- /dev/null +++ b/tests/aws-cpp-sdk-bedrock-runtime-integration-tests/CMakeLists.txt @@ -0,0 +1,35 @@ +add_project(aws-cpp-sdk-bedrock-runtime-integration-tests + "Tests for Bedrock Runtime C++ SDK" + aws-cpp-sdk-bedrock-runtime + testing-resources + aws-cpp-sdk-core +) + +# Headers are included in the source so that they show up in Visual Studio. +# They are included elsewhere for consistency. + +file(GLOB AWS_BEDROCK_RUNTIME_SRC + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +file(GLOB AWS_BEDROCK_RUNTIME_INTEGRATION_TESTS_SRC + ${AWS_BEDROCK_RUNTIME_SRC} +) + +if(MSVC AND BUILD_SHARED_LIBS) + add_definitions(-DGTEST_LINKED_AS_SHARED_LIBRARY=1) +endif() + +enable_testing() + +if(PLATFORM_ANDROID AND BUILD_SHARED_LIBS) + add_library(${PROJECT_NAME} ${AWS_BEDROCK_RUNTIME_INTEGRATION_TESTS_SRC}) +else() + add_executable(${PROJECT_NAME} ${AWS_BEDROCK_RUNTIME_INTEGRATION_TESTS_SRC}) +endif() + +set_compiler_flags(${PROJECT_NAME}) +set_compiler_warnings(${PROJECT_NAME}) + +target_link_libraries(${PROJECT_NAME} ${PROJECT_LIBS}) + diff --git a/tests/aws-cpp-sdk-bedrock-runtime-integration-tests/IntegrationTests.cpp b/tests/aws-cpp-sdk-bedrock-runtime-integration-tests/IntegrationTests.cpp new file mode 100644 index 00000000000..4e6613d8255 --- /dev/null +++ b/tests/aws-cpp-sdk-bedrock-runtime-integration-tests/IntegrationTests.cpp @@ -0,0 +1,77 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Aws::BedrockRuntime; +using namespace Aws::BedrockRuntime::Model; +using namespace Aws::Client; +using namespace Aws::Region; + + +namespace +{ +static const char* ALLOCATION_TAG = "BedrockRuntimeTests"; + +class BedrockRuntimeTests : public ::testing::Test +{ +protected: + std::shared_ptr m_client; + + void SetUp() + { + Aws::Client::ClientConfiguration config; + config.region = AWS_TEST_REGION; + m_client = Aws::MakeShared(ALLOCATION_TAG, config); + } + +}; + +TEST_F(BedrockRuntimeTests, TestStreaming) +{ + std::shared_ptr streamHandler = Aws::MakeShared(ALLOCATION_TAG); + + Aws::BedrockRuntime::Model::ConverseStreamRequest bedrockRequest; + // other request setup + std::mutex mutex; + std::condition_variable cv; + auto startTime = std::chrono::system_clock::now(); + bool responseReceived = false; + streamHandler->SetInitialResponseCallbackEx([&](const Aws::BedrockRuntime::Model::ConverseStreamInitialResponse& , const Aws::Utils::Event::InitialResponseType awsResponseType) + { + std::unique_lock lock(mutex); + if (awsResponseType == Aws::Utils::Event::InitialResponseType::ON_RESPONSE) { + responseReceived = true; + cv.notify_one(); + } + }); + + bedrockRequest.SetEventStreamHandler(*streamHandler); + bedrockRequest.SetModelId("dummy model"); + bedrockRequest.SetMessages({}); + Aws::BedrockRuntime::Model::ConverseStreamOutcome outcome = m_client->ConverseStream(bedrockRequest); + ASSERT_FALSE(outcome.IsSuccess()); + std::unique_lock lock(mutex); + cv.wait_until(lock, startTime + std::chrono::seconds(10), [&] { + return responseReceived; }); + ASSERT_TRUE(responseReceived); +} +} diff --git a/tests/aws-cpp-sdk-bedrock-runtime-integration-tests/RunTests.cpp b/tests/aws-cpp-sdk-bedrock-runtime-integration-tests/RunTests.cpp new file mode 100644 index 00000000000..a6f2d5668dd --- /dev/null +++ b/tests/aws-cpp-sdk-bedrock-runtime-integration-tests/RunTests.cpp @@ -0,0 +1,30 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include +#include +#include +#include +#include + +int main(int argc, char** argv) +{ + Aws::Testing::SetDefaultSigPipeHandler(); + Aws::SDKOptions options; + options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Trace; + AWS_BEGIN_MEMORY_TEST_EX(options, 1024, 128); + + Aws::Testing::InitPlatformTest(options); + Aws::Testing::ParseArgs(argc, argv); + + Aws::InitAPI(options); + ::testing::InitGoogleTest(&argc, argv); + int exitCode = RUN_ALL_TESTS(); + + Aws::ShutdownAPI(options); + AWS_END_MEMORY_TEST_EX; + Aws::Testing::ShutdownPlatformTest(options); + return exitCode; +} diff --git a/tools/scripts/run_integration_tests.py b/tools/scripts/run_integration_tests.py index 32783ffff01..790487d9436 100644 --- a/tools/scripts/run_integration_tests.py +++ b/tools/scripts/run_integration_tests.py @@ -54,7 +54,9 @@ def main(): "aws-cpp-sdk-elasticfilesystem-integration-tests", "aws-cpp-sdk-rds-integration-tests", "aws-cpp-sdk-ec2-integration-tests", - "aws-cpp-sdk-timestream-query-integration-tests"] + "aws-cpp-sdk-timestream-query-integration-tests", + "aws-cpp-sdk-bedrock-runtime-integration-tests" + ] # check for existence of these binaries before adding them to tests # as they will not always be present