Skip to content

Commit 404fbb5

Browse files
authored
specialize one for FillArrays (#184)
1 parent 255f641 commit 404fbb5

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/FillArrays.jl

+13-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
66
+, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!,
77
any, all, axes, isone, iterate, unique, allunique, permutedims, inv,
88
copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero,
9-
show, view, in, mapreduce
9+
show, view, in, mapreduce, one
1010

1111
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
1212
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec, TransposeAbsVec,
@@ -336,7 +336,7 @@ axes(rd::Diagonal{<:Any,<:AbstractFill}) = (axes(rd.diag,1),axes(rd.diag,1))
336336
axes(T::AbstractTriangular{<:Any,<:AbstractFill}) = axes(parent(T))
337337

338338
axes(rd::RectDiagonal) = rd.axes
339-
size(rd::RectDiagonal) = length.(rd.axes)
339+
size(rd::RectDiagonal) = map(length, rd.axes)
340340

341341
@inline function getindex(rd::RectDiagonal{T}, i::Integer, j::Integer) where T
342342
@boundscheck checkbounds(rd, i, j)
@@ -551,6 +551,17 @@ zero(r::Zeros{T,N}) where {T,N} = r
551551
zero(r::Ones{T,N}) where {T,N} = Zeros{T,N}(r.axes)
552552
zero(r::Fill{T,N}) where {T,N} = Zeros{T,N}(r.axes)
553553

554+
#########
555+
# oneunit
556+
#########
557+
558+
function one(A::AbstractFill{T,2}) where {T}
559+
Base.require_one_based_indexing(A)
560+
m, n = size(A)
561+
m == n || throw(ArgumentError("multiplicative identity defined only for square matrices"))
562+
SquareEye{T}(m)
563+
end
564+
554565
#########
555566
# any/all/isone/iszero
556567
#########

test/runtests.jl

+9
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,15 @@ end
398398
@test convert(Diagonal{Int}, Eye(5)) == Diagonal(ones(Int,5))
399399
end
400400

401+
@testset "one" begin
402+
@testset for A in Any[Eye(4), Zeros(4,4), Ones(4,4), Fill(3,4,4)]
403+
B = one(A)
404+
@test B * A == A * B == A
405+
end
406+
@test_throws ArgumentError one(Ones(3,4))
407+
@test_throws ArgumentError one(Ones((3:5,4:5)))
408+
end
409+
401410
@testset "Sparse vectors and matrices" begin
402411
@test SparseVector(Zeros(5)) ==
403412
SparseVector{Float64}(Zeros(5)) ==

0 commit comments

Comments
 (0)