Skip to content

Commit 3622bde

Browse files
authored
Merge pull request #9 from darsnack/gpu-support
Allow any backing array type and add Adapt support
2 parents 1eb9387 + 2802016 commit 3622bde

File tree

4 files changed

+210
-11
lines changed

4 files changed

+210
-11
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
fail-fast: false
1111
matrix:
1212
version:
13-
- '1.0'
13+
# - '1.0' # incompatible with CUDA.jl
1414
- '1'
1515
- 'nightly'
1616
os:

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@ uuid = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
33
authors = ["Jun Tian <[email protected]> and contributors"]
44
version = "0.1.5"
55

6+
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
69
[compat]
710
julia = "1"
811

912
[extras]
1013
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
14+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1115

1216
[targets]
13-
test = ["Test"]
17+
test = ["CUDA", "Test"]

src/CircularArrayBuffers.jl

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
module CircularArrayBuffers
22

3+
using Adapt
4+
35
export CircularArrayBuffer, CircularVectorBuffer, capacity, isfull
46

57
"""
6-
CircularArrayBuffer{T}(sz::Integer...) -> CircularArrayBuffer{T, N}
8+
CircularArrayBuffer{T}(sz::Integer...) -> CircularArrayBuffer{T, N, Array{T, N}}
79
810
`CircularArrayBuffer` uses a `N`-dimension `Array` of size `sz` to serve as a buffer for
911
`N-1`-dimension `Array`s of the same size.
1012
"""
11-
mutable struct CircularArrayBuffer{T,N} <: AbstractArray{T,N}
12-
buffer::Array{T,N}
13+
mutable struct CircularArrayBuffer{T,N,S<:AbstractArray{T,N}} <: AbstractArray{T,N}
14+
buffer::S
1315
first::Int
1416
nframes::Int
1517
step_size::Int
1618
end
1719

18-
const CircularVectorBuffer{T} = CircularArrayBuffer{T,1}
20+
const CircularVectorBuffer{T,S} = CircularArrayBuffer{T,1,S}
1921

2022
CircularVectorBuffer{T}(n::Integer) where {T} = CircularArrayBuffer{T}(n)
2123

@@ -28,12 +30,25 @@ function CircularArrayBuffer(A::AbstractArray{T,N}) where {T,N}
2830
CircularArrayBuffer(A, 1, size(A, N), N == 1 ? 1 : *(size(A)[1:end-1]...))
2931
end
3032

33+
Adapt.adapt_structure(to, cb::CircularArrayBuffer) =
34+
CircularArrayBuffer(adapt(to, cb.buffer), cb.first, cb.nframes, cb.step_size)
35+
36+
function Base.show(io::IO, ::MIME"text/plain", cb::CircularArrayBuffer{T}) where T
37+
print(io, ndims(cb) == 1 ? "CircularVectorBuffer(" : "CircularArrayBuffer(")
38+
Base.showarg(io, cb.buffer, false)
39+
print(io, ") with eltype $T:\n")
40+
Base.print_array(io, adapt(Array, cb))
41+
return nothing
42+
end
43+
3144
Base.IndexStyle(::CircularArrayBuffer) = IndexLinear()
3245

3346
Base.size(cb::CircularArrayBuffer{T,N}, i::Integer) where {T,N} = i == N ? cb.nframes : size(cb.buffer, i)
3447
Base.size(cb::CircularArrayBuffer{T,N}) where {T,N} = ntuple(i -> size(cb, i), N)
3548
Base.getindex(cb::CircularArrayBuffer{T,N}, i::Int) where {T,N} = getindex(cb.buffer, _buffer_index(cb, i))
49+
Base.getindex(cb::CircularArrayBuffer{T,N}, I...) where {T,N} = getindex(cb.buffer, Base.front(I)..., _buffer_frame(cb, Base.last(I)))
3650
Base.setindex!(cb::CircularArrayBuffer{T,N}, v, i::Int) where {T,N} = setindex!(cb.buffer, v, _buffer_index(cb, i))
51+
Base.setindex!(cb::CircularArrayBuffer{T,N}, v, I...) where {T,N} = setindex!(cb.buffer, v, Base.front(I)..., _buffer_frame(cb, Base.last(I)))
3752

3853
capacity(cb::CircularArrayBuffer{T,N}) where {T,N} = size(cb.buffer, N)
3954
isfull(cb::CircularArrayBuffer) = cb.nframes == capacity(cb)
@@ -47,6 +62,7 @@ Base.isempty(cb::CircularArrayBuffer) = cb.nframes == 0
4762
ind
4863
end
4964
end
65+
@inline _buffer_index(cb::CircularArrayBuffer, I::AbstractVector{<:Integer}) = map(Base.Fix1(_buffer_index, cb), I)
5066

5167
@inline function _buffer_frame(cb::CircularArrayBuffer, i::Int)
5268
n = capacity(cb)
@@ -58,7 +74,7 @@ end
5874
end
5975
end
6076

61-
_buffer_frame(cb::CircularArrayBuffer, I::Vector{Int}) = map(i -> _buffer_frame(cb, i), I)
77+
_buffer_frame(cb::CircularArrayBuffer, I::AbstractVector{<:Integer}) = map(i -> _buffer_frame(cb, i), I)
6278

6379
function Base.empty!(cb::CircularArrayBuffer)
6480
cb.nframes = 0
@@ -72,13 +88,14 @@ function Base.push!(cb::CircularArrayBuffer{T,N}, data) where {T,N}
7288
cb.nframes += 1
7389
end
7490
if N == 1
91+
i = _buffer_frame(cb, cb.nframes)
7592
if ndims(data) == 0
76-
cb[cb.nframes] = data[]
93+
cb.buffer[i:i] .= data[]
7794
else
78-
cb[cb.nframes] = data
95+
cb.buffer[i:i] .= data
7996
end
8097
else
81-
cb[ntuple(_ -> (:), N - 1)..., cb.nframes] .= data
98+
cb.buffer[ntuple(_ -> (:), N - 1)..., _buffer_frame(cb, cb.nframes)] .= data
8299
end
83100
cb
84101
end

test/runtests.jl

Lines changed: 179 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
using CircularArrayBuffers
22
using Test
3+
using Adapt
4+
using CUDA
5+
CUDA.allowscalar(false)
36

4-
@testset "CircularArrayBuffers.jl" begin
7+
@testset "CircularArrayBuffers (Array)" begin
58
A = ones(2, 2)
69
C = ones(Float32, 2, 2)
710

@@ -165,3 +168,178 @@ using Test
165168
]
166169
end
167170
end
171+
172+
if CUDA.functional()
173+
@testset "CircularArrayBuffers (CuArray)" begin
174+
A = CUDA.ones(2, 2)
175+
Ac = adapt(Array, A)
176+
C = CUDA.ones(Float32, 2, 2)
177+
178+
@testset "Adapt" begin
179+
X = CircularArrayBuffer(rand(2, 3))
180+
Xc = adapt(CuArray, X)
181+
@test Xc isa CircularArrayBuffer{Float64,2,<:CuArray}
182+
@test adapt(Array, Xc) == X
183+
end
184+
185+
# https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/issues/551
186+
@testset "1D with 0d data" begin
187+
b = adapt(CuArray, CircularArrayBuffer{Int}(3))
188+
CUDA.@allowscalar push!(b, CUDA.zeros(Int, ()))
189+
@test length(b) == 1
190+
@test CUDA.@allowscalar b[1] == 0
191+
end
192+
193+
@testset "1D Int" begin
194+
b = adapt(CuArray, CircularArrayBuffer{Int}(3))
195+
196+
@test eltype(b) == Int
197+
@test capacity(b) == 3
198+
@test isfull(b) == false
199+
@test isempty(b) == true
200+
@test length(b) == 0
201+
@test size(b) == (0,)
202+
# element must has the exact same length with the element of buffer
203+
@test_throws Exception push!(b, [1, 2])
204+
205+
for x in 1:3
206+
push!(b, x)
207+
end
208+
209+
@test capacity(b) == 3
210+
@test isfull(b) == true
211+
@test length(b) == 3
212+
@test size(b) == (3,)
213+
# scalar indexing is not allowed
214+
@test_throws ErrorException b[1]
215+
@test_throws ErrorException b[end]
216+
@test CUDA.@allowscalar b[1:end] == cu([1, 2, 3])
217+
218+
for x in 4:5
219+
push!(b, x)
220+
end
221+
222+
@test capacity(b) == 3
223+
@test length(b) == 3
224+
@test size(b) == (3,)
225+
@test CUDA.@allowscalar b[1:end] == [3, 4, 5]
226+
227+
empty!(b)
228+
@test isfull(b) == false
229+
@test isempty(b) == true
230+
@test length(b) == 0
231+
@test size(b) == (0,)
232+
233+
push!(b, 6)
234+
@test isfull(b) == false
235+
@test isempty(b) == false
236+
@test length(b) == 1
237+
@test size(b) == (1,)
238+
@test CUDA.@allowscalar b[1] == 6
239+
240+
push!(b, 7)
241+
push!(b, 8)
242+
@test isfull(b) == true
243+
@test isempty(b) == false
244+
@test length(b) == 3
245+
@test size(b) == (3,)
246+
@test CUDA.@allowscalar b[1:3] == cu([6, 7, 8])
247+
248+
push!(b, 9)
249+
@test isfull(b) == true
250+
@test isempty(b) == false
251+
@test length(b) == 3
252+
@test size(b) == (3,)
253+
@test CUDA.@allowscalar b[1:3] == cu([7, 8, 9])
254+
255+
x = CUDA.@allowscalar pop!(b)
256+
@test x == 9
257+
@test length(b) == 2
258+
@test CUDA.@allowscalar b[1:2] == cu([7, 8])
259+
260+
x = CUDA.@allowscalar popfirst!(b)
261+
@test x == 7
262+
@test length(b) == 1
263+
@test CUDA.@allowscalar b[1] == 8
264+
265+
x = CUDA.@allowscalar pop!(b)
266+
@test x == 8
267+
@test length(b) == 0
268+
269+
@test_throws ArgumentError pop!(b)
270+
@test_throws ArgumentError popfirst!(b)
271+
end
272+
273+
@testset "2D Float64" begin
274+
b = adapt(CuArray, CircularArrayBuffer{Float64}(2, 2, 3))
275+
276+
@test eltype(b) == Float64
277+
@test capacity(b) == 3
278+
@test isfull(b) == false
279+
@test length(b) == 0
280+
@test size(b) == (2, 2, 0)
281+
282+
for x in 1:3
283+
push!(b, x * A)
284+
end
285+
286+
@test capacity(b) == 3
287+
@test isfull(b) == true
288+
@test length(b) == 2 * 2 * 3
289+
@test size(b) == (2, 2, 3)
290+
for i in 1:3
291+
@test b[:, :, i] == i * A
292+
end
293+
@test b[:, :, end] == 3 * A
294+
295+
for x in 4:5
296+
push!(b, x * CUDA.ones(Float64, 2, 2))
297+
end
298+
299+
@test capacity(b) == 3
300+
@test length(b) == 2 * 2 * 3
301+
@test size(b) == (2, 2, 3)
302+
@test b[:, :, 1] == 3 * A
303+
@test b[:, :, end] == 5 * A
304+
305+
# doing b == ... triggers scalar indexing
306+
@test CUDA.@allowscalar b == cu(reshape([c for x in 3:5 for c in x * Ac], 2, 2, 3))
307+
308+
push!(b, 6 * CUDA.ones(Float32, 2, 2))
309+
push!(b, 7 * CUDA.ones(Int, 2, 2))
310+
@test CUDA.@allowscalar b == cu(reshape([c for x in 5:7 for c in x * Ac], 2, 2, 3))
311+
312+
x = pop!(b)
313+
@test x == 7 * CUDA.ones(Float64, 2, 2)
314+
@test CUDA.@allowscalar b == cu(reshape([c for x in 5:6 for c in x * Ac], 2, 2, 2))
315+
end
316+
317+
@testset "append!" begin
318+
b = adapt(CuArray, CircularArrayBuffer{Int}(2, 3))
319+
append!(b, CUDA.zeros(2))
320+
append!(b, 1:4)
321+
@test CUDA.@allowscalar b == cu([
322+
0 1 3
323+
0 2 4
324+
])
325+
326+
327+
b = adapt(CuArray, CircularArrayBuffer{Int}(2, 3))
328+
for i in 1:5
329+
push!(b, CUDA.fill(i, 2))
330+
end
331+
empty!(b)
332+
append!(b, 1:4)
333+
@test CUDA.@allowscalar b == cu([
334+
1 3
335+
2 4
336+
])
337+
338+
append!(b, 5:8)
339+
@test CUDA.@allowscalar b == cu([
340+
3 5 7
341+
4 6 8
342+
])
343+
end
344+
end
345+
end

0 commit comments

Comments
 (0)