Skip to content

Commit 0c28513

Browse files
fix(sdk): Disable concurrency on rewrap
- Adds new `concurrencyLimit` decrypt param, which sets a thread pool (kinda) - Defaults value to 1
1 parent 8b1de24 commit 0c28513

File tree

5 files changed

+206
-61
lines changed

5 files changed

+206
-61
lines changed

lib/src/concurrency.ts

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
type LabelledSuccess<T> = { lid: string; value: Promise<T> };
2+
type LabelledFailure = { lid: string; e: any };
3+
4+
async function labelPromise<T>(label: string, promise: Promise<T>): Promise<LabelledSuccess<T>> {
5+
try {
6+
const value = await promise;
7+
return { lid: label, value: Promise.resolve(value) };
8+
} catch (e) {
9+
throw { lid: label, e };
10+
}
11+
}
12+
13+
// Pooled variant of Promise.all; implements most of the logic of the real all,
14+
// but with a pool size of n. Rejects on first reject, or returns a list
15+
// of all successful responses. Operates with at most n 'active' promises at a time.
16+
// For tracking purposes, all promises must have a unique identifier.
17+
export async function allPool<T>(n: number, p: Record<string, Promise<T>>): Promise<Awaited<T>[]> {
18+
const pool: Record<string, Promise<LabelledSuccess<T>>> = {};
19+
const resolved: Awaited<T>[] = [];
20+
for (const [id, job] of Object.entries(p)) {
21+
// while the size of jobs to do is greater than n,
22+
// let n jobs run and take the first one to finish out of the pool
23+
pool[id] = labelPromise(id, job);
24+
if (Object.keys(pool).length > n - 1) {
25+
const promises = Object.values(pool);
26+
try {
27+
const { lid, value } = await Promise.race(promises);
28+
resolved.push(await value);
29+
console.log(`succeeded on promise ${lid}`, value);
30+
delete pool[lid];
31+
} catch (err) {
32+
const { lid, e } = err as LabelledFailure;
33+
console.warn(`failed on promise ${lid}`, err);
34+
throw e;
35+
}
36+
}
37+
}
38+
try {
39+
for (const labelled of await Promise.all(Object.values(pool))) {
40+
console.log(`real.all succeeded on promise ${labelled.lid}`, labelled);
41+
resolved.push(await labelled.value);
42+
}
43+
} catch (err) {
44+
if ('lid' in err && 'e' in err) {
45+
throw err.e;
46+
} else {
47+
throw err;
48+
}
49+
}
50+
return resolved;
51+
}
52+
53+
// Pooled variant of promise.any; implements most of the logic of the real any,
54+
// but with a pool size of n, and returns the first successful promise,
55+
// operating with at most n 'active' promises at a time.
56+
export async function anyPool<T>(n: number, p: Record<string, Promise<T>>): Promise<Awaited<T>> {
57+
const pool: Record<string, Promise<LabelledSuccess<T>>> = {};
58+
const rejections = [];
59+
for (const [id, job] of Object.entries(p)) {
60+
// while the size of jobs to do is greater than n,
61+
// let n jobs run and take the first one to finish out of the pool
62+
pool[id] = labelPromise(id, job);
63+
if (Object.keys(pool).length > n - 1) {
64+
const promises = Object.values(pool);
65+
try {
66+
const { lid, value } = await Promise.race(promises);
67+
console.log(`any succeeded on promise ${lid}`, value);
68+
return await value;
69+
} catch (error) {
70+
const { lid, e } = error;
71+
rejections.push(e);
72+
delete pool[lid];
73+
console.log(`any failed on promise ${lid}`, e);
74+
}
75+
}
76+
}
77+
try {
78+
const { lid, value } = await Promise.any(Object.values(pool));
79+
console.log(`real.any succeeded on promise ${lid}`);
80+
return await value;
81+
} catch (errors) {
82+
console.log(`real.any failed`, errors);
83+
if (errors instanceof AggregateError) {
84+
for (const error of errors.errors) {
85+
if ('lid' in error && 'e' in error) {
86+
rejections.push(error.e);
87+
} else {
88+
rejections.push(error);
89+
}
90+
}
91+
} else {
92+
rejections.push(errors);
93+
}
94+
}
95+
throw new AggregateError(rejections);
96+
}

lib/tdf3/src/client/builders.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ export type DecryptParams = {
519519
keyMiddleware?: DecryptKeyMiddleware;
520520
streamMiddleware?: DecryptStreamMiddleware;
521521
assertionVerificationKeys?: AssertionVerificationKeys;
522+
concurrencyLimit?: number;
522523
noVerifyAssertions?: boolean;
523524
};
524525

@@ -685,6 +686,11 @@ class DecryptParamsBuilder {
685686
return freeze({ ..._params });
686687
}
687688

689+
withConcurrencyLimit(limit: number): DecryptParamsBuilder {
690+
this._params.concurrencyLimit = limit;
691+
return this;
692+
}
693+
688694
/**
689695
* Generate a parameters object in the form expected by <code>{@link Client#decrypt|decrypt}</code>.
690696
* <br/><br/>

lib/tdf3/src/client/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ export class Client {
562562
streamMiddleware = async (stream: DecoratedReadableStream) => stream,
563563
assertionVerificationKeys,
564564
noVerifyAssertions,
565+
concurrencyLimit = 1,
565566
}: DecryptParams): Promise<DecoratedReadableStream> {
566567
const dpopKeys = await this.dpopKeys;
567568
let entityObject;
@@ -587,6 +588,7 @@ export class Client {
587588
allowList: this.allowedKases,
588589
authProvider: this.authProvider,
589590
chunker,
591+
concurrencyLimit,
590592
cryptoService: this.cryptoService,
591593
dpopKeys,
592594
entity: entityObject,

lib/tdf3/src/tdf.ts

Lines changed: 37 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ import PolicyObject from '../../src/tdf/PolicyObject.js';
6565
import { type CryptoService, type DecryptResult } from './crypto/declarations.js';
6666
import { CentralDirectory } from './utils/zip-reader.js';
6767
import { SymmetricCipher } from './ciphers/symmetric-cipher-base.js';
68+
import { allPool, anyPool } from '../../src/concurrency.js';
6869

6970
// TODO: input validation on manifest JSON
7071
const DEFAULT_SEGMENT_SIZE = 1024 * 1024;
@@ -163,6 +164,7 @@ export type DecryptConfiguration = {
163164
fileStreamServiceWorker?: string;
164165
assertionVerificationKeys?: AssertionVerificationKeys;
165166
noVerifyAssertions?: boolean;
167+
concurrencyLimit?: number;
166168
};
167169

168170
export type UpsertConfiguration = {
@@ -904,17 +906,24 @@ export function splitLookupTableFactory(
904906
return splitPotentials;
905907
}
906908

909+
type RewrapResponseData = {
910+
key: Uint8Array;
911+
metadata: Record<string, unknown>;
912+
};
913+
907914
async function unwrapKey({
908915
manifest,
909916
allowedKases,
910917
authProvider,
911918
dpopKeys,
919+
concurrencyLimit,
912920
entity,
913921
cryptoService,
914922
}: {
915923
manifest: Manifest;
916924
allowedKases: OriginAllowList;
917925
authProvider: AuthProvider | AppIdAuthProvider;
926+
concurrencyLimit?: number;
918927
dpopKeys: CryptoKeyPair;
919928
entity: EntityObject | undefined;
920929
cryptoService: CryptoService;
@@ -928,7 +937,7 @@ async function unwrapKey({
928937
const splitPotentials = splitLookupTableFactory(keyAccess, allowedKases);
929938
const isAppIdProvider = authProvider && isAppIdProviderCheck(authProvider);
930939

931-
async function tryKasRewrap(keySplitInfo: KeyAccessObject) {
940+
async function tryKasRewrap(keySplitInfo: KeyAccessObject): Promise<RewrapResponseData> {
932941
const url = `${keySplitInfo.url}/${isAppIdProvider ? '' : 'v2/'}rewrap`;
933942
const ephemeralEncryptionKeys = await cryptoService.cryptoToPemPair(
934943
await cryptoService.generateKeyPair()
@@ -982,77 +991,44 @@ async function unwrapKey({
982991
};
983992
}
984993

985-
// Get unique split IDs to determine if we have an OR or AND condition
986-
const splitIds = new Set(Object.keys(splitPotentials));
987-
988-
// If we have only one split ID, it's an OR condition
989-
if (splitIds.size === 1) {
990-
const [splitId] = splitIds;
994+
const poolSize = concurrencyLimit === undefined ? 1 : concurrencyLimit > 1 ? concurrencyLimit : 1;
995+
const splitPromises: Record<string, Promise<RewrapResponseData>> = {};
996+
for (const splitId of Object.keys(splitPotentials)) {
991997
const potentials = splitPotentials[splitId];
992-
993-
try {
994-
// OR condition: Try all KAS servers for this split, take first success
995-
const result = await Promise.any(
996-
Object.values(potentials).map(async (keySplitInfo) => {
997-
try {
998-
return await tryKasRewrap(keySplitInfo);
999-
} catch (e) {
1000-
// Rethrow with more context
1001-
throw handleRewrapError(e as Error | AxiosError);
1002-
}
1003-
})
998+
if (!potentials || !Object.keys(potentials).length) {
999+
throw new UnsafeUrlError(
1000+
`Unreconstructable key - no valid KAS found for split ${JSON.stringify(splitId)}`,
1001+
''
10041002
);
1005-
1006-
const reconstructedKey = keyMerge([result.key]);
1007-
return {
1008-
reconstructedKeyBinary: Binary.fromArrayBuffer(reconstructedKey),
1009-
metadata: result.metadata,
1010-
};
1011-
} catch (error) {
1012-
if (error instanceof AggregateError) {
1013-
// All KAS servers failed
1014-
throw error.errors[0]; // Throw the first error since we've already wrapped them
1015-
}
1016-
throw error;
10171003
}
1018-
} else {
1019-
// AND condition: We need successful results from all different splits
1020-
const splitResults = await Promise.all(
1021-
Object.entries(splitPotentials).map(async ([splitId, potentials]) => {
1022-
if (!potentials || !Object.keys(potentials).length) {
1023-
throw new UnsafeUrlError(
1024-
`Unreconstructable key - no valid KAS found for split ${JSON.stringify(splitId)}`,
1025-
''
1026-
);
1027-
}
1028-
1004+
const anyPromises: Record<string, Promise<RewrapResponseData>> = {};
1005+
for (const [kas, keySplitInfo] of Object.entries(potentials)) {
1006+
anyPromises[kas] = (async () => {
10291007
try {
1030-
// For each split, try all potential KAS servers until one succeeds
1031-
return await Promise.any(
1032-
Object.values(potentials).map(async (keySplitInfo) => {
1033-
try {
1034-
return await tryKasRewrap(keySplitInfo);
1035-
} catch (e) {
1036-
throw handleRewrapError(e as Error | AxiosError);
1037-
}
1038-
})
1039-
);
1040-
} catch (error) {
1041-
if (error instanceof AggregateError) {
1042-
// All KAS servers for this split failed
1043-
throw error.errors[0]; // Throw the first error since we've already wrapped them
1044-
}
1045-
throw error;
1008+
return await tryKasRewrap(keySplitInfo);
1009+
} catch (e) {
1010+
throw handleRewrapError(e as Error | AxiosError);
10461011
}
1047-
})
1048-
);
1049-
1012+
})();
1013+
}
1014+
splitPromises[splitId] = anyPool(poolSize, anyPromises);
1015+
}
1016+
try {
1017+
const splitResults = await allPool(poolSize, splitPromises);
10501018
// Merge all the split keys
10511019
const reconstructedKey = keyMerge(splitResults.map((r) => r.key));
10521020
return {
10531021
reconstructedKeyBinary: Binary.fromArrayBuffer(reconstructedKey),
10541022
metadata: splitResults[0].metadata, // Use metadata from first split
10551023
};
1024+
} catch (e) {
1025+
if (e instanceof AggregateError) {
1026+
const errors = e.errors;
1027+
if (errors.length === 1) {
1028+
throw errors[0];
1029+
}
1030+
}
1031+
throw e;
10561032
}
10571033
}
10581034

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import { allPool, anyPool } from '../../../src/concurrency.js';
2+
import { expect } from 'chai';
3+
4+
describe('concurrency', () => {
5+
for (const n of [1, 2, 3, 4]) {
6+
describe(`allPool(${n})`, () => {
7+
it(`should resolve all promises with a pool size of ${n}`, async () => {
8+
const promises = {
9+
a: Promise.resolve(1),
10+
b: Promise.resolve(2),
11+
c: Promise.resolve(3),
12+
};
13+
const result = await allPool(n, promises);
14+
expect(result).to.have.members([1, 2, 3]);
15+
});
16+
it(`should reject if any promise rejects, n=${n}`, async () => {
17+
const promises = {
18+
a: Promise.resolve(1),
19+
b: Promise.reject(new Error('failure')),
20+
c: Promise.resolve(3),
21+
};
22+
try {
23+
await allPool(n, promises);
24+
} catch (e) {
25+
expect(e).to.contain({ message: 'failure' });
26+
}
27+
});
28+
});
29+
describe(`anyPool(${n})`, () => {
30+
it('should resolve with the first resolved promise', async () => {
31+
const startTime = Date.now();
32+
const promises = {
33+
a: new Promise((resolve) => setTimeout(() => resolve(1), 500)),
34+
b: new Promise((resolve) => setTimeout(() => resolve(2), 50)),
35+
c: new Promise((resolve) => setTimeout(() => resolve(3), 1500)),
36+
};
37+
const result = await anyPool(n, promises);
38+
const endTime = Date.now();
39+
const elapsed = endTime - startTime;
40+
if (n > 1) {
41+
expect(elapsed).to.be.lessThan(500);
42+
expect(result).to.equal(2);
43+
} else {
44+
expect(elapsed).to.be.greaterThan(50);
45+
expect(elapsed).to.be.lessThan(1000);
46+
expect(result).to.equal(1);
47+
}
48+
});
49+
50+
it('should reject if all promises reject', async () => {
51+
const promises = {
52+
a: Promise.reject(new Error('failure1')),
53+
b: Promise.reject(new Error('failure2')),
54+
c: Promise.reject(new Error('failure3')),
55+
};
56+
try {
57+
await anyPool(n, promises);
58+
} catch (e) {
59+
expect(e).to.be.instanceOf(AggregateError);
60+
expect(e.errors).to.have.lengthOf(3);
61+
}
62+
});
63+
});
64+
}
65+
});

0 commit comments

Comments
 (0)