Skip to content

Commit bd40b61

Browse files
committed
add more tests
1 parent 397cbdb commit bd40b61

File tree

5 files changed

+105
-52
lines changed

5 files changed

+105
-52
lines changed

azure_default_identity_provider.go

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,31 @@ type DefaultAzureIdentityProviderOptions struct {
1515
AzureOptions *azidentity.DefaultAzureCredentialOptions
1616
// Scopes is the list of scopes used to request a token from the identity provider.
1717
Scopes []string
18+
19+
// credFactory is a factory for creating the default Azure credential.
20+
// This is used for testing purposes, to allow mocking the credential creation.
21+
// If not provided, the default implementation - azidentity.NewDefaultAzureCredential will be used
22+
credFactory credFactory
23+
}
24+
25+
type credFactory interface {
26+
NewDefaultAzureCredential(options *azidentity.DefaultAzureCredentialOptions) (defaultAzureCredential, error)
1827
}
1928

2029
type defaultAzureCredential interface {
2130
GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error)
2231
}
2332

33+
type defaultCredFactory struct{}
34+
35+
func (d *defaultCredFactory) NewDefaultAzureCredential(options *azidentity.DefaultAzureCredentialOptions) (defaultAzureCredential, error) {
36+
return azidentity.NewDefaultAzureCredential(options)
37+
}
38+
2439
type DefaultAzureIdentityProvider struct {
25-
options *azidentity.DefaultAzureCredentialOptions
26-
cred defaultAzureCredential
27-
scopes []string
40+
options *azidentity.DefaultAzureCredentialOptions
41+
credFactory credFactory
42+
scopes []string
2843
}
2944

3045
// NewDefaultAzureIdentityProvider creates a new DefaultAzureIdentityProvider.
@@ -33,21 +48,26 @@ func NewDefaultAzureIdentityProvider(opts DefaultAzureIdentityProviderOptions) (
3348
opts.Scopes = []string{RedisScopeDefault}
3449
}
3550

36-
return &DefaultAzureIdentityProvider{options: opts.AzureOptions, scopes: opts.Scopes}, nil
51+
return &DefaultAzureIdentityProvider{
52+
options: opts.AzureOptions,
53+
scopes: opts.Scopes,
54+
credFactory: opts.credFactory,
55+
}, nil
3756
}
3857

3958
// RequestToken requests a token from the Azure Default Identity provider.
4059
// It returns the token, the expiration time, and an error if any.
4160
func (a *DefaultAzureIdentityProvider) RequestToken() (IdentityProviderResponse, error) {
42-
var err error
43-
if a.cred == nil {
44-
a.cred, err = azidentity.NewDefaultAzureCredential(a.options)
45-
if err != nil {
46-
return nil, fmt.Errorf("failed to create default azure credential: %w", err)
47-
}
61+
credFactory := a.credFactory
62+
if credFactory == nil {
63+
credFactory = &defaultCredFactory{}
64+
}
65+
cred, err := credFactory.NewDefaultAzureCredential(a.options)
66+
if err != nil {
67+
return nil, fmt.Errorf("failed to create default azure credential: %w", err)
4868
}
4969

50-
token, err := a.cred.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: a.scopes})
70+
token, err := cred.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: a.scopes})
5171
if err != nil {
5272
return nil, fmt.Errorf("failed to get token: %w", err)
5373
}

azure_default_identity_provider_test.go

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,19 @@ func TestAzureDefaultIdentityProvider_RequestToken(t *testing.T) {
4949
assert.Error(t, err, "failed to request token")
5050

5151
// use mockAzureCredential to simulate the environment
52-
mockCreds := &mockAzureCredential{}
53-
provider.cred = mockCreds
54-
mockToken := azcore.AccessToken{
52+
mToken := azcore.AccessToken{
5553
Token: testJWTtoken,
5654
}
57-
mockCreds.On("GetToken", mock.Anything, mock.Anything).Return(mockToken, nil)
58-
55+
mCreds := &mockAzureCredential{}
56+
mCreds.On("GetToken", mock.Anything, mock.Anything).Return(mToken, nil)
57+
mCredFactory := &mockCredFactory{}
58+
mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil)
59+
provider.credFactory = mCredFactory
5960
token, err = provider.RequestToken()
6061
assert.NotNil(t, token, "token should not be nil")
6162
assert.NoError(t, err, "failed to request token")
6263
assert.Equal(t, ResponseTypeAccessToken, token.Type(), "token type should be access token")
63-
assert.Equal(t, mockToken, token.AccessToken(), "access token should be equal to testJWTtoken")
64+
assert.Equal(t, mToken, token.AccessToken(), "access token should be equal to testJWTtoken")
6465
}
6566

6667
func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) {
@@ -73,22 +74,34 @@ func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) {
7374
t.Fatalf("failed to create DefaultAzureIdentityProvider: %v", err)
7475
}
7576

76-
// Request a token from the provider
77-
token, err := provider.RequestToken()
78-
assert.Nil(t, token, "token should be nil")
79-
assert.Error(t, err, "failed to request token")
80-
81-
// use mockAzureCredential to simulate the environment
82-
mockCreds := &mockAzureCredential{}
83-
provider.cred = mockCreds
84-
mockToken := azcore.AccessToken{
85-
Token: testJWTtoken,
86-
}
87-
mockCreds.On("GetToken", mock.Anything, policy.TokenRequestOptions{Scopes: scopes}).Return(mockToken, nil)
77+
t.Run("RequestToken with custom scopes", func(t *testing.T) {
78+
// Request a token from the provider
79+
token, err := provider.RequestToken()
80+
assert.Nil(t, token, "token should be nil")
81+
assert.Error(t, err, "failed to request token")
8882

89-
token, err = provider.RequestToken()
90-
assert.NotNil(t, token, "token should not be nil")
91-
assert.NoError(t, err, "failed to request token")
92-
assert.Equal(t, ResponseTypeAccessToken, token.Type(), "token type should be access token")
93-
assert.Equal(t, mockToken, token.AccessToken(), "access token should be equal to testJWTtoken")
83+
// use mockAzureCredential to simulate the environment
84+
mToken := azcore.AccessToken{
85+
Token: testJWTtoken,
86+
}
87+
mCreds := &mockAzureCredential{}
88+
mCreds.On("GetToken", mock.Anything, policy.TokenRequestOptions{Scopes: scopes}).Return(mToken, nil)
89+
mCredFactory := &mockCredFactory{}
90+
mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil)
91+
provider.credFactory = mCredFactory
92+
token, err = provider.RequestToken()
93+
assert.NotNil(t, token, "token should not be nil")
94+
assert.NoError(t, err, "failed to request token")
95+
assert.Equal(t, ResponseTypeAccessToken, token.Type(), "token type should be access token")
96+
assert.Equal(t, mToken, token.AccessToken(), "access token should be equal to testJWTtoken")
97+
})
98+
t.Run("RequestToken with error from credFactory", func(t *testing.T) {
99+
// use mockAzureCredential to simulate the environment
100+
mCredFactory := &mockCredFactory{}
101+
mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(nil, assert.AnError)
102+
provider.credFactory = mCredFactory
103+
token, err := provider.RequestToken()
104+
assert.Nil(t, token, "token should be nil")
105+
assert.Error(t, err, "failed to request token")
106+
})
94107
}

entraid_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77

88
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
99
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
10+
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
1011
"github.com/stretchr/testify/mock"
1112
)
1213

@@ -94,5 +95,21 @@ type mockAzureCredential struct {
9495

9596
func (m *mockAzureCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
9697
args := m.Called(ctx, options)
98+
if args.Get(0) == nil {
99+
return azcore.AccessToken{}, args.Error(1)
100+
}
97101
return args.Get(0).(azcore.AccessToken), args.Error(1)
98102
}
103+
104+
type mockCredFactory struct {
105+
// Mock implementation of the credFactory interface
106+
mock.Mock
107+
}
108+
109+
func (m *mockCredFactory) NewDefaultAzureCredential(options *azidentity.DefaultAzureCredentialOptions) (defaultAzureCredential, error) {
110+
args := m.Called(options)
111+
if args.Get(0) == nil {
112+
return nil, args.Error(1)
113+
}
114+
return args.Get(0).(defaultAzureCredential), args.Error(1)
115+
}

token_manager.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,10 @@ var defaultIdentityProviderResponseParser IdentityProviderResponseParserFunc = f
9090
switch response.Type() {
9191
case ResponseTypeAuthResult:
9292
authResult := response.AuthResult()
93-
if authResult.IDToken.RawToken == "" {
94-
return nil, fmt.Errorf("auth result id token is empty")
93+
if authResult.ExpiresOn.Before(time.Now()) {
94+
return nil, fmt.Errorf("auth result expired or invalid")
9595
}
9696
rawToken = authResult.IDToken.RawToken
97-
9897
username = authResult.IDToken.Oid
9998
password = rawToken
10099
expiresOn = authResult.ExpiresOn.UTC()

token_manager_test.go

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"time"
1313

1414
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
15+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
1516
"github.com/stretchr/testify/assert"
1617
"github.com/stretchr/testify/mock"
1718
)
@@ -379,27 +380,30 @@ func TestTokenManager_Start(t *testing.T) {
379380

380381
func TestDefaultIdentityProviderResponseParser(t *testing.T) {
381382
t.Parallel()
382-
/*
383-
t.Run("Default IdentityProviderResponseParser with type AuthResult", func(t *testing.T) {
384-
idpResponse, err := NewIDPResponse(ResponseTypeAuthResult,
385-
&public.AuthResult{
386-
ExpiresOn: time.Now().Add(time.Hour),
387-
})
388-
assert.NoError(t, err)
389-
//_, err := defaultIdentityProviderResponseParser(idpResponse)
390-
//assert.NoError(t, err)
391-
//assert.NotNil(t, token)
392-
})
393-
*/
383+
t.Run("Default IdentityProviderResponseParser with type AuthResult", func(t *testing.T) {
384+
authResult := &public.AuthResult{
385+
ExpiresOn: time.Now().Add(time.Hour).UTC(),
386+
}
387+
idpResponse, err := NewIDPResponse(ResponseTypeAuthResult,
388+
authResult)
389+
assert.NoError(t, err)
390+
token, err := defaultIdentityProviderResponseParser(idpResponse)
391+
assert.NoError(t, err)
392+
assert.NotNil(t, token)
393+
assert.Equal(t, authResult.ExpiresOn, token.ExpirationOn())
394+
})
394395
t.Run("Default IdentityProviderResponseParser with type AccessToken", func(t *testing.T) {
395-
idpResponse, err := NewIDPResponse(ResponseTypeAccessToken, &azcore.AccessToken{
396+
accessToken := &azcore.AccessToken{
396397
Token: testJWTtoken,
397-
ExpiresOn: time.Now().Add(time.Hour),
398-
})
398+
ExpiresOn: time.Now().Add(time.Hour).UTC(),
399+
}
400+
idpResponse, err := NewIDPResponse(ResponseTypeAccessToken, accessToken)
399401
assert.NoError(t, err)
400402
token, err := defaultIdentityProviderResponseParser(idpResponse)
401403
assert.NoError(t, err)
402404
assert.NotNil(t, token)
405+
assert.Equal(t, accessToken.ExpiresOn, token.ExpirationOn())
406+
assert.Equal(t, accessToken.Token, token.RawCredentials())
403407
})
404408
t.Run("Default IdentityProviderResponseParser with type RawToken", func(t *testing.T) {
405409
idpResponse, err := NewIDPResponse(ResponseTypeRawToken, testJWTtoken)

0 commit comments

Comments
 (0)