From 8591d35b376777d92b64a8ea060505e0eee5b7b2 Mon Sep 17 00:00:00 2001
From: Egor Olefirenko <egor.olefirenko892@gmail.com>
Date: Fri, 13 Sep 2024 21:52:46 +0300
Subject: [PATCH 1/2] oauth2/google/stsexchange: provide response body using
 RetrieveError

---
 google/externalaccount/basecredentials.go   | 6 ++++++
 google/internal/stsexchange/sts_exchange.go | 6 +++++-
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/google/externalaccount/basecredentials.go b/google/externalaccount/basecredentials.go
index 6c81a6872..b5c0b1dbc 100644
--- a/google/externalaccount/basecredentials.go
+++ b/google/externalaccount/basecredentials.go
@@ -109,6 +109,7 @@ package externalaccount
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"net/http"
 	"regexp"
@@ -119,6 +120,7 @@ import (
 	"golang.org/x/oauth2"
 	"golang.org/x/oauth2/google/internal/impersonate"
 	"golang.org/x/oauth2/google/internal/stsexchange"
+	"golang.org/x/oauth2/internal"
 )
 
 const (
@@ -464,6 +466,10 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
 	}
 	stsResp, err := stsexchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
 	if err != nil {
+		var rErr *internal.RetrieveError
+		if errors.As(err, &rErr) {
+			return nil, (*oauth2.RetrieveError)(rErr)
+		}
 		return nil, err
 	}
 
diff --git a/google/internal/stsexchange/sts_exchange.go b/google/internal/stsexchange/sts_exchange.go
index 1a0bebd15..9eb16e1c0 100644
--- a/google/internal/stsexchange/sts_exchange.go
+++ b/google/internal/stsexchange/sts_exchange.go
@@ -16,6 +16,7 @@ import (
 	"strings"
 
 	"golang.org/x/oauth2"
+	"golang.org/x/oauth2/internal"
 )
 
 func defaultHeader() http.Header {
@@ -87,7 +88,10 @@ func makeRequest(ctx context.Context, endpoint string, data url.Values, authenti
 		return nil, err
 	}
 	if c := resp.StatusCode; c < 200 || c > 299 {
-		return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
+		return nil, &internal.RetrieveError{
+			Response: resp,
+			Body:     body,
+		}
 	}
 	var stsResp Response
 	err = json.Unmarshal(body, &stsResp)

From 880daa6c59338918f9bca9455b9b4e0bba8f573b Mon Sep 17 00:00:00 2001
From: Egor Olefirenko <egor.olefirenko892@gmail.com>
Date: Fri, 13 Sep 2024 22:00:05 +0300
Subject: [PATCH 2/2] oauth2/google/stsexchange: implement basic test

---
 sts_exchange_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 75 insertions(+)
 create mode 100644 sts_exchange_test.go

diff --git a/sts_exchange_test.go b/sts_exchange_test.go
new file mode 100644
index 000000000..8f4a6b7c9
--- /dev/null
+++ b/sts_exchange_test.go
@@ -0,0 +1,75 @@
+package oauth2_test
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"golang.org/x/oauth2"
+	"io"
+	"net/http"
+	"testing"
+
+	"golang.org/x/oauth2/google/externalaccount"
+)
+
+var _ externalaccount.SubjectTokenSupplier = fakeSupplier{}
+
+type fakeSupplier struct{}
+
+func (f fakeSupplier) SubjectToken(_ context.Context, _ externalaccount.SupplierOptions) (string, error) {
+	return "test-token", nil
+}
+
+var _ http.RoundTripper = fakeRT{}
+
+type fakeRT struct {
+	body string
+}
+
+func (f fakeRT) RoundTrip(_ *http.Request) (*http.Response, error) {
+	status := http.StatusUnauthorized
+	return &http.Response{
+		StatusCode: status,
+		Body:       io.NopCloser(bytes.NewReader([]byte(f.body))),
+	}, nil
+}
+
+func TestSTSExchange_error_handling(t *testing.T) {
+	t.Parallel()
+
+	// Arrange
+	body := `{"reason": "client does not exist"}`
+	client := &http.Client{Transport: fakeRT{body: body}}
+	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client)
+
+	source, err := externalaccount.NewTokenSource(ctx, externalaccount.Config{
+		Audience:             "aud",
+		SubjectTokenType:     "test-token",
+		TokenURL:             "url",
+		Scopes:               []string{},
+		SubjectTokenSupplier: fakeSupplier{},
+	})
+	if err != nil {
+		t.Errorf("got unexpected error while token source building: %s", err)
+	}
+
+	// Act
+	_, err = source.Token()
+
+	// Assert
+	if err == nil {
+		t.Errorf("expected token issuance error")
+	}
+	var retrieveErr *oauth2.RetrieveError
+	if !errors.As(err, &retrieveErr) {
+		t.Errorf("expected an instance of RetrieveError, got error: %s", err)
+	}
+
+	if string(retrieveErr.Body) != body {
+		t.Errorf("expected body content `%s`, got: `%s`", body, retrieveErr.Body)
+	}
+
+	if retrieveErr.Response.StatusCode != http.StatusUnauthorized {
+		t.Errorf("expected unathorized status code, got: %s", retrieveErr.ErrorCode)
+	}
+}