Skip to content

Commit c55caf7

Browse files
committed
feat: mostly working
1 parent 4d359d9 commit c55caf7

File tree

8 files changed

+222
-8
lines changed

8 files changed

+222
-8
lines changed

CondaPkg.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[pip.deps]
22
jax = ">= 0.6"
33
tensorflow = ">= 2.17"
4+
numpy = ">= 2"

deps/ReactantExtra/API.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,17 @@ extern "C" MlirOperation mlirOperationParse(MlirContext ctx, MlirBlock block,
270270
.release()};
271271
}
272272

273+
extern "C" MlirType mlirGetFunctionTypeFromOperation(MlirOperation op) {
274+
if (auto funcOp = dyn_cast<mlir::FunctionOpInterface>(unwrap(op))) {
275+
return wrap(funcOp.getFunctionType());
276+
}
277+
ReactantThrowError("Not a function op");
278+
}
279+
280+
extern "C" bool mlirIsFunctionOpInterface(MlirOperation op) {
281+
return llvm::isa<mlir::FunctionOpInterface>(unwrap(op));
282+
}
283+
273284
// TODO mlirComplexAttrGetnValue
274285
// TODO extern "C" MlirTypeID mlirComplexAttrGetTypeID(void) { return
275286
// wrap(complex::NumberAttr::getTypeID()); }

deps/ReactantExtra/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,8 @@ cc_library(
888888
"-Wl,-exported_symbol,_hlo_sharding_*",
889889
"-Wl,-exported_symbol,_free_ifrt_sharding",
890890
"-Wl,-exported_symbol,_addSdyPropagationPipeline",
891+
"-Wl,-exported_symbol,_mlirGetFunctionTypeFromOperation",
892+
"-Wl,-exported_symbol,_mlirIsFunctionOpInterface",
891893
],
892894
}),
893895
linkstatic = True,

ext/ReactantPythonCallExt/ReactantPythonCallExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ const JAX_TRACING_SUPPORTED = Ref{Bool}(false)
1010

1111
const tfptr = Ref{Py}()
1212
const tf2xlaptr = Ref{Py}()
13+
const npptr = Ref{Py}()
1314

1415
const SAVED_MODEL_EXPORT_SUPPORTED = Ref{Bool}(false)
1516

@@ -26,9 +27,9 @@ const NUMPY_SIMPLE_TYPES = Dict(
2627
Float16 => :float16,
2728
Float32 => :float32,
2829
Float64 => :float64,
29-
ComplexF16 => :complex32,
30-
ComplexF32 => :complex64,
31-
ComplexF64 => :complex128,
30+
ComplexF16 => :complex16,
31+
ComplexF32 => :complex32,
32+
ComplexF64 => :complex64,
3233
)
3334

3435
function __init__()
@@ -45,6 +46,7 @@ function __init__()
4546
tfptr[] = pyimport("tensorflow")
4647
tfptr[].config.set_visible_devices(pylist(); device_type="GPU")
4748
tf2xlaptr[] = pyimport("tensorflow.compiler.tf2xla.python.xla")
49+
npptr[] = pyimport("numpy")
4850
SAVED_MODEL_EXPORT_SUPPORTED[] = true
4951
catch err
5052
@warn "Failed to import tensorflow. Exporting Reactant compiled functions as \
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# TODO: at some point, we should use the TF C++ API to export the SavedModel
2+
3+
function Reactant.Serialization.serialization_supported(::Val{:SavedModel})
4+
return SAVED_MODEL_EXPORT_SUPPORTED[]
5+
end
6+
7+
function _extract_call_parameters(args::Tuple, input_locations, state_dict)
8+
call_args = []
9+
for loc in input_locations
10+
if loc isa Reactant.Serialization.TFSavedModel.InputArgument
11+
push!(call_args, args[loc.position])
12+
else
13+
push!(call_args, state_dict[loc.name])
14+
end
15+
end
16+
return call_args
17+
end
18+
19+
function _wrap_as_tf_func(spec::Reactant.Serialization.TFSavedModel.ReactantFunctionSpec)
20+
Touts = pylist([string(sig.dtype) for sig in spec.output_signature])
21+
Souts = pylist([pylist(sig.shape) for sig in spec.output_signature])
22+
return pyfunc(
23+
function (args...)
24+
return tf2xlaptr[].call_module(
25+
pytuple(
26+
_extract_call_parameters(args, spec.input_locations, spec.state_dict)
27+
);
28+
version=5,
29+
Tout=Touts, # dtype information
30+
Sout=Souts, # Shape information
31+
function_list=pylist([]), # No functions to call
32+
:module => spec.bytecode,
33+
)
34+
end,
35+
)
36+
end
37+
38+
function _make_input_signatures(
39+
fn_spec::Reactant.Serialization.TFSavedModel.ReactantFunctionSpec
40+
)
41+
input_pos_to_spec = Dict(
42+
loc.position => spec for
43+
(loc, spec) in zip(fn_spec.input_locations, fn_spec.input_signature) if
44+
loc isa Reactant.Serialization.TFSavedModel.InputArgument
45+
)
46+
47+
sigs = []
48+
for i in 1:length(input_pos_to_spec)
49+
spec = input_pos_to_spec[i]
50+
dtype = getproperty(tfptr[], spec.dtype)
51+
push!(
52+
sigs,
53+
tfptr[].TensorSpec(;
54+
shape=pylist(spec.shape), dtype=dtype, name="args_$(i - 1)"
55+
),
56+
)
57+
end
58+
return sigs
59+
end
60+
61+
function Reactant.Serialization.TFSavedModel.__to_tf_saved_model(
62+
fn_spec::Reactant.Serialization.TFSavedModel.ReactantFunctionSpec, path::String
63+
)
64+
tfm = tfptr[].Module()
65+
66+
state_dict = Dict(
67+
k => tfptr[].Variable(
68+
npptr[].asarray(permutedims(v, collect(ndims(v):-1:1)));
69+
# npptr[].asarray(v);
70+
trainable=false,
71+
name=k,
72+
) for (k, v) in fn_spec.state_dict
73+
)
74+
75+
@show fn_spec.input_signature
76+
@show fn_spec.output_signature
77+
78+
input_signatures = _make_input_signatures(fn_spec)
79+
80+
tfm.f = getproperty(tfptr[], :function)(
81+
_wrap_as_tf_func(fn_spec); input_signature=pylist(input_signatures)
82+
)
83+
tfm._variables = pylist(collect(values(state_dict)))
84+
85+
signatures = Dict(
86+
tfptr[].saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY =>
87+
tfm.f.get_concrete_function(pylist(input_signatures)...),
88+
)
89+
save_options = tfptr[].saved_model.SaveOptions(; function_aliases=Dict("" => tfm.f))
90+
91+
tfptr[].saved_model.save(tfm, path; signatures=signatures, options=save_options)
92+
93+
return nothing
94+
end

src/mlir/IR/Operation.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,18 @@ function create_operation_at_front(args...; kwargs...)
368368
Base.pushfirst!(block(), res)
369369
return res
370370
end
371+
372+
function FunctionType(op::Operation)
373+
is_function_op = @ccall API.mlir_c.mlirIsFunctionOpInterface(
374+
op::API.MlirOperation
375+
)::Bool
376+
if is_function_op
377+
return Type(
378+
@ccall API.mlir_c.mlirGetFunctionTypeFromOperation(
379+
op::API.MlirOperation
380+
)::API.MlirType
381+
)
382+
else
383+
throw("operation is not a function operation")
384+
end
385+
end

src/serialization/Serialization.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ module Serialization
77

88
using ..Reactant: Reactant, MLIR
99

10+
serialization_supported(::Val) = false
11+
1012
include("TFSavedModel.jl")
1113

1214
end

src/serialization/TFSavedModel.jl

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
11
module TFSavedModel
22

3+
using ..Serialization: serialization_supported
34
using ..Reactant: AbstractConcreteArray, AbstractConcreteNumber, Compiler, MLIR
45

56
# https://github.com/openxla/stablehlo/blob/955fa7e6e3b0a6411edc8ff6fcce1e644440acbd/stablehlo/integrations/python/stablehlo/savedmodel/stablehlo_to_tf_saved_model.py
67

8+
const NUMPY_SIMPLE_TYPES = Dict(
9+
Bool => :bool,
10+
Int8 => :int8,
11+
Int16 => :int16,
12+
Int32 => :int32,
13+
Int64 => :int64,
14+
UInt8 => :uint8,
15+
UInt16 => :uint16,
16+
UInt32 => :uint32,
17+
UInt64 => :uint64,
18+
Float16 => :float16,
19+
Float32 => :float32,
20+
Float64 => :float64,
21+
ComplexF16 => :complex16,
22+
ComplexF32 => :complex32,
23+
ComplexF64 => :complex64,
24+
)
25+
726
struct VariableSignature
827
shape::Vector{Int}
928
dtype::Symbol
@@ -19,21 +38,89 @@ struct Parameter <: VariableType
1938
name::String
2039
end
2140

41+
struct ReactantFunctionSpec
42+
input_signature::Vector{VariableSignature}
43+
output_signature::Vector{VariableSignature}
44+
input_locations::Vector{<:VariableType}
45+
bytecode::Base.CodeUnits{UInt8,String}
46+
state_dict::Dict
47+
end
48+
2249
function export_as_saved_model(
2350
thunk::Compiler.Thunk,
2451
saved_model_path::String,
2552
target_version::VersionNumber,
26-
input_locations,
27-
state_dict::Dict{String, <:Union{<:AbstractConcreteArray,<:AbstractConcreteNumber}},
53+
input_locations::Vector,
54+
state_dict::Dict,
2855
)
2956
isempty(thunk.module_string) && error(
3057
"To export a thunk, ensure that it has been compiled with `serializable=true`."
3158
)
3259

33-
mlir_mod = parse(MLIR.IR.Module, thunk.module_string)
34-
display(mlir_mod)
60+
if !serialization_supported(Val(:SavedModel))
61+
error("Serialization to SavedModel is not supported. This might happen if \
62+
PythonCall hasn't been installed and loaded.")
63+
end
64+
65+
mlir_mod = MLIR.IR.with_context() do ctx
66+
parse(MLIR.IR.Module, thunk.module_string)
67+
end
68+
69+
ftype = MLIR.IR.FunctionType(first(MLIR.IR.body(mlir_mod)))
70+
71+
input_signature = [
72+
VariableSignature(
73+
reverse(collect(Int64, size(MLIR.IR.input(ftype, i)))),
74+
# collect(Int64, size(MLIR.IR.input(ftype, i))),
75+
NUMPY_SIMPLE_TYPES[MLIR.IR.julia_type(eltype(MLIR.IR.input(ftype, i)))],
76+
) for i in 1:MLIR.IR.ninputs(ftype)
77+
]
78+
79+
output_signature = [
80+
VariableSignature(
81+
reverse(collect(Int64, size(MLIR.IR.result(ftype, i)))),
82+
# collect(Int64, size(MLIR.IR.result(ftype, i))),
83+
NUMPY_SIMPLE_TYPES[MLIR.IR.julia_type(eltype(MLIR.IR.result(ftype, i)))],
84+
) for i in 1:MLIR.IR.nresults(ftype)
85+
]
86+
87+
if isempty(input_locations)
88+
input_locations = [InputArgument(i) for i in 1:length(input_signature)]
89+
end
90+
91+
c_print_callback = @cfunction(
92+
MLIR.IR.print_callback, Cvoid, (MLIR.API.MlirStringRef, Any)
93+
)
94+
ref = Ref(IOBuffer())
95+
result = MLIR.IR.LogicalResult(
96+
MLIR.API.stablehloSerializePortableArtifactFromModule(
97+
mlir_mod, string(target_version), c_print_callback, ref, true
98+
),
99+
)
100+
MLIR.IR.isfailure(result) && throw("Couldn't serialize the module")
101+
serialized_module = codeunits(String(take!(ref[])))
102+
103+
return to_tf_saved_model(
104+
ReactantFunctionSpec(
105+
input_signature,
106+
output_signature,
107+
input_locations,
108+
serialized_module,
109+
Dict(k => Array(v) for (k, v) in state_dict),
110+
),
111+
saved_model_path,
112+
)
113+
end
35114

36-
return nothing
115+
function to_tf_saved_model(fn_spec::ReactantFunctionSpec, path::String)
116+
if !serialization_supported(Val(:SavedModel))
117+
error("Serialization to SavedModel is not supported. This might happen if \
118+
PythonCall hasn't been installed and loaded.")
119+
end
120+
return __to_tf_saved_model(fn_spec, path)
37121
end
38122

123+
# Defined in the PythonCallExt module
124+
function __to_tf_saved_model end
125+
39126
end

0 commit comments

Comments
 (0)