Skip to content

Commit 52532ec

Browse files
committed
faster BdivU_small_kern
1 parent 0135554 commit 52532ec

File tree

1 file changed

+125
-76
lines changed

1 file changed

+125
-76
lines changed

src/TriangularSolve.jl

Lines changed: 125 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
module TriangularSolve
2-
using Base: @nexprs
2+
using Base: @nexprs, @ntuple
33
if isdefined(Base, :Experimental) &&
44
isdefined(Base.Experimental, Symbol("@max_methods"))
55
@eval Base.Experimental.@max_methods 1
@@ -53,92 +53,128 @@ end
5353
@inline maybestore!(p, v, i, m) = vstore!(p, v, i, m)
5454
@inline maybestore!(::Nothing, v, i, m) = nothing
5555

56-
@inline function store_small_kern!(spa, sp, v, _, i, n, mask, ::Val{true})
56+
@inline function store_small_kern!(spa, sp, v, i, mask)
5757
vstore!(spa, v, i, mask)
5858
vstore!(sp, v, i, mask)
5959
end
60-
@inline store_small_kern!(spa, ::Nothing, v, spu, i, n, mask, ::Val{true}) =
61-
vstore!(spa, v, i, mask)
62-
63-
@inline function store_small_kern!(spa, sp, v, spu, i, n, mask, ::Val{false})
64-
x = v / vload(spu, (n, n))
65-
vstore!(spa, x, i, mask)
66-
vstore!(sp, x, i, mask)
67-
end
68-
@inline store_small_kern!(spa, ::Nothing, v, spu, i, n, mask, ::Val{false}) =
69-
vstore!(spa, v / vload(spu, (n, n)), i, mask)
60+
@inline store_small_kern!(spa, ::Nothing, v, i, mask) = vstore!(spa, v, i, mask)
7061

71-
@inline function store_small_kern!(spa, sp, v, spu, i, n, ::Val{true})
62+
@inline function store_small_kern!(spa, sp, v, i)
7263
vstore!(spa, v, i)
7364
vstore!(sp, v, i)
7465
end
75-
@inline store_small_kern!(spa, ::Nothing, v, spu, i, n, ::Val{true}) =
76-
vstore!(spa, v, i)
66+
@inline store_small_kern!(spa, ::Nothing, v, i) = vstore!(spa, v, i)
7767

78-
@inline function store_small_kern!(spa, sp, v, spu, i, n, ::Val{false})
79-
x = v / vload(spu, (n, n))
80-
vstore!(spa, x, i)
81-
vstore!(sp, x, i)
82-
end
83-
@inline store_small_kern!(spa, ::Nothing, v, spu, i, n, ::Val{false}) =
84-
vstore!(spa, v / vload(spu, (n, n)), i)
85-
86-
@inline function BdivU_small_kern!(
68+
@generated function BdivU_small_kern!(
8769
spa::AbstractStridedPointer{T},
8870
sp,
8971
spb::AbstractStridedPointer{T},
9072
spu::AbstractStridedPointer{T},
91-
N,
73+
::StaticInt{N},
9274
mask::AbstractMask{W},
9375
::Val{UNIT}
94-
) where {T,UNIT,W}
95-
# W = VectorizationBase.pick_vector_width(T)
96-
for n CloseOpen(N)
97-
Amn = vload(spb, (MM{W}(StaticInt(0)), n), mask)
98-
for k SafeCloseOpen(n)
99-
Amn = vfnmadd_fast(
100-
vload(spa, (MM{W}(StaticInt(0)), k), mask),
101-
vload(spu, (k, n)),
102-
Amn
103-
)
76+
) where {T,UNIT,W,N}
77+
z = static(0)
78+
if N == 1
79+
i = (MM{W}(z), z)
80+
Amn = :(vload(spb, $i, mask))
81+
if !UNIT
82+
Amn = :($Amn / vload(spu, $((z, z))))
83+
end
84+
quote
85+
$(Expr(:meta, :inline))
86+
store_small_kern!(spa, sp, $Amn, $i, mask)
87+
end
88+
else
89+
unroll = Unroll{2,1,N,1,W,(-1 % UInt),1}((z, z))
90+
tostore = :(VecUnroll(Base.Cartesian.@ntuple $N Amn))
91+
scale = UNIT ? nothing : :(Amn_n /= vload(spu, (n - 1, n - 1)))
92+
quote
93+
$(Expr(:meta, :inline))
94+
Amn = getfield(vload(spb, $unroll, mask), :data)
95+
Base.Cartesian.@nexprs $N n -> begin
96+
Amn_n = getfield(Amn, n)
97+
Base.Cartesian.@nexprs (n - 1) k -> begin
98+
Amn_n = vfnmadd_fast(Amn_k, vload(spu, (k - 1, n - 1)), Amn_n)
99+
end
100+
$scale
101+
end
102+
store_small_kern!(spa, sp, $tostore, $unroll, mask)
104103
end
105-
store_small_kern!(
106-
spa,
107-
sp,
108-
Amn,
109-
spu,
110-
(MM{W}(StaticInt(0)), n),
111-
n,
112-
mask,
113-
Val{UNIT}()
114-
)
115104
end
116105
end
117-
@inline function BdivU_small_kern_u!(
106+
@generated function BdivU_small_kern_u!(
118107
spa::AbstractStridedPointer{T},
119108
sp,
120109
spb::AbstractStridedPointer{T},
121110
spu::AbstractStridedPointer{T},
122-
N,
111+
::StaticInt{N},
123112
::StaticInt{U},
124-
::Val{UNIT}
125-
) where {T,U,UNIT}
126-
W = Int(VectorizationBase.pick_vector_width(T))
127-
for n CloseOpen(N)
128-
Amn = vload(spb, Unroll{1,W,U,1,W,zero(UInt),1}((StaticInt(0), n)))
129-
for k SafeCloseOpen(n)
130-
Amk = vload(spa, Unroll{1,W,U,1,W,zero(UInt),1}((StaticInt(0), k)))
131-
Amn = vfnmadd_fast(Amk, vload(spu, (k, n)), Amn)
113+
::Val{UNIT},
114+
::StaticInt{W}
115+
) where {T,U,UNIT,N,W}
116+
z = static(0)
117+
if N == 1
118+
unroll = Unroll{1,W,U,1,W,zero(UInt),1}((z, z))
119+
Amn = :(vload(spb, $unroll))
120+
if !UNIT
121+
Amn = :($Amn / vload(spu, $((z, z))))
132122
end
133-
store_small_kern!(
134-
spa,
135-
sp,
136-
Amn,
137-
spu,
138-
Unroll{1,W,U,1,W,zero(UInt),1}((StaticInt(0), n)),
139-
n,
140-
Val{UNIT}()
141-
)
123+
quote
124+
$(Expr(:meta, :inline))
125+
store_small_kern!(spa, sp, $Amn, $unroll)
126+
end
127+
else
128+
double_unroll =
129+
Unroll{2,1,N,1,W,zero(UInt),1}(Unroll{1,W,U,1,W,zero(UInt),1}((z, z)))
130+
tostore = :(VecUnroll(Base.Cartesian.@ntuple $N Amn))
131+
scale = UNIT ? nothing : :(Amn_n /= vload(spu, (n - 1, n - 1)))
132+
quote
133+
$(Expr(:meta, :inline))
134+
Amn = getfield(vload(spb, $double_unroll), :data)
135+
Base.Cartesian.@nexprs $N n -> begin
136+
Amn_n = getfield(Amn, n)
137+
Base.Cartesian.@nexprs (n - 1) k -> begin
138+
Amn_n = vfnmadd_fast(Amn_k, vload(spu, (k - 1, n - 1)), Amn_n)
139+
end
140+
$scale
141+
end
142+
store_small_kern!(spa, sp, $tostore, $double_unroll)
143+
end
144+
end
145+
end
146+
@generated function BdivU_small_kern!(
147+
spa::AbstractStridedPointer{T},
148+
sp,
149+
spb::AbstractStridedPointer{T},
150+
spu::AbstractStridedPointer{T},
151+
Nr::Int,
152+
mask::AbstractMask{W},
153+
::Val{UNIT}
154+
) where {T,UNIT,W}
155+
quote
156+
# $(Expr(:meta, :inline))
157+
Base.Cartesian.@nif $(W - 1) n -> n == Nr n ->
158+
BdivU_small_kern!(spa, sp, spb, spu, static(n), mask, $(Val(UNIT)))
159+
end
160+
end
161+
@generated function BdivU_small_kern_u!(
162+
spa::AbstractStridedPointer{T},
163+
sp,
164+
spb::AbstractStridedPointer{T},
165+
spu::AbstractStridedPointer{T},
166+
Nr::Int,
167+
::StaticInt{U},
168+
::Val{UNIT},
169+
::StaticInt{W}
170+
) where {T,U,UNIT,W}
171+
su = static(U)
172+
vu = Val(UNIT)
173+
sw = static(W)
174+
quote
175+
# $(Expr(:meta, :inline))
176+
Base.Cartesian.@nif $(W - 1) n -> n == Nr n ->
177+
BdivU_small_kern_u!(spa, sp, spb, spu, static(n), $su, $vu, $sw)
142178
end
143179
end
144180

@@ -232,7 +268,7 @@ end
232268
) where {W,U,UNIT}
233269
z = static(0)
234270
quote
235-
$(Expr(:meta, :inline))
271+
# $(Expr(:meta, :inline))
236272
# C = L \ A; L * C = A
237273
# A_{i,j} = L_{i,i}*C_{i,j} + \sum_{k=1}^{i-1}L_{i,k}C_{k,j}
238274
# C_{i,j} = L_{i,i} \ (A_{i,j} - \sum_{k=1}^{i-1}L_{i,k}C_{k,j})
@@ -328,7 +364,7 @@ end
328364
) where {W,UNIT}
329365
z = static(0)
330366
quote
331-
$(Expr(:meta, :inline))
367+
# $(Expr(:meta, :inline))
332368
# Like `ldiv_solve_W_u!`, except no unrolling, just a `W`x`W` block
333369
#
334370
# C = L \ A; L * C = A
@@ -382,9 +418,8 @@ end
382418
R <= 1 && throw("Remainder of `<= 1` shouldn't be called, but had $R.")
383419
R >= W && throw("Reaminderof `>= $W` shouldn't be called, but had $R.")
384420
z = static(0)
385-
WS = static(W)
386421
q = quote
387-
$(Expr(:meta, :inline))
422+
# $(Expr(:meta, :inline))
388423
# Like `ldiv_solve_W_u!`, except no unrolling, just a `W`x`W` block
389424
#
390425
# C = L \ A; L * C = A
@@ -447,6 +482,7 @@ end
447482
push!(q.args, q3)
448483
return q
449484
end
485+
450486
@inline function rdiv_U!(
451487
spc::AbstractStridedPointer{T},
452488
spa::AbstractStridedPointer,
@@ -467,7 +503,7 @@ end
467503
while m < M - WU + 1
468504
n = Nr
469505
if n > 0
470-
BdivU_small_kern_u!(spc, nothing, spa, spu, n, UF, Val(UNIT))
506+
BdivU_small_kern_u!(spc, nothing, spa, spu, n, UF, Val(UNIT), WS)
471507
end
472508
for _ 1:Nd
473509
rdiv_solve_W_u!(spc, nothing, spa, spu, n, WS, UF, Val(UNIT))
@@ -963,7 +999,7 @@ function rdiv_U!(
963999
while m < M - WU + 1
9641000
n = Nr
9651001
if n > 0
966-
BdivU_small_kern_u!(spb, spc, spa, spu, n, UF, Val(UNIT))
1002+
BdivU_small_kern_u!(spb, spc, spa, spu, n, UF, Val(UNIT), WS)
9671003
end
9681004
for _ 1:Nd
9691005
rdiv_solve_W_u!(spb, spc, spa, spu, n, WS, UF, Val(UNIT))
@@ -995,7 +1031,7 @@ function rdiv_U!(
9951031
nothing
9961032
end
9971033

998-
@generated function ldiv_remainder!(
1034+
@generated function _ldiv_remainder!(
9991035
spc,
10001036
spa,
10011037
spu,
@@ -1011,16 +1047,20 @@ end
10111047
r >= W && throw("Reaminderof `>= $W` shouldn't be called, but had $r.")
10121048
if r == 1
10131049
z = static(0)
1050+
sub = Base.FastMath.sub_fast
1051+
mul = Base.FastMath.mul_fast
1052+
div = Base.FastMath.div_fast
10141053
vlxj = :(vload(spc, ($z, j)))
10151054
if UNIT
10161055
vlxj = :(xj = $vlxj)
10171056
else
10181057
vlxj = quote
1019-
xj = $vlxj / vload(spu, (j, j))
1058+
xj = $div($vlxj, vload(spu, (j, j)))
10201059
vstore!(spc, xj, ($z, j))
10211060
end
10221061
end
10231062
quote
1063+
$(Expr(:meta, :inline))
10241064
if pointer(spc) != pointer(spa)
10251065
for n = 0:N-1
10261066
vstore!(spc, vload(spa, ($z, n)), ($z, n))
@@ -1031,7 +1071,7 @@ end
10311071
for i = (j+1):N-1
10321072
xi = vload(spc, ($z, i))
10331073
Uji = vload(spu, (j, i))
1034-
vstore!(spc, xi - xj * Uji, ($z, i))
1074+
vstore!(spc, $sub(xi, $mul(xj, Uji)), ($z, i))
10351075
end
10361076
end
10371077
end
@@ -1070,8 +1110,8 @@ end
10701110
WS = static(W)
10711111
# US = static(U)
10721112
quote
1073-
$(Expr(:meta, :inline))
1074-
Base.Cartesian.@nif $W w -> m == M - w w -> ldiv_remainder!(
1113+
# $(Expr(:meta, :inline))
1114+
Base.Cartesian.@nif $(W - 1) w -> m == M - w w -> _ldiv_remainder!(
10751115
spc,
10761116
spa,
10771117
spu,
@@ -1111,7 +1151,16 @@ function rdiv_U!(
11111151
while m < M - WS + 1
11121152
n = Nr # non factor of W remainder
11131153
if n > 0
1114-
BdivU_small_kern_u!(spc, nothing, spa, spu, n, StaticInt(1), Val(UNIT))
1154+
BdivU_small_kern_u!(
1155+
spc,
1156+
nothing,
1157+
spa,
1158+
spu,
1159+
n,
1160+
StaticInt(1),
1161+
Val(UNIT),
1162+
WS
1163+
)
11151164
end
11161165
while n < N - (WU - 1)
11171166
ldiv_solve_W_u!(spc, spa, spu, n, WS, UF, Val(UNIT))

0 commit comments

Comments
 (0)