Skip to content

Commit 2f06b0c

Browse files
authored
Merge pull request #46 from chriselrod/format
Format the entire library
2 parents d7b32c6 + cab3237 commit 2f06b0c

File tree

4 files changed

+158
-152
lines changed

4 files changed

+158
-152
lines changed

perf/lu.jl

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ nc = min(Int(VectorizationBase.num_cores()), Threads.nthreads())
44
BLAS.set_num_threads(nc)
55
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5
66

7-
function luflop(m, n=m; innerflop=2)
7+
function luflop(m, n = m; innerflop = 2)
88
sum(1:min(m, n)) do k
99
invflop = 1
10-
scaleflop = isempty(k+1:m) ? 0 : sum(k+1:m)
11-
updateflop = isempty(k+1:n) ? 0 : sum(k+1:n) do j
12-
isempty(k+1:m) ? 0 : sum(k+1:m) do i
10+
scaleflop = isempty((k + 1):m) ? 0 : sum((k + 1):m)
11+
updateflop = isempty((k + 1):n) ? 0 :
12+
sum((k + 1):n) do j
13+
isempty((k + 1):m) ? 0 : sum((k + 1):m) do i
1314
innerflop
1415
end
1516
end
@@ -28,45 +29,45 @@ for n in ns
2829
rng = MersenneTwister(123)
2930
global A = rand(rng, n, n)
3031
bt = @belapsed LinearAlgebra.lu!(B) setup=(B = copy(A))
31-
push!(bas_mflops, luflop(n)/bt/1e9)
32+
push!(bas_mflops, luflop(n) / bt / 1e9)
3233

3334
rt = @belapsed RecursiveFactorization.lu!(B) setup=(B = copy(A))
34-
push!(rec_mflops, luflop(n)/rt/1e9)
35+
push!(rec_mflops, luflop(n) / rt / 1e9)
3536

36-
rt4 = @belapsed RecursiveFactorization.lu!(B; threshold=4) setup=(B = copy(A))
37-
push!(rec4_mflops, luflop(n)/rt4/1e9)
37+
rt4 = @belapsed RecursiveFactorization.lu!(B; threshold = 4) setup=(B = copy(A))
38+
push!(rec4_mflops, luflop(n) / rt4 / 1e9)
3839

39-
rt800 = @belapsed RecursiveFactorization.lu!(B; threshold=800) setup=(B = copy(A))
40-
push!(rec800_mflops, luflop(n)/rt800/1e9)
40+
rt800 = @belapsed RecursiveFactorization.lu!(B; threshold = 800) setup=(B = copy(A))
41+
push!(rec800_mflops, luflop(n) / rt800 / 1e9)
4142

4243
ref = @belapsed LinearAlgebra.generic_lufact!(B) setup=(B = copy(A))
43-
push!(ref_mflops, luflop(n)/ref/1e9)
44+
push!(ref_mflops, luflop(n) / ref / 1e9)
4445
end
4546

4647
using DataFrames, VegaLite
4748
blaslib = if VERSION v"1.7.0-beta2"
48-
config = BLAS.get_config().loaded_libs
49-
occursin("libmkl_rt", config[1].libname) ? :MKL : :OpenBLAS
49+
config = BLAS.get_config().loaded_libs
50+
occursin("libmkl_rt", config[1].libname) ? :MKL : :OpenBLAS
5051
else
51-
BLAS.vendor() === :mkl ? :MKL : :OpenBLAS
52+
BLAS.vendor() === :mkl ? :MKL : :OpenBLAS
5253
end
5354
df = DataFrame(Size = ns,
5455
Reference = ref_mflops)
5556
setproperty!(df, blaslib, bas_mflops)
5657
setproperty!(df, Symbol("RF with default threshold"), rec_mflops)
5758
setproperty!(df, Symbol("RF fully recursive"), rec4_mflops)
5859
setproperty!(df, Symbol("RF fully iterative"), rec800_mflops)
59-
df = stack(df, [Symbol("RF with default threshold"),
60-
Symbol("RF fully recursive"),
61-
Symbol("RF fully iterative"),
62-
blaslib,
63-
:Reference], variable_name = :Library, value_name = :GFLOPS)
64-
plt = df |> @vlplot(
65-
:line, color = {:Library, scale={scheme="category10"}},
66-
x = {:Size}, y = {:GFLOPS},
67-
width = 1000, height = 600
68-
)
69-
save(joinpath(homedir(), "Pictures", "lu_float64_$(VERSION)_$(Sys.CPU_NAME)_$(nc)cores_$blaslib.png"), plt)
60+
df = stack(df,
61+
[Symbol("RF with default threshold"),
62+
Symbol("RF fully recursive"),
63+
Symbol("RF fully iterative"),
64+
blaslib,
65+
:Reference], variable_name = :Library, value_name = :GFLOPS)
66+
plt = df |> @vlplot(:line, color={:Library, scale = {scheme = "category10"}},
67+
x={:Size}, y={:GFLOPS},
68+
width=1000, height=600)
69+
save(joinpath(homedir(), "Pictures",
70+
"lu_float64_$(VERSION)_$(Sys.CPU_NAME)_$(nc)cores_$blaslib.png"), plt)
7071

7172
#=
7273
using Plot

src/RecursiveFactorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ include("./lu.jl")
44

55
let
66
while true
7-
lu!(rand(2,2))
7+
lu!(rand(2, 2))
88
break
99
end
1010
end

src/lu.jl

Lines changed: 97 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
using LoopVectorization
22
using TriangularSolve: ldiv!
3-
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS, LinearAlgebra, Adjoint, Transpose
3+
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4+
LinearAlgebra, Adjoint, Transpose
45
using StrideArraysCore
56
using Polyester: @batch
67

8+
@generated function _unit_lower_triangular(B::A) where {T, A <: AbstractMatrix{T}}
9+
Expr(:new, UnitLowerTriangular{T, A}, :B)
10+
end
711
# 1.7 compat
8-
normalize_pivot(t::Val{T}) where T = t
12+
normalize_pivot(t::Val{T}) where {T} = t
913
to_stdlib_pivot(t) = t
1014
if VERSION >= v"1.7.0-DEV.1188"
1115
normalize_pivot(::LinearAlgebra.RowMaximum) = Val(true)
@@ -18,19 +22,20 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
1822
return lu!(copy(A), normalize_pivot(pivot), thread; kwargs...)
1923
end
2024

21-
function lu!(A, pivot = Val(true), thread = Val(true); check=true, kwargs...)
22-
m, n = size(A)
25+
function lu!(A, pivot = Val(true), thread = Val(true); check = true, kwargs...)
26+
m, n = size(A)
2327
minmn = min(m, n)
2428
F = if minmn < 10 # avx introduces small performance degradation
25-
LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); check=check)
29+
LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); check = check)
2630
else
27-
lu!(A, Vector{BlasInt}(undef, minmn), normalize_pivot(pivot), thread; check=check, kwargs...)
31+
lu!(A, Vector{BlasInt}(undef, minmn), normalize_pivot(pivot), thread; check = check,
32+
kwargs...)
2833
end
2934
return F
3035
end
3136

3237
for (f, T) in [(:adjoint, :Adjoint), (:transpose, :Transpose)], lu in (:lu, :lu!)
33-
@eval $lu(A::$T, args...; kwargs...) = $f($lu(parent(A), args...; kwargs...))
38+
@eval $lu(A::$T, args...; kwargs...) = $f($lu(parent(A), args...; kwargs...))
3439
end
3540

3641
const RECURSION_THRESHOLD = Ref(-1)
@@ -44,23 +49,21 @@ end
4449
recurse(::StridedArray) = true
4550
recurse(_) = false
4651

47-
function lu!(
48-
A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
49-
pivot = Val(true), thread = Val(true);
50-
check::Bool=true,
51-
# the performance is not sensitive wrt blocksize, and 8 is a good default
52-
blocksize::Integer=length(A) 40_000 ? 8 : 16,
53-
threshold::Integer=pick_threshold()
54-
) where T
52+
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
53+
pivot = Val(true), thread = Val(true);
54+
check::Bool = true,
55+
# the performance is not sensitive wrt blocksize, and 8 is a good default
56+
blocksize::Integer = length(A) 40_000 ? 8 : 16,
57+
threshold::Integer = pick_threshold()) where {T}
5558
pivot = normalize_pivot(pivot)
5659
info = zero(BlasInt)
5760
m, n = size(A)
5861
mnmin = min(m, n)
5962
if recurse(A) && mnmin > threshold
60-
if T <: Union{Float32,Float64}
61-
GC.@preserve ipiv A begin
62-
info = recurse!(PtrArray(A), pivot, m, n, mnmin, PtrArray(ipiv), info, blocksize, thread)
63-
end
63+
if T <: Union{Float32, Float64}
64+
GC.@preserve ipiv A begin info = recurse!(PtrArray(A), pivot, m, n, mnmin,
65+
PtrArray(ipiv), info, blocksize,
66+
thread) end
6467
else
6568
info = recurse!(A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
6669
end
@@ -71,30 +74,33 @@ function lu!(
7174
LU{T, typeof(A)}(A, ipiv, info)
7275
end
7376

74-
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize, ::Val{true}) where {Pivot}
75-
if length(A) * _sizeof(eltype(A)) > 0.92 * LoopVectorization.VectorizationBase.cache_size(Val(1))
76-
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(true))
77-
else
78-
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false))
79-
end
77+
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
78+
::Val{true}) where {Pivot}
79+
if length(A) * _sizeof(eltype(A)) >
80+
0.92 * LoopVectorization.VectorizationBase.cache_size(Val(1))
81+
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(true))
82+
else
83+
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false))
84+
end
8085
end
81-
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize, ::Val{false}) where {Pivot}
82-
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false))
86+
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
87+
::Val{false}) where {Pivot}
88+
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false))
8389
end
84-
@inline function _recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize, ::Val{Thread}) where {Pivot,Thread}
85-
info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, Val(Thread))
86-
@inbounds if m < n # fat matrix
87-
# [AL AR]
88-
AL = @view A[:, 1:m]
89-
AR = @view A[:, m+1:n]
90-
apply_permutation!(ipiv, AR, Val(Thread))
91-
ldiv!(UnitLowerTriangular(AL), AR, Val(Thread))
92-
end
93-
info
90+
@inline function _recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
91+
::Val{Thread}) where {Pivot, Thread}
92+
info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, Val(Thread))::Int
93+
@inbounds if m < n # fat matrix
94+
# [AL AR]
95+
AL = @view A[:, 1:m]
96+
AR = @view A[:, (m + 1):n]
97+
apply_permutation!(ipiv, AR, Val(Thread))
98+
ldiv!(_unit_lower_triangular(AL), AR, Val(Thread))
99+
end
100+
info
94101
end
95102

96-
97-
@inline function nsplit(::Type{T}, n) where T
103+
@inline function nsplit(::Type{T}, n) where {T}
98104
k = 512 ÷ (isbitstype(T) ? sizeof(T) : 8)
99105
k_2 = k ÷ 2
100106
return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
@@ -125,8 +131,8 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
125131
end
126132
nothing
127133
end
128-
129-
function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize, thread)::BlasInt where {T,Pivot}
134+
function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize,
135+
thread)::BlasInt where {T, Pivot}
130136
@inbounds begin
131137
if n <= max(blocksize, 1)
132138
info = _generic_lufact!(A, Val(Pivot), ipiv, info)
@@ -147,18 +153,18 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
147153
# Partition the matrix A
148154
# [AL AR]
149155
AL = @view A[:, 1:n1]
150-
AR = @view A[:, n1+1:n]
156+
AR = @view A[:, (n1 + 1):n]
151157
# AL AR
152158
# [A11 A12]
153159
# [A21 A22]
154160
A11 = @view A[1:n1, 1:n1]
155-
A12 = @view A[1:n1, n1+1:n]
156-
A21 = @view A[n1+1:m, 1:n1]
157-
A22 = @view A[n1+1:m, n1+1:n]
161+
A12 = @view A[1:n1, (n1 + 1):n]
162+
A21 = @view A[(n1 + 1):m, 1:n1]
163+
A22 = @view A[(n1 + 1):m, (n1 + 1):n]
158164
# [P1]
159165
# [P2]
160166
P1 = @view ipiv[1:n1]
161-
P2 = @view ipiv[n1+1:n]
167+
P2 = @view ipiv[(n1 + 1):n]
162168
# ========================================
163169

164170
# [ A11 ] [ L11 ]
@@ -170,7 +176,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
170176
# [ A22 ] [ 0 ] [ A22 ]
171177
Pivot && apply_permutation!(P1, AR, thread)
172178
# A12 = L11 U12 => U12 = L11 \ A12
173-
ldiv!(UnitLowerTriangular(A11), A12, thread)
179+
ldiv!(_unit_lower_triangular(A11), A12, thread)
174180
# Schur complement:
175181
# We have A22 = L21 U12 + A′22, hence
176182
# A′22 = A22 - L21 U12
@@ -191,23 +197,23 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
191197
end # inbounds
192198
end
193199

194-
function schur_complement!(𝐂, 𝐀, 𝐁,::Val{THREAD}=Val(true)) where {THREAD}
200+
function schur_complement!(𝐂, 𝐀, 𝐁, ::Val{THREAD} = Val(true)) where {THREAD}
195201
# mul!(𝐂,𝐀,𝐁,-1,1)
196202
if THREAD
197-
@tturbo warn_check_args=false for m 1:size(𝐀,1), n 1:size(𝐁,2)
203+
@tturbo warn_check_args=false for m in 1:size(𝐀, 1), n in 1:size(𝐁, 2)
198204
𝐂ₘₙ = zero(eltype(𝐂))
199-
for k 1:size(𝐀,2)
200-
𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
205+
for k in 1:size(𝐀, 2)
206+
𝐂ₘₙ -= 𝐀[m, k] * 𝐁[k, n]
201207
end
202-
𝐂[m,n] = 𝐂ₘₙ + 𝐂[m,n]
208+
𝐂[m, n] = 𝐂ₘₙ + 𝐂[m, n]
203209
end
204210
else
205-
@turbo warn_check_args=false for m 1:size(𝐀,1), n 1:size(𝐁,2)
211+
@turbo warn_check_args=false for m in 1:size(𝐀, 1), n in 1:size(𝐁, 2)
206212
𝐂ₘₙ = zero(eltype(𝐂))
207-
for k 1:size(𝐀,2)
208-
𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
213+
for k in 1:size(𝐀, 2)
214+
𝐂ₘₙ -= 𝐀[m, k] * 𝐁[k, n]
209215
end
210-
𝐂[m,n] = 𝐂ₘₙ + 𝐂[m,n]
216+
𝐂[m, n] = 𝐂ₘₙ + 𝐂[m, n]
211217
end
212218
end
213219
end
@@ -216,49 +222,47 @@ end
216222
Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
217223
License is MIT: https://julialang.org/license
218224
=#
219-
function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
225+
function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
220226
m, n = size(A)
221227
minmn = length(ipiv)
222-
@inbounds begin
223-
for k = 1:minmn
224-
# find index max
225-
kp = k
226-
if Pivot
227-
amax = abs(zero(eltype(A)))
228-
for i = k:m
229-
absi = abs(A[i,k])
230-
if absi > amax
231-
kp = i
232-
amax = absi
233-
end
234-
end
235-
end
236-
ipiv[k] = kp
237-
if !iszero(A[kp,k])
238-
if k != kp
239-
# Interchange
240-
@simd for i = 1:n
241-
tmp = A[k,i]
242-
A[k,i] = A[kp,i]
243-
A[kp,i] = tmp
244-
end
228+
@inbounds begin for k in 1:minmn
229+
# find index max
230+
kp = k
231+
if Pivot
232+
amax = abs(zero(eltype(A)))
233+
for i in k:m
234+
absi = abs(A[i, k])
235+
if absi > amax
236+
kp = i
237+
amax = absi
245238
end
246-
# Scale first column
247-
Akkinv = inv(A[k,k])
248-
@turbo check_empty=true warn_check_args=false for i = k+1:m
249-
A[i,k] *= Akkinv
250-
end
251-
elseif info == 0
252-
info = k
253239
end
254-
k == minmn && break
255-
# Update the rest
256-
@turbo warn_check_args=false for j = k+1:n
257-
for i = k+1:m
258-
A[i,j] -= A[i,k]*A[k,j]
240+
end
241+
ipiv[k] = kp
242+
if !iszero(A[kp, k])
243+
if k != kp
244+
# Interchange
245+
@simd for i in 1:n
246+
tmp = A[k, i]
247+
A[k, i] = A[kp, i]
248+
A[kp, i] = tmp
259249
end
260250
end
251+
# Scale first column
252+
Akkinv = inv(A[k, k])
253+
@turbo check_empty=true warn_check_args=false for i in (k + 1):m
254+
A[i, k] *= Akkinv
255+
end
256+
elseif info == 0
257+
info = k
261258
end
262-
end
259+
k == minmn && break
260+
# Update the rest
261+
@turbo warn_check_args=false for j in (k + 1):n
262+
for i in (k + 1):m
263+
A[i, j] -= A[i, k] * A[k, j]
264+
end
265+
end
266+
end end
263267
return info
264268
end

0 commit comments

Comments
 (0)