Skip to content

Commit 1c72cc2

Browse files
committed
A few improvements
1 parent 1d2b77b commit 1c72cc2

File tree

1 file changed

+65
-44
lines changed

1 file changed

+65
-44
lines changed

src/TriangularSolve.jl

Lines changed: 65 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
module TriangularSolve
2-
using Base: @nexprs, @ntuple
2+
using Base: @nexprs, @ntuple, Flatten
33
if isdefined(Base, :Experimental) &&
44
isdefined(Base.Experimental, Symbol("@max_methods"))
55
@eval Base.Experimental.@max_methods 1
@@ -22,13 +22,15 @@ using Polyester
2222
using StaticArrayInterface
2323

2424
const LPtr{T} = Core.LLVMPtr{T,0}
25-
_lptr(x::Ptr{T}) where {T} = reinterpret(LPtr{T}, x)
26-
_lptr(x) = x
27-
_ptr(x::LPtr{T}) where {T} = reinterpret(Ptr{T}, x)
28-
_ptr(x) = x
29-
@inline reassemble_tup(::Type{T}, t) where {T} =
30-
LoopVectorization.reassemble_tuple(T, map(_ptr, t))
31-
@inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t))
25+
# _lptr(x::Ptr{T}) where {T} = Base.bitcast(LPtr{T}, x)::LPtr{T}
26+
# _ptr(x::LPtr{T}) where {T} = Base.bitcast(Ptr{T}, x)::Ptr{T}
27+
# _lptr(x) = x
28+
# _ptr(x) = x
29+
const reassemble_tup = LoopVectorization.reassemble_tuple
30+
const flatten_to_tup = LoopVectorization.flatten_to_tuple
31+
# @inline reassemble_tup(::Type{T}, t) where {T} =
32+
# LoopVectorization.reassemble_tuple(T, map(_ptr, t))
33+
# @inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t))
3234

3335
@generated function solve_AU(
3436
A::VecUnroll{Nm1},
@@ -717,7 +719,7 @@ struct Mat{T,ColMajor} <: AbstractMatrix{T}
717719
M::Int
718720
N::Int
719721
end
720-
Base.size(A::Mat) = (A.M, A.N)
722+
Base.size(A::Mat)::Tuple{Int,Int} = (A.M, A.N)::Tuple{Int,Int}
721723
Base.axes(A::Mat) = (CloseOpen(A.M), CloseOpen(A.N))
722724
Base.strides(A::Mat{T,true}) where {T} = (1, getfield(A, :x))
723725
Base.strides(A::Mat{T,false}) where {T} = (getfield(A, :x), 1)
@@ -730,11 +732,15 @@ StaticArrayInterface.static_strides(A::Mat{T,true}) where {T} =
730732
StaticArrayInterface.static_strides(A::Mat{T,false}) where {T} =
731733
(getfield(A, :x), static(1))
732734
StaticArrayInterface.offsets(::Mat) = (static(0), static(0))
733-
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,true}}) = (static(1), static(2))
734-
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,false}}) = (static(2), static(1))
735+
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,true}}) =
736+
(static(1), static(2))
737+
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,false}}) =
738+
(static(2), static(1))
735739
StaticArrayInterface.contiguous_batch_size(::Type{<:Mat}) = static(0)
736-
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,true}}) = (static(true),static(false))
737-
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,false}}) = (static(false),static(true))
740+
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,true}}) =
741+
(static(true), static(false))
742+
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,false}}) =
743+
(static(false), static(true))
738744
StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,true}}) = static(1)
739745
StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,false}}) = static(2)
740746
@inline function Base.getindex(
@@ -771,26 +777,28 @@ end
771777
# C -= A * B
772778
@inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{false})
773779
# _turbo_! will not be inlined
774-
@turbo warn_check_args = false for n in indices((C, B), 2),
775-
m in indices((C, A), 1)
780+
@inbounds begin
781+
@turbo warn_check_args = false for n in indices((C, B), 2),
782+
m in indices((C, A), 1)
776783

777-
Cmn = C[m, n]
778-
for k in indices((A, B), (2, 1))
779-
Cmn -= A[m, k] * B[k, n]
784+
Cmn = zero(eltype(C))
785+
for k in indices((A, B), (2, 1))
786+
Cmn -= A[m, k] * B[k, n]
787+
end
788+
C[m, n] += Cmn
780789
end
781-
C[m, n] = Cmn
782790
end
783791
end
784792
@inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{true})
785793
# _turbo_! will not be inlined
786794
@tturbo warn_check_args = false for n in indices((C, B), 2),
787795
m in indices((C, A), 1)
788796

789-
Cmn = C[m, n]
797+
Cmn = zero(eltype(C))
790798
for k in indices((A, B), (2, 1))
791799
Cmn -= A[m, k] * B[k, n]
792800
end
793-
C[m, n] = Cmn
801+
C[m, n] += Cmn
794802
end
795803
end
796804
@inline function schur_complement!(
@@ -828,15 +836,17 @@ end
828836
end
829837

830838
function rdiv_block_N!(
831-
spa::AbstractStridedPointer{T},
832-
spu,
833839
M,
834840
N,
835841
::Val{UNIT},
836-
Bsize = nothing
837-
) where {T,UNIT}
842+
Bsize,
843+
::Type{Args},
844+
args::Vararg{Any,K}
845+
) where {K,Args,UNIT}
846+
spa, spu = reassemble_tup(Args, args)
838847
spa_base = spa
839848
n = 0
849+
T = eltype(spa)
840850
W = VectorizationBase.pick_vector_width(T)
841851
B_normalized =
842852
Bsize === nothing ?
@@ -863,29 +873,36 @@ function rdiv_block_N!(
863873
)
864874
end
865875
end
876+
_contig_axis(::AbstractStridedPointer{<:Any,2,X}) where {X} = X
866877
function rdiv_block_MandN!(
867-
spa::AbstractStridedPointer{T,<:Any,XA},
868-
spu::AbstractStridedPointer{T,<:Any,XU},
869878
M,
870879
N,
871-
::Val{UNIT}
872-
) where {T,UNIT,XA,XU}
880+
::Val{UNIT},
881+
::Type{Args},
882+
args::Vararg{Any,K}
883+
) where {UNIT,Args,K}
884+
spa, spu = reassemble_tup(Args, args)
885+
T = eltype(spa)
873886
B = block_size(Val(T))
874887
W = VectorizationBase.pick_vector_width(T)
888+
XA = _contig_axis(spa)
889+
XA = _contig_axis(spu)
875890
WUF = XA == XA == 2 ? W : W * unroll_factor(W)
876891
B_m = VectorizationBase.vcld(M, VectorizationBase.vcld(M, B) * WUF) * WUF
877892
m = 0
878893
while m < M
879894
mu = m + B_m
880895
Mtemp = min(M, mu) - m
881-
rdiv_block_N!(
882-
spa,
883-
spu,
884-
Mtemp,
885-
N,
886-
Val{UNIT}(),
887-
VectorizationBase.vcld(N, VectorizationBase.vcld(N, B) * W) * W
888-
)
896+
let tup = (spa, spu), ftup = flatten_to_tup(tup)
897+
rdiv_block_N!(
898+
Mtemp,
899+
N,
900+
Val{UNIT}(),
901+
VectorizationBase.vcld(N, VectorizationBase.vcld(N, B) * W) * W,
902+
typeof(tup),
903+
ftup...
904+
)
905+
end
889906
spa = gesp(spa, (B_m, StaticInt{0}()))
890907
m = mu
891908
end
@@ -905,13 +922,17 @@ function (f::RDivBlockMandNv2{UNIT})(
905922
) where {UNIT}
906923
spa, spu, N, Mrem, Nblock, mtb = allargs
907924
for block = blockstart-1:blockstop-1
908-
rdiv_block_MandN!(
909-
gesp(spa, (mtb * block, StaticInt{0}())),
910-
spu,
911-
Core.ifelse(block == Nblock - 1, Mrem, mtb),
912-
N,
913-
Val{UNIT}()
914-
)
925+
let tup = (gesp(spa, (mtb * block, StaticInt{0}())), spu),
926+
ftup = flatten_to_tup(tup)
927+
928+
rdiv_block_MandN!(
929+
Core.ifelse(block == Nblock - 1, Mrem, mtb),
930+
N,
931+
Val{UNIT}(),
932+
typeof(tup),
933+
ftup...
934+
)
935+
end
915936
end
916937
end
917938

0 commit comments

Comments
 (0)