Skip to content

Commit b77e9cd

Browse files
committed
Adding a token getter to get service account tokens
1 parent 872b7f7 commit b77e9cd

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed
+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package authentication
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
authenticationv1 "k8s.io/api/authentication/v1"
9+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
10+
"k8s.io/apimachinery/pkg/types"
11+
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
12+
"k8s.io/utils/ptr"
13+
)
14+
15+
type TokenGetter struct {
16+
client corev1.ServiceAccountsGetter
17+
expirationSeconds int64
18+
tokens map[types.NamespacedName]*authenticationv1.TokenRequestStatus
19+
tokenLocks keyLock[types.NamespacedName]
20+
mu sync.RWMutex
21+
}
22+
23+
type TokenGetterOption func(*TokenGetter)
24+
25+
// Returns a token getter that can fetch tokens given a service account.
26+
// The token getter also caches tokens which helps reduce the number of requests to the API Server.
27+
// In case a cached token is expiring a fresh token is created.
28+
func NewTokenGetter(client corev1.ServiceAccountsGetter, options ...TokenGetterOption) *TokenGetter {
29+
tokenGetter := &TokenGetter{
30+
client: client,
31+
expirationSeconds: int64(5 * time.Minute), // default token ttl
32+
tokens: map[types.NamespacedName]*authenticationv1.TokenRequestStatus{},
33+
tokenLocks: newKeyLock[types.NamespacedName](),
34+
}
35+
36+
for _, opt := range options {
37+
opt(tokenGetter)
38+
}
39+
40+
return tokenGetter
41+
}
42+
43+
func WithExpirationSeconds(expirationSeconds int64) TokenGetterOption {
44+
return func(tg *TokenGetter) {
45+
tg.expirationSeconds = expirationSeconds
46+
}
47+
}
48+
49+
type keyLock[K comparable] struct {
50+
locks map[K]*sync.Mutex
51+
mu sync.Mutex
52+
}
53+
54+
func newKeyLock[K comparable]() keyLock[K] {
55+
return keyLock[K]{locks: map[K]*sync.Mutex{}}
56+
}
57+
58+
func (k *keyLock[K]) Lock(key K) {
59+
k.getLock(key).Lock()
60+
}
61+
62+
func (k *keyLock[K]) Unlock(key K) {
63+
k.getLock(key).Unlock()
64+
}
65+
66+
func (k *keyLock[K]) getLock(key K) *sync.Mutex {
67+
k.mu.Lock()
68+
defer k.mu.Unlock()
69+
70+
lock, ok := k.locks[key]
71+
if !ok {
72+
lock = &sync.Mutex{}
73+
k.locks[key] = lock
74+
}
75+
return lock
76+
}
77+
78+
// Returns a token from the cache if available and not expiring, otherwise creates a new token and caches it.
79+
func (t *TokenGetter) Get(ctx context.Context, key types.NamespacedName) (string, error) {
80+
t.tokenLocks.Lock(key)
81+
defer t.tokenLocks.Unlock(key)
82+
83+
t.mu.RLock()
84+
token, ok := t.tokens[key]
85+
t.mu.RUnlock()
86+
87+
expireTime := time.Time{}
88+
if ok {
89+
expireTime = token.ExpirationTimestamp.Time
90+
}
91+
92+
expirationSecondsAfterNow := metav1.Now().Add(time.Duration(t.expirationSeconds))
93+
if expireTime.Before(expirationSecondsAfterNow) {
94+
var err error
95+
token, err = t.getToken(ctx, key)
96+
if err != nil {
97+
return "", err
98+
}
99+
t.mu.Lock()
100+
t.tokens[key] = token
101+
t.mu.Unlock()
102+
}
103+
104+
return token.Token, nil
105+
}
106+
107+
func (t *TokenGetter) getToken(ctx context.Context, key types.NamespacedName) (*authenticationv1.TokenRequestStatus, error) {
108+
req, err := t.client.ServiceAccounts(key.Namespace).CreateToken(ctx,
109+
key.Name,
110+
&authenticationv1.TokenRequest{
111+
Spec: authenticationv1.TokenRequestSpec{ExpirationSeconds: ptr.To[int64](t.expirationSeconds)},
112+
}, metav1.CreateOptions{})
113+
if err != nil {
114+
return nil, err
115+
}
116+
return &req.Status, nil
117+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package authentication
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
authenticationv1 "k8s.io/api/authentication/v1"
11+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
12+
"k8s.io/apimachinery/pkg/runtime"
13+
"k8s.io/apimachinery/pkg/types"
14+
"k8s.io/client-go/kubernetes/fake"
15+
ctest "k8s.io/client-go/testing"
16+
)
17+
18+
func TestNewTokenGetter(t *testing.T) {
19+
fakeClient := fake.NewSimpleClientset()
20+
fakeClient.PrependReactor("create", "serviceaccounts/token", func(action ctest.Action) (bool, runtime.Object, error) {
21+
act, ok := action.(ctest.CreateActionImpl)
22+
if !ok {
23+
return false, nil, nil
24+
}
25+
tokenRequest := act.GetObject().(*authenticationv1.TokenRequest)
26+
var err error
27+
if act.Name == "test-service-account-1" {
28+
tokenRequest.Status = authenticationv1.TokenRequestStatus{
29+
Token: "test-token-1",
30+
ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(5 * time.Minute)),
31+
}
32+
}
33+
if act.Name == "test-service-account-2" {
34+
tokenRequest.Status = authenticationv1.TokenRequestStatus{
35+
Token: "test-token-2",
36+
ExpirationTimestamp: metav1.NewTime(metav1.Now().Add(1 * time.Second)),
37+
}
38+
}
39+
if act.Name == "test-service-account-3" {
40+
tokenRequest = nil
41+
err = fmt.Errorf("error when fetching token")
42+
}
43+
return true, tokenRequest, err
44+
})
45+
46+
tg := NewTokenGetter(fakeClient.CoreV1(), WithExpirationSeconds(int64(5*time.Minute)))
47+
48+
tests := []struct {
49+
testName string
50+
serviceAccountName string
51+
namespace string
52+
want string
53+
errorMsg string
54+
}{
55+
{"Testing NewTokenGetter with fake client", "test-service-account-1",
56+
"test-namespace-1", "test-token-1", "failed to get token"},
57+
{"Testing getting token from cache", "test-service-account-1",
58+
"test-namespace-1", "test-token-1", "failed to get token from cache"},
59+
{"Testing getting short lived token from fake client", "test-service-account-2",
60+
"test-namespace-2", "test-token-2", "failed to get token"},
61+
{"Testing getting expired token from cache", "test-service-account-2",
62+
"test-namespace-2", "test-token-2", "failed to refresh token"},
63+
{"Testing error when getting token from fake client", "test-service-account-3",
64+
"test-namespace-3", "error when fetching token", "error when fetching token"},
65+
}
66+
67+
for _, tc := range tests {
68+
got, err := tg.Get(context.Background(), types.NamespacedName{Namespace: tc.namespace, Name: tc.serviceAccountName})
69+
if err != nil {
70+
t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, err)
71+
assert.EqualError(t, err, tc.errorMsg)
72+
} else {
73+
t.Logf("%s: expected: %v, got: %v", tc.testName, tc.want, got)
74+
assert.Equal(t, tc.want, got, tc.errorMsg)
75+
}
76+
}
77+
}

0 commit comments

Comments
 (0)