From 1c72cc246949b88d08dd795bac8490ccd4981f68 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Thu, 25 Apr 2024 22:51:58 -0400 Subject: [PATCH] A few improvements --- src/TriangularSolve.jl | 109 ++++++++++++++++++++++++----------------- 1 file changed, 65 insertions(+), 44 deletions(-) diff --git a/src/TriangularSolve.jl b/src/TriangularSolve.jl index 2846073..c8306fe 100644 --- a/src/TriangularSolve.jl +++ b/src/TriangularSolve.jl @@ -1,5 +1,5 @@ module TriangularSolve -using Base: @nexprs, @ntuple +using Base: @nexprs, @ntuple, Flatten if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods")) @eval Base.Experimental.@max_methods 1 @@ -22,13 +22,15 @@ using Polyester using StaticArrayInterface const LPtr{T} = Core.LLVMPtr{T,0} -_lptr(x::Ptr{T}) where {T} = reinterpret(LPtr{T}, x) -_lptr(x) = x -_ptr(x::LPtr{T}) where {T} = reinterpret(Ptr{T}, x) -_ptr(x) = x -@inline reassemble_tup(::Type{T}, t) where {T} = - LoopVectorization.reassemble_tuple(T, map(_ptr, t)) -@inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t)) +# _lptr(x::Ptr{T}) where {T} = Base.bitcast(LPtr{T}, x)::LPtr{T} +# _ptr(x::LPtr{T}) where {T} = Base.bitcast(Ptr{T}, x)::Ptr{T} +# _lptr(x) = x +# _ptr(x) = x +const reassemble_tup = LoopVectorization.reassemble_tuple +const flatten_to_tup = LoopVectorization.flatten_to_tuple +# @inline reassemble_tup(::Type{T}, t) where {T} = +# LoopVectorization.reassemble_tuple(T, map(_ptr, t)) +# @inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t)) @generated function solve_AU( A::VecUnroll{Nm1}, @@ -717,7 +719,7 @@ struct Mat{T,ColMajor} <: AbstractMatrix{T} M::Int N::Int end -Base.size(A::Mat) = (A.M, A.N) +Base.size(A::Mat)::Tuple{Int,Int} = (A.M, A.N)::Tuple{Int,Int} Base.axes(A::Mat) = (CloseOpen(A.M), CloseOpen(A.N)) Base.strides(A::Mat{T,true}) where {T} = (1, getfield(A, :x)) Base.strides(A::Mat{T,false}) where {T} = (getfield(A, :x), 1) @@ -730,11 +732,15 @@ StaticArrayInterface.static_strides(A::Mat{T,true}) where {T} = StaticArrayInterface.static_strides(A::Mat{T,false}) where {T} = (getfield(A, :x), static(1)) StaticArrayInterface.offsets(::Mat) = (static(0), static(0)) -StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,true}}) = (static(1), static(2)) -StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,false}}) = (static(2), static(1)) +StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,true}}) = + (static(1), static(2)) +StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,false}}) = + (static(2), static(1)) StaticArrayInterface.contiguous_batch_size(::Type{<:Mat}) = static(0) -StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,true}}) = (static(true),static(false)) -StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,false}}) = (static(false),static(true)) +StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,true}}) = + (static(true), static(false)) +StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,false}}) = + (static(false), static(true)) StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,true}}) = static(1) StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,false}}) = static(2) @inline function Base.getindex( @@ -771,14 +777,16 @@ end # C -= A * B @inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{false}) # _turbo_! will not be inlined - @turbo warn_check_args = false for n in indices((C, B), 2), - m in indices((C, A), 1) + @inbounds begin + @turbo warn_check_args = false for n in indices((C, B), 2), + m in indices((C, A), 1) - Cmn = C[m, n] - for k in indices((A, B), (2, 1)) - Cmn -= A[m, k] * B[k, n] + Cmn = zero(eltype(C)) + for k in indices((A, B), (2, 1)) + Cmn -= A[m, k] * B[k, n] + end + C[m, n] += Cmn end - C[m, n] = Cmn end end @inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{true}) @@ -786,11 +794,11 @@ end @tturbo warn_check_args = false for n in indices((C, B), 2), m in indices((C, A), 1) - Cmn = C[m, n] + Cmn = zero(eltype(C)) for k in indices((A, B), (2, 1)) Cmn -= A[m, k] * B[k, n] end - C[m, n] = Cmn + C[m, n] += Cmn end end @inline function schur_complement!( @@ -828,15 +836,17 @@ end end function rdiv_block_N!( - spa::AbstractStridedPointer{T}, - spu, M, N, ::Val{UNIT}, - Bsize = nothing -) where {T,UNIT} + Bsize, + ::Type{Args}, + args::Vararg{Any,K} +) where {K,Args,UNIT} + spa, spu = reassemble_tup(Args, args) spa_base = spa n = 0 + T = eltype(spa) W = VectorizationBase.pick_vector_width(T) B_normalized = Bsize === nothing ? @@ -863,29 +873,36 @@ function rdiv_block_N!( ) end end +_contig_axis(::AbstractStridedPointer{<:Any,2,X}) where {X} = X function rdiv_block_MandN!( - spa::AbstractStridedPointer{T,<:Any,XA}, - spu::AbstractStridedPointer{T,<:Any,XU}, M, N, - ::Val{UNIT} -) where {T,UNIT,XA,XU} + ::Val{UNIT}, + ::Type{Args}, + args::Vararg{Any,K} +) where {UNIT,Args,K} + spa, spu = reassemble_tup(Args, args) + T = eltype(spa) B = block_size(Val(T)) W = VectorizationBase.pick_vector_width(T) + XA = _contig_axis(spa) + XA = _contig_axis(spu) WUF = XA == XA == 2 ? W : W * unroll_factor(W) B_m = VectorizationBase.vcld(M, VectorizationBase.vcld(M, B) * WUF) * WUF m = 0 while m < M mu = m + B_m Mtemp = min(M, mu) - m - rdiv_block_N!( - spa, - spu, - Mtemp, - N, - Val{UNIT}(), - VectorizationBase.vcld(N, VectorizationBase.vcld(N, B) * W) * W - ) + let tup = (spa, spu), ftup = flatten_to_tup(tup) + rdiv_block_N!( + Mtemp, + N, + Val{UNIT}(), + VectorizationBase.vcld(N, VectorizationBase.vcld(N, B) * W) * W, + typeof(tup), + ftup... + ) + end spa = gesp(spa, (B_m, StaticInt{0}())) m = mu end @@ -905,13 +922,17 @@ function (f::RDivBlockMandNv2{UNIT})( ) where {UNIT} spa, spu, N, Mrem, Nblock, mtb = allargs for block = blockstart-1:blockstop-1 - rdiv_block_MandN!( - gesp(spa, (mtb * block, StaticInt{0}())), - spu, - Core.ifelse(block == Nblock - 1, Mrem, mtb), - N, - Val{UNIT}() - ) + let tup = (gesp(spa, (mtb * block, StaticInt{0}())), spu), + ftup = flatten_to_tup(tup) + + rdiv_block_MandN!( + Core.ifelse(block == Nblock - 1, Mrem, mtb), + N, + Val{UNIT}(), + typeof(tup), + ftup... + ) + end end end