Skip to content

Commit ed44bd4

Browse files
committed
pre-commit fix
1 parent 308d53a commit ed44bd4

File tree

2 files changed

+27
-57
lines changed

2 files changed

+27
-57
lines changed

Diff for: nncf/experimental/torch2/function_hook/handle_inner_functions.py

+18-54
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def multi_head_attention_forward(
6363
is_causal: bool = False,
6464
) -> Tuple[Tensor, Optional[Tensor]]:
6565

66-
is_batched = _mha_shape_check(
67-
query, key, value, key_padding_mask, attn_mask, num_heads
68-
)
66+
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
6967

7068
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
7169
# is batched, run the computation and before returning squeeze the
@@ -126,37 +124,25 @@ def multi_head_attention_forward(
126124
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
127125
else:
128126
head_dim = embed_dim // num_heads
129-
assert (
130-
head_dim * num_heads == embed_dim
131-
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
127+
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
132128
if use_separate_proj_weight:
133129
# allow MHA to have different embedding dimensions when separate projection weights are used
134130
assert (
135131
key.shape[:2] == value.shape[:2]
136132
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
137133
else:
138-
assert (
139-
key.shape == value.shape
140-
), f"key shape {key.shape} does not match value shape {value.shape}"
134+
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
141135

142136
#
143137
# compute in-projection
144138
#
145139
if not use_separate_proj_weight:
146-
assert (
147-
in_proj_weight is not None
148-
), "use_separate_proj_weight is False but in_proj_weight is None"
140+
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
149141
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
150142
else:
151-
assert (
152-
q_proj_weight is not None
153-
), "use_separate_proj_weight is True but q_proj_weight is None"
154-
assert (
155-
k_proj_weight is not None
156-
), "use_separate_proj_weight is True but k_proj_weight is None"
157-
assert (
158-
v_proj_weight is not None
159-
), "use_separate_proj_weight is True but v_proj_weight is None"
143+
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
144+
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
145+
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
160146
if in_proj_bias is None:
161147
b_q = b_k = b_v = None
162148
else:
@@ -191,9 +177,7 @@ def multi_head_attention_forward(
191177
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
192178
)
193179
else:
194-
raise RuntimeError(
195-
f"attn_mask's dimension {attn_mask.dim()} is not supported"
196-
)
180+
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
197181

198182
# add bias along batch dimension (currently second)
199183
if bias_k is not None and bias_v is not None:
@@ -220,9 +204,7 @@ def multi_head_attention_forward(
220204
assert (
221205
static_k.size(0) == bsz * num_heads
222206
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
223-
assert (
224-
static_k.size(2) == head_dim
225-
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
207+
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
226208
k = static_k
227209
if static_v is None:
228210
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
@@ -231,20 +213,14 @@ def multi_head_attention_forward(
231213
assert (
232214
static_v.size(0) == bsz * num_heads
233215
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
234-
assert (
235-
static_v.size(2) == head_dim
236-
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
216+
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
237217
v = static_v
238218

239219
# add zero attention along batch dimension (now first)
240220
if add_zero_attn:
241221
zero_attn_shape = (bsz * num_heads, 1, head_dim)
242-
k = torch.cat(
243-
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
244-
)
245-
v = torch.cat(
246-
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
247-
)
222+
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
223+
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
248224
if attn_mask is not None:
249225
attn_mask = pad(attn_mask, (0, 1))
250226
if key_padding_mask is not None:
@@ -259,9 +235,7 @@ def multi_head_attention_forward(
259235
_check_key_padding_mask(key_padding_mask, src_len, bsz)
260236

261237
key_padding_mask = (
262-
key_padding_mask.view(bsz, 1, 1, src_len)
263-
.expand(-1, num_heads, -1, -1)
264-
.reshape(bsz * num_heads, 1, src_len)
238+
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
265239
)
266240
if attn_mask is None:
267241
attn_mask = key_padding_mask
@@ -280,14 +254,10 @@ def multi_head_attention_forward(
280254
_B, _Nt, E = q.shape
281255
q_scaled = q * math.sqrt(1.0 / float(E))
282256

283-
assert not (
284-
is_causal and attn_mask is None
285-
), "FIXME: is_causal not implemented for need_weights"
257+
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
286258

287259
if attn_mask is not None:
288-
attn_output_weights = torch.baddbmm(
289-
attn_mask, q_scaled, k.transpose(-2, -1)
290-
)
260+
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
291261
else:
292262
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
293263
attn_output_weights = softmax(attn_output_weights, dim=-1)
@@ -296,9 +266,7 @@ def multi_head_attention_forward(
296266

297267
attn_output = torch.bmm(attn_output_weights, v)
298268

299-
attn_output = (
300-
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
301-
)
269+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
302270
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
303271
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
304272

@@ -326,12 +294,8 @@ def multi_head_attention_forward(
326294
k = k.view(bsz, num_heads, src_len, head_dim)
327295
v = v.view(bsz, num_heads, src_len, head_dim)
328296

329-
attn_output = scaled_dot_product_attention(
330-
q, k, v, attn_mask, dropout_p, is_causal
331-
)
332-
attn_output = (
333-
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
334-
)
297+
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
298+
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
335299

336300
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
337301
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

Diff for: tests/torch/pruning/experimental/test_nodes_grouping.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def __str__(self) -> str:
180180
model_desc=GeneralModelDesc(
181181
model_name="1_layer_BERT",
182182
input_info=[dict(sample_size=[1, 10], type="long")],
183-
model_builder=partial(AutoModelForQuestionAnswering.from_config, BertConfig(num_hidden_layers=1), attn_implementation="eager"),
183+
model_builder=partial(
184+
AutoModelForQuestionAnswering.from_config, BertConfig(num_hidden_layers=1), attn_implementation="eager"
185+
),
184186
),
185187
ref_groups=[
186188
PruningGroup(
@@ -226,7 +228,11 @@ def __str__(self) -> str:
226228
model_desc=GeneralModelDesc(
227229
model_name="RoBERTa",
228230
input_info=[dict(sample_size=[1, 10], type="long")],
229-
model_builder=partial(AutoModelForQuestionAnswering.from_config, RobertaConfig(num_hidden_layers=1), attn_implementation="eager"),
231+
model_builder=partial(
232+
AutoModelForQuestionAnswering.from_config,
233+
RobertaConfig(num_hidden_layers=1),
234+
attn_implementation="eager",
235+
),
230236
),
231237
ref_groups=[
232238
PruningGroup(
@@ -483,7 +489,7 @@ def test_groups(desc: GroupTestDesc, mocker, tmp_path):
483489
not_filtered_groups = get_pruning_groups(
484490
nncf_network.nncf.get_graph(), PT_EXPERIMENTAL_PRUNING_OPERATOR_METATYPES, pruning_producing_types, tmp_path
485491
)
486-
nncf_network.nncf.get_graph().visualize_graph('transformers38.dot')
492+
nncf_network.nncf.get_graph().visualize_graph("transformers38.dot")
487493
nx_graph = get_graph_spy.spy_return
488494
path_to_dot = get_full_path_to_the_graph(f"{str(desc)}.dot", "pruning_groups")
489495
compare_nx_graph_with_reference(nx_graph, path_to_dot, sort_dot_graph=False)

0 commit comments

Comments
 (0)