Skip to content

Commit 3b2949f

Browse files
committed
Adding a token getter to get service account tokens
1 parent 872b7f7 commit 3b2949f

File tree

2 files changed

+199
-0
lines changed

2 files changed

+199
-0
lines changed
+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package authentication
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
authenticationv1 "k8s.io/api/authentication/v1"
9+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
10+
"k8s.io/apimachinery/pkg/types"
11+
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
12+
"k8s.io/utils/ptr"
13+
)
14+
15+
type TokenGetter struct {
16+
client corev1.ServiceAccountsGetter
17+
expirationDuration time.Duration
18+
removeAfterExpiredDuration time.Duration
19+
tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus
20+
mu sync.RWMutex
21+
}
22+
23+
type TokenGetterOption func(*TokenGetter)
24+
25+
const (
26+
RotationThresholdPercentage = 10
27+
DefaultExpirationDuration = 5 * time.Minute
28+
DefaultRemoveAfterExpiredDuration = 90 * time.Minute
29+
)
30+
31+
// Returns a token getter that can fetch tokens given a service account.
32+
// The token getter also caches tokens which helps reduce the number of requests to the API Server.
33+
// In case a cached token is expiring a fresh token is created.
34+
func NewTokenGetter(client corev1.ServiceAccountsGetter, options ...TokenGetterOption) *TokenGetter {
35+
tokenGetter := &TokenGetter{
36+
client: client,
37+
expirationDuration: DefaultExpirationDuration,
38+
removeAfterExpiredDuration: DefaultRemoveAfterExpiredDuration,
39+
tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{},
40+
}
41+
42+
for _, opt := range options {
43+
opt(tokenGetter)
44+
}
45+
46+
return tokenGetter
47+
}
48+
49+
func WithExpirationDuration(expirationDuration time.Duration) TokenGetterOption {
50+
return func(tg *TokenGetter) {
51+
tg.expirationDuration = expirationDuration
52+
}
53+
}
54+
55+
func WithRemoveAfterExpiredDuration(removeAfterExpiredDuration time.Duration) TokenGetterOption {
56+
return func(tg *TokenGetter) {
57+
tg.removeAfterExpiredDuration = removeAfterExpiredDuration
58+
}
59+
}
60+
61+
// Get returns a token from the cache if available and not expiring, otherwise creates a new token
62+
func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) {
63+
t.mu.RLock()
64+
token, ok := t.tokens[key]
65+
t.mu.RUnlock()
66+
67+
expireTime := time.Time{}
68+
if ok {
69+
expireTime = token.ExpirationTimestamp.Time
70+
}
71+
72+
// Create a new token if the cached token expires within DurationPercentage of expirationDuration from now
73+
rotationThresholdAfterNow := metav1.Now().Add(t.expirationDuration * (RotationThresholdPercentage / 100))
74+
if expireTime.Before(rotationThresholdAfterNow) {
75+
var err error
76+
token, err = t.getToken(ctx, key)
77+
if err != nil {
78+
return "", err
79+
}
80+
t.mu.Lock()
81+
t.tokens[key] = token
82+
t.mu.Unlock()
83+
}
84+
85+
// Delete tokens that have been expired for more than ExpiredDuration
86+
t.reapExpiredTokens(t.removeAfterExpiredDuration)
87+
88+
return token.Token, nil
89+
}
90+
91+
func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (*authenticationv1.TokenRequestStatus, error) {
92+
req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx,
93+
key.Name,
94+
&authenticationv1.TokenRequest{
95+
Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](int64(t.expirationDuration))},
96+
}, metav1.CreateOptions{})
97+
if err != nil {
98+
return nil, err
99+
}
100+
return &req.Status, nil
101+
}
102+
103+
func (t *TokenGetter) reapExpiredTokens(expiredDuration time.Duration) {
104+
t.mu.Lock()
105+
defer t.mu.Unlock()
106+
for key, token := range t.tokens {
107+
if metav1.Now().Sub(token.ExpirationTimestamp.Time) > expiredDuration {
108+
delete(t.tokens, key)
109+
}
110+
}
111+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package authentication
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
authenticationv1 "k8s.io/api/authentication/v1"
11+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
12+
"k8s.io/apimachinery/pkg/runtime"
13+
"k8s.io/apimachinery/pkg/types"
14+
"k8s.io/client-go/kubernetes/fake"
15+
ctest "k8s.io/client-go/testing"
16+
)
17+
18+
func TestTokenGetterGet(t *testing.T) {
19+
fakeClient := fake.NewSimpleClientset()
20+
fakeClient.PrependReactor("create", "serviceaccounts/token",
21+
func(action ctest.Action) (bool, runtime.Object, error) {
22+
act, ok := action.(ctest.CreateActionImpl)
23+
if !ok {
24+
return false, nil, nil
25+
}
26+
tokenRequest := act.GetObject().(*authenticationv1.TokenRequest)
27+
var err error
28+
if act.Name == "test-service-account-1" {
29+
tokenRequest.Status = authenticationv1.TokenRequestStatus{
30+
Token: "test-token-1",
31+
ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(DefaultExpirationDuration)),
32+
}
33+
}
34+
if act.Name == "test-service-account-2" {
35+
tokenRequest.Status = authenticationv1.TokenRequestStatus{
36+
Token: "test-token-2",
37+
ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(1 * time.Second)),
38+
}
39+
}
40+
if act.Name == "test-service-account-3" {
41+
tokenRequest.Status = authenticationv1.TokenRequestStatus{
42+
Token: "test-token-3",
43+
ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(-DefaultRemoveAfterExpiredDuration)),
44+
}
45+
}
46+
if act.Name == "test-service-account-4" {
47+
tokenRequest = nil
48+
err = fmt.Errorf("error when fetching token")
49+
}
50+
return true, tokenRequest, err
51+
})
52+
53+
tg := NewTokenGetter(fakeClient.CoreV1(),
54+
WithExpirationDuration(DefaultExpirationDuration),
55+
WithRemoveAfterExpiredDuration(DefaultRemoveAfterExpiredDuration))
56+
57+
tests := []struct {
58+
testName string
59+
serviceAccountName string
60+
namespace string
61+
want string
62+
errorMsg string
63+
}{
64+
{"Testing getting token with fake client", "test-service-account-1",
65+
"test-namespace-1", "test-token-1", "failed to get token"},
66+
{"Testing getting token from cache", "test-service-account-1",
67+
"test-namespace-1", "test-token-1", "failed to get token"},
68+
{"Testing getting short lived token from fake client", "test-service-account-2",
69+
"test-namespace-2", "test-token-2", "failed to get token"},
70+
{"Testing getting expired token from cache", "test-service-account-2",
71+
"test-namespace-2", "test-token-2", "failed to refresh token"},
72+
{"Testing token that expired 90 minutes ago", "test-service-account-3",
73+
"test-namespace-3", "test-token-3", "failed to get token"},
74+
{"Testing error when getting token from fake client", "test-service-account-4",
75+
"test-namespace-4", "error when fetching token", "error when fetching token"},
76+
}
77+
78+
for _, tc := range tests {
79+
got, err := tg.Get(context.Background(), types.NamespacedName{Namespace: tc.namespace, Name: tc.serviceAccountName})
80+
if err != nil {
81+
t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, err)
82+
assert.EqualError(t, err, tc.errorMsg)
83+
} else {
84+
t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, got)
85+
assert.Equal(t, tc.want, got, tc.errorMsg)
86+
}
87+
}
88+
}

0 commit comments

Comments
 (0)