Skip to content

Commit 2635dea

Browse files
roflmaostcnsajkoLilithHafnermkitti
authored
Add insertdims method which is inverse to dropdims (#45793)
Example: ```julia julia> a = [1 2; 3 4] 2×2 Matrix{Int64}: 1 2 3 4 julia> b = insertdims(a, dims=(1,4)) 1×2×2×1 Array{Int64, 4}: [:, :, 1, 1] = 1 3 [:, :, 2, 1] = 2 4 julia> b[1,1,1,1] = 5; a 2×2 Matrix{Int64}: 5 2 3 4 julia> b = insertdims(a, dims=(1,2)) 1×1×2×2 Array{Int64, 4}: [:, :, 1, 1] = 5 [:, :, 2, 1] = 3 [:, :, 1, 2] = 2 [:, :, 2, 2] = 4 julia> b = insertdims(a, dims=(1,3)) 1×2×1×2 Array{Int64, 4}: [:, :, 1, 1] = 5 3 [:, :, 1, 2] = 2 4 ``` --------- Co-authored-by: Neven Sajko <[email protected]> Co-authored-by: Lilith Orion Hafner <[email protected]> Co-authored-by: Mark Kittisopikul <[email protected]>
1 parent c6732a7 commit 2635dea

File tree

5 files changed

+96
-0
lines changed

5 files changed

+96
-0
lines changed

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ New library functions
7272
* The new `isfull(c::Channel)` function can be used to check if `put!(c, some_value)` will block. ([#53159])
7373
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).
7474
* `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]).
75+
* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims`
7576

7677
New library features
7778
--------------------

base/abstractarraymath.jl

+64
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,70 @@ function _dropdims(A::AbstractArray, dims::Dims)
9393
end
9494
_dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),))
9595

96+
97+
"""
98+
insertdims(A; dims)
99+
100+
Inverse of [`dropdims`](@ref); return an array with new singleton dimensions
101+
at every dimension in `dims`.
102+
103+
Repeated dimensions are forbidden and the largest entry in `dims` must be
104+
less than or equal than `ndims(A) + length(dims)`.
105+
106+
The result shares the same underlying data as `A`, such that the
107+
result is mutable if and only if `A` is mutable, and setting elements of one
108+
alters the values of the other.
109+
110+
See also: [`dropdims`](@ref), [`reshape`](@ref), [`vec`](@ref).
111+
# Examples
112+
```jldoctest
113+
julia> x = [1 2 3; 4 5 6]
114+
2×3 Matrix{Int64}:
115+
1 2 3
116+
4 5 6
117+
118+
julia> insertdims(x, dims=3)
119+
2×3×1 Array{Int64, 3}:
120+
[:, :, 1] =
121+
1 2 3
122+
4 5 6
123+
124+
julia> insertdims(x, dims=(1,2,5)) == reshape(x, 1, 1, 2, 3, 1)
125+
true
126+
127+
julia> dropdims(insertdims(x, dims=(1,2,5)), dims=(1,2,5))
128+
2×3 Matrix{Int64}:
129+
1 2 3
130+
4 5 6
131+
```
132+
133+
!!! compat "Julia 1.12"
134+
Requires Julia 1.12 or later.
135+
"""
136+
insertdims(A; dims) = _insertdims(A, dims)
137+
function _insertdims(A::AbstractArray{T, N}, dims::NTuple{M, Int}) where {T, N, M}
138+
for i in eachindex(dims)
139+
1 dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1."))
140+
dims[i] N+M || throw(ArgumentError("the largest entry in dims must be not larger than the dimension of the array and the length of dims added"))
141+
for j = 1:i-1
142+
dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique"))
143+
end
144+
end
145+
146+
# acc is a tuple, where the first entry is the final shape
147+
# the second entry off acc is a counter for the axes of A
148+
inds= Base._foldoneto((acc, i) ->
149+
i dims
150+
? ((acc[1]..., Base.OneTo(1)), acc[2])
151+
: ((acc[1]..., axes(A, acc[2])), acc[2] + 1),
152+
((), 1), Val(N+M))
153+
new_shape = inds[1]
154+
return reshape(A, new_shape)
155+
end
156+
_insertdims(A::AbstractArray, dim::Integer) = _insertdims(A, (Int(dim),))
157+
158+
159+
96160
## Unary operators ##
97161

98162
"""

base/exports.jl

+1
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ export
407407
indexin,
408408
argmax,
409409
argmin,
410+
insertdims,
410411
invperm,
411412
invpermute!,
412413
isassigned,

doc/src/base/arrays.md

+1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ Base.parentindices
138138
Base.selectdim
139139
Base.reinterpret
140140
Base.reshape
141+
Base.insertdims
141142
Base.dropdims
142143
Base.vec
143144
Base.SubArray

test/arrayops.jl

+29
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,35 @@ end
308308
@test_throws ArgumentError dropdims(a, dims=4)
309309
@test_throws ArgumentError dropdims(a, dims=6)
310310

311+
312+
a = rand(8, 7)
313+
@test @inferred(insertdims(a, dims=1)) == @inferred(insertdims(a, dims=(1,))) == reshape(a, (1, 8, 7))
314+
@test @inferred(insertdims(a, dims=3)) == @inferred(insertdims(a, dims=(3,))) == reshape(a, (8, 7, 1))
315+
@test @inferred(insertdims(a, dims=(1, 3))) == reshape(a, (1, 8, 1, 7))
316+
@test @inferred(insertdims(a, dims=(1, 2, 3))) == reshape(a, (1, 1, 1, 8, 7))
317+
@test @inferred(insertdims(a, dims=(1, 4))) == reshape(a, (1, 8, 7, 1))
318+
@test @inferred(insertdims(a, dims=(1, 3, 5))) == reshape(a, (1, 8, 1, 7, 1))
319+
@test @inferred(insertdims(a, dims=(1, 2, 4, 6))) == reshape(a, (1, 1, 8, 1, 7, 1))
320+
@test @inferred(insertdims(a, dims=(1, 3, 4, 6))) == reshape(a, (1, 8, 1, 1, 7, 1))
321+
@test @inferred(insertdims(a, dims=(1, 4, 6, 3))) == reshape(a, (1, 8, 1, 1, 7, 1))
322+
@test @inferred(insertdims(a, dims=(1, 3, 5, 6))) == reshape(a, (1, 8, 1, 7, 1, 1))
323+
324+
@test_throws ArgumentError insertdims(a, dims=(1, 1, 2, 3))
325+
@test_throws ArgumentError insertdims(a, dims=(1, 2, 2, 3))
326+
@test_throws ArgumentError insertdims(a, dims=(1, 2, 3, 3))
327+
@test_throws UndefKeywordError insertdims(a)
328+
@test_throws ArgumentError insertdims(a, dims=0)
329+
@test_throws ArgumentError insertdims(a, dims=(1, 2, 1))
330+
@test_throws ArgumentError insertdims(a, dims=4)
331+
@test_throws ArgumentError insertdims(a, dims=6)
332+
333+
# insertdims and dropdims are inverses
334+
b = rand(1,1,1,5,1,1,7)
335+
for dims in [1, (1,), 2, (2,), 3, (3,), (1,3), (1,2,3), (1,2), (1,3,5), (1,2,5,6), (1,3,5,6), (1,3,5,6), (1,6,5,3)]
336+
@test dropdims(insertdims(a; dims); dims) == a
337+
@test insertdims(dropdims(b; dims); dims) == b
338+
end
339+
311340
sz = (5,8,7)
312341
A = reshape(1:prod(sz),sz...)
313342
@test A[2:6] == [2:6;]

0 commit comments

Comments
 (0)