Skip to content

Commit 7720eb5

Browse files
committed
feat: serialization
1 parent 0c1b0e3 commit 7720eb5

File tree

4 files changed

+96
-8
lines changed

4 files changed

+96
-8
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1313
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1414
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1515
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
16+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1617
LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e"
1718
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1819
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -73,6 +74,7 @@ Functors = "0.5"
7374
GPUArraysCore = "0.2"
7475
GPUCompiler = "1.3"
7576
HTTP = "1.10.15"
77+
JLD2 = "0.5.12"
7678
KernelAbstractions = "0.9.30"
7779
LLVM = "9.1"
7880
LLVMOpenMP_jll = "18.1.7"

src/Compiler.jl

+40-8
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,7 @@ macro compile(args...)
13001300
:raise => false,
13011301
:shardy_passes => :(:to_mhlo_shardings),
13021302
:assert_nonallocating => false,
1303+
:serializable => true,
13031304
)
13041305
return esc(first(compile_call_expr(__module__, compile, default_options, args...)))
13051306
end
@@ -1963,7 +1964,7 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
19631964
return (client, device)
19641965
end
19651966

1966-
function compile_xla(f, args; client=nothing, kwargs...)
1967+
function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs...)
19671968
# register MLIR dialects
19681969
ctx = MLIR.IR.Context(Reactant.registry[], false)
19691970
context_gc_vector[ctx] = Vector{Union{TracedRArray,TracedRNumber}}(undef, 0)
@@ -2002,6 +2003,15 @@ function compile_xla(f, args; client=nothing, kwargs...)
20022003
global_device_ids = collect(Int64, mlir_fn_res.global_device_ids)
20032004
mlir_fn_res.is_sharded && (device = nothing)
20042005

2006+
# XLA.compile mutates the module, for serialization we need to keep a copy
2007+
if serializable
2008+
mod_pre_xla = MLIR.IR.Module(
2009+
MLIR.API.mlirModuleFromOperation(copy(MLIR.IR.Operation(mod)))
2010+
)
2011+
else
2012+
mod_pre_xla = mod
2013+
end
2014+
20052015
exec = XLA.compile(
20062016
client,
20072017
device,
@@ -2015,7 +2025,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
20152025
mlir_fn_res.use_shardy_partitioner,
20162026
)
20172027

2018-
return mod, exec, mlir_fn_res, device, client
2028+
return mod_pre_xla, exec, mlir_fn_res, device, client
20192029
finally
20202030
MLIR.IR.deactivate!(ctx)
20212031
end
@@ -2158,16 +2168,24 @@ function compile(f, args; sync=false, kwargs...)
21582168
mlir_fn_res.fnwrapped,
21592169
exec,
21602170
mlir_fn_res.is_sharded ? nothing : device,
2171+
serializable ? mod : nothing,
21612172
)
21622173
end
21632174

21642175
# inspired by RuntimeGeneratedFunction.jl
21652176
const __thunk_body_cache = Dict{Symbol,Expr}()
21662177

2167-
struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
2178+
struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,M<:Union{Nothing,MLIR.IR.Module}}
21682179
f::FTy
21692180
exec::ExecTy
21702181
device::DeviceTy
2182+
mod::M
2183+
end
2184+
2185+
function Base.show(
2186+
io::IO, thunk::Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
2187+
) where {FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
2188+
return print(io, "Reactant compiled function $(thunk.f) (with tag $(tag))")
21712189
end
21722190

21732191
XLA.cost_analysis(thunk::Thunk) = XLA.cost_analysis(thunk.exec)
@@ -2179,11 +2197,16 @@ XLA.get_parameter_shardings(thunk::Thunk) = XLA.get_parameter_shardings(thunk.ex
21792197
struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
21802198

21812199
function Base.showerror(
2182-
io::IO, ece::MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}
2183-
) where {FTy,tag,ArgTypes,FoundTypes,IsClosure}
2200+
io::IO,
2201+
::MisMatchedThunkTypeError{
2202+
Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy},FoundTypes
2203+
},
2204+
) where {FTy,tag,ArgTypes,FoundTypes,IsClosure,ExecTy,DeviceTy}
21842205
print(
21852206
io,
2186-
"\nThe Reactant-compiled function `$(Thunk{FTy, tag, ArgTypes, IsClosure})` exists, but no method is defined for this combination of argument types.",
2207+
"\nThe Reactant-compiled function \
2208+
`$(Thunk{FTy, tag, ArgTypes, IsClosure, ExecTy, DeviceTy})` exists, but no method \
2209+
is defined for this combination of argument types.",
21872210
)
21882211
print(
21892212
io,
@@ -2231,10 +2254,19 @@ function register_thunk(
22312254
isclosure::Bool,
22322255
exec,
22332256
device,
2257+
mod,
22342258
)
22352259
__thunk_body_cache[tag] = body
2236-
return Thunk{Core.Typeof(f),tag,argtys,isclosure,Core.Typeof(exec),Core.Typeof(device)}(
2237-
f, exec, device
2260+
return Thunk{
2261+
Core.Typeof(f),
2262+
tag,
2263+
argtys,
2264+
isclosure,
2265+
Core.Typeof(exec),
2266+
Core.Typeof(device),
2267+
Core.Typeof(mod),
2268+
}(
2269+
f, exec, device, mod
22382270
)
22392271
end
22402272

src/Reactant.jl

+2
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ include("Compiler.jl")
181181

182182
include("Overlay.jl")
183183

184+
include("Serialize.jl")
185+
184186
function Enzyme.make_zero(
185187
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
186188
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}

src/Serialize.jl

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
module Serialize
2+
3+
# TODO: move these deps into an extension
4+
# TODO: Deal with sharding/global devices
5+
6+
using JLD2
7+
using Reactant: Reactant, MLIR
8+
9+
struct SerializedThunk{FTy,tag,ArgTypes,IsClosure}
10+
f::FTy
11+
body::Expr
12+
end
13+
14+
# function JLD2.writeas(
15+
# ::Type{<:Reactant.Compiler.Thunk{FTy,tag,ArgTypes,IsClosure}}
16+
# ) where {FTy,tag,ArgTypes,IsClosure}
17+
# return SerializedThunk{FTy,tag,ArgTypes,IsClosure}
18+
# end
19+
20+
# function JLD2.wconvert(
21+
# ::Type{SerializedThunk{FTy,tag,ArgTypes,IsClosure}},
22+
# thunk::Reactant.Compiler.Thunk{FTy,tag,ArgTypes,IsClosure},
23+
# ) where {FTy,tag,ArgTypes,IsClosure}
24+
# if thunk.mod === nothing
25+
# throw("To serialize a compiled thunk, ensure it is called with `serializable=true`")
26+
# end
27+
28+
# return error("TODO")
29+
# end
30+
31+
# function JLD2.rconvert(
32+
# ::Type{Reactant.Compiler.Thunk{FTy,tag,ArgTypes,IsClosure}},
33+
# serialized::SerializedThunk{FTy,tag,ArgTypes,IsClosure},
34+
# ) where {FTy,tag,ArgTypes,IsClosure}
35+
# return error("TODO")
36+
# end
37+
38+
function serialize(
39+
thunk::Reactant.Compiler.Thunk{FTy,tag,ArgTypes,IsClosure}
40+
) where {FTy,tag,ArgTypes,IsClosure}
41+
if thunk.mod === nothing
42+
throw("To serialize a compiled thunk, ensure it is called with `serializable=true`")
43+
end
44+
45+
serializable_thunk = SerializedThunk{FTy,tag,ArgTypes,IsClosure}(
46+
thunk.f, Reactant.Compiler.__thunk_body_cache[tag]
47+
)
48+
end
49+
50+
function deserialize() end
51+
52+
end

0 commit comments

Comments
 (0)