Skip to content

feat: add leaf macro to avoid tracing into types #854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/api/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ Reactant.Profiler.@annotate
Reactant.devices
Reactant.addressable_devices
```

## Tracing

```@docs
Reactant.@leaf
```
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ReactantCore: ReactantCore, @trace, within_compile, MissingTracedValue

using LinearAlgebra: LinearAlgebra
using Random: Random, AbstractRNG
using Functors: @leaf
using Functors: Functors

using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`
Expand Down
89 changes: 65 additions & 24 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,61 @@
function traced_type_inner end
function make_tracer end

"""
@leaf type [make_tracer = true]

This marks a type as a leaf type for the purposes of tracing in reactant. This means that
we won't recurse into the type and it will be left untouched.
"""
macro leaf(args...)
@assert length(args) ≥ 1
orig_type, args = args[1], args[2:end]

options = Dict{Symbol,Any}()
while length(args) ≥ 1
if !Meta.isexpr(args[1], :(=))
error("Invalid argument $(args[1])")
end
options[args[1].args[1]] = args[1].args[2]
args = args[2:end]
end

subtype = Meta.isexpr(orig_type, :(<:))
type = subtype ? orig_type.args[1] : orig_type

traced_type_inner_expr = quote
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(T::Type{$(orig_type)}),
seen,
@nospecialize(mode::$(TraceMode)),
@nospecialize(track_numbers::Type),
@nospecialize(sharding),
@nospecialize(runtime),
)
return T
end
end

make_tracer_expr = if get(options, :make_tracer, true)
quote
function Reactant.make_tracer(
seen, @nospecialize(prev::$(type)), @nospecialize(path), mode; kwargs...
)
return prev
end
end
else
:()
end

return esc(
quote
$traced_type_inner_expr
$make_tracer_expr
end,
)
end

@enum TraceMode begin
ConcreteToTraced = 1
TracedTrack = 2
Expand All @@ -14,35 +72,27 @@ end

function traced_type_inner end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{Union{}}), @nospecialize(args...)
)
return T
for T in (Symbol, Union{})
@eval begin
@leaf $T make_tracer = false
end
end

for T in (
DataType,
Module,
Nothing,
Symbol,
AbstractChar,
AbstractString,
AbstractFloat,
Integer,
RNumber,
Val,
VersionNumber,
Base.ExceptionStack,
Core.MethodInstance,
)
@eval Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{<:$T}),
seen,
@nospecialize(mode::TraceMode),
@nospecialize(track_numbers::Type),
@nospecialize(sharding),
@nospecialize(runtime)
)
return T
end
@eval @leaf <:$T
end

Base.@nospecializeinfer function traced_type_inner(
Expand Down Expand Up @@ -907,15 +957,6 @@ function Base.showerror(io::IO, err::NoFieldMatchError)
end
end

function make_tracer(
seen,
@nospecialize(prev::Union{Base.ExceptionStack,Core.MethodInstance}),
@nospecialize(path),
mode;
kwargs...,
)
return prev
end
append_path(@nospecialize(path), i) = (path..., i)

function make_tracer(
Expand Down
10 changes: 5 additions & 5 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ abstract type AbstractConcreteArray{T,N} <: RArray{T,N} end
# Traced Types

## MissingTracedValue -- defined in ReactantCore
@leaf MissingTracedValue
Functors.@leaf MissingTracedValue

## TracedRNumber
mutable struct TracedRNumber{T} <: RNumber{T}
Expand All @@ -26,7 +26,7 @@ mutable struct TracedRNumber{T} <: RNumber{T}
end
end

@leaf TracedRNumber
Functors.@leaf TracedRNumber

## TracedRArray
mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
Expand All @@ -45,7 +45,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
end
end

@leaf TracedRArray
Functors.@leaf TracedRArray
Adapt.parent_type(::Type{TracedRArray{T,N}}) where {T,N} = TracedRArray{T,N}

const WrappedTracedRArray{T,N} = WrappedArray{
Expand Down Expand Up @@ -79,7 +79,7 @@ function ConcretePJRTNumber{T}(data::Tuple{XLA.PJRT.AsyncBuffer}) where {T}
return ConcretePJRTNumber{T,1,Sharding.NoShardInfo}(data, Sharding.NoShardInfo())
end

@leaf ConcretePJRTNumber
Functors.@leaf ConcretePJRTNumber

function ConcretePJRTNumber{T}(data::T2; kwargs...) where {T<:Number,T2<:Number}
carray = ConcretePJRTArray(fill(convert(T, data)); kwargs...)
Expand Down Expand Up @@ -115,7 +115,7 @@ mutable struct ConcretePJRTArray{T,N,D,S<:Sharding.ShardInfo} <: AbstractConcret
sharding::S
end

@leaf ConcretePJRTArray
Functors.@leaf ConcretePJRTArray
Adapt.parent_type(::Type{<:ConcretePJRTArray{T,N}}) where {T,N} = ConcretePJRTArray{T,N}
function Adapt.parent_type(::Type{ConcretePJRTArray{T,N,D,S}}) where {T,N,D,S}
return ConcretePJRTArray{T,N,D,S}
Expand Down
Loading