@@ -20,6 +20,15 @@ using IfElse: ifelse
20
20
using LoopVectorization
21
21
using Polyester
22
22
23
+ const LPtr{T} = Core. LLVMPtr{T,0 }
24
+ _lptr (x:: Ptr{T} ) where {T} = reinterpret (LPtr{T}, x)
25
+ _lptr (x) = x
26
+ _ptr (x:: LPtr{T} ) where {T} = reinterpret (Ptr{T}, x)
27
+ _ptr (x) = x
28
+ @inline reassemble_tup (:: Type{T} , t) where {T} =
29
+ LoopVectorization. reassemble_tuple (T, map (_ptr, t))
30
+ @inline flatten_to_tup (t) = map (_lptr, LoopVectorization. flatten_to_tuple (t))
31
+
23
32
@generated function solve_AU (
24
33
A:: VecUnroll{Nm1} ,
25
34
spu:: AbstractStridedPointer ,
65
74
quote
66
75
$ (Expr (:meta , :inline ))
67
76
mask = $ (VectorizationBase. Mask{W})(_mask)
68
- spa, spu = LoopVectorization . reassemble_tuple ($ Args, args)
77
+ spa, spu = reassemble_tup ($ Args, args)
69
78
vstore! (spa, $ Amn, $ i, mask)
70
79
end
71
80
else
74
83
scale = UNIT ? nothing : :(Amn_n /= vload (spu, (n - 1 , n - 1 )))
75
84
quote
76
85
$ (Expr (:meta , :inline ))
77
- spa, spu = LoopVectorization . reassemble_tuple ($ Args, args)
86
+ spa, spu = reassemble_tup ($ Args, args)
78
87
mask = $ (VectorizationBase. Mask{W})(_mask)
79
88
Amn = getfield (vload (spa, $ unroll, mask), :data )
80
89
Base. Cartesian. @nexprs $ N n -> begin
120
129
end
121
130
quote
122
131
$ (Expr (:meta , :inline ))
123
- spa, spu = LoopVectorization . reassemble_tuple ($ Args, args)
132
+ spa, spu = reassemble_tup ($ Args, args)
124
133
vstore! (spa, $ Amn, $ unroll)
125
134
end
126
135
else
130
139
scale = UNIT ? nothing : :(Amn_n /= vload (spu, (n - 1 , n - 1 )))
131
140
quote
132
141
$ (Expr (:meta , :inline ))
133
- spa, spu = LoopVectorization . reassemble_tuple ($ Args, args)
142
+ spa, spu = reassemble_tup ($ Args, args)
134
143
Amn = getfield (vload (spa, $ double_unroll), :data )
135
144
Base. Cartesian. @nexprs $ N n -> begin
136
145
Amn_n = getfield (Amn, n)
468
477
while m < M - WU + 1
469
478
n = Nr
470
479
if n > 0
471
- let t = (spa, spu), ft = LoopVectorization . flatten_to_tuple (t)
480
+ let t = (spa, spu), ft = flatten_to_tup (t)
472
481
BdivU_small_kern_u! (n, UF, Val (UNIT), WS, typeof (t), ft... )
473
482
end
474
483
end
488
497
n = Nr
489
498
if n > 0
490
499
let t = (spa, spu),
491
- ft = LoopVectorization . flatten_to_tuple (t),
500
+ ft = flatten_to_tup (t),
492
501
mask = getfield (mask, :u ) % UInt32
493
502
494
503
BdivU_small_kern! (n, mask, WS, Val (UNIT), typeof (t), ft... )
@@ -733,22 +742,17 @@ end
733
742
end
734
743
@inline function Mat (A:: AbstractMatrix{T} ) where {T}
735
744
r, c = LoopVectorization. ArrayInterface. stride_rank (A)
736
- M, N = size (A)
745
+ M, N = size (A)
737
746
if r === static (1 )
738
- Mat {T,true} (pointer (A), stride (A,2 ), M, N)
747
+ Mat {T,true} (pointer (A), stride (A, 2 ), M, N)
739
748
else
740
- @assert c === static (1 )
741
- Mat {T,false} (pointer (A), stride (A,1 ), M, N)
749
+ @assert c === static (1 )
750
+ Mat {T,false} (pointer (A), stride (A, 1 ), M, N)
742
751
end
743
752
end
744
753
745
754
# C -= A * B
746
- @inline function _schur_complement! (
747
- C:: Mat ,
748
- A:: Mat ,
749
- B:: Mat ,
750
- :: Val{false}
751
- )
755
+ @inline function _schur_complement! (C:: Mat , A:: Mat , B:: Mat , :: Val{false} )
752
756
# _turbo_! will not be inlined
753
757
@turbo warn_check_args = false for n in indices ((C, B), 2 ),
754
758
m in indices ((C, A), 1 )
760
764
C[m, n] = Cmn
761
765
end
762
766
end
763
- @inline function _schur_complement! (
764
- C:: Mat ,
765
- A:: Mat ,
766
- B:: Mat ,
767
- :: Val{true}
768
- )
767
+ @inline function _schur_complement! (C:: Mat , A:: Mat , B:: Mat , :: Val{true} )
769
768
# _turbo_! will not be inlined
770
769
@tturbo warn_check_args = false for n in indices ((C, B), 2 ),
771
770
m in indices ((C, A), 1 )
@@ -839,7 +838,12 @@ function rdiv_block_N!(
839
838
n += B_normalized
840
839
repeat = n + B_normalized < N
841
840
N_temp = repeat ? N_temp : N - n
842
- schur_complement! (Mat (spa, M, N_temp), Mat (spa_base, M, n), Mat (spu, n, N_temp), Val (false ))
841
+ schur_complement! (
842
+ Mat (spa, M, N_temp),
843
+ Mat (spa_base, M, n),
844
+ Mat (spu, n, N_temp),
845
+ Val (false )
846
+ )
843
847
end
844
848
end
845
849
function rdiv_block_MandN! (
974
978
n = Nr # non factor of W remainder
975
979
if n > 0
976
980
let t = (spa, spu),
977
- ft = LoopVectorization . flatten_to_tuple (t),
981
+ ft = flatten_to_tup (t),
978
982
mask = $ (getfield (_mask (WS, r), :u ) % UInt32)
979
983
980
984
BdivU_small_kern! (n, mask, $ WS, $ (Val (UNIT)), typeof (t), ft... )
@@ -1007,14 +1011,14 @@ end
1007
1011
if W == 2
1008
1012
quote
1009
1013
$ (Expr (:meta , :inline ))
1010
- spa, spu = LoopVectorization . reassemble_tuple (Args, args)
1014
+ spa, spu = reassemble_tup (Args, args)
1011
1015
_ldiv_remainder! (spa, spu, N, Nr, $ WS, $ (Val (UNIT)), $ (static (1 )))
1012
1016
nothing
1013
1017
end
1014
1018
elseif W == 8
1015
1019
quote
1016
1020
# $(Expr(:meta, :inline))
1017
- spa, spu = LoopVectorization . reassemble_tuple (Args, args)
1021
+ spa, spu = reassemble_tup (Args, args)
1018
1022
if m == M - 1
1019
1023
_ldiv_remainder! (spa, spu, N, Nr, static (8 ), $ (Val (UNIT)), StaticInt (1 ))
1020
1024
else
@@ -1093,7 +1097,7 @@ end
1093
1097
else
1094
1098
quote
1095
1099
# $(Expr(:meta, :inline))
1096
- spa, spu = LoopVectorization . reassemble_tuple (Args, args)
1100
+ spa, spu = reassemble_tup (Args, args)
1097
1101
Base. Cartesian. @nif $ (W - 1 ) w -> m == M - w w ->
1098
1102
_ldiv_remainder! (spa, spu, N, Nr, $ WS, $ (Val (UNIT)), static (w))
1099
1103
nothing
@@ -1108,7 +1112,7 @@ end
1108
1112
:: Val{UNIT}
1109
1113
) where {T,UNIT}
1110
1114
tup = (spa, spu)
1111
- ftup = LoopVectorization . flatten_to_tuple (tup)
1115
+ ftup = flatten_to_tup (tup)
1112
1116
_ldiv_L! (M, N, Val (UNIT), typeof (tup), ftup... )
1113
1117
end
1114
1118
@@ -1122,7 +1126,7 @@ function _ldiv_L!(
1122
1126
:: Type{Args} ,
1123
1127
args:: Vararg{Any,K}
1124
1128
) where {UNIT,Args,K}
1125
- spa, spu = LoopVectorization . reassemble_tuple (Args, args)
1129
+ spa, spu = reassemble_tup (Args, args)
1126
1130
T = eltype (spa)
1127
1131
WS = pick_vector_width (T)
1128
1132
W = Int (WS)
@@ -1134,7 +1138,7 @@ function _ldiv_L!(
1134
1138
while m < M - WS + 1
1135
1139
n = Nr # non factor of W remainder
1136
1140
if n > 0
1137
- let t = (spa, spu), ft = LoopVectorization . flatten_to_tuple (t)
1141
+ let t = (spa, spu), ft = flatten_to_tup (t)
1138
1142
BdivU_small_kern_u! (n, StaticInt (1 ), Val (UNIT), WS, typeof (t), ft... )
1139
1143
end
1140
1144
end
@@ -1151,7 +1155,7 @@ function _ldiv_L!(
1151
1155
end
1152
1156
# remainder on `m`
1153
1157
if m < M
1154
- let tup = (spa, spu), ftup = LoopVectorization . flatten_to_tuple (tup)
1158
+ let tup = (spa, spu), ftup = flatten_to_tup (tup)
1155
1159
ldiv_remainder! (M, N, m, Nr, WS, Val (UNIT), typeof (tup), ftup... )
1156
1160
end
1157
1161
end
0 commit comments