Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix streaming initial response for windows clients #3347

Merged
merged 3 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/sdksCommon.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ list(APPEND SDK_TEST_PROJECT_LIST "transcribestreaming:tests/aws-cpp-sdk-transcr
list(APPEND SDK_TEST_PROJECT_LIST "eventbridge:tests/aws-cpp-sdk-eventbridge-tests")
list(APPEND SDK_TEST_PROJECT_LIST "timestream-query:tests/aws-cpp-sdk-timestream-query-integration-tests")
list(APPEND SDK_TEST_PROJECT_LIST "dsql:tests/aws-cpp-sdk-dsql-unit-tests")
list(APPEND SDK_TEST_PROJECT_LIST "bedrock-runtime:tests/aws-cpp-sdk-bedrock-runtime-integration-tests")


build_sdk_list()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,13 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr<HttpRequest>&
{
hashIterator.second->Update(reinterpret_cast<unsigned char*>(dst), static_cast<size_t>(read));
}

auto& headersHandler = request->GetHeadersReceivedEventHandler();
if (headersHandler)
{
headersHandler(request.get(), response.get());
}

if (readLimiter != nullptr)
{
readLimiter->ApplyAndPayForCost(read);
Expand Down
35 changes: 35 additions & 0 deletions tests/aws-cpp-sdk-bedrock-runtime-integration-tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
add_project(aws-cpp-sdk-bedrock-runtime-integration-tests
"Tests for Bedrock Runtime C++ SDK"
aws-cpp-sdk-bedrock-runtime
testing-resources
aws-cpp-sdk-core
)

# Headers are included in the source so that they show up in Visual Studio.
# They are included elsewhere for consistency.

file(GLOB AWS_BEDROCK_RUNTIME_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
)

file(GLOB AWS_BEDROCK_RUNTIME_INTEGRATION_TESTS_SRC
${AWS_BEDROCK_RUNTIME_SRC}
)

if(MSVC AND BUILD_SHARED_LIBS)
add_definitions(-DGTEST_LINKED_AS_SHARED_LIBRARY=1)
endif()

enable_testing()

if(PLATFORM_ANDROID AND BUILD_SHARED_LIBS)
add_library(${PROJECT_NAME} ${AWS_BEDROCK_RUNTIME_INTEGRATION_TESTS_SRC})
else()
add_executable(${PROJECT_NAME} ${AWS_BEDROCK_RUNTIME_INTEGRATION_TESTS_SRC})
endif()

set_compiler_flags(${PROJECT_NAME})
set_compiler_warnings(${PROJECT_NAME})

target_link_libraries(${PROJECT_NAME} ${PROJECT_LIBS})

Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

#include <gtest/gtest.h>
#include <aws/testing/AwsTestHelpers.h>
#include <aws/testing/MemoryTesting.h>
#include <algorithm>
#include <thread>

#include <aws/bedrock-runtime/BedrockRuntimeClient.h>
#include <aws/bedrock-runtime/BedrockRuntimeErrors.h>
#include <aws/bedrock-runtime/model/ConverseStreamRequest.h>
#include <aws/bedrock-runtime/model/ConverseStreamHandler.h>
#include <aws/core/client/CoreErrors.h>
#include <aws/core/utils/json/JsonSerializer.h>
#include <aws/core/utils/Outcome.h>
#include <aws/testing/TestingEnvironment.h>
#include <aws/core/platform/Environment.h>
#include <condition_variable>
#include <chrono>

using namespace Aws::BedrockRuntime;
using namespace Aws::BedrockRuntime::Model;
using namespace Aws::Client;
using namespace Aws::Region;


namespace
{
static const char* ALLOCATION_TAG = "BedrockRuntimeTests";

class BedrockRuntimeTests : public ::testing::Test
{
protected:
std::shared_ptr<BedrockRuntimeClient> m_client;

void SetUp()
{
Aws::Client::ClientConfiguration config;
config.region = AWS_TEST_REGION;
m_client = Aws::MakeShared<BedrockRuntimeClient>(ALLOCATION_TAG, config);
}

};

TEST_F(BedrockRuntimeTests, TestStreaming)
{
std::shared_ptr<Aws::BedrockRuntime::Model::ConverseStreamHandler> streamHandler = Aws::MakeShared<Aws::BedrockRuntime::Model::ConverseStreamHandler>(ALLOCATION_TAG);

Aws::BedrockRuntime::Model::ConverseStreamRequest bedrockRequest;
// other request setup
std::mutex mutex;
std::condition_variable cv;
auto startTime = std::chrono::system_clock::now();
bool responseReceived = false;
streamHandler->SetInitialResponseCallbackEx([&](const Aws::BedrockRuntime::Model::ConverseStreamInitialResponse& , const Aws::Utils::Event::InitialResponseType awsResponseType)
{
std::unique_lock<std::mutex> lock(mutex);
if (awsResponseType == Aws::Utils::Event::InitialResponseType::ON_RESPONSE) {
responseReceived = true;
cv.notify_one();
}
});

bedrockRequest.SetEventStreamHandler(*streamHandler);
bedrockRequest.SetModelId("dummy model");
bedrockRequest.SetMessages({});
Aws::BedrockRuntime::Model::ConverseStreamOutcome outcome = m_client->ConverseStream(bedrockRequest);
ASSERT_FALSE(outcome.IsSuccess());
std::unique_lock<std::mutex> lock(mutex);
cv.wait_until(lock, startTime + std::chrono::seconds(10), [&] {
return responseReceived; });
ASSERT_TRUE(responseReceived);
}
}
30 changes: 30 additions & 0 deletions tests/aws-cpp-sdk-bedrock-runtime-integration-tests/RunTests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

#include <gtest/gtest.h>
#include <aws/core/Aws.h>
#include <aws/testing/platform/PlatformTesting.h>
#include <aws/testing/TestingEnvironment.h>
#include <aws/testing/MemoryTesting.h>

int main(int argc, char** argv)
{
Aws::Testing::SetDefaultSigPipeHandler();
Aws::SDKOptions options;
options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Trace;
AWS_BEGIN_MEMORY_TEST_EX(options, 1024, 128);

Aws::Testing::InitPlatformTest(options);
Aws::Testing::ParseArgs(argc, argv);

Aws::InitAPI(options);
::testing::InitGoogleTest(&argc, argv);
int exitCode = RUN_ALL_TESTS();

Aws::ShutdownAPI(options);
AWS_END_MEMORY_TEST_EX;
Aws::Testing::ShutdownPlatformTest(options);
return exitCode;
}
4 changes: 3 additions & 1 deletion tools/scripts/run_integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def main():
"aws-cpp-sdk-elasticfilesystem-integration-tests",
"aws-cpp-sdk-rds-integration-tests",
"aws-cpp-sdk-ec2-integration-tests",
"aws-cpp-sdk-timestream-query-integration-tests"]
"aws-cpp-sdk-timestream-query-integration-tests",
"aws-cpp-sdk-bedrock-runtime-integration-tests"
]

# check for existence of these binaries before adding them to tests
# as they will not always be present
Expand Down
Loading