Skip to content

Commit d4b705f

Browse files
committed
fix bugs when trying to adapt to RLEnvs.jl
1 parent 8e5ef58 commit d4b705f

File tree

3 files changed

+24
-11
lines changed

3 files changed

+24
-11
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
*.jl.*.cov
22
*.jl.cov
33
*.jl.mem
4-
/Manifest.toml
4+
Manifest.toml
55
/docs/build/

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CommonRLSpaces"
22
uuid = "408f5b3e-f2a2-48a6-b4bb-c8aa44c458e6"
33
authors = ["Jun Tian <[email protected]> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/basic.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,38 @@ struct Space{T}
66
s::T
77
end
88

9-
Space(s::Type{T}) where {T} = Space(typemin(T):typemax(T))
9+
Space(s::Type{T}) where {T<:Number} = Space(typemin(T):typemax(T))
1010

11-
Space(x, dims::Int...) = Space(fill(x, dims))
12-
Space(x::Type{T}, dims::Int...) where {T<:Integer} = Space(fill(typemin(x):typemax(T), dims))
13-
Space(x::Type{T}, dims::Int...) where {T<:AbstractFloat} = Space(fill(typemin(x) .. typemax(T), dims))
11+
Space(x, dims::Int...) = Space(Fill(x, dims))
12+
Space(x::Type{T}, dim::Int, extra_dims::Int...) where {T<:Integer} = Space(Fill(typemin(x):typemax(T), dim, extra_dims...))
13+
Space(x::Type{T}, dim::Int, extra_dims::Int...) where {T<:AbstractFloat} = Space(Fill(typemin(x) .. typemax(T), dim, extra_dims...))
14+
Space(x::Type{T}, dim::Int, extra_dims::Int...) where {T} = Space(Fill(T, dim, extra_dims...))
1415

1516
Base.size(s::Space) = size(SpaceStyle(s))
17+
Base.length(s::Space) = length(SpaceStyle(s), s)
18+
Base.getindex(s::Space, i...) = getindex(SpaceStyle(s), s, i...)
19+
Base.:(==)(s1::Space, s2::Space) = s1.s == s2.s
1620

1721
#####
1822

1923
abstract type AbstractSpaceStyle{S} end
2024

21-
Base.size(::AbstractSpaceStyle{S}) where {S} = S
22-
2325
struct DiscreteSpaceStyle{S} <: AbstractSpaceStyle{S} end
2426
struct ContinuousSpaceStyle{S} <: AbstractSpaceStyle{S} end
2527

2628
SpaceStyle(::Space{<:Tuple}) = DiscreteSpaceStyle{()}()
27-
SpaceStyle(::Space{<:AbstractRange}) = DiscreteSpaceStyle{()}()
29+
SpaceStyle(::Space{<:AbstractVector{<:Number}}) = DiscreteSpaceStyle{()}()
2830
SpaceStyle(::Space{<:AbstractInterval}) = ContinuousSpaceStyle{()}()
2931

3032
SpaceStyle(s::Space{<:AbstractArray{<:Tuple}}) = DiscreteSpaceStyle{size(s.s)}()
3133
SpaceStyle(s::Space{<:AbstractArray{<:AbstractRange}}) = DiscreteSpaceStyle{size(s.s)}()
3234
SpaceStyle(s::Space{<:AbstractArray{<:AbstractInterval}}) = ContinuousSpaceStyle{size(s.s)}()
3335

36+
Base.size(::AbstractSpaceStyle{S}) where {S} = S
37+
Base.length(::DiscreteSpaceStyle{()}, s) = length(s.s)
38+
Base.getindex(::DiscreteSpaceStyle{()}, s, i...) = getindex(s.s, i...)
39+
Base.length(::DiscreteSpaceStyle, s) = mapreduce(length, *, s.s)
40+
3441
#####
3542

3643
Random.rand(rng::Random.AbstractRNG, s::Space) = rand(rng, s.s)
@@ -45,6 +52,7 @@ Random.rand(
4552
) = map(x -> rand(rng, x), s.s)
4653

4754
Base.in(x, s::Space) = x in s.s
55+
Base.in(x, s::Space{<:Type}) = x isa s.s
4856

4957
Base.in(
5058
x,
@@ -69,15 +77,20 @@ function Random.rand(rng::AbstractRNG, s::Interval{:closed,:closed,T}) where {T}
6977
end
7078
end
7179

80+
Base.iterate(s::Space, args...) = iterate(SpaceStyle(s), s, args...)
81+
Base.iterate(::DiscreteSpaceStyle{()}, s::Space, args...) = iterate(s.s, args...)
82+
7283
#####
7384

74-
const TupleSpace = Tuple{Vararg{<:Space}}
85+
const TupleSpace = Tuple{Vararg{Space}}
7586
const NamedSpace = NamedTuple{<:Any,<:TupleSpace}
87+
const VectorSpace = Vector{<:Space}
7688
const DictSpace = Dict{<:Any,<:Space}
7789

78-
Random.rand(rng::AbstractRNG, s::Union{TupleSpace,NamedSpace}) = map(x -> rand(rng, x), s)
90+
Random.rand(rng::AbstractRNG, s::Union{TupleSpace,NamedSpace,VectorSpace}) = map(x -> rand(rng, x), s)
7991
Random.rand(rng::AbstractRNG, s::DictSpace) = Dict(k => rand(rng, s[k]) for k in keys(s))
8092

8193
Base.in(xs::Tuple, ts::TupleSpace) = length(xs) == length(ts) && all(((x, s),) -> x in s, zip(xs, ts))
94+
Base.in(xs::AbstractVector, ts::VectorSpace) = length(xs) == length(ts) && all(((x, s),) -> x in s, zip(xs, ts))
8295
Base.in(xs::NamedTuple{names}, ns::NamedTuple{names,<:TupleSpace}) where {names} = all(((x, s),) -> x in s, zip(xs, ns))
8396
Base.in(xs::Dict, ds::DictSpace) = length(xs) == length(ds) && all(k -> haskey(ds, k) && xs[k] in ds[k], keys(xs))

0 commit comments

Comments
 (0)