Skip to content

[Bug/Correctness] Hardcoded device_id=0 + missing CUDAGuard can break multi-GPU correctness (wrong hw_info / stream mismatch) #158

@red1239109-cmd

Description

@red1239109-cmd

Hi maintainers,

While reviewing the FMHA forward runner integration, I noticed two correctness issues that can break execution on multi-GPU setups (and can also create subtle stream/device mismatches):

  1. Hardcoded GPU selection (device_id = 0)

In run_fmha_fwd, the hardware info is pinned to GPU0:

cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

If tensors (q/k/v/o/lse) live on a non-zero device, this will query the wrong SM count and may also lead to launching with incorrect hardware assumptions.

  1. Missing CUDA device guard (at::cuda::CUDAGuard) and stream/device alignment

The code uses at::cuda::getCurrentCUDAStream() at the end:

CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));

but does not guard/set the current device to match q.device() (or any input tensor). In multi-GPU scenarios, the “current device” may differ from the tensor’s device, leading to:

wrong stream/device being used

incorrect hw_info.device_id / sm_count

potential launch failures or silent misbehavior

Suggested fix

Use a device guard based on an input tensor (e.g., q) and set hw_info.device_id accordingly:

#include <ATen/cuda/CUDAGuard.h>

at::cuda::CUDAGuard device_guard(q.device());
const int dev = at::cuda::current_device();

cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = dev;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(dev);

CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));

This ensures:

correct device is active

stream matches the tensor device context

hardware info queries the right GPU

Why this matters

Even if most users run single-GPU, multi-GPU is common in training/inference servers. Hardcoding GPU0 + missing guards can produce correctness issues that are hard to diagnose (especially when the failure is not immediate).

If you'd like, I can provide a small repro snippet that places q/k/v/o on cuda:1 and shows the mismatch.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions