Skip to content

Add MLX backend support and optimize inference for Apple Silicon#1426

Open
hkmoon wants to merge 5 commits into
MouseLand:mainfrom
hkmoon:main
Open

Add MLX backend support and optimize inference for Apple Silicon#1426
hkmoon wants to merge 5 commits into
MouseLand:mainfrom
hkmoon:main

Conversation

@hkmoon
Copy link
Copy Markdown

@hkmoon hkmoon commented Apr 1, 2026

This pull request adds support for running Cellpose inference with the MLX backend, enabling native Apple Silicon GPU acceleration. The changes introduce a new MLX-based implementation of the Cellpose-SAM Transformer model, update the CLI and core logic to allow users to select the MLX backend, and add device detection and model dispatching for MLX. This allows users on macOS with Apple Silicon and the MLX package installed to run Cellpose models without PyTorch/MPS.

Based on my local test, it improves 1.5x faster for the segmentation 2048x2048 H&E image tiles.
MLX: Segmenting tiles: [17:07<02:51, 17.17s/tile]
MPS: Segmenting tiles: [04:47<26:02, 26.48s/tile]
In case of local development with an Apple Silicon based notebook, this approach might be useful.

MLX backend support for Apple Silicon:

  • Added a new module cellpose/mlx_net.py implementing the Cellpose-SAM Transformer model using MLX for native Apple Silicon acceleration, including all necessary transformer, attention, and convolutional layers, as well as PyTorch-to-MLX weight conversion and support for relative positional embeddings.
  • Added detection of MLX availability in cellpose/core.py and a use_mlx() function to check for MLX support at runtime. [1] [2]

CLI and argument updates:

  • Added a --use_mlx command-line argument to enable or auto-detect MLX backend usage in cellpose/cli.py.

Core inference logic updates:

  • Modified model and inference logic to accept and dispatch to the MLX backend:
    • Passed use_mlx argument from CLI to model creation in cellpose/__main__.py.
    • Added _forward_mlx() for MLX-based inference, and updated run_net() and run_3D() to support a use_mlx_backend flag and select the appropriate forward function. [1] [2] [3] [4] [5] [6]

These changes provide a good performance boost for Apple Silicon users and make Cellpose more accessible on modern Mac hardware.

claude and others added 5 commits March 31, 2026 18:47
Implement MLX (Apple's ML framework) as an alternative backend for
cellpose inference on Apple Silicon Macs (M1/M2/M3/M4), providing
native GPU acceleration without PyTorch MPS overhead.

New files:
- cellpose/mlx_net.py: MLX implementation of the CP-SAM Transformer
  (ViT encoder, attention with relative position embeddings, neck,
  readout head with pixel shuffle)
- cellpose/mlx_utils.py: PyTorch-to-MLX weight conversion utilities
  (key mapping, Conv2d weight transposition OIHW->OHWI)

Modified files:
- cellpose/models.py: Add use_mlx parameter to CellposeModel
- cellpose/core.py: Add _forward_mlx() and pass use_mlx_backend through
  run_net/run_3D
- cellpose/cli.py: Add --use_mlx CLI flag
- cellpose/__main__.py: Wire use_mlx to model creation
- setup.py: Add optional 'mlx' extras_require

Usage:
  model = CellposeModel(use_mlx=True)
  # or CLI: cellpose --use_mlx --dir /path/to/images

Falls back gracefully to PyTorch if MLX is not installed.

https://claude.ai/code/session_01EfKu1kx3mC9ZWXvzrmPgy5
- Fix _add_decomposed_rel_pos: add proper dimension expansion (unsqueeze)
  for rel_h and rel_w when adding to 5D attention tensor, matching SAM's
  original implementation exactly
- Pre-compute interpolated relative position embeddings at weight load
  time to avoid repeated scipy interpolation during inference (48x per
  forward pass)
- Guard torch.cuda.empty_cache() with is_available() check to prevent
  crash on non-CUDA systems (Apple Silicon)
- Add use_mlx="auto" mode for automatic MLX detection on Apple Silicon
  when CUDA is not available
- Update CLI --use_mlx to accept optional 'auto' value

https://claude.ai/code/session_01EfKu1kx3mC9ZWXvzrmPgy5
MLX's Module.update() expects Python lists for list-based submodules
(e.g. self.blocks = [Block(...), ...]). The weight conversion built
plain dicts with string keys {"0": {...}, "1": {...}} which caused
ValueError: Module does not have parameter named "0".

Add _dicts_to_lists() to recursively convert any dict whose keys are
consecutive integers starting at 0 into a list before calling update().

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Key optimizations:
- mx.compile() the forward pass to fuse operations into a single graph
- Replace broadcasted multiply+sum with mx.einsum for rel_pos bias
- Precompute rel_pos index tables to eliminate numpy ops during forward
- Rewrite LayerNorm2d for NHWC layout, eliminating 4 transposes in Neck
- Keep data in NHWC throughout (transformer blocks -> neck -> readout)
- Store head_dim as attribute to avoid recomputation

Benchmark (256x256 tile, ViT-L, Apple Silicon):
  Before: ~639 ms/tile
  After:  ~156 ms/tile (~4.1x faster)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings April 1, 2026 17:40
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an optional MLX-based inference backend for Cellpose-SAM to enable native Apple Silicon acceleration, wiring it through the CLI and core inference dispatch so users can select MLX instead of PyTorch/MPS.

Changes:

  • Introduces an MLX implementation of the Cellpose-SAM Transformer plus PyTorch→MLX weight conversion utilities.
  • Adds a --use_mlx CLI option and passes it through to CellposeModel for backend selection/auto-detection.
  • Updates core inference (run_net/run_3D) to dispatch to an MLX forward path when requested.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
setup.py Adds an optional mlx extra dependency group.
cellpose/models.py Adds use_mlx parameter, MLX auto-detection, and MLX model loading/dispatch.
cellpose/core.py Adds MLX availability detection, MLX forward helper, and a dispatch flag in run_net/run_3D.
cellpose/cli.py Adds --use_mlx CLI argument.
cellpose/main.py Plumbs CLI use_mlx argument into CellposeModel.
cellpose/mlx_net.py New MLX-based Transformer implementation for inference.
cellpose/mlx_utils.py New utilities for converting/loading/saving weights for MLX.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread cellpose/cli.py
"--gpu_device", required=False, default="0", type=str,
help="which gpu device to use, use an integer for torch, or mps for M1")
hardware_args.add_argument(
"--use_mlx", nargs="?", const=True, default=False,
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--use_mlx uses nargs='?' without type/choices, so any provided string becomes truthy (e.g. --use_mlx false sets 'false' and will be treated as enabled). Consider restricting allowed values (e.g. choices=['auto'] plus a separate --use_mlx flag) or parsing common boolean strings into a real bool to avoid surprising behavior.

Suggested change
"--use_mlx", nargs="?", const=True, default=False,
"--use_mlx", nargs="?", const=True, default=False, choices=["auto"],

Copilot uses AI. Check for mistakes.
Comment thread cellpose/models.py
Comment on lines +138 to +142
elif use_mlx:
self.use_mlx = MLX_AVAILABLE
if not MLX_AVAILABLE:
models_logger.warning(
"MLX backend requested but MLX is not available. "
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_mlx is treated as a truthy/falsy value (elif use_mlx:), so programmatic callers passing strings like 'false' / '0' will inadvertently request MLX. It would be safer to validate the input type/value (accept only True/False/'auto', or normalize string values) and raise a clear ValueError for unsupported inputs.

Copilot uses AI. Check for mistakes.
Comment thread cellpose/core.py
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
"""
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_net(..., use_mlx_backend=True) selects _forward_mlx unconditionally, but _forward_mlx depends on mx existing. If MLX isn't installed, callers can hit a NameError instead of a clear message. Add a guard when use_mlx_backend is requested (e.g. if use_mlx_backend and not MLX_AVAILABLE: raise ImportError(...)) before setting forward_fn.

Suggested change
"""
"""
if use_mlx_backend and not MLX_AVAILABLE:
raise ImportError(
"MLX backend was requested (use_mlx_backend=True), but the 'mlx' package is not "
"installed. Please install 'mlx' or set use_mlx_backend=False."
)

Copilot uses AI. Check for mistakes.
Comment thread cellpose/mlx_utils.py
Comment on lines +12 to +16
try:
import mlx.core as mx
MLX_AVAILABLE = True
except ImportError:
MLX_AVAILABLE = False
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If MLX is not installed, mx is never defined, but later functions (e.g. convert_pytorch_to_mlx_weights / save_mlx_weights) call mx.array / mx.savez unconditionally. To avoid NameError, consider defining mx = None in the except block and raising an ImportError with an actionable message at the start of MLX-dependent functions when MLX_AVAILABLE is false.

Copilot uses AI. Check for mistakes.
Comment thread cellpose/mlx_utils.py
Comment on lines +218 to +224
def save_mlx_weights(state_dict_path, output_path):
"""Convert a PyTorch checkpoint to MLX safetensors format.

Args:
state_dict_path: Path to PyTorch model file.
output_path: Path to save MLX weights (.safetensors or .npz).
"""
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_mlx_weights claims to write “MLX safetensors” and suggests using a .safetensors extension, but the implementation always uses np.savez/mx.savez (NPZ-style archives). This is likely to produce an incorrectly formatted file when output_path ends with .safetensors, and the docstring is misleading. Either switch to the MLX safetensors save API (if available) for .safetensors, or update the function name/docs and enforce a .npz extension.

Copilot uses AI. Check for mistakes.
Comment thread cellpose/cli.py
Comment on lines +33 to +37
hardware_args.add_argument(
"--use_mlx", nargs="?", const=True, default=False,
help="use MLX backend for Apple Silicon acceleration. "
"Pass without value to enable, or pass 'auto' to auto-detect "
"(requires macOS + Apple Silicon + mlx package)")
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New MLX backend selection introduces several new behaviors (CLI parsing of --use_mlx, auto-detection logic, and the fallback path when MLX is unavailable) but there are no corresponding tests. Since the repo already has CLI tests (e.g. tests/test_output.py), please add coverage that asserts: (1) --use_mlx auto does not crash on non-MLX systems, (2) invalid values are rejected/normalized, and (3) --use_mlx enables the backend only when available.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants