11module  TriangularSolve
2- using  Base:  @nexprs , @ntuple 
2+ using  Base:  @nexprs , @ntuple , Flatten 
33if  isdefined (Base, :Experimental ) && 
44   isdefined (Base. Experimental, Symbol (" @max_methods"  ))
55  @eval  Base. Experimental. @max_methods  1 
@@ -22,13 +22,15 @@ using Polyester
2222using  StaticArrayInterface
2323
2424const  LPtr{T} =  Core. LLVMPtr{T,0 }
25- _lptr (x:: Ptr{T} ) where  {T} =  reinterpret (LPtr{T}, x)
26- _lptr (x) =  x
27- _ptr (x:: LPtr{T} ) where  {T} =  reinterpret (Ptr{T}, x)
28- _ptr (x) =  x
29- @inline  reassemble_tup (:: Type{T} , t) where  {T} = 
30-   LoopVectorization. reassemble_tuple (T, map (_ptr, t))
31- @inline  flatten_to_tup (t) =  map (_lptr, LoopVectorization. flatten_to_tuple (t))
25+ #  _lptr(x::Ptr{T}) where {T} = Base.bitcast(LPtr{T}, x)::LPtr{T}
26+ #  _ptr(x::LPtr{T}) where {T} = Base.bitcast(Ptr{T}, x)::Ptr{T}
27+ #  _lptr(x) = x
28+ #  _ptr(x) = x
29+ const  reassemble_tup =  LoopVectorization. reassemble_tuple
30+ const  flatten_to_tup =  LoopVectorization. flatten_to_tuple
31+ #  @inline reassemble_tup(::Type{T}, t) where {T} =
32+ #    LoopVectorization.reassemble_tuple(T, map(_ptr, t))
33+ #  @inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t))
3234
3335@generated  function  solve_AU (
3436  A:: VecUnroll{Nm1} ,
@@ -717,7 +719,7 @@ struct Mat{T,ColMajor} <: AbstractMatrix{T}
717719  M:: Int 
718720  N:: Int 
719721end 
720- Base. size (A:: Mat ) =  (A. M, A. N)
722+ Base. size (A:: Mat ):: Tuple{Int,Int}  =  (A. M, A. N):: Tuple{Int,Int} 
721723Base. axes (A:: Mat ) =  (CloseOpen (A. M), CloseOpen (A. N))
722724Base. strides (A:: Mat{T,true} ) where  {T} =  (1 , getfield (A, :x ))
723725Base. strides (A:: Mat{T,false} ) where  {T} =  (getfield (A, :x ), 1 )
@@ -730,11 +732,15 @@ StaticArrayInterface.static_strides(A::Mat{T,true}) where {T} =
730732StaticArrayInterface. static_strides (A:: Mat{T,false} ) where  {T} = 
731733  (getfield (A, :x ), static (1 ))
732734StaticArrayInterface. offsets (:: Mat ) =  (static (0 ), static (0 ))
733- StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,true}} ) =  (static (1 ), static (2 ))
734- StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,false}} ) =  (static (2 ), static (1 ))
735+ StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,true}} ) = 
736+   (static (1 ), static (2 ))
737+ StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,false}} ) = 
738+   (static (2 ), static (1 ))
735739StaticArrayInterface. contiguous_batch_size (:: Type{<:Mat} ) =  static (0 )
736- StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,true}} ) =  (static (true ),static (false ))
737- StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,false}} ) =  (static (false ),static (true ))
740+ StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,true}} ) = 
741+   (static (true ), static (false ))
742+ StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,false}} ) = 
743+   (static (false ), static (true ))
738744StaticArrayInterface. contiguous_axis (:: Type{<:Mat{<:Any,true}} ) =  static (1 )
739745StaticArrayInterface. contiguous_axis (:: Type{<:Mat{<:Any,false}} ) =  static (2 )
740746@inline  function  Base. getindex (
@@ -771,26 +777,28 @@ end
771777#  C -= A * B
772778@inline  function  _schur_complement! (C:: Mat , A:: Mat , B:: Mat , :: Val{false} )
773779  #  _turbo_! will not be inlined
774-   @turbo  warn_check_args =  false  for  n in  indices ((C, B), 2 ),
775-     m in  indices ((C, A), 1 )
780+   @inbounds  begin 
781+     @turbo  warn_check_args =  false  for  n in  indices ((C, B), 2 ),
782+       m in  indices ((C, A), 1 )
776783
777-     Cmn =  C[m, n]
778-     for  k in  indices ((A, B), (2 , 1 ))
779-       Cmn -=  A[m, k] *  B[k, n]
784+       Cmn =  zero (eltype (C))
785+       for  k in  indices ((A, B), (2 , 1 ))
786+         Cmn -=  A[m, k] *  B[k, n]
787+       end 
788+       C[m, n] +=  Cmn
780789    end 
781-     C[m, n] =  Cmn
782790  end 
783791end 
784792@inline  function  _schur_complement! (C:: Mat , A:: Mat , B:: Mat , :: Val{true} )
785793  #  _turbo_! will not be inlined
786794  @tturbo  warn_check_args =  false  for  n in  indices ((C, B), 2 ),
787795    m in  indices ((C, A), 1 )
788796
789-     Cmn =  C[m, n] 
797+     Cmn =  zero ( eltype (C)) 
790798    for  k in  indices ((A, B), (2 , 1 ))
791799      Cmn -=  A[m, k] *  B[k, n]
792800    end 
793-     C[m, n] =  Cmn
801+     C[m, n] + =  Cmn
794802  end 
795803end 
796804@inline  function  schur_complement! (
@@ -828,15 +836,17 @@ end
828836end 
829837
830838function  rdiv_block_N! (
831-   spa:: AbstractStridedPointer{T} ,
832-   spu,
833839  M,
834840  N,
835841  :: Val{UNIT} ,
836-   Bsize =  nothing 
837- ) where  {T,UNIT}
842+   Bsize,
843+   :: Type{Args} ,
844+   args:: Vararg{Any,K} 
845+ ) where  {K,Args,UNIT}
846+   spa, spu =  reassemble_tup (Args, args)
838847  spa_base =  spa
839848  n =  0 
849+   T =  eltype (spa)
840850  W =  VectorizationBase. pick_vector_width (T)
841851  B_normalized = 
842852    Bsize ===  nothing  ? 
@@ -863,29 +873,36 @@ function rdiv_block_N!(
863873    )
864874  end 
865875end 
876+ _contig_axis (:: AbstractStridedPointer{<:Any,2,X} ) where  {X} =  X
866877function  rdiv_block_MandN! (
867-   spa:: AbstractStridedPointer{T,<:Any,XA} ,
868-   spu:: AbstractStridedPointer{T,<:Any,XU} ,
869878  M,
870879  N,
871-   :: Val{UNIT} 
872- ) where  {T,UNIT,XA,XU}
880+   :: Val{UNIT} ,
881+   :: Type{Args} ,
882+   args:: Vararg{Any,K} 
883+ ) where  {UNIT,Args,K}
884+   spa, spu =  reassemble_tup (Args, args)
885+   T =  eltype (spa)
873886  B =  block_size (Val (T))
874887  W =  VectorizationBase. pick_vector_width (T)
888+   XA =  _contig_axis (spa)
889+   XA =  _contig_axis (spu)
875890  WUF =  XA ==  XA ==  2  ?  W :  W *  unroll_factor (W)
876891  B_m =  VectorizationBase. vcld (M, VectorizationBase. vcld (M, B) *  WUF) *  WUF
877892  m =  0 
878893  while  m <  M
879894    mu =  m +  B_m
880895    Mtemp =  min (M, mu) -  m
881-     rdiv_block_N! (
882-       spa,
883-       spu,
884-       Mtemp,
885-       N,
886-       Val {UNIT} (),
887-       VectorizationBase. vcld (N, VectorizationBase. vcld (N, B) *  W) *  W
888-     )
896+     let  tup =  (spa, spu), ftup =  flatten_to_tup (tup)
897+       rdiv_block_N! (
898+         Mtemp,
899+         N,
900+         Val {UNIT} (),
901+         VectorizationBase. vcld (N, VectorizationBase. vcld (N, B) *  W) *  W,
902+         typeof (tup),
903+         ftup... 
904+       )
905+     end 
889906    spa =  gesp (spa, (B_m, StaticInt {0} ()))
890907    m =  mu
891908  end 
@@ -905,13 +922,17 @@ function (f::RDivBlockMandNv2{UNIT})(
905922) where  {UNIT}
906923  spa, spu, N, Mrem, Nblock, mtb =  allargs
907924  for  block =  blockstart- 1 : blockstop- 1 
908-     rdiv_block_MandN! (
909-       gesp (spa, (mtb *  block, StaticInt {0} ())),
910-       spu,
911-       Core. ifelse (block ==  Nblock -  1 , Mrem, mtb),
912-       N,
913-       Val {UNIT} ()
914-     )
925+     let  tup =  (gesp (spa, (mtb *  block, StaticInt {0} ())), spu),
926+       ftup =  flatten_to_tup (tup)
927+ 
928+       rdiv_block_MandN! (
929+         Core. ifelse (block ==  Nblock -  1 , Mrem, mtb),
930+         N,
931+         Val {UNIT} (),
932+         typeof (tup),
933+         ftup... 
934+       )
935+     end 
915936  end 
916937end 
917938
0 commit comments