Skip to content

Commit 2270f5d

Browse files
author
Julien Pivotto
authored
Read oauth2 secret from file (#293)
* Read oauth2 secret from file Signed-off-by: Julien Pivotto <[email protected]>
1 parent 10f0b67 commit 2270f5d

File tree

3 files changed

+206
-20
lines changed

3 files changed

+206
-20
lines changed

config/http_config.go

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,20 @@ func (u URL) MarshalYAML() (interface{}, error) {
112112

113113
// OAuth2 is the oauth2 client configuration.
114114
type OAuth2 struct {
115-
ClientID string `yaml:"client_id"`
116-
ClientSecret Secret `yaml:"client_secret"`
117-
Scopes []string `yaml:"scopes,omitempty"`
118-
TokenURL string `yaml:"token_url"`
119-
EndpointParams map[string]string `yaml:"endpoint_params,omitempty"`
115+
ClientID string `yaml:"client_id"`
116+
ClientSecret Secret `yaml:"client_secret"`
117+
ClientSecretFile string `yaml:"client_secret_file"`
118+
Scopes []string `yaml:"scopes,omitempty"`
119+
TokenURL string `yaml:"token_url"`
120+
EndpointParams map[string]string `yaml:"endpoint_params,omitempty"`
121+
}
122+
123+
// SetDirectory joins any relative file paths with dir.
124+
func (a *OAuth2) SetDirectory(dir string) {
125+
if a == nil {
126+
return
127+
}
128+
a.ClientSecretFile = JoinDir(dir, a.ClientSecretFile)
120129
}
121130

122131
// HTTPClientConfig configures an HTTP client.
@@ -151,6 +160,7 @@ func (c *HTTPClientConfig) SetDirectory(dir string) {
151160
c.TLSConfig.SetDirectory(dir)
152161
c.BasicAuth.SetDirectory(dir)
153162
c.Authorization.SetDirectory(dir)
163+
c.OAuth2.SetDirectory(dir)
154164
c.BearerTokenFile = JoinDir(dir, c.BearerTokenFile)
155165
}
156166

@@ -196,8 +206,13 @@ func (c *HTTPClientConfig) Validate() error {
196206
c.BearerTokenFile = ""
197207
}
198208
}
199-
if c.BasicAuth != nil && c.OAuth2 != nil {
200-
return fmt.Errorf("at most one of basic_auth, oauth2 & authorization must be configured")
209+
if c.OAuth2 != nil {
210+
if c.BasicAuth != nil {
211+
return fmt.Errorf("at most one of basic_auth, oauth2 & authorization must be configured")
212+
}
213+
if len(c.OAuth2.ClientSecret) > 0 && len(c.OAuth2.ClientSecretFile) > 0 {
214+
return fmt.Errorf("at most one of oauth2 client_secret & client_secret_file must be configured")
215+
}
201216
}
202217
return nil
203218
}
@@ -347,7 +362,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
347362
}
348363

349364
if cfg.OAuth2 != nil {
350-
rt = cfg.OAuth2.NewOAuth2RoundTripper(context.Background(), rt)
365+
rt = NewOAuth2RoundTripper(cfg.OAuth2, rt)
351366
}
352367
// Return a new configured RoundTripper.
353368
return rt, nil
@@ -462,20 +477,72 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() {
462477
}
463478
}
464479

465-
func (c *OAuth2) NewOAuth2RoundTripper(ctx context.Context, next http.RoundTripper) http.RoundTripper {
466-
config := &clientcredentials.Config{
467-
ClientID: c.ClientID,
468-
ClientSecret: string(c.ClientSecret),
469-
Scopes: c.Scopes,
470-
TokenURL: c.TokenURL,
471-
EndpointParams: mapToValues(c.EndpointParams),
480+
type oauth2RoundTripper struct {
481+
config *OAuth2
482+
rt http.RoundTripper
483+
next http.RoundTripper
484+
secret string
485+
mtx sync.RWMutex
486+
}
487+
488+
func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper) http.RoundTripper {
489+
return &oauth2RoundTripper{
490+
config: config,
491+
next: next,
472492
}
493+
}
473494

474-
tokenSource := config.TokenSource(ctx)
495+
func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
496+
var (
497+
secret string
498+
changed bool
499+
)
475500

476-
return &oauth2.Transport{
477-
Base: next,
478-
Source: tokenSource,
501+
if rt.config.ClientSecretFile != "" {
502+
data, err := ioutil.ReadFile(rt.config.ClientSecretFile)
503+
if err != nil {
504+
return nil, fmt.Errorf("unable to read oauth2 client secret file %s: %s", rt.config.ClientSecretFile, err)
505+
}
506+
secret = strings.TrimSpace(string(data))
507+
rt.mtx.RLock()
508+
changed = secret != rt.secret
509+
rt.mtx.RUnlock()
510+
}
511+
512+
if changed || rt.rt == nil {
513+
if rt.config.ClientSecret != "" {
514+
secret = string(rt.config.ClientSecret)
515+
}
516+
517+
config := &clientcredentials.Config{
518+
ClientID: rt.config.ClientID,
519+
ClientSecret: secret,
520+
Scopes: rt.config.Scopes,
521+
TokenURL: rt.config.TokenURL,
522+
EndpointParams: mapToValues(rt.config.EndpointParams),
523+
}
524+
525+
tokenSource := config.TokenSource(context.Background())
526+
527+
rt.mtx.Lock()
528+
rt.secret = secret
529+
rt.rt = &oauth2.Transport{
530+
Base: rt.next,
531+
Source: tokenSource,
532+
}
533+
rt.mtx.Unlock()
534+
}
535+
536+
rt.mtx.RLock()
537+
currentRT := rt.rt
538+
rt.mtx.RUnlock()
539+
return currentRT.RoundTrip(req)
540+
}
541+
542+
func (rt *oauth2RoundTripper) CloseIdleConnections() {
543+
// OAuth2 RT does not support CloseIdleConnections() but the next RT might.
544+
if ci, ok := rt.next.(closeIdler); ok {
545+
ci.CloseIdleConnections()
479546
}
480547
}
481548

config/http_config_test.go

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ var invalidHTTPClientConfigs = []struct {
103103
httpClientConfigFile: "testdata/http.conf.auth-creds-no-basic.bad.yaml",
104104
errMsg: `authorization type cannot be set to "basic", use "basic_auth" instead`,
105105
},
106+
{
107+
httpClientConfigFile: "testdata/http.conf.oauth2-secret-and-file-set.bad.yml",
108+
errMsg: "at most one of oauth2 client_secret & client_secret_file must be configured",
109+
},
106110
}
107111

108112
func newTestServer(handler func(w http.ResponseWriter, r *http.Request)) (*httptest.Server, error) {
@@ -1136,7 +1140,7 @@ endpoint_params:
11361140
t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig)
11371141
}
11381142

1139-
rt := expectedConfig.NewOAuth2RoundTripper(context.Background(), http.DefaultTransport)
1143+
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
11401144

11411145
client := http.Client{
11421146
Transport: rt,
@@ -1148,3 +1152,115 @@ endpoint_params:
11481152
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
11491153
}
11501154
}
1155+
1156+
func TestOAuth2WithFile(t *testing.T) {
1157+
var expectedAuth *string
1158+
var previousAuth string
1159+
tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1160+
auth := r.Header.Get("Authorization")
1161+
if auth != *expectedAuth {
1162+
t.Fatalf("bad auth, expected %s, got %s", *expectedAuth, auth)
1163+
}
1164+
if auth == previousAuth {
1165+
t.Fatal("token endpoint called twice")
1166+
}
1167+
previousAuth = auth
1168+
res, _ := json.Marshal(testServerResponse{
1169+
AccessToken: "12345",
1170+
TokenType: "Bearer",
1171+
})
1172+
w.Header().Add("Content-Type", "application/json")
1173+
_, _ = w.Write(res)
1174+
}))
1175+
defer tokenTS.Close()
1176+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1177+
auth := r.Header.Get("Authorization")
1178+
if auth != "Bearer 12345" {
1179+
t.Fatalf("bad auth, expected %s, got %s", "Bearer 12345", auth)
1180+
}
1181+
fmt.Fprintln(w, "Hello, client")
1182+
}))
1183+
defer ts.Close()
1184+
1185+
secretFile, err := ioutil.TempFile("", "oauth2_secret")
1186+
if err != nil {
1187+
t.Fatal(err)
1188+
}
1189+
defer os.Remove(secretFile.Name())
1190+
1191+
var yamlConfig = fmt.Sprintf(`
1192+
client_id: 1
1193+
client_secret_file: %s
1194+
scopes:
1195+
- A
1196+
- B
1197+
token_url: %s
1198+
endpoint_params:
1199+
hi: hello
1200+
`, secretFile.Name(), tokenTS.URL)
1201+
expectedConfig := OAuth2{
1202+
ClientID: "1",
1203+
ClientSecretFile: secretFile.Name(),
1204+
Scopes: []string{"A", "B"},
1205+
EndpointParams: map[string]string{"hi": "hello"},
1206+
TokenURL: tokenTS.URL,
1207+
}
1208+
1209+
var unmarshalledConfig OAuth2
1210+
err = yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig)
1211+
if err != nil {
1212+
t.Fatalf("Expected no error unmarshalling yaml, got %v", err)
1213+
}
1214+
if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) {
1215+
t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig)
1216+
}
1217+
1218+
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
1219+
1220+
client := http.Client{
1221+
Transport: rt,
1222+
}
1223+
1224+
tk := "Basic MToxMjM0NTY="
1225+
expectedAuth = &tk
1226+
if _, err := secretFile.Write([]byte("123456")); err != nil {
1227+
t.Fatal(err)
1228+
}
1229+
resp, err := client.Get(ts.URL)
1230+
if err != nil {
1231+
t.Fatal(err)
1232+
}
1233+
1234+
authorization := resp.Request.Header.Get("Authorization")
1235+
if authorization != "Bearer 12345" {
1236+
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
1237+
}
1238+
1239+
// Making a second request with the same file content should not re-call the token API.
1240+
resp, err = client.Get(ts.URL)
1241+
if err != nil {
1242+
t.Fatal(err)
1243+
}
1244+
1245+
tk = "Basic MToxMjM0NTY3"
1246+
expectedAuth = &tk
1247+
if _, err := secretFile.Write([]byte("7")); err != nil {
1248+
t.Fatal(err)
1249+
}
1250+
1251+
_, err = client.Get(ts.URL)
1252+
if err != nil {
1253+
t.Fatal(err)
1254+
}
1255+
1256+
// Making a second request with the same file content should not re-call the token API.
1257+
_, err = client.Get(ts.URL)
1258+
if err != nil {
1259+
t.Fatal(err)
1260+
}
1261+
1262+
authorization = resp.Request.Header.Get("Authorization")
1263+
if authorization != "Bearer 12345" {
1264+
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
1265+
}
1266+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
oauth2:
2+
client_secret: "mysecret"
3+
client_secret_file: "mysecret"

0 commit comments

Comments
 (0)