Skip to content

Commit

Permalink
A few improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Apr 26, 2024
1 parent 1d2b77b commit 1c72cc2
Showing 1 changed file with 65 additions and 44 deletions.
109 changes: 65 additions & 44 deletions src/TriangularSolve.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module TriangularSolve
using Base: @nexprs, @ntuple
using Base: @nexprs, @ntuple, Flatten
if isdefined(Base, :Experimental) &&
isdefined(Base.Experimental, Symbol("@max_methods"))
@eval Base.Experimental.@max_methods 1
Expand All @@ -22,13 +22,15 @@ using Polyester
using StaticArrayInterface

const LPtr{T} = Core.LLVMPtr{T,0}
_lptr(x::Ptr{T}) where {T} = reinterpret(LPtr{T}, x)
_lptr(x) = x
_ptr(x::LPtr{T}) where {T} = reinterpret(Ptr{T}, x)
_ptr(x) = x
@inline reassemble_tup(::Type{T}, t) where {T} =
LoopVectorization.reassemble_tuple(T, map(_ptr, t))
@inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t))
# _lptr(x::Ptr{T}) where {T} = Base.bitcast(LPtr{T}, x)::LPtr{T}
# _ptr(x::LPtr{T}) where {T} = Base.bitcast(Ptr{T}, x)::Ptr{T}
# _lptr(x) = x
# _ptr(x) = x
const reassemble_tup = LoopVectorization.reassemble_tuple
const flatten_to_tup = LoopVectorization.flatten_to_tuple
# @inline reassemble_tup(::Type{T}, t) where {T} =
# LoopVectorization.reassemble_tuple(T, map(_ptr, t))
# @inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t))

@generated function solve_AU(
A::VecUnroll{Nm1},
Expand Down Expand Up @@ -717,7 +719,7 @@ struct Mat{T,ColMajor} <: AbstractMatrix{T}
M::Int
N::Int
end
Base.size(A::Mat) = (A.M, A.N)
Base.size(A::Mat)::Tuple{Int,Int} = (A.M, A.N)::Tuple{Int,Int}
Base.axes(A::Mat) = (CloseOpen(A.M), CloseOpen(A.N))
Base.strides(A::Mat{T,true}) where {T} = (1, getfield(A, :x))
Base.strides(A::Mat{T,false}) where {T} = (getfield(A, :x), 1)
Expand All @@ -730,11 +732,15 @@ StaticArrayInterface.static_strides(A::Mat{T,true}) where {T} =
StaticArrayInterface.static_strides(A::Mat{T,false}) where {T} =
(getfield(A, :x), static(1))
StaticArrayInterface.offsets(::Mat) = (static(0), static(0))
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,true}}) = (static(1), static(2))
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,false}}) = (static(2), static(1))
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,true}}) =
(static(1), static(2))
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,false}}) =
(static(2), static(1))
StaticArrayInterface.contiguous_batch_size(::Type{<:Mat}) = static(0)
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,true}}) = (static(true),static(false))
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,false}}) = (static(false),static(true))
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,true}}) =
(static(true), static(false))
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,false}}) =
(static(false), static(true))
StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,true}}) = static(1)
StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,false}}) = static(2)
@inline function Base.getindex(
Expand Down Expand Up @@ -771,26 +777,28 @@ end
# C -= A * B
@inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{false})
# _turbo_! will not be inlined
@turbo warn_check_args = false for n in indices((C, B), 2),
m in indices((C, A), 1)
@inbounds begin
@turbo warn_check_args = false for n in indices((C, B), 2),
m in indices((C, A), 1)

Cmn = C[m, n]
for k in indices((A, B), (2, 1))
Cmn -= A[m, k] * B[k, n]
Cmn = zero(eltype(C))
for k in indices((A, B), (2, 1))
Cmn -= A[m, k] * B[k, n]
end
C[m, n] += Cmn
end
C[m, n] = Cmn
end
end
@inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{true})
# _turbo_! will not be inlined
@tturbo warn_check_args = false for n in indices((C, B), 2),
m in indices((C, A), 1)

Cmn = C[m, n]
Cmn = zero(eltype(C))
for k in indices((A, B), (2, 1))
Cmn -= A[m, k] * B[k, n]
end
C[m, n] = Cmn
C[m, n] += Cmn
end
end
@inline function schur_complement!(
Expand Down Expand Up @@ -828,15 +836,17 @@ end
end

function rdiv_block_N!(
spa::AbstractStridedPointer{T},
spu,
M,
N,
::Val{UNIT},
Bsize = nothing
) where {T,UNIT}
Bsize,
::Type{Args},
args::Vararg{Any,K}
) where {K,Args,UNIT}
spa, spu = reassemble_tup(Args, args)
spa_base = spa
n = 0
T = eltype(spa)
W = VectorizationBase.pick_vector_width(T)
B_normalized =
Bsize === nothing ?
Expand All @@ -863,29 +873,36 @@ function rdiv_block_N!(
)
end
end
_contig_axis(::AbstractStridedPointer{<:Any,2,X}) where {X} = X
function rdiv_block_MandN!(
spa::AbstractStridedPointer{T,<:Any,XA},
spu::AbstractStridedPointer{T,<:Any,XU},
M,
N,
::Val{UNIT}
) where {T,UNIT,XA,XU}
::Val{UNIT},
::Type{Args},
args::Vararg{Any,K}
) where {UNIT,Args,K}
spa, spu = reassemble_tup(Args, args)
T = eltype(spa)
B = block_size(Val(T))
W = VectorizationBase.pick_vector_width(T)
XA = _contig_axis(spa)
XA = _contig_axis(spu)
WUF = XA == XA == 2 ? W : W * unroll_factor(W)
B_m = VectorizationBase.vcld(M, VectorizationBase.vcld(M, B) * WUF) * WUF
m = 0
while m < M
mu = m + B_m
Mtemp = min(M, mu) - m
rdiv_block_N!(
spa,
spu,
Mtemp,
N,
Val{UNIT}(),
VectorizationBase.vcld(N, VectorizationBase.vcld(N, B) * W) * W
)
let tup = (spa, spu), ftup = flatten_to_tup(tup)
rdiv_block_N!(
Mtemp,
N,
Val{UNIT}(),
VectorizationBase.vcld(N, VectorizationBase.vcld(N, B) * W) * W,
typeof(tup),
ftup...
)
end
spa = gesp(spa, (B_m, StaticInt{0}()))
m = mu
end
Expand All @@ -905,13 +922,17 @@ function (f::RDivBlockMandNv2{UNIT})(
) where {UNIT}
spa, spu, N, Mrem, Nblock, mtb = allargs
for block = blockstart-1:blockstop-1
rdiv_block_MandN!(
gesp(spa, (mtb * block, StaticInt{0}())),
spu,
Core.ifelse(block == Nblock - 1, Mrem, mtb),
N,
Val{UNIT}()
)
let tup = (gesp(spa, (mtb * block, StaticInt{0}())), spu),
ftup = flatten_to_tup(tup)

rdiv_block_MandN!(
Core.ifelse(block == Nblock - 1, Mrem, mtb),
N,
Val{UNIT}(),
typeof(tup),
ftup...
)
end
end
end

Expand Down

0 comments on commit 1c72cc2

Please sign in to comment.