Fix VMEM OOM and benchmark issue in Pallas guide#2284
Fix VMEM OOM and benchmark issue in Pallas guide#2284jeffcarp wants to merge 1 commit intokeras-team:masterfrom
Conversation
Summary of ChangesHello @jeffcarp, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses critical functional and performance issues within the custom kernel guide. It rectifies a VMEM OOM error that occurred on TPUs with recent JAX versions and corrects an inaccurate benchmark that showed the Pallas kernel as slower. The changes ensure the guide operates reliably and accurately showcases the intended performance benefits of custom Pallas kernels. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request aims to fix a VMEM OOM error and a benchmark issue in the Pallas guide. The proposed change involves tiling the contracting dimension in the fused_matmul function. While this change does resolve the OOM error, it introduces a critical bug where the matrix multiplication becomes mathematically incorrect because it omits the accumulation step over the tiled dimension. This leads to wrong results and misleadingly fast benchmark times. My review includes a critical comment on this issue with a suggested fix that correctly implements the tiled matrix multiplication while preventing the OOM error. I've also noted that the example output in the markdown file needs to be updated accordingly.
| tile_m, tile_k, tile_n = 128, 128, 128 | ||
| assert ( | ||
| m % tile_m == 0 and n % tile_n == 0 | ||
| m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0 | ||
| ), "Inputs must be multiples of 128 for this demo" | ||
|
|
||
| return pl.pallas_call( | ||
| matmul_relu_kernel, | ||
| # Map output indices to input blocks | ||
| out_shape=jax.ShapeDtypeStruct((m, n), a.dtype), | ||
| in_specs=[ | ||
| # For each output tile, we take a slice of A of shape (tile_m, k) | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j: (i, 0), block_shape=(tile_m, k) | ||
| ), # Matrix A | ||
| # For each output tile, we take a slice of B of shape (k, tile_n) | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j: (0, j), block_shape=(k, tile_n) | ||
| ), # Matrix B | ||
| # For each output tile, we take a (tile_m, tile_k) slice of A | ||
| pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, tile_k)), | ||
| # For each output tile, we take a (tile_k, tile_n) slice of B | ||
| pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(tile_k, tile_n)), | ||
| ], | ||
| out_specs=pl.BlockSpec( | ||
| index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n) | ||
| ), # Matrix C | ||
| ), | ||
| grid=(m // tile_m, n // tile_n), | ||
| )(a, b) |
There was a problem hiding this comment.
This implementation is mathematically incorrect. By changing block_shape to use tile_k but not adding a loop to accumulate over the k dimension, the function now computes a partial matrix multiplication. This explains why the benchmark is faster—it's performing significantly less computation, leading to incorrect results.
To correctly implement tiling for the contracting dimension while avoiding the OOM error, the kernel needs to loop over the k dimension tiles and accumulate the results. The original matmul_relu_kernel is not changed in this PR and is not suitable for this. A new kernel with a loop is needed. The OOM from double-buffering can be prevented by disabling pipelining via compiler_params.
tile_m, tile_k, tile_n = 128, 128, 128
assert (
m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0
), "Inputs must be multiples of 128 for this demo"
# This kernel loops over the k-dimension to perform a full matmul.
def matmul_relu_kernel_loop(a_ref, b_ref, c_ref):
acc = jnp.zeros((tile_m, tile_n), dtype=a.dtype)
for ki in range(k // tile_k):
a_tile = pl.load(a_ref, (pl.dslice(0, tile_m), pl.dslice(ki * tile_k, tile_k)))
b_tile = pl.load(b_ref, (pl.dslice(ki * tile_k, tile_k), pl.dslice(0, tile_n)))
acc += jnp.dot(a_tile, b_tile)
c_ref[...] = jnp.maximum(acc, 0)
return pl.pallas_call(
matmul_relu_kernel_loop,
out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),
in_specs=[
pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)),
pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(k, tile_n)),
],
out_specs=pl.BlockSpec(
index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)
),
grid=(m // tile_m, n // tile_n),
# Disable pipelining (double buffering) to avoid OOM on TPUs
compiler_params={"mosaic": {"num_stages": 1}},
)(a, b)| " tile_m, tile_k, tile_n = 128, 128, 128\n", | ||
| " assert (\n", | ||
| " m % tile_m == 0 and n % tile_n == 0\n", | ||
| " m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0\n", | ||
| " ), \"Inputs must be multiples of 128 for this demo\"\n", | ||
| "\n", | ||
| " return pl.pallas_call(\n", | ||
| " matmul_relu_kernel,\n", | ||
| " # Map output indices to input blocks\n", | ||
| " out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),\n", | ||
| " in_specs=[\n", | ||
| " # For each output tile, we take a slice of A of shape (tile_m, k)\n", | ||
| " pl.BlockSpec(\n", | ||
| " index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)\n", | ||
| " ), # Matrix A\n", | ||
| " # For each output tile, we take a slice of B of shape (k, tile_n)\n", | ||
| " pl.BlockSpec(\n", | ||
| " index_map=lambda i, j: (0, j), block_shape=(k, tile_n)\n", | ||
| " ), # Matrix B\n", | ||
| " # For each output tile, we take a (tile_m, tile_k) slice of A\n", | ||
| " pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, tile_k)),\n", | ||
| " # For each output tile, we take a (tile_k, tile_n) slice of B\n", | ||
| " pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(tile_k, tile_n)),\n", | ||
| " ],\n", | ||
| " out_specs=pl.BlockSpec(\n", | ||
| " index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)\n", | ||
| " ), # Matrix C\n", | ||
| " ),\n", | ||
| " grid=(m // tile_m, n // tile_n),\n", | ||
| " )(a, b)\n", |
There was a problem hiding this comment.
This implementation is mathematically incorrect. By changing block_shape to use tile_k but not adding a loop to accumulate over the k dimension, the function now computes a partial matrix multiplication. This explains why the benchmark is faster—it's performing significantly less computation, leading to incorrect results.
To correctly implement tiling for the contracting dimension while avoiding the OOM error, the kernel needs to loop over the k dimension tiles and accumulate the results. The original matmul_relu_kernel is not changed in this PR and is not suitable for this. A new kernel with a loop is needed. The OOM from double-buffering can be prevented by disabling pipelining via compiler_params.
" tile_m, tile_k, tile_n = 128, 128, 128\n",
" assert (\n",
" m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0\n",
" ), \"Inputs must be multiples of 128 for this demo\"\n",
"\n",
" # This kernel loops over the k-dimension to perform a full matmul.\n",
" def matmul_relu_kernel_loop(a_ref, b_ref, c_ref):\n",
" acc = jnp.zeros((tile_m, tile_n), dtype=a.dtype)\n",
" for ki in range(k // tile_k):\n",
" a_tile = pl.load(a_ref, (pl.dslice(0, tile_m), pl.dslice(ki * tile_k, tile_k)))\n",
" b_tile = pl.load(b_ref, (pl.dslice(ki * tile_k, tile_k), pl.dslice(0, tile_n)))\n",
" acc += jnp.dot(a_tile, b_tile)\n",
" c_ref[...] = jnp.maximum(acc, 0)\n",
"\n",
" return pl.pallas_call(\n",
" matmul_relu_kernel_loop,\n",
" out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),\n",
" in_specs=[\n",
" pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)),\n",
" pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(k, tile_n)),\n",
" ],\n",
" out_specs=pl.BlockSpec(\n",
" index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)\n",
" ),\n",
" grid=(m // tile_m, n // tile_n),\n",
" # Disable pipelining (double buffering) to avoid OOM on TPUs\n",
" compiler_params={\"mosaic\": {\"num_stages\": 1}},\n",
" )(a, b)\n"
| tile_m, tile_k, tile_n = 128, 128, 128 | ||
| assert ( | ||
| m % tile_m == 0 and n % tile_n == 0 | ||
| m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0 | ||
| ), "Inputs must be multiples of 128 for this demo" | ||
|
|
||
| return pl.pallas_call( | ||
| matmul_relu_kernel, | ||
| # Map output indices to input blocks | ||
| out_shape=jax.ShapeDtypeStruct((m, n), a.dtype), | ||
| in_specs=[ | ||
| # For each output tile, we take a slice of A of shape (tile_m, k) | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j: (i, 0), block_shape=(tile_m, k) | ||
| ), # Matrix A | ||
| # For each output tile, we take a slice of B of shape (k, tile_n) | ||
| pl.BlockSpec( | ||
| index_map=lambda i, j: (0, j), block_shape=(k, tile_n) | ||
| ), # Matrix B | ||
| # For each output tile, we take a (tile_m, tile_k) slice of A | ||
| pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, tile_k)), | ||
| # For each output tile, we take a (tile_k, tile_n) slice of B | ||
| pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(tile_k, tile_n)), | ||
| ], | ||
| out_specs=pl.BlockSpec( | ||
| index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n) | ||
| ), # Matrix C | ||
| ), | ||
| grid=(m // tile_m, n // tile_n), | ||
| )(a, b) |
There was a problem hiding this comment.
This implementation is mathematically incorrect. By changing block_shape to use tile_k but not adding a loop to accumulate over the k dimension, the function now computes a partial matrix multiplication. This explains why the benchmark is faster—it's performing significantly less computation, leading to incorrect results.
To correctly implement tiling for the contracting dimension while avoiding the OOM error, the kernel needs to loop over the k dimension tiles and accumulate the results. The original matmul_relu_kernel is not changed in this PR and is not suitable for this. A new kernel with a loop is needed. The OOM from double-buffering can be prevented by disabling pipelining via compiler_params.
tile_m, tile_k, tile_n = 128, 128, 128
assert (
m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0
), "Inputs must be multiples of 128 for this demo"
# This kernel loops over the k-dimension to perform a full matmul.
def matmul_relu_kernel_loop(a_ref, b_ref, c_ref):
acc = jnp.zeros((tile_m, tile_n), dtype=a.dtype)
for ki in range(k // tile_k):
a_tile = pl.load(a_ref, (pl.dslice(0, tile_m), pl.dslice(ki * tile_k, tile_k)))
b_tile = pl.load(b_ref, (pl.dslice(ki * tile_k, tile_k), pl.dslice(0, tile_n)))
acc += jnp.dot(a_tile, b_tile)
c_ref[...] = jnp.maximum(acc, 0)
return pl.pallas_call(
matmul_relu_kernel_loop,
out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),
in_specs=[
pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)),
pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(k, tile_n)),
],
out_specs=pl.BlockSpec(
index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)
),
grid=(m // tile_m, n // tile_n),
# Disable pipelining (double buffering) to avoid OOM on TPUs
compiler_params={"mosaic": {"num_stages": 1}},
)(a, b)| Array([[128., 128., 128., ..., 128., 128., 128.], | ||
| [128., 128., 128., ..., 128., 128., 128.], | ||
| [128., 128., 128., ..., 128., 128., 128.], | ||
| ..., | ||
| [256., 256., 256., ..., 256., 256., 256.], | ||
| [256., 256., 256., ..., 256., 256., 256.], | ||
| [256., 256., 256., ..., 256., 256., 256.]], dtype=float32) | ||
| [128., 128., 128., ..., 128., 128., 128.], | ||
| [128., 128., 128., ..., 128., 128., 128.], | ||
| [128., 128., 128., ..., 128., 128., 128.]], dtype=float32) |
|
Ah re: Gemini's comments - I didn't realize that's how BlockSpec works, let me investigate and revise a bit. Thanks! |
|
Re: the benchmark, I wonder if it's not realistic to demonstrate this on v5e. I increased the VMEM size scoped for this Pallas kernel to circumvent the OOM via: And the Pallas kernel is still slower: It may be that it's impossible to beat XLA at its own game here. I may pivot the tutorial to instead show how to define an MoE layer, which should have a more robust speedup between dense and sparse versions. I'll close this PR until I have that ready. |
Running the custom kernel guide as-is on TPU v5e will encounter a VMEM compiler OOM in newer versions of JAX:
This is caused by newer versions of Pallas in JAX >0.8.0 enabling double VMEM buffering by default. The guide was rendered with an older version of JAX so it didn't hit this issue.
Additionally, in the guide's benchmark, the non-Pallas vanilla JAX version is unintentionally faster:
(Thanks to Aditya Kane for pointing this out)
This PR updates the guide to tile the contracting dimension in the matmul, avoiding the VMEM OOM. After re-rendering on TPU, the benchmark correctly shows the fused Pallas kernel as being faster: