-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkmsjwt.go
104 lines (88 loc) · 3.31 KB
/
kmsjwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package kmsjwt
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
"github.com/golang-jwt/jwt/v4"
)
var (
signingMethod = jwt.SigningMethodPS512
_ = jwt.SigningMethod(KMSJWT{})
)
// KMSJWT is a JWT signing method implementation using an asymmetric AWS KMS key.
// The signing is done by KMS service, so there is a network call on every sign action.
// The verification is done on the client side with the exported public key.
// The public key is retrieved from KMS on initialization.
type KMSJWT struct {
client KMS
keyID string
publicKey *rsa.PublicKey
}
// New retrieves the public key from KMS and returns a signer.
func New(ctx context.Context, client KMS, keyID string) (KMSJWT, error) {
publicKey, err := getPublicKey(ctx, client, keyID)
if err != nil {
return KMSJWT{}, fmt.Errorf("kmsjwt new: %w", err)
}
return KMSJWT{client: client, keyID: keyID, publicKey: publicKey}, err
}
func getPublicKey(ctx context.Context, client KMS, keyID string) (*rsa.PublicKey, error) {
response, err := client.GetPublicKey(ctx, &kms.GetPublicKeyInput{KeyId: &keyID})
if err != nil {
return nil, fmt.Errorf("could not retrieve public key: %w", err)
}
publicKey, err := x509.ParsePKIXPublicKey(response.PublicKey)
if err != nil {
return nil, fmt.Errorf("could not parse public key: %w", err)
}
result, ok := publicKey.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("public key: cannot assert %T as %T", publicKey, result)
}
return result, nil
}
// Alg returns the signing algorithm as defined in https://datatracker.ietf.org/doc/html/rfc7518#section-3.1.
func (k KMSJWT) Alg() string {
return signingMethod.Alg()
}
// Sign signs the signingString with AWS KMS using the key ID stored on the object.
// The key parameter expects a context.Context that is used for the network call to KMS.
func (k KMSJWT) Sign(signingString string, key interface{}) (string, error) {
ctx, ok := key.(context.Context)
if !ok {
return "", fmt.Errorf("kmsjwt sign: %w", jwt.ErrInvalidKeyType)
}
hash := signingMethod.Hash.New()
_, err := hash.Write([]byte(signingString))
if err != nil {
return "", fmt.Errorf("kmsjwt writing hash: %w", err)
}
out, err := k.client.Sign(ctx, &kms.SignInput{
KeyId: aws.String(k.keyID),
Message: hash.Sum(nil),
MessageType: types.MessageTypeDigest,
SigningAlgorithm: types.SigningAlgorithmSpecRsassaPssSha512,
})
if err != nil {
return "", fmt.Errorf("kmsjwt signing with KMS: %w", err)
}
return base64.RawURLEncoding.EncodeToString(out.Signature), nil
}
// Verify verifies that the signature is valid for the signingString.
// The verification is done on the client side using the rsa.PublicKey stored on the object.
// For the key parameter a context.Context is expected.
func (k KMSJWT) Verify(signingString, stringSignature string, key interface{}) error {
// We don't use context, but let's keep it so:
// - The interface remains symmetric with Sign.
// - It can be reintroduced later if needed without breaking the interface.
_, ok := key.(context.Context)
if !ok {
return fmt.Errorf("kmsjwt verify: %w", jwt.ErrInvalidKeyType)
}
return signingMethod.Verify(signingString, stringSignature, k.publicKey)
}