Skip to content

Commit 1e9839e

Browse files
committed
Out-of-place triu/tril for Symmetric in each branch
1 parent 39ee3af commit 1e9839e

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

src/symmetric.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -474,49 +474,49 @@ Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo)
474474
# tril/triu
475475
function tril(A::Hermitian, k::Integer=0)
476476
if A.uplo == 'U' && k <= 0
477-
return tril!(copy(A.data'),k)
477+
return tril_maybe_inplace(copy(A.data'),k)
478478
elseif A.uplo == 'U' && k > 0
479-
return tril!(copy(A.data'),-1) + tril!(triu(A.data),k)
479+
return tril_maybe_inplace(copy(A.data'),-1) + tril_maybe_inplace(triu(A.data),k)
480480
elseif A.uplo == 'L' && k <= 0
481481
return tril(A.data,k)
482482
else
483-
return tril(A.data,-1) + tril!(triu!(copy(A.data')),k)
483+
return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(A.data')),k)
484484
end
485485
end
486486

487487
function tril(A::Symmetric, k::Integer=0)
488488
if A.uplo == 'U' && k <= 0
489-
return tril!(copy(transpose(A.data)),k)
489+
return tril_maybe_inplace(copy(transpose(A.data)),k)
490490
elseif A.uplo == 'U' && k > 0
491-
return tril!(copy(transpose(A.data)),-1) + tril!(triu(A.data),k)
491+
return tril_maybe_inplace(copy(transpose(A.data)),-1) + tril_maybe_inplace(triu(A.data),k)
492492
elseif A.uplo == 'L' && k <= 0
493493
return tril(A.data,k)
494494
else
495-
return tril(A.data,-1) + tril!(triu!(copy(transpose(A.data))),k)
495+
return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(transpose(A.data))),k)
496496
end
497497
end
498498

499499
function triu(A::Hermitian, k::Integer=0)
500500
if A.uplo == 'U' && k >= 0
501501
return triu(A.data,k)
502502
elseif A.uplo == 'U' && k < 0
503-
return triu(A.data,1) + triu!(tril!(copy(A.data')),k)
503+
return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(A.data')),k)
504504
elseif A.uplo == 'L' && k >= 0
505-
return triu!(copy(A.data'),k)
505+
return triu_maybe_inplace(copy(A.data'),k)
506506
else
507-
return triu!(copy(A.data'),1) + triu!(tril(A.data),k)
507+
return triu_maybe_inplace(copy(A.data'),1) + triu_maybe_inplace(tril(A.data),k)
508508
end
509509
end
510510

511511
function triu(A::Symmetric, k::Integer=0)
512512
if A.uplo == 'U' && k >= 0
513513
return triu(A.data,k)
514514
elseif A.uplo == 'U' && k < 0
515-
return triu(A.data,1) + triu!(tril!(copy(transpose(A.data))),k)
515+
return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(transpose(A.data))),k)
516516
elseif A.uplo == 'L' && k >= 0
517-
return triu!(copy(transpose(A.data)),k)
517+
return triu_maybe_inplace(copy(transpose(A.data)),k)
518518
else
519-
return triu!(copy(transpose(A.data)),1) + triu!(tril(A.data),k)
519+
return triu_maybe_inplace(copy(transpose(A.data)),1) + triu_maybe_inplace(tril(A.data),k)
520520
end
521521
end
522522

src/triangular.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,11 @@ function tril!(A::UnitLowerTriangular, k::Integer=0)
484484
return tril!(LowerTriangular(A.data), k)
485485
end
486486

487+
tril_maybe_inplace(A, k::Integer=0) = tril(A, k)
488+
triu_maybe_inplace(A, k::Integer=0) = triu(A, k)
489+
tril_maybe_inplace(A::StridedMatrix, k::Integer=0) = tril!(A, k)
490+
triu_maybe_inplace(A::StridedMatrix, k::Integer=0) = triu!(A, k)
491+
487492
adjoint(A::LowerTriangular) = UpperTriangular(adjoint(A.data))
488493
adjoint(A::UpperTriangular) = LowerTriangular(adjoint(A.data))
489494
adjoint(A::UnitLowerTriangular) = UnitUpperTriangular(adjoint(A.data))

test/symmetric.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,4 +1209,27 @@ end
12091209
end
12101210
end
12111211

1212+
@testset "triu/tril with immutable arrays" begin
1213+
struct ImmutableMatrix{T,A<:AbstractMatrix{T}} <: AbstractMatrix{T}
1214+
a :: A
1215+
end
1216+
Base.size(A::ImmutableMatrix) = size(A.a)
1217+
Base.getindex(A::ImmutableMatrix, i::Int, j::Int) = getindex(A.a, i, j)
1218+
Base.copy(A::ImmutableMatrix) = A
1219+
LinearAlgebra.adjoint(A::ImmutableMatrix) = ImmutableMatrix(adjoint(A.a))
1220+
LinearAlgebra.transpose(A::ImmutableMatrix) = ImmutableMatrix(transpose(A.a))
1221+
1222+
A = ImmutableMatrix([1 2; 3 4])
1223+
for T in (Symmetric, Hermitian), uplo in (:U, :L)
1224+
H = T(A, uplo)
1225+
MH = Matrix(H)
1226+
@test triu(H,-1) == triu(MH,-1)
1227+
@test triu(H) == triu(MH)
1228+
@test triu(H,1) == triu(MH,1)
1229+
@test tril(H,1) == tril(MH,1)
1230+
@test tril(H) == tril(MH)
1231+
@test tril(H,-1) == tril(MH,-1)
1232+
end
1233+
end
1234+
12121235
end # module TestSymmetric

0 commit comments

Comments
 (0)