Skip to content

Commit 5afb189

Browse files
authored
[Containers] document nested and vectorized_product (#2751)
1 parent 55af1ee commit 5afb189

File tree

3 files changed

+102
-32
lines changed

3 files changed

+102
-32
lines changed

Diff for: docs/src/reference/containers.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ Containers.DenseAxisArray
88
Containers.SparseAxisArray
99
Containers.container
1010
Containers.default_container
11+
Containers.@container
1112
Containers.VectorizedProductIterator
13+
Containers.vectorized_product
1214
Containers.NestedIterator
13-
Containers.@container
15+
Containers.nested
1416
```

Diff for: src/Containers/nested_iterator.jl

+68-23
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,20 @@
1010
end
1111
1212
Iterators over the tuples that are produced by a nested for loop.
13-
For instance, if `length(iterators) == 3` , this corresponds to the tuples
14-
`(i1, i2, i3)` produced by:
15-
```
13+
14+
Construct a `NestedIterator` using [`nested`](@ref).
15+
16+
## Example
17+
18+
If `length(iterators) == 3` ,
19+
```julia
20+
x = NestedIterator(iterators, condition)
21+
for (i1, i2, i3) in x
22+
# produces (i1, i2, i3)
23+
end
24+
````
25+
is the same as
26+
```julia
1627
for i1 in iterators[1]()
1728
for i2 in iterator[2](i1)
1829
for i3 in iterator[3](i1, i2)
@@ -28,40 +39,69 @@ struct NestedIterator{T,C}
2839
iterators::T # Tuple of functions
2940
condition::C
3041
end
42+
43+
"""
44+
nested(iterators...; condition = (args...) -> true)
45+
46+
Create a [`NestedIterator`](@ref).
47+
48+
## Example
49+
50+
```julia
51+
nested(1:2, ["A", "B"]; condition = (i, j) -> isodd(i) || j == "B")
52+
```
53+
"""
3154
function nested(iterators...; condition = (args...) -> true)
3255
return NestedIterator(iterators, condition)
3356
end
57+
3458
Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown()
59+
3560
Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown()
36-
function next_iterate(iterators, condition, elems, states, iterator, elem_state)
61+
62+
function _next_iterate(
63+
iterators,
64+
condition,
65+
elems,
66+
states,
67+
iterator,
68+
elem_state,
69+
)
3770
while true
38-
elem_state === nothing && return nothing
71+
if elem_state === nothing
72+
return
73+
end
3974
elem, state = elem_state
40-
elems_states = first_iterate(
75+
elems_states = _first_iterate(
4176
Base.tail(iterators),
4277
condition,
4378
(elems..., elem),
4479
(states..., (iterator, state, elem)),
4580
)
46-
elems_states !== nothing && return elems_states
47-
# This could be written as a recursive function where we call `next_iterate`
48-
# here with this new value of `next_iterate` instead of the `while` loop`.
49-
# However, if there are too many consecutive elements for which `condition`
50-
# is `false` for the last iterator, this will result in a `StackOverflow`.
51-
# See https://github.com/jump-dev/JuMP.jl/issues/2335
81+
if elems_states !== nothing
82+
return elems_states
83+
end
84+
# This could be written as a recursive function where we call
85+
# `_next_iterate` here with this new value of `_next_iterate` instead of
86+
# the `while` loop`.
87+
#
88+
# However, if there are too many consecutive elements for which
89+
# `condition`is `false` for the last iterator, this will result in a
90+
# `StackOverflow`. See https://github.com/jump-dev/JuMP.jl/issues/2335
5291
elem_state = iterate(iterator, state)
5392
end
5493
end
55-
function first_iterate(::Tuple{}, condition, elems, states)
94+
95+
function _first_iterate(::Tuple{}, condition, elems, states)
5696
if condition(elems...)
5797
return elems, states
58-
else
59-
return nothing
6098
end
99+
return
61100
end
62-
function first_iterate(iterators, condition, elems, states)
101+
102+
function _first_iterate(iterators, condition, elems, states)
63103
iterator = iterators[1](elems...)
64-
return next_iterate(
104+
return _next_iterate(
65105
iterators,
66106
condition,
67107
elems,
@@ -70,9 +110,11 @@ function first_iterate(iterators, condition, elems, states)
70110
iterate(iterator),
71111
)
72112
end
73-
tail_iterate(::Tuple{}, condition, elems, states, prev_states) = nothing
74-
function tail_iterate(iterators, condition, elems, states, prev_states)
75-
next = tail_iterate(
113+
114+
_tail_iterate(::Tuple{}, condition, elems, states, prev_states) = nothing
115+
116+
function _tail_iterate(iterators, condition, elems, states, prev_states)
117+
next = _tail_iterate(
76118
Base.tail(iterators),
77119
condition,
78120
(elems..., states[1][3]),
@@ -83,7 +125,7 @@ function tail_iterate(iterators, condition, elems, states, prev_states)
83125
return next
84126
end
85127
iterator = states[1][1]
86-
return next_iterate(
128+
return _next_iterate(
87129
iterators,
88130
condition,
89131
elems,
@@ -92,12 +134,15 @@ function tail_iterate(iterators, condition, elems, states, prev_states)
92134
iterate(iterator, states[1][2]),
93135
)
94136
end
137+
95138
function Base.iterate(it::NestedIterator)
96-
return first_iterate(it.iterators, it.condition, tuple(), tuple())
139+
return _first_iterate(it.iterators, it.condition, tuple(), tuple())
97140
end
141+
98142
function Base.iterate(it::NestedIterator, states)
99-
return tail_iterate(it.iterators, it.condition, tuple(), states, tuple())
143+
return _tail_iterate(it.iterators, it.condition, tuple(), states, tuple())
100144
end
145+
101146
function _eltype_or_any(::NestedIterator{<:Tuple{Vararg{Any,N}}}) where {N}
102147
return NTuple{N,Any}
103148
end

Diff for: src/Containers/vectorized_product_iterator.jl

+31-8
Original file line numberDiff line numberDiff line change
@@ -27,50 +27,73 @@
2727
# Long story short, we want to tried everything as an iterator without shape
2828
# while `Iterators.ProductIterator` does care about preserving the shape
2929
# when doing the Cartesian product.
30+
3031
"""
3132
struct VectorizedProductIterator{T}
3233
prod::Iterators.ProductIterator{T}
3334
end
3435
35-
Cartesian product of the iterators `prod.iterators`. It is the same iterator as
36-
`Base.Iterators.ProductIterator` except that it is independent of the
37-
`IteratorSize` of the elements of `prod.iterators`.
38-
For instance:
39-
* `size(Iterators.product(1, 2))` is `tuple()` while
40-
`size(VectorizedProductIterator(1, 2))` is `(1, 1)`.
41-
* `size(Iterators.product(ones(2, 3)))` is `(2, 3)` while
42-
`size(VectorizedProductIterator(ones(2, 3)))` is `(1, 1)`.
36+
A wrapper type for `Iterators.ProuctIterator` that discards shape information
37+
and returns a `Vector`.
38+
39+
Construct a `VectorizedProductIterator` using [`vectorized_product`](@ref).
4340
"""
4441
struct VectorizedProductIterator{T}
4542
prod::Iterators.ProductIterator{T}
4643
end
4744

4845
# Collect iterators with unknown size so they can be used as axes.
4946
_collect(::Base.SizeUnknown, x) = collect(x)
47+
5048
_collect(::Any, x) = x
49+
5150
function _collect(::Base.IsInfinite, x)
5251
return error("Unable to form a container. Axis $(x) has infinite size!")
5352
end
53+
5454
_collect(x) = _collect(Base.IteratorSize(x), x)
5555

56+
"""
57+
vectorized_product(iterators...)
58+
59+
Created a [`VectorizedProductIterator`](@ref).
60+
61+
## Examples
62+
63+
```julia
64+
vectorized_product(1:2, ["A", "B"])
65+
```
66+
"""
5667
function vectorized_product(iterators...)
5768
return VectorizedProductIterator(Iterators.product(_collect.(iterators)...))
5869
end
70+
5971
function Base.IteratorSize(
6072
::Type{<:VectorizedProductIterator{<:Tuple{Vararg{Any,N}}}},
6173
) where {N}
6274
return Base.HasShape{N}()
6375
end
76+
6477
Base.IteratorEltype(::Type{<:VectorizedProductIterator}) = Base.EltypeUnknown()
78+
6579
Base.size(it::VectorizedProductIterator) = _prod_size(it.prod.iterators)
80+
6681
_prod_size(::Tuple{}) = ()
82+
6783
_prod_size(t::Tuple) = (length(t[1]), _prod_size(Base.tail(t))...)
84+
6885
Base.axes(it::VectorizedProductIterator) = _prod_indices(it.prod.iterators)
86+
6987
_prod_indices(::Tuple{}) = ()
88+
7089
function _prod_indices(t::Tuple)
7190
return (Base.OneTo(length(t[1])), _prod_indices(Base.tail(t))...)
7291
end
92+
7393
Base.ndims(it::VectorizedProductIterator) = length(axes(it))
94+
7495
Base.length(it::VectorizedProductIterator) = prod(size(it))
96+
7597
Base.iterate(it::VectorizedProductIterator, args...) = iterate(it.prod, args...)
98+
7699
_eltype_or_any(it::VectorizedProductIterator) = eltype(it.prod)

0 commit comments

Comments
 (0)