11module TriangularSolve
2- using Base: @nexprs , @ntuple
2+ using Base: @nexprs , @ntuple , Flatten
33if isdefined (Base, :Experimental ) &&
44 isdefined (Base. Experimental, Symbol (" @max_methods" ))
55 @eval Base. Experimental. @max_methods 1
@@ -22,13 +22,15 @@ using Polyester
2222using StaticArrayInterface
2323
2424const LPtr{T} = Core. LLVMPtr{T,0 }
25- _lptr (x:: Ptr{T} ) where {T} = reinterpret (LPtr{T}, x)
26- _lptr (x) = x
27- _ptr (x:: LPtr{T} ) where {T} = reinterpret (Ptr{T}, x)
28- _ptr (x) = x
29- @inline reassemble_tup (:: Type{T} , t) where {T} =
30- LoopVectorization. reassemble_tuple (T, map (_ptr, t))
31- @inline flatten_to_tup (t) = map (_lptr, LoopVectorization. flatten_to_tuple (t))
25+ # _lptr(x::Ptr{T}) where {T} = Base.bitcast(LPtr{T}, x)::LPtr{T}
26+ # _ptr(x::LPtr{T}) where {T} = Base.bitcast(Ptr{T}, x)::Ptr{T}
27+ # _lptr(x) = x
28+ # _ptr(x) = x
29+ const reassemble_tup = LoopVectorization. reassemble_tuple
30+ const flatten_to_tup = LoopVectorization. flatten_to_tuple
31+ # @inline reassemble_tup(::Type{T}, t) where {T} =
32+ # LoopVectorization.reassemble_tuple(T, map(_ptr, t))
33+ # @inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t))
3234
3335@generated function solve_AU (
3436 A:: VecUnroll{Nm1} ,
@@ -717,7 +719,7 @@ struct Mat{T,ColMajor} <: AbstractMatrix{T}
717719 M:: Int
718720 N:: Int
719721end
720- Base. size (A:: Mat ) = (A. M, A. N)
722+ Base. size (A:: Mat ):: Tuple{Int,Int} = (A. M, A. N):: Tuple{Int,Int}
721723Base. axes (A:: Mat ) = (CloseOpen (A. M), CloseOpen (A. N))
722724Base. strides (A:: Mat{T,true} ) where {T} = (1 , getfield (A, :x ))
723725Base. strides (A:: Mat{T,false} ) where {T} = (getfield (A, :x ), 1 )
@@ -730,11 +732,15 @@ StaticArrayInterface.static_strides(A::Mat{T,true}) where {T} =
730732StaticArrayInterface. static_strides (A:: Mat{T,false} ) where {T} =
731733 (getfield (A, :x ), static (1 ))
732734StaticArrayInterface. offsets (:: Mat ) = (static (0 ), static (0 ))
733- StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,true}} ) = (static (1 ), static (2 ))
734- StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,false}} ) = (static (2 ), static (1 ))
735+ StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,true}} ) =
736+ (static (1 ), static (2 ))
737+ StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,false}} ) =
738+ (static (2 ), static (1 ))
735739StaticArrayInterface. contiguous_batch_size (:: Type{<:Mat} ) = static (0 )
736- StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,true}} ) = (static (true ),static (false ))
737- StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,false}} ) = (static (false ),static (true ))
740+ StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,true}} ) =
741+ (static (true ), static (false ))
742+ StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,false}} ) =
743+ (static (false ), static (true ))
738744StaticArrayInterface. contiguous_axis (:: Type{<:Mat{<:Any,true}} ) = static (1 )
739745StaticArrayInterface. contiguous_axis (:: Type{<:Mat{<:Any,false}} ) = static (2 )
740746@inline function Base. getindex (
@@ -771,26 +777,28 @@ end
771777# C -= A * B
772778@inline function _schur_complement! (C:: Mat , A:: Mat , B:: Mat , :: Val{false} )
773779 # _turbo_! will not be inlined
774- @turbo warn_check_args = false for n in indices ((C, B), 2 ),
775- m in indices ((C, A), 1 )
780+ @inbounds begin
781+ @turbo warn_check_args = false for n in indices ((C, B), 2 ),
782+ m in indices ((C, A), 1 )
776783
777- Cmn = C[m, n]
778- for k in indices ((A, B), (2 , 1 ))
779- Cmn -= A[m, k] * B[k, n]
784+ Cmn = zero (eltype (C))
785+ for k in indices ((A, B), (2 , 1 ))
786+ Cmn -= A[m, k] * B[k, n]
787+ end
788+ C[m, n] += Cmn
780789 end
781- C[m, n] = Cmn
782790 end
783791end
784792@inline function _schur_complement! (C:: Mat , A:: Mat , B:: Mat , :: Val{true} )
785793 # _turbo_! will not be inlined
786794 @tturbo warn_check_args = false for n in indices ((C, B), 2 ),
787795 m in indices ((C, A), 1 )
788796
789- Cmn = C[m, n]
797+ Cmn = zero ( eltype (C))
790798 for k in indices ((A, B), (2 , 1 ))
791799 Cmn -= A[m, k] * B[k, n]
792800 end
793- C[m, n] = Cmn
801+ C[m, n] + = Cmn
794802 end
795803end
796804@inline function schur_complement! (
@@ -828,15 +836,17 @@ end
828836end
829837
830838function rdiv_block_N! (
831- spa:: AbstractStridedPointer{T} ,
832- spu,
833839 M,
834840 N,
835841 :: Val{UNIT} ,
836- Bsize = nothing
837- ) where {T,UNIT}
842+ Bsize,
843+ :: Type{Args} ,
844+ args:: Vararg{Any,K}
845+ ) where {K,Args,UNIT}
846+ spa, spu = reassemble_tup (Args, args)
838847 spa_base = spa
839848 n = 0
849+ T = eltype (spa)
840850 W = VectorizationBase. pick_vector_width (T)
841851 B_normalized =
842852 Bsize === nothing ?
@@ -863,29 +873,36 @@ function rdiv_block_N!(
863873 )
864874 end
865875end
876+ _contig_axis (:: AbstractStridedPointer{<:Any,2,X} ) where {X} = X
866877function rdiv_block_MandN! (
867- spa:: AbstractStridedPointer{T,<:Any,XA} ,
868- spu:: AbstractStridedPointer{T,<:Any,XU} ,
869878 M,
870879 N,
871- :: Val{UNIT}
872- ) where {T,UNIT,XA,XU}
880+ :: Val{UNIT} ,
881+ :: Type{Args} ,
882+ args:: Vararg{Any,K}
883+ ) where {UNIT,Args,K}
884+ spa, spu = reassemble_tup (Args, args)
885+ T = eltype (spa)
873886 B = block_size (Val (T))
874887 W = VectorizationBase. pick_vector_width (T)
888+ XA = _contig_axis (spa)
889+ XA = _contig_axis (spu)
875890 WUF = XA == XA == 2 ? W : W * unroll_factor (W)
876891 B_m = VectorizationBase. vcld (M, VectorizationBase. vcld (M, B) * WUF) * WUF
877892 m = 0
878893 while m < M
879894 mu = m + B_m
880895 Mtemp = min (M, mu) - m
881- rdiv_block_N! (
882- spa,
883- spu,
884- Mtemp,
885- N,
886- Val {UNIT} (),
887- VectorizationBase. vcld (N, VectorizationBase. vcld (N, B) * W) * W
888- )
896+ let tup = (spa, spu), ftup = flatten_to_tup (tup)
897+ rdiv_block_N! (
898+ Mtemp,
899+ N,
900+ Val {UNIT} (),
901+ VectorizationBase. vcld (N, VectorizationBase. vcld (N, B) * W) * W,
902+ typeof (tup),
903+ ftup...
904+ )
905+ end
889906 spa = gesp (spa, (B_m, StaticInt {0} ()))
890907 m = mu
891908 end
@@ -905,13 +922,17 @@ function (f::RDivBlockMandNv2{UNIT})(
905922) where {UNIT}
906923 spa, spu, N, Mrem, Nblock, mtb = allargs
907924 for block = blockstart- 1 : blockstop- 1
908- rdiv_block_MandN! (
909- gesp (spa, (mtb * block, StaticInt {0} ())),
910- spu,
911- Core. ifelse (block == Nblock - 1 , Mrem, mtb),
912- N,
913- Val {UNIT} ()
914- )
925+ let tup = (gesp (spa, (mtb * block, StaticInt {0} ())), spu),
926+ ftup = flatten_to_tup (tup)
927+
928+ rdiv_block_MandN! (
929+ Core. ifelse (block == Nblock - 1 , Mrem, mtb),
930+ N,
931+ Val {UNIT} (),
932+ typeof (tup),
933+ ftup...
934+ )
935+ end
915936 end
916937end
917938
0 commit comments