1
1
using LoopVectorization
2
2
using TriangularSolve: ldiv!
3
- using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS, LinearAlgebra, Adjoint, Transpose
3
+ using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4
+ LinearAlgebra, Adjoint, Transpose
4
5
using StrideArraysCore
5
6
using Polyester: @batch
6
7
8
+ @generated function _unit_lower_triangular (B:: A ) where {T, A <: AbstractMatrix{T} }
9
+ Expr (:new , UnitLowerTriangular{T, A}, :B )
10
+ end
7
11
# 1.7 compat
8
- normalize_pivot (t:: Val{T} ) where T = t
12
+ normalize_pivot (t:: Val{T} ) where {T} = t
9
13
to_stdlib_pivot (t) = t
10
14
if VERSION >= v " 1.7.0-DEV.1188"
11
15
normalize_pivot (:: LinearAlgebra.RowMaximum ) = Val (true )
@@ -18,19 +22,20 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
18
22
return lu! (copy (A), normalize_pivot (pivot), thread; kwargs... )
19
23
end
20
24
21
- function lu! (A, pivot = Val (true ), thread = Val (true ); check= true , kwargs... )
22
- m, n = size (A)
25
+ function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
26
+ m, n = size (A)
23
27
minmn = min (m, n)
24
28
F = if minmn < 10 # avx introduces small performance degradation
25
- LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check= check)
29
+ LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check = check)
26
30
else
27
- lu! (A, Vector {BlasInt} (undef, minmn), normalize_pivot (pivot), thread; check= check, kwargs... )
31
+ lu! (A, Vector {BlasInt} (undef, minmn), normalize_pivot (pivot), thread; check = check,
32
+ kwargs... )
28
33
end
29
34
return F
30
35
end
31
36
32
37
for (f, T) in [(:adjoint , :Adjoint ), (:transpose , :Transpose )], lu in (:lu , :lu! )
33
- @eval $ lu (A:: $T , args... ; kwargs... ) = $ f ($ lu (parent (A), args... ; kwargs... ))
38
+ @eval $ lu (A:: $T , args... ; kwargs... ) = $ f ($ lu (parent (A), args... ; kwargs... ))
34
39
end
35
40
36
41
const RECURSION_THRESHOLD = Ref (- 1 )
44
49
recurse (:: StridedArray ) = true
45
50
recurse (_) = false
46
51
47
- function lu! (
48
- A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
49
- pivot = Val (true ), thread = Val (true );
50
- check:: Bool = true ,
51
- # the performance is not sensitive wrt blocksize, and 8 is a good default
52
- blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
53
- threshold:: Integer = pick_threshold ()
54
- ) where T
52
+ function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
53
+ pivot = Val (true ), thread = Val (true );
54
+ check:: Bool = true ,
55
+ # the performance is not sensitive wrt blocksize, and 8 is a good default
56
+ blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
57
+ threshold:: Integer = pick_threshold ()) where {T}
55
58
pivot = normalize_pivot (pivot)
56
59
info = zero (BlasInt)
57
60
m, n = size (A)
58
61
mnmin = min (m, n)
59
62
if recurse (A) && mnmin > threshold
60
- if T <: Union{Float32,Float64}
61
- GC. @preserve ipiv A begin
62
- info = recurse! ( PtrArray (A), pivot, m, n, mnmin, PtrArray (ipiv), info, blocksize, thread)
63
- end
63
+ if T <: Union{Float32, Float64}
64
+ GC. @preserve ipiv A begin info = recurse! ( PtrArray (A), pivot, m, n, mnmin,
65
+ PtrArray (ipiv), info, blocksize,
66
+ thread) end
64
67
else
65
68
info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
66
69
end
@@ -71,30 +74,33 @@ function lu!(
71
74
LU {T, typeof(A)} (A, ipiv, info)
72
75
end
73
76
74
- @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{true} ) where {Pivot}
75
- if length (A) * _sizeof (eltype (A)) > 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (1 ))
76
- _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
77
- else
78
- _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
79
- end
77
+ @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
78
+ :: Val{true} ) where {Pivot}
79
+ if length (A) * _sizeof (eltype (A)) >
80
+ 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (1 ))
81
+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
82
+ else
83
+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
84
+ end
80
85
end
81
- @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{false} ) where {Pivot}
82
- _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
86
+ @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
87
+ :: Val{false} ) where {Pivot}
88
+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
83
89
end
84
- @inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{Thread} ) where {Pivot,Thread}
85
- info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread))
86
- @inbounds if m < n # fat matrix
87
- # [AL AR]
88
- AL = @view A[:, 1 : m]
89
- AR = @view A[:, m+ 1 : n]
90
- apply_permutation! (ipiv, AR, Val (Thread))
91
- ldiv! (UnitLowerTriangular (AL), AR, Val (Thread))
92
- end
93
- info
90
+ @inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
91
+ :: Val{Thread} ) where {Pivot, Thread}
92
+ info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread)):: Int
93
+ @inbounds if m < n # fat matrix
94
+ # [AL AR]
95
+ AL = @view A[:, 1 : m]
96
+ AR = @view A[:, (m + 1 ): n]
97
+ apply_permutation! (ipiv, AR, Val (Thread))
98
+ ldiv! (_unit_lower_triangular (AL), AR, Val (Thread))
99
+ end
100
+ info
94
101
end
95
102
96
-
97
- @inline function nsplit (:: Type{T} , n) where T
103
+ @inline function nsplit (:: Type{T} , n) where {T}
98
104
k = 512 ÷ (isbitstype (T) ? sizeof (T) : 8 )
99
105
k_2 = k ÷ 2
100
106
return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
@@ -125,8 +131,8 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
125
131
end
126
132
nothing
127
133
end
128
-
129
- function reckernel! (A :: AbstractMatrix{T} , pivot :: Val{Pivot} , m, n, ipiv, info, blocksize, thread):: BlasInt where {T,Pivot}
134
+ function reckernel! (A :: AbstractMatrix{T} , pivot :: Val{Pivot} , m, n, ipiv, info, blocksize,
135
+ thread):: BlasInt where {T, Pivot}
130
136
@inbounds begin
131
137
if n <= max (blocksize, 1 )
132
138
info = _generic_lufact! (A, Val (Pivot), ipiv, info)
@@ -147,18 +153,18 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
147
153
# Partition the matrix A
148
154
# [AL AR]
149
155
AL = @view A[:, 1 : n1]
150
- AR = @view A[:, n1 + 1 : n]
156
+ AR = @view A[:, (n1 + 1 ) : n]
151
157
# AL AR
152
158
# [A11 A12]
153
159
# [A21 A22]
154
160
A11 = @view A[1 : n1, 1 : n1]
155
- A12 = @view A[1 : n1, n1 + 1 : n]
156
- A21 = @view A[n1 + 1 : m, 1 : n1]
157
- A22 = @view A[n1 + 1 : m, n1 + 1 : n]
161
+ A12 = @view A[1 : n1, (n1 + 1 ) : n]
162
+ A21 = @view A[(n1 + 1 ) : m, 1 : n1]
163
+ A22 = @view A[(n1 + 1 ) : m, (n1 + 1 ) : n]
158
164
# [P1]
159
165
# [P2]
160
166
P1 = @view ipiv[1 : n1]
161
- P2 = @view ipiv[n1 + 1 : n]
167
+ P2 = @view ipiv[(n1 + 1 ) : n]
162
168
# ========================================
163
169
164
170
# [ A11 ] [ L11 ]
@@ -170,7 +176,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
170
176
# [ A22 ] [ 0 ] [ A22 ]
171
177
Pivot && apply_permutation! (P1, AR, thread)
172
178
# A12 = L11 U12 => U12 = L11 \ A12
173
- ldiv! (UnitLowerTriangular (A11), A12, thread)
179
+ ldiv! (_unit_lower_triangular (A11), A12, thread)
174
180
# Schur complement:
175
181
# We have A22 = L21 U12 + A′22, hence
176
182
# A′22 = A22 - L21 U12
@@ -191,23 +197,23 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
191
197
end # inbounds
192
198
end
193
199
194
- function schur_complement! (𝐂, 𝐀, 𝐁,:: Val{THREAD} = Val (true )) where {THREAD}
200
+ function schur_complement! (𝐂, 𝐀, 𝐁, :: Val{THREAD} = Val (true )) where {THREAD}
195
201
# mul!(𝐂,𝐀,𝐁,-1,1)
196
202
if THREAD
197
- @tturbo warn_check_args= false for m ∈ 1 : size (𝐀,1 ), n ∈ 1 : size (𝐁,2 )
203
+ @tturbo warn_check_args= false for m in 1 : size (𝐀, 1 ), n in 1 : size (𝐁, 2 )
198
204
𝐂ₘₙ = zero (eltype (𝐂))
199
- for k ∈ 1 : size (𝐀,2 )
200
- 𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
205
+ for k in 1 : size (𝐀, 2 )
206
+ 𝐂ₘₙ -= 𝐀[m, k] * 𝐁[k, n]
201
207
end
202
- 𝐂[m,n] = 𝐂ₘₙ + 𝐂[m,n]
208
+ 𝐂[m, n] = 𝐂ₘₙ + 𝐂[m, n]
203
209
end
204
210
else
205
- @turbo warn_check_args= false for m ∈ 1 : size (𝐀,1 ), n ∈ 1 : size (𝐁,2 )
211
+ @turbo warn_check_args= false for m in 1 : size (𝐀, 1 ), n in 1 : size (𝐁, 2 )
206
212
𝐂ₘₙ = zero (eltype (𝐂))
207
- for k ∈ 1 : size (𝐀,2 )
208
- 𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
213
+ for k in 1 : size (𝐀, 2 )
214
+ 𝐂ₘₙ -= 𝐀[m, k] * 𝐁[k, n]
209
215
end
210
- 𝐂[m,n] = 𝐂ₘₙ + 𝐂[m,n]
216
+ 𝐂[m, n] = 𝐂ₘₙ + 𝐂[m, n]
211
217
end
212
218
end
213
219
end
@@ -216,49 +222,47 @@ end
216
222
Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
217
223
License is MIT: https://julialang.org/license
218
224
=#
219
- function _generic_lufact! (A, :: Val{Pivot} , ipiv, info) where Pivot
225
+ function _generic_lufact! (A, :: Val{Pivot} , ipiv, info) where { Pivot}
220
226
m, n = size (A)
221
227
minmn = length (ipiv)
222
- @inbounds begin
223
- for k = 1 : minmn
224
- # find index max
225
- kp = k
226
- if Pivot
227
- amax = abs (zero (eltype (A)))
228
- for i = k: m
229
- absi = abs (A[i,k])
230
- if absi > amax
231
- kp = i
232
- amax = absi
233
- end
234
- end
235
- end
236
- ipiv[k] = kp
237
- if ! iszero (A[kp,k])
238
- if k != kp
239
- # Interchange
240
- @simd for i = 1 : n
241
- tmp = A[k,i]
242
- A[k,i] = A[kp,i]
243
- A[kp,i] = tmp
244
- end
228
+ @inbounds begin for k in 1 : minmn
229
+ # find index max
230
+ kp = k
231
+ if Pivot
232
+ amax = abs (zero (eltype (A)))
233
+ for i in k: m
234
+ absi = abs (A[i, k])
235
+ if absi > amax
236
+ kp = i
237
+ amax = absi
245
238
end
246
- # Scale first column
247
- Akkinv = inv (A[k,k])
248
- @turbo check_empty= true warn_check_args= false for i = k+ 1 : m
249
- A[i,k] *= Akkinv
250
- end
251
- elseif info == 0
252
- info = k
253
239
end
254
- k == minmn && break
255
- # Update the rest
256
- @turbo warn_check_args= false for j = k+ 1 : n
257
- for i = k+ 1 : m
258
- A[i,j] -= A[i,k]* A[k,j]
240
+ end
241
+ ipiv[k] = kp
242
+ if ! iszero (A[kp, k])
243
+ if k != kp
244
+ # Interchange
245
+ @simd for i in 1 : n
246
+ tmp = A[k, i]
247
+ A[k, i] = A[kp, i]
248
+ A[kp, i] = tmp
259
249
end
260
250
end
251
+ # Scale first column
252
+ Akkinv = inv (A[k, k])
253
+ @turbo check_empty= true warn_check_args= false for i in (k + 1 ): m
254
+ A[i, k] *= Akkinv
255
+ end
256
+ elseif info == 0
257
+ info = k
261
258
end
262
- end
259
+ k == minmn && break
260
+ # Update the rest
261
+ @turbo warn_check_args= false for j in (k + 1 ): n
262
+ for i in (k + 1 ): m
263
+ A[i, j] -= A[i, k] * A[k, j]
264
+ end
265
+ end
266
+ end end
263
267
return info
264
268
end
0 commit comments