Skip to content

Commit 9812856

Browse files
author
Kubernetes Submit Queue
authored
Merge pull request kubernetes#45317 from ericchiang/oidc-client-update
Automatic merge from submit-queue oidc client plugin: reduce round trips and fix scopes requested This PR attempts to simplify the OpenID Connect client plugin to reduce round trips. The steps taken by the client are now: * If ID Token isn't expired: * Do nothing. * If ID Token is expired: * Query /.well-known discovery URL to find token_endpoint. * Use an OAuth2 client and refresh token to request new ID token. This avoids the previous pattern of always initializing a client, which would hit the /.well-known endpoint several times. The client no longer does token validation since the server already does this. As a result, this code no longer imports github.com/coreos/go-oidc, instead just using golang.org/x/oauth2 for refreshing. Overall reduction in tests because we're not verify as many things on the client side. For example, we're no longer validating the id_token signature (again, because it's being done on the server side). This has been manually tested against dex, and I hope to continue to test this over the 1.7 release cycle. cc @mlbiam @frodenas @curtisallen @jsloyer @rithujohn191 @philips @kubernetes/sig-auth-pr-reviews ```release-note NONE ``` Updates kubernetes#42654 Closes kubernetes#37875 Closes kubernetes#37874
2 parents ee0de5f + 6915f85 commit 9812856

File tree

3 files changed

+192
-412
lines changed

3 files changed

+192
-412
lines changed

staging/src/k8s.io/client-go/plugin/pkg/client/auth/oidc/BUILD

+1-9
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,15 @@ go_test(
1313
srcs = ["oidc_test.go"],
1414
library = ":go_default_library",
1515
tags = ["automanaged"],
16-
deps = [
17-
"//vendor/github.com/coreos/go-oidc/jose:go_default_library",
18-
"//vendor/github.com/coreos/go-oidc/key:go_default_library",
19-
"//vendor/github.com/coreos/go-oidc/oauth2:go_default_library",
20-
"//vendor/k8s.io/client-go/plugin/pkg/auth/authenticator/token/oidc/testing:go_default_library",
21-
],
2216
)
2317

2418
go_library(
2519
name = "go_default_library",
2620
srcs = ["oidc.go"],
2721
tags = ["automanaged"],
2822
deps = [
29-
"//vendor/github.com/coreos/go-oidc/jose:go_default_library",
30-
"//vendor/github.com/coreos/go-oidc/oauth2:go_default_library",
31-
"//vendor/github.com/coreos/go-oidc/oidc:go_default_library",
3223
"//vendor/github.com/golang/glog:go_default_library",
24+
"//vendor/golang.org/x/oauth2:go_default_library",
3325
"//vendor/k8s.io/client-go/rest:go_default_library",
3426
],
3527
)

staging/src/k8s.io/client-go/plugin/pkg/client/auth/oidc/oidc.go

+128-94
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@ limitations under the License.
1717
package oidc
1818

1919
import (
20+
"context"
2021
"encoding/base64"
22+
"encoding/json"
2123
"errors"
2224
"fmt"
25+
"io/ioutil"
2326
"net/http"
2427
"strings"
2528
"sync"
2629
"time"
2730

28-
"github.com/coreos/go-oidc/jose"
29-
"github.com/coreos/go-oidc/oauth2"
30-
"github.com/coreos/go-oidc/oidc"
3131
"github.com/golang/glog"
32-
32+
"golang.org/x/oauth2"
3333
restclient "k8s.io/client-go/rest"
3434
)
3535

@@ -39,9 +39,11 @@ const (
3939
cfgClientSecret = "client-secret"
4040
cfgCertificateAuthority = "idp-certificate-authority"
4141
cfgCertificateAuthorityData = "idp-certificate-authority-data"
42-
cfgExtraScopes = "extra-scopes"
4342
cfgIDToken = "id-token"
4443
cfgRefreshToken = "refresh-token"
44+
45+
// Unused. Scopes aren't sent during refreshing.
46+
cfgExtraScopes = "extra-scopes"
4547
)
4648

4749
func init() {
@@ -59,9 +61,12 @@ const expiryDelta = 10 * time.Second
5961

6062
var cache = newClientCache()
6163

62-
// Like TLS transports, keep a cache of OIDC clients indexed by issuer URL.
64+
// Like TLS transports, keep a cache of OIDC clients indexed by issuer URL. This ensures
65+
// current requests from different clients don't concurrently attempt to refresh the same
66+
// set of credentials.
6367
type clientCache struct {
64-
mu sync.RWMutex
68+
mu sync.RWMutex
69+
6570
cache map[cacheKey]*oidcAuthProvider
6671
}
6772

@@ -72,27 +77,22 @@ func newClientCache() *clientCache {
7277
type cacheKey struct {
7378
// Canonical issuer URL string of the provider.
7479
issuerURL string
75-
76-
clientID string
77-
clientSecret string
78-
79-
// Don't use CA as cache key because we only add a cache entry if we can connect
80-
// to the issuer in the first place. A valid CA is a prerequisite.
80+
clientID string
8181
}
8282

83-
func (c *clientCache) getClient(issuer, clientID, clientSecret string) (*oidcAuthProvider, bool) {
83+
func (c *clientCache) getClient(issuer, clientID string) (*oidcAuthProvider, bool) {
8484
c.mu.RLock()
8585
defer c.mu.RUnlock()
86-
client, ok := c.cache[cacheKey{issuer, clientID, clientSecret}]
86+
client, ok := c.cache[cacheKey{issuer, clientID}]
8787
return client, ok
8888
}
8989

9090
// setClient attempts to put the client in the cache but may return any clients
9191
// with the same keys set before. This is so there's only ever one client for a provider.
92-
func (c *clientCache) setClient(issuer, clientID, clientSecret string, client *oidcAuthProvider) *oidcAuthProvider {
92+
func (c *clientCache) setClient(issuer, clientID string, client *oidcAuthProvider) *oidcAuthProvider {
9393
c.mu.Lock()
9494
defer c.mu.Unlock()
95-
key := cacheKey{issuer, clientID, clientSecret}
95+
key := cacheKey{issuer, clientID}
9696

9797
// If another client has already initialized a client for the given provider we want
9898
// to use that client instead of the one we're trying to set. This is so all transports
@@ -117,16 +117,16 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
117117
return nil, fmt.Errorf("Must provide %s", cfgClientID)
118118
}
119119

120-
clientSecret := cfg[cfgClientSecret]
121-
if clientSecret == "" {
122-
return nil, fmt.Errorf("Must provide %s", cfgClientSecret)
123-
}
124-
125120
// Check cache for existing provider.
126-
if provider, ok := cache.getClient(issuer, clientID, clientSecret); ok {
121+
if provider, ok := cache.getClient(issuer, clientID); ok {
127122
return provider, nil
128123
}
129124

125+
if len(cfg[cfgExtraScopes]) > 0 {
126+
glog.V(2).Infof("%s auth provider field depricated, refresh request don't send scopes",
127+
cfgExtraScopes)
128+
}
129+
130130
var certAuthData []byte
131131
var err error
132132
if cfg[cfgCertificateAuthorityData] != "" {
@@ -149,41 +149,20 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
149149
}
150150
hc := &http.Client{Transport: trans}
151151

152-
providerCfg, err := oidc.FetchProviderConfig(hc, issuer)
153-
if err != nil {
154-
return nil, fmt.Errorf("error fetching provider config: %v", err)
155-
}
156-
157-
scopes := strings.Split(cfg[cfgExtraScopes], ",")
158-
oidcCfg := oidc.ClientConfig{
159-
HTTPClient: hc,
160-
Credentials: oidc.ClientCredentials{
161-
ID: clientID,
162-
Secret: clientSecret,
163-
},
164-
ProviderConfig: providerCfg,
165-
Scope: append(scopes, oidc.DefaultScope...),
166-
}
167-
client, err := oidc.NewClient(oidcCfg)
168-
if err != nil {
169-
return nil, fmt.Errorf("error creating OIDC Client: %v", err)
170-
}
171-
172152
provider := &oidcAuthProvider{
173-
client: &oidcClient{client},
153+
client: hc,
154+
now: time.Now,
174155
cfg: cfg,
175156
persister: persister,
176-
now: time.Now,
177157
}
178158

179-
return cache.setClient(issuer, clientID, clientSecret, provider), nil
159+
return cache.setClient(issuer, clientID, provider), nil
180160
}
181161

182162
type oidcAuthProvider struct {
183-
// Interface rather than a raw *oidc.Client for testing.
184-
client OIDCClient
163+
client *http.Client
185164

186-
// Stubbed out for testing.
165+
// Method for determining the current time.
187166
now func() time.Time
188167

189168
// Mutex guards persisting to the kubeconfig file and allows synchronized
@@ -205,11 +184,6 @@ func (p *oidcAuthProvider) Login() error {
205184
return errors.New("not yet implemented")
206185
}
207186

208-
type OIDCClient interface {
209-
refreshToken(rt string) (oauth2.TokenResponse, error)
210-
verifyJWT(jwt *jose.JWT) error
211-
}
212-
213187
type roundTripper struct {
214188
provider *oidcAuthProvider
215189
wrapped http.RoundTripper
@@ -243,7 +217,7 @@ func (p *oidcAuthProvider) idToken() (string, error) {
243217
defer p.mu.Unlock()
244218

245219
if idToken, ok := p.cfg[cfgIDToken]; ok && len(idToken) > 0 {
246-
valid, err := verifyJWTExpiry(p.now(), idToken)
220+
valid, err := idTokenExpired(p.now, idToken)
247221
if err != nil {
248222
return "", err
249223
}
@@ -259,17 +233,27 @@ func (p *oidcAuthProvider) idToken() (string, error) {
259233
return "", errors.New("No valid id-token, and cannot refresh without refresh-token")
260234
}
261235

262-
tokens, err := p.client.refreshToken(rt)
236+
// Determine provider's OAuth2 token endpoint.
237+
tokenURL, err := tokenEndpoint(p.client, p.cfg[cfgIssuerUrl])
263238
if err != nil {
264-
return "", fmt.Errorf("could not refresh token: %v", err)
239+
return "", err
240+
}
241+
242+
config := oauth2.Config{
243+
ClientID: p.cfg[cfgClientID],
244+
ClientSecret: p.cfg[cfgClientSecret],
245+
Endpoint: oauth2.Endpoint{TokenURL: tokenURL},
265246
}
266-
jwt, err := jose.ParseJWT(tokens.IDToken)
247+
248+
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, p.client)
249+
token, err := config.TokenSource(ctx, &oauth2.Token{RefreshToken: rt}).Token()
267250
if err != nil {
268-
return "", err
251+
return "", fmt.Errorf("failed to refresh token: %v", err)
269252
}
270253

271-
if err := p.client.verifyJWT(&jwt); err != nil {
272-
return "", err
254+
idToken, ok := token.Extra("id_token").(string)
255+
if !ok {
256+
return "", fmt.Errorf("token response did not contain an id_token")
273257
}
274258

275259
// Create a new config to persist.
@@ -278,59 +262,109 @@ func (p *oidcAuthProvider) idToken() (string, error) {
278262
newCfg[key] = val
279263
}
280264

281-
if tokens.RefreshToken != "" && tokens.RefreshToken != rt {
282-
newCfg[cfgRefreshToken] = tokens.RefreshToken
265+
// Update the refresh token if the server returned another one.
266+
if token.RefreshToken != "" && token.RefreshToken != rt {
267+
newCfg[cfgRefreshToken] = token.RefreshToken
283268
}
269+
newCfg[cfgIDToken] = idToken
284270

285-
newCfg[cfgIDToken] = tokens.IDToken
271+
// Persist new config and if successful, update the in memory config.
286272
if err = p.persister.Persist(newCfg); err != nil {
287273
return "", fmt.Errorf("could not perist new tokens: %v", err)
288274
}
289-
290-
// Update the in memory config to reflect the on disk one.
291275
p.cfg = newCfg
292276

293-
return tokens.IDToken, nil
294-
}
295-
296-
// oidcClient is the real implementation of the OIDCClient interface, which is
297-
// used for testing.
298-
type oidcClient struct {
299-
client *oidc.Client
277+
return idToken, nil
300278
}
301279

302-
func (o *oidcClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
303-
oac, err := o.client.OAuthClient()
280+
// tokenEndpoint uses OpenID Connect discovery to determine the OAuth2 token
281+
// endpoint for the provider, the endpoint the client will use the refresh
282+
// token against.
283+
func tokenEndpoint(client *http.Client, issuer string) (string, error) {
284+
// Well known URL for getting OpenID Connect metadata.
285+
//
286+
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
287+
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
288+
resp, err := client.Get(wellKnown)
304289
if err != nil {
305-
return oauth2.TokenResponse{}, err
290+
return "", err
306291
}
292+
defer resp.Body.Close()
307293

308-
return oac.RequestToken(oauth2.GrantTypeRefreshToken, rt)
309-
}
294+
body, err := ioutil.ReadAll(resp.Body)
295+
if err != nil {
296+
return "", err
297+
}
298+
if resp.StatusCode != http.StatusOK {
299+
// Don't produce an error that's too huge (e.g. if we get HTML back for some reason).
300+
const n = 80
301+
if len(body) > n {
302+
body = append(body[:n], []byte("...")...)
303+
}
304+
return "", fmt.Errorf("oidc: failed to query metadata endpoint %s: %q", resp.Status, body)
305+
}
310306

311-
func (o *oidcClient) verifyJWT(jwt *jose.JWT) error {
312-
return o.client.VerifyJWT(*jwt)
307+
// Metadata object. We only care about the token_endpoint, the thing endpoint
308+
// we'll be refreshing against.
309+
//
310+
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
311+
var metadata struct {
312+
TokenURL string `json:"token_endpoint"`
313+
}
314+
if err := json.Unmarshal(body, &metadata); err != nil {
315+
return "", fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
316+
}
317+
if metadata.TokenURL == "" {
318+
return "", fmt.Errorf("oidc: discovery object doesn't contain a token_endpoint")
319+
}
320+
return metadata.TokenURL, nil
313321
}
314322

315-
func verifyJWTExpiry(now time.Time, s string) (valid bool, err error) {
316-
jwt, err := jose.ParseJWT(s)
317-
if err != nil {
318-
return false, fmt.Errorf("invalid %q", cfgIDToken)
323+
func idTokenExpired(now func() time.Time, idToken string) (bool, error) {
324+
parts := strings.Split(idToken, ".")
325+
if len(parts) != 3 {
326+
return false, fmt.Errorf("ID Token is not a valid JWT")
319327
}
320-
claims, err := jwt.Claims()
328+
329+
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
321330
if err != nil {
322331
return false, err
323332
}
333+
var claims struct {
334+
Expiry jsonTime `json:"exp"`
335+
}
336+
if err := json.Unmarshal(payload, &claims); err != nil {
337+
return false, fmt.Errorf("parsing claims: %v", err)
338+
}
339+
340+
return now().Add(expiryDelta).Before(time.Time(claims.Expiry)), nil
341+
}
324342

325-
exp, ok, err := claims.TimeClaim("exp")
326-
switch {
327-
case err != nil:
328-
return false, fmt.Errorf("failed to parse 'exp' claim: %v", err)
329-
case !ok:
330-
return false, errors.New("missing required 'exp' claim")
331-
case exp.After(now.Add(expiryDelta)):
332-
return true, nil
343+
// jsonTime is a json.Unmarshaler that parses a unix timestamp.
344+
// Because JSON numbers don't differentiate between ints and floats,
345+
// we want to ensure we can parse either.
346+
type jsonTime time.Time
347+
348+
func (j *jsonTime) UnmarshalJSON(b []byte) error {
349+
var n json.Number
350+
if err := json.Unmarshal(b, &n); err != nil {
351+
return err
352+
}
353+
var unix int64
354+
355+
if t, err := n.Int64(); err == nil {
356+
unix = t
357+
} else {
358+
f, err := n.Float64()
359+
if err != nil {
360+
return err
361+
}
362+
unix = int64(f)
333363
}
364+
*j = jsonTime(time.Unix(unix, 0))
365+
return nil
366+
}
334367

335-
return false, nil
368+
func (j jsonTime) MarshalJSON() ([]byte, error) {
369+
return json.Marshal(time.Time(j).Unix())
336370
}

0 commit comments

Comments
 (0)