@@ -1300,6 +1300,7 @@ macro compile(args...)
1300
1300
:raise => false ,
1301
1301
:shardy_passes => :(:to_mhlo_shardings ),
1302
1302
:assert_nonallocating => false ,
1303
+ :serializable => true ,
1303
1304
)
1304
1305
return esc (first (compile_call_expr (__module__, compile, default_options, args... )))
1305
1306
end
@@ -1963,7 +1964,7 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
1963
1964
return (client, device)
1964
1965
end
1965
1966
1966
- function compile_xla (f, args; client= nothing , kwargs... )
1967
+ function compile_xla (f, args; client= nothing , serializable :: Bool = false , kwargs... )
1967
1968
# register MLIR dialects
1968
1969
ctx = MLIR. IR. Context (Reactant. registry[], false )
1969
1970
context_gc_vector[ctx] = Vector {Union{TracedRArray,TracedRNumber}} (undef, 0 )
@@ -2002,6 +2003,15 @@ function compile_xla(f, args; client=nothing, kwargs...)
2002
2003
global_device_ids = collect (Int64, mlir_fn_res. global_device_ids)
2003
2004
mlir_fn_res. is_sharded && (device = nothing )
2004
2005
2006
+ # XLA.compile mutates the module, for serialization we need to keep a copy
2007
+ if serializable
2008
+ mod_pre_xla = MLIR. IR. Module (
2009
+ MLIR. API. mlirModuleFromOperation (copy (MLIR. IR. Operation (mod)))
2010
+ )
2011
+ else
2012
+ mod_pre_xla = mod
2013
+ end
2014
+
2005
2015
exec = XLA. compile (
2006
2016
client,
2007
2017
device,
@@ -2015,7 +2025,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
2015
2025
mlir_fn_res. use_shardy_partitioner,
2016
2026
)
2017
2027
2018
- return mod , exec, mlir_fn_res, device, client
2028
+ return mod_pre_xla , exec, mlir_fn_res, device, client
2019
2029
finally
2020
2030
MLIR. IR. deactivate! (ctx)
2021
2031
end
@@ -2158,16 +2168,24 @@ function compile(f, args; sync=false, kwargs...)
2158
2168
mlir_fn_res. fnwrapped,
2159
2169
exec,
2160
2170
mlir_fn_res. is_sharded ? nothing : device,
2171
+ serializable ? mod : nothing ,
2161
2172
)
2162
2173
end
2163
2174
2164
2175
# inspired by RuntimeGeneratedFunction.jl
2165
2176
const __thunk_body_cache = Dict {Symbol,Expr} ()
2166
2177
2167
- struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
2178
+ struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,M <: Union{Nothing,MLIR.IR.Module} }
2168
2179
f:: FTy
2169
2180
exec:: ExecTy
2170
2181
device:: DeviceTy
2182
+ mod:: M
2183
+ end
2184
+
2185
+ function Base. show (
2186
+ io:: IO , thunk:: Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
2187
+ ) where {FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
2188
+ return print (io, " Reactant compiled function $(thunk. f) (with tag $(tag) )" )
2171
2189
end
2172
2190
2173
2191
XLA. cost_analysis (thunk:: Thunk ) = XLA. cost_analysis (thunk. exec)
@@ -2179,11 +2197,16 @@ XLA.get_parameter_shardings(thunk::Thunk) = XLA.get_parameter_shardings(thunk.ex
2179
2197
struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
2180
2198
2181
2199
function Base. showerror (
2182
- io:: IO , ece:: MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}
2183
- ) where {FTy,tag,ArgTypes,FoundTypes,IsClosure}
2200
+ io:: IO ,
2201
+ :: MisMatchedThunkTypeError {
2202
+ Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy},FoundTypes
2203
+ },
2204
+ ) where {FTy,tag,ArgTypes,FoundTypes,IsClosure,ExecTy,DeviceTy}
2184
2205
print (
2185
2206
io,
2186
- " \n The Reactant-compiled function `$(Thunk{FTy, tag, ArgTypes, IsClosure}) ` exists, but no method is defined for this combination of argument types." ,
2207
+ " \n The Reactant-compiled function \
2208
+ `$(Thunk{FTy, tag, ArgTypes, IsClosure, ExecTy, DeviceTy}) ` exists, but no method \
2209
+ is defined for this combination of argument types." ,
2187
2210
)
2188
2211
print (
2189
2212
io,
@@ -2231,10 +2254,19 @@ function register_thunk(
2231
2254
isclosure:: Bool ,
2232
2255
exec,
2233
2256
device,
2257
+ mod,
2234
2258
)
2235
2259
__thunk_body_cache[tag] = body
2236
- return Thunk {Core.Typeof(f),tag,argtys,isclosure,Core.Typeof(exec),Core.Typeof(device)} (
2237
- f, exec, device
2260
+ return Thunk{
2261
+ Core. Typeof (f),
2262
+ tag,
2263
+ argtys,
2264
+ isclosure,
2265
+ Core. Typeof (exec),
2266
+ Core. Typeof (device),
2267
+ Core. Typeof (mod),
2268
+ }(
2269
+ f, exec, device, mod
2238
2270
)
2239
2271
end
2240
2272
0 commit comments