diff --git a/src/eachobs.jl b/src/eachobs.jl index 99c96c4..4932df7 100644 --- a/src/eachobs.jl +++ b/src/eachobs.jl @@ -56,7 +56,8 @@ The original data is preserved in the `data` field of the DataLoader. - `data`: The data to be iterated over. The data type has to be supported by [`numobs`](@ref) and [`getobs`](@ref). - `batchsize`: If less than 0, iterates over individual observations. - Otherwise, each iteration (except possibly the last) yields a mini-batch + If 0, then one mini-batch containing all `numobs(x)` observations. + If larger than 0, each iteration (except possibly the last) yields a mini-batch containing `batchsize` observations. Default `1`. - `buffer`: If `buffer=true` and supported by the type of `data`, a buffer will be allocated and reused for memory efficiency. @@ -149,6 +150,7 @@ function DataLoader( if !(collate ∈ (Val(nothing), Val(true), Val(false))) throw(ArgumentError("`collate` must be one of `nothing`, `true` or `false`.")) end + batchsize = batchsize == 0 ? numobs(data) : batchsize return DataLoader(data, batchsize, buffer, partial, shuffle, parallel, collate, rng) end