diff --git a/Project.toml b/Project.toml index 2ee8d80..15ef9e9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TriangularSolve" uuid = "d5829a12-d9aa-46ab-831f-fb7c9ab06edf" authors = ["chriselrod and contributors"] -version = "0.1.21" +version = "0.2.0" [deps] CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9" diff --git a/README.md b/README.md index 0796969..2e5150a 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,74 @@ Platform Info: Environment: JULIA_NUM_THREADS = 8 ``` +Single-threaded benchmarks on an M1 mac: +```julia +julia> N = 100; + +julia> A = rand(N,N); B = rand(N,N); C = similar(A); + +julia> @benchmark TriangularSolve.rdiv!($C, $A, UpperTriangular($B), Val(false)) # false means single threaded +BenchmarkTools.Trial: 10000 samples with 1 evaluation. + Range (min … max): 21.416 μs … 34.458 μs ┊ GC (min … max): 0.00% … 0.00% + Time (median): 21.624 μs ┊ GC (median): 0.00% + Time (mean ± σ): 21.767 μs ± 491.788 ns ┊ GC (mean ± σ): 0.00% ± 0.00% + + ▃ ▆██ ▆▄ ▁ ▃▄ ▄▂ ▁ ▂▃▁ ▂ + ▃▇█▁███▁██▁█▆▁▁▁▁▁▁▁▁▁▁▁▁▁▃█▁██▁███▁▆▃▁▁▆▇▁██▁█▆▅▁▄▃▁▃▃▇▁███ █ + 21.4 μs Histogram: log(frequency) by time 23.2 μs < + + Memory estimate: 0 bytes, allocs estimate: 0. + +julia> @benchmark rdiv!(copyto!($C, $A), UpperTriangular($B)) +BenchmarkTools.Trial: 10000 samples with 1 evaluation. + Range (min … max): 39.124 μs … 57.749 μs ┊ GC (min … max): 0.00% … 0.00% + Time (median): 46.166 μs ┊ GC (median): 0.00% + Time (mean ± σ): 46.274 μs ± 1.766 μs ┊ GC (mean ± σ): 0.00% ± 0.00% + + ▁▁▄▂▆▃█▅▇▄▇▅▃▃▁▃▁▂ + ▂▁▁▂▂▂▂▂▁▂▂▂▂▂▂▃▃▃▃▃▄▄▅▅▆▅▇▇████████████████████▆▇▆▆▅▆▅▅▄▃▃ ▅ + 39.1 μs Histogram: frequency by time 50.2 μs < + Memory estimate: 0 bytes, allocs estimate: 0. + +julia> @benchmark ldiv!($C, LowerTriangular($B), $A) +BenchmarkTools.Trial: 10000 samples with 1 evaluation. + Range (min … max): 48.291 μs … 57.833 μs ┊ GC (min … max): 0.00% … 0.00% + Time (median): 49.124 μs ┊ GC (median): 0.00% + Time (mean ± σ): 49.306 μs ± 802.143 ns ┊ GC (mean ± σ): 0.00% ± 0.00% + + ▁▃▅▆▇██▇██▇▇▆▅▄▂▂▁▁▁▂▁▁▁▁▁▁▁ ▁▁▁ ▃ + ▃████████████████████████████████████▇▆▄▂▄▃▂▃▃▄▄▃▆▅▇▇▇██▇█▇▇ █ + 48.3 μs Histogram: log(frequency) by time 53 μs < + + Memory estimate: 0 bytes, allocs estimate: 0. + +julia> @benchmark TriangularSolve.ldiv!($C, LowerTriangular($B), $A, Val(false)) # false means single threaded +BenchmarkTools.Trial: 10000 samples with 1 evaluation. + Range (min … max): 34.249 μs … 40.208 μs ┊ GC (min … max): 0.00% … 0.00% + Time (median): 34.375 μs ┊ GC (median): 0.00% + Time (mean ± σ): 34.748 μs ± 774.675 ns ┊ GC (mean ± σ): 0.00% ± 0.00% + + ▆██▆▃▄▅▃ ▁▁▄▅▅▃▂▁ ▂▃▂ ▁▂ ▂ + ████████▁▁▃▁▁▁▁▁▃▄▃▁▁▃██████████▇▅▄▅▅▆▄▄▄▄▄▅▄▄▃▅▃▄▃▅█████▇██ █ + 34.2 μs Histogram: log(frequency) by time 37.1 μs < + + Memory estimate: 0 bytes, allocs estimate: 0. +``` +Or +```julia +julia> @benchmark TriangularSolve.ldiv!($C, LowerTriangular($B), $A, Val(false)) # false means single threaded +BenchmarkTools.Trial: 10000 samples with 1 evaluation. + Range (min … max): 23.750 μs … 30.541 μs ┊ GC (min … max): 0.00% … 0.00% + Time (median): 23.875 μs ┊ GC (median): 0.00% + Time (mean ± σ): 23.948 μs ± 316.293 ns ┊ GC (mean ± σ): 0.00% ± 0.00% + + ▃▁▆ █ ▇▆▆ ▄ ▁ ▁ ▁ ▁ ▁ ▂ + ▅███▆█▁███▄█▁██▇▁▄▁▁▁▁▁▃▁▁▁▁▁▁▁▃▁▁▁▃▁▁▁▁▁▆▁▇▆█▁█▁▇▆▅▁▅▁▇▆█▁█ █ + 23.8 μs Histogram: log(frequency) by time 25 μs < + + Memory estimate: 0 bytes, allocs estimate: 0. +``` For editing convenience (you can copy/paste the above into a REPL and it should automatically strip `julia> `s and outputs, but the above is less convenient to edit if you want to try changing the benchmarks): ```julia diff --git a/src/TriangularSolve.jl b/src/TriangularSolve.jl index ff9c3c4..5b59214 100644 --- a/src/TriangularSolve.jl +++ b/src/TriangularSolve.jl @@ -1,4 +1,5 @@ module TriangularSolve +using Base: @nexprs, @ntuple if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods")) @eval Base.Experimental.@max_methods 1 @@ -38,7 +39,7 @@ using Polyester end Base.Cartesian.@nexprs $N n -> begin Base.Cartesian.@nexprs n - 1 k -> begin - A_n = Base.FastMath.sub_fast(A_n, Base.FastMath.mul_fast(A_k, U_k_n)) + A_n = vfnmadd_fast(A_k, U_k_n, A_n) end $A_n_expr end @@ -52,92 +53,128 @@ end @inline maybestore!(p, v, i, m) = vstore!(p, v, i, m) @inline maybestore!(::Nothing, v, i, m) = nothing -@inline function store_small_kern!(spa, sp, v, _, i, n, mask, ::Val{true}) +@inline function store_small_kern!(spa, sp, v, i, mask) vstore!(spa, v, i, mask) vstore!(sp, v, i, mask) end -@inline store_small_kern!(spa, ::Nothing, v, spu, i, n, mask, ::Val{true}) = - vstore!(spa, v, i, mask) - -@inline function store_small_kern!(spa, sp, v, spu, i, n, mask, ::Val{false}) - x = v / vload(spu, (n, n)) - vstore!(spa, x, i, mask) - vstore!(sp, x, i, mask) -end -@inline store_small_kern!(spa, ::Nothing, v, spu, i, n, mask, ::Val{false}) = - vstore!(spa, v / vload(spu, (n, n)), i, mask) +@inline store_small_kern!(spa, ::Nothing, v, i, mask) = vstore!(spa, v, i, mask) -@inline function store_small_kern!(spa, sp, v, spu, i, n, ::Val{true}) +@inline function store_small_kern!(spa, sp, v, i) vstore!(spa, v, i) vstore!(sp, v, i) end -@inline store_small_kern!(spa, ::Nothing, v, spu, i, n, ::Val{true}) = - vstore!(spa, v, i) - -@inline function store_small_kern!(spa, sp, v, spu, i, n, ::Val{false}) - x = v / vload(spu, (n, n)) - vstore!(spa, x, i) - vstore!(sp, x, i) -end -@inline store_small_kern!(spa, ::Nothing, v, spu, i, n, ::Val{false}) = - vstore!(spa, v / vload(spu, (n, n)), i) +@inline store_small_kern!(spa, ::Nothing, v, i) = vstore!(spa, v, i) -@inline function BdivU_small_kern!( +@generated function BdivU_small_kern!( spa::AbstractStridedPointer{T}, sp, spb::AbstractStridedPointer{T}, spu::AbstractStridedPointer{T}, - N, + ::StaticInt{N}, mask::AbstractMask{W}, ::Val{UNIT} -) where {T,UNIT,W} - # W = VectorizationBase.pick_vector_width(T) - for n ∈ CloseOpen(N) - Amn = vload(spb, (MM{W}(StaticInt(0)), n), mask) - for k ∈ SafeCloseOpen(n) - Amn = vfnmadd_fast( - vload(spa, (MM{W}(StaticInt(0)), k), mask), - vload(spu, (k, n)), - Amn - ) +) where {T,UNIT,W,N} + z = static(0) + if N == 1 + i = (MM{W}(z), z) + Amn = :(vload(spb, $i, mask)) + if !UNIT + Amn = :($Amn / vload(spu, $((z, z)))) + end + quote + $(Expr(:meta, :inline)) + store_small_kern!(spa, sp, $Amn, $i, mask) + end + else + unroll = Unroll{2,1,N,1,W,(-1 % UInt),1}((z, z)) + tostore = :(VecUnroll(Base.Cartesian.@ntuple $N Amn)) + scale = UNIT ? nothing : :(Amn_n /= vload(spu, (n - 1, n - 1))) + quote + $(Expr(:meta, :inline)) + Amn = getfield(vload(spb, $unroll, mask), :data) + Base.Cartesian.@nexprs $N n -> begin + Amn_n = getfield(Amn, n) + Base.Cartesian.@nexprs (n - 1) k -> begin + Amn_n = vfnmadd_fast(Amn_k, vload(spu, (k - 1, n - 1)), Amn_n) + end + $scale + end + store_small_kern!(spa, sp, $tostore, $unroll, mask) end - store_small_kern!( - spa, - sp, - Amn, - spu, - (MM{W}(StaticInt(0)), n), - n, - mask, - Val{UNIT}() - ) end end -@inline function BdivU_small_kern_u!( +@generated function BdivU_small_kern_u!( spa::AbstractStridedPointer{T}, sp, spb::AbstractStridedPointer{T}, spu::AbstractStridedPointer{T}, - N, + ::StaticInt{N}, ::StaticInt{U}, - ::Val{UNIT} -) where {T,U,UNIT} - W = Int(VectorizationBase.pick_vector_width(T)) - for n ∈ CloseOpen(N) - Amn = vload(spb, Unroll{1,W,U,1,W,zero(UInt),1}((StaticInt(0), n))) - for k ∈ SafeCloseOpen(n) - Amk = vload(spa, Unroll{1,W,U,1,W,zero(UInt),1}((StaticInt(0), k))) - Amn = vfnmadd_fast(Amk, vload(spu, (k, n)), Amn) + ::Val{UNIT}, + ::StaticInt{W} +) where {T,U,UNIT,N,W} + z = static(0) + if N == 1 + unroll = Unroll{1,W,U,1,W,zero(UInt),1}((z, z)) + Amn = :(vload(spb, $unroll)) + if !UNIT + Amn = :($Amn / vload(spu, $((z, z)))) + end + quote + $(Expr(:meta, :inline)) + store_small_kern!(spa, sp, $Amn, $unroll) + end + else + double_unroll = + Unroll{2,1,N,1,W,zero(UInt),1}(Unroll{1,W,U,1,W,zero(UInt),1}((z, z))) + tostore = :(VecUnroll(Base.Cartesian.@ntuple $N Amn)) + scale = UNIT ? nothing : :(Amn_n /= vload(spu, (n - 1, n - 1))) + quote + $(Expr(:meta, :inline)) + Amn = getfield(vload(spb, $double_unroll), :data) + Base.Cartesian.@nexprs $N n -> begin + Amn_n = getfield(Amn, n) + Base.Cartesian.@nexprs (n - 1) k -> begin + Amn_n = vfnmadd_fast(Amn_k, vload(spu, (k - 1, n - 1)), Amn_n) + end + $scale + end + store_small_kern!(spa, sp, $tostore, $double_unroll) end - store_small_kern!( - spa, - sp, - Amn, - spu, - Unroll{1,W,U,1,W,zero(UInt),1}((StaticInt(0), n)), - n, - Val{UNIT}() - ) + end +end +@generated function BdivU_small_kern!( + spa::AbstractStridedPointer{T}, + sp, + spb::AbstractStridedPointer{T}, + spu::AbstractStridedPointer{T}, + Nr::Int, + mask::AbstractMask{W}, + ::Val{UNIT} +) where {T,UNIT,W} + quote + # $(Expr(:meta, :inline)) + Base.Cartesian.@nif $(W - 1) n -> n == Nr n -> + BdivU_small_kern!(spa, sp, spb, spu, static(n), mask, $(Val(UNIT))) + end +end +@generated function BdivU_small_kern_u!( + spa::AbstractStridedPointer{T}, + sp, + spb::AbstractStridedPointer{T}, + spu::AbstractStridedPointer{T}, + Nr::Int, + ::StaticInt{U}, + ::Val{UNIT}, + ::StaticInt{W} +) where {T,U,UNIT,W} + su = static(U) + vu = Val(UNIT) + sw = static(W) + quote + # $(Expr(:meta, :inline)) + Base.Cartesian.@nif $(W - 1) n -> n == Nr n -> + BdivU_small_kern_u!(spa, sp, spb, spu, static(n), $su, $vu, $sw) end end @@ -151,27 +188,34 @@ end ::StaticInt{U}, ::Val{UNIT} ) where {W,U,UNIT} + z = static(0) quote $(Expr(:meta, :inline)) - # here, we just want to load the vectors + # C = A / U; C * U = A + # A_{i,j} = C_{i,j}U_{j,j} + \sum_{k=1}^{j-1} C_{i,k}U_{k,j} + # C_{i,j} = (A_{i,j} - \sum_{k=1}^{j-1} C_{i,k}U_{k,j}) / U_{j,j} + # Load A_{i,j} + # Actually: (A_{i+[0,W*U), j+[0,W)}): + # outer unroll are `W` columns + # Inner unroll are `W*U` rows (U simd vecs) C11 = VectorizationBase.data( vload( spa, - Unroll{2,1,$W,1,$W,zero(UInt),1}( - Unroll{1,$W,$U,1,$W,zero(UInt),1}((StaticInt(0), n)) + $(Unroll{2,1,W,1,W,zero(UInt),1})( + $(Unroll{1,W,U,1,W,zero(UInt),1})(($z, n)) ) ) ) Base.Cartesian.@nexprs $W c -> C11_c = C11[c] for nk ∈ SafeCloseOpen(n) # nmuladd - A11 = vload(spc, Unroll{1,$W,$U,1,$W,zero(UInt),1}((StaticInt(0), nk))) + A11 = vload(spc, $(Unroll{1,W,U,1,W,zero(UInt),1})(($(StaticInt(0)), nk))) Base.Cartesian.@nexprs $W c -> C11_c = vfnmadd_fast(A11, vload(spu, (nk, n + (c - 1))), C11_c) end C11vu = - solve_AU(VecUnroll((Base.Cartesian.@ntuple $W C11)), spu, n, Val{$UNIT}()) - i = Unroll{2,1,$W,1,$W,zero(UInt),1}( - Unroll{1,$W,$U,1,$W,zero(UInt),1}((StaticInt(0), n)) + solve_AU(VecUnroll((Base.Cartesian.@ntuple $W C11)), spu, n, $(Val(UNIT))) + i = $(Unroll{2,1,W,1,W,zero(UInt),1})( + $(Unroll{1,W,U,1,W,zero(UInt),1})(($z, n)) ) vstore!(spc, C11vu, i) maybestore!(spb, C11vu, i) @@ -192,54 +236,281 @@ end else :(vstore!(spc, C11, i, mask)) end + z = static(0) quote $(Expr(:meta, :inline)) # here, we just want to load the vectors C11 = VectorizationBase.data( - vload(spa, Unroll{2,1,$W,1,$W,(-1 % UInt),1}((StaticInt(0), n)), mask) + vload(spa, $(Unroll{2,1,W,1,W,(-1 % UInt),1})(($z, n)), mask) ) Base.Cartesian.@nexprs $W c -> C11_c = C11[c] for nk ∈ SafeCloseOpen(n) # nmuladd - A11 = vload(spc, (MM{$W}(StaticInt(0)), nk), mask) + A11 = vload(spc, ($(MM{W}(z)), nk), mask) Base.Cartesian.@nexprs $W c -> C11_c = vfnmadd_fast(A11, vload(spu, (nk, n + (c - 1))), C11_c) end C11 = VecUnroll((Base.Cartesian.@ntuple $W C11)) - C11 = solve_AU(C11, spu, n, Val{$UNIT}()) - i = Unroll{2,1,$W,1,$W,(-1 % UInt),1}((StaticInt(0), n)) + C11 = solve_AU(C11, spu, n, $(Val(UNIT))) + i = $(Unroll{2,1,W,1,W,(-1 % UInt),1})(($z, n)) $storecexpr maybestore!(spb, C11, i, mask) end end +@generated function ldiv_solve_W_u!( + spc, + spa, + spu, + n, + ::StaticInt{W}, + ::StaticInt{U}, + ::Val{UNIT} +) where {W,U,UNIT} + z = static(0) + quote + # $(Expr(:meta, :inline)) + # C = L \ A; L * C = A + # A_{i,j} = L_{i,i}*C_{i,j} + \sum_{k=1}^{i-1}L_{i,k}C_{k,j} + # C_{i,j} = L_{i,i} \ (A_{i,j} - \sum_{k=1}^{i-1}L_{i,k}C_{k,j}) + # The inputs here are transposed, as the library was formulated in terms of `rdiv!`, + # so we have + # C_{j,i} = (A_{j,i} - \sum_{k=1}^{i-1}C_{j,k}U_{k,i}) / L_{i,i} + # This solves for the block: C_{j+[0,W],i+[0,W*U)} + # This can be viewed as `U` blocks that are each `W`x`W` + # E.g. U=3, rough alg: + # r=[0,W); c=[0,WU) + # X = A_{j+r,i+c} - \sum_{k=1}^{i-1}C_{j+r,k}*U_{k,i+c} + # C_{j+r,i+r} = X[:, r] / U_{i+r,i+r} + # C_{j+r,i+W+r} = (X[:, W+r] - C_{j+r,i+r}*U_{i+r,i+W+r}) / U_{i+W+r,i+W+r} + # C_{j+r,i+2W+r} = (X[:, 2W+r] - C_{j+r,i+r}*U_{i+r,i+2W+r} - C_{j+r,i+W+r}*U_{i+W+r,i+2W+r}) / U_{i+2W+r,i+2W+r} + # + # outer unroll are `W` rows + # Inner unroll are `W*U` columns (U simd vecs) + # + A11 = getfield( + vload( + spa, + $(Unroll{1,1,W,2,W,zero(UInt),1})( + $(Unroll{2,W,U,2,W,zero(UInt),1})(($z, n)) + ) + ), + :data + ) + # The `W` rows + Base.Cartesian.@nexprs $W c -> A11_c = getfield(A11, c) + # compute + # A_{j,i} - \sum_{k=1}^{i-1}U_{k,i}C_{j,k}) + # Each iter: + # A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)} + for nk ∈ SafeCloseOpen(n) # nmuladd + U_ki = vload(spu, $(Unroll{2,W,U,2,W,zero(UInt),1})((nk, n))) + Base.Cartesian.@nexprs $W c -> + A11_c = vfnmadd_fast(U_ki, vload(spc, (static(c - 1), nk)), A11_c) + end + # solve AU wants: + # outer unroll are `W` columns + # Inner unroll are `W` rows (U simd vecs) + # So, we'll use `U = 1`, and transpose blocks + # We then have column-major multiplies + Base.Cartesian.@nexprs $U u -> begin + # take A[(u-1)*W,u*W), [0,W)] + X_u = getfield( + VectorizationBase.transpose_vecunroll( + VecUnroll( + Base.Cartesian.@ntuple $W w -> + getfield(getfield(A11_w, :data), u) + ) + ), + :data + ) + Base.Cartesian.@nexprs $W c -> X_u_c = getfield(X_u, c) + Base.Cartesian.@nexprs (u - 1) j -> begin + # subtract + # r = W*(j-1)+[0,W) + # A_{j+[0,W),i+r} -= C_{j+[0,W),r}*U_{r,i+r} + # W x W matmul + Base.Cartesian.@nexprs $W k -> begin # reduction + Base.Cartesian.@nexprs $W c -> begin # cols + U_u_j_k_c = vload( + spu, + (n + ((k - 1) + ((j - 1) * $W)), n + ((c - 1) + ((u - 1) * $W))) + ) + X_u_c = vfnmadd_fast(C_j_k, U_u_j_k_c, X_u_c) + end + end + end + C_u = solve_AU( + VecUnroll(Base.Cartesian.@ntuple $W X_u), + spu, + n + ((u - 1) * $W), + $(Val(UNIT)) + ) + Cdata_u = getfield(C_u, :data) + Base.Cartesian.@nexprs $W c -> C_u_c = getfield(Cdata_u, c) + end + # store at end (no aliasing) + Base.Cartesian.@nexprs $U u -> begin + vstore!(spc, C_u, $(Unroll{2,1,W,1,W,zero(UInt),1})(($z, n + (u - 1) * $W))) + end + end +end +@generated function ldiv_solve_W!( + spc, + spa, + spu, + n, + ::StaticInt{W}, + ::Val{UNIT} +) where {W,UNIT} + z = static(0) + quote + # $(Expr(:meta, :inline)) + # Like `ldiv_solve_W_u!`, except no unrolling, just a `W`x`W` block + # + # C = L \ A; L * C = A + # A_{i,j} = L_{i,i}*C_{i,j} + \sum_{k=1}^{i-1}L_{i,k}C_{k,j} + # C_{i,j} = L_{i,i} \ (A_{i,j} - \sum_{k=1}^{i-1}L_{i,k}C_{k,j}) + # The inputs here are transposed, as the library was formulated in terms of `rdiv!`, + # so we have + # C_{j,i} = (A_{j,i} - \sum_{k=1}^{i-1}C_{j,k}U_{k,i}) / L_{i,i} + # This solves for the block: C_{j+[0,W],i+[0,W)} + # Rough alg: + # r=[0,W); + # X = A_{j+r,i+r} - \sum_{k=1}^{i-1}C_{j+r,k}*U_{k,i+r} + # C_{j+r,i+r} = X / U_{i+r,i+r} + # + # Load the `W`x`W` block... + # what about masking? + A11 = + getfield(vload(spa, $(Unroll{1,1,W,2,W,zero(UInt),1})(($z, n))), :data) + # The `W` rows + Base.Cartesian.@nexprs $W c -> A11_c = getfield(A11, c) + # compute + # A_{j,i} - \sum_{k=1}^{i-1}U_{k,i}C_{j,k}) + # Each iter: + # A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)} + for nk ∈ SafeCloseOpen(n) # nmuladd + U_ki = vload(spu, (nk, $(MM{W})(n))) + Base.Cartesian.@nexprs $W c -> + A11_c = vfnmadd_fast(U_ki, vload(spc, (static(c - 1), nk)), A11_c) + end + # Base.Cartesian.@nexprs $W c -> @show A11_c + # solve AU wants us to transpose + # We then have column-major multiplies + # take A[(u-1)*W,u*W), [0,W)] + X = VectorizationBase.transpose_vecunroll( + VecUnroll(Base.Cartesian.@ntuple $W A11) + ) + # @show X + C_u = solve_AU(X, spu, n, $(Val(UNIT))) + vstore!(spc, C_u, $(Unroll{2,1,W,1,W,zero(UInt),1})(($z, n))) + end +end +@generated function ldiv_solve_W!( + spc, + spa, + spu, + n, + ::StaticInt{W}, + ::Val{UNIT}, + ::StaticInt{R} +) where {W,UNIT,R} + R <= 1 && throw("Remainder of `<= 1` shouldn't be called, but had $R.") + R >= W && throw("Reaminderof `>= $W` shouldn't be called, but had $R.") + z = static(0) + q = quote + # $(Expr(:meta, :inline)) + # Like `ldiv_solve_W_u!`, except no unrolling, just a `W`x`W` block + # + # C = L \ A; L * C = A + # A_{i,j} = L_{i,i}*C_{i,j} + \sum_{k=1}^{i-1}L_{i,k}C_{k,j} + # C_{i,j} = L_{i,i} \ (A_{i,j} - \sum_{k=1}^{i-1}L_{i,k}C_{k,j}) + # The inputs here are transposed, as the library was formulated in terms of `rdiv!`, + # so we have + # C_{j,i} = (A_{j,i} - \sum_{k=1}^{i-1}C_{j,k}U_{k,i}) / L_{i,i} + # This solves for the block: C_{j+[0,R],i+[0,W)} + # Rough alg: + # r=[0,R); w=[0,W); + # X = A_{j+r,i+w} - \sum_{k=1}^{i-1}C_{j+r,k}*U_{k,i+w} + # C_{j+r,i+w} = X / U_{i+r,i+w} + # + # Load the `W`x`W` block... + # what about masking? + A11 = + getfield(vload(spa, $(Unroll{1,1,R,2,W,zero(UInt),1})(($z, n))), :data) + # The `W` rows + Base.Cartesian.@nexprs $R r -> A11_r = getfield(A11, r) + # compute + # A_{j,i} - \sum_{k=1}^{i-1}U_{k,i}C_{j,k}) + # Each iter: + # A_{j+[0,W), i+[0,W*U)} -= C_{j+[0,W),k}*U_{k,i+[0,W*U)} + for nk ∈ SafeCloseOpen(n) # nmuladd + U_ki = vload(spu, (nk, $(MM{W})(n))) + Base.Cartesian.@nexprs $R r -> + A11_r = vfnmadd_fast(U_ki, vload(spc, (static(r - 1), nk)), A11_r) + end + end + # pad with zeros + Wpad = VectorizationBase.nextpow2(R) + t = Expr(:tuple) + for r = 1:R + push!(t.args, Symbol(:A11_, r)) + end + for _ = R+1:Wpad + push!(t.args, :(zero(A11_1))) + end + q2 = quote + # solve AU wants us to transpose + # We then have column-major multiplies + # take A[(u-1)*W,u*W), [0,W)] + X = VectorizationBase.transpose_vecunroll(VecUnroll($t)) + C_u = solve_AU(X, spu, n, $(Val(UNIT))) + end + push!(q.args, q2) + q3 = if R == Wpad + quote + i = $(Unroll{2,1,W,1,Wpad,zero(UInt),1})(($z, n)) + vstore!(spc, C_u, i) + end + else + quote + mask = VectorizationBase.mask($(static(Wpad)), $(static(R))) + i = $(Unroll{2,1,W,1,Wpad,(-1 % UInt),1})(($z, n)) + vstore!(spc, C_u, i, mask) + end + end + push!(q.args, q3) + return q +end + @inline function rdiv_U!( spc::AbstractStridedPointer{T}, spa::AbstractStridedPointer, spu::AbstractStridedPointer, M, N, - ::StaticInt{1}, ::Val{UNIT} ) where {T,UNIT} WS = pick_vector_width(T) W = Int(WS) UF = unroll_factor(WS) WU = UF * WS - MU = UF > 1 ? M : 0 Nd, Nr = VectorizationBase.vdivrem(N, WS) m = 0 - while m < MU - WU + 1 - n = Nr - if n > 0 - BdivU_small_kern_u!(spc, nothing, spa, spu, n, UF, Val(UNIT)) - end - for i ∈ 1:Nd - rdiv_solve_W_u!(spc, nothing, spa, spu, n, WS, UF, Val(UNIT)) - n += W + if UF > 1 + while m < M - WU + 1 + n = Nr + if n > 0 + BdivU_small_kern_u!(spc, nothing, spa, spu, n, UF, Val(UNIT), WS) + end + for _ ∈ 1:Nd + rdiv_solve_W_u!(spc, nothing, spa, spu, n, WS, UF, Val(UNIT)) + n += W + end + m += WU + spa = gesp(spa, (WU, StaticInt(0))) + spc = gesp(spc, (WU, StaticInt(0))) end - m += WU - spa = gesp(spa, (WU, StaticInt(0))) - spc = gesp(spc, (WU, StaticInt(0))) end finalmask = VectorizationBase.mask(WS, M) while m < M @@ -261,59 +532,6 @@ end nothing end -const buffer = Ref{Ptr{Cvoid}}(C_NULL) - -function __init__() - bp_size = 2 * sizeof(Int) * Threads.nthreads() - buffer[] = bp = Libc.malloc(bp_size) - Libc.memset(bp, 0, bp_size) -end - -function _get_buffer_pointer(::StaticInt{UF}, N) where {UF} - RS = VectorizationBase.register_size() - RSUF = StaticInt{UF}() * RS - L = RSUF * N - tid = Threads.threadid() - 1 - bp = Ptr{Pair{Ptr{Cvoid},Int}}(buffer[]) + 2sizeof(Int) * tid - (p, buff_current) = unsafe_load(bp) - if buff_current < L - p == C_NULL || Libc.free(p) - buff_size = max(RSUF * 128, L) - p = Libc.malloc(Int(buff_size + RS - 1)) - unsafe_store!(bp, p => buff_size) - end - return VectorizationBase.align(p, RS) -end - -@inline function lubuffer(::Val{T}, ::StaticInt{UF}, N) where {T,UF} - RS = VectorizationBase.register_size() - RSUF = StaticInt{UF}() * RS - ptr = Ptr{T}(_get_buffer_pointer(StaticInt{UF}(), N)) - si = StrideIndex{2,(1, 2),1}( - (VectorizationBase.static_sizeof(T), RSUF), - (StaticInt(0), StaticInt(0)) - ) - stridedpointer(ptr, si, StaticInt{0}()), nothing -end -@inline function lubuffer( - ::Val{T}, - ::StaticInt{UF}, - ::StaticInt{N} -) where {T,UF,N} - RSUF = StaticInt{UF}() * VectorizationBase.pick_vector_width(T) - L = RSUF * N - buf = Ref{NTuple{L,T}}() - ptr = Base.unsafe_convert(Ptr{T}, buf) - si = StrideIndex{2,(1, 2),1}( - ( - VectorizationBase.static_sizeof(T), - RSUF * VectorizationBase.static_sizeof(T) - ), - (StaticInt(0), StaticInt(0)) - ) - stridedpointer(ptr, si, StaticInt{0}()), buf -end -@inline _free(p::Ptr) = Libc.free(p) _canonicalize(x) = signed(x) _canonicalize(::StaticInt{N}) where {N} = StaticInt{N}() function div_dispatch!( @@ -336,36 +554,11 @@ function div_dispatch!( GC.@preserve spap spcp spup begin mtb = m_thread_block_size(M, N, nthread, Val(T)) if nthread > 1 - (M > mtb) && return multithread_rdiv!( - spc, - spa, - spu, - M, - N, - mtb, - Val(UNIT), - VectorizationBase.contiguous_axis(A) - ) + (M > mtb) && return multithread_rdiv!(spc, spa, spu, M, N, mtb, Val(UNIT)) elseif N > block_size(Val(T)) - return rdiv_block_MandN!( - spc, - spa, - spu, - M, - N, - Val(UNIT), - VectorizationBase.contiguous_axis(A) - ) + return rdiv_block_MandN!(spc, spa, spu, M, N, Val(UNIT)) end - return rdiv_U!( - spc, - spa, - spu, - M, - N, - VectorizationBase.contiguous_axis(A), - Val(UNIT) - ) + return rdiv_U!(spc, spa, spu, M, N, Val(UNIT)) end end @@ -582,9 +775,8 @@ function rdiv_block_N!( M, N, ::Val{UNIT}, - ::StaticInt{X}, Bsize = nothing -) where {T,UNIT,X} +) where {T,UNIT} spa_rdiv = spa spc_base = spc n = 0 @@ -605,7 +797,6 @@ function rdiv_block_N!( gesp(spu, (n, StaticInt{0}())), M, N_temp, - StaticInt{X}(), Val{UNIT}() ) repeat || break @@ -620,17 +811,16 @@ function rdiv_block_N!( end end function rdiv_block_MandN!( - spc::AbstractStridedPointer{T}, - spa, - spu, + spc::AbstractStridedPointer{T,<:Any,XC}, + spa::AbstractStridedPointer{T,<:Any,XA}, + spu::AbstractStridedPointer{T,<:Any,XU}, M, N, - ::Val{UNIT}, - ::StaticInt{X} -) where {T,UNIT,X} + ::Val{UNIT} +) where {T,UNIT,XC,XA,XU} B = block_size(Val(T)) W = VectorizationBase.pick_vector_width(T) - WUF = W * unroll_factor(W) + WUF = XC == XA == XA == 2 ? W : W * unroll_factor(W) B_m = VectorizationBase.vcld(M, VectorizationBase.vcld(M, B) * WUF) * WUF m = 0 while m < M @@ -643,7 +833,6 @@ function rdiv_block_MandN!( Mtemp, N, Val{UNIT}(), - StaticInt{X}(), VectorizationBase.vcld(N, VectorizationBase.vcld(N, B) * W) * W ) spa = gesp(spa, (B_m, StaticInt{0}())) @@ -658,12 +847,12 @@ function m_thread_block_size(M, N, nthreads, ::Val{T}) where {T} min(M, VectorizationBase.vcld(M, nb * W) * W) end -struct RDivBlockMandNv2{UNIT,X} end -function (f::RDivBlockMandNv2{UNIT,X})( +struct RDivBlockMandNv2{UNIT} end +function (f::RDivBlockMandNv2{UNIT})( allargs, blockstart, blockstop -) where {UNIT,X} +) where {UNIT} spc, spa, spu, N, Mrem, Nblock, mtb = allargs for block = blockstart-1:blockstop-1 rdiv_block_MandN!( @@ -672,8 +861,7 @@ function (f::RDivBlockMandNv2{UNIT,X})( spu, Core.ifelse(block == Nblock - 1, Mrem, mtb), N, - Val{UNIT}(), - static(X) + Val{UNIT}() ) end end @@ -685,16 +873,14 @@ function multithread_rdiv!( M::Int, N::Int, mtb::Int, - ::Val{UNIT}, - ::StaticInt{X} -) where {X,UNIT,TC,TA,TU} + ::Val{UNIT} +) where {UNIT,TC,TA,TU} # Main._a[] = (spc, spa, spu, M, N, mtb, Val(UNIT), static(X)); (Md, Mr) = VectorizationBase.vdivrem(M, mtb) Nblock = Md + (Mr ≠ 0) Mrem = Core.ifelse(Mr ≠ 0, Mr, mtb) - f = RDivBlockMandNv2{UNIT,X}() batch( - f, + RDivBlockMandNv2{UNIT}(), (Nblock, min(Nblock, Threads.nthreads())), spc, spa, @@ -720,56 +906,168 @@ function unroll_factor(::StaticInt{W}) where {W} ifelse(Static.lt(num_blocks, StaticInt{1}()), StaticInt{1}(), num_blocks) end +@generated function _ldiv_remainder!( + spc, + spa, + spu, + M, + N, + m, + Nr, + ::StaticInt{W}, + ::Val{UNIT}, + ::StaticInt{r} +) where {W,UNIT,r} + r <= 0 && throw("Remainder of `<= 0` shouldn't be called, but had $r.") + r >= W && throw("Reaminderof `>= $W` shouldn't be called, but had $r.") + if r == 1 + z = static(0) + sub = Base.FastMath.sub_fast + mul = Base.FastMath.mul_fast + div = Base.FastMath.div_fast + vlxj = :(vload(spc, ($z, j))) + if UNIT + vlxj = :(xj = $vlxj) + else + vlxj = quote + xj = $div($vlxj, vload(spu, (j, j))) + vstore!(spc, xj, ($z, j)) + end + end + quote + $(Expr(:meta, :inline)) + if pointer(spc) != pointer(spa) + for n = 0:N-1 + vstore!(spc, vload(spa, ($z, n)), ($z, n)) + end + end + for j = 0:N-1 + $vlxj + for i = (j+1):N-1 + xi = vload(spc, ($z, i)) + Uji = vload(spu, (j, i)) + vstore!(spc, $sub(xi, $mul(xj, Uji)), ($z, i)) + end + end + end + else + WS = static(W) + quote + $(Expr(:meta, :inline)) + n = Nr # non factor of W remainder + if n > 0 + mask = $(VectorizationBase.mask(WS, r)) + BdivU_small_kern!(spc, nothing, spa, spu, n, mask, $(Val(UNIT))) + end + # while n < N - $(W * U - 1) + # ldiv_solve_W_u!(spc, spa, spu, n, $WS, $US, Val(UNIT), Val(r)) + # n += $(W * U) + # end + while n != N + ldiv_solve_W!(spc, spa, spu, n, $WS, $(Val(UNIT)), $(StaticInt(r))) + n += $W + end + end + end +end +@generated function ldiv_remainder!( + spc, + spa, + spu, + M, + N, + m, + Nr, + ::StaticInt{W}, + # ::Val{U}, + ::Val{UNIT} +) where {W,UNIT} + WS = static(W) + # US = static(U) + if W == 2 + quote + $(Expr(:meta, :inline)) + _ldiv_remainder!( + spc, + spa, + spu, + M, + N, + m, + Nr, + $WS, + $(Val(UNIT)), + $(static(1)) + ) + end + else + quote + # $(Expr(:meta, :inline)) + Base.Cartesian.@nif $(W - 1) w -> m == M - w w -> _ldiv_remainder!( + spc, + spa, + spu, + M, + N, + m, + Nr, + $WS, + $(Val(UNIT)), + StaticInt(w) + ) + end + end +end + +# spc = spa / spu +# spc' = (spu' \ spa')' +# This is ldiv function rdiv_U!( - spc::AbstractStridedPointer{T}, - spa::AbstractStridedPointer, - spu::AbstractStridedPointer, + spc::AbstractStridedPointer{T,2,2}, + spa::AbstractStridedPointer{T,2,2}, + spu::AbstractStridedPointer{T,2,2}, M, N, - ::StaticInt{var"#UNUSED#"}, ::Val{UNIT} -) where {T,UNIT,var"#UNUSED#"} +) where {T,UNIT} WS = pick_vector_width(T) W = Int(WS) UF = unroll_factor(WS) WU = UF * WS MU = UF > 1 ? M : 0 Nd, Nr = VectorizationBase.vdivrem(N, WS) - spb, preserve = lubuffer(Val(T), UF, N) m = 0 - GC.@preserve preserve begin - while m < MU - WU + 1 - n = Nr - if n > 0 - BdivU_small_kern_u!(spb, spc, spa, spu, n, UF, Val(UNIT)) - end - for _ ∈ 1:Nd - rdiv_solve_W_u!(spb, spc, spa, spu, n, WS, UF, Val(UNIT)) - n += W - end - m += WU - spa = gesp(spa, (WU, StaticInt(0))) - spc = gesp(spc, (WU, StaticInt(0))) + # @show M,N + # m, no remainder + while m < M - WS + 1 + n = Nr # non factor of W remainder + if n > 0 + BdivU_small_kern_u!( + spc, + nothing, + spa, + spu, + n, + StaticInt(1), + Val(UNIT), + WS + ) end - finalmask = VectorizationBase.mask(WS, M) - while m < M - ubm = m + W - nomaskiter = ubm < M - mask = nomaskiter ? VectorizationBase.max_mask(WS) : finalmask - n = Nr - if n > 0 - BdivU_small_kern!(spb, spc, spa, spu, n, mask, Val(UNIT)) - end - for i ∈ 1:Nd - # @show C, n - rdiv_solve_W!(spb, spc, spa, spu, n, i ≠ Nd, mask, Val(UNIT)) - n += W - end - spa = gesp(spa, (WS, StaticInt(0))) - spc = gesp(spc, (WS, StaticInt(0))) - m = ubm + while n < N - (WU - 1) + ldiv_solve_W_u!(spc, spa, spu, n, WS, UF, Val(UNIT)) + n += WU + end + while n != N + ldiv_solve_W!(spc, spa, spu, n, WS, Val(UNIT)) + n += W end + m += W + spa = gesp(spa, (W, StaticInt(0))) + spc = gesp(spc, (W, StaticInt(0))) end + # remainder on `m` + m < M && ldiv_remainder!(spc, spa, spu, M, N, m, Nr, WS, Val(UNIT)) + # m < M && ldiv_remainder!(spc, spa, spu, M, N, m, Nr, WS, UF, Val(UNIT)) nothing end diff --git a/test/runtests.jl b/test/runtests.jl index d9725c6..f8e1b7d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,31 @@ using TriangularSolve, LinearAlgebra using Test +function check_box_for_nans(A, M, N) + # blocks start at 17, and are MxN + @test all(isnan, @view(A[1:16, :])) + @test all(isnan, @view(A[17+M:end, :])) + @test all(isnan, @view(A[17:16+M, 1:16])) + @test all(isnan, @view(A[17:16+M, 17+N:end])) +end + function test_solve(::Type{T}) where {T} - for n ∈ 1:(T === Float32 ? 100 : 200) + maxN = (T === Float32 ? 100 : 200) + maxM = maxN + 10 + AA = fill(T(NaN), maxM + 32, maxM + 32) + RR = fill(T(NaN), maxM + 32, maxM + 32) + BB = fill(T(NaN), maxN + 32, maxN + 32) + for n ∈ 1:maxN @show n for m ∈ max(1, n - 10):n+10 - A = rand(T, m, n) - res = similar(A) - B = rand(T, n, n) + I + A = @view AA[17:16+m, 17:16+n] + res = @view RR[17:16+m, 17:16+n] + B = @view BB[17:16+n, 17:16+n] + + A .= rand.(T) + B .= rand.(T) + @view(B[diagind(B)]) .+= one(T) + @test TriangularSolve.rdiv!(res, A, UpperTriangular(B)) * UpperTriangular(B) ≈ A @test TriangularSolve.rdiv!(res, A, UnitUpperTriangular(B)) * @@ -16,8 +34,15 @@ function test_solve(::Type{T}) where {T} UpperTriangular(B) ≈ A @test TriangularSolve.rdiv!(res, A, UnitUpperTriangular(B), Val(false)) * UnitUpperTriangular(B) ≈ A - A = rand(T, n, m) - res = similar(A) + + check_box_for_nans(RR, m, n) + res .= NaN + A .= NaN + + A = @view AA[17:16+n, 17:16+m] + res = @view RR[17:16+n, 17:16+m] + A .= rand.(T) + @test LowerTriangular(B) * TriangularSolve.ldiv!(res, LowerTriangular(B), A) ≈ A @test UnitLowerTriangular(B) * @@ -27,6 +52,10 @@ function test_solve(::Type{T}) where {T} @test UnitLowerTriangular(B) * TriangularSolve.ldiv!(res, UnitLowerTriangular(B), A, Val(false)) ≈ A + check_box_for_nans(RR, n, m) + res .= NaN + A .= NaN + B .= NaN end end end @@ -41,8 +70,5 @@ end end using Aqua -Aqua.test_all( - TriangularSolve; - ambiguities = false -) +Aqua.test_all(TriangularSolve; ambiguities = false) @test isempty(Test.detect_ambiguities(TriangularSolve))