diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index 2d47664870..437ca452ef 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -129,19 +129,23 @@ end ``` """ macro trace(args...) - track_numbers = true expr = first(args) - if length(args) > 1 && Meta.isexpr(args[1], :(=)) - tn_expr = args[1] - tn_expr.args[1] == :track_numbers || - error("@trace supports setting track_numbers, but got $(tn_expr)") - - track_numbers = tn_expr.args[2] - expr = only(args[2:end]) - else - expr = only(args) + options = Dict([:track_numbers => false, :include_paths => :([])]) + while length(args) > 1 + kwarg, args = first(args), args[2:end] + if !Meta.isexpr(kwarg, :(=)) + error("Expected keyword argument but got $(kwarg)") + end + option, value = kwarg.args + if !haskey(options, option) + error("Unknown keyword argument $(option), expected one of $(keys(options))") + else + options[option] = value + end end - track_numbers = track_numbers ? Number : Union{} + expr = only(args) + track_numbers = options[:track_numbers] ? Number : Union{} + include_paths_expr = options[:include_paths] expr = macroexpand(__module__, expr) if Meta.isexpr(expr, :(=)) @@ -157,11 +161,12 @@ macro trace(args...) return esc(trace_call(__module__, call)) end Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr; track_numbers)) - Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr; track_numbers))) + Meta.isexpr(expr, :for) && + return (esc(trace_for(__module__, expr; track_numbers, include_paths_expr))) return error("Only `if-elseif-else` blocks are currently supported by `@trace`") end -function trace_for(mod, expr; track_numbers) +function trace_for(mod, expr; track_numbers, include_paths_expr) Meta.isexpr(expr, :for, 2) || error("expected for expr") assign, body = expr.args @@ -216,6 +221,8 @@ function trace_for(mod, expr; track_numbers) ) for (s, ref) in zip(external_syms, ref_syms) ] + include_paths = gensym(:include_paths) + reactant_code_block = quote let args = $(args_init) cond_fn = @@ -238,13 +245,14 @@ function trace_for(mod, expr; track_numbers) $counter[].mlir_data = ($counter[] + 1).mlir_data nothing end - + $(include_paths) = $(include_paths_expr) $(ReactantCore).traced_while( cond_fn, body_fn, args; track_numbers=$(track_numbers), verify_arg_names=$(QuoteNode(args_names)), + include_paths=$(include_paths), ) end end diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 560210a227..b5c654ccda 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -9,7 +9,14 @@ function ReactantCore.traced_call(f::Function, args...) end function ReactantCore.traced_while( - cond_fn::CFn, body_fn::BFn, args; track_numbers=Number, verify_arg_names=nothing + cond_fn::CFn, + body_fn::BFn, + args; + track_numbers=Number, + verify_arg_names=nothing, + include_paths=[], ) where {CFn,BFn} - return Ops.while_loop(cond_fn, body_fn, args...; track_numbers, verify_arg_names) + return Ops.while_loop( + cond_fn, body_fn, args...; track_numbers, verify_arg_names, include_paths + ) end diff --git a/src/Ops.jl b/src/Ops.jl index 5cd4171aeb..34467b2d5f 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1722,7 +1722,12 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. end @noinline function while_loop( - cond_fn::CFn, body_fn::BFn, args...; track_numbers, verify_arg_names=nothing + cond_fn::CFn, + body_fn::BFn, + args...; + track_numbers, + verify_arg_names=nothing, + include_paths=[], ) where {CFn,BFn} # TODO: detect and prevent mutation within the condition @@ -1733,7 +1738,7 @@ end for (i, prev) in enumerate(args) @inbounds traced_args[i] = Reactant.make_tracer( - seen_args, prev, (), Reactant.NoStopTracedTrack; track_numbers + seen_args, prev, (), Reactant.NoStopTracedTrack; track_numbers, include_paths ) end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 31f466cd1a..cd7d394773 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -200,6 +200,7 @@ function make_mlir_fn( argprefix::Symbol=:args, resprefix::Symbol=:result, resargprefix::Symbol=:resargs, + include_paths=[], num_replicas=1, ) if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction @@ -238,7 +239,7 @@ function make_mlir_fn( end for i in 1:N @inbounds traced_args[i] = Reactant.make_tracer( - seen_args, args[i], (argprefix, i), inmode; toscalar, runtime + seen_args, args[i], (argprefix, i), inmode; toscalar, runtime, include_paths ) end @@ -376,6 +377,7 @@ function make_mlir_fn( (resargprefix, i), Reactant.NoStopTracedTrack; runtime, + include_paths=[], ) end traced_result diff --git a/src/Tracing.jl b/src/Tracing.jl index dc3965b40c..0f4e482b9c 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -14,6 +14,9 @@ end function traced_type_inner end +# temporary fallback to old behavior +traced_type_inner(args...; include_paths, kwargs...) = traced_type_inner(args...; kwargs...) + Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Type{Union{}}), @nospecialize(args...) ) @@ -40,7 +43,8 @@ for T in ( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) return T end @@ -52,9 +56,14 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) - if mode == ArrayToConcrete && T <: track_numbers + should_track = T <: track_numbers + if !isempty(include_paths) + @assert all(isempty, include_paths) "Expected no include path to point inside a $T." + end + if mode == ArrayToConcrete && should_track if runtime isa Val{:PJRT} return ConcretePJRTNumber{ T,Sharding.ndevices(sharding),Sharding.shard_type(typeof(sharding), 0) @@ -65,24 +74,53 @@ Base.@nospecializeinfer function traced_type_inner( error("Unsupported runtime $runtime") end elseif (mode == NoStopTracedTrack || mode == TracedTrack || mode == TracedSetPath) && - T <: track_numbers + should_track return TracedRNumber{T} end return T end +function path_subtract(paths, val) + paths = filter(paths) do path + length(path) > 1 && first(path) == val + end + return map(paths) do path + return path[2:end] + end +end + Base.@nospecializeinfer function traced_type_inner( @nospecialize(C::Type{<:Complex}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) - C isa UnionAll || return Complex{ - traced_type_inner(C.parameters[1], seen, mode, track_numbers, sharding, runtime) - } - return C + C isa UnionAll && return C + re_T = traced_type_inner( + C.parameters[1], + seen, + mode, + track_numbers, + sharding, + runtime, + path_subtract(include_paths, 1), + ) + im_T = traced_type_inner( + C.parameters[1], + seen, + mode, + track_numbers, + sharding, + runtime, + path_subtract(include_paths, 2), + ) + if re_T != im_T + throw(NoFieldMatchError(C, C, (re_T, im_T))) + end + return Complex{re_T} end Base.@nospecializeinfer function traced_type_inner( @@ -91,7 +129,8 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) # functions are directly returned if T === Function || sizeof(T) == 0 @@ -104,7 +143,13 @@ Base.@nospecializeinfer function traced_type_inner( traced_fieldtypes = Type[] for i in 1:N next = traced_type_inner( - fieldtype(T, i), seen, mode, track_numbers, getproperty(sharding, i), runtime + fieldtype(T, i), + seen, + mode, + track_numbers, + getproperty(sharding, i), + runtime, + path_subtract(include_paths, i), ) changed |= next != fieldtype(T, i) push!(traced_fieldtypes, next) @@ -124,7 +169,8 @@ Base.@nospecializeinfer function traced_tuple_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) if T === Tuple return T @@ -133,7 +179,15 @@ Base.@nospecializeinfer function traced_tuple_type_inner( if T.var.lb === Union{} && T.var.ub === Any return UnionAll( T.var, - traced_type_inner(T.body, seen, mode, track_numbers, sharding, runtime), + traced_type_inner( + T.body, + seen, + mode, + track_numbers, + sharding, + runtime, + path_subtract(include_paths, 1), + ), ) end throw(AssertionError("Type $T is not concrete type or concrete tuple")) @@ -141,7 +195,13 @@ Base.@nospecializeinfer function traced_tuple_type_inner( TT = Union{Type,Core.TypeofVararg}[] for i in 1:length(T.parameters) st = traced_type_inner( - T.parameters[i], seen, mode, track_numbers, sharding, runtime + T.parameters[i], + seen, + mode, + track_numbers, + sharding, + runtime, + path_subtract(include_paths, i), ) push!(TT, st) end @@ -154,9 +214,14 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) - return Vararg{traced_type_inner(T.T, seen, mode, track_numbers, sharding, runtime),T.N} + @assert isempty(include_paths) "TODO: handle this." + return Vararg{ + traced_type_inner(T.T, seen, mode, track_numbers, sharding, runtime, include_paths), + T.N, + } end Base.@nospecializeinfer function traced_type_inner( @@ -165,8 +230,10 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) + @assert isempty(include_paths) "Expected no include path to point to a TypeVar." if T.lb === Union{} && T.ub === Any return T end @@ -179,11 +246,14 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) N = T.parameters[1] V = T.parameters[2] - return NamedTuple{N,traced_type_inner(V, seen, mode, track_numbers, sharding, runtime)} + return NamedTuple{ + N,traced_type_inner(V, seen, mode, track_numbers, sharding, runtime, include_paths) + } end Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict}) = nothing @@ -209,14 +279,16 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) + @assert isempty(include_paths) "TODO: handle dictionaries" V = dict_value(T) if V === nothing return T else K = dict_key(T) - V2 = traced_type_inner(V, seen, mode, track_numbers, sharding, runtime) + V2 = traced_type_inner(V, seen, mode, track_numbers, sharding, runtime, []) if V == V2 return T end @@ -239,8 +311,10 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) + @assert all(isempty, include_paths) "Expected no include path to point inside a ConcretePJRTNumber." if T0 isa UnionAll T = T0.body isa UnionAll ? T0.body.body.parameters[1] : T0.body.parameters[1] else @@ -267,8 +341,10 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) + @assert all(isempty, include_paths) "Expected no include path to point inside a ConcretePJRTNumber." T = T0 isa UnionAll ? T0.body.parameters[1] : T0.parameters[1] if mode == ConcreteToTraced @@ -289,8 +365,10 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) + @assert all(isempty, include_paths) "Expected no include path to point inside a ConcretePJRTArray." if T isa UnionAll if T.body isa UnionAll elT, N = T.body.body.parameters[1], T.body.body.parameters[2] @@ -321,8 +399,10 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) + @assert all(isempty, include_paths) "Expected no include path to point inside a ConcreteIFRTArray." if T isa UnionAll elT, N = T.body.parameters[1], T.body.parameters[2] else @@ -347,8 +427,10 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) + @assert all(isempty, include_paths) "Expected no include path to point inside a ConcreteRNG." if mode == ConcreteToTraced return TracedRNG elseif mode == TracedToConcrete @@ -370,8 +452,10 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) + @assert all(isempty, include_paths) "Expected no include path to point inside a TracedRArray." if mode == ConcreteToTraced throw("TracedRArray cannot be traced") elseif mode == TracedToConcrete @@ -403,8 +487,10 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) + @assert all(isempty, include_paths) "Expected no include path to point inside a TracedRNumber." if mode == ConcreteToTraced throw("TracedRNumber cannot be traced") elseif mode == TracedToConcrete @@ -449,14 +535,21 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) if mode == ConcreteToTraced throw("TracedRNG cannot be traced") elseif mode == TracedToConcrete return ConcreteRNG{ traced_type_inner( - TracedRArray{UInt64,1}, seen, mode, track_numbers, sharding, runtime + TracedRArray{UInt64,1}, + seen, + mode, + track_numbers, + sharding, + runtime, + path_subtract(include_paths, 1), ), } elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath @@ -472,7 +565,8 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) return A end @@ -483,11 +577,13 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) where {T} + @assert isempty(include_paths) "TODO: handle include_paths for AbstractArray{T}" if mode == ConcreteToTraced return AbstractArray{ - traced_type_inner(T, seen, mode, track_numbers, sharding, runtime) + traced_type_inner(T, seen, mode, track_numbers, sharding, runtime, []) } else return A @@ -500,11 +596,13 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) where {T,N} + @assert isempty(include_paths) "TODO: handle include_paths for AbstractArray{T}" if mode == ConcreteToTraced return AbstractArray{ - traced_type_inner(T, seen, mode, track_numbers, sharding, runtime),N + traced_type_inner(T, seen, mode, track_numbers, sharding, runtime, []),N } else return A @@ -517,8 +615,10 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) + @assert isempty(include_paths) "TODO: handle include_paths for Array" T = eltype(A) if A isa UnionAll if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive @@ -528,7 +628,7 @@ Base.@nospecializeinfer function traced_type_inner( else return Array{ traced_type_inner( - T, seen, mode, track_numbers, getproperty(sharding, 1), runtime + T, seen, mode, track_numbers, getproperty(sharding, 1), runtime, [] ), } end @@ -544,7 +644,7 @@ Base.@nospecializeinfer function traced_type_inner( else return Array{ traced_type_inner( - T, seen, mode, track_numbers, getproperty(sharding, 1), runtime + T, seen, mode, track_numbers, getproperty(sharding, 1), runtime, [] ), N, } @@ -558,10 +658,12 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( mode::Reactant.TraceMode, @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) where {T,N,P,I,L} - P2 = Reactant.traced_type_inner(P, seen, mode, track_numbers, sharding, runtime) - I2 = Reactant.traced_type_inner(I, seen, mode, track_numbers, sharding, runtime) + @assert isempty(include_paths) "TODO: handle include_paths for SubArray" + P2 = Reactant.traced_type_inner(P, seen, mode, track_numbers, sharding, runtime, []) + I2 = Reactant.traced_type_inner(I, seen, mode, track_numbers, sharding, runtime, []) T2 = eltype(P2) return SubArray{T2,N,P2,I2,L} end @@ -573,7 +675,8 @@ for P in (Ptr, Core.LLVMPtr, Base.RefValue) @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) return $(P) end @@ -585,11 +688,13 @@ for P in (Ptr, Base.RefValue) @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) where {T} + @assert isempty(include_paths) "TODO: handle include_paths for $P{T}" return $P{ traced_type_inner( - PT.parameters[1], seen, mode, track_numbers, sharding, runtime + PT.parameters[1], seen, mode, track_numbers, sharding, runtime, [] ), } end @@ -601,11 +706,13 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) where {T} + @assert isempty(include_paths) "TODO: handle include_paths for Core.LLVMPtr" return Core.LLVMPtr{ traced_type_inner( - PT.body.parameters[1], seen, mode, track_numbers, sharding, runtime + PT.body.parameters[1], seen, mode, track_numbers, sharding, runtime, [] ), } end @@ -615,10 +722,15 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) where {T,A} + @assert isempty(include_paths) "TODO: handle include_paths for Core.LLVMPtr" return Core.LLVMPtr{ - traced_type_inner(PT.parameters[1], seen, mode, track_numbers, sharding, runtime),A + traced_type_inner( + PT.parameters[1], seen, mode, track_numbers, sharding, runtime, [] + ), + A, } end @@ -628,7 +740,8 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), + @nospecialize(include_paths), ) if T === Any return T @@ -647,15 +760,18 @@ Base.@nospecializeinfer function traced_type_inner( end if T <: Tuple - return traced_tuple_type_inner(T, seen, mode, track_numbers, sharding, runtime) + return traced_tuple_type_inner( + T, seen, mode, track_numbers, sharding, runtime, include_paths + ) end # unknown number of fields if Base.inferencebarrier(T) isa UnionAll if T.var.lb === Union{} && T.var.ub === Any || T <: Type + @assert isempty(include_paths) "TODO: handle include_paths for UnionAll type" return UnionAll( T.var, - traced_type_inner(T.body, seen, mode, track_numbers, sharding, runtime), + traced_type_inner(T.body, seen, mode, track_numbers, sharding, runtime, []), ) end aT = Base.argument_datatype(T) @@ -669,9 +785,10 @@ Base.@nospecializeinfer function traced_type_inner( end if T isa Union + @assert isempty(include_paths) "TODO: handle include_paths for Union type" return Union{ - traced_type_inner(T.a, seen, mode, track_numbers, sharding, runtime), - traced_type_inner(T.b, seen, mode, track_numbers, sharding, runtime), + traced_type_inner(T.a, seen, mode, track_numbers, sharding, runtime, []), + traced_type_inner(T.b, seen, mode, track_numbers, sharding, runtime, []), } end @@ -683,10 +800,6 @@ Base.@nospecializeinfer function traced_type_inner( throw(TracedTypeError("Unhandled abstract type $T")) end - if T <: Tuple - return traced_tuple_type_inner(T, seen, mode, track_numbers, sharding, runtime) - end - if haskey(seen, T) return seen[T] end @@ -698,7 +811,15 @@ Base.@nospecializeinfer function traced_type_inner( subTys = Union{Type,TypeVar}[] for f in 1:fieldcount(T) subT = fieldtype(T, f) - subTT = traced_type_inner(subT, seen2, mode, track_numbers, sharding, runtime) + subTT = traced_type_inner( + subT, + seen2, + mode, + track_numbers, + sharding, + runtime, + path_subtract(include_paths, f), + ) changed |= subT != subTT push!(subTys, subTT) end @@ -718,6 +839,7 @@ Base.@nospecializeinfer function traced_type_inner( for (i, SST) in enumerate(T.parameters) if wrapped_cpjrt_array && i == 1 && SST isa Type && SST <: ReactantPrimitive # XXX: Sharding??? + # TODO: what should happen with include_paths here? TrT = traced_type_inner( ConcretePJRTNumber{ SST,Sharding.ndevices(sharding),Sharding.shard_type(typeof(sharding), 0) @@ -727,6 +849,7 @@ Base.@nospecializeinfer function traced_type_inner( track_numbers, sharding, runtime, + [], ) push!(subParms, TrT) elseif wrapped_cifrt_array && i == 1 && SST isa Type && SST <: ReactantPrimitive @@ -738,16 +861,19 @@ Base.@nospecializeinfer function traced_type_inner( track_numbers, sharding, runtime, + [], ) push!(subParms, TrT) elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber TrT = traced_type_inner( - unwrapped_eltype(SST), seen, mode, track_numbers, sharding, runtime + unwrapped_eltype(SST), seen, mode, track_numbers, sharding, runtime, [] ) push!(subParms, TrT) else if SST isa Type - TrT = traced_type_inner(SST, seen, mode, track_numbers, sharding, runtime) + TrT = traced_type_inner( + SST, seen, mode, track_numbers, sharding, runtime, [] + ) push!(subParms, TrT) else push!(subParms, SST) @@ -767,7 +893,15 @@ Base.@nospecializeinfer function traced_type_inner( for f in 1:fieldcount(T) subT = fieldtype(T, f) subT2 = fieldtype(TT2, f) - subTT = traced_type_inner(subT, seen3, mode, track_numbers, sharding, runtime) + subTT = traced_type_inner( + subT, + seen3, + mode, + track_numbers, + sharding, + runtime, + path_subtract(include_paths, f), + ) if subT2 != subTT legal = false break @@ -878,7 +1012,7 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type,Any},Dict{Type,Type}}() # end Base.@assume_effects :total @inline function traced_type( - T::Type, ::Val{mode}, track_numbers::Type, sharding, runtime + T::Type, ::Val{mode}, track_numbers::Type, sharding, runtime, include_paths=[] ) where {mode} if mode == TracedSetPath || mode == TracedTrack return T @@ -886,13 +1020,15 @@ Base.@assume_effects :total @inline function traced_type( cache = nothing cache_key = (mode, track_numbers, sharding) - if haskey(traced_type_cache, cache_key) + if false && haskey(traced_type_cache, cache_key) cache = traced_type_cache[cache_key] else cache = Dict{Type,Type}() traced_type_cache[cache_key] = cache end - return traced_type_inner(T, cache, mode, track_numbers, sharding, runtime) + return traced_type_inner( + T, cache, mode, track_numbers, sharding, runtime, include_paths + ) end abstract type TracedTypeException <: Exception end @@ -955,6 +1091,7 @@ Base.@nospecializeinfer function make_tracer_via_immutable_constructor( @nospecialize(track_numbers::Type = Union{}), @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), + @nospecialize(include_paths = []), kwargs..., ) RT = Core.Typeof(prev) @@ -970,7 +1107,7 @@ Base.@nospecializeinfer function make_tracer_via_immutable_constructor( push!(path, RT) seen[prev] = VisitedObject(length(seen) + 1) end - TT = traced_type(RT, Val(mode), track_numbers, sharding, runtime) + TT = traced_type(RT, Val(mode), track_numbers, sharding, runtime, include_paths) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) @@ -999,6 +1136,7 @@ Base.@nospecializeinfer function make_tracer_via_immutable_constructor( track_numbers, sharding=Base.getproperty(sharding, i), runtime, + include_paths=path_subtract(include_paths, i), kwargs..., ) if xi !== xi2 @@ -1030,6 +1168,7 @@ Base.@nospecializeinfer function make_tracer_unknown( @nospecialize(track_numbers::Type = Union{}), @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), + @nospecialize(include_paths = []), kwargs..., ) RT = Core.Typeof(prev) @@ -1045,7 +1184,7 @@ Base.@nospecializeinfer function make_tracer_unknown( push!(path, RT) seen[prev] = VisitedObject(length(seen) + 1) end - TT = traced_type(RT, Val(mode), track_numbers, sharding, runtime) + TT = traced_type(RT, Val(mode), track_numbers, sharding, runtime, include_paths) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) @@ -1074,6 +1213,7 @@ Base.@nospecializeinfer function make_tracer_unknown( track_numbers, sharding=Base.getproperty(sharding, i), runtime, + include_paths=path_subtract(include_paths, i), kwargs..., ) if xi !== xi2 @@ -1111,6 +1251,7 @@ Base.@nospecializeinfer function make_tracer_unknown( track_numbers, sharding=getproperty(sharding, i), runtime, + include_paths=path_subtract(include_paths, i), kwargs..., ) if xi !== xi2 @@ -1149,10 +1290,11 @@ function make_tracer( @nospecialize(track_numbers::Type = Union{}), @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), + @nospecialize(include_paths = []), kwargs..., ) return make_tracer_unknown( - seen, prev, path, mode; track_numbers, sharding, runtime, kwargs... + seen, prev, path, mode; track_numbers, sharding, runtime, include_paths, kwargs... ) end @@ -1433,6 +1575,7 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(track_numbers::Type = Union{}), @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), + @nospecialize(include_paths = []), kwargs..., ) if mode == TracedToTypes @@ -1440,7 +1583,12 @@ Base.@nospecializeinfer function make_tracer( return nothing end RT = Core.Typeof(prev) - if RT <: track_numbers && mode != TracedSetPath && mode != TracedTrack + should_track = RT <: track_numbers + if !isempty(include_paths) + @assert all(isempty, include_paths) "include path cannot point into a number" + should_track = true + end + if should_track && mode != TracedSetPath && mode != TracedTrack if mode == ArrayToConcrete runtime isa Val{:PJRT} && return ConcretePJRTNumber(prev; sharding) runtime isa Val{:IFRT} && return ConcreteIFRTNumber(prev; sharding) @@ -1496,18 +1644,47 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(path), mode; @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(include_paths = []), kwargs..., ) Sharding.is_sharded(sharding) && error("Cannot specify sharding for Complex") if mode == TracedToTypes push!(path, Core.Typeof(prev)) - make_tracer(seen, prev.re, path, mode; kwargs...) - make_tracer(seen, prev.im, path, mode; kwargs...) + make_tracer( + seen, + prev.re, + path, + mode; + include_paths=path_subtract(include_paths, 1), + kwargs..., + ) + make_tracer( + seen, + prev.im, + path, + mode; + include_paths=path_subtract(include_paths, 1), + kwargs..., + ) return nothing end return Complex( - make_tracer(seen, prev.re, append_path(path, :re), mode; kwargs...), - make_tracer(seen, prev.im, append_path(path, :im), mode; kwargs...), + make_tracer( + seen, + prev.re, + append_path(path, :re), + mode; + include_paths=path_subtract(include_paths, 1), + kwargs..., + ), + make_tracer( + seen, + prev.im, + append_path(path, :im), + mode; + include_paths=path_subtract(include_paths, 2), + kwargs..., + ), ) end @@ -1519,6 +1696,7 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(track_numbers::Type = Union{}), @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), + @nospecialize(include_paths = []), kwargs..., ) RT = Core.Typeof(prev) @@ -1551,7 +1729,15 @@ Base.@nospecializeinfer function make_tracer( if isassigned(prev, I) pv = prev[I] make_tracer( - seen, pv, path, mode; track_numbers, sharding, runtime, kwargs... + seen, + pv, + path, + mode; + track_numbers, + sharding, + runtime, + include_paths=path_subtract(paths, I), + kwargs..., ) end end @@ -1572,11 +1758,13 @@ Base.@nospecializeinfer function make_tracer( track_numbers, sharding=Base.getproperty(sharding, I), runtime, + include_paths=path_subtract(include_paths, I), kwargs..., ) if pv !== nv same = false end + # TODO: nice error if types don't match up (if an element was part of the include path and has been traced) @inbounds newa[I] = nv end end @@ -1595,6 +1783,7 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(track_numbers::Type = Union{}), @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), + @nospecialize(include_paths = []), kwargs..., ) where {Key,Value} RT = Core.Typeof(prev) @@ -1627,8 +1816,8 @@ Base.@nospecializeinfer function make_tracer( end return nothing end - Value2 = traced_type(Value, Val(mode), track_numbers, sharding, runtime) - newa = Dict{Key,Value2}() + new_DT = traced_type(RT, Val(mode), track_numbers, sharding, runtime, include_paths) + newa = new_DT() seen[prev] = newa same = true for (k, v) in prev @@ -1640,11 +1829,17 @@ Base.@nospecializeinfer function make_tracer( track_numbers, sharding=Base.getproperty(sharding, k), runtime, + include_paths=path_subtract(include_paths, k), kwargs..., ) if v !== nv same = false end + if !(nv isa dict_value(new_DT)) + error( + "Value at key $k has type $(typeof(nv)), but expected $(dict_value(new_DT))" + ) + end newa[k] = nv end if same @@ -1660,6 +1855,7 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(path), mode; @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(include_paths = []), kwargs..., ) RT = Core.Typeof(prev) @@ -1667,7 +1863,13 @@ Base.@nospecializeinfer function make_tracer( push!(path, RT) for (i, v) in enumerate(prev) make_tracer( - seen, v, path, mode; sharding=Base.getproperty(sharding, i), kwargs... + seen, + v, + path, + mode; + sharding=Base.getproperty(sharding, i), + include_paths=path_subtract(include_paths, i), + kwargs..., ) end return nothing @@ -1680,6 +1882,7 @@ Base.@nospecializeinfer function make_tracer( append_path(path, i), mode; sharding=Base.getproperty(sharding, i), + include_paths=path_subtract(include_paths, i), kwargs..., ) for (i, v) in enumerate(prev) )..., @@ -1694,6 +1897,7 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(track_numbers::Type = Union{}), @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), + @nospecialize(include_paths = []), kwargs..., ) NT = Core.Typeof(prev) @@ -1704,12 +1908,21 @@ Base.@nospecializeinfer function make_tracer( push!(path, NT) for i in 1:length(A) make_tracer( - seen, Base.getfield(prev, i), path, mode; track_numbers, sharding, kwargs... + seen, + Base.getfield(prev, i), + path, + mode; + track_numbers, + include_paths=path_subtract(include_paths, i), + sharding, + kwargs..., ) end return nothing end - return NamedTuple{A,traced_type(RT, Val(mode), track_numbers, sharding, runtime)}(( + return NamedTuple{ + A,traced_type(RT, Val(mode), track_numbers, sharding, runtime, include_paths) + }(( ( make_tracer( seen, @@ -1718,6 +1931,7 @@ Base.@nospecializeinfer function make_tracer( mode; sharding=Base.getproperty(sharding, i), track_numbers, + include_paths=path_subtract(include_paths, i), runtime, kwargs..., ) for i in 1:length(A) @@ -1731,11 +1945,20 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(path), mode; @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(include_paths = []), kwargs..., ) if mode == TracedToTypes push!(path, Core.Box) - return make_tracer(seen, prev.contents, path, mode; sharding, kwargs...) + return make_tracer( + seen, + prev.contents, + path, + mode; + sharding, + include_paths=path_subtract(include_paths, :contents), + kwargs..., + ) end if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] @@ -1747,6 +1970,7 @@ Base.@nospecializeinfer function make_tracer( append_path(path, :contents), mode; sharding=Base.getproperty(sharding, :contents), + include_paths=path_subtract(include_paths, :contents), kwargs..., ) if tr === prev2 @@ -1786,10 +2010,17 @@ end @nospecialize(x), @nospecialize(track_numbers::Type), @nospecialize(sharding), - @nospecialize(runtime) + @nospecialize(runtime), ) return make_tracer( - OrderedIdDict(), x, (), ArrayToConcrete; track_numbers, sharding, runtime + OrderedIdDict(), + x, + (), + ArrayToConcrete; + track_numbers, + sharding, + runtime, + include_paths=[], ) end @@ -1904,13 +2135,34 @@ function Reactant.traced_type_inner( track_numbers::Type, sharding, runtime, + include_paths, ) - (T,) = RT.parameters - newT = Reactant.traced_type_inner(T, seen, mode, track_numbers, sharding, runtime) - if T == newT + FTs = fieldtypes(RT) + Tstart = Reactant.traced_type_inner( + FTs[1], + seen, + mode, + track_numbers, + sharding, + runtime, + path_subtract(include_paths, :start), + ) + Tstop = Reactant.traced_type_inner( + FTs[2], + seen, + mode, + track_numbers, + sharding, + runtime, + path_subtract(include_paths, :stop), + ) + if Tstart != Tstop + throw(NoFieldMatchError(RT, RT, (Tstart, Tstop))) + end + if Tstart == first(FTs) return RT else - return TracedRNumberOverrides.TracedUnitRange{newT} + return TracedRNumberOverrides.TracedUnitRange{Tstart} end end @@ -1920,6 +2172,7 @@ function Reactant.make_tracer( @nospecialize(path), mode; @nospecialize(sharding = Sharding.NoSharding()), + include_paths, kwargs..., ) Reactant.Sharding.is_sharded(sharding) && error("Cannot specify sharding for UnitRange") @@ -1930,10 +2183,20 @@ function Reactant.make_tracer( return nothing end newstart = Reactant.make_tracer( - seen, prev.start, Reactant.append_path(path, :start), mode; kwargs... + seen, + prev.start, + Reactant.append_path(path, :start), + mode; + include_paths=path_subtract(include_paths, :start), + kwargs..., ) newstop = Reactant.make_tracer( - seen, prev.stop, Reactant.append_path(path, :stop), mode; kwargs... + seen, + prev.stop, + Reactant.append_path(path, :stop), + mode; + include_paths=path_subtract(include_paths, :stop), + kwargs..., ) if typeof(newstart) == typeof(prev.start) && typeof(newstop) == typeof(prev.stop) return prev @@ -1949,7 +2212,15 @@ function Reactant.traced_type_inner( track_numbers::Type, sharding, runtime, + include_paths, ) + @assert isempty(include_paths) "Currently cannot have include_paths pointing to a StepRangeLen because it has a type parameter that is statically derived from the fieldtypes." + # FTs = fieldtypes(RT) + # Tref = Reactant.traced_type_inner(FTs[1], seen, mode, track_numbers, sharding, runtime, path_subtrace(include_paths, :ref)) + # Tstep = Reactant.traced_type_inner(FTs[2], seen, mode, track_numbers, sharding, runtime, path_subtrace(include_paths, :step)) + # Tlen = Reactant.traced_type_inner(FTs[3], seen, mode, track_numbers, sharding, runtime, path_subtrace(include_paths, :len)) + # Toffset = Reactant.traced_type_inner(FTs[4], seen, mode, track_numbers, sharding, runtime, path_subtrace(include_paths, :offset)) + T, R, S, L = RT.parameters newT = Reactant.traced_type_inner(T, seen, mode, track_numbers, sharding, runtime) newR = Reactant.traced_type_inner(R, seen, mode, track_numbers, sharding, runtime) @@ -1968,29 +2239,55 @@ function Reactant.make_tracer( @nospecialize(path), mode; @nospecialize(sharding = Sharding.NoSharding()), + include_paths, kwargs..., ) + @assert isempty(include_paths) "Currently cannot have include_paths pointing to a StepRangeLen because it has a type parameter that is statically derived from the fieldtypes." Reactant.Sharding.is_sharded(sharding) && error("Cannot specify sharding for StepRangeLen") if mode == Reactant.TracedToTypes push!(path, Core.Typeof(prev)) - make_tracer(seen, prev.ref, path, mode; sharding, kwargs...) - make_tracer(seen, prev.step, path, mode; sharding, kwargs...) - make_tracer(seen, prev.len, path, mode; sharding, kwargs...) - make_tracer(seen, prev.offset, path, mode; sharding, kwargs...) + make_tracer(seen, prev.ref, path, mode; sharding, include_paths=[], kwargs...) + make_tracer(seen, prev.step, path, mode; sharding, include_paths=[], kwargs...) + make_tracer(seen, prev.len, path, mode; sharding, include_paths=[], kwargs...) + make_tracer(seen, prev.offset, path, mode; sharding, include_paths=[], kwargs...) return nothing end newref = Reactant.make_tracer( - seen, prev.ref, Reactant.append_path(path, :ref), mode; sharding, kwargs... + seen, + prev.ref, + Reactant.append_path(path, :ref), + mode; + sharding, + include_paths=[], + kwargs..., ) newstep = Reactant.make_tracer( - seen, prev.step, Reactant.append_path(path, :step), mode; sharding, kwargs... + seen, + prev.step, + Reactant.append_path(path, :step), + mode; + sharding, + include_paths=[], + kwargs..., ) newlen = Reactant.make_tracer( - seen, prev.len, Reactant.append_path(path, :len), mode; sharding, kwargs... + seen, + prev.len, + Reactant.append_path(path, :len), + mode; + sharding, + include_paths=[], + kwargs..., ) newoffset = Reactant.make_tracer( - seen, prev.offset, Reactant.append_path(path, :offset), mode; sharding, kwargs... + seen, + prev.offset, + Reactant.append_path(path, :offset), + mode; + sharding, + include_paths=[], + kwargs..., ) if typeof(newref) == typeof(prev.ref) && typeof(newstep) == typeof(prev.step) &&