Skip to content

Commit 168d962

Browse files
committed
LPtr
1 parent 3f5b155 commit 168d962

File tree

1 file changed

+35
-31
lines changed

1 file changed

+35
-31
lines changed

src/TriangularSolve.jl

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ using IfElse: ifelse
2020
using LoopVectorization
2121
using Polyester
2222

23+
const LPtr{T} = Core.LLVMPtr{T,0}
24+
_lptr(x::Ptr{T}) where {T} = reinterpret(LPtr{T}, x)
25+
_lptr(x) = x
26+
_ptr(x::LPtr{T}) where {T} = reinterpret(Ptr{T}, x)
27+
_ptr(x) = x
28+
@inline reassemble_tup(::Type{T}, t) where {T} =
29+
LoopVectorization.reassemble_tuple(T, map(_ptr, t))
30+
@inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t))
31+
2332
@generated function solve_AU(
2433
A::VecUnroll{Nm1},
2534
spu::AbstractStridedPointer,
@@ -65,7 +74,7 @@ end
6574
quote
6675
$(Expr(:meta, :inline))
6776
mask = $(VectorizationBase.Mask{W})(_mask)
68-
spa, spu = LoopVectorization.reassemble_tuple($Args, args)
77+
spa, spu = reassemble_tup($Args, args)
6978
vstore!(spa, $Amn, $i, mask)
7079
end
7180
else
@@ -74,7 +83,7 @@ end
7483
scale = UNIT ? nothing : :(Amn_n /= vload(spu, (n - 1, n - 1)))
7584
quote
7685
$(Expr(:meta, :inline))
77-
spa, spu = LoopVectorization.reassemble_tuple($Args, args)
86+
spa, spu = reassemble_tup($Args, args)
7887
mask = $(VectorizationBase.Mask{W})(_mask)
7988
Amn = getfield(vload(spa, $unroll, mask), :data)
8089
Base.Cartesian.@nexprs $N n -> begin
@@ -120,7 +129,7 @@ end
120129
end
121130
quote
122131
$(Expr(:meta, :inline))
123-
spa, spu = LoopVectorization.reassemble_tuple($Args, args)
132+
spa, spu = reassemble_tup($Args, args)
124133
vstore!(spa, $Amn, $unroll)
125134
end
126135
else
@@ -130,7 +139,7 @@ end
130139
scale = UNIT ? nothing : :(Amn_n /= vload(spu, (n - 1, n - 1)))
131140
quote
132141
$(Expr(:meta, :inline))
133-
spa, spu = LoopVectorization.reassemble_tuple($Args, args)
142+
spa, spu = reassemble_tup($Args, args)
134143
Amn = getfield(vload(spa, $double_unroll), :data)
135144
Base.Cartesian.@nexprs $N n -> begin
136145
Amn_n = getfield(Amn, n)
@@ -468,7 +477,7 @@ end
468477
while m < M - WU + 1
469478
n = Nr
470479
if n > 0
471-
let t = (spa, spu), ft = LoopVectorization.flatten_to_tuple(t)
480+
let t = (spa, spu), ft = flatten_to_tup(t)
472481
BdivU_small_kern_u!(n, UF, Val(UNIT), WS, typeof(t), ft...)
473482
end
474483
end
@@ -488,7 +497,7 @@ end
488497
n = Nr
489498
if n > 0
490499
let t = (spa, spu),
491-
ft = LoopVectorization.flatten_to_tuple(t),
500+
ft = flatten_to_tup(t),
492501
mask = getfield(mask, :u) % UInt32
493502

494503
BdivU_small_kern!(n, mask, WS, Val(UNIT), typeof(t), ft...)
@@ -733,22 +742,17 @@ end
733742
end
734743
@inline function Mat(A::AbstractMatrix{T}) where {T}
735744
r, c = LoopVectorization.ArrayInterface.stride_rank(A)
736-
M, N = size(A)
745+
M, N = size(A)
737746
if r === static(1)
738-
Mat{T,true}(pointer(A), stride(A,2), M, N)
747+
Mat{T,true}(pointer(A), stride(A, 2), M, N)
739748
else
740-
@assert c === static(1)
741-
Mat{T,false}(pointer(A), stride(A,1), M, N)
749+
@assert c === static(1)
750+
Mat{T,false}(pointer(A), stride(A, 1), M, N)
742751
end
743752
end
744753

745754
# C -= A * B
746-
@inline function _schur_complement!(
747-
C::Mat,
748-
A::Mat,
749-
B::Mat,
750-
::Val{false}
751-
)
755+
@inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{false})
752756
# _turbo_! will not be inlined
753757
@turbo warn_check_args = false for n in indices((C, B), 2),
754758
m in indices((C, A), 1)
@@ -760,12 +764,7 @@ end
760764
C[m, n] = Cmn
761765
end
762766
end
763-
@inline function _schur_complement!(
764-
C::Mat,
765-
A::Mat,
766-
B::Mat,
767-
::Val{true}
768-
)
767+
@inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{true})
769768
# _turbo_! will not be inlined
770769
@tturbo warn_check_args = false for n in indices((C, B), 2),
771770
m in indices((C, A), 1)
@@ -839,7 +838,12 @@ function rdiv_block_N!(
839838
n += B_normalized
840839
repeat = n + B_normalized < N
841840
N_temp = repeat ? N_temp : N - n
842-
schur_complement!(Mat(spa, M, N_temp), Mat(spa_base, M, n), Mat(spu, n, N_temp), Val(false))
841+
schur_complement!(
842+
Mat(spa, M, N_temp),
843+
Mat(spa_base, M, n),
844+
Mat(spu, n, N_temp),
845+
Val(false)
846+
)
843847
end
844848
end
845849
function rdiv_block_MandN!(
@@ -974,7 +978,7 @@ end
974978
n = Nr # non factor of W remainder
975979
if n > 0
976980
let t = (spa, spu),
977-
ft = LoopVectorization.flatten_to_tuple(t),
981+
ft = flatten_to_tup(t),
978982
mask = $(getfield(_mask(WS, r), :u) % UInt32)
979983

980984
BdivU_small_kern!(n, mask, $WS, $(Val(UNIT)), typeof(t), ft...)
@@ -1007,14 +1011,14 @@ end
10071011
if W == 2
10081012
quote
10091013
$(Expr(:meta, :inline))
1010-
spa, spu = LoopVectorization.reassemble_tuple(Args, args)
1014+
spa, spu = reassemble_tup(Args, args)
10111015
_ldiv_remainder!(spa, spu, N, Nr, $WS, $(Val(UNIT)), $(static(1)))
10121016
nothing
10131017
end
10141018
elseif W == 8
10151019
quote
10161020
# $(Expr(:meta, :inline))
1017-
spa, spu = LoopVectorization.reassemble_tuple(Args, args)
1021+
spa, spu = reassemble_tup(Args, args)
10181022
if m == M - 1
10191023
_ldiv_remainder!(spa, spu, N, Nr, static(8), $(Val(UNIT)), StaticInt(1))
10201024
else
@@ -1093,7 +1097,7 @@ end
10931097
else
10941098
quote
10951099
# $(Expr(:meta, :inline))
1096-
spa, spu = LoopVectorization.reassemble_tuple(Args, args)
1100+
spa, spu = reassemble_tup(Args, args)
10971101
Base.Cartesian.@nif $(W - 1) w -> m == M - w w ->
10981102
_ldiv_remainder!(spa, spu, N, Nr, $WS, $(Val(UNIT)), static(w))
10991103
nothing
@@ -1108,7 +1112,7 @@ end
11081112
::Val{UNIT}
11091113
) where {T,UNIT}
11101114
tup = (spa, spu)
1111-
ftup = LoopVectorization.flatten_to_tuple(tup)
1115+
ftup = flatten_to_tup(tup)
11121116
_ldiv_L!(M, N, Val(UNIT), typeof(tup), ftup...)
11131117
end
11141118

@@ -1122,7 +1126,7 @@ function _ldiv_L!(
11221126
::Type{Args},
11231127
args::Vararg{Any,K}
11241128
) where {UNIT,Args,K}
1125-
spa, spu = LoopVectorization.reassemble_tuple(Args, args)
1129+
spa, spu = reassemble_tup(Args, args)
11261130
T = eltype(spa)
11271131
WS = pick_vector_width(T)
11281132
W = Int(WS)
@@ -1134,7 +1138,7 @@ function _ldiv_L!(
11341138
while m < M - WS + 1
11351139
n = Nr # non factor of W remainder
11361140
if n > 0
1137-
let t = (spa, spu), ft = LoopVectorization.flatten_to_tuple(t)
1141+
let t = (spa, spu), ft = flatten_to_tup(t)
11381142
BdivU_small_kern_u!(n, StaticInt(1), Val(UNIT), WS, typeof(t), ft...)
11391143
end
11401144
end
@@ -1151,7 +1155,7 @@ function _ldiv_L!(
11511155
end
11521156
# remainder on `m`
11531157
if m < M
1154-
let tup = (spa, spu), ftup = LoopVectorization.flatten_to_tuple(tup)
1158+
let tup = (spa, spu), ftup = flatten_to_tup(tup)
11551159
ldiv_remainder!(M, N, m, Nr, WS, Val(UNIT), typeof(tup), ftup...)
11561160
end
11571161
end

0 commit comments

Comments
 (0)