|
1 | 1 | using CircularArrayBuffers
|
2 | 2 | using Test
|
| 3 | +using Adapt |
| 4 | +using CUDA |
| 5 | +CUDA.allowscalar(false) |
3 | 6 |
|
4 |
| -@testset "CircularArrayBuffers.jl" begin |
| 7 | +@testset "CircularArrayBuffers (Array)" begin |
5 | 8 | A = ones(2, 2)
|
6 | 9 | C = ones(Float32, 2, 2)
|
7 | 10 |
|
@@ -165,3 +168,178 @@ using Test
|
165 | 168 | ]
|
166 | 169 | end
|
167 | 170 | end
|
| 171 | + |
| 172 | +if CUDA.functional() |
| 173 | + @testset "CircularArrayBuffers (CuArray)" begin |
| 174 | + A = CUDA.ones(2, 2) |
| 175 | + Ac = adapt(Array, A) |
| 176 | + C = CUDA.ones(Float32, 2, 2) |
| 177 | + |
| 178 | + @testset "Adapt" begin |
| 179 | + X = CircularArrayBuffer(rand(2, 3)) |
| 180 | + Xc = adapt(CuArray, X) |
| 181 | + @test Xc isa CircularArrayBuffer{Float64,2,<:CuArray} |
| 182 | + @test adapt(Array, Xc) == X |
| 183 | + end |
| 184 | + |
| 185 | + # https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/issues/551 |
| 186 | + @testset "1D with 0d data" begin |
| 187 | + b = adapt(CuArray, CircularArrayBuffer{Int}(3)) |
| 188 | + CUDA.@allowscalar push!(b, CUDA.zeros(Int, ())) |
| 189 | + @test length(b) == 1 |
| 190 | + @test CUDA.@allowscalar b[1] == 0 |
| 191 | + end |
| 192 | + |
| 193 | + @testset "1D Int" begin |
| 194 | + b = adapt(CuArray, CircularArrayBuffer{Int}(3)) |
| 195 | + |
| 196 | + @test eltype(b) == Int |
| 197 | + @test capacity(b) == 3 |
| 198 | + @test isfull(b) == false |
| 199 | + @test isempty(b) == true |
| 200 | + @test length(b) == 0 |
| 201 | + @test size(b) == (0,) |
| 202 | + # element must has the exact same length with the element of buffer |
| 203 | + @test_throws Exception push!(b, [1, 2]) |
| 204 | + |
| 205 | + for x in 1:3 |
| 206 | + push!(b, x) |
| 207 | + end |
| 208 | + |
| 209 | + @test capacity(b) == 3 |
| 210 | + @test isfull(b) == true |
| 211 | + @test length(b) == 3 |
| 212 | + @test size(b) == (3,) |
| 213 | + # scalar indexing is not allowed |
| 214 | + @test_throws ErrorException b[1] |
| 215 | + @test_throws ErrorException b[end] |
| 216 | + @test CUDA.@allowscalar b[1:end] == cu([1, 2, 3]) |
| 217 | + |
| 218 | + for x in 4:5 |
| 219 | + push!(b, x) |
| 220 | + end |
| 221 | + |
| 222 | + @test capacity(b) == 3 |
| 223 | + @test length(b) == 3 |
| 224 | + @test size(b) == (3,) |
| 225 | + @test CUDA.@allowscalar b[1:end] == [3, 4, 5] |
| 226 | + |
| 227 | + empty!(b) |
| 228 | + @test isfull(b) == false |
| 229 | + @test isempty(b) == true |
| 230 | + @test length(b) == 0 |
| 231 | + @test size(b) == (0,) |
| 232 | + |
| 233 | + push!(b, 6) |
| 234 | + @test isfull(b) == false |
| 235 | + @test isempty(b) == false |
| 236 | + @test length(b) == 1 |
| 237 | + @test size(b) == (1,) |
| 238 | + @test CUDA.@allowscalar b[1] == 6 |
| 239 | + |
| 240 | + push!(b, 7) |
| 241 | + push!(b, 8) |
| 242 | + @test isfull(b) == true |
| 243 | + @test isempty(b) == false |
| 244 | + @test length(b) == 3 |
| 245 | + @test size(b) == (3,) |
| 246 | + @test CUDA.@allowscalar b[1:3] == cu([6, 7, 8]) |
| 247 | + |
| 248 | + push!(b, 9) |
| 249 | + @test isfull(b) == true |
| 250 | + @test isempty(b) == false |
| 251 | + @test length(b) == 3 |
| 252 | + @test size(b) == (3,) |
| 253 | + @test CUDA.@allowscalar b[1:3] == cu([7, 8, 9]) |
| 254 | + |
| 255 | + x = CUDA.@allowscalar pop!(b) |
| 256 | + @test x == 9 |
| 257 | + @test length(b) == 2 |
| 258 | + @test CUDA.@allowscalar b[1:2] == cu([7, 8]) |
| 259 | + |
| 260 | + x = CUDA.@allowscalar popfirst!(b) |
| 261 | + @test x == 7 |
| 262 | + @test length(b) == 1 |
| 263 | + @test CUDA.@allowscalar b[1] == 8 |
| 264 | + |
| 265 | + x = CUDA.@allowscalar pop!(b) |
| 266 | + @test x == 8 |
| 267 | + @test length(b) == 0 |
| 268 | + |
| 269 | + @test_throws ArgumentError pop!(b) |
| 270 | + @test_throws ArgumentError popfirst!(b) |
| 271 | + end |
| 272 | + |
| 273 | + @testset "2D Float64" begin |
| 274 | + b = adapt(CuArray, CircularArrayBuffer{Float64}(2, 2, 3)) |
| 275 | + |
| 276 | + @test eltype(b) == Float64 |
| 277 | + @test capacity(b) == 3 |
| 278 | + @test isfull(b) == false |
| 279 | + @test length(b) == 0 |
| 280 | + @test size(b) == (2, 2, 0) |
| 281 | + |
| 282 | + for x in 1:3 |
| 283 | + push!(b, x * A) |
| 284 | + end |
| 285 | + |
| 286 | + @test capacity(b) == 3 |
| 287 | + @test isfull(b) == true |
| 288 | + @test length(b) == 2 * 2 * 3 |
| 289 | + @test size(b) == (2, 2, 3) |
| 290 | + for i in 1:3 |
| 291 | + @test b[:, :, i] == i * A |
| 292 | + end |
| 293 | + @test b[:, :, end] == 3 * A |
| 294 | + |
| 295 | + for x in 4:5 |
| 296 | + push!(b, x * CUDA.ones(Float64, 2, 2)) |
| 297 | + end |
| 298 | + |
| 299 | + @test capacity(b) == 3 |
| 300 | + @test length(b) == 2 * 2 * 3 |
| 301 | + @test size(b) == (2, 2, 3) |
| 302 | + @test b[:, :, 1] == 3 * A |
| 303 | + @test b[:, :, end] == 5 * A |
| 304 | + |
| 305 | + # doing b == ... triggers scalar indexing |
| 306 | + @test CUDA.@allowscalar b == cu(reshape([c for x in 3:5 for c in x * Ac], 2, 2, 3)) |
| 307 | + |
| 308 | + push!(b, 6 * CUDA.ones(Float32, 2, 2)) |
| 309 | + push!(b, 7 * CUDA.ones(Int, 2, 2)) |
| 310 | + @test CUDA.@allowscalar b == cu(reshape([c for x in 5:7 for c in x * Ac], 2, 2, 3)) |
| 311 | + |
| 312 | + x = pop!(b) |
| 313 | + @test x == 7 * CUDA.ones(Float64, 2, 2) |
| 314 | + @test CUDA.@allowscalar b == cu(reshape([c for x in 5:6 for c in x * Ac], 2, 2, 2)) |
| 315 | + end |
| 316 | + |
| 317 | + @testset "append!" begin |
| 318 | + b = adapt(CuArray, CircularArrayBuffer{Int}(2, 3)) |
| 319 | + append!(b, CUDA.zeros(2)) |
| 320 | + append!(b, 1:4) |
| 321 | + @test CUDA.@allowscalar b == cu([ |
| 322 | + 0 1 3 |
| 323 | + 0 2 4 |
| 324 | + ]) |
| 325 | + |
| 326 | + |
| 327 | + b = adapt(CuArray, CircularArrayBuffer{Int}(2, 3)) |
| 328 | + for i in 1:5 |
| 329 | + push!(b, CUDA.fill(i, 2)) |
| 330 | + end |
| 331 | + empty!(b) |
| 332 | + append!(b, 1:4) |
| 333 | + @test CUDA.@allowscalar b == cu([ |
| 334 | + 1 3 |
| 335 | + 2 4 |
| 336 | + ]) |
| 337 | + |
| 338 | + append!(b, 5:8) |
| 339 | + @test CUDA.@allowscalar b == cu([ |
| 340 | + 3 5 7 |
| 341 | + 4 6 8 |
| 342 | + ]) |
| 343 | + end |
| 344 | + end |
| 345 | +end |
0 commit comments