Skip to content

Commit 5ab00f5

Browse files
committed
fix comments
1 parent fcc79d1 commit 5ab00f5

File tree

2 files changed

+5
-45
lines changed

2 files changed

+5
-45
lines changed

nncf/experimental/torch2/function_hook/handle_inner_functions.py

+1-41
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,13 @@ def multi_head_attention_forward(
6464
) -> Tuple[Tensor, Optional[Tensor]]:
6565

6666
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
67-
68-
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
69-
# is batched, run the computation and before returning squeeze the
70-
# batch dimension so that the output doesn't carry this temporary batch dimension.
7167
if not is_batched:
72-
# unsqueeze if the input is unbatched
7368
query = query.unsqueeze(1)
7469
key = key.unsqueeze(1)
7570
value = value.unsqueeze(1)
7671
if key_padding_mask is not None:
7772
key_padding_mask = key_padding_mask.unsqueeze(0)
7873

79-
# set up shape vars
8074
tgt_len, bsz, embed_dim = query.shape
8175
src_len, _, _ = key.shape
8276

@@ -96,9 +90,6 @@ def multi_head_attention_forward(
9690
)
9791

9892
if is_causal and key_padding_mask is None and not need_weights:
99-
# when we have a kpm or need weights, we need attn_mask
100-
# Otherwise, we use the is_causal hint go as is_causal
101-
# indicator to SDPA.
10293
attn_mask = None
10394
else:
10495
attn_mask = _canonical_mask(
@@ -111,31 +102,23 @@ def multi_head_attention_forward(
111102
)
112103

113104
if key_padding_mask is not None:
114-
# We have the attn_mask, and use that to merge kpm into it.
115-
# Turn off use of is_causal hint, as the merged mask is no
116-
# longer causal.
117105
is_causal = False
118106

119107
assert (
120108
embed_dim == embed_dim_to_check
121109
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
122110
if isinstance(embed_dim, torch.Tensor):
123-
# embed_dim can be a tensor when JIT tracing
124111
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
125112
else:
126113
head_dim = embed_dim // num_heads
127114
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
128115
if use_separate_proj_weight:
129-
# allow MHA to have different embedding dimensions when separate projection weights are used
130116
assert (
131117
key.shape[:2] == value.shape[:2]
132118
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
133119
else:
134120
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
135121

136-
#
137-
# compute in-projection
138-
#
139122
if not use_separate_proj_weight:
140123
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
141124
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
@@ -159,10 +142,7 @@ def multi_head_attention_forward(
159142
b_v,
160143
)
161144

162-
# prep attention mask
163-
164145
if attn_mask is not None:
165-
# ensure attn_mask's dim is 3
166146
if attn_mask.dim() == 2:
167147
correct_2d_size = (tgt_len, src_len)
168148
if attn_mask.shape != correct_2d_size:
@@ -179,7 +159,6 @@ def multi_head_attention_forward(
179159
else:
180160
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
181161

182-
# add bias along batch dimension (currently second)
183162
if bias_k is not None and bias_v is not None:
184163
assert static_k is None, "bias cannot be added to static key."
185164
assert static_v is None, "bias cannot be added to static value."
@@ -193,14 +172,10 @@ def multi_head_attention_forward(
193172
assert bias_k is None
194173
assert bias_v is None
195174

196-
#
197-
# reshape q, k, v for multihead attention and make them batch first
198-
#
199175
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
200176
if static_k is None:
201177
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
202178
else:
203-
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
204179
assert (
205180
static_k.size(0) == bsz * num_heads
206181
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
@@ -209,14 +184,12 @@ def multi_head_attention_forward(
209184
if static_v is None:
210185
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
211186
else:
212-
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
213187
assert (
214188
static_v.size(0) == bsz * num_heads
215189
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
216190
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
217191
v = static_v
218192

219-
# add zero attention along batch dimension (now first)
220193
if add_zero_attn:
221194
zero_attn_shape = (bsz * num_heads, 1, head_dim)
222195
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
@@ -226,10 +199,8 @@ def multi_head_attention_forward(
226199
if key_padding_mask is not None:
227200
key_padding_mask = pad(key_padding_mask, (0, 1))
228201

229-
# update source sequence length after adjustments
230202
src_len = k.size(1)
231203

232-
# merge key padding and attention masks
233204
if key_padding_mask is not None:
234205
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
235206
_check_key_padding_mask(key_padding_mask, src_len, bsz)
@@ -242,16 +213,11 @@ def multi_head_attention_forward(
242213
else:
243214
attn_mask = attn_mask + key_padding_mask
244215

245-
# adjust dropout probability
246216
if not training:
247217
dropout_p = 0.0
248218

249-
#
250-
# (deep breath) calculate attention and out projection
251-
#
252-
253219
if need_weights:
254-
_B, _Nt, E = q.shape
220+
_B, _Nt, E = q.shape # noqa: F841
255221
q_scaled = q * math.sqrt(1.0 / float(E))
256222

257223
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
@@ -270,20 +236,15 @@ def multi_head_attention_forward(
270236
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
271237
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
272238

273-
# optionally average attention weights over heads
274239
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
275240
if average_attn_weights:
276241
attn_output_weights = attn_output_weights.mean(dim=1)
277242

278243
if not is_batched:
279-
# squeeze the output if input was unbatched
280244
attn_output = attn_output.squeeze(1)
281245
attn_output_weights = attn_output_weights.squeeze(0)
282246
return attn_output, attn_output_weights
283247
else:
284-
# attn_mask can be either (L,S) or (N*num_heads, L, S)
285-
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
286-
# in order to match the input for SDPA of (N, num_heads, L, S)
287248
if attn_mask is not None:
288249
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
289250
attn_mask = attn_mask.unsqueeze(0)
@@ -300,7 +261,6 @@ def multi_head_attention_forward(
300261
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
301262
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
302263
if not is_batched:
303-
# squeeze the output if input was unbatched
304264
attn_output = attn_output.squeeze(1)
305265
return attn_output, None
306266

tests/torch/experimental/search_building_blocks/test_transformer_blocks.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,22 @@ def forward(self, x):
9191
TransformerSearchBBlockParamsCase(
9292
name="BERT",
9393
input_info=[dict(sample_size=[1, 10], type="long")],
94-
model_creator=partial(AutoModelForQuestionAnswering.from_config, BertConfig()),
94+
model_creator=partial(AutoModelForQuestionAnswering.from_config, BertConfig(), attn_implementation="eager"),
9595
),
9696
TransformerSearchBBlockParamsCase(
9797
name="ViT",
9898
input_info=dict(sample_size=[1, 3, 224, 224]),
99-
model_creator=partial(AutoModelForImageClassification.from_config, ViTConfig()),
99+
model_creator=partial(AutoModelForImageClassification.from_config, ViTConfig(), attn_implementation="eager"),
100100
),
101101
TransformerSearchBBlockParamsCase(
102102
name="wave2vec 2.0",
103103
input_info=dict(sample_size=[1, 400]),
104-
model_creator=partial(AutoModelForAudioClassification.from_config, Wav2Vec2Config()),
104+
model_creator=partial(AutoModelForAudioClassification.from_config, Wav2Vec2Config(), attn_implementation="eager"),
105105
),
106106
TransformerSearchBBlockParamsCase(
107107
name="SWIN MS",
108108
input_info=dict(sample_size=[1, 3, 224, 224]),
109-
model_creator=partial(AutoModelForImageClassification.from_config, SwinConfig()),
109+
model_creator=partial(AutoModelForImageClassification.from_config, SwinConfig(), attn_implementation="eager"),
110110
),
111111
TransformerSearchBBlockParamsCase(
112112
name="one MHSA",

0 commit comments

Comments
 (0)