@@ -1300,6 +1300,7 @@ macro compile(args...)
13001300 :raise => false ,
13011301 :shardy_passes => :(:to_mhlo_shardings ),
13021302 :assert_nonallocating => false ,
1303+ :serializable => true ,
13031304 )
13041305 return esc (first (compile_call_expr (__module__, compile, default_options, args... )))
13051306end
@@ -1963,7 +1964,7 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
19631964 return (client, device)
19641965end
19651966
1966- function compile_xla (f, args; client= nothing , kwargs... )
1967+ function compile_xla (f, args; client= nothing , serializable :: Bool = false , kwargs... )
19671968 # register MLIR dialects
19681969 ctx = MLIR. IR. Context (Reactant. registry[], false )
19691970 context_gc_vector[ctx] = Vector {Union{TracedRArray,TracedRNumber}} (undef, 0 )
@@ -2002,6 +2003,15 @@ function compile_xla(f, args; client=nothing, kwargs...)
20022003 global_device_ids = collect (Int64, mlir_fn_res. global_device_ids)
20032004 mlir_fn_res. is_sharded && (device = nothing )
20042005
2006+ # XLA.compile mutates the module, for serialization we need to keep a copy
2007+ if serializable
2008+ mod_pre_xla = MLIR. IR. Module (
2009+ MLIR. API. mlirModuleFromOperation (copy (MLIR. IR. Operation (mod)))
2010+ )
2011+ else
2012+ mod_pre_xla = mod
2013+ end
2014+
20052015 exec = XLA. compile (
20062016 client,
20072017 device,
@@ -2015,7 +2025,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
20152025 mlir_fn_res. use_shardy_partitioner,
20162026 )
20172027
2018- return mod , exec, mlir_fn_res, device, client
2028+ return mod_pre_xla , exec, mlir_fn_res, device, client
20192029 finally
20202030 MLIR. IR. deactivate! (ctx)
20212031 end
@@ -2158,16 +2168,24 @@ function compile(f, args; sync=false, kwargs...)
21582168 mlir_fn_res. fnwrapped,
21592169 exec,
21602170 mlir_fn_res. is_sharded ? nothing : device,
2171+ serializable ? mod : nothing ,
21612172 )
21622173end
21632174
21642175# inspired by RuntimeGeneratedFunction.jl
21652176const __thunk_body_cache = Dict {Symbol,Expr} ()
21662177
2167- struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
2178+ struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,M <: Union{Nothing,MLIR.IR.Module} }
21682179 f:: FTy
21692180 exec:: ExecTy
21702181 device:: DeviceTy
2182+ mod:: M
2183+ end
2184+
2185+ function Base. show (
2186+ io:: IO , thunk:: Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
2187+ ) where {FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
2188+ return print (io, " Reactant compiled function $(thunk. f) (with tag $(tag) )" )
21712189end
21722190
21732191XLA. cost_analysis (thunk:: Thunk ) = XLA. cost_analysis (thunk. exec)
@@ -2179,11 +2197,16 @@ XLA.get_parameter_shardings(thunk::Thunk) = XLA.get_parameter_shardings(thunk.ex
21792197struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
21802198
21812199function Base. showerror (
2182- io:: IO , ece:: MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes}
2183- ) where {FTy,tag,ArgTypes,FoundTypes,IsClosure}
2200+ io:: IO ,
2201+ :: MisMatchedThunkTypeError {
2202+ Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy},FoundTypes
2203+ },
2204+ ) where {FTy,tag,ArgTypes,FoundTypes,IsClosure,ExecTy,DeviceTy}
21842205 print (
21852206 io,
2186- " \n The Reactant-compiled function `$(Thunk{FTy, tag, ArgTypes, IsClosure}) ` exists, but no method is defined for this combination of argument types." ,
2207+ " \n The Reactant-compiled function \
2208+ `$(Thunk{FTy, tag, ArgTypes, IsClosure, ExecTy, DeviceTy}) ` exists, but no method \
2209+ is defined for this combination of argument types." ,
21872210 )
21882211 print (
21892212 io,
@@ -2231,10 +2254,19 @@ function register_thunk(
22312254 isclosure:: Bool ,
22322255 exec,
22332256 device,
2257+ mod,
22342258)
22352259 __thunk_body_cache[tag] = body
2236- return Thunk {Core.Typeof(f),tag,argtys,isclosure,Core.Typeof(exec),Core.Typeof(device)} (
2237- f, exec, device
2260+ return Thunk{
2261+ Core. Typeof (f),
2262+ tag,
2263+ argtys,
2264+ isclosure,
2265+ Core. Typeof (exec),
2266+ Core. Typeof (device),
2267+ Core. Typeof (mod),
2268+ }(
2269+ f, exec, device, mod
22382270 )
22392271end
22402272
0 commit comments