1
1
module TriangularSolve
2
- using Base: @nexprs , @ntuple
2
+ using Base: @nexprs , @ntuple , Flatten
3
3
if isdefined (Base, :Experimental ) &&
4
4
isdefined (Base. Experimental, Symbol (" @max_methods" ))
5
5
@eval Base. Experimental. @max_methods 1
@@ -22,13 +22,15 @@ using Polyester
22
22
using StaticArrayInterface
23
23
24
24
const LPtr{T} = Core. LLVMPtr{T,0 }
25
- _lptr (x:: Ptr{T} ) where {T} = reinterpret (LPtr{T}, x)
26
- _lptr (x) = x
27
- _ptr (x:: LPtr{T} ) where {T} = reinterpret (Ptr{T}, x)
28
- _ptr (x) = x
29
- @inline reassemble_tup (:: Type{T} , t) where {T} =
30
- LoopVectorization. reassemble_tuple (T, map (_ptr, t))
31
- @inline flatten_to_tup (t) = map (_lptr, LoopVectorization. flatten_to_tuple (t))
25
+ # _lptr(x::Ptr{T}) where {T} = Base.bitcast(LPtr{T}, x)::LPtr{T}
26
+ # _ptr(x::LPtr{T}) where {T} = Base.bitcast(Ptr{T}, x)::Ptr{T}
27
+ # _lptr(x) = x
28
+ # _ptr(x) = x
29
+ const reassemble_tup = LoopVectorization. reassemble_tuple
30
+ const flatten_to_tup = LoopVectorization. flatten_to_tuple
31
+ # @inline reassemble_tup(::Type{T}, t) where {T} =
32
+ # LoopVectorization.reassemble_tuple(T, map(_ptr, t))
33
+ # @inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t))
32
34
33
35
@generated function solve_AU (
34
36
A:: VecUnroll{Nm1} ,
@@ -717,7 +719,7 @@ struct Mat{T,ColMajor} <: AbstractMatrix{T}
717
719
M:: Int
718
720
N:: Int
719
721
end
720
- Base. size (A:: Mat ) = (A. M, A. N)
722
+ Base. size (A:: Mat ):: Tuple{Int,Int} = (A. M, A. N):: Tuple{Int,Int}
721
723
Base. axes (A:: Mat ) = (CloseOpen (A. M), CloseOpen (A. N))
722
724
Base. strides (A:: Mat{T,true} ) where {T} = (1 , getfield (A, :x ))
723
725
Base. strides (A:: Mat{T,false} ) where {T} = (getfield (A, :x ), 1 )
@@ -730,11 +732,15 @@ StaticArrayInterface.static_strides(A::Mat{T,true}) where {T} =
730
732
StaticArrayInterface. static_strides (A:: Mat{T,false} ) where {T} =
731
733
(getfield (A, :x ), static (1 ))
732
734
StaticArrayInterface. offsets (:: Mat ) = (static (0 ), static (0 ))
733
- StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,true}} ) = (static (1 ), static (2 ))
734
- StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,false}} ) = (static (2 ), static (1 ))
735
+ StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,true}} ) =
736
+ (static (1 ), static (2 ))
737
+ StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,false}} ) =
738
+ (static (2 ), static (1 ))
735
739
StaticArrayInterface. contiguous_batch_size (:: Type{<:Mat} ) = static (0 )
736
- StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,true}} ) = (static (true ),static (false ))
737
- StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,false}} ) = (static (false ),static (true ))
740
+ StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,true}} ) =
741
+ (static (true ), static (false ))
742
+ StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,false}} ) =
743
+ (static (false ), static (true ))
738
744
StaticArrayInterface. contiguous_axis (:: Type{<:Mat{<:Any,true}} ) = static (1 )
739
745
StaticArrayInterface. contiguous_axis (:: Type{<:Mat{<:Any,false}} ) = static (2 )
740
746
@inline function Base. getindex (
@@ -771,26 +777,28 @@ end
771
777
# C -= A * B
772
778
@inline function _schur_complement! (C:: Mat , A:: Mat , B:: Mat , :: Val{false} )
773
779
# _turbo_! will not be inlined
774
- @turbo warn_check_args = false for n in indices ((C, B), 2 ),
775
- m in indices ((C, A), 1 )
780
+ @inbounds begin
781
+ @turbo warn_check_args = false for n in indices ((C, B), 2 ),
782
+ m in indices ((C, A), 1 )
776
783
777
- Cmn = C[m, n]
778
- for k in indices ((A, B), (2 , 1 ))
779
- Cmn -= A[m, k] * B[k, n]
784
+ Cmn = zero (eltype (C))
785
+ for k in indices ((A, B), (2 , 1 ))
786
+ Cmn -= A[m, k] * B[k, n]
787
+ end
788
+ C[m, n] += Cmn
780
789
end
781
- C[m, n] = Cmn
782
790
end
783
791
end
784
792
@inline function _schur_complement! (C:: Mat , A:: Mat , B:: Mat , :: Val{true} )
785
793
# _turbo_! will not be inlined
786
794
@tturbo warn_check_args = false for n in indices ((C, B), 2 ),
787
795
m in indices ((C, A), 1 )
788
796
789
- Cmn = C[m, n]
797
+ Cmn = zero ( eltype (C))
790
798
for k in indices ((A, B), (2 , 1 ))
791
799
Cmn -= A[m, k] * B[k, n]
792
800
end
793
- C[m, n] = Cmn
801
+ C[m, n] + = Cmn
794
802
end
795
803
end
796
804
@inline function schur_complement! (
@@ -828,15 +836,17 @@ end
828
836
end
829
837
830
838
function rdiv_block_N! (
831
- spa:: AbstractStridedPointer{T} ,
832
- spu,
833
839
M,
834
840
N,
835
841
:: Val{UNIT} ,
836
- Bsize = nothing
837
- ) where {T,UNIT}
842
+ Bsize,
843
+ :: Type{Args} ,
844
+ args:: Vararg{Any,K}
845
+ ) where {K,Args,UNIT}
846
+ spa, spu = reassemble_tup (Args, args)
838
847
spa_base = spa
839
848
n = 0
849
+ T = eltype (spa)
840
850
W = VectorizationBase. pick_vector_width (T)
841
851
B_normalized =
842
852
Bsize === nothing ?
@@ -863,29 +873,36 @@ function rdiv_block_N!(
863
873
)
864
874
end
865
875
end
876
+ _contig_axis (:: AbstractStridedPointer{<:Any,2,X} ) where {X} = X
866
877
function rdiv_block_MandN! (
867
- spa:: AbstractStridedPointer{T,<:Any,XA} ,
868
- spu:: AbstractStridedPointer{T,<:Any,XU} ,
869
878
M,
870
879
N,
871
- :: Val{UNIT}
872
- ) where {T,UNIT,XA,XU}
880
+ :: Val{UNIT} ,
881
+ :: Type{Args} ,
882
+ args:: Vararg{Any,K}
883
+ ) where {UNIT,Args,K}
884
+ spa, spu = reassemble_tup (Args, args)
885
+ T = eltype (spa)
873
886
B = block_size (Val (T))
874
887
W = VectorizationBase. pick_vector_width (T)
888
+ XA = _contig_axis (spa)
889
+ XA = _contig_axis (spu)
875
890
WUF = XA == XA == 2 ? W : W * unroll_factor (W)
876
891
B_m = VectorizationBase. vcld (M, VectorizationBase. vcld (M, B) * WUF) * WUF
877
892
m = 0
878
893
while m < M
879
894
mu = m + B_m
880
895
Mtemp = min (M, mu) - m
881
- rdiv_block_N! (
882
- spa,
883
- spu,
884
- Mtemp,
885
- N,
886
- Val {UNIT} (),
887
- VectorizationBase. vcld (N, VectorizationBase. vcld (N, B) * W) * W
888
- )
896
+ let tup = (spa, spu), ftup = flatten_to_tup (tup)
897
+ rdiv_block_N! (
898
+ Mtemp,
899
+ N,
900
+ Val {UNIT} (),
901
+ VectorizationBase. vcld (N, VectorizationBase. vcld (N, B) * W) * W,
902
+ typeof (tup),
903
+ ftup...
904
+ )
905
+ end
889
906
spa = gesp (spa, (B_m, StaticInt {0} ()))
890
907
m = mu
891
908
end
@@ -905,13 +922,17 @@ function (f::RDivBlockMandNv2{UNIT})(
905
922
) where {UNIT}
906
923
spa, spu, N, Mrem, Nblock, mtb = allargs
907
924
for block = blockstart- 1 : blockstop- 1
908
- rdiv_block_MandN! (
909
- gesp (spa, (mtb * block, StaticInt {0} ())),
910
- spu,
911
- Core. ifelse (block == Nblock - 1 , Mrem, mtb),
912
- N,
913
- Val {UNIT} ()
914
- )
925
+ let tup = (gesp (spa, (mtb * block, StaticInt {0} ())), spu),
926
+ ftup = flatten_to_tup (tup)
927
+
928
+ rdiv_block_MandN! (
929
+ Core. ifelse (block == Nblock - 1 , Mrem, mtb),
930
+ N,
931
+ Val {UNIT} (),
932
+ typeof (tup),
933
+ ftup...
934
+ )
935
+ end
915
936
end
916
937
end
917
938
0 commit comments