|
6 | 6 | package s3
|
7 | 7 |
|
8 | 8 | import (
|
| 9 | + "context" |
9 | 10 | "encoding/base64"
|
| 11 | + "io" |
10 | 12 | "net/http"
|
| 13 | + "net/http/httptest" |
| 14 | + "os" |
11 | 15 | "testing"
|
| 16 | + "time" |
| 17 | + |
| 18 | + "github.com/go-kit/log" |
12 | 19 |
|
13 | 20 | "github.com/grafana/dskit/flagext"
|
14 | 21 | "github.com/stretchr/testify/assert"
|
@@ -158,3 +165,127 @@ func TestConfig_Validate(t *testing.T) {
|
158 | 165 | })
|
159 | 166 | }
|
160 | 167 | }
|
| 168 | + |
| 169 | +type testRoundTripper struct { |
| 170 | + roundTrip func(r *http.Request) (*http.Response, error) |
| 171 | +} |
| 172 | + |
| 173 | +func (t *testRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { |
| 174 | + return t.roundTrip(r) |
| 175 | +} |
| 176 | + |
| 177 | +func handleSTSRequest(t *testing.T, r *http.Request, w http.ResponseWriter) { |
| 178 | + body, err := io.ReadAll(r.Body) |
| 179 | + require.NoError(t, err) |
| 180 | + |
| 181 | + require.Contains(t, string(body), "RoleArn=arn%3Ahello-world") |
| 182 | + require.Contains(t, string(body), "WebIdentityToken=my-web-token") |
| 183 | + require.Contains(t, string(body), "Action=AssumeRoleWithWebIdentity") |
| 184 | + |
| 185 | + w.WriteHeader(200) |
| 186 | + _, err = w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?> |
| 187 | + <AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/"> |
| 188 | + <AssumeRoleWithWebIdentityResult> |
| 189 | + <Credentials> |
| 190 | + <AccessKeyId>test-key</AccessKeyId> |
| 191 | + <SecretAccessKey>test-secret</SecretAccessKey> |
| 192 | + <SessionToken>test-token</SessionToken> |
| 193 | + <Expiration>` + time.Now().Add(time.Hour).Format(time.RFC3339) + `</Expiration> |
| 194 | + </Credentials> |
| 195 | + </AssumeRoleWithWebIdentityResult> |
| 196 | + <ResponseMetadata> |
| 197 | + <RequestId>test-request-id</RequestId> |
| 198 | + </ResponseMetadata> |
| 199 | + </AssumeRoleWithWebIdentityResponse>`)) |
| 200 | + require.NoError(t, err) |
| 201 | + |
| 202 | +} |
| 203 | + |
| 204 | +func overrideEnv(t testing.TB, kv ...string) { |
| 205 | + old := make([]string, len(kv)) |
| 206 | + for i := 0; i < len(kv); i += 2 { |
| 207 | + k := kv[i] |
| 208 | + v := kv[i+1] |
| 209 | + old[i] = k |
| 210 | + old[i+1] = os.Getenv(k) |
| 211 | + os.Setenv(k, v) |
| 212 | + } |
| 213 | + t.Cleanup(func() { |
| 214 | + for i := 0; i < len(old); i += 2 { |
| 215 | + os.Setenv(old[i], old[i+1]) |
| 216 | + } |
| 217 | + }) |
| 218 | +} |
| 219 | + |
| 220 | +func TestAWSSTSWebIdentity(t *testing.T) { |
| 221 | + logger := log.NewNopLogger() |
| 222 | + tmpDir := t.TempDir() |
| 223 | + |
| 224 | + // override env variables, will be cleaned up by t.Cleanup |
| 225 | + overrideEnv(t, |
| 226 | + "AWS_WEB_IDENTITY_TOKEN_FILE", tmpDir+"/token", |
| 227 | + "AWS_ROLE_ARN", "arn:hello-world", |
| 228 | + "AWS_DEFAULT_REGION", "eu-central-1", |
| 229 | + "AWS_CONFIG_FILE", "/dev/null", // dont accidentally use real config |
| 230 | + "AWS_ACCESS_KEY_ID", "", // dont use real credentials |
| 231 | + "AWS_SECRET_ACCESS_KEY", "", // dont use real credentials |
| 232 | + ) |
| 233 | + |
| 234 | + rt := &testRoundTripper{ |
| 235 | + roundTrip: func(r *http.Request) (*http.Response, error) { |
| 236 | + w := httptest.NewRecorder() |
| 237 | + if r.Body != nil { |
| 238 | + defer r.Body.Close() |
| 239 | + } |
| 240 | + switch r.URL.String() { |
| 241 | + case "https://sts.amazonaws.com": |
| 242 | + handleSTSRequest(t, r, w) |
| 243 | + case "https://eu-central-1.amazonaws.com/pyroscope-test-bucket/test": |
| 244 | + assert.Equal(t, "GET", r.Method) |
| 245 | + assert.Contains(t, r.Header.Get("Authorization"), "AWS4-HMAC-SHA256 Credential=test-key") |
| 246 | + w.Header().Set("Last-Modified", time.Now().Format("Mon, 2 Jan 2006 15:04:05 GMT")) |
| 247 | + w.WriteHeader(200) |
| 248 | + _, err := w.Write([]byte("test")) |
| 249 | + require.NoError(t, err) |
| 250 | + default: |
| 251 | + w.WriteHeader(404) |
| 252 | + _, err := w.Write([]byte("unexpected")) |
| 253 | + require.NoError(t, err) |
| 254 | + t.Errorf("unexpected request: %s", r.URL.Host) |
| 255 | + t.FailNow() |
| 256 | + } |
| 257 | + return w.Result(), nil |
| 258 | + }, |
| 259 | + } |
| 260 | + oldDefaultTransport := http.DefaultTransport |
| 261 | + oldDefaultClient := http.DefaultClient |
| 262 | + http.DefaultTransport = rt |
| 263 | + http.DefaultClient = &http.Client{ |
| 264 | + Transport: rt, |
| 265 | + } |
| 266 | + // restore default transport and client |
| 267 | + t.Cleanup(func() { |
| 268 | + http.DefaultTransport = oldDefaultTransport |
| 269 | + http.DefaultClient = oldDefaultClient |
| 270 | + }) |
| 271 | + |
| 272 | + // mock a web token |
| 273 | + err := os.WriteFile(tmpDir+"/token", []byte("my-web-token"), 0644) |
| 274 | + require.NoError(t, err) |
| 275 | + |
| 276 | + cfg := Config{ |
| 277 | + SignatureVersion: SignatureVersionV4, |
| 278 | + BucketName: "pyroscope-test-bucket", |
| 279 | + Region: "eu-central-1", |
| 280 | + Endpoint: "eu-central-1.amazonaws.com", |
| 281 | + BucketLookupType: AutoLookup, |
| 282 | + } |
| 283 | + |
| 284 | + cfg.HTTP.Transport = rt |
| 285 | + r, err := NewBucketClient(cfg, "test", logger) |
| 286 | + require.NoError(t, err) |
| 287 | + |
| 288 | + _, err = r.Get(context.Background(), "test") |
| 289 | + require.NoError(t, err) |
| 290 | + |
| 291 | +} |
0 commit comments