Skip to content

Commit ea2d84c

Browse files
authored
feat(auth): discover accounts/roles linked to IdC (AWS SSO) #3023
Problem: Users can add connections to IAM Identity Center but there's no easy way to use them for the AWS explorer. Solution: Automatically discover linked accounts/roles and list them in the connections picker.
1 parent 2a1c728 commit ea2d84c

File tree

11 files changed

+497
-77
lines changed

11 files changed

+497
-77
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "Feature",
3+
"description": "auth: AWS accounts and roles from IAM Identity Center are automatically discovered by the Toolkit when selecting a connection."
4+
}

src/credentials/auth.ts

Lines changed: 215 additions & 54 deletions
Large diffs are not rendered by default.

src/credentials/providers/credentials.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ export function credentialsProviderToTelemetryType(o: CredentialsProviderType):
8686
return 'envVars'
8787
case 'profile':
8888
return 'sharedCredentials'
89+
case 'sso':
90+
return 'iamIdentityCenter'
8991
default:
9092
return 'other'
9193
}

src/credentials/providers/credentialsProviderManager.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ export class CredentialsProviderManager {
8787
}
8888

8989
public addProvider(provider: CredentialsProvider) {
90+
this.removeProvider(provider.getCredentialsId())
9091
this.providers.push(provider)
9192
}
9293

@@ -102,6 +103,13 @@ export class CredentialsProviderManager {
102103
this.providerFactories.push(...factory)
103104
}
104105

106+
public removeProvider(id: CredentialsId) {
107+
const idx = this.providers.findIndex(p => isEqual(id, p.getCredentialsId()))
108+
if (idx !== -1) {
109+
this.providers.splice(idx, 1)
110+
}
111+
}
112+
105113
public static getInstance(): CredentialsProviderManager {
106114
if (!CredentialsProviderManager.INSTANCE) {
107115
CredentialsProviderManager.INSTANCE = new CredentialsProviderManager()
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*!
2+
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
import { Credentials } from '@aws-sdk/types'
7+
import { CredentialType } from '../../shared/telemetry/telemetry.gen'
8+
import { getStringHash } from '../../shared/utilities/textUtilities'
9+
import { CredentialsId, CredentialsProvider, CredentialsProviderType } from './credentials'
10+
import { SsoClient } from '../sso/clients'
11+
import { SsoAccessTokenProvider } from '../sso/ssoAccessTokenProvider'
12+
13+
export class SsoCredentialsProvider implements CredentialsProvider {
14+
public constructor(
15+
private readonly id: CredentialsId,
16+
private readonly client: SsoClient,
17+
private readonly tokenProvider: SsoAccessTokenProvider,
18+
private readonly accountId: string,
19+
private readonly roleName: string
20+
) {}
21+
22+
public async isAvailable(): Promise<boolean> {
23+
return true
24+
}
25+
26+
public getCredentialsId(): CredentialsId {
27+
return this.id
28+
}
29+
30+
public getProviderType(): CredentialsProviderType {
31+
return this.id.credentialSource
32+
}
33+
34+
public getTelemetryType(): CredentialType {
35+
return 'ssoProfile'
36+
}
37+
38+
public getDefaultRegion(): string | undefined {
39+
return this.client.region
40+
}
41+
42+
public getHashCode(): string {
43+
return getStringHash(this.accountId + this.roleName)
44+
}
45+
46+
public async canAutoConnect(): Promise<boolean> {
47+
return this.hasToken()
48+
}
49+
50+
public async getCredentials(): Promise<Credentials> {
51+
if (!(await this.hasToken())) {
52+
await this.tokenProvider.createToken()
53+
}
54+
55+
return this.client.getRoleCredentials({
56+
accountId: this.accountId,
57+
roleName: this.roleName,
58+
})
59+
}
60+
61+
private async hasToken() {
62+
return (await this.tokenProvider.getToken()) !== undefined
63+
}
64+
}

src/credentials/sso/clients.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ type PromisifyClient<T> = {
9595
}
9696

9797
export class SsoClient {
98+
public get region() {
99+
const region = this.client.config.region
100+
101+
return typeof region === 'string' ? (region as string) : undefined
102+
}
103+
98104
public constructor(
99105
private readonly client: PromisifyClient<SSO>,
100106
private readonly provider: SsoAccessTokenProvider

src/shared/utilities/asyncCollection.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ export interface AsyncCollection<T> extends AsyncIterable<T> {
3131
find<U extends T>(predicate: (item: T) => item is U): Promise<U | undefined>
3232
find<U extends T>(predicate: (item: T) => boolean): Promise<U | undefined>
3333

34+
/**
35+
* Catches all errors from the underlying iterable(s).
36+
*
37+
* Note that currently the contiuation behavior is highly dependent on the
38+
* underlying implementations. For example, a `for await` loop cannot be
39+
* continued if any of the resulting values are rejected.
40+
*/
41+
catch<U>(handler: (error: unknown) => Promise<U> | U): AsyncCollection<T | U>
42+
3443
/**
3544
* Uses only the first 'count' number of values returned by the generator.
3645
*/
@@ -86,6 +95,8 @@ export function toCollection<T>(generator: () => AsyncGenerator<T, T | undefined
8695
flatten: () => toCollection<SafeUnboxIterable<T>>(() => delegateGenerator(generator(), flatten)),
8796
filter: <U extends T>(predicate: Predicate<T, U>) =>
8897
toCollection<U>(() => filterGenerator<T, U>(generator(), predicate)),
98+
catch: <U>(fn: (error: unknown) => Promise<U> | U) =>
99+
toCollection<T | U>(() => catchGenerator(generator(), fn)),
89100
map: <U>(fn: (item: T) => Promise<U> | U) => toCollection<U>(() => mapGenerator(generator(), fn)),
90101
limit: (count: number) => toCollection(() => delegateGenerator(generator(), takeFrom(count))),
91102
promise: () => promise(iterable),
@@ -244,3 +255,23 @@ async function find<T, U extends T>(iterable: AsyncIterable<T>, predicate: (item
244255
}
245256
}
246257
}
258+
259+
async function* catchGenerator<T, U, R = T>(
260+
generator: AsyncGenerator<T, R | undefined | void>,
261+
fn: (error: unknown) => Promise<U> | U
262+
): AsyncGenerator<T | U, R | U | undefined | void> {
263+
while (true) {
264+
try {
265+
const { value, done } = await generator.next()
266+
if (done) {
267+
return value
268+
}
269+
yield value
270+
} catch (err) {
271+
// Catching an error when the generator would normally
272+
// report 'done' means that the 'done' value would be
273+
// replaced by `undefined`.
274+
yield fn(err)
275+
}
276+
}
277+
}

src/shared/utilities/tsUtils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export function selectFrom<T, K extends keyof T>(obj: T, ...props: K[]): { [P in
3737
return props.map(p => [p, obj[p]] as const).reduce((a, [k, v]) => ((a[k] = v), a), {} as { [P in K]: T[P] })
3838
}
3939

40-
export function isNonNullable<T>(obj: T): obj is NonNullable<T> {
40+
export function isNonNullable<T>(obj: T | void): obj is NonNullable<T> {
4141
// eslint-disable-next-line no-null/no-null
4242
return obj !== undefined && obj !== null
4343
}

src/test/credentials/auth.test.ts

Lines changed: 149 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,22 @@
44
*/
55

66
import * as assert from 'assert'
7-
import { AuthNode, isSsoConnection, promptForConnection, SsoConnection } from '../../credentials/auth'
7+
import * as sinon from 'sinon'
8+
import {
9+
AuthNode,
10+
Connection,
11+
isIamConnection,
12+
isSsoConnection,
13+
promptForConnection,
14+
ssoAccountAccessScopes,
15+
} from '../../credentials/auth'
816
import { ToolkitError } from '../../shared/errors'
917
import { assertTreeItem } from '../shared/treeview/testUtil'
1018
import { getTestWindow } from '../shared/vscode/window'
1119
import { captureEventOnce } from '../testUtil'
12-
import { createSsoProfile, createTestAuth } from './testUtil'
20+
import { createBuilderIdProfile, createSsoProfile, createTestAuth } from './testUtil'
21+
import { toCollection } from '../../shared/utilities/asyncCollection'
22+
import globals from '../../shared/extensionGlobals'
1323

1424
const ssoProfile = createSsoProfile()
1525
const scopedSsoProfile = createSsoProfile({ scopes: ['foo'] })
@@ -181,15 +191,15 @@ describe('Auth', function () {
181191
assert.ok(await conn.getToken())
182192
})
183193

184-
describe('SSO Connections', function () {
185-
async function runExpiredGetTokenFlow(conn: SsoConnection, selection: string | RegExp) {
186-
const token = conn.getToken()
187-
const message = await getTestWindow().waitForMessage(expiredConnPattern)
188-
message.selectItem(selection)
194+
async function runExpiredConnectionFlow(conn: Connection, selection: string | RegExp) {
195+
const creds = conn.type === 'sso' ? conn.getToken() : conn.getCredentials()
196+
const message = await getTestWindow().waitForMessage(expiredConnPattern)
197+
message.selectItem(selection)
189198

190-
return token
191-
}
199+
return creds
200+
}
192201

202+
describe('SSO Connections', function () {
193203
it('creates a new token if one does not exist', async function () {
194204
const conn = await auth.createConnection(ssoProfile)
195205
const provider = auth.getTestTokenProvider(conn)
@@ -198,7 +208,7 @@ describe('Auth', function () {
198208

199209
it('prompts the user if the token is invalid or expired', async function () {
200210
const conn = await auth.createInvalidSsoConnection(ssoProfile)
201-
const token = await runExpiredGetTokenFlow(conn, /login/i)
211+
const token = await runExpiredConnectionFlow(conn, /login/i)
202212
assert.notStrictEqual(token, undefined)
203213
})
204214

@@ -207,7 +217,7 @@ describe('Auth', function () {
207217
await auth.useConnection(conn)
208218
await auth.invalidateCachedCredentials(conn)
209219

210-
const token = runExpiredGetTokenFlow(conn, /no/i)
220+
const token = runExpiredConnectionFlow(conn, /no/i)
211221
await assert.rejects(token, ToolkitError)
212222

213223
assert.strictEqual(auth.activeConnection?.state, 'invalid')
@@ -217,12 +227,139 @@ describe('Auth', function () {
217227
const err1 = new ToolkitError('test', { code: 'test' })
218228
const conn = await auth.createConnection(ssoProfile)
219229
auth.getTestTokenProvider(conn)?.getToken.rejects(err1)
220-
const err2 = await runExpiredGetTokenFlow(conn, /no/i).catch(e => e)
230+
const err2 = await runExpiredConnectionFlow(conn, /no/i).catch(e => e)
221231
assert.ok(err2 instanceof ToolkitError)
222232
assert.strictEqual(err2.cause, err1)
223233
})
224234
})
225235

236+
describe('Linked Connections', function () {
237+
const linkedSsoProfile = createSsoProfile({ scopes: ssoAccountAccessScopes })
238+
const accountRoles = [
239+
{ accountId: '1245678910', roleName: 'foo' },
240+
{ accountId: '9876543210', roleName: 'foo' },
241+
{ accountId: '9876543210', roleName: 'bar' },
242+
]
243+
244+
beforeEach(function () {
245+
auth.ssoClient.listAccounts.returns(
246+
toCollection(async function* () {
247+
yield [{ accountId: '1245678910' }, { accountId: '9876543210' }]
248+
})
249+
)
250+
251+
auth.ssoClient.listAccountRoles.callsFake(req =>
252+
toCollection(async function* () {
253+
yield accountRoles.filter(i => i.accountId === req.accountId)
254+
})
255+
)
256+
257+
auth.ssoClient.getRoleCredentials.resolves({
258+
accessKeyId: 'xxx',
259+
secretAccessKey: 'xxx',
260+
expiration: new Date(Date.now() + 1000000),
261+
})
262+
263+
sinon.stub(globals.loginManager, 'validateCredentials').resolves('')
264+
})
265+
266+
afterEach(function () {
267+
sinon.restore()
268+
})
269+
270+
it('lists linked conections for SSO connections', async function () {
271+
await auth.createConnection(linkedSsoProfile)
272+
const connections = await auth.listAndTraverseConnections().promise()
273+
assert.deepStrictEqual(
274+
connections.map(c => c.type),
275+
['sso', 'iam', 'iam', 'iam']
276+
)
277+
})
278+
279+
it('does not gather linked accounts when calling `listConnections`', async function () {
280+
await auth.createConnection(linkedSsoProfile)
281+
const connections = await auth.listConnections()
282+
assert.deepStrictEqual(
283+
connections.map(c => c.type),
284+
['sso']
285+
)
286+
})
287+
288+
it('caches linked conections when the source connection becomes invalid', async function () {
289+
const conn = await auth.createConnection(linkedSsoProfile)
290+
await auth.listAndTraverseConnections().promise()
291+
await auth.invalidateCachedCredentials(conn)
292+
293+
const connections = await auth.listConnections()
294+
assert.deepStrictEqual(
295+
connections.map(c => c.type),
296+
['sso', 'iam', 'iam', 'iam']
297+
)
298+
})
299+
300+
it('gracefully handles source connections becoming invalid when discovering linked accounts', async function () {
301+
await auth.createConnection(linkedSsoProfile)
302+
auth.ssoClient.listAccounts.rejects(new Error('No access'))
303+
const connections = await auth.listAndTraverseConnections().promise()
304+
assert.deepStrictEqual(
305+
connections.map(c => c.type),
306+
['sso']
307+
)
308+
})
309+
310+
it('removes linked connections when the source connection is deleted', async function () {
311+
const conn = await auth.createConnection(linkedSsoProfile)
312+
await auth.listAndTraverseConnections().promise()
313+
await auth.deleteConnection(conn)
314+
315+
assert.deepStrictEqual(await auth.listAndTraverseConnections().promise(), [])
316+
})
317+
318+
it('prompts the user to reauthenticate if the source connection becomes invalid', async function () {
319+
const source = await auth.createConnection(linkedSsoProfile)
320+
const conn = await auth.listAndTraverseConnections().find(c => isIamConnection(c) && c.id.includes('sso'))
321+
assert.ok(conn)
322+
await auth.useConnection(conn)
323+
await auth.reauthenticate(conn)
324+
await auth.invalidateCachedCredentials(conn)
325+
await auth.invalidateCachedCredentials(source)
326+
327+
await runExpiredConnectionFlow(conn, /login/i)
328+
assert.strictEqual(auth.getConnectionState(source), 'valid')
329+
assert.strictEqual(auth.getConnectionState(conn), 'valid')
330+
})
331+
332+
describe('Multiple Connections', function () {
333+
const otherProfile = createBuilderIdProfile({ scopes: ssoAccountAccessScopes })
334+
335+
// Equivalent profiles from multiple sources is a fairly rare situation right now
336+
// Ideally they would be de-duped although the implementation can be tricky
337+
it('can handle multiple SSO connection and does not de-dupe', async function () {
338+
await auth.createConnection(linkedSsoProfile)
339+
await auth.createConnection(otherProfile)
340+
341+
const connections = await auth.listAndTraverseConnections().promise()
342+
assert.deepStrictEqual(
343+
connections.map(c => c.type),
344+
['sso', 'sso', 'iam', 'iam', 'iam', 'iam', 'iam', 'iam'],
345+
'Expected two SSO connections and 3 IAM connections for each SSO connection'
346+
)
347+
})
348+
349+
it('does not stop discovery if one connection fails', async function () {
350+
const otherProfile = createBuilderIdProfile({ scopes: ssoAccountAccessScopes })
351+
await auth.createConnection(linkedSsoProfile)
352+
await auth.createConnection(otherProfile)
353+
auth.ssoClient.listAccounts.onFirstCall().rejects(new Error('No access'))
354+
const connections = await auth.listAndTraverseConnections().promise()
355+
assert.deepStrictEqual(
356+
connections.map(c => c.type),
357+
['sso', 'sso', 'iam', 'iam', 'iam']
358+
)
359+
})
360+
})
361+
})
362+
226363
describe('AuthNode', function () {
227364
it('shows a message to create a connection if no connections exist', async function () {
228365
const node = new AuthNode(auth)

src/test/credentials/provider/sharedCredentialsProvider.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ describe('SharedCredentialsProvider', async function () {
304304
`
305305

306306
beforeEach(function () {
307-
const client = stub(SsoClient)
307+
const client = stub(SsoClient, { region: 'foo' })
308308
client.getRoleCredentials.callsFake(async request => {
309309
assert.strictEqual(request.accountId, '012345678910')
310310
assert.strictEqual(request.roleName, 'MyRole')

0 commit comments

Comments
 (0)