diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 04202405d6..eb48c6ce03 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,4 +1,5 @@ #include +#include #include "exla_client.h" #include "exla_cuda.h" @@ -11,11 +12,36 @@ #include "stablehlo/dialect/StablehloOps.h" #include "xla/pjrt/pjrt_api.h" #include "xla/service/platform_util.h" +#include "xla/service/custom_call_target_registry.h" // All of these are created with calls to `new` and subsequently // passed to the VM as pointers-to-pointers so we balance it out // with calls to delete rather than just using the default destructor. +// We need to hold a reference to the `dlopen` handle for as long +// as EXLA is running, so we have this resource which holds the handle, +// then we define a custom free which calls `dlclose`. Then it's up to +// the caller to keep this resource in scope so it's not garbage collected +typedef struct { + void * handle; +} ExlaPlugin; + +typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[], int **dims); + +typedef struct { + const char* name; + ExlaCustomCallFunction func; +} ExlaPluginCustomCall; + +static ErlNifResourceType* exla_plugin_resource_type; + +void free_exla_plugin(ErlNifEnv* env, void* obj) { + ExlaPlugin* plugin = reinterpret_cast(obj); + if (plugin != nullptr) { + dlclose(plugin->handle); + } +} + void free_exla_executable(ErlNifEnv* env, void* obj) { exla::ExlaExecutable** executable = reinterpret_cast(obj); if (*executable != nullptr) { @@ -65,10 +91,17 @@ static int open_resources(ErlNifEnv* env) { if (!exla::nif::open_resource(env, mod, "ExlaMLIRModule")) { return -1; } - if (!exla::nif::open_resource(env, mod, "MLIRContext")) { return -1; } + + // Just a C Resource + ErlNifResourceFlags flags = ErlNifResourceFlags(ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER); + exla_plugin_resource_type = enif_open_resource_type(env, mod, "ExlaPlugin", free_exla_plugin, flags, NULL); + if (!exla_plugin_resource_type) { + return -1; + } + return 1; } @@ -911,6 +944,80 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) return exla::nif::ok(env); } +// Plugins + +ERL_NIF_TERM load_custom_call_plugin_library(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 1) { + return exla::nif::error(env, "Bad argument count."); + } + + std::string library_path; + + if (!exla::nif::get(env, argv[0], library_path)) { + return exla::nif::error(env, "Unable to get library path."); + } + + void* handle = dlopen(library_path.c_str(), RTLD_NOW); + if (!handle) { + return exla::nif::error(env, "Unable to open library."); + } + + ExlaPlugin* plugin = (ExlaPlugin*) enif_alloc_resource(exla_plugin_resource_type, sizeof(ExlaPlugin)); + plugin->handle = handle; + + ERL_NIF_TERM result = enif_make_resource(env, plugin); + enif_release_resource(plugin); + + return exla::nif::ok(env, result); +} + +ERL_NIF_TERM register_custom_call_symbol(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 3) { + return exla::nif::error(env, "Bad argument count."); + } + + ExlaPlugin* plugin; + std::string symbol; + std::vector> dimensions; + + if (!enif_get_resource(env, argv[0], exla_plugin_resource_type, (void **) &plugin)) { + return exla::nif::error(env, "Unable to get plugin."); + } + if (!exla::nif::get(env, argv[1], symbol)) { + return exla::nif::error(env, "Unable to get symbol."); + } + if (!exla::nif::get_list(env, argv[2], dimensions)) { + return exla::nif::error(env, "Unable to get dimensions."); + } + + ExlaCustomCallFunction function = (ExlaCustomCallFunction) dlsym(plugin->handle, symbol.c_str()); + + if (!function) { + return exla::nif::error(env, "Could not find symbol."); + } + + auto lambda = [&dimensions, function](void *in[], const void *out[]) { + std::vector> int_dims(dimensions.size()); + for (size_t i = 0; i < dimensions.size(); ++i) { + int_dims[i].resize(dimensions[i].size()); + std::transform(dimensions[i].begin(), dimensions[i].end(), int_dims[i].begin(), + [](exla::int64 x) { return static_cast(x); }); + } + + std::vector dims_ptrs; + for (auto& d : int_dims) { + dims_ptrs.push_back(d.data()); + } + + function(in, out, dims_ptrs.data()); + }; + + // TODO: GPU/Client flag + XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol.c_str(), function); + + return exla::nif::ok(env); +} + static ErlNifFunc exla_funcs[] = { // MLIR Builder {"mlir_new_context", 0, mlir_new_context}, @@ -947,6 +1054,10 @@ static ErlNifFunc exla_funcs[] = { {"start_log_sink", 1, start_log_sink}, // Serialization {"serialize_executable", 1, serialize_executable}, - {"deserialize_executable", 2, deserialize_executable}}; + {"deserialize_executable", 2, deserialize_executable}, + // Plugins + {"load_custom_call_plugin_library", 1, load_custom_call_plugin_library}, + {"register_custom_call_symbol", 3, register_custom_call_symbol} + }; ERL_NIF_INIT(Elixir.EXLA.NIF, exla_funcs, &load, NULL, &upgrade, NULL); diff --git a/exla/c_src/exla/exla_nif_util.cc b/exla/c_src/exla/exla_nif_util.cc index d38785f6ed..d802f2a55d 100644 --- a/exla/c_src/exla/exla_nif_util.cc +++ b/exla/c_src/exla/exla_nif_util.cc @@ -248,6 +248,25 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { return 1; } +int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector>& var) { + unsigned int length; + if (!enif_get_list_length(env, list, &length)) { + return 0; + } + var.reserve(length); + ERL_NIF_TERM head, tail; + + while (enif_get_list_cell(env, list, &head, &tail)) { + std::vector elem; + if (!get_list(env, head, elem)) { + return 0; + } + var.push_back(elem); + list = tail; + } + return 1; +} + int get_binary(ErlNifEnv* env, ERL_NIF_TERM term, ErlNifBinary* var) { return enif_inspect_binary(env, term, var); } diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index 5abf7e3cda..8244511174 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -247,6 +247,8 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); +int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector>& var); + template int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { unsigned int length; diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 3bdfa30d0c..03dc4b7e4c 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -18,6 +18,7 @@ defmodule EXLA.Application do name: EXLA.MLIR.ContextPool, lazy: true}, EXLA.Client, + EXLA.Plugin, EXLA.Defn.Lock, EXLA.Defn.LockedCache, {Task.Supervisor, name: EXLA.Defn.TaskSupervisor} diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index b53795055d..dd54727b3d 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -815,6 +815,27 @@ defmodule EXLA.MLIR.Value do {q, r} end + def plugin_custom_call(registered_name, [%Value{function: func} | _] = args, result_typespec) do + operand_shapes = + Enum.map(args, fn %Value{function: ^func} = value -> + %{shape: op_shape} = get_typespec(value) + constant(func, Tuple.to_list(op_shape), Typespec.tensor({:s, 64}, {length(op_shape)})) + end) + + operands = + args + |> Enum.zip_with(operand_shapes, fn val, shape -> [val, shape] end) + |> List.flatten() + + # TODO: GPU + attributes = [ + call_target_name: attr_string(registered_name), + backend_config: attr_string("Host") + ] + + op(func, "stablehlo.custom_call", operands, result_typespec, attributes: attributes) + end + def get_tuple_element(%Value{function: func} = operand, index, typespec) do result_types = typespecs_to_mlir_types([typespec]) attributes = [index: attr_i32(index)] diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 6830df726c..dd90ced016 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -112,4 +112,8 @@ defmodule EXLA.NIF do def get_c_api_client(_device_type), do: :erlang.nif_error(:undef) def load_pjrt_plugin(_device_type, _library_path), do: :erlang.nif_error(:undef) + + def load_custom_call_plugin_library(_library_path), do: :erlang.nif_error(:undef) + + def register_custom_call_symbol(_plugin, _symbol, _dimensions), do: :erlang.nif_error(:undef) end diff --git a/exla/lib/exla/plugin.ex b/exla/lib/exla/plugin.ex new file mode 100644 index 0000000000..b7f682867e --- /dev/null +++ b/exla/lib/exla/plugin.ex @@ -0,0 +1,56 @@ +defmodule EXLA.Plugin do + @moduledoc """ + Plugin system for registering custom calls. + """ + use GenServer + + # TODO: Register and lookup per client + + def start_link(_opts) do + GenServer.start_link(__MODULE__, %{}, name: __MODULE__) + end + + def register(key, library_path) do + GenServer.cast(__MODULE__, {:register, key, library_path}) + end + + def lookup(key) do + GenServer.call(__MODULE__, {:lookup, key}) + end + + def register_symbol(key, symbol, dimensions) do + if ref = lookup(key) do + EXLA.NIF.register_custom_call_symbol(ref, symbol, dimensions) + end + end + + @impl true + def init(_opts) do + {:ok, %{}} + end + + @impl true + def handle_cast({:register, key, library_path}, state) do + case state do + %{^key => _ref} -> + {:noreply, state} + + %{} -> + ref = + library_path + |> EXLA.NIF.load_custom_call_plugin_library() + |> unwrap!() + + {:noreply, Map.put(state, key, ref)} + end + end + + @impl true + def handle_call({:lookup, key}, _from, state) do + value = Map.get(state, key) + {:reply, value, state} + end + + defp unwrap!({:ok, ref}), do: ref + defp unwrap!({:error, reason}), do: raise("#{reason}") +end diff --git a/exla/test/exla/plugin_test.exs b/exla/test/exla/plugin_test.exs new file mode 100644 index 0000000000..ca9f9dfce1 --- /dev/null +++ b/exla/test/exla/plugin_test.exs @@ -0,0 +1,9 @@ +defmodule EXLA.PluginTest do + use ExUnit.Case + + describe "register/1" do + test "registers a plugin" do + assert :ok = EXLA.Plugin.register(:custom_plugin, "test/support/c/libcustom_plugin.so") + end + end +end diff --git a/exla/test/support/c/custom_plugin.c b/exla/test/support/c/custom_plugin.c new file mode 100644 index 0000000000..b3c70f0950 --- /dev/null +++ b/exla/test/support/c/custom_plugin.c @@ -0,0 +1,22 @@ +#include +#include + +typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[], int **dims); + +typedef struct { + const char* name; + ExlaCustomCallFunction func; +} ExlaPluginCustomCall; + +extern "C" void custom_increment(void *out[], const void *in[], int **dims) { + int64_t *operand = (int64_t *)in[0]; + int64_t *dim_sizes = (int64_t *)dims[0]; + + int64_t *out_buffer = (int64_t *)out[0]; + + int64_t n = dim_sizes[0]; + + for (int64_t i = 0; i < n; i++) { + out_buffer[i] = operand[i] + 1; + } +} \ No newline at end of file diff --git a/exla/test/support/c/libcustom_plugin.so b/exla/test/support/c/libcustom_plugin.so new file mode 100755 index 0000000000..90e7640043 Binary files /dev/null and b/exla/test/support/c/libcustom_plugin.so differ