Add MLX backend support and optimize inference for Apple Silicon#1426
Add MLX backend support and optimize inference for Apple Silicon#1426hkmoon wants to merge 5 commits into
Conversation
Implement MLX (Apple's ML framework) as an alternative backend for cellpose inference on Apple Silicon Macs (M1/M2/M3/M4), providing native GPU acceleration without PyTorch MPS overhead. New files: - cellpose/mlx_net.py: MLX implementation of the CP-SAM Transformer (ViT encoder, attention with relative position embeddings, neck, readout head with pixel shuffle) - cellpose/mlx_utils.py: PyTorch-to-MLX weight conversion utilities (key mapping, Conv2d weight transposition OIHW->OHWI) Modified files: - cellpose/models.py: Add use_mlx parameter to CellposeModel - cellpose/core.py: Add _forward_mlx() and pass use_mlx_backend through run_net/run_3D - cellpose/cli.py: Add --use_mlx CLI flag - cellpose/__main__.py: Wire use_mlx to model creation - setup.py: Add optional 'mlx' extras_require Usage: model = CellposeModel(use_mlx=True) # or CLI: cellpose --use_mlx --dir /path/to/images Falls back gracefully to PyTorch if MLX is not installed. https://claude.ai/code/session_01EfKu1kx3mC9ZWXvzrmPgy5
- Fix _add_decomposed_rel_pos: add proper dimension expansion (unsqueeze) for rel_h and rel_w when adding to 5D attention tensor, matching SAM's original implementation exactly - Pre-compute interpolated relative position embeddings at weight load time to avoid repeated scipy interpolation during inference (48x per forward pass) - Guard torch.cuda.empty_cache() with is_available() check to prevent crash on non-CUDA systems (Apple Silicon) - Add use_mlx="auto" mode for automatic MLX detection on Apple Silicon when CUDA is not available - Update CLI --use_mlx to accept optional 'auto' value https://claude.ai/code/session_01EfKu1kx3mC9ZWXvzrmPgy5
Claude/cellpose mlx support
MLX's Module.update() expects Python lists for list-based submodules
(e.g. self.blocks = [Block(...), ...]). The weight conversion built
plain dicts with string keys {"0": {...}, "1": {...}} which caused
ValueError: Module does not have parameter named "0".
Add _dicts_to_lists() to recursively convert any dict whose keys are
consecutive integers starting at 0 into a list before calling update().
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Key optimizations: - mx.compile() the forward pass to fuse operations into a single graph - Replace broadcasted multiply+sum with mx.einsum for rel_pos bias - Precompute rel_pos index tables to eliminate numpy ops during forward - Rewrite LayerNorm2d for NHWC layout, eliminating 4 transposes in Neck - Keep data in NHWC throughout (transformer blocks -> neck -> readout) - Store head_dim as attribute to avoid recomputation Benchmark (256x256 tile, ViT-L, Apple Silicon): Before: ~639 ms/tile After: ~156 ms/tile (~4.1x faster) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR adds an optional MLX-based inference backend for Cellpose-SAM to enable native Apple Silicon acceleration, wiring it through the CLI and core inference dispatch so users can select MLX instead of PyTorch/MPS.
Changes:
- Introduces an MLX implementation of the Cellpose-SAM Transformer plus PyTorch→MLX weight conversion utilities.
- Adds a
--use_mlxCLI option and passes it through toCellposeModelfor backend selection/auto-detection. - Updates core inference (
run_net/run_3D) to dispatch to an MLX forward path when requested.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| setup.py | Adds an optional mlx extra dependency group. |
| cellpose/models.py | Adds use_mlx parameter, MLX auto-detection, and MLX model loading/dispatch. |
| cellpose/core.py | Adds MLX availability detection, MLX forward helper, and a dispatch flag in run_net/run_3D. |
| cellpose/cli.py | Adds --use_mlx CLI argument. |
| cellpose/main.py | Plumbs CLI use_mlx argument into CellposeModel. |
| cellpose/mlx_net.py | New MLX-based Transformer implementation for inference. |
| cellpose/mlx_utils.py | New utilities for converting/loading/saving weights for MLX. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| "--gpu_device", required=False, default="0", type=str, | ||
| help="which gpu device to use, use an integer for torch, or mps for M1") | ||
| hardware_args.add_argument( | ||
| "--use_mlx", nargs="?", const=True, default=False, |
There was a problem hiding this comment.
--use_mlx uses nargs='?' without type/choices, so any provided string becomes truthy (e.g. --use_mlx false sets 'false' and will be treated as enabled). Consider restricting allowed values (e.g. choices=['auto'] plus a separate --use_mlx flag) or parsing common boolean strings into a real bool to avoid surprising behavior.
| "--use_mlx", nargs="?", const=True, default=False, | |
| "--use_mlx", nargs="?", const=True, default=False, choices=["auto"], |
| elif use_mlx: | ||
| self.use_mlx = MLX_AVAILABLE | ||
| if not MLX_AVAILABLE: | ||
| models_logger.warning( | ||
| "MLX backend requested but MLX is not available. " |
There was a problem hiding this comment.
use_mlx is treated as a truthy/falsy value (elif use_mlx:), so programmatic callers passing strings like 'false' / '0' will inadvertently request MLX. It would be safer to validate the input type/value (accept only True/False/'auto', or normalize string values) and raise a clear ValueError for unsupported inputs.
| y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability. | ||
| y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability. | ||
| style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles. | ||
| """ |
There was a problem hiding this comment.
run_net(..., use_mlx_backend=True) selects _forward_mlx unconditionally, but _forward_mlx depends on mx existing. If MLX isn't installed, callers can hit a NameError instead of a clear message. Add a guard when use_mlx_backend is requested (e.g. if use_mlx_backend and not MLX_AVAILABLE: raise ImportError(...)) before setting forward_fn.
| """ | |
| """ | |
| if use_mlx_backend and not MLX_AVAILABLE: | |
| raise ImportError( | |
| "MLX backend was requested (use_mlx_backend=True), but the 'mlx' package is not " | |
| "installed. Please install 'mlx' or set use_mlx_backend=False." | |
| ) |
| try: | ||
| import mlx.core as mx | ||
| MLX_AVAILABLE = True | ||
| except ImportError: | ||
| MLX_AVAILABLE = False |
There was a problem hiding this comment.
If MLX is not installed, mx is never defined, but later functions (e.g. convert_pytorch_to_mlx_weights / save_mlx_weights) call mx.array / mx.savez unconditionally. To avoid NameError, consider defining mx = None in the except block and raising an ImportError with an actionable message at the start of MLX-dependent functions when MLX_AVAILABLE is false.
| def save_mlx_weights(state_dict_path, output_path): | ||
| """Convert a PyTorch checkpoint to MLX safetensors format. | ||
|
|
||
| Args: | ||
| state_dict_path: Path to PyTorch model file. | ||
| output_path: Path to save MLX weights (.safetensors or .npz). | ||
| """ |
There was a problem hiding this comment.
save_mlx_weights claims to write “MLX safetensors” and suggests using a .safetensors extension, but the implementation always uses np.savez/mx.savez (NPZ-style archives). This is likely to produce an incorrectly formatted file when output_path ends with .safetensors, and the docstring is misleading. Either switch to the MLX safetensors save API (if available) for .safetensors, or update the function name/docs and enforce a .npz extension.
| hardware_args.add_argument( | ||
| "--use_mlx", nargs="?", const=True, default=False, | ||
| help="use MLX backend for Apple Silicon acceleration. " | ||
| "Pass without value to enable, or pass 'auto' to auto-detect " | ||
| "(requires macOS + Apple Silicon + mlx package)") |
There was a problem hiding this comment.
New MLX backend selection introduces several new behaviors (CLI parsing of --use_mlx, auto-detection logic, and the fallback path when MLX is unavailable) but there are no corresponding tests. Since the repo already has CLI tests (e.g. tests/test_output.py), please add coverage that asserts: (1) --use_mlx auto does not crash on non-MLX systems, (2) invalid values are rejected/normalized, and (3) --use_mlx enables the backend only when available.
This pull request adds support for running Cellpose inference with the MLX backend, enabling native Apple Silicon GPU acceleration. The changes introduce a new MLX-based implementation of the Cellpose-SAM Transformer model, update the CLI and core logic to allow users to select the MLX backend, and add device detection and model dispatching for MLX. This allows users on macOS with Apple Silicon and the MLX package installed to run Cellpose models without PyTorch/MPS.
Based on my local test, it improves 1.5x faster for the segmentation 2048x2048 H&E image tiles.
MLX: Segmenting tiles: [17:07<02:51, 17.17s/tile]
MPS: Segmenting tiles: [04:47<26:02, 26.48s/tile]
In case of local development with an Apple Silicon based notebook, this approach might be useful.
MLX backend support for Apple Silicon:
cellpose/mlx_net.pyimplementing the Cellpose-SAM Transformer model using MLX for native Apple Silicon acceleration, including all necessary transformer, attention, and convolutional layers, as well as PyTorch-to-MLX weight conversion and support for relative positional embeddings.cellpose/core.pyand ause_mlx()function to check for MLX support at runtime. [1] [2]CLI and argument updates:
--use_mlxcommand-line argument to enable or auto-detect MLX backend usage incellpose/cli.py.Core inference logic updates:
use_mlxargument from CLI to model creation incellpose/__main__.py._forward_mlx()for MLX-based inference, and updatedrun_net()andrun_3D()to support ause_mlx_backendflag and select the appropriate forward function. [1] [2] [3] [4] [5] [6]These changes provide a good performance boost for Apple Silicon users and make Cellpose more accessible on modern Mac hardware.