1
1
module TriangularSolve
2
- using Base: @nexprs
2
+ using Base: @nexprs , @ntuple
3
3
if isdefined (Base, :Experimental ) &&
4
4
isdefined (Base. Experimental, Symbol (" @max_methods" ))
5
5
@eval Base. Experimental. @max_methods 1
53
53
@inline maybestore! (p, v, i, m) = vstore! (p, v, i, m)
54
54
@inline maybestore! (:: Nothing , v, i, m) = nothing
55
55
56
- @inline function store_small_kern! (spa, sp, v, _, i, n, mask, :: Val{true} )
56
+ @inline function store_small_kern! (spa, sp, v, i, mask)
57
57
vstore! (spa, v, i, mask)
58
58
vstore! (sp, v, i, mask)
59
59
end
60
- @inline store_small_kern! (spa, :: Nothing , v, spu, i, n, mask, :: Val{true} ) =
61
- vstore! (spa, v, i, mask)
62
-
63
- @inline function store_small_kern! (spa, sp, v, spu, i, n, mask, :: Val{false} )
64
- x = v / vload (spu, (n, n))
65
- vstore! (spa, x, i, mask)
66
- vstore! (sp, x, i, mask)
67
- end
68
- @inline store_small_kern! (spa, :: Nothing , v, spu, i, n, mask, :: Val{false} ) =
69
- vstore! (spa, v / vload (spu, (n, n)), i, mask)
60
+ @inline store_small_kern! (spa, :: Nothing , v, i, mask) = vstore! (spa, v, i, mask)
70
61
71
- @inline function store_small_kern! (spa, sp, v, spu, i, n, :: Val{true} )
62
+ @inline function store_small_kern! (spa, sp, v, i )
72
63
vstore! (spa, v, i)
73
64
vstore! (sp, v, i)
74
65
end
75
- @inline store_small_kern! (spa, :: Nothing , v, spu, i, n, :: Val{true} ) =
76
- vstore! (spa, v, i)
66
+ @inline store_small_kern! (spa, :: Nothing , v, i) = vstore! (spa, v, i)
77
67
78
- @inline function store_small_kern! (spa, sp, v, spu, i, n, :: Val{false} )
79
- x = v / vload (spu, (n, n))
80
- vstore! (spa, x, i)
81
- vstore! (sp, x, i)
82
- end
83
- @inline store_small_kern! (spa, :: Nothing , v, spu, i, n, :: Val{false} ) =
84
- vstore! (spa, v / vload (spu, (n, n)), i)
85
-
86
- @inline function BdivU_small_kern! (
68
+ @generated function BdivU_small_kern! (
87
69
spa:: AbstractStridedPointer{T} ,
88
70
sp,
89
71
spb:: AbstractStridedPointer{T} ,
90
72
spu:: AbstractStridedPointer{T} ,
91
- N ,
73
+ :: StaticInt{N} ,
92
74
mask:: AbstractMask{W} ,
93
75
:: Val{UNIT}
94
- ) where {T,UNIT,W}
95
- # W = VectorizationBase.pick_vector_width(T)
96
- for n ∈ CloseOpen (N)
97
- Amn = vload (spb, (MM {W} (StaticInt (0 )), n), mask)
98
- for k ∈ SafeCloseOpen (n)
99
- Amn = vfnmadd_fast (
100
- vload (spa, (MM {W} (StaticInt (0 )), k), mask),
101
- vload (spu, (k, n)),
102
- Amn
103
- )
76
+ ) where {T,UNIT,W,N}
77
+ z = static (0 )
78
+ if N == 1
79
+ i = (MM {W} (z), z)
80
+ Amn = :(vload (spb, $ i, mask))
81
+ if ! UNIT
82
+ Amn = :($ Amn / vload (spu, $ ((z, z))))
83
+ end
84
+ quote
85
+ $ (Expr (:meta , :inline ))
86
+ store_small_kern! (spa, sp, $ Amn, $ i, mask)
87
+ end
88
+ else
89
+ unroll = Unroll {2,1,N,1,W,(-1 % UInt),1} ((z, z))
90
+ tostore = :(VecUnroll (Base. Cartesian. @ntuple $ N Amn))
91
+ scale = UNIT ? nothing : :(Amn_n /= vload (spu, (n - 1 , n - 1 )))
92
+ quote
93
+ $ (Expr (:meta , :inline ))
94
+ Amn = getfield (vload (spb, $ unroll, mask), :data )
95
+ Base. Cartesian. @nexprs $ N n -> begin
96
+ Amn_n = getfield (Amn, n)
97
+ Base. Cartesian. @nexprs (n - 1 ) k -> begin
98
+ Amn_n = vfnmadd_fast (Amn_k, vload (spu, (k - 1 , n - 1 )), Amn_n)
99
+ end
100
+ $ scale
101
+ end
102
+ store_small_kern! (spa, sp, $ tostore, $ unroll, mask)
104
103
end
105
- store_small_kern! (
106
- spa,
107
- sp,
108
- Amn,
109
- spu,
110
- (MM {W} (StaticInt (0 )), n),
111
- n,
112
- mask,
113
- Val {UNIT} ()
114
- )
115
104
end
116
105
end
117
- @inline function BdivU_small_kern_u! (
106
+ @generated function BdivU_small_kern_u! (
118
107
spa:: AbstractStridedPointer{T} ,
119
108
sp,
120
109
spb:: AbstractStridedPointer{T} ,
121
110
spu:: AbstractStridedPointer{T} ,
122
- N ,
111
+ :: StaticInt{N} ,
123
112
:: StaticInt{U} ,
124
- :: Val{UNIT}
125
- ) where {T,U,UNIT}
126
- W = Int (VectorizationBase. pick_vector_width (T))
127
- for n ∈ CloseOpen (N)
128
- Amn = vload (spb, Unroll {1,W,U,1,W,zero(UInt),1} ((StaticInt (0 ), n)))
129
- for k ∈ SafeCloseOpen (n)
130
- Amk = vload (spa, Unroll {1,W,U,1,W,zero(UInt),1} ((StaticInt (0 ), k)))
131
- Amn = vfnmadd_fast (Amk, vload (spu, (k, n)), Amn)
113
+ :: Val{UNIT} ,
114
+ :: StaticInt{W}
115
+ ) where {T,U,UNIT,N,W}
116
+ z = static (0 )
117
+ if N == 1
118
+ unroll = Unroll {1,W,U,1,W,zero(UInt),1} ((z, z))
119
+ Amn = :(vload (spb, $ unroll))
120
+ if ! UNIT
121
+ Amn = :($ Amn / vload (spu, $ ((z, z))))
132
122
end
133
- store_small_kern! (
134
- spa,
135
- sp,
136
- Amn,
137
- spu,
138
- Unroll {1,W,U,1,W,zero(UInt),1} ((StaticInt (0 ), n)),
139
- n,
140
- Val {UNIT} ()
141
- )
123
+ quote
124
+ $ (Expr (:meta , :inline ))
125
+ store_small_kern! (spa, sp, $ Amn, $ unroll)
126
+ end
127
+ else
128
+ double_unroll =
129
+ Unroll {2,1,N,1,W,zero(UInt),1} (Unroll {1,W,U,1,W,zero(UInt),1} ((z, z)))
130
+ tostore = :(VecUnroll (Base. Cartesian. @ntuple $ N Amn))
131
+ scale = UNIT ? nothing : :(Amn_n /= vload (spu, (n - 1 , n - 1 )))
132
+ quote
133
+ $ (Expr (:meta , :inline ))
134
+ Amn = getfield (vload (spb, $ double_unroll), :data )
135
+ Base. Cartesian. @nexprs $ N n -> begin
136
+ Amn_n = getfield (Amn, n)
137
+ Base. Cartesian. @nexprs (n - 1 ) k -> begin
138
+ Amn_n = vfnmadd_fast (Amn_k, vload (spu, (k - 1 , n - 1 )), Amn_n)
139
+ end
140
+ $ scale
141
+ end
142
+ store_small_kern! (spa, sp, $ tostore, $ double_unroll)
143
+ end
144
+ end
145
+ end
146
+ @generated function BdivU_small_kern! (
147
+ spa:: AbstractStridedPointer{T} ,
148
+ sp,
149
+ spb:: AbstractStridedPointer{T} ,
150
+ spu:: AbstractStridedPointer{T} ,
151
+ Nr:: Int ,
152
+ mask:: AbstractMask{W} ,
153
+ :: Val{UNIT}
154
+ ) where {T,UNIT,W}
155
+ quote
156
+ # $(Expr(:meta, :inline))
157
+ Base. Cartesian. @nif $ (W - 1 ) n -> n == Nr n ->
158
+ BdivU_small_kern! (spa, sp, spb, spu, static (n), mask, $ (Val (UNIT)))
159
+ end
160
+ end
161
+ @generated function BdivU_small_kern_u! (
162
+ spa:: AbstractStridedPointer{T} ,
163
+ sp,
164
+ spb:: AbstractStridedPointer{T} ,
165
+ spu:: AbstractStridedPointer{T} ,
166
+ Nr:: Int ,
167
+ :: StaticInt{U} ,
168
+ :: Val{UNIT} ,
169
+ :: StaticInt{W}
170
+ ) where {T,U,UNIT,W}
171
+ su = static (U)
172
+ vu = Val (UNIT)
173
+ sw = static (W)
174
+ quote
175
+ # $(Expr(:meta, :inline))
176
+ Base. Cartesian. @nif $ (W - 1 ) n -> n == Nr n ->
177
+ BdivU_small_kern_u! (spa, sp, spb, spu, static (n), $ su, $ vu, $ sw)
142
178
end
143
179
end
144
180
232
268
) where {W,U,UNIT}
233
269
z = static (0 )
234
270
quote
235
- $ (Expr (:meta , :inline ))
271
+ # $(Expr(:meta, :inline))
236
272
# C = L \ A; L * C = A
237
273
# A_{i,j} = L_{i,i}*C_{i,j} + \sum_{k=1}^{i-1}L_{i,k}C_{k,j}
238
274
# C_{i,j} = L_{i,i} \ (A_{i,j} - \sum_{k=1}^{i-1}L_{i,k}C_{k,j})
328
364
) where {W,UNIT}
329
365
z = static (0 )
330
366
quote
331
- $ (Expr (:meta , :inline ))
367
+ # $(Expr(:meta, :inline))
332
368
# Like `ldiv_solve_W_u!`, except no unrolling, just a `W`x`W` block
333
369
#
334
370
# C = L \ A; L * C = A
382
418
R <= 1 && throw (" Remainder of `<= 1` shouldn't be called, but had $R ." )
383
419
R >= W && throw (" Reaminderof `>= $W ` shouldn't be called, but had $R ." )
384
420
z = static (0 )
385
- WS = static (W)
386
421
q = quote
387
- $ (Expr (:meta , :inline ))
422
+ # $(Expr(:meta, :inline))
388
423
# Like `ldiv_solve_W_u!`, except no unrolling, just a `W`x`W` block
389
424
#
390
425
# C = L \ A; L * C = A
447
482
push! (q. args, q3)
448
483
return q
449
484
end
485
+
450
486
@inline function rdiv_U! (
451
487
spc:: AbstractStridedPointer{T} ,
452
488
spa:: AbstractStridedPointer ,
467
503
while m < M - WU + 1
468
504
n = Nr
469
505
if n > 0
470
- BdivU_small_kern_u! (spc, nothing , spa, spu, n, UF, Val (UNIT))
506
+ BdivU_small_kern_u! (spc, nothing , spa, spu, n, UF, Val (UNIT), WS )
471
507
end
472
508
for _ ∈ 1 : Nd
473
509
rdiv_solve_W_u! (spc, nothing , spa, spu, n, WS, UF, Val (UNIT))
@@ -963,7 +999,7 @@ function rdiv_U!(
963
999
while m < M - WU + 1
964
1000
n = Nr
965
1001
if n > 0
966
- BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT))
1002
+ BdivU_small_kern_u! (spb, spc, spa, spu, n, UF, Val (UNIT), WS )
967
1003
end
968
1004
for _ ∈ 1 : Nd
969
1005
rdiv_solve_W_u! (spb, spc, spa, spu, n, WS, UF, Val (UNIT))
@@ -995,7 +1031,7 @@ function rdiv_U!(
995
1031
nothing
996
1032
end
997
1033
998
- @generated function ldiv_remainder ! (
1034
+ @generated function _ldiv_remainder ! (
999
1035
spc,
1000
1036
spa,
1001
1037
spu,
@@ -1011,16 +1047,20 @@ end
1011
1047
r >= W && throw (" Reaminderof `>= $W ` shouldn't be called, but had $r ." )
1012
1048
if r == 1
1013
1049
z = static (0 )
1050
+ sub = Base. FastMath. sub_fast
1051
+ mul = Base. FastMath. mul_fast
1052
+ div = Base. FastMath. div_fast
1014
1053
vlxj = :(vload (spc, ($ z, j)))
1015
1054
if UNIT
1016
1055
vlxj = :(xj = $ vlxj)
1017
1056
else
1018
1057
vlxj = quote
1019
- xj = $ vlxj / vload (spu, (j, j))
1058
+ xj = $ div ( $ vlxj, vload (spu, (j, j) ))
1020
1059
vstore! (spc, xj, ($ z, j))
1021
1060
end
1022
1061
end
1023
1062
quote
1063
+ $ (Expr (:meta , :inline ))
1024
1064
if pointer (spc) != pointer (spa)
1025
1065
for n = 0 : N- 1
1026
1066
vstore! (spc, vload (spa, ($ z, n)), ($ z, n))
@@ -1031,7 +1071,7 @@ end
1031
1071
for i = (j+ 1 ): N- 1
1032
1072
xi = vload (spc, ($ z, i))
1033
1073
Uji = vload (spu, (j, i))
1034
- vstore! (spc, xi - xj * Uji, ($ z, i))
1074
+ vstore! (spc, $ sub (xi, $ mul (xj, Uji)) , ($ z, i))
1035
1075
end
1036
1076
end
1037
1077
end
@@ -1070,8 +1110,8 @@ end
1070
1110
WS = static (W)
1071
1111
# US = static(U)
1072
1112
quote
1073
- $ (Expr (:meta , :inline ))
1074
- Base. Cartesian. @nif $ W w -> m == M - w w -> ldiv_remainder ! (
1113
+ # $(Expr(:meta, :inline))
1114
+ Base. Cartesian. @nif $ (W - 1 ) w -> m == M - w w -> _ldiv_remainder ! (
1075
1115
spc,
1076
1116
spa,
1077
1117
spu,
@@ -1111,7 +1151,16 @@ function rdiv_U!(
1111
1151
while m < M - WS + 1
1112
1152
n = Nr # non factor of W remainder
1113
1153
if n > 0
1114
- BdivU_small_kern_u! (spc, nothing , spa, spu, n, StaticInt (1 ), Val (UNIT))
1154
+ BdivU_small_kern_u! (
1155
+ spc,
1156
+ nothing ,
1157
+ spa,
1158
+ spu,
1159
+ n,
1160
+ StaticInt (1 ),
1161
+ Val (UNIT),
1162
+ WS
1163
+ )
1115
1164
end
1116
1165
while n < N - (WU - 1 )
1117
1166
ldiv_solve_W_u! (spc, spa, spu, n, WS, UF, Val (UNIT))
0 commit comments