@@ -17,19 +17,19 @@ limitations under the License.
17
17
package oidc
18
18
19
19
import (
20
+ "context"
20
21
"encoding/base64"
22
+ "encoding/json"
21
23
"errors"
22
24
"fmt"
25
+ "io/ioutil"
23
26
"net/http"
24
27
"strings"
25
28
"sync"
26
29
"time"
27
30
28
- "github.com/coreos/go-oidc/jose"
29
- "github.com/coreos/go-oidc/oauth2"
30
- "github.com/coreos/go-oidc/oidc"
31
31
"github.com/golang/glog"
32
-
32
+ "golang.org/x/oauth2"
33
33
restclient "k8s.io/client-go/rest"
34
34
)
35
35
@@ -39,9 +39,11 @@ const (
39
39
cfgClientSecret = "client-secret"
40
40
cfgCertificateAuthority = "idp-certificate-authority"
41
41
cfgCertificateAuthorityData = "idp-certificate-authority-data"
42
- cfgExtraScopes = "extra-scopes"
43
42
cfgIDToken = "id-token"
44
43
cfgRefreshToken = "refresh-token"
44
+
45
+ // Unused. Scopes aren't sent during refreshing.
46
+ cfgExtraScopes = "extra-scopes"
45
47
)
46
48
47
49
func init () {
@@ -59,9 +61,12 @@ const expiryDelta = 10 * time.Second
59
61
60
62
var cache = newClientCache ()
61
63
62
- // Like TLS transports, keep a cache of OIDC clients indexed by issuer URL.
64
+ // Like TLS transports, keep a cache of OIDC clients indexed by issuer URL. This ensures
65
+ // current requests from different clients don't concurrently attempt to refresh the same
66
+ // set of credentials.
63
67
type clientCache struct {
64
- mu sync.RWMutex
68
+ mu sync.RWMutex
69
+
65
70
cache map [cacheKey ]* oidcAuthProvider
66
71
}
67
72
@@ -72,27 +77,22 @@ func newClientCache() *clientCache {
72
77
type cacheKey struct {
73
78
// Canonical issuer URL string of the provider.
74
79
issuerURL string
75
-
76
- clientID string
77
- clientSecret string
78
-
79
- // Don't use CA as cache key because we only add a cache entry if we can connect
80
- // to the issuer in the first place. A valid CA is a prerequisite.
80
+ clientID string
81
81
}
82
82
83
- func (c * clientCache ) getClient (issuer , clientID , clientSecret string ) (* oidcAuthProvider , bool ) {
83
+ func (c * clientCache ) getClient (issuer , clientID string ) (* oidcAuthProvider , bool ) {
84
84
c .mu .RLock ()
85
85
defer c .mu .RUnlock ()
86
- client , ok := c .cache [cacheKey {issuer , clientID , clientSecret }]
86
+ client , ok := c .cache [cacheKey {issuer , clientID }]
87
87
return client , ok
88
88
}
89
89
90
90
// setClient attempts to put the client in the cache but may return any clients
91
91
// with the same keys set before. This is so there's only ever one client for a provider.
92
- func (c * clientCache ) setClient (issuer , clientID , clientSecret string , client * oidcAuthProvider ) * oidcAuthProvider {
92
+ func (c * clientCache ) setClient (issuer , clientID string , client * oidcAuthProvider ) * oidcAuthProvider {
93
93
c .mu .Lock ()
94
94
defer c .mu .Unlock ()
95
- key := cacheKey {issuer , clientID , clientSecret }
95
+ key := cacheKey {issuer , clientID }
96
96
97
97
// If another client has already initialized a client for the given provider we want
98
98
// to use that client instead of the one we're trying to set. This is so all transports
@@ -117,16 +117,16 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
117
117
return nil , fmt .Errorf ("Must provide %s" , cfgClientID )
118
118
}
119
119
120
- clientSecret := cfg [cfgClientSecret ]
121
- if clientSecret == "" {
122
- return nil , fmt .Errorf ("Must provide %s" , cfgClientSecret )
123
- }
124
-
125
120
// Check cache for existing provider.
126
- if provider , ok := cache .getClient (issuer , clientID , clientSecret ); ok {
121
+ if provider , ok := cache .getClient (issuer , clientID ); ok {
127
122
return provider , nil
128
123
}
129
124
125
+ if len (cfg [cfgExtraScopes ]) > 0 {
126
+ glog .V (2 ).Infof ("%s auth provider field depricated, refresh request don't send scopes" ,
127
+ cfgExtraScopes )
128
+ }
129
+
130
130
var certAuthData []byte
131
131
var err error
132
132
if cfg [cfgCertificateAuthorityData ] != "" {
@@ -149,41 +149,20 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
149
149
}
150
150
hc := & http.Client {Transport : trans }
151
151
152
- providerCfg , err := oidc .FetchProviderConfig (hc , issuer )
153
- if err != nil {
154
- return nil , fmt .Errorf ("error fetching provider config: %v" , err )
155
- }
156
-
157
- scopes := strings .Split (cfg [cfgExtraScopes ], "," )
158
- oidcCfg := oidc.ClientConfig {
159
- HTTPClient : hc ,
160
- Credentials : oidc.ClientCredentials {
161
- ID : clientID ,
162
- Secret : clientSecret ,
163
- },
164
- ProviderConfig : providerCfg ,
165
- Scope : append (scopes , oidc .DefaultScope ... ),
166
- }
167
- client , err := oidc .NewClient (oidcCfg )
168
- if err != nil {
169
- return nil , fmt .Errorf ("error creating OIDC Client: %v" , err )
170
- }
171
-
172
152
provider := & oidcAuthProvider {
173
- client : & oidcClient {client },
153
+ client : hc ,
154
+ now : time .Now ,
174
155
cfg : cfg ,
175
156
persister : persister ,
176
- now : time .Now ,
177
157
}
178
158
179
- return cache .setClient (issuer , clientID , clientSecret , provider ), nil
159
+ return cache .setClient (issuer , clientID , provider ), nil
180
160
}
181
161
182
162
type oidcAuthProvider struct {
183
- // Interface rather than a raw *oidc.Client for testing.
184
- client OIDCClient
163
+ client * http.Client
185
164
186
- // Stubbed out for testing .
165
+ // Method for determining the current time .
187
166
now func () time.Time
188
167
189
168
// Mutex guards persisting to the kubeconfig file and allows synchronized
@@ -205,11 +184,6 @@ func (p *oidcAuthProvider) Login() error {
205
184
return errors .New ("not yet implemented" )
206
185
}
207
186
208
- type OIDCClient interface {
209
- refreshToken (rt string ) (oauth2.TokenResponse , error )
210
- verifyJWT (jwt * jose.JWT ) error
211
- }
212
-
213
187
type roundTripper struct {
214
188
provider * oidcAuthProvider
215
189
wrapped http.RoundTripper
@@ -243,7 +217,7 @@ func (p *oidcAuthProvider) idToken() (string, error) {
243
217
defer p .mu .Unlock ()
244
218
245
219
if idToken , ok := p .cfg [cfgIDToken ]; ok && len (idToken ) > 0 {
246
- valid , err := verifyJWTExpiry (p .now () , idToken )
220
+ valid , err := idTokenExpired (p .now , idToken )
247
221
if err != nil {
248
222
return "" , err
249
223
}
@@ -259,17 +233,27 @@ func (p *oidcAuthProvider) idToken() (string, error) {
259
233
return "" , errors .New ("No valid id-token, and cannot refresh without refresh-token" )
260
234
}
261
235
262
- tokens , err := p .client .refreshToken (rt )
236
+ // Determine provider's OAuth2 token endpoint.
237
+ tokenURL , err := tokenEndpoint (p .client , p .cfg [cfgIssuerUrl ])
263
238
if err != nil {
264
- return "" , fmt .Errorf ("could not refresh token: %v" , err )
239
+ return "" , err
240
+ }
241
+
242
+ config := oauth2.Config {
243
+ ClientID : p .cfg [cfgClientID ],
244
+ ClientSecret : p .cfg [cfgClientSecret ],
245
+ Endpoint : oauth2.Endpoint {TokenURL : tokenURL },
265
246
}
266
- jwt , err := jose .ParseJWT (tokens .IDToken )
247
+
248
+ ctx := context .WithValue (context .Background (), oauth2 .HTTPClient , p .client )
249
+ token , err := config .TokenSource (ctx , & oauth2.Token {RefreshToken : rt }).Token ()
267
250
if err != nil {
268
- return "" , err
251
+ return "" , fmt . Errorf ( "failed to refresh token: %v" , err )
269
252
}
270
253
271
- if err := p .client .verifyJWT (& jwt ); err != nil {
272
- return "" , err
254
+ idToken , ok := token .Extra ("id_token" ).(string )
255
+ if ! ok {
256
+ return "" , fmt .Errorf ("token response did not contain an id_token" )
273
257
}
274
258
275
259
// Create a new config to persist.
@@ -278,59 +262,109 @@ func (p *oidcAuthProvider) idToken() (string, error) {
278
262
newCfg [key ] = val
279
263
}
280
264
281
- if tokens .RefreshToken != "" && tokens .RefreshToken != rt {
282
- newCfg [cfgRefreshToken ] = tokens .RefreshToken
265
+ // Update the refresh token if the server returned another one.
266
+ if token .RefreshToken != "" && token .RefreshToken != rt {
267
+ newCfg [cfgRefreshToken ] = token .RefreshToken
283
268
}
269
+ newCfg [cfgIDToken ] = idToken
284
270
285
- newCfg [ cfgIDToken ] = tokens . IDToken
271
+ // Persist new config and if successful, update the in memory config.
286
272
if err = p .persister .Persist (newCfg ); err != nil {
287
273
return "" , fmt .Errorf ("could not perist new tokens: %v" , err )
288
274
}
289
-
290
- // Update the in memory config to reflect the on disk one.
291
275
p .cfg = newCfg
292
276
293
- return tokens .IDToken , nil
294
- }
295
-
296
- // oidcClient is the real implementation of the OIDCClient interface, which is
297
- // used for testing.
298
- type oidcClient struct {
299
- client * oidc.Client
277
+ return idToken , nil
300
278
}
301
279
302
- func (o * oidcClient ) refreshToken (rt string ) (oauth2.TokenResponse , error ) {
303
- oac , err := o .client .OAuthClient ()
280
+ // tokenEndpoint uses OpenID Connect discovery to determine the OAuth2 token
281
+ // endpoint for the provider, the endpoint the client will use the refresh
282
+ // token against.
283
+ func tokenEndpoint (client * http.Client , issuer string ) (string , error ) {
284
+ // Well known URL for getting OpenID Connect metadata.
285
+ //
286
+ // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
287
+ wellKnown := strings .TrimSuffix (issuer , "/" ) + "/.well-known/openid-configuration"
288
+ resp , err := client .Get (wellKnown )
304
289
if err != nil {
305
- return oauth2. TokenResponse {} , err
290
+ return "" , err
306
291
}
292
+ defer resp .Body .Close ()
307
293
308
- return oac .RequestToken (oauth2 .GrantTypeRefreshToken , rt )
309
- }
294
+ body , err := ioutil .ReadAll (resp .Body )
295
+ if err != nil {
296
+ return "" , err
297
+ }
298
+ if resp .StatusCode != http .StatusOK {
299
+ // Don't produce an error that's too huge (e.g. if we get HTML back for some reason).
300
+ const n = 80
301
+ if len (body ) > n {
302
+ body = append (body [:n ], []byte ("..." )... )
303
+ }
304
+ return "" , fmt .Errorf ("oidc: failed to query metadata endpoint %s: %q" , resp .Status , body )
305
+ }
310
306
311
- func (o * oidcClient ) verifyJWT (jwt * jose.JWT ) error {
312
- return o .client .VerifyJWT (* jwt )
307
+ // Metadata object. We only care about the token_endpoint, the thing endpoint
308
+ // we'll be refreshing against.
309
+ //
310
+ // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
311
+ var metadata struct {
312
+ TokenURL string `json:"token_endpoint"`
313
+ }
314
+ if err := json .Unmarshal (body , & metadata ); err != nil {
315
+ return "" , fmt .Errorf ("oidc: failed to decode provider discovery object: %v" , err )
316
+ }
317
+ if metadata .TokenURL == "" {
318
+ return "" , fmt .Errorf ("oidc: discovery object doesn't contain a token_endpoint" )
319
+ }
320
+ return metadata .TokenURL , nil
313
321
}
314
322
315
- func verifyJWTExpiry (now time.Time , s string ) (valid bool , err error ) {
316
- jwt , err := jose . ParseJWT ( s )
317
- if err != nil {
318
- return false , fmt .Errorf ("invalid %q" , cfgIDToken )
323
+ func idTokenExpired (now func () time.Time , idToken string ) (bool , error ) {
324
+ parts := strings . Split ( idToken , "." )
325
+ if len ( parts ) != 3 {
326
+ return false , fmt .Errorf ("ID Token is not a valid JWT" )
319
327
}
320
- claims , err := jwt .Claims ()
328
+
329
+ payload , err := base64 .RawURLEncoding .DecodeString (parts [1 ])
321
330
if err != nil {
322
331
return false , err
323
332
}
333
+ var claims struct {
334
+ Expiry jsonTime `json:"exp"`
335
+ }
336
+ if err := json .Unmarshal (payload , & claims ); err != nil {
337
+ return false , fmt .Errorf ("parsing claims: %v" , err )
338
+ }
339
+
340
+ return now ().Add (expiryDelta ).Before (time .Time (claims .Expiry )), nil
341
+ }
324
342
325
- exp , ok , err := claims .TimeClaim ("exp" )
326
- switch {
327
- case err != nil :
328
- return false , fmt .Errorf ("failed to parse 'exp' claim: %v" , err )
329
- case ! ok :
330
- return false , errors .New ("missing required 'exp' claim" )
331
- case exp .After (now .Add (expiryDelta )):
332
- return true , nil
343
+ // jsonTime is a json.Unmarshaler that parses a unix timestamp.
344
+ // Because JSON numbers don't differentiate between ints and floats,
345
+ // we want to ensure we can parse either.
346
+ type jsonTime time.Time
347
+
348
+ func (j * jsonTime ) UnmarshalJSON (b []byte ) error {
349
+ var n json.Number
350
+ if err := json .Unmarshal (b , & n ); err != nil {
351
+ return err
352
+ }
353
+ var unix int64
354
+
355
+ if t , err := n .Int64 (); err == nil {
356
+ unix = t
357
+ } else {
358
+ f , err := n .Float64 ()
359
+ if err != nil {
360
+ return err
361
+ }
362
+ unix = int64 (f )
333
363
}
364
+ * j = jsonTime (time .Unix (unix , 0 ))
365
+ return nil
366
+ }
334
367
335
- return false , nil
368
+ func (j jsonTime ) MarshalJSON () ([]byte , error ) {
369
+ return json .Marshal (time .Time (j ).Unix ())
336
370
}
0 commit comments