File tree 1 file changed +5
-3
lines changed
1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -7036,14 +7036,16 @@ def MultiTaskProjection(
7036
7036
# o - output_dim
7037
7037
7038
7038
if einsum_order == 'select_and_multiply' :
7039
- # Weights quantization:
7040
- weights = quant_layer .QWeight (weights , domain = w_q_domain )
7041
- weights = quant_layer .ToAqtWeight (w_q_name , weights , feature_axis = - 1 )
7042
7039
# select..
7043
7040
# [{batch,} {time,} input_dim, output_dim]
7044
7041
selected_weights = tf .einsum (
7045
7042
f'{ b_task } { t_task } k,kio->{ b_task } { t_task } io' , tasks_onehot , weights
7046
7043
)
7044
+ # Weights quantization:
7045
+ selected_weights = quant_layer .QWeight (selected_weights , domain = w_q_domain )
7046
+ selected_weights = quant_layer .ToAqtWeight (
7047
+ w_q_name , selected_weights , feature_axis = - 1
7048
+ )
7047
7049
if qat_output :
7048
7050
# .. and multiply
7049
7051
# [batch, {time,} output_dim]
You can’t perform that action at this time.
0 commit comments