Skip to content

Commit

Permalink
Migrate to Go v2 SDK
Browse files Browse the repository at this point in the history
* Pull in files related to Go v2 SDK for Roles Anywhere
* Downstream changes to code calling CreateSession with new client
  • Loading branch information
13ajay committed Feb 19, 2025
1 parent f695756 commit 24d4dae
Show file tree
Hide file tree
Showing 36 changed files with 5,474 additions and 1,000 deletions.
80 changes: 60 additions & 20 deletions aws_signing_helper/credentials.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
package aws_signing_helper

import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"runtime"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/rolesanywhere-credential-helper/rolesanywhere"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)

type CredentialsOpts struct {
Expand All @@ -38,6 +41,18 @@ type CredentialsOpts struct {
RoleSessionName string
}

// Middleware to set a custom user agent header
func createCredHelperUserAgentMiddleware(userAgent string) middleware.BuildMiddleware {
return middleware.BuildMiddlewareFunc("UserAgent", func(
ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler,
) (middleware.BuildOutput, middleware.Metadata, error) {
if req, ok := input.Request.(*smithyhttp.Request); ok {
req.Header.Set("User-Agent", userAgent)
}
return next.HandleBuild(ctx, input)
})
}

// Function to create session and generate credentials
func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (CredentialProcessOutput, error) {
// Assign values to region and endpoint if they haven't already been assigned
Expand All @@ -58,15 +73,12 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith
opts.Region = trustAnchorArn.Region
}

mySession := session.Must(session.NewSession())

var logLevel aws.LogLevelType
var logMode aws.ClientLogMode = 0
if Debug {
logLevel = aws.LogDebug
} else {
logLevel = aws.LogOff
logMode = aws.LogSigning | aws.LogRetries | aws.LogRequestWithBody | aws.LogResponseWithBody | aws.LogRequestEventMessage | aws.LogResponseEventMessage
}

// Custom HTTP client with proxy and TLS settings
var tr *http.Transport
if opts.WithProxy {
tr = &http.Transport{
Expand All @@ -78,15 +90,26 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: opts.NoVerifySSL},
}
}
client := &http.Client{Transport: tr}
config := aws.NewConfig().WithRegion(opts.Region).WithHTTPClient(client).WithLogLevel(logLevel)
httpClient := &http.Client{Transport: tr}
ctx := context.TODO()
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(opts.Region), config.WithHTTPClient(httpClient), config.WithClientLogMode(logMode))
if err != nil {
return CredentialProcessOutput{}, err
}

// Override endpoint if specified
if opts.Endpoint != "" {
config.WithEndpoint(opts.Endpoint)
cfg.BaseEndpoint = aws.String(opts.Endpoint)
}
rolesAnywhereClient := rolesanywhere.New(mySession, config)
rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler")
rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: request.MakeAddToUserAgentHandler("CredHelper", opts.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH)})
rolesAnywhereClient.Handlers.Sign.Clear()

// Set a custom user agent
userAgentStr := fmt.Sprintf("CredHelper/%s (%s; %s; %s)", opts.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH)
cfg.APIOptions = append(cfg.APIOptions, func(stack *middleware.Stack) error {
stack.Build.Remove("UserAgent")
return stack.Build.Add(createCredHelperUserAgentMiddleware(userAgentStr), middleware.After)
})

// Add custom request signer, implementing SigV4-X509
certificate, err := signer.Certificate()
if err != nil {
return CredentialProcessOutput{}, errors.New("unable to find certificate")
Expand All @@ -98,10 +121,27 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith
log.Println(err)
}
}
rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: CreateRequestSignFunction(signer, signatureAlgorithm, certificate, certificateChain)})
cfg.APIOptions = append(cfg.APIOptions, func(stack *middleware.Stack) error {
for _, name := range stack.Finalize.List() {
fmt.Println(name)
}
// Remove middleware related to SigV4 signing
stack.Finalize.Remove("Signing")
stack.Finalize.Remove("setLegacyContextSigningOptions")
stack.Finalize.Remove("GetIdentity")
// Add middleware for SigV4-X509 signing
stack.Finalize.Add(middleware.FinalizeMiddlewareFunc("Signing", CreateRequestSignFinalizeFunction(signer, opts.Region, signatureAlgorithm, certificate, certificateChain)), middleware.After)
for _, name := range stack.Finalize.List() {
fmt.Println(name)
}
return nil
})

// Create the Roles Anywhere client using the above-constructed Config
rolesAnywhereClient := rolesanywhere.NewFromConfig(cfg)

certificateStr := base64.StdEncoding.EncodeToString(certificate.Raw)
durationSeconds := int64(opts.SessionDuration)
durationSeconds := int32(opts.SessionDuration)
createSessionRequest := rolesanywhere.CreateSessionInput{
Cert: &certificateStr,
ProfileArn: &opts.ProfileArnStr,
Expand All @@ -114,7 +154,7 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith
if opts.RoleSessionName != "" {
createSessionRequest.RoleSessionName = &opts.RoleSessionName
}
output, err := rolesAnywhereClient.CreateSession(&createSessionRequest)
output, err := rolesAnywhereClient.CreateSession(ctx, &createSessionRequest)
if err != nil {
return CredentialProcessOutput{}, err
}
Expand Down
96 changes: 36 additions & 60 deletions aws_signing_helper/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aws_signing_helper

import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/rand"
Expand All @@ -24,8 +25,9 @@ import (
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"golang.org/x/crypto/pkcs12"
"golang.org/x/term"
)
Expand Down Expand Up @@ -57,6 +59,9 @@ var (
"Trust",
"CA",
}

// Signing name for the IAM Roles Anywhere service
ROLESANYWHERE_SIGNING_NAME = "rolesanywhere"
)

// Interface that all signers will have to implement
Expand Down Expand Up @@ -425,75 +430,46 @@ func certificateChainToString(certificateChain []*x509.Certificate) string {
return x509ChainString.String()
}

func CreateRequestSignFunction(signer crypto.Signer, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate) func(*request.Request) {
return func(req *request.Request) {
region := req.ClientInfo.SigningRegion
if region == "" {
region = aws.StringValue(req.Config.Region)
}

name := req.ClientInfo.SigningName
if name == "" {
name = req.ClientInfo.ServiceName
}

signerParams := SignerParams{time.Now(), region, name, signingAlgorithm}

// Set headers that are necessary for signing
req.HTTPRequest.Header.Set(host, req.HTTPRequest.URL.Host)
req.HTTPRequest.Header.Set(x_amz_date, signerParams.GetFormattedSigningDateTime())
req.HTTPRequest.Header.Set(x_amz_x509, certificateToString(certificate))
if certificateChain != nil {
req.HTTPRequest.Header.Set(x_amz_x509_chain, certificateChainToString(certificateChain))
}

contentSha256 := calculateContentHash(req.HTTPRequest, req.Body)
if req.HTTPRequest.Header.Get(x_amz_content_sha256) == "required" {
req.HTTPRequest.Header.Set(x_amz_content_sha256, contentSha256)
func CreateRequestSignFinalizeFunction(signer crypto.Signer, signingRegion string, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate) func(context.Context, middleware.FinalizeInput, middleware.FinalizeHandler) (middleware.FinalizeOutput, middleware.Metadata, error) {
return func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, errors.New(fmt.Sprintf("unexpected request middleware type %T", in.Request))
}

canonicalRequest, signedHeadersString := createCanonicalRequest(req.HTTPRequest, req.Body, contentSha256)
payloadHash := v4.GetPayloadHash(ctx)
signRequest(signer, signingRegion, signingAlgorithm, certificate, certificateChain, req.Request, payloadHash)

stringToSign := CreateStringToSign(canonicalRequest, signerParams)
signatureBytes, err := signer.Sign(rand.Reader, []byte(stringToSign), crypto.SHA256)
if err != nil {
log.Println("could not sign", err)
os.Exit(1)
}
signature := hex.EncodeToString(signatureBytes)

req.HTTPRequest.Header.Set(authorization, BuildAuthorizationHeader(req.HTTPRequest, req.Body, signedHeadersString, signature, certificate, signerParams))
req.SignedHeaderVals = req.HTTPRequest.Header
return next.HandleFinalize(ctx, in)
}
}

// Find the SHA256 hash of the provided request body as a io.ReadSeeker
func makeSha256Reader(reader io.ReadSeeker) []byte {
hash := sha256.New()
start, _ := reader.Seek(0, 1)
defer reader.Seek(start, 0)
func signRequest(signer crypto.Signer, signingRegion string, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate, req *http.Request, payloadHash string) {
signerParams := SignerParams{time.Now(), signingRegion, ROLESANYWHERE_SIGNING_NAME, signingAlgorithm}

io.Copy(hash, reader)
return hash.Sum(nil)
}
// Set headers that are necessary for signing
req.Header.Set(host, req.URL.Host)
req.Header.Set(x_amz_date, signerParams.GetFormattedSigningDateTime())
req.Header.Set(x_amz_x509, certificateToString(certificate))
if certificateChain != nil {
req.Header.Set(x_amz_x509_chain, certificateChainToString(certificateChain))
}

// Calculate the hash of the request body
func calculateContentHash(r *http.Request, body io.ReadSeeker) string {
hash := r.Header.Get(x_amz_content_sha256)
canonicalRequest, signedHeadersString := createCanonicalRequest(req, payloadHash)

if hash == "" {
if body == nil {
hash = emptyStringSHA256
} else {
hash = hex.EncodeToString(makeSha256Reader(body))
}
stringToSign := CreateStringToSign(canonicalRequest, signerParams)
signatureBytes, err := signer.Sign(rand.Reader, []byte(stringToSign), crypto.SHA256)
if err != nil {
log.Println("could not sign request", err)
os.Exit(1)
}
signature := hex.EncodeToString(signatureBytes)

return hash
req.Header.Set(authorization, BuildAuthorizationHeader(req, signedHeadersString, signature, certificate, signerParams))
}

// Create the canonical query string.
func createCanonicalQueryString(r *http.Request, body io.ReadSeeker) string {
func createCanonicalQueryString(r *http.Request) string {
rawQuery := strings.Replace(r.URL.Query().Encode(), "+", "%20", -1)
return rawQuery
}
Expand Down Expand Up @@ -573,14 +549,14 @@ func stripExcessSpaces(vals []string) {
}

// Create the canonical request.
func createCanonicalRequest(r *http.Request, body io.ReadSeeker, contentSha256 string) (string, string) {
func createCanonicalRequest(r *http.Request, contentSha256 string) (string, string) {
var canonicalRequestStrBuilder strings.Builder
canonicalHeaderString, signedHeadersString := createCanonicalHeaderString(r)
canonicalRequestStrBuilder.WriteString("POST")
canonicalRequestStrBuilder.WriteString("\n")
canonicalRequestStrBuilder.WriteString("/sessions")
canonicalRequestStrBuilder.WriteString("\n")
canonicalRequestStrBuilder.WriteString(createCanonicalQueryString(r, body))
canonicalRequestStrBuilder.WriteString(createCanonicalQueryString(r))
canonicalRequestStrBuilder.WriteString("\n")
canonicalRequestStrBuilder.WriteString(canonicalHeaderString)
canonicalRequestStrBuilder.WriteString("\n\n")
Expand All @@ -607,7 +583,7 @@ func CreateStringToSign(canonicalRequest string, signerParams SignerParams) stri
}

// Builds the complete authorization header
func BuildAuthorizationHeader(request *http.Request, body io.ReadSeeker, signedHeadersString string, signature string, certificate *x509.Certificate, signerParams SignerParams) string {
func BuildAuthorizationHeader(request *http.Request, signedHeadersString string, signature string, certificate *x509.Certificate, signerParams SignerParams) string {
signingCredentials := certificate.SerialNumber.String() + "/" + signerParams.GetScope()
credential := "Credential=" + signingCredentials
signerHeaders := "SignedHeaders=" + signedHeadersString
Expand Down
11 changes: 4 additions & 7 deletions aws_signing_helper/signer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import (
"testing"
"time"
"unicode/utf8"

"github.com/aws/aws-sdk-go/aws/request"
)

const TestCredentialsFilePath = "/tmp/credentials"
Expand Down Expand Up @@ -125,7 +123,6 @@ func TestBuildAuthorizationHeader(t *testing.T) {
certificate1 := certificateList1[0]
pkPath := "../tst/certs/rsa-2048-key.pem"

awsRequest := request.Request{HTTPRequest: testRequest}
signer, signingAlgorithm, err := GetFileSystemSigner(pkPath, "", path, false)
if err != nil {
t.Log(err)
Expand All @@ -151,8 +148,9 @@ func TestBuildAuthorizationHeader(t *testing.T) {
t.Fail()
}
}
requestSignFunction := CreateRequestSignFunction(signer, signingAlgorithm, certificate, certificateChain)
requestSignFunction(&awsRequest)
signingRegion := "us-west-2"
emptyStringSHA256 := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
signRequest(signer, signingRegion, signingAlgorithm, certificate, certificateChain, testRequest, emptyStringSHA256)

certificateList2, _ := ReadCertificateBundleData("../tst/certs/rsa-4096-sha256-cert.pem")
certificate2 := certificateList2[0]
Expand Down Expand Up @@ -180,8 +178,7 @@ func TestBuildAuthorizationHeader(t *testing.T) {
}
os.Rename("../tst/certs/rsa-2048-sha256-cert.pem", "../tst/certs/rsa-4096-sha256-cert.pem")
os.Rename("../tst/certs/rsa-2048-sha256-cert.pem.bak", "../tst/certs/rsa-2048-sha256-cert.pem")
requestSignFunction2 := CreateRequestSignFunction(signer, signingAlgorithm, certificate, certificateChain)
requestSignFunction2(&awsRequest)
signRequest(signer, signingRegion, signingAlgorithm, certificate, certificateChain, testRequest, emptyStringSHA256)
}

func TestSign(t *testing.T) {
Expand Down
15 changes: 13 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ go 1.22.5

require (
github.com/aws/aws-sdk-go v1.55.5
github.com/aws/aws-sdk-go-v2 v1.36.1
github.com/aws/aws-sdk-go-v2/config v1.29.6
github.com/aws/smithy-go v1.22.2
github.com/google/go-tpm v0.3.3
github.com/miekg/pkcs11 v1.1.1
github.com/spf13/cobra v1.8.1
Expand All @@ -15,8 +18,16 @@ require (
)

require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.59 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.28 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.32 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.32 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.13 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.24.15 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.14 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.14 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
Loading

0 comments on commit 24d4dae

Please sign in to comment.