Skip to content

matrix multiplication: optimizations around matmul2x2or3x3_nonzeroalpha!#1563

Open
nsajko wants to merge 20 commits into
JuliaLang:masterfrom
nsajko:s
Open

matrix multiplication: optimizations around matmul2x2or3x3_nonzeroalpha!#1563
nsajko wants to merge 20 commits into
JuliaLang:masterfrom
nsajko:s

Conversation

@nsajko

@nsajko nsajko commented Mar 17, 2026

Copy link
Copy Markdown
Member

Improvements:

  • Possibly slightly speeds up the general case; multiplying non-small matrices. Presumably because of decreasing the number of branches that non-special-cased/non-small matrices are subject to. However this speedup seems to only be significant for small-enough matrices.

  • Cover matrix products which are empty with the special case pure Julia code for small matrices. Speeds up that case. This includes matrix products where one of the "matrices" is an AbstractVector.

  • Cover matrix products of two 1x1 matrices with the special case pure Julia code for small matrices. Speeds up that case. This includes matrix products where one of the "matrices" is an AbstractVector.

The commit history is tidy for easier review.

nsajko added 7 commits March 17, 2026 00:48
Lays the groundwork for adding more small-matrix special cases while
keeping complexity low.
In case `n > 3`, which I feel is most of the time, this change
makes it so there is less branching (one branch instead of two).

The `code_typed` should also be smaller after this change, as Julia
does not do common subexpression elimination, not without LLVM.

This lays the groundwork to add special cases for `n < 2`, too, without
causing extra branching in case `n > 3`.
All square matrices with less than four rows should now be implemented
in pure Julia, without `ccall`/FFI.

One-element vectors are included, too, being treated as 1x1 matrices.
@codecov

codecov Bot commented Mar 17, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 85.00000% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 94.30%. Comparing base (17325b7) to head (03c168e).
⚠️ Report is 12 commits behind head on master.

Files with missing lines Patch % Lines
src/matmul.jl 85.00% 9 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1563      +/-   ##
==========================================
- Coverage   94.33%   94.30%   -0.04%     
==========================================
  Files          35       35              
  Lines       16007    16043      +36     
==========================================
+ Hits        15100    15129      +29     
- Misses        907      914       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@nsajko

nsajko commented Mar 17, 2026

Copy link
Copy Markdown
Member Author

Benchmarking

Example: small square matrices, including Matrix from Base and FixedSizeMatrixDefault from FixedSizeArrays.jl.

Benchmark script

bench.jl

using LinearAlgebra: LinearAlgebra
using BenchmarkTools: @btime
using Random: seed!, rand!
using FixedSizeArrays: FixedSizeMatrixDefault

function square_arrs(typ::Type, m::Int, n::Int)
    ntuple((_ -> typ(undef, (n, n))), m)
end

const m = parse(Int, ARGS[1])
const max_n = parse(Int, ARGS[2])
const seed = parse(Int, ARGS[3])
const samples = parse(Int, ARGS[4])
const evals = parse(Int, ARGS[5])

@show m
@show max_n
@show seed
@show samples
@show evals

const elt = Float32
const typ_arr = Matrix{elt}
const typ_fsa = FixedSizeMatrixDefault{elt}

global arrs::NTuple{m, typ_arr}
global fsas::NTuple{m, typ_fsa}

for n  0:max_n
    global arrs
    global fsas
    arrs = square_arrs(typ_arr, m, n)
    fsas = square_arrs(typ_fsa, m, n)
    seed!(seed)
    foreach(rand!, arrs)
    for i  eachindex(arrs, fsas)
        fsas[i] .= arrs[i]
    end
    print(' ' ^ 2)
    @show n
    print(' ' ^ 2)
    @btime prod(arrs) seconds=Inf samples=samples evals=evals
    print(' ' ^ 2)
    @btime prod(fsas) seconds=Inf samples=samples evals=evals
end

versioninfo()

versioninfo

Julia Version 1.14.0-DEV.1893
Commit b4aba01002b (2026-03-15 04:25 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × AMD Ryzen 3 5300U with Radeon Graphics
  WORD_SIZE: 64
  LLVM: libLLVM-20.1.8 (ORCJIT, znver2)
  GC: Built with stock GC
Threads: 7 default, 1 interactive, 7 GC (on 8 virtual cores)
Environment:
  JULIA_PRUNE_OLD_LA = true
  JULIA_NUM_PRECOMPILE_TASKS = 2
  JULIA_PKG_PRECOMPILE_AUTO = 0

Performance measurements

Git branch master, ea5b648:

master branch

m = 30
max_n = 6
seed = 123
samples = 2000
evals = 200
  n = 0
    484.810 ns (29 allocations: 1.36 KiB)
    503.540 ns (0 allocations: 0 bytes)
  n = 1
    4.121 μs (58 allocations: 2.27 KiB)
    4.479 μs (29 allocations: 928 bytes)
  n = 2
    1.386 μs (58 allocations: 2.72 KiB)
    1.635 μs (29 allocations: 1.36 KiB)
  n = 3
    1.577 μs (58 allocations: 3.17 KiB)
    1.796 μs (29 allocations: 1.81 KiB)
  n = 4
    4.971 μs (58 allocations: 4.08 KiB)
    5.176 μs (29 allocations: 2.72 KiB)
  n = 5
    6.000 μs (58 allocations: 4.98 KiB)
    6.225 μs (29 allocations: 3.62 KiB)
  n = 6
    6.608 μs (58 allocations: 6.34 KiB)
    6.837 μs (29 allocations: 4.98 KiB)

PR branch, dc7a9d3:

PR branch

m = 30
max_n = 6
seed = 123
samples = 2000
evals = 200
  n = 0
    464.370 ns (29 allocations: 1.36 KiB)
    493.575 ns (0 allocations: 0 bytes)
  n = 1
    1.395 μs (58 allocations: 2.27 KiB)
    1.420 μs (29 allocations: 928 bytes)
  n = 2
    1.524 μs (58 allocations: 2.72 KiB)
    1.721 μs (29 allocations: 1.36 KiB)
  n = 3
    1.703 μs (58 allocations: 3.17 KiB)
    1.800 μs (29 allocations: 1.81 KiB)
  n = 4
    4.887 μs (58 allocations: 4.08 KiB)
    4.992 μs (29 allocations: 2.72 KiB)
  n = 5
    5.833 μs (58 allocations: 4.98 KiB)
    5.888 μs (29 allocations: 3.62 KiB)
  n = 6
    6.521 μs (58 allocations: 6.34 KiB)
    6.560 μs (29 allocations: 4.98 KiB)

Interpretation

Speedups for all n from 0:6, except for 2:3. There should be no change for 2:3.

@dkarrasch dkarrasch left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Do we have tests for the 1x1 case?

Comment thread src/matmul.jl Outdated
Comment thread src/matmul.jl
Comment thread src/matmul.jl
Comment on lines +1239 to +1247
elseif tA_uc == 'S'
if isuppercase(tA) # tA == 'S'
A11 = symmetric(a11, :U)
else
A11 = symmetric(a11, :L)
end
elseif tA_uc == 'H'
if isuppercase(tA) # tA == 'H'
A11 = hermitian(a11, :U)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not covered by the added tests? Not immediately sure how to proceed. I suppose these cases are handled somewhere else, never dispatching to this part of the code.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what that is worth, the 2x2 and 3x3 correspondents (pre-existing) also lack complete coverage in the corresponding method body parts.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be that those never run by the small matmul branch, and always go through the symm/hemm route?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed the Int elements. Nevermind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants