Skip to content

Commit d0bd551

Browse files
lingvo-botcopybara-github
authored andcommitted
Move the quantization/select order, this prevents accidentally accumulating size from unused weights.
PiperOrigin-RevId: 654933747
1 parent b162979 commit d0bd551

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

lingvo/core/py_utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -7036,14 +7036,16 @@ def MultiTaskProjection(
70367036
# o - output_dim
70377037

70387038
if einsum_order == 'select_and_multiply':
7039-
# Weights quantization:
7040-
weights = quant_layer.QWeight(weights, domain=w_q_domain)
7041-
weights = quant_layer.ToAqtWeight(w_q_name, weights, feature_axis=-1)
70427039
# select..
70437040
# [{batch,} {time,} input_dim, output_dim]
70447041
selected_weights = tf.einsum(
70457042
f'{b_task}{t_task}k,kio->{b_task}{t_task}io', tasks_onehot, weights
70467043
)
7044+
# Weights quantization:
7045+
selected_weights = quant_layer.QWeight(selected_weights, domain=w_q_domain)
7046+
selected_weights = quant_layer.ToAqtWeight(
7047+
w_q_name, selected_weights, feature_axis=-1
7048+
)
70477049
if qat_output:
70487050
# .. and multiply
70497051
# [batch, {time,} output_dim]

0 commit comments

Comments
 (0)