489
489
spu:: AbstractStridedPointer ,
490
490
M,
491
491
N,
492
- :: StaticInt{1} ,
493
- :: StaticInt{1} ,
494
492
:: Val{UNIT}
495
493
) where {T,UNIT}
496
494
WS = pick_vector_width (T)
534
532
nothing
535
533
end
536
534
537
- const buffer = Ref {Ptr{Cvoid}} (C_NULL )
538
-
539
- function __init__ ()
540
- bp_size = 2 * sizeof (Int) * Threads. nthreads ()
541
- buffer[] = bp = Libc. malloc (bp_size % UInt)
542
- Libc. memset (bp, 0 , bp_size)
543
- end
544
-
545
- function _get_buffer_pointer (:: StaticInt{UF} , N) where {UF}
546
- RS = VectorizationBase. register_size ()
547
- RSUF = StaticInt {UF} () * RS
548
- L = RSUF * N
549
- tid = Threads. threadid () - 1
550
- bp = Ptr {Pair{Ptr{Cvoid},Int}} (buffer[]) + 2 sizeof (Int) * tid
551
- (p, buff_current) = unsafe_load (bp)
552
- if buff_current < L
553
- p == C_NULL || Libc. free (p)
554
- buff_size = max (RSUF * 128 , L)
555
- p = Libc. malloc ((buff_size + RS - 1 ) % UInt)
556
- unsafe_store! (bp, p => buff_size)
557
- end
558
- return VectorizationBase. align (p, RS)
559
- end
560
-
561
- @inline function lubuffer (:: Val{T} , :: StaticInt{UF} , N) where {T,UF}
562
- RS = VectorizationBase. register_size ()
563
- RSUF = StaticInt {UF} () * RS
564
- ptr = Ptr {T} (_get_buffer_pointer (StaticInt {UF} (), N))
565
- si = StrideIndex {2,(1, 2),1} (
566
- (VectorizationBase. static_sizeof (T), RSUF),
567
- (StaticInt (0 ), StaticInt (0 ))
568
- )
569
- stridedpointer (ptr, si, StaticInt {0} ()), nothing
570
- end
571
- @inline function lubuffer (
572
- :: Val{T} ,
573
- :: StaticInt{UF} ,
574
- :: StaticInt{N}
575
- ) where {T,UF,N}
576
- RSUF = StaticInt {UF} () * VectorizationBase. pick_vector_width (T)
577
- L = RSUF * N
578
- buf = Ref {NTuple{L,T}} ()
579
- ptr = Base. unsafe_convert (Ptr{T}, buf)
580
- si = StrideIndex {2,(1, 2),1} (
581
- (
582
- VectorizationBase. static_sizeof (T),
583
- RSUF * VectorizationBase. static_sizeof (T)
584
- ),
585
- (StaticInt (0 ), StaticInt (0 ))
586
- )
587
- stridedpointer (ptr, si, StaticInt {0} ()), buf
588
- end
589
- @inline _free (p:: Ptr ) = Libc. free (p)
590
535
_canonicalize (x) = signed (x)
591
536
_canonicalize (:: StaticInt{N} ) where {N} = StaticInt {N} ()
592
537
function div_dispatch! (
@@ -606,17 +551,14 @@ function div_dispatch!(
606
551
spa = zero_offsets (_spa)
607
552
spc = zero_offsets (_spc)
608
553
spu = zero_offsets (_spu)
609
- XC = VectorizationBase. contiguous_axis (C)
610
- XA = VectorizationBase. contiguous_axis (A)
611
554
GC. @preserve spap spcp spup begin
612
555
mtb = m_thread_block_size (M, N, nthread, Val (T))
613
556
if nthread > 1
614
- (M > mtb) &&
615
- return multithread_rdiv! (spc, spa, spu, M, N, mtb, Val (UNIT), XC, XA)
557
+ (M > mtb) && return multithread_rdiv! (spc, spa, spu, M, N, mtb, Val (UNIT))
616
558
elseif N > block_size (Val (T))
617
- return rdiv_block_MandN! (spc, spa, spu, M, N, Val (UNIT), XC, XA )
559
+ return rdiv_block_MandN! (spc, spa, spu, M, N, Val (UNIT))
618
560
end
619
- return rdiv_U! (spc, spa, spu, M, N, XC, XA, Val (UNIT))
561
+ return rdiv_U! (spc, spa, spu, M, N, Val (UNIT))
620
562
end
621
563
end
622
564
@@ -833,10 +775,8 @@ function rdiv_block_N!(
833
775
M,
834
776
N,
835
777
:: Val{UNIT} ,
836
- :: StaticInt{XC} ,
837
- :: StaticInt{XA} ,
838
778
Bsize = nothing
839
- ) where {T,UNIT,XC,XA }
779
+ ) where {T,UNIT}
840
780
spa_rdiv = spa
841
781
spc_base = spc
842
782
n = 0
@@ -857,8 +797,6 @@ function rdiv_block_N!(
857
797
gesp (spu, (n, StaticInt {0} ())),
858
798
M,
859
799
N_temp,
860
- StaticInt {XC} (),
861
- StaticInt {XA} (),
862
800
Val {UNIT} ()
863
801
)
864
802
repeat || break
@@ -873,18 +811,16 @@ function rdiv_block_N!(
873
811
end
874
812
end
875
813
function rdiv_block_MandN! (
876
- spc:: AbstractStridedPointer{T} ,
877
- spa,
878
- spu,
814
+ spc:: AbstractStridedPointer{T,<:Any,XC } ,
815
+ spa:: AbstractStridedPointer{T,<:Any,XA} ,
816
+ spu:: AbstractStridedPointer{T,<:Any,XU} ,
879
817
M,
880
818
N,
881
- :: Val{UNIT} ,
882
- :: StaticInt{XC} ,
883
- :: StaticInt{XA}
884
- ) where {T,UNIT,XC,XA}
819
+ :: Val{UNIT}
820
+ ) where {T,UNIT,XC,XA,XU}
885
821
B = block_size (Val (T))
886
822
W = VectorizationBase. pick_vector_width (T)
887
- WUF = XC == XA == 2 ? W : W * unroll_factor (W)
823
+ WUF = XC == XA == XA == 2 ? W : W * unroll_factor (W)
888
824
B_m = VectorizationBase. vcld (M, VectorizationBase. vcld (M, B) * WUF) * WUF
889
825
m = 0
890
826
while m < M
@@ -897,8 +833,6 @@ function rdiv_block_MandN!(
897
833
Mtemp,
898
834
N,
899
835
Val {UNIT} (),
900
- StaticInt {XC} (),
901
- StaticInt {XA} (),
902
836
VectorizationBase. vcld (N, VectorizationBase. vcld (N, B) * W) * W
903
837
)
904
838
spa = gesp (spa, (B_m, StaticInt {0} ()))
@@ -913,12 +847,12 @@ function m_thread_block_size(M, N, nthreads, ::Val{T}) where {T}
913
847
min (M, VectorizationBase. vcld (M, nb * W) * W)
914
848
end
915
849
916
- struct RDivBlockMandNv2{UNIT,XC,XA } end
917
- function (f:: RDivBlockMandNv2{UNIT,XC,XA } )(
850
+ struct RDivBlockMandNv2{UNIT} end
851
+ function (f:: RDivBlockMandNv2{UNIT} )(
918
852
allargs,
919
853
blockstart,
920
854
blockstop
921
- ) where {UNIT,XC,XA }
855
+ ) where {UNIT}
922
856
spc, spa, spu, N, Mrem, Nblock, mtb = allargs
923
857
for block = blockstart- 1 : blockstop- 1
924
858
rdiv_block_MandN! (
@@ -927,9 +861,7 @@ function (f::RDivBlockMandNv2{UNIT,XC,XA})(
927
861
spu,
928
862
Core. ifelse (block == Nblock - 1 , Mrem, mtb),
929
863
N,
930
- Val {UNIT} (),
931
- static (XC),
932
- static (XA)
864
+ Val {UNIT} ()
933
865
)
934
866
end
935
867
end
@@ -941,17 +873,14 @@ function multithread_rdiv!(
941
873
M:: Int ,
942
874
N:: Int ,
943
875
mtb:: Int ,
944
- :: Val{UNIT} ,
945
- :: StaticInt{XC} ,
946
- :: StaticInt{XA}
947
- ) where {XC,XA,UNIT,TC,TA,TU}
876
+ :: Val{UNIT}
877
+ ) where {UNIT,TC,TA,TU}
948
878
# Main._a[] = (spc, spa, spu, M, N, mtb, Val(UNIT), static(X));
949
879
(Md, Mr) = VectorizationBase. vdivrem (M, mtb)
950
880
Nblock = Md + (Mr ≠ 0 )
951
881
Mrem = Core. ifelse (Mr ≠ 0 , Mr, mtb)
952
- f = RDivBlockMandNv2 {UNIT,XC,XA} ()
953
882
batch (
954
- f ,
883
+ RDivBlockMandNv2 {UNIT} () ,
955
884
(Nblock, min (Nblock, Threads. nthreads ())),
956
885
spc,
957
886
spa,
@@ -977,60 +906,6 @@ function unroll_factor(::StaticInt{W}) where {W}
977
906
ifelse (Static. lt (num_blocks, StaticInt {1} ()), StaticInt {1} (), num_blocks)
978
907
end
979
908
980
- function rdiv_U! (
981
- spc:: AbstractStridedPointer{T} ,
982
- spa:: AbstractStridedPointer ,
983
- spu:: AbstractStridedPointer ,
984
- M,
985
- N,
986
- :: StaticInt{var"#UNUSED1#"} ,
987
- :: StaticInt{var"#UNUSED2#"} ,
988
- :: Val{UNIT}
989
- ) where {T,UNIT,var"#UNUSED1#" ,var"#UNUSED2#" }
990
- WS = pick_vector_width (T)
991
- W = Int (WS)
992
- UF = unroll_factor (WS)
993
- WU = UF * WS
994
- Nd, Nr = VectorizationBase. vdivrem (N, WS)
995
- spb, preserve = lubuffer (Val (T), UF, N)
996
- m = 0
997
- GC. @preserve preserve begin
998
- if UF > 1
999
- while m < M - WU + 1
1000
- n = Nr
1001
- if n > 0
1002
- BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT), WS)
1003
- end
1004
- for _ ∈ 1 : Nd
1005
- rdiv_solve_W_u! (spb, spc, spa, spu, n, WS, UF, Val (UNIT))
1006
- n += W
1007
- end
1008
- m += WU
1009
- spa = gesp (spa, (WU, StaticInt (0 )))
1010
- spc = gesp (spc, (WU, StaticInt (0 )))
1011
- end
1012
- end
1013
- finalmask = VectorizationBase. mask (WS, M)
1014
- while m < M
1015
- ubm = m + W
1016
- nomaskiter = ubm < M
1017
- mask = nomaskiter ? VectorizationBase. max_mask (WS) : finalmask
1018
- n = Nr
1019
- if n > 0
1020
- BdivU_small_kern! (spb, spc, spa, spu, n, mask, Val (UNIT))
1021
- end
1022
- for i ∈ 1 : Nd
1023
- rdiv_solve_W! (spb, spc, spa, spu, n, i ≠ Nd, mask, Val (UNIT))
1024
- n += W
1025
- end
1026
- spa = gesp (spa, (WS, StaticInt (0 )))
1027
- spc = gesp (spc, (WS, StaticInt (0 )))
1028
- m = ubm
1029
- end
1030
- end
1031
- nothing
1032
- end
1033
-
1034
909
@generated function _ldiv_remainder! (
1035
910
spc,
1036
911
spa,
@@ -1109,34 +984,50 @@ end
1109
984
) where {W,UNIT}
1110
985
WS = static (W)
1111
986
# US = static(U)
1112
- quote
1113
- # $(Expr(:meta, :inline))
1114
- Base. Cartesian. @nif $ (W - 1 ) w -> m == M - w w -> _ldiv_remainder! (
1115
- spc,
1116
- spa,
1117
- spu,
1118
- M,
1119
- N,
1120
- m,
1121
- Nr,
1122
- $ WS,
1123
- $ (Val (UNIT)),
1124
- StaticInt (w)
1125
- )
987
+ if W == 2
988
+ quote
989
+ $ (Expr (:meta , :inline ))
990
+ _ldiv_remainder! (
991
+ spc,
992
+ spa,
993
+ spu,
994
+ M,
995
+ N,
996
+ m,
997
+ Nr,
998
+ $ WS,
999
+ $ (Val (UNIT)),
1000
+ $ (static (1 ))
1001
+ )
1002
+ end
1003
+ else
1004
+ quote
1005
+ # $(Expr(:meta, :inline))
1006
+ Base. Cartesian. @nif $ (W - 1 ) w -> m == M - w w -> _ldiv_remainder! (
1007
+ spc,
1008
+ spa,
1009
+ spu,
1010
+ M,
1011
+ N,
1012
+ m,
1013
+ Nr,
1014
+ $ WS,
1015
+ $ (Val (UNIT)),
1016
+ StaticInt (w)
1017
+ )
1018
+ end
1126
1019
end
1127
1020
end
1128
1021
1129
1022
# spc = spa / spu
1130
1023
# spc' = (spu' \ spa')'
1131
1024
# This is ldiv
1132
1025
function rdiv_U! (
1133
- spc:: AbstractStridedPointer{T} ,
1134
- spa:: AbstractStridedPointer ,
1135
- spu:: AbstractStridedPointer ,
1026
+ spc:: AbstractStridedPointer{T,2,2 } ,
1027
+ spa:: AbstractStridedPointer{T,2,2} ,
1028
+ spu:: AbstractStridedPointer{T,2,2} ,
1136
1029
M,
1137
1030
N,
1138
- :: StaticInt{2} ,
1139
- :: StaticInt{2} ,
1140
1031
:: Val{UNIT}
1141
1032
) where {T,UNIT}
1142
1033
WS = pick_vector_width (T)
0 commit comments