diff --git a/google/externalaccount/basecredentials.go b/google/externalaccount/basecredentials.go index 6c81a6872..b5c0b1dbc 100644 --- a/google/externalaccount/basecredentials.go +++ b/google/externalaccount/basecredentials.go @@ -109,6 +109,7 @@ package externalaccount import ( "context" + "errors" "fmt" "net/http" "regexp" @@ -119,6 +120,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google/internal/impersonate" "golang.org/x/oauth2/google/internal/stsexchange" + "golang.org/x/oauth2/internal" ) const ( @@ -464,6 +466,10 @@ func (ts tokenSource) Token() (*oauth2.Token, error) { } stsResp, err := stsexchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options) if err != nil { + var rErr *internal.RetrieveError + if errors.As(err, &rErr) { + return nil, (*oauth2.RetrieveError)(rErr) + } return nil, err } diff --git a/google/internal/stsexchange/sts_exchange.go b/google/internal/stsexchange/sts_exchange.go index 1a0bebd15..9eb16e1c0 100644 --- a/google/internal/stsexchange/sts_exchange.go +++ b/google/internal/stsexchange/sts_exchange.go @@ -16,6 +16,7 @@ import ( "strings" "golang.org/x/oauth2" + "golang.org/x/oauth2/internal" ) func defaultHeader() http.Header { @@ -87,7 +88,10 @@ func makeRequest(ctx context.Context, endpoint string, data url.Values, authenti return nil, err } if c := resp.StatusCode; c < 200 || c > 299 { - return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body) + return nil, &internal.RetrieveError{ + Response: resp, + Body: body, + } } var stsResp Response err = json.Unmarshal(body, &stsResp) diff --git a/sts_exchange_test.go b/sts_exchange_test.go new file mode 100644 index 000000000..8f4a6b7c9 --- /dev/null +++ b/sts_exchange_test.go @@ -0,0 +1,75 @@ +package oauth2_test + +import ( + "bytes" + "context" + "errors" + "golang.org/x/oauth2" + "io" + "net/http" + "testing" + + "golang.org/x/oauth2/google/externalaccount" +) + +var _ externalaccount.SubjectTokenSupplier = fakeSupplier{} + +type fakeSupplier struct{} + +func (f fakeSupplier) SubjectToken(_ context.Context, _ externalaccount.SupplierOptions) (string, error) { + return "test-token", nil +} + +var _ http.RoundTripper = fakeRT{} + +type fakeRT struct { + body string +} + +func (f fakeRT) RoundTrip(_ *http.Request) (*http.Response, error) { + status := http.StatusUnauthorized + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(bytes.NewReader([]byte(f.body))), + }, nil +} + +func TestSTSExchange_error_handling(t *testing.T) { + t.Parallel() + + // Arrange + body := `{"reason": "client does not exist"}` + client := &http.Client{Transport: fakeRT{body: body}} + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) + + source, err := externalaccount.NewTokenSource(ctx, externalaccount.Config{ + Audience: "aud", + SubjectTokenType: "test-token", + TokenURL: "url", + Scopes: []string{}, + SubjectTokenSupplier: fakeSupplier{}, + }) + if err != nil { + t.Errorf("got unexpected error while token source building: %s", err) + } + + // Act + _, err = source.Token() + + // Assert + if err == nil { + t.Errorf("expected token issuance error") + } + var retrieveErr *oauth2.RetrieveError + if !errors.As(err, &retrieveErr) { + t.Errorf("expected an instance of RetrieveError, got error: %s", err) + } + + if string(retrieveErr.Body) != body { + t.Errorf("expected body content `%s`, got: `%s`", body, retrieveErr.Body) + } + + if retrieveErr.Response.StatusCode != http.StatusUnauthorized { + t.Errorf("expected unathorized status code, got: %s", retrieveErr.ErrorCode) + } +}