Skip to content

Fix VMEM OOM and benchmark issue in Pallas guide#2284

Closed
jeffcarp wants to merge 1 commit intokeras-team:masterfrom
jeffcarp:fix-pallas-oom
Closed

Fix VMEM OOM and benchmark issue in Pallas guide#2284
jeffcarp wants to merge 1 commit intokeras-team:masterfrom
jeffcarp:fix-pallas-oom

Conversation

@jeffcarp
Copy link
Copy Markdown
Member

@jeffcarp jeffcarp commented Feb 7, 2026

Running the custom kernel guide as-is on TPU v5e will encounter a VMEM compiler OOM in newer versions of JAX:

image

This is caused by newer versions of Pallas in JAX >0.8.0 enabling double VMEM buffering by default. The guide was rendered with an older version of JAX so it didn't hit this issue.

Additionally, in the guide's benchmark, the non-Pallas vanilla JAX version is unintentionally faster:

Benchmarking Matrix Size: 8192x8192
------------------------------
Standard Keras (Matmul + ReLU) Average Latency: 7.811 ms
Pallas Fused (Matmul + ReLU) Average Latency: 35.039 ms

(Thanks to Aditya Kane for pointing this out)

This PR updates the guide to tile the contracting dimension in the matmul, avoiding the VMEM OOM. After re-rendering on TPU, the benchmark correctly shows the fused Pallas kernel as being faster:

Benchmarking Matrix Size: 8192x8192
------------------------------
Standard Keras (Matmul + ReLU) Average Latency: 7.807 ms
Pallas Fused (Matmul + ReLU) Average Latency: 2.240 ms

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @jeffcarp, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses critical functional and performance issues within the custom kernel guide. It rectifies a VMEM OOM error that occurred on TPUs with recent JAX versions and corrects an inaccurate benchmark that showed the Pallas kernel as slower. The changes ensure the guide operates reliably and accurately showcases the intended performance benefits of custom Pallas kernels.

Highlights

  • VMEM OOM Resolution: The pull request resolves a VMEM Out-Of-Memory (OOM) error encountered when running the custom kernel guide on TPU v5e with newer JAX versions (JAX >0.8.0). This OOM was caused by Pallas enabling double VMEM buffering by default.
  • Benchmark Correction: The PR fixes an issue where the non-Pallas vanilla JAX version was unintentionally faster than the Pallas fused kernel in the guide's benchmark. After the changes, the Pallas fused kernel now correctly demonstrates superior performance.
  • Tiling Implementation: The core solution involves updating the fused_matmul function to tile the contracting dimension (k) in the matrix multiplication, which prevents the VMEM OOM and improves performance.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • guides/define_custom_kernel.py
    • Modified the fused_matmul function to introduce tile_k for tiling the contracting dimension.
    • Updated the assertion to include k % tile_k == 0 for input validation.
    • Adjusted in_specs for pl.BlockSpec to utilize tile_k for the block shapes of input matrices A and B.
  • guides/ipynb/define_custom_kernel.ipynb
    • Added newlines to the source arrays of several code cells for improved formatting.
    • Applied the same tile_k modifications to the fused_matmul function as in the .py file.
    • Updated the expected output values for fused_matmul and FusedDense in the markdown cells.
    • Reflected the corrected benchmark results, showing the Pallas fused kernel as significantly faster.
    • Updated the model.fit output with new loss and history object ID values.
  • guides/md/define_custom_kernel.md
    • Removed an extraneous blank line in the import section.
    • Implemented the tile_k changes in the fused_matmul function.
    • Updated the expected output values for fused_matmul and FusedDense.
    • Reflected the corrected benchmark results, demonstrating the performance improvement of the Pallas fused kernel.
Activity
  • The author, jeffcarp, identified a VMEM OOM issue in the custom kernel guide on TPU v5e, caused by newer JAX versions enabling double VMEM buffering by default.
  • jeffcarp also noted an unintended benchmark discrepancy where the non-Pallas JAX version was faster.
  • The author implemented a fix by introducing tiling for the contracting dimension in the matrix multiplication within the fused_matmul function.
  • The changes were verified by re-rendering the guide on TPU, confirming both the resolution of the OOM and the correct, improved benchmark performance of the Pallas fused kernel.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a VMEM OOM error and a benchmark issue in the Pallas guide. The proposed change involves tiling the contracting dimension in the fused_matmul function. While this change does resolve the OOM error, it introduces a critical bug where the matrix multiplication becomes mathematically incorrect because it omits the accumulation step over the tiled dimension. This leads to wrong results and misleadingly fast benchmark times. My review includes a critical comment on this issue with a suggested fix that correctly implements the tiled matrix multiplication while preventing the OOM error. I've also noted that the example output in the markdown file needs to be updated accordingly.

Comment on lines +206 to 225
tile_m, tile_k, tile_n = 128, 128, 128
assert (
m % tile_m == 0 and n % tile_n == 0
m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0
), "Inputs must be multiples of 128 for this demo"

return pl.pallas_call(
matmul_relu_kernel,
# Map output indices to input blocks
out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),
in_specs=[
# For each output tile, we take a slice of A of shape (tile_m, k)
pl.BlockSpec(
index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)
), # Matrix A
# For each output tile, we take a slice of B of shape (k, tile_n)
pl.BlockSpec(
index_map=lambda i, j: (0, j), block_shape=(k, tile_n)
), # Matrix B
# For each output tile, we take a (tile_m, tile_k) slice of A
pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, tile_k)),
# For each output tile, we take a (tile_k, tile_n) slice of B
pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(tile_k, tile_n)),
],
out_specs=pl.BlockSpec(
index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)
), # Matrix C
),
grid=(m // tile_m, n // tile_n),
)(a, b)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This implementation is mathematically incorrect. By changing block_shape to use tile_k but not adding a loop to accumulate over the k dimension, the function now computes a partial matrix multiplication. This explains why the benchmark is faster—it's performing significantly less computation, leading to incorrect results.

To correctly implement tiling for the contracting dimension while avoiding the OOM error, the kernel needs to loop over the k dimension tiles and accumulate the results. The original matmul_relu_kernel is not changed in this PR and is not suitable for this. A new kernel with a loop is needed. The OOM from double-buffering can be prevented by disabling pipelining via compiler_params.

    tile_m, tile_k, tile_n = 128, 128, 128
    assert (
        m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0
    ), "Inputs must be multiples of 128 for this demo"

    # This kernel loops over the k-dimension to perform a full matmul.
    def matmul_relu_kernel_loop(a_ref, b_ref, c_ref):
        acc = jnp.zeros((tile_m, tile_n), dtype=a.dtype)
        for ki in range(k // tile_k):
            a_tile = pl.load(a_ref, (pl.dslice(0, tile_m), pl.dslice(ki * tile_k, tile_k)))
            b_tile = pl.load(b_ref, (pl.dslice(ki * tile_k, tile_k), pl.dslice(0, tile_n)))
            acc += jnp.dot(a_tile, b_tile)
        c_ref[...] = jnp.maximum(acc, 0)

    return pl.pallas_call(
        matmul_relu_kernel_loop,
        out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),
        in_specs=[
            pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)),
            pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(k, tile_n)),
        ],
        out_specs=pl.BlockSpec(
            index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)
        ),
        grid=(m // tile_m, n // tile_n),
        # Disable pipelining (double buffering) to avoid OOM on TPUs
        compiler_params={"mosaic": {"num_stages": 1}},
    )(a, b)

Comment on lines +322 to 341
" tile_m, tile_k, tile_n = 128, 128, 128\n",
" assert (\n",
" m % tile_m == 0 and n % tile_n == 0\n",
" m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0\n",
" ), \"Inputs must be multiples of 128 for this demo\"\n",
"\n",
" return pl.pallas_call(\n",
" matmul_relu_kernel,\n",
" # Map output indices to input blocks\n",
" out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),\n",
" in_specs=[\n",
" # For each output tile, we take a slice of A of shape (tile_m, k)\n",
" pl.BlockSpec(\n",
" index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)\n",
" ), # Matrix A\n",
" # For each output tile, we take a slice of B of shape (k, tile_n)\n",
" pl.BlockSpec(\n",
" index_map=lambda i, j: (0, j), block_shape=(k, tile_n)\n",
" ), # Matrix B\n",
" # For each output tile, we take a (tile_m, tile_k) slice of A\n",
" pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, tile_k)),\n",
" # For each output tile, we take a (tile_k, tile_n) slice of B\n",
" pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(tile_k, tile_n)),\n",
" ],\n",
" out_specs=pl.BlockSpec(\n",
" index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)\n",
" ), # Matrix C\n",
" ),\n",
" grid=(m // tile_m, n // tile_n),\n",
" )(a, b)\n",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This implementation is mathematically incorrect. By changing block_shape to use tile_k but not adding a loop to accumulate over the k dimension, the function now computes a partial matrix multiplication. This explains why the benchmark is faster—it's performing significantly less computation, leading to incorrect results.

To correctly implement tiling for the contracting dimension while avoiding the OOM error, the kernel needs to loop over the k dimension tiles and accumulate the results. The original matmul_relu_kernel is not changed in this PR and is not suitable for this. A new kernel with a loop is needed. The OOM from double-buffering can be prevented by disabling pipelining via compiler_params.

    "    tile_m, tile_k, tile_n = 128, 128, 128\n",
    "    assert (\n",
    "        m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0\n",
    "    ), \"Inputs must be multiples of 128 for this demo\"\n",
    "\n",
    "    # This kernel loops over the k-dimension to perform a full matmul.\n",
    "    def matmul_relu_kernel_loop(a_ref, b_ref, c_ref):\n",
    "        acc = jnp.zeros((tile_m, tile_n), dtype=a.dtype)\n",
    "        for ki in range(k // tile_k):\n",
    "            a_tile = pl.load(a_ref, (pl.dslice(0, tile_m), pl.dslice(ki * tile_k, tile_k)))\n",
    "            b_tile = pl.load(b_ref, (pl.dslice(ki * tile_k, tile_k), pl.dslice(0, tile_n)))\n",
    "            acc += jnp.dot(a_tile, b_tile)\n",
    "        c_ref[...] = jnp.maximum(acc, 0)\n",
    "\n",
    "    return pl.pallas_call(\n",
    "        matmul_relu_kernel_loop,\n",
    "        out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),\n",
    "        in_specs=[\n",
    "            pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)),\n",
    "            pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(k, tile_n)),\n",
    "        ],\n",
    "        out_specs=pl.BlockSpec(\n",
    "            index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)\n",
    "        ),\n",
    "        grid=(m // tile_m, n // tile_n),\n",
    "        # Disable pipelining (double buffering) to avoid OOM on TPUs\n",
    "        compiler_params={\"mosaic\": {\"num_stages\": 1}},\n",
    "    )(a, b)\n"

Comment on lines +243 to 262
tile_m, tile_k, tile_n = 128, 128, 128
assert (
m % tile_m == 0 and n % tile_n == 0
m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0
), "Inputs must be multiples of 128 for this demo"

return pl.pallas_call(
matmul_relu_kernel,
# Map output indices to input blocks
out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),
in_specs=[
# For each output tile, we take a slice of A of shape (tile_m, k)
pl.BlockSpec(
index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)
), # Matrix A
# For each output tile, we take a slice of B of shape (k, tile_n)
pl.BlockSpec(
index_map=lambda i, j: (0, j), block_shape=(k, tile_n)
), # Matrix B
# For each output tile, we take a (tile_m, tile_k) slice of A
pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, tile_k)),
# For each output tile, we take a (tile_k, tile_n) slice of B
pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(tile_k, tile_n)),
],
out_specs=pl.BlockSpec(
index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)
), # Matrix C
),
grid=(m // tile_m, n // tile_n),
)(a, b)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This implementation is mathematically incorrect. By changing block_shape to use tile_k but not adding a loop to accumulate over the k dimension, the function now computes a partial matrix multiplication. This explains why the benchmark is faster—it's performing significantly less computation, leading to incorrect results.

To correctly implement tiling for the contracting dimension while avoiding the OOM error, the kernel needs to loop over the k dimension tiles and accumulate the results. The original matmul_relu_kernel is not changed in this PR and is not suitable for this. A new kernel with a loop is needed. The OOM from double-buffering can be prevented by disabling pipelining via compiler_params.

    tile_m, tile_k, tile_n = 128, 128, 128
    assert (
        m % tile_m == 0 and k % tile_k == 0 and n % tile_n == 0
    ), "Inputs must be multiples of 128 for this demo"

    # This kernel loops over the k-dimension to perform a full matmul.
    def matmul_relu_kernel_loop(a_ref, b_ref, c_ref):
        acc = jnp.zeros((tile_m, tile_n), dtype=a.dtype)
        for ki in range(k // tile_k):
            a_tile = pl.load(a_ref, (pl.dslice(0, tile_m), pl.dslice(ki * tile_k, tile_k)))
            b_tile = pl.load(b_ref, (pl.dslice(ki * tile_k, tile_k), pl.dslice(0, tile_n)))
            acc += jnp.dot(a_tile, b_tile)
        c_ref[...] = jnp.maximum(acc, 0)

    return pl.pallas_call(
        matmul_relu_kernel_loop,
        out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),
        in_specs=[
            pl.BlockSpec(index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)),
            pl.BlockSpec(index_map=lambda i, j: (0, j), block_shape=(k, tile_n)),
        ],
        out_specs=pl.BlockSpec(
            index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)
        ),
        grid=(m // tile_m, n // tile_n),
        # Disable pipelining (double buffering) to avoid OOM on TPUs
        compiler_params={"mosaic": {"num_stages": 1}},
    )(a, b)

Comment on lines +273 to +279
Array([[128., 128., 128., ..., 128., 128., 128.],
[128., 128., 128., ..., 128., 128., 128.],
[128., 128., 128., ..., 128., 128., 128.],
...,
[256., 256., 256., ..., 256., 256., 256.],
[256., 256., 256., ..., 256., 256., 256.],
[256., 256., 256., ..., 256., 256., 256.]], dtype=float32)
[128., 128., 128., ..., 128., 128., 128.],
[128., 128., 128., ..., 128., 128., 128.],
[128., 128., 128., ..., 128., 128., 128.]], dtype=float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This output is a result of the incorrect matrix multiplication. With the corrected implementation, the output for a (256, 256) matrix of ones should have values of 256., not 128.. This example output should be updated once the calculation is fixed.

Copy link
Copy Markdown
Contributor

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@jeffcarp
Copy link
Copy Markdown
Member Author

jeffcarp commented Feb 9, 2026

Ah re: Gemini's comments - I didn't realize that's how BlockSpec works, let me investigate and revise a bit. Thanks!

@jeffcarp
Copy link
Copy Markdown
Member Author

Re: the benchmark, I wonder if it's not realistic to demonstrate this on v5e. I increased the VMEM size scoped for this Pallas kernel to circumvent the OOM via:

os.environ['JAX_XLA_FLAGS'] = '--xla_tpu_scoped_vmem_limit_kib=20480'
# And within pl.pallas_call:
compiler_params=pl.tpu.CompilerParams(vmem_limit_bytes=20480*1000),

And the Pallas kernel is still slower:

Benchmarking Matrix Size: 8192x8192
------------------------------
Standard Keras (Matmul + ReLU) Average Latency: 8.009 ms
Pallas Fused (Matmul + ReLU) Average Latency: 26.169 ms

It may be that it's impossible to beat XLA at its own game here. I may pivot the tutorial to instead show how to define an MoE layer, which should have a more robust speedup between dense and sparse versions. I'll close this PR until I have that ready.

@jeffcarp jeffcarp closed this Feb 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants