diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9677615206..46800a1e13 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -82,6 +82,7 @@ jobs: install-numba: [0] install-jax: [0] install-torch: [0] + install-mlx: [0] part: - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" @@ -115,6 +116,7 @@ jobs: install-numba: 0 install-jax: 0 install-torch: 0 + install-mlx: 0 - install-numba: 1 os: "ubuntu-latest" python-version: "3.10" @@ -150,6 +152,13 @@ jobs: fast-compile: 0 float32: 0 part: "tests/link/pytorch" + - install-mlx: 1 + os: "ubuntu-latest" + python-version: "3.10" + numpy-version: ">=2.0" + fast-compile: 0 + float32: 0 + part: "tests/link/mlx" - os: macos-15 python-version: "3.13" numpy-version: ">=2.0" @@ -196,6 +205,7 @@ jobs: if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi + if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi pip install pytest-sphinx pip install -e ./ @@ -212,6 +222,7 @@ jobs: INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_TORCH: ${{ matrix.install-torch}} + INSTALL_MLX: ${{ matrix.install-mlx }} OS: ${{ matrix.os}} - name: Run tests diff --git a/.gitignore b/.gitignore index dfe862b868..ebe8e61bd0 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,6 @@ __pycache__ \#*\# build compiled/*.cpp -core.* cutils_ext.cpp dist doc/.build/ diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index f80dfaaf5c..8dc7c742bc 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -27,6 +27,7 @@ from pytensor.link.basic import Linker, PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.jax.linker import JAXLinker +from pytensor.link.mlx.linker import MLXLinker from pytensor.link.numba.linker import NumbaLinker from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.vm import VMLinker @@ -50,6 +51,7 @@ "jax": JAXLinker(), "pytorch": PytorchLinker(), "numba": NumbaLinker(), + "mlx": MLXLinker(), } @@ -494,6 +496,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ), ) +MLX = Mode( + MLXLinker(), + RewriteDatabaseQuery( + include=["fast_run"], + exclude=[ + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + "scan_save_mem_prealloc", + ], + ), +) + predefined_modes = { "FAST_COMPILE": FAST_COMPILE, @@ -501,6 +517,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "JAX": JAX, "NUMBA": NUMBA, "PYTORCH": PYTORCH, + "MLX": MLX, } _CACHED_RUNTIME_MODES: dict[str, Mode] = {} diff --git a/pytensor/link/mlx/__init__.py b/pytensor/link/mlx/__init__.py new file mode 100644 index 0000000000..d5a6ab19ff --- /dev/null +++ b/pytensor/link/mlx/__init__.py @@ -0,0 +1 @@ +from pytensor.link.mlx.linker import MLXLinker diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py new file mode 100644 index 0000000000..f039263a37 --- /dev/null +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -0,0 +1,13 @@ +# isort: off +from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify + +import pytensor.link.mlx.dispatch.math +import pytensor.link.mlx.dispatch.basic +import pytensor.link.mlx.dispatch.elemwise +import pytensor.link.mlx.dispatch.shape +import pytensor.link.mlx.dispatch.subtensor +import pytensor.link.mlx.dispatch.core +import pytensor.link.mlx.dispatch.signal +import pytensor.link.mlx.dispatch.signal.conv +import pytensor.link.mlx.dispatch.blockwise +# isort: on diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py new file mode 100644 index 0000000000..d0b3d451f5 --- /dev/null +++ b/pytensor/link/mlx/dispatch/basic.py @@ -0,0 +1,78 @@ +import warnings +from copy import deepcopy +from functools import singledispatch +from types import NoneType + +import mlx.core as mx +import numpy as np + +from pytensor.compile.ops import DeepCopyOp +from pytensor.graph.fg import FunctionGraph +from pytensor.link.utils import fgraph_to_python +from pytensor.raise_op import Assert, CheckAndRaise + + +@singledispatch +def mlx_typify(data, **kwargs): + raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}") + + +@mlx_typify.register(np.ndarray) +@mlx_typify.register(mx.array) +def mlx_typify_tensor(data, dtype=None, **kwargs): + return mx.array(data, dtype=dtype) + + +@mlx_typify.register(slice) +@mlx_typify.register(NoneType) +@mlx_typify.register(np.number) +def mlx_typify_no_conversion_needed(data, **kwargs): + return data + + +@singledispatch +def mlx_funcify(op, node=None, storage_map=None, **kwargs): + """Create a MLX compatible function from an PyTensor `Op`.""" + raise NotImplementedError( + f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation" + ) + + +@mlx_funcify.register(FunctionGraph) +def mlx_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="mlx_funcified_fgraph", + conversion_func=mlx_funcify, + **kwargs, +): + built_kwargs = {"conversion_func": conversion_func, **kwargs} + return fgraph_to_python( + fgraph, + conversion_func, + type_conversion_fn=mlx_typify, + fgraph_name=fgraph_name, + **built_kwargs, + ) + + +@mlx_funcify.register(DeepCopyOp) +def mlx_funcify_DeepCopyOp(op, **kwargs): + def deepcopyop(x): + return deepcopy(x) + + return deepcopyop + + +@mlx_funcify.register(Assert) +@mlx_funcify.register(CheckAndRaise) +def mlx_funcify_CheckAndRaise(op, **kwargs): + warnings.warn( + f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""", + stacklevel=2, + ) + + def assert_fn(x, *inputs): + return x + + return assert_fn diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py new file mode 100644 index 0000000000..74bb018a68 --- /dev/null +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -0,0 +1,99 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.signal.conv import Conv1d + + +def blockwise_conv1d(op, node, **kwargs): + """ + Custom implementation of Blockwise.conv1d for MLX. + """ + + def batched_conv1d( + x: mx.array, + kernels: mx.array, + mode: str = op.core_op.mode, + stride: int = 1, + dilation: int = 1, + ) -> mx.array: + """ + Apply B separate 1D convolutions (full or valid) to B sequences in parallel. + + Parameters + ---------- + x : array of shape (B, T) + B sequences of length T. + kernels : array of shape (B, K) + B kernels of length K. + mode : {"valid", "full"} + "valid" → no padding, output length = T - K + 1 + "full" → zero-pad so output length = T + K - 1 + stride : int, convolution stride (default=1) + dilation : int, convolution dilation (default=1) + + Returns + ------- + out : array of shape (B, L) + where L = + - T - K + 1 if mode="valid" + - T + K - 1 if mode="full" + """ + # --- 1) shape checks --- + B, T = x.shape + Bk, K = kernels.shape + if B != Bk: + raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") + + # --- 2) flip kernels for convolution --- + kernels_flipped = kernels[:, ::-1] # shape (B, K) + + # --- 3) decide padding --- + if mode == "valid": + pad = 0 + elif mode == "full": + pad = (K - 1) * dilation + else: + raise ValueError(f"Unsupported mode {mode!r}: choose 'valid' or 'full'") + + # --- 4) reshape into MLX conv1d form --- + # input: (N=1, H=T, C_in=B) + x_in = x.T[None, :, :] + + # weight: (C_out=B, H_f=K, C_in=1) + w = kernels_flipped[:, :, None] + + # --- 5) run grouped conv1d --- + y = mx.conv1d(x_in, w, stride=stride, padding=pad, dilation=dilation, groups=B) + # y shape: (1, H_out, B) + + # --- 6) return shape (B, H_out) --- + return y[0].T + + return batched_conv1d + + +@mlx_funcify.register(Blockwise) +def funcify_Blockwise(op: Blockwise, node, **kwargs): + # 1) If it's a Conv1d Blockwise, use the custom implementation + if isinstance(op.core_op, Conv1d): + return blockwise_conv1d(op, node, **kwargs) + + # 2) Otherwise, get the core python function for this Blockwise + core_node = op._create_dummy_core_node(node.inputs) + core_f = mlx_funcify(op.core_op, core_node) + + # 3) Determine how many inputs correspond to batch dimensions + n_batch = op.batch_ndim(node) + + # 4) Build in_axes: map only the first n_batch args, keep the rest static + in_axes = tuple(0 if i < n_batch else None for i in range(len(node.inputs))) + + # 5) Vectorize (vmap) with in_axes + blockwise_f = mx.vmap(core_f, in_axes=in_axes) + + # 6) Return the mapped function + def blockwise_fun(*inputs): + return blockwise_f(*inputs) + + return blockwise_fun diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py new file mode 100644 index 0000000000..3a0b279cd3 --- /dev/null +++ b/pytensor/link/mlx/dispatch/core.py @@ -0,0 +1,253 @@ +""" +pytensor/link/mlx/dispatch/basic.py +----------------------------------- + +First-cut MLX translations for the most common tensor Ops. + +The structure intentionally follows pytensor's JAX dispatcher so that +once these kernels stabilise they can be optimised further (e.g. fusing +element-wise graphs, adding in-place updates, RNG thinning, etc.). +""" + +from __future__ import annotations + +import warnings + +import mlx.core as mx # MLX +import numpy as np + +from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX +from pytensor.tensor import get_vector_length +from pytensor.tensor.basic import ( + Alloc, + AllocEmpty, + ExtractDiag, + Eye, + Join, + MakeVector, + ScalarFromTensor, + Split, + TensorFromScalar, + Tri, + get_scalar_constant_value, +) +from pytensor.tensor.exceptions import NotScalarConstantError + + +# ------------------------------------------------------------------ +# Join +# ------------------------------------------------------------------ +@mlx_funcify.register(Join) # MLX +def mlx_funcify_Join(op, **kwargs): + def join(axis, *tensors): + view = op.view + if (view != -1) and all( + tensors[i].shape[axis] == 0 # MLX + for i in list(range(view)) + list(range(view + 1, len(tensors))) + ): + return tensors[view] + + return mx.concatenate(tensors, axis=axis) # MLX + + return join + + +# ------------------------------------------------------------------ +# Split +# ------------------------------------------------------------------ +@mlx_funcify.register(Split) # MLX +def mlx_funcify_Split(op: Split, node, **kwargs): + _, axis_sym, splits_sym = node.inputs + + try: + constant_axis = get_scalar_constant_value(axis_sym) + except NotScalarConstantError: + constant_axis = None + warnings.warn( + "Split node does not have a constant axis. MLX implementation may fail." + ) + + try: + constant_splits = np.array( + [ + get_scalar_constant_value(splits_sym[i]) + for i in range(get_vector_length(splits_sym)) + ] + ) + except (ValueError, NotScalarConstantError): + constant_splits = None + warnings.warn( + "Split node does not have constant split positions. MLX implementation may fail." + ) + + def split(x, axis, splits): + # Resolve constants (avoids tracing extra ops) + if constant_axis is not None: + axis = int(constant_axis) + + if constant_splits is not None: + splits = constant_splits + cumsum_splits = np.cumsum(splits[:-1]) + else: + # dynamic - keep in graph + splits_arr = mx.array(splits) # MLX + cumsum_splits = mx.cumsum( + splits_arr[:-1] + ).tolist() # python list for mx.split + + if len(splits) != op.len_splits: + raise ValueError("Length of 'splits' is not equal to n_splits") + if np.sum(np.asarray(splits)) != x.shape[axis]: + raise ValueError( + "Split sizes do not sum to the input length on the chosen axis." + ) + if np.any(np.asarray(splits) < 0): + raise ValueError("Split sizes cannot be negative.") + + return mx.split(x, cumsum_splits, axis=axis) # MLX + + return split + + +# ------------------------------------------------------------------ +# ExtractDiag +# ------------------------------------------------------------------ +@mlx_funcify.register(ExtractDiag) # MLX +def mlx_funcify_ExtractDiag(op, **kwargs): + offset, axis1, axis2 = op.offset, op.axis1, op.axis2 + + def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): + return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX + + return extract_diag + + +# ------------------------------------------------------------------ +# Eye +# ------------------------------------------------------------------ +@mlx_funcify.register(Eye) # MLX +def mlx_funcify_Eye(op, **kwargs): + dtype = convert_dtype_to_mlx(op.dtype) + + def eye(N, M, k): + return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX + + return eye + + +def convert_dtype_to_mlx(dtype_str): + """Convert PyTensor dtype strings to MLX dtype objects. + + MLX expects dtype objects rather than string literals for type conversion. + This function maps common dtype strings to their MLX equivalents. + """ + if isinstance(dtype_str, str): + if dtype_str == "bool": + return mx.bool_ + elif dtype_str == "int8": + return mx.int8 + elif dtype_str == "int16": + return mx.int16 + elif dtype_str == "int32": + return mx.int32 + elif dtype_str == "int64": + return mx.int64 + elif dtype_str == "uint8": + return mx.uint8 + elif dtype_str == "uint16": + return mx.uint16 + elif dtype_str == "uint32": + return mx.uint32 + elif dtype_str == "uint64": + return mx.uint64 + elif dtype_str == "float16": + return mx.float16 + elif dtype_str == "float32": + return mx.float32 + elif dtype_str == "float64": + return mx.float64 + elif dtype_str == "bfloat16": + return mx.bfloat16 + elif dtype_str == "complex64": + return mx.complex64 + elif dtype_str == "complex128": + return mx.complex128 + # Return as is if it's already an MLX dtype or not a recognized string + return dtype_str + + +# ------------------------------------------------------------------ +# MakeVector +# ------------------------------------------------------------------ +@mlx_funcify.register(MakeVector) # MLX +def mlx_funcify_MakeVector(op, **kwargs): + dtype = convert_dtype_to_mlx(op.dtype) + + def makevector(*x): + return mx.array(x, dtype=dtype) # MLX + + return makevector + + +# ------------------------------------------------------------------ +# TensorFromScalar (identity for MLX) +# ------------------------------------------------------------------ +@mlx_funcify.register(TensorFromScalar) # MLX +def mlx_funcify_TensorFromScalar(op, **kwargs): + def tensor_from_scalar(x): + return x # already an MLX array / scalar + + return tensor_from_scalar + + +# ------------------------------------------------------------------ +# ScalarFromTensor +# ------------------------------------------------------------------ +@mlx_funcify.register(ScalarFromTensor) # MLX +def mlx_funcify_ScalarFromTensor(op, **kwargs): + def scalar_from_tensor(x): + return mx.array(x).reshape(-1)[0] # MLX + + return scalar_from_tensor + + +# ------------------------------------------------------------------ +# Tri +# ------------------------------------------------------------------ +@mlx_funcify.register(Tri) # MLX +def mlx_funcify_Tri(op, node, **kwargs): + # node.inputs -> N, M, k + const_args = [getattr(inp, "data", None) for inp in node.inputs] + dtype = convert_dtype_to_mlx(op.dtype) + + def tri(*args): + # Replace args with compile-time constants when available + args = [ + arg if const_a is None else const_a + for arg, const_a in zip(args, const_args, strict=True) + ] + return mx.tri(*args, dtype=dtype) # MLX + + return tri + + +@mlx_funcify.register(AllocEmpty) +def mlx_funcify_AllocEmpty(op, **kwargs): + dtype = convert_dtype_to_mlx(op.dtype) + + def allocempty(*shape): + return mx.zeros(shape, dtype=dtype) + + return allocempty + + +@mlx_funcify.register(Alloc) +def mlx_funcify_Alloc(op, node, **kwargs): + def alloc(x, *shape): + # Convert x to an MLX array with the correct dtype if it's a scalar + x_array = mx.array(x) + res = mx.broadcast_to(x_array, shape) + Alloc._check_runtime_broadcast(node, x_array, res.shape) + return res + + return alloc diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py new file mode 100644 index 0000000000..aaf04968de --- /dev/null +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -0,0 +1,112 @@ +import mlx.core as mx +import numpy as np + +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx +from pytensor.scalar import Softplus +from pytensor.scalar.basic import ( + AND, + OR, + Add, + Cast, + Mul, +) +from pytensor.tensor.elemwise import CAReduce, DimShuffle +from pytensor.tensor.special import Softmax, SoftmaxGrad + + +@mlx_funcify.register(DimShuffle) +def mlx_funcify_DimShuffle(op, **kwargs): + def dimshuffle(x): + # Convert scalar to array if needed + if isinstance(x, int | float) or ( + isinstance(x, np.number) and not isinstance(x, np.ndarray) + ): + x = mx.array(x) + res = mx.transpose(x, op.transposition) + shape = list(res.shape[: len(op.shuffle)]) + for augm in op.augment: + shape.insert(augm, 1) + return mx.reshape(res, shape) + + return dimshuffle + + +@mlx_funcify.register(CAReduce) +def mlx_funcify_CAReduce(op, **kwargs): + if isinstance(op.scalar_op, Add): + + def sum(x): + return mx.sum(x, axis=op.axis) + + return sum + elif isinstance(op.scalar_op, Mul): + + def prod(x): + return mx.prod(x, axis=op.axis) + + return prod + elif isinstance(op.scalar_op, AND): + + def all(x): + return x.all(axis=op.axis) + + return all + elif isinstance(op.scalar_op, OR): + + def any(x): + return mx.any(x, axis=op.axis) + + return any + else: + raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") + + +@mlx_funcify.register(Softmax) +def mlx_funcify_Softmax(op, **kwargs): + axis = op.axis + + def softmax(x): + return mx.softmax(x, axis=axis) + + return softmax + + +@mlx_funcify.register(SoftmaxGrad) +def mlx_funcify_SoftmaxGrad(op, **kwargs): + axis = op.axis + + def softmax_grad(dy, sm): + dy_times_sm = dy * sm + return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm + + return softmax_grad + + +@mlx_funcify.register(Softplus) +def mlx_funcify_Softplus(op, **kwargs): + def softplus(x): + return mx.where( + x < -37.0, + mx.exp(x), + mx.where( + x < 18.0, + mx.log1p(mx.exp(x)), + mx.where( + x < 33.3, + x + mx.exp(-x), + x, + ), + ), + ) + + return softplus + + +@mlx_funcify.register(Cast) +def mlx_funcify_Cast(op, **kwargs): + def cast(x): + dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype) + return x.astype(dtype) + + return cast diff --git a/pytensor/link/mlx/dispatch/math.py b/pytensor/link/mlx/dispatch/math.py new file mode 100644 index 0000000000..153f049b0e --- /dev/null +++ b/pytensor/link/mlx/dispatch/math.py @@ -0,0 +1,250 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify, mlx_typify +from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx +from pytensor.scalar import Softplus +from pytensor.scalar.basic import ( + AND, + EQ, + GE, + GT, + LE, + LT, + NEQ, + OR, + Abs, + Add, + Cast, + Cos, + Exp, + Log, + Log1p, + Mul, + Neg, + Pow, + ScalarMaximum, + ScalarMinimum, + Sign, + Sin, + Sqr, + Sqrt, + Sub, + Switch, + TrueDiv, +) +from pytensor.scalar.math import Sigmoid +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.math import Dot + + +@mlx_typify.register(int) +@mlx_typify.register(float) +def mlx_typify_python_scalar(data, **kwargs): + return mx.array(data) + + +@mlx_funcify.register(Dot) +def mlx_funcify_Dot(op, **kwargs): + def dot(x, y): + return mx.matmul(x, y) + + return dot + + +@mlx_funcify.register(Elemwise) +def mlx_funcify_Elemwise(op, **kwargs): + if isinstance(op.scalar_op, Add): + + def add(*args): + result = args[0] + for arg in args[1:]: + result = mx.add(result, arg) + return result + + return add + elif isinstance(op.scalar_op, Sub): + + def sub(x, y): + return mx.subtract(x, y) + + return sub + elif isinstance(op.scalar_op, Mul): + + def mul(*args): + result = args[0] + for arg in args[1:]: + result = mx.multiply(result, arg) + return result + + return mul + elif isinstance(op.scalar_op, Exp): + + def exp(x): + return mx.exp(x) + + return exp + elif isinstance(op.scalar_op, Log): + + def log(x): + return mx.log(x) + + return log + elif isinstance(op.scalar_op, Sin): + + def sin(x): + return mx.sin(x) + + return sin + elif isinstance(op.scalar_op, Cos): + + def cos(x): + return mx.cos(x) + + return cos + elif isinstance(op.scalar_op, Sigmoid): + + def sigmoid(x): + return mx.sigmoid(x) + + return sigmoid + elif isinstance(op.scalar_op, LE): + + def le(x, y): + return mx.less_equal(x, y) + + return le + elif isinstance(op.scalar_op, LT): + + def lt(x, y): + return mx.less(x, y) + + return lt + elif isinstance(op.scalar_op, GE): + + def ge(x, y): + return mx.greater_equal(x, y) + + return ge + elif isinstance(op.scalar_op, GT): + + def gt(x, y): + return mx.greater(x, y) + + return gt + elif isinstance(op.scalar_op, EQ): + + def eq(x, y): + return mx.equal(x, y) + + return eq + elif isinstance(op.scalar_op, NEQ): + + def neq(x, y): + return mx.not_equal(x, y) + + return neq + elif isinstance(op.scalar_op, Switch): + + def switch(cond, x, y): + return mx.where(cond, x, y) + + return switch + elif isinstance(op.scalar_op, Pow): + + def pow(x, y): + return mx.power(x, y) + + return pow + elif isinstance(op.scalar_op, TrueDiv): + + def true_div(x, y): + return mx.divide(x, y) + + return true_div + elif isinstance(op.scalar_op, Sqr): + + def sqr(x): + return mx.square(x) + + return sqr + elif isinstance(op.scalar_op, Sqrt): + + def sqrt(x): + return mx.sqrt(x) + + return sqrt + elif isinstance(op.scalar_op, Abs): + + def abs(x): + return mx.abs(x) + + return abs + elif isinstance(op.scalar_op, Softplus): + + def softplus(x): + return mx.where( + x < -37.0, + mx.exp(x), + mx.where( + x < 18.0, + mx.log1p(mx.exp(x)), + mx.where( + x < 33.3, + x + mx.exp(-x), + x, + ), + ), + ) + + return softplus + elif isinstance(op.scalar_op, Neg): + + def neg(x): + return mx.negative(x) + + return neg + elif isinstance(op.scalar_op, AND): + + def all(x, y): + return mx.bitwise_and(x, y) + + return all + elif isinstance(op.scalar_op, OR): + + def any(x, y): + return mx.bitwise_or(x, y) + + return any + elif isinstance(op.scalar_op, ScalarMaximum): + + def max(x, y): + return mx.maximum(x, y) + + return max + elif isinstance(op.scalar_op, ScalarMinimum): + + def min(x, y): + return mx.minimum(x, y) + + return min + elif isinstance(op.scalar_op, Cast): + + def cast(x): + dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype) + return x.astype(dtype) + + return cast + elif isinstance(op.scalar_op, Sign): + + def sign(x): + return mx.sign(x) + + return sign + elif isinstance(op.scalar_op, Log1p): + + def log1p(x): + return mx.log1p(x) + + return log1p + else: + raise NotImplementedError(f"MLX does not support {op.scalar_op}") diff --git a/pytensor/link/mlx/dispatch/shape.py b/pytensor/link/mlx/dispatch/shape.py new file mode 100644 index 0000000000..bd5b5941d9 --- /dev/null +++ b/pytensor/link/mlx/dispatch/shape.py @@ -0,0 +1,24 @@ +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.shape import Shape_i, SpecifyShape + + +@mlx_funcify.register(SpecifyShape) +def mlx_funcify_SpecifyShape(op, node, **kwargs): + def specifyshape(x, *shape): + assert x.ndim == len(shape) + for actual, expected in zip(x.shape, shape, strict=True): + if expected is None: + continue + if actual != expected: + raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}") + return x + + return specifyshape + + +@mlx_funcify.register(Shape_i) +def mlx_funcify_Shape_i(op, node, **kwargs): + def shape_i(x): + return x.shape[op.i] + + return shape_i diff --git a/pytensor/link/mlx/dispatch/signal/__init__.py b/pytensor/link/mlx/dispatch/signal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pytensor/link/mlx/dispatch/signal/conv.py b/pytensor/link/mlx/dispatch/signal/conv.py new file mode 100644 index 0000000000..8f84ebb42f --- /dev/null +++ b/pytensor/link/mlx/dispatch/signal/conv.py @@ -0,0 +1,14 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.signal.conv import Conv1d + + +@mlx_funcify.register(Conv1d) +def mlx_funcify_Conv1d(op, node=None, **kwargs): + mode = op.mode + + def conv1d(data, kernel): + return mx.convolve(data, kernel, mode=mode) + + return conv1d diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py new file mode 100644 index 0000000000..ce14d08246 --- /dev/null +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -0,0 +1,105 @@ +from copy import deepcopy + +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, + indices_from_subtensor, +) +from pytensor.tensor.type_other import MakeSlice + + +@mlx_funcify.register(Subtensor) +def mlx_funcify_Subtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + def subtensor(x, *ilists): + indices = indices_from_subtensor([int(element) for element in ilists], idx_list) + if len(indices) == 1: + indices = indices[0] + + return x.__getitem__(indices) + + return subtensor + + +@mlx_funcify.register(AdvancedSubtensor) +@mlx_funcify.register(AdvancedSubtensor1) +def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + def advanced_subtensor(x, *ilists): + indices = indices_from_subtensor(ilists, idx_list) + if len(indices) == 1: + indices = indices[0] + + return x.__getitem__(indices) + + return advanced_subtensor + + +@mlx_funcify.register(IncSubtensor) +@mlx_funcify.register(AdvancedIncSubtensor1) +def mlx_funcify_IncSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + + if getattr(op, "set_instead_of_inc", False): + + def mlx_fn(x, indices, y): + if not op.inplace: + x = deepcopy(x) + x[indices] = y + return x + + else: + + def mlx_fn(x, indices, y): + if not op.inplace: + x = deepcopy(x) + x[indices] += y + return x + + def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): + indices = indices_from_subtensor(ilist, idx_list) + if len(indices) == 1: + indices = indices[0] + + return mlx_fn(x, indices, y) + + return incsubtensor + + +@mlx_funcify.register(AdvancedIncSubtensor) +def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): + if getattr(op, "set_instead_of_inc", False): + + def mlx_fn(x, indices, y): + if not op.inplace: + x = deepcopy(x) + x[indices] = y + return x + + else: + + def mlx_fn(x, indices, y): + if not op.inplace: + x = deepcopy(x) + x[indices] += y + return x + + def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): + return mlx_fn(x, ilist, y) + + return advancedincsubtensor + + +@mlx_funcify.register(MakeSlice) +def mlx_funcify_MakeSlice(op, **kwargs): + def makeslice(*x): + return slice(*x) + + return makeslice diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py new file mode 100644 index 0000000000..e057bb942c --- /dev/null +++ b/pytensor/link/mlx/linker.py @@ -0,0 +1,71 @@ +from pytensor.link.basic import JITLinker + + +class MLXLinker(JITLinker): + """A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.gen_functors = [] + + def fgraph_convert(self, fgraph, **kwargs): + """Convert a PyTensor FunctionGraph to an MLX-compatible function. + + Parameters + ---------- + fgraph : FunctionGraph + The function graph to convert + + Returns + ------- + callable + An MLX-compatible function + """ + from pytensor.link.mlx.dispatch import mlx_funcify + + return mlx_funcify( + fgraph, + **kwargs, + ) + + def jit_compile(self, fn): + import mlx.core as mx + + from pytensor.link.mlx.dispatch import mlx_typify + + inner_fn = mx.compile(fn) + + def fn(*inputs, inner_fn=inner_fn): + return inner_fn(*(mlx_typify(inp) for inp in inputs)) + + return fn + + def create_thunk_inputs(self, storage_map): + """Create inputs for the MLX thunk. + + Parameters + ---------- + storage_map : dict + Map from variables to their storage + + Returns + ------- + list + The inputs for the thunk + """ + from numpy.random import Generator, RandomState + + from pytensor.link.mlx.dispatch import mlx_typify + + thunk_inputs = [] + for n in self.fgraph.inputs: + sinput = storage_map[n] + # Handle random number generators specially + if isinstance(sinput[0], RandomState | Generator): + new_value = mlx_typify( + sinput[0], dtype=getattr(sinput[0], "dtype", None) + ) + sinput[0] = new_value + thunk_inputs.append(sinput) + + return thunk_inputs diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index b8475e3157..18824a5b71 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -31,13 +31,15 @@ def conversion_func_register(*args, **kwargs): **kwargs, } return pytorch_funcify( - fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs + fgraph, + input_storage=input_storage, + storage_map=storage_map, + **built_kwargs, ) def jit_compile(self, fn): import torch - # flag that tend to help our graphs torch._dynamo.config.capture_dynamic_output_shape_ops = True from pytensor.link.pytorch.dispatch import pytorch_typify diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py new file mode 100644 index 0000000000..4a6e67f406 --- /dev/null +++ b/tests/link/mlx/test_basic.py @@ -0,0 +1,80 @@ +from collections.abc import Callable, Iterable +from functools import partial + +import numpy as np +import pytest + +from pytensor.compile.function import function +from pytensor.compile.mode import MLX, Mode +from pytensor.graph import RewriteDatabaseQuery +from pytensor.graph.basic import Variable +from pytensor.link.mlx import MLXLinker + + +mx = pytest.importorskip("mlx.core") + +optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude) +mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer) +py_mode = Mode(linker="py", optimizer=None) + + +def compare_mlx_and_py( + graph_inputs: Iterable[Variable], + graph_outputs: Variable | Iterable[Variable], + test_inputs: Iterable, + *, + assert_fn: Callable | None = None, + must_be_device_array: bool = True, + mlx_mode=mlx_mode, + py_mode=py_mode, +): + """Function to compare python function output and mlx compiled output for testing equality + + The inputs and outputs are then passed to this function which then compiles the given function in both + mlx and python, runs the calculation in both and checks if the results are the same + + Parameters + ---------- + graph_inputs: + Symbolic inputs to the graph + outputs: + Symbolic outputs of the graph + test_inputs: iter + Numerical inputs for testing the function. + assert_fn: func, opt + Assert function used to check for equality between python and mlx. If not + provided uses np.testing.assert_allclose + must_be_device_array: Bool + Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes + if this device array is found it indicates if the result was computed by jax + + Returns + ------- + mlx_res + + """ + if assert_fn is None: + assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) + + if any(inp.owner is not None for inp in graph_inputs): + raise ValueError("Inputs must be root variables") + + pytensor_mlx_fn = function(graph_inputs, graph_outputs, mode=mlx_mode) + mlx_res = pytensor_mlx_fn(*test_inputs) + + if must_be_device_array: + if isinstance(mlx_res, list): + assert all(isinstance(res, mx.array) for res in mlx_res) + else: + assert isinstance(mlx_res, mx.array) + + pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode) + py_res = pytensor_py_fn(*test_inputs) + + if isinstance(graph_outputs, list | tuple): + for j, p in zip(mlx_res, py_res, strict=True): + assert_fn(j, p) + else: + assert_fn(mlx_res, py_res) + + return pytensor_mlx_fn, mlx_res diff --git a/tests/link/mlx/test_blockwise.py b/tests/link/mlx/test_blockwise.py new file mode 100644 index 0000000000..9b271186c9 --- /dev/null +++ b/tests/link/mlx/test_blockwise.py @@ -0,0 +1,64 @@ +import numpy as np + +import pytensor.tensor as pt +from pytensor.tensor import tensor +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.math import Dot +from tests.link.mlx.test_basic import compare_mlx_and_py + + +# Equivalent blockwise to matmul but with dumb signature +odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)") + + +# @pytest.mark.parametrize("matmul_op", (matmul, odd_matmul)) +# def test_matmul(matmul_op): +# rng = np.random.default_rng(14) +# a = tensor("a", shape=(2, 3, 5)) +# b = tensor("b", shape=(2, 5, 3)) +# test_values = [ +# rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b) +# ] +# +# out = matmul_op(a, b) +# assert isinstance(out.owner.op, Blockwise) +# fn, _ = compare_mlx_and_py([a, b], [out], test_values) +# +## Check we are not adding any unnecessary stuff +# jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values)) +# jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul") +# expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values)) +# assert jaxpr == expected_jaxpr + + +# conv1d +# (2, 100) +# (8, 100) +# mode = valid + + +def test_blockwise_conv1d(): + rng = np.random.default_rng(14) + a = tensor("a", shape=(2, 100)) + b = tensor("b", shape=(2, 8)) + + # a_test = np.broadcast_to(np.arange(100), (2, 100)) + a_test = rng.normal(size=(2, 100)) + b_test = rng.normal(size=(2, 8)) + # b_test = np.concatenate( + # [ + # np.ones((1, 8)), + # np.zeros((1, 8)), + # np.zeros((1, 8)), + # np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, 8), + # np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, 8), + # ], + # axis=0, + # ) + + test_values = [a_test, b_test] + + out = pt.signal.convolve1d(a, b, mode="valid") + + # assert isinstance(out.owner.op, Blockwise) + compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True) diff --git a/tests/link/mlx/test_elemwise.py b/tests/link/mlx/test_elemwise.py new file mode 100644 index 0000000000..7819df06be --- /dev/null +++ b/tests/link/mlx/test_elemwise.py @@ -0,0 +1,13 @@ +import pytest + +import pytensor.tensor as pt +from tests.link.mlx.test_basic import compare_mlx_and_py, mx + + +@pytest.mark.parametrize("op", [pt.any, pt.all, pt.max, pt.min]) +def test_input(op) -> None: + x = pt.vector("x") + out = op(x > 0) + x_test = mx.array([1.0, 2.0, 3.0]) + + compare_mlx_and_py([x], out, [x_test]) diff --git a/tests/link/mlx/test_math.py b/tests/link/mlx/test_math.py new file mode 100644 index 0000000000..2c08d986c9 --- /dev/null +++ b/tests/link/mlx/test_math.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor.tensor.math import Argmax, Max +from tests.link.mlx.test_basic import compare_mlx_and_py, mx + + +def test_dot(): + x = pt.matrix("x") + y = pt.matrix("y") + + out = x.dot(y) + fn = pytensor.function([x, y], out, mode="MLX") + + seed = sum(map(ord, "test_mlx_dot")) + rng = np.random.default_rng(seed) + + test_x = rng.normal(size=(3, 2)) + test_y = rng.normal(size=(2, 4)) + + actual = fn(test_x, test_y) + assert isinstance(actual, mx.array) + expected = np.dot(test_x, test_y) + np.testing.assert_allclose(actual, expected, rtol=1e-6) + + +@pytest.mark.parametrize( + "op", + [ + pytest.param(pt.exp, id="exp"), + pytest.param(pt.log, id="log"), + pytest.param(pt.sin, id="sin"), + pytest.param(pt.cos, id="cos"), + pytest.param(pt.sigmoid, id="sigmoid"), + ], +) +def test_elemwise_one_input(op) -> None: + x = pt.vector("x") + out = op(x) + x_test = mx.array([1.0, 2.0, 3.0]) + compare_mlx_and_py([x], out, [x_test]) + + +def test_switch() -> None: + x = pt.vector("x") + y = pt.vector("y") + + out = pt.switch(x > 0, y, x) + + x_test = mx.array([-1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + +@pytest.mark.parametrize("op", [pt.sum, pt.prod]) +def test_input(op) -> None: + x = pt.vector("x") + y = pt.vector("y") + out = op([x, y, x + y]) + x_test = mx.array([1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + +@pytest.mark.parametrize( + "op", + [ + pytest.param(pt.add, id="add"), + pytest.param(pt.sub, id="sub"), + pytest.param(pt.mul, id="mul"), + pytest.param(pt.power, id="power"), + pytest.param(pt.le, id="le"), + pytest.param(pt.lt, id="lt"), + pytest.param(pt.ge, id="ge"), + pytest.param(pt.gt, id="gt"), + pytest.param(pt.eq, id="eq"), + pytest.param(pt.neq, id="neq"), + pytest.param(pt.true_div, id="true_div"), + ], +) +def test_elemwise_two_inputs(op) -> None: + x = pt.vector("x") + y = pt.vector("y") + out = op(x, y) + x_test = mx.array([1.0, 2.0, 3.0]) + y_test = mx.array([4.0, 5.0, 6.0]) + compare_mlx_and_py([x, y], out, [x_test, y_test]) + + +@pytest.mark.xfail(reason="Argmax not implemented yet") +def test_mlx_max_and_argmax(): + # Test that a single output of a multi-output `Op` can be used as input to + # another `Op` + x = pt.dvector() + mx = Max([0])(x) + amx = Argmax([0])(x) + out = mx * amx + compare_mlx_and_py([x], [out], [np.r_[1, 2]]) diff --git a/tests/link/mlx/test_shape.py b/tests/link/mlx/test_shape.py new file mode 100644 index 0000000000..7a548df8f8 --- /dev/null +++ b/tests/link/mlx/test_shape.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.configdefaults import config +from pytensor.tensor.shape import Shape, Shape_i, reshape +from pytensor.tensor.type import iscalar, vector +from tests.link.mlx.test_basic import compare_mlx_and_py + + +@pytest.mark.xfail(reason="Shape Op is not supported yet") +def test_mlx_shape_ops(): + x_np = np.zeros((20, 3)) + x = Shape()(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], [], must_be_device_array=False) + + x = Shape_i(1)(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], [], must_be_device_array=False) + + +@pytest.mark.xfail(reason="Shape Op is not supported yet") +def test_mlx_specify_shape(): + in_pt = pt.matrix("in") + x = pt.specify_shape(in_pt, (4, None)) + compare_mlx_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)]) + + # When used to assert two arrays have similar shapes + in_pt = pt.matrix("in") + shape_pt = pt.matrix("shape") + x = pt.specify_shape(in_pt, shape_pt.shape) + + compare_mlx_and_py( + [in_pt, shape_pt], + [x], + [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], + ) + + +@pytest.mark.xfail(reason="Reshape Op is not supported yet") +def test_mlx_Reshape_constant(): + a = vector("a") + x = reshape(a, (2, 2)) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +@pytest.mark.xfail(reason="Reshape Op is not supported yet") +def test_mlx_Reshape_concrete_shape(): + """MLX should compile when a concrete value is passed for the `shape` parameter.""" + a = vector("a") + x = reshape(a, a.shape) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) + compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +@pytest.mark.xfail(reason="`shape_pt` should be specified as a static argument") +def test_mlx_Reshape_shape_graph_input(): + a = vector("a") + shape_pt = iscalar("b") + x = reshape(a, (shape_pt, shape_pt)) + compare_mlx_and_py( + [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2] + ) + + +@pytest.mark.xfail(reason="ViewOp Op is not supported yet") +def test_mlx_compile_ops(): + x = DeepCopyOp()(pt.as_tensor_variable(1.1)) + compare_mlx_and_py([], [x], []) + + x_np = np.zeros((20, 1, 1)) + x = ViewOp()(pt.as_tensor_variable(x_np)) + + compare_mlx_and_py([], [x], [])