Skip to content

Commit 5dcfb0f

Browse files
committed
Add constructor overload to STSProfileCredentialsProvider where the client factory returns a shared pointer.
1 parent 6584675 commit 5dcfb0f

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,47 @@ namespace Aws
5252
*/
5353
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration = std::chrono::minutes(60));
5454

55+
/**
56+
* Use the provided profile name from the shared configuration file and a custom STS client.
57+
*
58+
* @param profileName The name of the profile in the shared configuration file.
59+
* @param duration The duration, in minutes, of the role session, after which the credentials are expired.
60+
* The value can range from 15 minutes up to the maximum session duration setting for the role. By default,
61+
* the duration is set to 1 hour.
62+
* Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That
63+
* ensures the credentials do not expire between the time they're checked and the time they're returned to
64+
* the user.
65+
* If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only
66+
* when they expire.
67+
* @param stsClientFactory A factory function that creates an STSClient with specific credentials.
68+
* Using the overload where the function returns a shared_ptr is preferred.
69+
*
70+
*/
5571
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<Aws::STS::STSClient*(const AWSCredentials&)> &stsClientFactory);
5672

73+
/**
74+
* Use the provided profile name from the shared configuration file and a custom STS client.
75+
*
76+
* @param profileName The name of the profile in the shared configuration file.
77+
* @param duration The duration, in minutes, of the role session, after which the credentials are expired.
78+
* The value can range from 15 minutes up to the maximum session duration setting for the role. By default,
79+
* the duration is set to 1 hour.
80+
* Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That
81+
* ensures the credentials do not expire between the time they're checked and the time they're returned to
82+
* the user.
83+
* If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only
84+
* when they expire.
85+
* @param stsClientFactory A factory function that creates an STSClient with specific credentials.
86+
*
87+
*/
88+
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<std::shared_ptr<Aws::STS::STSClient>(const AWSCredentials&)> &stsClientFactory);
89+
90+
/**
91+
* Compatibility constructor to assist with overload resolution when passing nullptr for the client factory.
92+
*
93+
*/
94+
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t);
95+
5796
/**
5897
* Fetches the credentials set from STS following the rules defined in the shared configuration file.
5998
*/
@@ -74,7 +113,7 @@ namespace Aws
74113
AWSCredentials m_credentials;
75114
const std::chrono::minutes m_duration;
76115
const std::chrono::milliseconds m_reloadFrequency;
77-
std::function<Aws::STS::STSClient*(const AWSCredentials&)> m_stsClientFactory;
116+
std::function<std::shared_ptr<Aws::STS::STSClient>(const AWSCredentials&)> m_stsClientFactory;
78117
};
79118
}
80119
}

src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ using namespace Aws::Auth;
1717

1818
constexpr char CLASS_TAG[] = "STSProfileCredentialsProvider";
1919

20+
template <typename T>
21+
struct NoOpDeleter
22+
{
23+
void operator()(T*) {}
24+
};
25+
2026
STSProfileCredentialsProvider::STSProfileCredentialsProvider()
2127
: STSProfileCredentialsProvider(GetConfigProfileName(), std::chrono::minutes(60)/*duration*/, nullptr/*stsClientFactory*/)
2228
{
@@ -27,8 +33,24 @@ STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String&
2733
{
2834
}
2935

36+
STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t)
37+
: m_profileName(profileName),
38+
m_duration(duration),
39+
m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast<int64_t>(duration.count()))) - std::chrono::minutes(5)),
40+
m_stsClientFactory(nullptr)
41+
{
42+
}
43+
3044
STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<Aws::STS::STSClient*(const AWSCredentials&)> &stsClientFactory)
3145
: m_profileName(profileName),
46+
m_duration(duration),
47+
m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast<int64_t>(duration.count()))) - std::chrono::minutes(5)),
48+
m_stsClientFactory([=](const auto& credentials) {return std::shared_ptr<Aws::STS::STSClient>(stsClientFactory(credentials), NoOpDeleter<Aws::STS::STSClient>()); })
49+
{
50+
}
51+
52+
STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<std::shared_ptr<Aws::STS::STSClient> (const AWSCredentials&)>& stsClientFactory)
53+
: m_profileName(profileName),
3254
m_duration(duration),
3355
m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast<int64_t>(duration.count()))) - std::chrono::minutes(5)),
3456
m_stsClientFactory(stsClientFactory)
@@ -337,7 +359,8 @@ AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromSTS(const AWSCre
337359
{
338360
using namespace Aws::STS::Model;
339361
if (m_stsClientFactory) {
340-
return GetCredentialsFromSTSInternal(roleArn, m_stsClientFactory(credentials));
362+
auto client = m_stsClientFactory(credentials);
363+
return GetCredentialsFromSTSInternal(roleArn, client.get());
341364
}
342365

343366
Aws::STS::STSClient stsClient {credentials};

0 commit comments

Comments
 (0)