Skip to content

Commit eb050fd

Browse files
authored
Merge pull request #11 from JuliaReinforcementLearning/fix_RL_600
support multi dimensions in the last dim of circular array buffer
2 parents b90a938 + 12bda6c commit eb050fd

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/CircularArrayBuffers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333
Adapt.adapt_structure(to, cb::CircularArrayBuffer) =
3434
CircularArrayBuffer(adapt(to, cb.buffer), cb.first, cb.nframes, cb.step_size)
3535

36-
function Base.show(io::IO, ::MIME"text/plain", cb::CircularArrayBuffer{T}) where T
36+
function Base.show(io::IO, ::MIME"text/plain", cb::CircularArrayBuffer{T}) where {T}
3737
print(io, ndims(cb) == 1 ? "CircularVectorBuffer(" : "CircularArrayBuffer(")
3838
Base.showarg(io, cb.buffer, false)
3939
print(io, ") with eltype $T:\n")
@@ -77,7 +77,7 @@ end
7777
end
7878
end
7979

80-
_buffer_frame(cb::CircularArrayBuffer, I::AbstractVector{<:Integer}) = map(i -> _buffer_frame(cb, i), I)
80+
_buffer_frame(cb::CircularArrayBuffer, I::AbstractArray{<:Integer}) = map(i -> _buffer_frame(cb, i), I)
8181

8282
function Base.empty!(cb::CircularArrayBuffer)
8383
cb.nframes = 0

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ CUDA.allowscalar(false)
7979
@test size(b) == (3,)
8080
@test b[[1, 2, 3]] == [7, 8, 9]
8181

82+
# https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/600
83+
@test b[[1 3; 2 3]] == [7 9; 8 9]
84+
@test @view(b[[1 3; 2 3]]) == [7 9; 8 9]
85+
8286
x = pop!(b)
8387
@test x == 9
8488
@test length(b) == 2

0 commit comments

Comments
 (0)