Conversation
|
Hi, do you have test cases for your implentation? |
tests/test_aiter_mha.py is the file I attached the shell script that runs all use cases. Results are like AITer MHA Test Suite Running test_batch_fwd_shape ... PASSED (50 passed) RESULTS Passed: 218 |
0df02f7 to
a6041a3
Compare
There was a problem hiding this comment.
Code Review: AITER MHA Kernels (ASM+CK) Integration
Commit: a6041a3b2c8ad9d6f86f071855e8523a3a49894a
Author: Zahid Iqbal
Scope: 3,644 lines added across 24 files -- integrates AITER multi-head attention forward/backward kernels for ROCm (CK + ASM v3) into JAX via XLA FFI.
1. Architecture Overview
The commit introduces a well-structured layered integration:
- Python API layer (
jax/_src/aiter/aiter_mha.py): Publicflash_attn_funcandflash_attn_varlenwithcustom_vjp, dispatching to unified forward/backward wrappers. - FFI bridge (
jaxlib/gpu/hip_aiter_mha_fwd.cc,hip_aiter_mha_bwd.cc): C++ handlers bound via XLA FFI, translating JAX buffers intoaiter::mha_fwd_args/aiter::mha_bwd_args. - Third-party headers (
third_party/aiter/): AITER's ownmha_fwd.h,mha_bwd.h, ASM kernel loader infrastructure. - Nanobind glue (
hip_aiter.cc): Registration of FFI targets.
The design of a single unified handler detecting batch (4D) vs varlen (3D) from tensor rank is clean and avoids handler proliferation.
2. Thread Safety and Data Race Issues (Critical Section)
2.1 CRITICAL: hipMemcpyAsync from stack-local buffer (use-after-scope risk)
In hip_aiter_mha_common_utils.cc, prepare_rng_state_for_fwd:
uint64_t host_rng[2] = {seed_value, offset_value};
hipError_t err =
hipMemcpyAsync(rng_state->untyped_data(), host_rng,
2 * sizeof(int64_t), hipMemcpyHostToDevice, stream);host_rng is a stack-allocated array. hipMemcpyAsync with hipMemcpyHostToDevice from pageable host memory is synchronous with respect to the host in the current HIP/CUDA runtime (the DMA engine must stage through a pinned bounce buffer), so this is de facto safe. However, this is an implementation-defined behavior -- future ROCm runtime changes could make this truly async, at which point host_rng would be read after it goes out of scope. The same pattern appears in the forward handler:
std::vector<float> neg_inf(n, -std::numeric_limits<float>::infinity());
HIP_CHECK(hipMemcpyAsync(lse->untyped_data(), neg_inf.data(),
n * sizeof(float), hipMemcpyHostToDevice, stream));Recommendation: Either use hipMemcpy (synchronous) explicitly, or use pinned memory, or add a hipStreamSynchronize after the copy.
2.2 CRITICAL: Reading device pointers on the host
In prepare_rng_state_for_fwd:
const auto *gen_data = static_cast<const int64_t *>(gen->untyped_data());
seed_value = static_cast<uint64_t>(gen_data[0]); // Host reads device memory!
offset_value = static_cast<uint64_t>(gen_data[1]); // Host reads device memory!
VLOG(1) << "Using provided generator with seed: " << seed_value << ...;If gen is a device buffer (which is the typical case in JAX -- all FFI buffers are device-side), dereferencing gen_data[0] from host code is undefined behavior. It will either segfault or return garbage. The subsequent hipMemcpyAsync(DeviceToDevice) is correct, but the VLOG will log incorrect values. The real concern is the UB itself.
Recommendation: Remove the host-side reads of gen_data[0]/gen_data[1], or use hipMemcpy to bring the values to host first if logging is needed.
2.3 HIGH: BwdDeviceBuffers destructor races with GPU work
In hip_aiter_mha_bwd.cc, BwdDeviceBuffers::~BwdDeviceBuffers() calls hipFree on buffers that were just passed to async kernels:
void free_all() {
if (dbias_expanded) { hipFree(dbias_expanded); ... }
if (dk_expanded) { hipFree(dk_expanded); ... }
...
}hipFree performs an implicit device synchronization (waits for ALL streams), so the data won't actually be freed while in use. However:
- This causes a full device sync on every backward pass, destroying any pipeline overlap.
- The implicit sync behavior is an implementation detail that could change.
- It makes the handler incompatible with CUDA Graphs / HIP Graphs and the
kCmdBufferCompatibletrait the handler claims.
Recommendation: Use a workspace pattern -- have JAX allocate the scratch space as additional output buffers, or use XLA's buffer allocation to pre-allocate scratch.
2.4 MEDIUM: Non-thread-safe get_gpu_arch() in header
In aiter_hip_common.h:
static const std::string get_gpu_arch() {
hipDevice_t dev;
HIP_CALL(hipGetDevice(&dev));
HIP_CALL(hipGetDeviceProperties(&dev_prop, dev));
...
}This is marked static in a header, so every translation unit gets its own copy. hipGetDevice returns the thread-local current device. If this is called from different threads with different active devices, each will get a different answer, which is the intended behavior. However, the combined pattern with get_num_cu_func():
static uint32_t get_num_cu_func() {
static const uint32_t num_cu = get_num_cu_local(); // initialized once
return num_cu;
}This caches the CU count from whichever device happens to be current during the first call. If the system has heterogeneous GPUs, subsequent calls from threads using different devices will get the wrong CU count. C++11 guarantees thread-safe initialization of static locals, but the value may be wrong for non-default devices.
2.5 MEDIUM: RNG seed collision under concurrency
When no generator is provided, the seed is derived from a timestamp:
seed_value = static_cast<uint64_t>(timestamp) ^ static_cast<uint64_t>(dev_idx);Two concurrent forward passes on the same device within the same microsecond will produce identical seeds, leading to identical dropout masks. This is a correctness issue for training with dropout.
Recommendation: Use a monotonic atomic counter or incorporate the stream pointer / buffer address as additional entropy.
2.6 LOW: _make_fwd_call / _make_bwd_call are not cached
Each call to mha_fwd_unified / mha_bwd_unified creates a fresh jax.ffi.ffi_call + jax.jit wrapper:
fn = _make_fwd_call(out_shape, lse_shape, p_shape, rng_shape, q.dtype)While JAX's tracing and compilation have their own caches, repeatedly creating new ffi_call objects and JIT wrappers adds unnecessary overhead. These should be memoized by (out_shape, lse_shape, p_shape, rng_shape, dtype).
3. Memory Management Issues
3.1 hipMalloc / hipFree on every backward call
The backward handler allocates up to 5 device buffers per invocation:
HIP_CHECK(hipMalloc(&bufs.dq_acc, dq_acc_bytes)); // always
HIP_CHECK(hipMalloc(&bufs.dk_expanded, dk_sz)); // if MQA/GQA
HIP_CHECK(hipMalloc(&bufs.dv_expanded, dv_sz)); // if MQA/GQA
HIP_CHECK(hipMalloc(&bufs.dbias_expanded, dbias_sz)); // if has_dbias
HIP_CHECK(hipMalloc(&bufs.dummy_rng, 2*sizeof(uint64_t)));// if no rnghipMalloc is synchronous and expensive (~microseconds to milliseconds). For training workloads where backward is called millions of times, this is a significant performance bottleneck.
Recommendation: Allocate these as JAX-managed output buffers in the Python wrapper (passing workspace shapes through the FFI), or use a persistent memory pool.
3.2 Large stack-allocated vector for LSE initialization
std::vector<float> neg_inf(n, -std::numeric_limits<float>::infinity());For large batch/sequence lengths, n can be very large (e.g., batch=64, heads=32, seq=2048 => n=4M floats = 16MB heap allocation). This should use a fill kernel instead.
4. Correctness Issues
4.1 Wrong sequence length passed to v3 eligibility check
In _flash_attn_backward:
_, sq, hq, dq = q.shape
_, sk, hk, _ = k.shape
...
use_v3 = _compute_v3_eligibility_bwd(
dropout_p, hq, hk, dq, causal, wl, wr, bias, sq, gfx # <-- sq, not sk
)The parameter is named sq_or_max_sk, and the check is:
if causal and gfx == "gfx950" and sq_or_max_sk > 256:
use_v3 = FalseFor the batch path, sq (query sequence length) is being passed where sk (key sequence length) is expected. When sq != sk (cross-attention), this can incorrectly enable or disable ASM v3. Compare with the varlen backward which correctly passes max_seqlen_k:
use_v3 = _compute_v3_eligibility_bwd(
res_dp, hq, hk, dq, causal, window_size[0], window_size[1],
None, max_seqlen_k, gfx # <-- correct: max_seqlen_k
)This is a bug for cross-attention cases on gfx950 with sq <= 256 < sk (causal).
4.2 kCmdBufferCompatible trait is incorrect
Both handlers are declared with:
{xla::ffi::Traits::kCmdBufferCompatible}But the backward handler calls hipMalloc, hipFree, hipPointerGetAttributes, and other operations that are not compatible with command buffer recording. The forward handler similarly calls hipPointerGetAttributes via device_from_ptr. This trait should be removed or the handlers refactored.
4.3 Variable shadowing in mha_fwd_unified
def mha_fwd_unified(q, k, v, ...):
...
dq = q.shape[-1] # shadows name 'dq' which in MHA context means gradient of q
use_v3_fwd = not (get_gfx() == "gfx950" and dq >= 96)While not a bug, dq meaning "head dimension of q" vs. "gradient of q" is confusing in MHA code. Consider renaming to hdim_q.
4.4 _flash_attn_func_bwd gradient count vs. custom_vjp nondiff argnums
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13))
def flash_attn_func(q, k, v, dropout_p, softmax_scale, causal, window_size,
bias, alibi_slopes, deterministic, return_lse,
return_attn_probs, cu_seqlens_q, cu_seqlens_kv):Differentiable args: indices 0,1,2,7,8 = q, k, v, bias, alibi_slopes. The backward returns:
return (dq, dk, dv, dbias, None) # 5 values for 5 diff argsThis is correct but note that None is returned for alibi_slopes gradient. If someone passes alibi_slopes that requires grad in the future, this will silently drop gradients. Consider adding a check or comment.
4.5 Missing softmax_scale default in varlen path differs from batch
In flash_attn_varlen, the default scale uses q.shape[-1] which is the head dimension for 3D input. In flash_attn_func, it uses q.shape[-1] which is also head dimension for 4D input. Both are correct, but the defensive check should be unified.
5. Error Handling Issues
5.1 HIP_CHECK calls std::abort() -- no graceful degradation
inline void hipCheck(hipError_t err, const char *file, int line) {
if (err != hipSuccess) {
LOG(ERROR) << ...;
std::abort();
}
}In the backward handler, if hipMalloc fails (e.g., OOM), the process is killed. For an FFI handler, it would be better to return ffi::Error with an appropriate error code and let JAX handle the error gracefully.
5.2 Bare catch (...) swallows errors
In hip_aiter_mha_bwd.cc:
try {
auto [s, o] = mha_utils::get_rng_seed_offset_ptrs(rng_state_, dropout_p);
seed_ptr = s; offset_ptr = o;
} catch (...) { /* fallthrough to dummy */ }This silently swallows any exception, including errors that indicate real problems (e.g., buffer too small). At minimum, log the exception.
5.3 dtype_to_string throws std::runtime_error
inline std::string dtype_to_string(xla::ffi::DataType dtype) {
...
default:
throw std::runtime_error("Unsupported dtype for MHA");
}Throwing across FFI boundaries is undefined behavior in many contexts. Should return ffi::Error instead.
6. Python-Side Issues
6.1 get_gfx() shells out to rocminfo at import time
@functools.lru_cache(maxsize=1)
def get_gfx() -> str:
...
result = subprocess.run([os.path.realpath(rocminfo)], ...)This is called during tracing, which happens at JIT compilation time. It's expensive, fragile (depends on rocminfo being on PATH), and won't work in sandboxed environments. The lru_cache helps but the first call can still be slow.
Recommendation: Query the GPU architecture through the HIP runtime API in C++ and expose it via an FFI call or Python binding, rather than shelling out.
6.2 FFI target name mismatch
In hip_aiter.cc:
dict[JAX_GPU_PREFIX "_mha_fwd_ffi"] = EncapsulateFfiHandler(aiter_mha_fwd);With JAX_GPU_PREFIX = "hip", this registers as "hip_mha_fwd_ffi". But the Python side calls:
jax.ffi.ffi_call("hip_mha_fwd_ffi", ...)The name must match exactly. If JAX_GPU_PREFIX is ever changed (e.g., to "rocm"), this silently breaks. Consider using a constant or deriving the name.
6.3 Redundant import
import functools
...
from functools import partialBoth functools and functools.partial are imported. Minor cleanup opportunity.
6.4 Missing newline at end of file
Multiple files (aiter_mha.py, __init__.py, several C++ files) are missing the trailing newline. This causes \ No newline at end of file in diffs and can confuse some tools.
7. Test Coverage Assessment
The test file (test_aiter_mha.py, 551 lines) is quite comprehensive:
Strengths:
- TE-style tolerance computation (
eps^(2/3)) is appropriate for mixed-precision - Good parametric coverage: head dims 32-256, both dtypes, MHA/GQA/MQA
- Regression guards for specific historical bugs (good practice)
- Both batch and varlen paths tested
- Edge cases: sq=1 (decode), sq>sk, sq<sk, large batch, single head
Gaps:
- No concurrent execution tests -- no test for thread-safety or multi-stream correctness
- No dropout numerical correctness test -- only shape/crash tests (acknowledged in docstring, but this means dropout correctness is untested)
- No test for the
logits_soft_capparameter -- it's accepted but never tested - No test for
zero_tensors=True-- the flag is accepted but the defaultFalsepath is always taken - No test for
min_seqlen_qparameter - No negative tests -- no tests for invalid inputs (wrong dtypes, mismatched shapes, etc.)
- No multi-GPU tests
- Reference implementation in tests doesn't support GQA/MQA, so GQA/MQA accuracy is only tested for shape/finiteness, not numerical correctness
8. Build System Issues
8.1 Missing .so files in third_party/aiter/
The BUILD file references:
cc_import(
name = "mha_fwd_so",
shared_library = "libmha_fwd.so",
)But libmha_fwd.so and libmha_bwd.so are not included in the commit. There's no documentation for how these shared libraries should be built or obtained.
8.2 linkopts = ["-Wl,-rpath,$$ORIGIN"]
This sets RPATH to look for .so files in the same directory as the binary. This is correct for deployment but needs the .so files to be co-located at runtime.
9. Summary of Findings by Severity
| Severity | Count | Key Items |
|---|---|---|
| Critical | 2 | Host reads of device pointers (UB); hipMemcpyAsync from stack buffers |
| High | 2 | hipFree in destructor syncs device (perf); hipMalloc per backward call |
| Medium | 3 | Wrong sq vs sk in v3 eligibility; incorrect kCmdBufferCompatible; RNG seed collision |
| Low | 6+ | Missing tests for dropout/logits_soft_cap/zero_tensors; std::abort on OOM; swallowed exceptions; style issues |
The most actionable items are:
- Fix the host-side device pointer reads in
prepare_rng_state_for_fwd - Fix
sqvsskin_flash_attn_backward's v3 eligibility check - Remove
kCmdBufferCompatiblefrom both handlers - Replace per-call
hipMalloc/hipFreewith a workspace pattern for the backward pass
| def _pad_to_multiple_of_8(q, k, v): | ||
| """Pad head dimensions of Q/K/V to the next multiple of 8 if needed. | ||
|
|
||
| Returns (q_padded, k_padded, v_padded, hd_q_original, hd_v_original). | ||
| """ | ||
| hd_q = q.shape[-1] | ||
| hd_v = v.shape[-1] | ||
| q_p, k_p, v_p = q, k, v | ||
| ndim = q.ndim | ||
| if hd_q % 8 != 0: | ||
| pad = 8 - hd_q % 8 | ||
| pw = tuple((0, 0) for _ in range(ndim - 1)) + ((0, pad),) | ||
| q_p = jnp.pad(q, pw) | ||
| k_p = jnp.pad(k, pw) | ||
| if hd_v % 8 != 0: | ||
| pad = 8 - hd_v % 8 | ||
| pw = tuple((0, 0) for _ in range(ndim - 1)) + ((0, pad),) | ||
| v_p = jnp.pad(v, pw) | ||
| return q_p, k_p, v_p, hd_q, hd_v |
There was a problem hiding this comment.
Is this padding for enabling the ck/aiter flow or for pure performance? We didn't see configs with hdim not 8 multiples from external customers before so we didn't do it in TE. But I was curious about the overall performance comparison between padding-->ck/aiter-->unpadding vs passing original config to ck/aiter
There was a problem hiding this comment.
The padding is a hard functional requirement, not a performance optimization. The CK and ASM v3 kernels will produce incorrect results or crash if head dimensions are not multiples of 8
The upstream AITer mha.py (the PyTorch version) has the identical logic in every public entry point, at line 1760:
And then after the kernel runs, the output is sliced back: line 1845
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
dk = dk[..., :head_size_q_og]
dv = dv[..., :head_size_v_og]
This is done in flash_attn_func, flash_attn_varlen_func, and flash_attn_fp8_pertensor_func — every single public API.
ASM v3 eligibility explicitly requires hdim % 8 == 0
aiter: mha.py:line1543, ret &= hdim_q >= 64 and hdim_q <= 192 and hdim_q % 8 == 0
| def _compute_v3_eligibility_bwd( | ||
| dropout_p, hq, hk, dq, causal, wl, wr, bias, sq_or_max_sk, gfx | ||
| ): | ||
| """Shared ASM v3 eligibility check for backward pass (batch & varlen).""" | ||
| swa = (wl > 0) or (wr >= 0 and wr != -1) | ||
| use_v3 = True | ||
| if dropout_p > 0: | ||
| use_v3 = False | ||
| if hq != hk: | ||
| use_v3 = False | ||
| if bias is not None and bias.size > 0: | ||
| use_v3 = False | ||
| if swa: | ||
| use_v3 = False | ||
| if causal and gfx == "gfx950" and sq_or_max_sk > 256: | ||
| use_v3 = False | ||
| if gfx == "gfx950" and dq >= 96: | ||
| use_v3 = False | ||
| return use_v3 |
There was a problem hiding this comment.
AITER had its internal v3 api checking and will fallback to v2 even if you requested v3 asm
There was a problem hiding this comment.
33 test crashed without this check... v3_api_check is user controlled in aiter, and controlled in benchmark tests, i think switch between v3 and v2
| gen = _empty(jnp.int64) | ||
|
|
||
| rng_shape = (2,) | ||
| bf16_cvt = 0 if get_gfx() == "gfx950" else 1 |
There was a problem hiding this comment.
bf16_cvt can also be set to different values in gfx942, refer to TE readme for more details: https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#aiter-fa-v3-kernels
| if (dbias_expanded) { hipFree(dbias_expanded); dbias_expanded = nullptr; } | ||
| if (dummy_rng) { hipFree(dummy_rng); dummy_rng = nullptr; } | ||
| if (dq_acc) { hipFree(dq_acc); dq_acc = nullptr; } | ||
| if (dk_expanded) { hipFree(dk_expanded); dk_expanded = nullptr; } | ||
| if (dv_expanded) { hipFree(dv_expanded); dv_expanded = nullptr; } |
There was a problem hiding this comment.
Is it possible to request a buffer from jax? In other words, the jax will manage those extra buffers like dq_acc, softmax_lse_buffer and so on? Calling hipMalloc and hipFree could be heavy for e2e training if you need to run this every iteration
There was a problem hiding this comment.
this is cc file (not python) where custom call launches kernels, no jax there.
| """ | ||
|
|
||
| try: | ||
| from .aiter_mha import ( |
There was a problem hiding this comment.
This doesn't look like a typical 2 space formatting that most of other files use, that's written in pyproject.toml.
Can you run ruff on all your python files to ensure format consistency?
There was a problem hiding this comment.
Done, ruff removed few extra lines
There was a problem hiding this comment.
Did you push changes? I see the same 4 space indentation...
| BwdDeviceBuffers &operator=(const BwdDeviceBuffers &) = delete; | ||
| }; | ||
|
|
||
| static size_t compute_dq_acc_size_unified( |
There was a problem hiding this comment.
Instead of making functions static to hide them from a linker, it's safer to put them into an anonymous namespace at the topmost namespace level (i.e. not enclosed into any other namespace). This is also how JAX/XLA typically do it.
| hipStream_t stream, | ||
| ffi::AnyBuffer dout, ffi::AnyBuffer q, ffi::AnyBuffer k, ffi::AnyBuffer v, | ||
| ffi::AnyBuffer out, ffi::AnyBuffer softmax_lse, | ||
| std::optional<ffi::AnyBuffer> cu_seqlens_q_, |
There was a problem hiding this comment.
is this file taken from AITER? A trailing underscore in a variable name is a hallmark of a class data member variable in Google C++ style guide. I don't think this can ever be upstreamed and it's better to fix it right now.
There was a problem hiding this comment.
code cleaned up, trailing underscore removed from all cc files
3404a0d to
720431d
Compare
| #include <hip/hip_runtime.h> | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| typedef hip_bfloat16 hip_bf16_type; |
There was a problem hiding this comment.
There's a long standing suggestion to not use typedef specifier and instead use using = https://dev.cppreference.com/cpp/language/type_alias type aliasing declaration, which is much more natural and easier to understand. Old code doesn't have to be amended, but there's no point to increase the number of typedefs uses instead of decreasing it.
| if (dtype == xla::ffi::DataType::F16) { | ||
| mqa_gqa_reduce_kernel<__half><<<blocks, threads, 0, stream>>>( | ||
| static_cast<const __half *>(src), static_cast<__half *>(dst), | ||
| batch_size, seqlen_k, num_heads_q, num_heads_k, head_size, groups); | ||
| } else if (dtype == xla::ffi::DataType::BF16) { | ||
| mqa_gqa_reduce_kernel<hip_bf16_type><<<blocks, threads, 0, stream>>>( | ||
| static_cast<const hip_bf16_type *>(src), | ||
| static_cast<hip_bf16_type *>(dst), batch_size, seqlen_k, num_heads_q, | ||
| num_heads_k, head_size, groups); | ||
| } |
There was a problem hiding this comment.
What if dtype is neither F16, not BF16? It'll silently do nothing, and that's not correct. Some kind of error handling is needed here; a silent misbehavior is a no-go..
| } | ||
|
|
||
| out_ptrs.seed = rng_state_ptr; | ||
| out_ptrs.offset = rng_state_ptr + 1; |
There was a problem hiding this comment.
While at this stage there's no UB yet, a read from any of out_ptrs will produce one, since rng_state_ptr isn't a valid pointer to uint64. For example, can you guarantee that alignment requirements of uint64 are fulfilled by (likely) a byte storage of (rng_state->untyped_data() ? Bytes doesn't have alignment requirements at all, so it won't necessary start on a 64-bit byte boundary required by unit64.
There are other risks too, it's a complex issue: https://arech.github.io/2024-08-17-reinterpret_cast-ub-and-a-pointer-casting-in-c++
There are two alternative ways to solve this properly and safely:
- can you instead make
RngStatePointersto bestd::byte* ptr+size_t sizebased, instead of relying on a fixed widthuint64? At least RNGs in the Standard Library are capable of consumingbyte*seeds, as well as many other properly designed ones. Beyond anything, a seed having just 64 bits could only properly seed the state of 64bit generator, which is criminally tiny. Most modern RNGs have much wider states, so seeding them with just 64 bits leaves a major chunk of randomness uncovered. This would be the best solution, as this addresses both, the UB and RNG initialization issue. - If you must do type punning in C++ prior to ability to C++20 Standard Lib, you must do it with
std::memcpy(). An example of this is shown in the linked article.
| namespace jax_aiter { | ||
| namespace mha_utils { | ||
|
|
||
| inline std::string dtype_to_string(xla::ffi::DataType dtype) { |
There was a problem hiding this comment.
you always return a statically allocated and stored string. Why do wrap them into dynamically allocated std::string while it could simply be a std::string_view? It's just a waste of cycles, potentially a huge one.
| default: | ||
| return 4; |
There was a problem hiding this comment.
Does this default suit all possible xla::ffi::DataType types including not added yet? Surely not, it must throw instead.
| if (is_causal) { | ||
| window_size_right = 0; | ||
| std::string mask_identify = "b:" + std::to_string(window_size_left) + ",0"; | ||
| mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); |
There was a problem hiding this comment.
::decode() is likely badly designed and doesn't actually need the first argument to be std::string, but you could at least have moved the mask_identify into the argument, instead of copying it.
There was a problem hiding this comment.
a better solution could be though allocating a char _buf[16] on the stack, std::snprintf() into it, and then passing just _buf as an argument to let the mask_info::decode() construct string out of it. This will avoid dumb reallocations in concatenating the strings.
| } else { | ||
| std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + | ||
| std::to_string(window_size_right); | ||
| mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); |
There was a problem hiding this comment.
same here, std::move(mask_identify) into the arg.
| auto *base_ptr = | ||
| static_cast<uint64_t *>(const_cast<void *>(rng_state.untyped_data())); |
There was a problem hiding this comment.
Use of const_cast<> is a pretty sure sign of bad design, leading to bugs that are very hard to identify and debug. It might be acceptable some narrow cases, for example, when dealing with a badly designed external API, or when an external API must be used in a unexpected manner. But I'm not sure this is the case here. Basically, both ends are APIs of this project. Let's have a call where you'll describe me the gist of that and I'll help you design a better solution. Const correctness is important to do right, otherwise it could lead to very nasty bugs.
| int64_t seqlen_k_or_max, int64_t num_heads, int64_t head_size, | ||
| bool deterministic, bool use_asm_v3, bool is_v3_atomic_fp32, | ||
| ffi::DataType q_dtype, std::vector<int64_t> &out_shape) | ||
| noexcept { |
There was a problem hiding this comment.
what is the value of having noexcept here? Is exception firewall useful here?
| try { | ||
| auto [s, o] = mha_utils::get_rng_seed_offset_ptrs(rng_state, dropout_p); | ||
| seed_ptr = s; offset_ptr = o; | ||
| } catch (...) { /* fallthrough to dummy */ } |
There was a problem hiding this comment.
so instead of presenting an error message to a user it just silently do something unexpected. Are you sure this is correct way to handle this?
| strides[rank - 1] = 1; | ||
| for (int i = rank - 2; i >= 0; i--) | ||
| strides[i] = strides[i + 1] * dq_acc_shape[i + 1]; | ||
|
|
There was a problem hiding this comment.
I don't like the idea of triggering dynamic memory for such simple calculations especially given that you'll use only first 4 elements. With a very simple modification you can only store 4 values of strides, which could be trivially simple done with std::array<index_t, 4>. Can you fix that?
| const ck_tile::index_t *cu_seqlen_k_ptr = nullptr; | ||
|
|
||
| if (is_varlen) { | ||
| seqstart_q_ptr = reinterpret_cast<const ck_tile::index_t *>(cu_seqlens_q->untyped_data()); |
There was a problem hiding this comment.
what are these cu_seqlens_q & related buffers? Where are pointers such as seqstart_q_ptr dereferenced?
|
|
||
| def _is_rocm_platform() -> bool: | ||
| """Return True when *any* ROCm plugin wheel is installed.""" | ||
| for name in ("jax_rocm7_plugin", "jax_rocm6_plugin"): |
There was a problem hiding this comment.
does this have to support rocm6? Do we even have a pipeline for it? How did you test it works?
There was a problem hiding this comment.
apparently no rocm6 support. but this gives warning only in case rocm is installed with no aiter ffi registration. I modify the code in a way to check for future plugins, no need to change it (like in case for incoming jax_rocm8_plugin)
from .plugin_support import _PLUGIN_MODULE_NAMES, import_from_plugin
def _is_rocm_platform() -> bool:
"""Return True when any ROCm plugin wheel is installed."""
return any(
importlib.util.find_spec(name) is not None
for name in _PLUGIN_MODULE_NAMES.get("rocm", [])
)
| #include "aiter_logger.h" | ||
| #include "ck_tile/core.hpp" | ||
| #include <cstdint> | ||
| #include <hip/hip_runtime.h> | ||
| #include <iostream> | ||
| #ifdef AITER_EMBEDDED_HSA_HEADER |
There was a problem hiding this comment.
Order of includes here and in basically most of other C++ files are a bit...weird. Please follow this: https://google.github.io/styleguide/cppguide.html#Names_and_Order_of_Includes
There was a problem hiding this comment.
This header file is from rocm/aiter, not a good idea to change it in our way ...
| printf("\n[AITER] %s:%d fail to call %s ---> [HIP error](%s)\n", \ | ||
| __FILE__, \ | ||
| __LINE__, \ | ||
| #call, \ | ||
| hipGetErrorString(err)); \ | ||
| exit(0); \ |
There was a problem hiding this comment.
Most likely, this print won't be visible to a user, since no stream flush has been done. A better idea is to write into an unbuffered STDERR with std::fprintf(stderr,...) - it doesn't mandate flush before doing exit() and is more appropriate for error reporting.
Also given that the macro is called from destructors, use of std::terminate() might be more appropriate here than exit(), which also calls other destructors.
There was a problem hiding this comment.
This header file is from rocm/aiter, not a good idea to change it in our way ...
| }; | ||
| }; | ||
|
|
||
| static const std::string get_gpu_arch() |
There was a problem hiding this comment.
static functions shouldn't be declared in the headers! This duplicate the function in each translation unit that references the header. It should be made inline instead to stimulate weak linkage and require linker to use a single function instance for the whole program.
There was a problem hiding this comment.
also const on the return type is meaningless at best. Please remove it.
There was a problem hiding this comment.
This header file is from rocm/aiter, not a good idea to change it in our way ...
| HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); | ||
|
|
||
| std::string arch_full = dev_prop.gcnArchName; | ||
| size_t colon_pos = arch_full.find(':'); |
There was a problem hiding this comment.
btw, not just this line, but the whole file should be reformatted according to JAX's format rules...
There was a problem hiding this comment.
This header file is from rocm/aiter, not a good idea to change it in our way ...
There was a problem hiding this comment.
in addition to formatting there are more important issues with the file:
- all declarations pollute global namespace. Please move all symbols to some dedicated namespace, such as
jax_aiteror smth like that. - the file is universally included and it defines several generic macros, that could collide with similarly named macros elsewhere. Please use
#undefdirective at the end of the file to make definitions visible only within the file.
There was a problem hiding this comment.
This header file is from rocm/aiter, not a good idea to change it in our way ...
| import jax.numpy as jnp | ||
| import numpy as np | ||
|
|
||
| jax.config.update("jax_enable_x64", True) |
There was a problem hiding this comment.
wait-wait-wait, enabling 64 bit floating point affects the whole program and brings in a serious performance penalty. Why does it even have to be here?
There was a problem hiding this comment.
Good catch!
The key insight: the C++ FFI side uses AnyBuffer and untyped_data() for all RNG/gen buffers -- it only cares about the total byte count, not the element type. The RNG state is 16 bytes (2 * sizeof(int64_t)). So we can represent it as (4,) int32 (4 * 4 = 16 bytes) instead of (2,) int64.
For the size-0 placeholder arrays (_empty(jnp.int64)), the dtype is irrelevant since 0 elements = 0 bytes regardless.
I completely eliminate the x64 dependency entirely without effecting the functionality...
| dropout_p: float = 0.0, | ||
| softmax_scale: Optional[float] = None, | ||
| causal: bool = False, | ||
| window_size: Tuple[int, int] = (-1, -1), |
There was a problem hiding this comment.
fyi: there's no need to use Tuple from typing, just regular tuple works, but it's faster
gulsumgudukbay
left a comment
There was a problem hiding this comment.
workspace.bzl does not exist! Also I have some comments.
|
|
||
| xla_workspace3() | ||
|
|
||
| load("//third_party/aiter:workspace.bzl", "aiter_mha_whl") |
There was a problem hiding this comment.
third_party/aiter/workspace.bzl is missing from the PR
Without it, WORKSPACE parsing fails on the load(...) line before any target can be built.
When it's added, please ensure it:
pins a specific GitHub release URL (something like https://github.com/ROCm/jax/releases/download/AITER-MHA/ not "latest"),
specifies an sha (it is for safety)
supplies a build_file or build_file_content that exposes libmha_fwd.so and libmha_bwd.so (e.g. as filegroup and/or cc_import targets) with visibility = ["//visibility:public"],
gracefully no-ops on non-Linux / non-ROCm hosts (if you don't do this, CUDA-only builds will try to download it).
| @@ -0,0 +1,41 @@ | |||
| package(default_visibility = ["//visibility:public"]) | |||
|
|
|||
There was a problem hiding this comment.
ck_tile/* headers not provided!
aiter/mha_fwd.h, mha_bwd.h, aiter_hip_common.h include ck_tile/core.hpp, ck_tile/ops/fmha_{fwd,bwd}.hpp, ck_tile/ops/mask.hpp.
Nothing in third_party/aiter/BUILD or :hip_aiter_mha deps provides them.
|
|
||
| cc_import( | ||
| name = "mha_fwd_so", | ||
| shared_library = "@aiter_mha_wheel//:libmha_fwd.so", |
There was a problem hiding this comment.
I am not sure about the validity of this and line 23. if this .so file is exposed by the workspace.bzl as a filegroup, this may fail. Make sure BUILD uses exports_files([libmha_fwd.so","libmha_bwd.so"])
| "//jaxlib/gpu:hip_aiter.h", | ||
| "//jaxlib/gpu:hip_aiter_mha_common_utils.h", | ||
| ]), | ||
| deps = [ |
There was a problem hiding this comment.
why did you not wrap this to if_rocm_is_configured?
it may break for non-ROCm users when you upstream this PR.
| ) | ||
|
|
||
| if wheel_sources: | ||
| for src in wheel_sources: |
There was a problem hiding this comment.
This loop can silently skip! it only copies if basename matches libmha_{fwd,bwd}.so. If absent, wheel ships without them and _aiter.so fails to dlopen at runtime. Fail loudly when ROCm is enabled and the libs aren't found.
| ) | ||
|
|
||
| exports_files(srcs = [ | ||
| "hip_aiter_mha_fwd.cc", |
There was a problem hiding this comment.
please make these alphabetical to match the style of the file.
| @@ -0,0 +1,41 @@ | |||
| package(default_visibility = ["//visibility:public"]) | |||
There was a problem hiding this comment.
missing license header & trailing newline
| deps = [":aiter_headers"], | ||
| ) | ||
|
|
||
| cc_library( |
There was a problem hiding this comment.
aiter_mha doesn't list aiter_headers in deps explicitly. Headers leak through cc_import -> aiter_headers. you may need to add aiter_headers directly so the utilization is explicit.
|
|
||
| pytype_strict_library( | ||
| name = "aiter", | ||
| srcs = glob(["aiter/**/*.py"]), |
There was a problem hiding this comment.
Use explicit srcs = ["aiter/init.py", "aiter/aiter_mha.py"] instead of glob(["aiter/**/*.py"]) in jax/_src/BUILD.
6f80b73 to
f461fcf
Compare
Integrated AMD's AITER library to provide high-performance multi-head
attention (MHA) forward and backward kernels on ROCm GPUs. AITER
dispatches internally between CK (Composable Kernel) and hand-tuned
ASM v3 assembly kernels for optimal performance on supported
architectures (e.g.gfx942, gfx950 + ).
Public API:
Both APIs support causal masking, sliding window attention, dropout,
ALiBi, GQA/MQA head layouts, and non-standard head dimensions
(automatically padded to multiples of 8).
Implementation:
wrap aiter::mha_fwd / aiter::mha_bwd with unified batch/varlen
dispatch based on tensor rank (4D=batch, 3D=varlen).
stride calculation, mask/bias construction, MQA/GQA reduction
kernels, and RNG state management.
hip_mha_fwd_ffi and hip_mha_bwd_ffi.
head-dim padding, ASM v3 eligibility checks, and gfx950-specific
guards.
jax-rocm plugin wheel via import_from_plugin.
ASM v3 kernel eligibility on gfx950:
causal with sq > sk configurations.
New files:
jax/_src/aiter/init.py
jax/_src/aiter/aiter_mha.py
jaxlib/gpu_aiter.py
jaxlib/gpu/hip_aiter.h
jaxlib/gpu/hip_aiter.cc
jaxlib/gpu/hip_aiter_mha_fwd.cc
jaxlib/gpu/hip_aiter_mha_bwd.cc
jaxlib/gpu/hip_aiter_mha_common_utils.h
jaxlib/gpu/hip_aiter_mha_common_utils.cc
third_party/aiter/{BUILD,include/}
tests/test_aiter_mha.py
Modified files:
jax/_src/BUILD (add aiter target)
jax/_src/lib/init.py (import gpu_aiter)
jaxlib/BUILD (add gpu_aiter.py)
jaxlib/gpu/BUILD (export aiter sources)
jaxlib/rocm/BUILD (aiter library + nanobind targets)
jaxlib/tools/build_wheel.py (include gpu_aiter.py in jaxlib wheel)
jaxlib/tools/build_gpu_kernels_wheel.py (include _aiter.so in plugin wheel)
jaxlib/tools/BUILD.bazel (aiter runtime .so in plugin wheel)