Skip to content

Implement all Ops in mlx #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
48 tasks
williambdean opened this issue Apr 8, 2025 · 1 comment
Open
48 tasks

Implement all Ops in mlx #1350

williambdean opened this issue Apr 8, 2025 · 1 comment

Comments

@williambdean
Copy link
Contributor

williambdean commented Apr 8, 2025

Description

Support the for mlx backend to unleash Apple Silicon's full potential.

Similar to #821

Tensor creation Ops

  • Alloc and AllocEmpty
  • Arange
  • Eye
  • ScalarFromTensor
  • TensorFromScalar
  • Repeat
  • Unique
  • Sort / Argsort
  • Tri

Shape Ops

  • Dimshuffle
  • Reshape
  • Shape, Shape_i
  • SpecifyShape
  • Unbroadcast
  • Join
  • Split
  • MakeVector

Math Ops

  • Elemwise
  • CAReduce (Sum, All, Any...)
  • CumOp
  • Softmax, LogSoftmax and Grads
  • Dot
  • BatchedDot
  • Argmax

Indexing Ops

  • Subtensor
  • Inc/SetSubtensor
  • AdvancedSubtensor[1]
  • AdvancedIncSubtensor[1]

Branching Ops

  • CheckAndRaise
  • Ifelse
  • ScalarLoop
  • Scan
  • OpFromGraph
  • Blockwise

Linalg Ops

  • SVD
  • Det
  • Eig
  • Eigh
  • MatrixInverse
  • MatrixPinv
  • QRFull
  • SLogDet
  • BlockDiagonal
  • Cholesky
  • Solve
  • SolveTriangular

SparseOps

  • ... (to be filled)

RandomVariable Ops

  • ... (to be filled)

If you need an Op that's not in this list, comment below and we'll add it!

@ricardoV94
Copy link
Member

I suggest starting with a POC with a new linker to compile a simple graph like:

import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.matrix("x")
y = pt.matrix("y")
out = pt.dot(x, y)
fn = pytensor.function([x, y], out, mode="MLX")

test_x = np.random.normal(size=(3, 2))
test_y = np.random.normal(size=(2, 4))
out = fn(test_x, test_y)
np.testing.assert_allclose(out, np.dot(test_x, test_y))

@williambdean williambdean mentioned this issue Apr 11, 2025
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants