Skip to content

Commit 169315b

Browse files
lingvo-botcopybara-github
authored andcommitted
num_tasks and max_task_id *is* the same. Don't pass max_task_id, infer the value from the weights shape.
PiperOrigin-RevId: 593825138
1 parent 015fc3a commit 169315b

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

lingvo/core/layers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,7 +1525,6 @@ def _ApplyProjectionKernel(self, theta, inputs, tasks):
15251525
biases=b,
15261526
inputs=inputs,
15271527
tasks=tasks,
1528-
max_task_id=p.num_tasks,
15291528
einsum_order=p.einsum_order,
15301529
quant_layer=self,
15311530
w_q_name='w',
@@ -6500,7 +6499,6 @@ def FProp(self, theta, inputs, tasks):
65006499
biases=theta.down_b,
65016500
inputs=norm_inputs,
65026501
tasks=tasks,
6503-
max_task_id=p.num_tasks,
65046502
einsum_order=p.einsum_order,
65056503
quant_layer=self,
65066504
w_q_name='down_w',
@@ -6514,7 +6512,6 @@ def FProp(self, theta, inputs, tasks):
65146512
biases=theta.up_b,
65156513
inputs=down_projected,
65166514
tasks=tasks,
6517-
max_task_id=p.num_tasks,
65186515
einsum_order=p.einsum_order,
65196516
quant_layer=self,
65206517
w_q_name='up_w',

lingvo/core/py_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6929,7 +6929,6 @@ def MultiTaskProjection(
69296929
biases: Optional[tf.Tensor],
69306930
inputs: tf.Tensor,
69316931
tasks: tf.Tensor,
6932-
max_task_id: int,
69336932
einsum_order: str,
69346933
quant_layer, # quant_utils.QuantizableLayer, would be circular import
69356934
w_q_name: str,
@@ -6948,8 +6947,7 @@ def MultiTaskProjection(
69486947
input_dim]
69496948
tasks: An int32 tensor containing the task ID for each input. Tensor size is
69506949
[batch_dim] or [batch_dim, time_dim] (allowed only when inputs also has a
6951-
time dimension), no elements are larger than max_task_id.
6952-
max_task_id: the highest task id allowed. (Note, different from num_tasks.)
6950+
time dimension), no elements are larger than num_tasks.
69536951
einsum_order: the algorithm to use, either 'select_and_multiply' or
69546952
'multiply_and_select'.
69556953
quant_layer: QuantizableLayer used for AQT (pass `self`)
@@ -6985,8 +6983,8 @@ def MultiTaskProjection(
69856983
tasks = HasShape(tasks, [batch_size, time_size])
69866984
t_task = 't'
69876985

6988-
# [batch, max_task_id] or [batch, time, max_task_id]
6989-
tasks_onehot = tf.one_hot(tasks, max_task_id, axis=-1, dtype=inputs.dtype)
6986+
# [batch, num_tasks] or [batch, time, num_tasks]
6987+
tasks_onehot = tf.one_hot(tasks, num_tasks, axis=-1, dtype=inputs.dtype)
69906988

69916989
# Einsum axis names:
69926990
# b - batch

0 commit comments

Comments
 (0)