Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap BNNS sublibrary with Clang and integrate with Random #76

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@ uuid = "13e28ba4-7ad8-5781-acae-3021b1ed3924"
version = "0.4.1"

[deps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[extensions]
RandomExt = "Random"

[compat]
BFloat16s = "0.5.0"
CEnum = "0.5.0"
julia = "1.9"

[extras]
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -18,7 +29,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
186 changes: 186 additions & 0 deletions ext/RandomExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
module RandomExt

@static if Sys.isapple()

using BFloat16s
using AppleAccelerate: BNNS
using .BNNS: BNNSFilterParameters,
BNNSRandomGeneratorMethodAES_CTR,
BNNSCreateRandomGenerator,
BNNSCreateRandomGeneratorWithSeed,
BNNSRandomGeneratorStateSize,
BNNSRandomGeneratorSetState,
BNNSRandomGeneratorGetState,
BNNSNDArrayDescriptor,
BNNSRandomFillNormalFloat,
BNNSRandomFillUniformFloat,
BNNSRandomFillUniformInt
using Random: Random, AbstractRNG

"""
RNG()

A random number generator using AppleAccelerate's BNNS functionality.
"""
mutable struct RNG <: AbstractRNG
ptr::Ptr{Nothing}
function RNG(filter_parameters::Union{Nothing, BNNSFilterParameters}=nothing)
params = isnothing(filter_parameters) ? Ptr{BNNSFilterParameters}(0) : [filter_parameters]
res = new(BNNSCreateRandomGenerator(BNNSRandomGeneratorMethodAES_CTR, params))
# finalizer(res) do
# BNNSDestroyRandomGenerator(res.ptr)
# end
return res
end
function RNG(seed::Integer, filter_parameters::Union{Nothing, BNNSFilterParameters}=nothing)
seed = seed%UInt64
params = isnothing(filter_parameters) ? Ptr{BNNSFilterParameters}(0) : [filter_parameters]
res = new(BNNSCreateRandomGeneratorWithSeed(BNNSRandomGeneratorMethodAES_CTR, seed, params))
# finalizer(res) do
# BNNSDestroyRandomGenerator(res.ptr)
# end
return res
end
end

BNNS.bnns_rng() = RNG()
BNNS.bnns_rng(seed::Integer) = RNG(seed)

@static if isdefined(Base, :Memory) #VERSION >= v"1.11"
function _get_rng_state(rng::RNG)
stateSize = BNNSRandomGeneratorStateSize(rng.ptr)
state = Memory{UInt8}(undef, Int64(stateSize))
BNNSRandomGeneratorGetState(rng.ptr, stateSize, state)
return state

Check warning on line 54 in ext/RandomExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RandomExt.jl#L50-L54

Added lines #L50 - L54 were not covered by tests
end
else
function _get_rng_state(rng::RNG)
stateSize = BNNSRandomGeneratorStateSize(rng.ptr)
state = Vector{UInt8}(undef, Int64(stateSize))
BNNSRandomGeneratorGetState(rng.ptr, stateSize, state)
return state
end
end

function Base.copy!(dest::RNG, src::RNG)
state = _get_rng_state(src)
BNNSRandomGeneratorSetState(dest.ptr, length(state), state)
return dest
end

function Base.copy(rng::RNG)
newrng = RNG()
return copy!(newrng, rng)

Check warning on line 73 in ext/RandomExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RandomExt.jl#L71-L73

Added lines #L71 - L73 were not covered by tests
end

Base.:(==)(rng1::RNG, rng2::RNG) = _get_rng_state(rng1) == _get_rng_state(rng2)

Check warning on line 76 in ext/RandomExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RandomExt.jl#L76

Added line #L76 was not covered by tests

function Random.seed!(rng::RNG, seed::Integer)
return copy!(rng, RNG(seed))
end

function Random.seed!(rng::RNG)
return copy!(rng, RNG())

Check warning on line 83 in ext/RandomExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RandomExt.jl#L82-L83

Added lines #L82 - L83 were not covered by tests
end

const GLOBAL_RNG = Ref{RNG}()
function BNNS.default_rng()
if !isassigned(GLOBAL_RNG)
GLOBAL_RNG[] = BNNS.bnns_rng()
end
return GLOBAL_RNG[]
end

const BNNSInt = Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}
const BNNSFloat = Union{Float16, Float32, BFloat16}

const BNNSUniform = Union{<:BNNSInt,<:BNNSFloat}
const BNNSNormal = BNNSFloat

function Random.rand!(rng::RNG, A::DenseArray{T}) where {T<:BNNSInt}
isempty(A) && return A
desc = Ref(BNNSNDArrayDescriptor(A))
res = BNNSRandomFillUniformInt(rng.ptr, desc, typemin(signed(T)), typemax(signed(T)))
@assert res == 0
return A
end
function Random.rand!(rng::RNG, A::DenseArray{T}) where {T<:BNNSFloat}
isempty(A) && return A
desc = Ref(BNNSNDArrayDescriptor(A))
res = BNNSRandomFillUniformFloat(rng.ptr, desc, T(0), T(1))
@assert res == 0
return A
end
function Random.randn!(rng::RNG, A::DenseArray{T}) where {T<:BNNSFloat}
isempty(A) && return A
desc = Ref(BNNSNDArrayDescriptor(A))
res = BNNSRandomFillNormalFloat(rng.ptr, desc, Float32(0), Float32(1))
@assert res == 0
return A
end

# Out of place
Random.rand(rng::RNG, ::Type{T}, dims::Dims) where T <: BNNSUniform =
Random.rand!(rng, Array{T,length(dims)}(undef, dims...))
Random.randn(rng::RNG, ::Type{T}, dims::Dims) where T <: BNNSNormal =

Check warning on line 125 in ext/RandomExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RandomExt.jl#L125

Added line #L125 was not covered by tests
Random.randn!(rng, Array{T,length(dims)}(undef, dims...))

# support all dimension specifications
Random.rand(rng::RNG, ::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSUniform =
Random.rand!(rng, Array{T,length(dims) + 1}(undef, dim1, dims...))
Random.randn(rng::RNG, ::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSNormal =
Random.randn!(rng, Array{T,length(dims) + 1}(undef, dim1, dims...))

# untyped out-of-place
Random.rand(rng::RNG, dim1::Integer, dims::Integer...) =
Random.rand!(rng, Array{Float32,length(dims) + 1}(undef, dim1, dims...))
Random.randn(rng::RNG, dim1::Integer, dims::Integer...) =
Random.randn!(rng, Array{Float32,length(dims) + 1}(undef, dim1, dims...))

# scalars
Random.rand(rng::RNG, T::Union{Type{Float16}, Type{Float32}, Type{BFloat16},
Type{Int8}, Type{UInt8},
Type{Int16}, Type{UInt16},
Type{Int32}, Type{UInt32},
Type{Int64}, Type{UInt64}}=Float32) = Random.rand(rng, T, 1)[1]

# This is the only way I could fix method ambiguity
Random.randn(rng::RNG, T::Type{BFloat16}) = Random.randn(rng, T, 1)[1]
Random.randn(rng::RNG, T::Type{Float16}) = Random.randn(rng, T, 1)[1]

Check warning on line 149 in ext/RandomExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RandomExt.jl#L148-L149

Added lines #L148 - L149 were not covered by tests
Random.randn(rng::RNG, T::Type{Float32}) = Random.randn(rng, T, 1)[1]
Random.randn(rng::RNG) = Random.randn(rng, Float32)


# GPUArrays out-of-place
function BNNS.rand(::Type{T}, dims::Dims) where T <: BNNSUniform
return Random.rand!(BNNS.default_rng(), Array{T,length(dims)}(undef, dims...))
end
function BNNS.randn(::Type{T}, dims::Dims) where T <: BNNSNormal
return Random.randn!(BNNS.default_rng(), Array{T,length(dims)}(undef, dims...))
end

# support all dimension specifications
function BNNS.rand(::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSUniform
return Random.rand!(BNNS.default_rng(), Array{T,length(dims) + 1}(undef, dim1, dims...))
end
function BNNS.randn(::Type{T}, dim1::Integer, dims::Integer...) where T <: BNNSNormal
return Random.randn!(BNNS.default_rng(), Array{T,length(dims) + 1}(undef, dim1, dims...))
end

# untyped out-of-place
BNNS.rand(dim1::Integer, dims::Integer...) =
Random.rand!(BNNS.default_rng(), Array{Float32,length(dims) + 1}(undef, dim1, dims...))
BNNS.randn(dim1::Integer, dims::Integer...) =
Random.randn!(BNNS.default_rng(), Array{Float32,length(dims) + 1}(undef, dim1, dims...))

# scalars
BNNS.rand(T::Type=Float32) = BNNS.rand(T, 1)[1]
BNNS.randn(T::Type=Float32) = BNNS.randn(T, 1)[1]

# seeding
function BNNS.seed!(seed=Base.rand(UInt64))
Random.seed!(BNNS.default_rng(), seed)

Check warning on line 182 in ext/RandomExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RandomExt.jl#L181-L182

Added lines #L181 - L182 were not covered by tests
end

end
end # module
42 changes: 42 additions & 0 deletions lib/BNNS/BNNS.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using BFloat16s

include("libBNNS.jl")

bnnsdatatype_modifier(::Type{T}) where {T <: Union{AbstractFloat, Bool}} = BNNSDataTypeFloatBit
bnnsdatatype_modifier(::Type{T}) where {T <: Signed} = BNNSDataTypeIntBit
bnnsdatatype_modifier(::Type{T}) where {T <: Unsigned} = BNNSDataTypeUIntBit
bnnsdatatype_modifier(::Type{Bool}) = BNNSDataTypeMiscellaneousBit
bnnsdatatype_modifier(::Type{BFloat16}) = 0x18000

Base.convert(::Type{BNNSDataType}, T) = BNNSDataType(bnnsdatatype_modifier(T) | UInt32(sizeof(T)*8))

function BNNSNDArrayDescriptor(arr::AbstractArray{T, N}) where {T,N}
N > 8 && throw(ArgumentError("BNNSNDArrays do not support more than 8 dimensions."))


layout = BNNSDataLayout(UInt32(N) * UInt32(BNNSDataLayoutVector) | 0x8000)
# layout = datalayout[N]
sz = ntuple(Val(8)) do i
Csize_t(get(size(arr), i, 0))
end
stride = ntuple(_ -> Csize_t(0), Val(8))
return GC.@preserve arr BNNSNDArrayDescriptor(BNNSNDArrayFlagBackpropSet,
layout,
sz,
stride,
Ptr{Nothing}(pointer(arr)),
T,
0,
T,
1,
0)
end

# Definitions for the Random extension
function bnns_rng end
function default_rng end
function rand end
function randn end
function rand! end
function randn! end
function seed! end
Loading
Loading