Skip to content

Commit 3fd51e0

Browse files
Revert "[Inductor] Constrain the shape of other tensor for Conv/Linear + broadcast add fusion. (pytorch#141759)"
This reverts commit 35752cb. Reverted pytorch#141759 on behalf of https://github.com/atalman due to Failing internally ([comment](pytorch#141759 (comment)))
1 parent db13bd9 commit 3fd51e0

File tree

3 files changed

+23
-79
lines changed

3 files changed

+23
-79
lines changed

aten/src/ATen/native/mkldnn/Linear.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -295,15 +295,12 @@ Tensor mkldnn_linear_pointwise_binary(
295295
input_reshaped.size(0), weight_t.size(0)};
296296
output = output.reshape(output_size_reshaped);
297297
other_reshaped = other_reshaped.reshape(output_size_reshaped);
298-
TORCH_CHECK(
299-
output.sizes() == other_reshaped.sizes(),
300-
"linear_binary_run expects the size of output and other tensor to be the same");
301-
} else {
302-
TORCH_CHECK(
303-
output.dim() == other_reshaped.dim(),
304-
"linear_binary_run expects the dimension of output and other tensor to be the same");
305298
}
306299

300+
TORCH_CHECK(
301+
output.dim() == other_reshaped.dim(),
302+
"linear_binary_run expects the dimension of output and other tensor to be the same");
303+
307304
c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
308305
ideep::tensor mkldnn_output = itensor_from_tensor(output);
309306
const ideep::tensor mkldnn_other = itensor_from_tensor(other_reshaped);

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -709,34 +709,26 @@ def forward(self, x, x2):
709709
dtypes.append(torch.float16)
710710
cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d
711711
test_memory_format = [torch.contiguous_format, cl_format]
712-
if dim == 4:
713-
input_shapes = [
714-
[2, 3, 56, 56],
715-
]
716-
other_shapes = [[2, 16, 1, 1], [1, 16, 1, 1], [1, 1, 1, 1]]
717-
else:
718-
input_shapes = [
719-
[2, 3, 20, 56, 56],
720-
]
721-
other_shapes = [[2, 16, 1, 1, 1], [1, 16, 1, 1, 1], [1, 1, 1, 1, 1]]
722712
options = itertools.product(
723713
binary_list,
724-
input_shapes,
725-
other_shapes,
726714
[True, False],
727715
test_memory_format,
728716
dtypes,
729717
)
730718

731719
for (
732720
binary_fn,
733-
x_shape,
734-
other_shape,
735721
has_relu,
736722
memory_format,
737723
dtype,
738724
) in options:
739725
metrics.reset()
726+
if dim == 4:
727+
x_shape = (1, 3, 56, 56)
728+
other_shape = (1, 16, 1, 1)
729+
else:
730+
x_shape = (1, 3, 20, 56, 56)
731+
other_shape = (1, 16, 1, 1, 1)
740732
mod = M(binary_fn, has_relu).eval()
741733
x = (
742734
torch.randn(x_shape, dtype=torch.float32, requires_grad=True)
@@ -855,23 +847,15 @@ def forward(self, x, y):
855847
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
856848
dtypes.append(torch.float16)
857849
options = itertools.product(
858-
binary_list,
859-
(
860-
([2, 3, 10], [1, 1, 30]),
861-
([2, 3, 10], [1, 1, 1]),
862-
([2, 10], [1, 30]),
863-
([2, 10], [1, 1]),
864-
),
865-
(True, False),
866-
dtypes,
850+
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
867851
)
868852
out_feature = 30
869853

870-
for binary_fn, (input_shape, other_shape), bias, dtype in options:
854+
for binary_fn, input_shape, bias, dtype in options:
871855
metrics.reset()
872856
mod = M(binary_fn, input_shape[-1], out_feature, bias).eval()
873857
v = torch.randn(input_shape)
874-
other = torch.randn(other_shape).to(dtype)
858+
other = torch.randn(input_shape[:-1] + [1]).to(dtype)
875859

876860
def matcher_check_fn():
877861
self.assertEqual(

torch/_inductor/fx_passes/mkldnn_fusion.py

Lines changed: 10 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ def fn(match, *args, **kwargs):
356356
ops.sub: "sub",
357357
}
358358

359-
def _is_valid_binary(match, computation_op, binary_op):
360-
binary_nodes = filter_nodes(match.nodes, binary_op)
359+
def _is_valid_binary(match, fn):
360+
binary_nodes = filter_nodes(match.nodes, fn)
361361
if len(binary_nodes) < 1:
362362
return False
363363

@@ -381,50 +381,13 @@ def get_meta_value(argument: torch.fx.node.Argument):
381381
):
382382
return False
383383

384-
def _check_input_sizes(n, computation_op):
385-
# Check if the tensor shape of the 'other' node is the same as or
386-
# can be broadcasted to the tensor shape of the computation node.
387-
computation_node = (
388-
n.args[0] if n.args[1] is match.kwargs["other"] else n.args[1]
389-
)
390-
assert computation_node.target == computation_op
391-
computation_node_size = get_meta_value(computation_node).size()
392-
if computation_op is mkldnn._linear_pointwise.default:
393-
if len(computation_node_size) >= 2:
394-
broadcast_sizes = [
395-
torch.Size(
396-
[1 for _ in range(len(computation_node_size) - 1)]
397-
+ [computation_node_size[-1]]
398-
),
399-
torch.Size([1 for _ in range(len(computation_node_size))]),
400-
]
401-
else:
402-
broadcast_sizes = [
403-
torch.Size([1 for _ in range(len(computation_node_size))]),
404-
]
405-
else:
406-
assert len(computation_node_size) > 2
407-
broadcast_sizes = [
408-
torch.Size(
409-
[computation_node_size[0], computation_node_size[1]]
410-
+ [1 for _ in range(len(computation_node_size) - 2)]
411-
),
412-
torch.Size(
413-
[1, computation_node_size[1]]
414-
+ [1 for _ in range(len(computation_node_size) - 2)]
415-
),
416-
torch.Size([1 for _ in range(len(computation_node_size))]),
417-
]
418-
return (
419-
get_meta_value(match.kwargs["other"]).size()
420-
in [
421-
computation_node_size,
422-
]
423-
+ broadcast_sizes
424-
)
425-
426384
if any(
427-
not _check_input_sizes(n, computation_op)
385+
get_meta_value(n.args[0]).dim() != get_meta_value(n.args[1]).dim()
386+
or not all(
387+
get_meta_value(n.args[0]).size(i) == get_meta_value(n.args[1]).size(i)
388+
or get_meta_value(match.kwargs["other"]).size(i) == 1
389+
for i in range(get_meta_value(n.args[0]).dim())
390+
)
428391
or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device
429392
or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype
430393
for n in binary_nodes
@@ -439,7 +402,7 @@ def _is_valid_computation_binary(computation_op, binary_op, other_index=None):
439402
def fn(match):
440403
if not _is_single_computation_op(computation_op)(match):
441404
return False
442-
if not _is_valid_binary(match, computation_op, binary_op):
405+
if not _is_valid_binary(match, binary_op):
443406
return False
444407
return True
445408

@@ -562,7 +525,7 @@ def _can_be_inplace(_other):
562525
else:
563526
return not (
564527
isinstance(_other.data, ir.ReinterpretView)
565-
or len(_other.data.get_inputs_that_alias_output()) > 0
528+
or len(_other.get_inputs_that_alias_output()) > 0
566529
)
567530

568531
def _register_binary_unary_maybe_inplace_fusion_lowering(

0 commit comments

Comments
 (0)