From 12a6f34f1ece9958361cc7b346bfdd2dd1bd2d90 Mon Sep 17 00:00:00 2001 From: Ferenc Szabo Date: Thu, 23 Jan 2025 14:54:05 +0100 Subject: [PATCH] test: add new suite of tests after dropping Goblin tests Test that error conditions do not blow up the library. Since dropping the Goblin suite we did not test error handling. --- kms_test.go | 21 ++++++++++ ksmjwt_test.go | 103 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 121 insertions(+), 3 deletions(-) create mode 100644 kms_test.go diff --git a/kms_test.go b/kms_test.go new file mode 100644 index 0000000..69d951a --- /dev/null +++ b/kms_test.go @@ -0,0 +1,21 @@ +package kmsjwt_test + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/kms" +) + +type KMSStub struct { + Err error + PublicKey []byte +} + +func (k KMSStub) GetPublicKey(context.Context, *kms.GetPublicKeyInput, ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { + return &kms.GetPublicKeyOutput{PublicKey: k.PublicKey}, k.Err +} + +func (k KMSStub) Sign(context.Context, *kms.SignInput, ...func(*kms.Options)) (*kms.SignOutput, error) { + // The message is already hashed, so we cannot produce a valid signature here. + return &kms.SignOutput{Signature: []byte("invalid")}, k.Err +} diff --git a/ksmjwt_test.go b/ksmjwt_test.go index eb3aa86..07d623c 100644 --- a/ksmjwt_test.go +++ b/ksmjwt_test.go @@ -2,8 +2,11 @@ package kmsjwt_test import ( "context" + "crypto/ed25519" + "crypto/rand" "crypto/rsa" "crypto/x509" + "errors" "testing" "github.com/aws/aws-sdk-go-v2/config" @@ -92,9 +95,103 @@ func (c Client) GetPublicKey(t *testing.T, ctx context.Context, id string) *rsa. return key.(*rsa.PublicKey) } -func TestAlg(t *testing.T) { +func TestNew(t *testing.T) { + const keyID = "dummy" + + t.Run("happy", func(t *testing.T) { + _, _ = newSignerAndStub(t) + }) + + t.Run("error preserved in chain from KMS", func(t *testing.T) { + ctx := context.Background() + want := errors.New("something went wrong") + + _, err := kmsjwt.New(ctx, KMSStub{Err: want}, keyID) + assert.ErrorIs(t, err, want) + }) + + t.Run("wrong key type", func(t *testing.T) { + ctx := context.Background() + publicKey := encodedED25519PublicKey(t) + + _, err := kmsjwt.New(ctx, KMSStub{PublicKey: publicKey}, keyID) + assert.ErrorContains(t, err, "cannot assert") + }) + + t.Run("key not parsable", func(t *testing.T) { + ctx := context.Background() + publicKey := []byte("something unexpected") + + _, err := kmsjwt.New(ctx, KMSStub{PublicKey: publicKey}, keyID) + assert.ErrorContains(t, err, "could not parse") + }) +} + +func newSignerAndStub(t *testing.T) (*kmsjwt.KMSJWT, *KMSStub) { + t.Helper() + const keyID = "dummy" + ctx := context.Background() + stub := &KMSStub{PublicKey: encodedRSAPublicKey(t)} + signer, err := kmsjwt.New(ctx, stub, keyID) + require.NoError(t, err, "creating signer") + return signer, stub +} + +func encodedRSAPublicKey(t *testing.T) []byte { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err, "generating RSA key") + return encode(t, &key.PublicKey) +} + +func encode(t *testing.T, publicKey any) []byte { + t.Helper() + encoded, err := x509.MarshalPKIXPublicKey(publicKey) + require.NoError(t, err, "encoding public key") + return encoded +} + +func encodedED25519PublicKey(t *testing.T) []byte { + t.Helper() + publicKey, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err, "generating ed25519 key") + return encode(t, publicKey) +} + +func TestKMSJWT_Alg(t *testing.T) { // Valid values: https://datatracker.ietf.org/doc/html/rfc7518#section-3.1 const want = "PS512" - got := kmsjwt.KMSJWT{}.Alg() - assert.Equal(t, want, got, "algorithm changed, that's MAJOR change") + signer, _ := newSignerAndStub(t) + assert.Equal(t, want, signer.Alg(), "algorithm changed, that's MAJOR change") +} + +func TestKMSJWT_Sign(t *testing.T) { + const signMe = "sign me, please" + + t.Run("invalid key type", func(t *testing.T) { + signer, _ := newSignerAndStub(t) + + _, err := signer.Sign(signMe, "foo") + assert.ErrorIs(t, err, jwt.ErrInvalidKeyType) + }) + + t.Run("error preserved in chain", func(t *testing.T) { + ctx := context.Background() + signer, stub := newSignerAndStub(t) + stub.Err = errors.New("something went wrong") + + _, err := signer.Sign(signMe, ctx) + assert.ErrorIs(t, err, stub.Err) + }) +} + +func TestKMSJWT_Verify(t *testing.T) { + const signMe = "sign me, please" + + t.Run("invalid key type", func(t *testing.T) { + signer, _ := newSignerAndStub(t) + + err := signer.Verify(signMe, "invalid signature", "foo") + assert.ErrorIs(t, err, jwt.ErrInvalidKeyType) + }) }