Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start to cleanup/unify accelerate and common back-ends (Part 1/N) #1777

Merged
merged 8 commits into from
Jan 29, 2025

Conversation

awni
Copy link
Member

@awni awni commented Jan 19, 2025

Begin to cleanup accelerate and common back-ends. This is just a small step and not the final API by any means, but I think it is a step in the right direction.

  • Adds mlx::core::simdSimd<T, N> generic type and many free functions
  • Gets rid of unary primitives in accelerate and instead use simd ops
  • Gets rid of binary primitives in accelerate and instead use simd ops
  • Gets rid of lots of unnecessary uses of eval with primitives eval_cpu in both accelerate and common
  • Unify softmax into a common simd implementation
  • Unify quantized into a common simd implementation and use the simd implementation for more types / quant parameters

TODOs for this PR:

  • Fix CPU compilation
  • Simd half type
  • Unify fast exp implementation
  • Benchmarks are up to standard, ideally no perf lost

TODOs for future PRs

  • Try and remove the rest of the Accelerate back-end by using Simd<T, N>
  • Add more custom implementations of functions using SIMD ops
  • Add more SIMD back-ends (Avx, Neon for all types, ...)
  • Use simd ops in compilation

@awni awni force-pushed the unify_cpu_part_1 branch from 64ff27d to b65d908 Compare January 19, 2025 03:02
@awni awni force-pushed the unify_cpu_part_1 branch from b65d908 to 7689ab8 Compare January 24, 2025 01:09
@awni awni marked this pull request as draft January 24, 2025 01:09
@awni awni force-pushed the unify_cpu_part_1 branch 3 times, most recently from 062ef77 to a990d4b Compare January 24, 2025 23:56
@awni awni force-pushed the unify_cpu_part_1 branch 2 times, most recently from 25206a1 to a0de7d4 Compare January 25, 2025 17:35
@awni awni force-pushed the unify_cpu_part_1 branch 2 times, most recently from 4512bd8 to 32ac258 Compare January 26, 2025 01:29
@awni awni force-pushed the unify_cpu_part_1 branch 2 times, most recently from e58969a to cc74427 Compare January 26, 2025 21:26
@awni awni force-pushed the unify_cpu_part_1 branch from cc74427 to 681dc8a Compare January 27, 2025 03:21
@awni
Copy link
Member Author

awni commented Jan 27, 2025

Sorry for the massive diff. I'm going to stop adding to this one as I think it's mergable. I will share some benchmarks shortly. There is still some work to do to remove the rest of accelerate but it's nearly there. After that, I think it would be good to split the "CPU" part of common out into a CPU back-end (which mirrors no_cpu) and will be simpler from a CMake standpoint moving forward. Then common is only what is actually common between all back-ends.

@awni awni marked this pull request as ready for review January 27, 2025 16:58
@awni awni force-pushed the unify_cpu_part_1 branch from bda8d7f to e7c8351 Compare January 27, 2025 17:18
@awni
Copy link
Member Author

awni commented Jan 27, 2025

QMM benchmarks, M3 Max

benchmark pre post
float32 qmv 4 32 7.55702 msec 1.04029 msec
float16 qmv 4 32 20.29642 msec 1.17468 msec
float32 qmv 4 64 1.42040 msec 1.01460 msec
float16 qmv 4 64 20.26675 msec 1.15008 msec
float32 qmv 8 32 7.64031 msec 1.08267 msec
float16 qmv 8 32 20.30266 msec 1.23946 msec
float32 qmv 8 64 5.43884 msec 1.08137 msec
float16 qmv 8 64 20.29582 msec 1.21884 msec

@awni
Copy link
Member Author

awni commented Jan 27, 2025

Unary ops M3 Max

TLDR

  • Mostly win for fp16
  • Mostly no change or speedups for fp32
  • Exception: sine/cos have faster vector implementations in accelerate (simd is slow for some reason
  • Edit added a fast implementation SIMD ERF is very slow.. I might replace this one with our custom one in this PR since it's a big hit
Benchmark pre post
Dtype mlx.core.float32 abs 0.65580 msec 0.58867 msec
Dtype mlx.core.float32 exp 2.92985 msec 2.98762 msec
Dtype mlx.core.float32 cos 2.78080 msec 4.97950 msec
Dtype mlx.core.float32 erf 11.16833 msec 5.75603 msec
Dtype mlx.core.float32 sign 0.58961 msec 0.80921 msec
Dtype mlx.core.float32 round 0.58400 msec 0.57695 msec
Dtype mlx.core.float32 floor 0.58393 msec 0.59424 msec
Dtype mlx.core.float16 abs 0.30060 msec 0.30210 msec
Dtype mlx.core.float16 exp 6.06438 msec 2.76648 msec
Dtype mlx.core.float16 cos 23.14954 msec 5.37498 msec
Dtype mlx.core.float16 erf 14.86014 msec 5.90261 msec
Dtype mlx.core.float16 sign 0.87095 msec 0.30026 msec
Dtype mlx.core.float16 round 0.25200 msec 0.28248 msec
Dtype mlx.core.float16 floor 0.25147 msec 0.28243 msec
Dtype mlx.core.int8 Timing abs 18.04146 msec 1.16671 msec
Dtype mlx.core.int16 Timing abs 1.17901 msec 1.16131 msec
Dtype mlx.core.int32 Timing abs 1.16406 msec 1.16534 msec
Dtype mlx.core.int64 Timing abs 1.17637 msec 1.17697 msec

@awni
Copy link
Member Author

awni commented Jan 27, 2025

Binary op on M3 Max

Observations:

  • Many simple ops are unchanged.. seems like the compiler was already doing a decent job
  • Some ops are way faster (remainder, scalar-vector atan2 and fp16 maximum)
  • Only a couple cases where the VV ops were faster and not by much
Benchmark Pre Post
vector-vector, dtype mlx.core.float32 arctan2 6.80063 msec 8.70261 msec
vector-vector, dtype mlx.core.float32 power 10.08954 mse 12.25302 msec
vector-vector, dtype mlx.core.float32 remainder 70.03200 mse 4.85313 msec
vector-vector, dtype mlx.core.float32 divide 1.09933 msec 0.86591 msec
vector-vector, dtype mlx.core.float32 add 0.89543 msec 0.86532 msec
vector-vector, dtype mlx.core.float32 multiply 0.83800 msec 0.86622 msec
vector-vector, dtype mlx.core.float32 equal 0.61173 msec 0.63793 msec
vector-vector, dtype mlx.core.float32 greater 0.60726 msec 0.63786 msec
vector-vector, dtype mlx.core.float32 less 0.60600 msec 0.63802 msec
vector-vector, dtype mlx.core.float32 logical_and 3.20360 msec 3.07354 msec
vector-vector, dtype mlx.core.float32 maximum 0.86219 msec 0.80908 msec
scalar-vector, dtype mlx.core.float32 arctan2 37.16476 mse 8.47581 msec
scalar-vector, dtype mlx.core.float32 divide 1.08531 msec 0.56309 msec
scalar-vector, dtype mlx.core.float32 add 0.65866 msec 0.57344 msec
scalar-vector, dtype mlx.core.float32 multiply 0.64534 msec 0.58218 msec
scalar-vector, dtype mlx.core.float32 equal 0.36215 msec 0.38280 msec
scalar-vector, dtype mlx.core.float32 greater 0.36176 msec 0.38227 msec
scalar-vector, dtype mlx.core.float32 less 0.36376 msec 0.38213 msec
scalar-vector, dtype mlx.core.float32 logical_and 2.56100 msec 2.70192 msec
scalar-vector, dtype mlx.core.float32 maximum 0.96238 msec 0.60532 msec
vector-scalar, dtype mlx.core.float32 arctan2 37.15536 mse 8.51265 msec
vector-scalar, dtype mlx.core.float32 divide 0.67501 msec 0.58936 msec
vector-scalar, dtype mlx.core.float32 add 0.68699 msec 0.60117 msec
vector-scalar, dtype mlx.core.float32 multiply 0.67430 msec 0.60541 msec
vector-scalar, dtype mlx.core.float32 equal 0.35143 msec 0.38434 msec
vector-scalar, dtype mlx.core.float32 greater 0.34812 msec 0.38447 msec
vector-scalar, dtype mlx.core.float32 less 0.34799 msec 0.38467 msec
vector-scalar, dtype mlx.core.float32 logical_and 2.55481 msec 2.71672 msec
vector-scalar, dtype mlx.core.float32 maximum 0.84953 msec 0.61606 msec
vector-vector, dtype mlx.core.float16 arctan2 40.40737 mse 8.79723 msec
vector-vector, dtype mlx.core.float16 power 23.76465 mse 12.39425 msec
vector-vector, dtype mlx.core.float16 remainder 41.67515 mse 4.88845 msec
vector-vector, dtype mlx.core.float16 divide 0.29231 msec 0.31124 msec
vector-vector, dtype mlx.core.float16 add 0.23684 msec 0.30132 msec
vector-vector, dtype mlx.core.float16 multiply 0.23680 msec 0.30149 msec
vector-vector, dtype mlx.core.float16 equal 0.18635 msec 0.28252 msec
vector-vector, dtype mlx.core.float16 greater 0.18646 msec 0.28600 msec
vector-vector, dtype mlx.core.float16 less 0.18715 msec 0.28376 msec
vector-vector, dtype mlx.core.float16 logical_and 2.35622 msec 2.50903 msec
vector-vector, dtype mlx.core.float16 maximum 0.41909 msec 0.28236 msec
scalar-vector, dtype mlx.core.float16 arctan2 37.71406 mse 8.85521 msec
scalar-vector, dtype mlx.core.float16 divide 1.44899 msec 0.91199 msec
scalar-vector, dtype mlx.core.float16 add 0.84486 msec 0.88652 msec
scalar-vector, dtype mlx.core.float16 multiply 0.85067 msec 0.89773 msec
scalar-vector, dtype mlx.core.float16 equal 0.85125 msec 0.80171 msec
scalar-vector, dtype mlx.core.float16 greater 0.85024 msec 0.80474 msec
scalar-vector, dtype mlx.core.float16 less 0.83678 msec 0.80777 msec
scalar-vector, dtype mlx.core.float16 logical_and 2.35563 msec 2.51308 msec
scalar-vector, dtype mlx.core.float16 maximum 3.53275 msec 0.89007 msec
vector-scalar, dtype mlx.core.float16 arctan2 37.74590 mse 8.89078 msec
vector-scalar, dtype mlx.core.float16 divide 0.87528 msec 0.93131 msec
vector-scalar, dtype mlx.core.float16 add 0.86858 msec 0.90258 msec
vector-scalar, dtype mlx.core.float16 multiply 0.88051 msec 0.90902 msec
vector-scalar, dtype mlx.core.float16 equal 0.84893 msec 0.80818 msec
vector-scalar, dtype mlx.core.float16 greater 0.84330 msec 0.80636 msec
vector-scalar, dtype mlx.core.float16 less 0.84232 msec 0.80553 msec
vector-scalar, dtype mlx.core.float16 logical_and 2.37919 msec 2.51553 msec
vector-scalar, dtype mlx.core.float16 maximum 3.52773 msec 0.89038 msec

@awni awni force-pushed the unify_cpu_part_1 branch from 2530efd to 9829c12 Compare January 27, 2025 21:29
Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me!!! Fantastic job.

Can't wait to add this to the CPU compile as well. It's gonna be beautiful.

@awni awni force-pushed the unify_cpu_part_1 branch from 9829c12 to b38b394 Compare January 28, 2025 15:20
@awni awni merged commit 4758c8b into main Jan 29, 2025
5 checks passed
@awni awni deleted the unify_cpu_part_1 branch January 29, 2025 22:34
constexpr std::array<uint32_t, 8> shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto l = simd::Simd<uint32_t, 4>(*w++);
auto r = simd::Simd<uint32_t, 4>(*w);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for posting post-merge, but the following is easier to describe in the context of this PR:

Clang on the latest FreeBSD, 13.4, and 14.2, is unhappy with the sims::Simd calls in L167-L168: error: implicit instantiation of undefined template 'mlx::core::simd::Simd<unsigned int, 4>, cf.
https://buildkite.com/julialang/yggdrasil/builds/17042#0194c6a1-5b8d-4a7c-8cef-8e098d29b750/6-32274

Not sure how to resolve this...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to file an issue. That way we won't forget about this and we can help debug it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants