diff --git a/src/symmetric.jl b/src/symmetric.jl index 804a063b..c768f091 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -474,25 +474,25 @@ Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo) # tril/triu function tril(A::Hermitian, k::Integer=0) if A.uplo == 'U' && k <= 0 - return tril!(copy(A.data'),k) + return tril_maybe_inplace(copy(A.data'),k) elseif A.uplo == 'U' && k > 0 - return tril!(copy(A.data'),-1) + tril!(triu(A.data),k) + return tril_maybe_inplace(copy(A.data'),-1) + tril_maybe_inplace(triu(A.data),k) elseif A.uplo == 'L' && k <= 0 return tril(A.data,k) else - return tril(A.data,-1) + tril!(triu!(copy(A.data')),k) + return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(A.data')),k) end end function tril(A::Symmetric, k::Integer=0) if A.uplo == 'U' && k <= 0 - return tril!(copy(transpose(A.data)),k) + return tril_maybe_inplace(copy(transpose(A.data)),k) elseif A.uplo == 'U' && k > 0 - return tril!(copy(transpose(A.data)),-1) + tril!(triu(A.data),k) + return tril_maybe_inplace(copy(transpose(A.data)),-1) + tril_maybe_inplace(triu(A.data),k) elseif A.uplo == 'L' && k <= 0 return tril(A.data,k) else - return tril(A.data,-1) + tril!(triu!(copy(transpose(A.data))),k) + return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(transpose(A.data))),k) end end @@ -500,11 +500,11 @@ function triu(A::Hermitian, k::Integer=0) if A.uplo == 'U' && k >= 0 return triu(A.data,k) elseif A.uplo == 'U' && k < 0 - return triu(A.data,1) + triu!(tril!(copy(A.data')),k) + return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(A.data')),k) elseif A.uplo == 'L' && k >= 0 - return triu!(copy(A.data'),k) + return triu_maybe_inplace(copy(A.data'),k) else - return triu!(copy(A.data'),1) + triu!(tril(A.data),k) + return triu_maybe_inplace(copy(A.data'),1) + triu_maybe_inplace(tril(A.data),k) end end @@ -512,11 +512,11 @@ function triu(A::Symmetric, k::Integer=0) if A.uplo == 'U' && k >= 0 return triu(A.data,k) elseif A.uplo == 'U' && k < 0 - return triu(A.data,1) + triu!(tril!(copy(transpose(A.data))),k) + return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(transpose(A.data))),k) elseif A.uplo == 'L' && k >= 0 - return triu!(copy(transpose(A.data)),k) + return triu_maybe_inplace(copy(transpose(A.data)),k) else - return triu!(copy(transpose(A.data)),1) + triu!(tril(A.data),k) + return triu_maybe_inplace(copy(transpose(A.data)),1) + triu_maybe_inplace(tril(A.data),k) end end diff --git a/src/triangular.jl b/src/triangular.jl index 26ad4204..7c99f85a 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -484,6 +484,11 @@ function tril!(A::UnitLowerTriangular, k::Integer=0) return tril!(LowerTriangular(A.data), k) end +tril_maybe_inplace(A, k::Integer=0) = tril(A, k) +triu_maybe_inplace(A, k::Integer=0) = triu(A, k) +tril_maybe_inplace(A::StridedMatrix, k::Integer=0) = tril!(A, k) +triu_maybe_inplace(A::StridedMatrix, k::Integer=0) = triu!(A, k) + adjoint(A::LowerTriangular) = UpperTriangular(adjoint(A.data)) adjoint(A::UpperTriangular) = LowerTriangular(adjoint(A.data)) adjoint(A::UnitLowerTriangular) = UnitUpperTriangular(adjoint(A.data)) diff --git a/test/symmetric.jl b/test/symmetric.jl index ae97b453..a9d1a883 100644 --- a/test/symmetric.jl +++ b/test/symmetric.jl @@ -1206,4 +1206,18 @@ end end end +@testset "triu/tril with immutable arrays" begin + A = ImmutableArray([1 2; 3 4]) + for T in (Symmetric, Hermitian), uplo in (:U, :L) + H = T(A, uplo) + MH = Matrix(H) + @test triu(H,-1) == triu(MH,-1) + @test triu(H) == triu(MH) + @test triu(H,1) == triu(MH,1) + @test tril(H,1) == tril(MH,1) + @test tril(H) == tril(MH) + @test tril(H,-1) == tril(MH,-1) + end +end + end # module TestSymmetric diff --git a/test/testhelpers/ImmutableArrays.jl b/test/testhelpers/ImmutableArrays.jl index 8f2d23be..014e8110 100644 --- a/test/testhelpers/ImmutableArrays.jl +++ b/test/testhelpers/ImmutableArrays.jl @@ -28,4 +28,7 @@ AbstractArray{T,N}(A::ImmutableArray{S,N}) where {S,T,N} = ImmutableArray(Abstra Base.copy(A::ImmutableArray) = ImmutableArray(copy(A.data)) Base.zero(A::ImmutableArray) = ImmutableArray(zero(A.data)) +Base.adjoint(A::ImmutableArray) = ImmutableArray(adjoint(A.data)) +Base.transpose(A::ImmutableArray) = ImmutableArray(transpose(A.data)) + end