Skip to content

Commit

Permalink
fix(sdk): Let pools take callables.
Browse files Browse the repository at this point in the history
Lets the pool start the promises itself, instead of taking in promises that have already been scheduled on the task queue (!)
  • Loading branch information
dmihalcik-virtru committed Nov 18, 2024
1 parent beb3c06 commit c6116ed
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 21 deletions.
18 changes: 14 additions & 4 deletions lib/src/concurrency.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
type LabelledSuccess<T> = { lid: string; value: Promise<T> };
type LabelledFailure = { lid: string; e: any };

async function labelPromise<T>(label: string, promise: Promise<T>): Promise<LabelledSuccess<T>> {
async function labelPromise<T>(
label: string,
promise: () => Promise<T>
): Promise<LabelledSuccess<T>> {
try {
const value = await promise;
const value = await promise();
return { lid: label, value: Promise.resolve(value) };
} catch (e) {
throw { lid: label, e };
Expand All @@ -14,14 +17,18 @@ async function labelPromise<T>(label: string, promise: Promise<T>): Promise<Labe
// but with a pool size of n. Rejects on first reject, or returns a list
// of all successful responses. Operates with at most n 'active' promises at a time.
// For tracking purposes, all promises must have a unique identifier.
export async function allPool<T>(n: number, p: Record<string, Promise<T>>): Promise<Awaited<T>[]> {
export async function allPool<T>(
n: number,
p: Record<string, () => Promise<T>>
): Promise<Awaited<T>[]> {
const pool: Record<string, Promise<LabelledSuccess<T>>> = {};
const resolved: Awaited<T>[] = [];
for (const [id, job] of Object.entries(p)) {
// while the size of jobs to do is greater than n,
// let n jobs run and take the first one to finish out of the pool
pool[id] = labelPromise(id, job);
if (Object.keys(pool).length > n - 1) {
// When pool is full, wait for one to resolve, and remove it from the pool.
const promises = Object.values(pool);
try {
const { lid, value } = await Promise.race(promises);
Expand Down Expand Up @@ -50,7 +57,10 @@ export async function allPool<T>(n: number, p: Record<string, Promise<T>>): Prom
// Pooled variant of promise.any; implements most of the logic of the real any,
// but with a pool size of n, and returns the first successful promise,
// operating with at most n 'active' promises at a time.
export async function anyPool<T>(n: number, p: Record<string, Promise<T>>): Promise<Awaited<T>> {
export async function anyPool<T>(
n: number,
p: Record<string, () => Promise<T>>
): Promise<Awaited<T>> {
const pool: Record<string, Promise<LabelledSuccess<T>>> = {};
const rejections = [];
for (const [id, job] of Object.entries(p)) {
Expand Down
10 changes: 5 additions & 5 deletions lib/tdf3/src/tdf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ async function unwrapKey({
if (concurrencyLimit !== undefined && concurrencyLimit > 1) {
poolSize = concurrencyLimit;
}
const splitPromises: Record<string, Promise<RewrapResponseData>> = {};
const splitPromises: Record<string, () => Promise<RewrapResponseData>> = {};
for (const splitId of Object.keys(splitPotentials)) {
const potentials = splitPotentials[splitId];
if (!potentials || !Object.keys(potentials).length) {
Expand All @@ -1004,17 +1004,17 @@ async function unwrapKey({
''
);
}
const anyPromises: Record<string, Promise<RewrapResponseData>> = {};
const anyPromises: Record<string, () => Promise<RewrapResponseData>> = {};
for (const [kas, keySplitInfo] of Object.entries(potentials)) {
anyPromises[kas] = (async () => {
anyPromises[kas] = async () => {
try {
return await tryKasRewrap(keySplitInfo);
} catch (e) {
throw handleRewrapError(e as Error | AxiosError);
}
})();
};
}
splitPromises[splitId] = anyPool(poolSize, anyPromises);
splitPromises[splitId] = () => anyPool(poolSize, anyPromises);
}
try {
const splitResults = await allPool(poolSize, splitPromises);
Expand Down
45 changes: 33 additions & 12 deletions lib/tests/mocha/unit/concurrency.spec.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import { allPool, anyPool } from '../../../src/concurrency.js';
import { expect } from 'chai';

async function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}

describe('concurrency', () => {
for (const n of [1, 2, 3, 4]) {
describe(`allPool(${n})`, () => {
it(`should resolve all promises with a pool size of ${n}`, async () => {
const promises = {
a: Promise.resolve(1),
b: Promise.resolve(2),
c: Promise.resolve(3),
a: () => Promise.resolve(1),
b: () => Promise.resolve(2),
c: () => Promise.resolve(3),
};
const result = await allPool(n, promises);
expect(result).to.have.members([1, 2, 3]);
});
it(`should reject if any promise rejects, n=${n}`, async () => {
const promises = {
a: Promise.resolve(1),
b: Promise.reject(new Error('failure')),
c: Promise.resolve(3),
a: () => Promise.resolve(1),
b: () => Promise.reject(new Error('failure')),
c: () => Promise.resolve(3),
};
try {
await allPool(n, promises);
Expand All @@ -29,29 +33,46 @@ describe('concurrency', () => {
describe(`anyPool(${n})`, () => {
it('should resolve with the first resolved promise', async () => {
const startTime = Date.now();
const started = new Set<string>();
const promises = {
a: new Promise((resolve) => setTimeout(() => resolve(1), 500)),
b: new Promise((resolve) => setTimeout(() => resolve(2), 50)),
c: new Promise((resolve) => setTimeout(() => resolve(3), 1500)),
a: async () => {
started.add('a');
await sleep(500);
return 1;
},
b: async () => {
started.add('b');
await sleep(50);
return 2;
},
c: async () => {
started.add('c');
await sleep(1500);
return 3;
},
};
const result = await anyPool(n, promises);
const endTime = Date.now();
const elapsed = endTime - startTime;
if (n > 1) {
expect(elapsed).to.be.lessThan(500);
expect(result).to.equal(2);
expect(started).to.include('b');
expect(started).to.have.length(n == 2 ? 2 : 3);
} else {
expect(elapsed).to.be.greaterThan(50);
expect(elapsed).to.be.lessThan(1000);
expect(result).to.equal(1);
expect(started).to.have.length(1);
expect(started).to.include('a');
}
});

it('should reject if all promises reject', async () => {
const promises = {
a: Promise.reject(new Error('failure1')),
b: Promise.reject(new Error('failure2')),
c: Promise.reject(new Error('failure3')),
a: () => Promise.reject(new Error('failure1')),
b: () => Promise.reject(new Error('failure2')),
c: () => Promise.reject(new Error('failure3')),
};
try {
await anyPool(n, promises);
Expand Down

0 comments on commit c6116ed

Please sign in to comment.