Skip to content

Commit b7e9640

Browse files
committed
Simplify
1 parent 52532ec commit b7e9640

File tree

2 files changed

+53
-162
lines changed

2 files changed

+53
-162
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TriangularSolve"
22
uuid = "d5829a12-d9aa-46ab-831f-fb7c9ab06edf"
33
authors = ["chriselrod <[email protected]> and contributors"]
4-
version = "0.1.22"
4+
version = "0.2.0"
55

66
[deps]
77
CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9"

src/TriangularSolve.jl

Lines changed: 52 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,6 @@ end
489489
spu::AbstractStridedPointer,
490490
M,
491491
N,
492-
::StaticInt{1},
493-
::StaticInt{1},
494492
::Val{UNIT}
495493
) where {T,UNIT}
496494
WS = pick_vector_width(T)
@@ -534,59 +532,6 @@ end
534532
nothing
535533
end
536534

537-
const buffer = Ref{Ptr{Cvoid}}(C_NULL)
538-
539-
function __init__()
540-
bp_size = 2 * sizeof(Int) * Threads.nthreads()
541-
buffer[] = bp = Libc.malloc(bp_size % UInt)
542-
Libc.memset(bp, 0, bp_size)
543-
end
544-
545-
function _get_buffer_pointer(::StaticInt{UF}, N) where {UF}
546-
RS = VectorizationBase.register_size()
547-
RSUF = StaticInt{UF}() * RS
548-
L = RSUF * N
549-
tid = Threads.threadid() - 1
550-
bp = Ptr{Pair{Ptr{Cvoid},Int}}(buffer[]) + 2sizeof(Int) * tid
551-
(p, buff_current) = unsafe_load(bp)
552-
if buff_current < L
553-
p == C_NULL || Libc.free(p)
554-
buff_size = max(RSUF * 128, L)
555-
p = Libc.malloc((buff_size + RS - 1) % UInt)
556-
unsafe_store!(bp, p => buff_size)
557-
end
558-
return VectorizationBase.align(p, RS)
559-
end
560-
561-
@inline function lubuffer(::Val{T}, ::StaticInt{UF}, N) where {T,UF}
562-
RS = VectorizationBase.register_size()
563-
RSUF = StaticInt{UF}() * RS
564-
ptr = Ptr{T}(_get_buffer_pointer(StaticInt{UF}(), N))
565-
si = StrideIndex{2,(1, 2),1}(
566-
(VectorizationBase.static_sizeof(T), RSUF),
567-
(StaticInt(0), StaticInt(0))
568-
)
569-
stridedpointer(ptr, si, StaticInt{0}()), nothing
570-
end
571-
@inline function lubuffer(
572-
::Val{T},
573-
::StaticInt{UF},
574-
::StaticInt{N}
575-
) where {T,UF,N}
576-
RSUF = StaticInt{UF}() * VectorizationBase.pick_vector_width(T)
577-
L = RSUF * N
578-
buf = Ref{NTuple{L,T}}()
579-
ptr = Base.unsafe_convert(Ptr{T}, buf)
580-
si = StrideIndex{2,(1, 2),1}(
581-
(
582-
VectorizationBase.static_sizeof(T),
583-
RSUF * VectorizationBase.static_sizeof(T)
584-
),
585-
(StaticInt(0), StaticInt(0))
586-
)
587-
stridedpointer(ptr, si, StaticInt{0}()), buf
588-
end
589-
@inline _free(p::Ptr) = Libc.free(p)
590535
_canonicalize(x) = signed(x)
591536
_canonicalize(::StaticInt{N}) where {N} = StaticInt{N}()
592537
function div_dispatch!(
@@ -606,17 +551,14 @@ function div_dispatch!(
606551
spa = zero_offsets(_spa)
607552
spc = zero_offsets(_spc)
608553
spu = zero_offsets(_spu)
609-
XC = VectorizationBase.contiguous_axis(C)
610-
XA = VectorizationBase.contiguous_axis(A)
611554
GC.@preserve spap spcp spup begin
612555
mtb = m_thread_block_size(M, N, nthread, Val(T))
613556
if nthread > 1
614-
(M > mtb) &&
615-
return multithread_rdiv!(spc, spa, spu, M, N, mtb, Val(UNIT), XC, XA)
557+
(M > mtb) && return multithread_rdiv!(spc, spa, spu, M, N, mtb, Val(UNIT))
616558
elseif N > block_size(Val(T))
617-
return rdiv_block_MandN!(spc, spa, spu, M, N, Val(UNIT), XC, XA)
559+
return rdiv_block_MandN!(spc, spa, spu, M, N, Val(UNIT))
618560
end
619-
return rdiv_U!(spc, spa, spu, M, N, XC, XA, Val(UNIT))
561+
return rdiv_U!(spc, spa, spu, M, N, Val(UNIT))
620562
end
621563
end
622564

@@ -833,10 +775,8 @@ function rdiv_block_N!(
833775
M,
834776
N,
835777
::Val{UNIT},
836-
::StaticInt{XC},
837-
::StaticInt{XA},
838778
Bsize = nothing
839-
) where {T,UNIT,XC,XA}
779+
) where {T,UNIT}
840780
spa_rdiv = spa
841781
spc_base = spc
842782
n = 0
@@ -857,8 +797,6 @@ function rdiv_block_N!(
857797
gesp(spu, (n, StaticInt{0}())),
858798
M,
859799
N_temp,
860-
StaticInt{XC}(),
861-
StaticInt{XA}(),
862800
Val{UNIT}()
863801
)
864802
repeat || break
@@ -873,18 +811,16 @@ function rdiv_block_N!(
873811
end
874812
end
875813
function rdiv_block_MandN!(
876-
spc::AbstractStridedPointer{T},
877-
spa,
878-
spu,
814+
spc::AbstractStridedPointer{T,<:Any,XC},
815+
spa::AbstractStridedPointer{T,<:Any,XA},
816+
spu::AbstractStridedPointer{T,<:Any,XU},
879817
M,
880818
N,
881-
::Val{UNIT},
882-
::StaticInt{XC},
883-
::StaticInt{XA}
884-
) where {T,UNIT,XC,XA}
819+
::Val{UNIT}
820+
) where {T,UNIT,XC,XA,XU}
885821
B = block_size(Val(T))
886822
W = VectorizationBase.pick_vector_width(T)
887-
WUF = XC == XA == 2 ? W : W * unroll_factor(W)
823+
WUF = XC == XA == XA == 2 ? W : W * unroll_factor(W)
888824
B_m = VectorizationBase.vcld(M, VectorizationBase.vcld(M, B) * WUF) * WUF
889825
m = 0
890826
while m < M
@@ -897,8 +833,6 @@ function rdiv_block_MandN!(
897833
Mtemp,
898834
N,
899835
Val{UNIT}(),
900-
StaticInt{XC}(),
901-
StaticInt{XA}(),
902836
VectorizationBase.vcld(N, VectorizationBase.vcld(N, B) * W) * W
903837
)
904838
spa = gesp(spa, (B_m, StaticInt{0}()))
@@ -913,12 +847,12 @@ function m_thread_block_size(M, N, nthreads, ::Val{T}) where {T}
913847
min(M, VectorizationBase.vcld(M, nb * W) * W)
914848
end
915849

916-
struct RDivBlockMandNv2{UNIT,XC,XA} end
917-
function (f::RDivBlockMandNv2{UNIT,XC,XA})(
850+
struct RDivBlockMandNv2{UNIT} end
851+
function (f::RDivBlockMandNv2{UNIT})(
918852
allargs,
919853
blockstart,
920854
blockstop
921-
) where {UNIT,XC,XA}
855+
) where {UNIT}
922856
spc, spa, spu, N, Mrem, Nblock, mtb = allargs
923857
for block = blockstart-1:blockstop-1
924858
rdiv_block_MandN!(
@@ -927,9 +861,7 @@ function (f::RDivBlockMandNv2{UNIT,XC,XA})(
927861
spu,
928862
Core.ifelse(block == Nblock - 1, Mrem, mtb),
929863
N,
930-
Val{UNIT}(),
931-
static(XC),
932-
static(XA)
864+
Val{UNIT}()
933865
)
934866
end
935867
end
@@ -941,17 +873,14 @@ function multithread_rdiv!(
941873
M::Int,
942874
N::Int,
943875
mtb::Int,
944-
::Val{UNIT},
945-
::StaticInt{XC},
946-
::StaticInt{XA}
947-
) where {XC,XA,UNIT,TC,TA,TU}
876+
::Val{UNIT}
877+
) where {UNIT,TC,TA,TU}
948878
# Main._a[] = (spc, spa, spu, M, N, mtb, Val(UNIT), static(X));
949879
(Md, Mr) = VectorizationBase.vdivrem(M, mtb)
950880
Nblock = Md + (Mr 0)
951881
Mrem = Core.ifelse(Mr 0, Mr, mtb)
952-
f = RDivBlockMandNv2{UNIT,XC,XA}()
953882
batch(
954-
f,
883+
RDivBlockMandNv2{UNIT}(),
955884
(Nblock, min(Nblock, Threads.nthreads())),
956885
spc,
957886
spa,
@@ -977,60 +906,6 @@ function unroll_factor(::StaticInt{W}) where {W}
977906
ifelse(Static.lt(num_blocks, StaticInt{1}()), StaticInt{1}(), num_blocks)
978907
end
979908

980-
function rdiv_U!(
981-
spc::AbstractStridedPointer{T},
982-
spa::AbstractStridedPointer,
983-
spu::AbstractStridedPointer,
984-
M,
985-
N,
986-
::StaticInt{var"#UNUSED1#"},
987-
::StaticInt{var"#UNUSED2#"},
988-
::Val{UNIT}
989-
) where {T,UNIT,var"#UNUSED1#",var"#UNUSED2#"}
990-
WS = pick_vector_width(T)
991-
W = Int(WS)
992-
UF = unroll_factor(WS)
993-
WU = UF * WS
994-
Nd, Nr = VectorizationBase.vdivrem(N, WS)
995-
spb, preserve = lubuffer(Val(T), UF, N)
996-
m = 0
997-
GC.@preserve preserve begin
998-
if UF > 1
999-
while m < M - WU + 1
1000-
n = Nr
1001-
if n > 0
1002-
BdivU_small_kern_u!(spb, spc, spa, spu, n, UF, Val(UNIT), WS)
1003-
end
1004-
for _ 1:Nd
1005-
rdiv_solve_W_u!(spb, spc, spa, spu, n, WS, UF, Val(UNIT))
1006-
n += W
1007-
end
1008-
m += WU
1009-
spa = gesp(spa, (WU, StaticInt(0)))
1010-
spc = gesp(spc, (WU, StaticInt(0)))
1011-
end
1012-
end
1013-
finalmask = VectorizationBase.mask(WS, M)
1014-
while m < M
1015-
ubm = m + W
1016-
nomaskiter = ubm < M
1017-
mask = nomaskiter ? VectorizationBase.max_mask(WS) : finalmask
1018-
n = Nr
1019-
if n > 0
1020-
BdivU_small_kern!(spb, spc, spa, spu, n, mask, Val(UNIT))
1021-
end
1022-
for i 1:Nd
1023-
rdiv_solve_W!(spb, spc, spa, spu, n, i Nd, mask, Val(UNIT))
1024-
n += W
1025-
end
1026-
spa = gesp(spa, (WS, StaticInt(0)))
1027-
spc = gesp(spc, (WS, StaticInt(0)))
1028-
m = ubm
1029-
end
1030-
end
1031-
nothing
1032-
end
1033-
1034909
@generated function _ldiv_remainder!(
1035910
spc,
1036911
spa,
@@ -1109,34 +984,50 @@ end
1109984
) where {W,UNIT}
1110985
WS = static(W)
1111986
# US = static(U)
1112-
quote
1113-
# $(Expr(:meta, :inline))
1114-
Base.Cartesian.@nif $(W - 1) w -> m == M - w w -> _ldiv_remainder!(
1115-
spc,
1116-
spa,
1117-
spu,
1118-
M,
1119-
N,
1120-
m,
1121-
Nr,
1122-
$WS,
1123-
$(Val(UNIT)),
1124-
StaticInt(w)
1125-
)
987+
if W == 2
988+
quote
989+
$(Expr(:meta, :inline))
990+
_ldiv_remainder!(
991+
spc,
992+
spa,
993+
spu,
994+
M,
995+
N,
996+
m,
997+
Nr,
998+
$WS,
999+
$(Val(UNIT)),
1000+
$(static(1))
1001+
)
1002+
end
1003+
else
1004+
quote
1005+
# $(Expr(:meta, :inline))
1006+
Base.Cartesian.@nif $(W - 1) w -> m == M - w w -> _ldiv_remainder!(
1007+
spc,
1008+
spa,
1009+
spu,
1010+
M,
1011+
N,
1012+
m,
1013+
Nr,
1014+
$WS,
1015+
$(Val(UNIT)),
1016+
StaticInt(w)
1017+
)
1018+
end
11261019
end
11271020
end
11281021

11291022
# spc = spa / spu
11301023
# spc' = (spu' \ spa')'
11311024
# This is ldiv
11321025
function rdiv_U!(
1133-
spc::AbstractStridedPointer{T},
1134-
spa::AbstractStridedPointer,
1135-
spu::AbstractStridedPointer,
1026+
spc::AbstractStridedPointer{T,2,2},
1027+
spa::AbstractStridedPointer{T,2,2},
1028+
spu::AbstractStridedPointer{T,2,2},
11361029
M,
11371030
N,
1138-
::StaticInt{2},
1139-
::StaticInt{2},
11401031
::Val{UNIT}
11411032
) where {T,UNIT}
11421033
WS = pick_vector_width(T)

0 commit comments

Comments
 (0)