Skip to content

Commit 61a7c83

Browse files
BoyuanFengpytorchmergebot
authored andcommitted
[Inductor] fix device error for NopKernelSchedulerNode (pytorch#141372)
This PR adds device guard support for NopKernelSchedulerNode which may create a tensor. Prior to this PR, we do not codegen device guard for NopKernelSchedulerNode, leading to errors. Prior to the PR: ```python def call(args): arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args args.clear() assert_size_stride(arg0_1, (1, 1, 2048, 128), (262144, 262144, 128, 1)) assert_size_stride(arg1_1, (1, 1, 2048, 128), (262144, 262144, 128, 1)) assert_size_stride(arg2_1, (1, 1, 2048, 128), (262144, 262144, 128, 1)) assert_size_stride(arg3_1, (1, 1, 16), (16, 16, 1)) assert_size_stride(arg4_1, (1, 1, 16, 16), (256, 256, 16, 1)) assert_size_stride(arg5_1, (1, 1, 16), (16, 16, 1)) assert_size_stride(arg6_1, (1, 1, 16, 16), (256, 256, 16, 1)) assert_size_stride(arg7_1, (1, 1, 16), (16, 16, 1)) assert_size_stride(arg8_1, (1, 1, 16, 16), (256, 256, 16, 1)) assert_size_stride(arg9_1, (1, 1, 16), (16, 16, 1)) assert_size_stride(arg10_1, (1, 1, 16, 16), (256, 256, 16, 1)) buf0 = empty_strided_cuda((1, 1, 2048), (2048, 2048, 1), torch.float32) # TODO: ERROR here. Should be cuda:1 with torch.cuda._DeviceGuard(1): torch.cuda.set_device(1) buf1 = empty_strided_cuda((1, 1, 2048, 128), (262144, 262144, 128, 1), torch.bfloat16) # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] stream1 = get_raw_stream(1) breakpoint() triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, arg3_1, arg4_1, arg5_1, arg6_1, buf1, grid=torch._inductor.kernel.flex_attention.flex_attention_grid(1, 1, 2048, 128, meta0), stream=stream1) del arg0_1 del arg1_1 del arg2_1 del arg3_1 del arg4_1 del arg5_1 del arg6_1 del buf0 return (buf1, ) ``` After the PR: ```python def call(args): arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args args.clear() assert_size_stride(arg0_1, (1, 1, 2048, 128), (262144, 262144, 128, 1)) assert_size_stride(arg1_1, (1, 1, 2048, 128), (262144, 262144, 128, 1)) assert_size_stride(arg2_1, (1, 1, 2048, 128), (262144, 262144, 128, 1)) assert_size_stride(arg3_1, (1, 1, 16), (16, 16, 1)) assert_size_stride(arg4_1, (1, 1, 16, 16), (256, 256, 16, 1)) assert_size_stride(arg5_1, (1, 1, 16), (16, 16, 1)) assert_size_stride(arg6_1, (1, 1, 16, 16), (256, 256, 16, 1)) assert_size_stride(arg7_1, (1, 1, 16), (16, 16, 1)) assert_size_stride(arg8_1, (1, 1, 16, 16), (256, 256, 16, 1)) assert_size_stride(arg9_1, (1, 1, 16), (16, 16, 1)) assert_size_stride(arg10_1, (1, 1, 16, 16), (256, 256, 16, 1)) with torch.cuda._DeviceGuard(1): torch.cuda.set_device(1) buf0 = empty_strided_cuda((1, 1, 2048), (2048, 2048, 1), torch.float32) # New: move into device guard buf1 = empty_strided_cuda((1, 1, 2048, 128), (262144, 262144, 128, 1), torch.bfloat16) # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] stream1 = get_raw_stream(1) triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, arg3_1, arg4_1, arg5_1, arg6_1, buf1, grid=torch._inductor.kernel.flex_attention.flex_attention_grid(1, 1, 2048, 128, meta0), stream=stream1) del arg0_1 del arg1_1 del arg2_1 del arg3_1 del arg4_1 del arg5_1 del arg6_1 del buf0 return (buf1, ) ``` Fixes pytorch#141010 Pull Request resolved: pytorch#141372 Approved by: https://github.com/eellison
1 parent 3fd51e0 commit 61a7c83

File tree

6 files changed

+52
-9
lines changed

6 files changed

+52
-9
lines changed

test/inductor/test_flex_attention.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3303,6 +3303,27 @@ def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
33033303
""", # noqa: B950
33043304
)
33053305

3306+
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
3307+
def test_device_cuda_1(self):
3308+
class TestModule(torch.nn.Module):
3309+
def forward(self, q, k, v, block_mask):
3310+
return flex_attention(q, k, v, block_mask=block_mask)
3311+
3312+
q = torch.randn(1, 1, 256, 32, device="cuda:1", dtype=torch.bfloat16)
3313+
k = torch.randn(1, 1, 256, 32, device="cuda:1", dtype=torch.bfloat16)
3314+
v = torch.randn(1, 1, 256, 32, device="cuda:1", dtype=torch.bfloat16)
3315+
mask = create_block_mask(
3316+
lambda b, h, q_idx, kv_idx: q_idx >= kv_idx,
3317+
B=None,
3318+
H=None,
3319+
Q_LEN=256,
3320+
KV_LEN=256,
3321+
device="cuda:1",
3322+
)
3323+
mod = torch.compile(TestModule())
3324+
attn_output = mod(q, k, v, mask)
3325+
self.assertEqual(attn_output.device, torch.device("cuda:1"))
3326+
33063327

33073328
class TestBlockMask(InductorTestCase):
33083329
@supported_platform

test/inductor/test_torchinductor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2494,6 +2494,13 @@ def fn(x):
24942494

24952495
self.common(fn, (torch.Tensor([]),))
24962496

2497+
@requires_multigpu()
2498+
def test_linspace4(self):
2499+
def fn(x):
2500+
return torch.linspace(0, 2, 0, device=f"{GPU_TYPE}:1")
2501+
2502+
self.common(fn, (torch.Tensor([]),))
2503+
24972504
def test_tensor1(self):
24982505
def fn(x):
24992506
return torch.tensor([1], device=x.device) + x, torch.tensor(

torch/_inductor/graph.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,8 +850,12 @@ def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str:
850850
device = buffer.get_device()
851851
if (
852852
# Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
853-
not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements())
854-
and device is not None
853+
device is not None
854+
and not (
855+
isinstance(buffer, ir.ComputedBuffer)
856+
and buffer.is_zero_elements()
857+
and device == torch.device("cpu")
858+
)
855859
):
856860
self.add_device_info(device)
857861

torch/_inductor/ir.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3176,8 +3176,10 @@ def __str__(self) -> str:
31763176
offset = ""
31773177
if self.offset != 0:
31783178
offset = f", offset={self.offset}"
3179+
3180+
device_index_str = "" if self.device.index is None else f":{self.device.index}"
31793181
return (
3180-
f"{type(self).__name__}('{self.device.type}', {self.dtype}, "
3182+
f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, "
31813183
f"size={self.size}, stride={self.stride}{offset})"
31823184
)
31833185

torch/_inductor/lowering.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3033,7 +3033,7 @@ def _new_constant(
30333033
dtype = decode_dtype(dtype) or x.get_dtype()
30343034
device = device or x.get_device()
30353035
size = [sympy.Integer(s) for s in size]
3036-
return _full(fill_value, device, dtype, size)
3036+
return _full(fill_value, decode_device(device), dtype, size)
30373037

30383038
return _new_constant
30393039

@@ -3045,7 +3045,12 @@ def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None)
30453045
if device is None:
30463046
device = x.get_device()
30473047
return empty_strided(
3048-
size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
3048+
size,
3049+
None,
3050+
dtype=dtype,
3051+
layout=layout,
3052+
device=decode_device(device),
3053+
pin_memory=pin_memory,
30493054
)
30503055

30513056

@@ -3059,6 +3064,7 @@ def empty_strided(
30593064
assert_nyi(layout in (None, torch.strided), f"layout={layout}")
30603065
dtype = decode_dtype(dtype) or torch.get_default_dtype()
30613066
device = device or torch.tensor(0.0).device
3067+
device = decode_device(device)
30623068
pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size)
30633069
pointwise.realize()
30643070
buffer = pointwise.data.data
@@ -3089,7 +3095,12 @@ def new_empty_strided(
30893095
if device is None:
30903096
device = x.get_device()
30913097
return empty_strided(
3092-
size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
3098+
size,
3099+
stride,
3100+
dtype=dtype,
3101+
layout=layout,
3102+
device=decode_device(device),
3103+
pin_memory=pin_memory,
30933104
)
30943105

30953106

torch/_inductor/scheduler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3511,9 +3511,7 @@ def _codegen(self) -> None:
35113511

35123512
self.enter_context(node)
35133513

3514-
if not isinstance(node, NopKernelSchedulerNode) and (
3515-
device := node.get_device()
3516-
):
3514+
if device := node.get_device():
35173515
if (
35183516
device != self.current_device
35193517
or node.is_extern()

0 commit comments

Comments
 (0)