Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds GTensor.where function to gtensor and unit tests. #55

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions animated-transformer/src/lib/gtensor/gtensor.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -905,4 +905,109 @@ describe('gtensor', () => {
[0, 0, 0],
]);
});

it('where', async () => {
const g1 = new gtensor.GTensor(
tf.tensor(
[
[
[1, 2],
[3, 4],
[5, 6],
],
[
[1, 2],
[3, 4],
[5, 6],
],
],
),
['example', 'pos', 'repSize'],
);

const g2 = new gtensor.GTensor(
tf.tensor(
[
[0, 0],
[0, 0],
[0, 0],
],
),
['pos', 'repSize'],
);

const condition = tf.tensor([1, 0, 0, 1, 1, 0], [3, 2], 'bool');

const g1WhereCondition = g1.where(condition, g2);

expect(g1WhereCondition.dimNames).toEqual(['example', 'pos', 'repSize']);
tf.test_util.expectArraysEqual(g1WhereCondition.tensor.arraySync(), [
[
[1, 0],
[0, 4],
[5, 0],
], // example = 1
[
[1, 0],
[0, 4],
[5, 0],
],
]);
});

it('where no broadcast over g2', async () => {
const g1 = new gtensor.GTensor(
tf.tensor(
[
[
[1, 2],
[3, 4],
[5, 6],
],
[
[1, 2],
[3, 4],
[5, 6],
],
],
),
['example', 'pos', 'repSize'],
);

const g2 = new gtensor.GTensor(
tf.tensor(
[
[
[0, 0],
[0, 0],
[0, 0],
], // example = 1
[
[0, 0],
[0, 0],
[0, 0],
],
], // example = 2
),
['example', 'pos', 'repSize'],
);

const condition = tf.tensor([1, 0, 0, 1, 1, 0], [3, 2], 'bool');

const g1WhereCondition = g1.where(condition, g2);

expect(g1WhereCondition.dimNames).toEqual(['example', 'pos', 'repSize']);
tf.test_util.expectArraysEqual(g1WhereCondition.tensor.arraySync(), [
[
[1, 0],
[0, 4],
[5, 0],
], // example = 1
[
[1, 0],
[0, 4],
[5, 0],
],
]);
});
});
22 changes: 20 additions & 2 deletions animated-transformer/src/lib/gtensor/gtensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,20 @@ export class GTensor<G extends DName> {
this.dimNames,
);
}

/* Returns the elements, either of the gtensor or g2 depending on the condition.
If the condition is true, select from the gtensor, otherwise select from g2.
if gtensor.dims != g2.dims substracted dimentions are broadcasted */
public where<G2 extends DName>(condition: tf.Tensor, g2: GTensor<G2>): GTensor<G> {
const g2big = g2.broadcastToCombinedShape(this);
const g1big = this.broadcastToCombinedShape(g2);
const g1bigLikeG2 = g1big.transposeLike(g2big);

const shape = Object.values(g1bigLikeG2.gshape()) as number[];
const conditionBig = condition.broadcastTo(shape);

return new GTensor(g1bigLikeG2.tensor.where(conditionBig, g2big.tensor), this.dimNames);
}
}

export class GVariable<G extends DName> extends GTensor<G> {
Expand Down Expand Up @@ -964,13 +978,17 @@ export function makeRange<T extends DName>(
* - dtype : The type of an element in the resulting tensor. Defaults to 'float32'
* // TODO add optianal broadcastTo dimensions/GTensor
* */
export function makeTriangularMatrix<N1 extends string, N2 extends string, T extends string | number>(
export function makeTriangularMatrix<
N1 extends string,
N2 extends string,
T extends string | number,
>(
size: number,
d1Name: N1,
d2Name: N2,
lowerLeftValue: T,
upperRightValue: T,
dtype: 'float32' | 'int32' | 'bool' | 'complex64' | 'string' = 'float32'
dtype: 'float32' | 'int32' | 'bool' | 'complex64' | 'string' = 'float32',
): GTensor<N1 | N2> {
// Create a range tensor for row indices
const rowIndices = tf.range(0, size, 1, 'int32');
Expand Down
Loading