|
| 1 | +# Kernel and Op Implementation and Registration API |
| 2 | + |
| 3 | +| Status | Accepted | |
| 4 | +:-------------- |:---------------------------------------------------- | |
| 5 | +| **Author(s) ** | James Ring ( [email protected]). | |
| 6 | +| **Sponsor ** | Günhan Gülsoy ( [email protected]) | |
| 7 | +| **Updated** | 2020-06-02 | |
| 8 | + |
| 9 | +## Objective |
| 10 | + |
| 11 | +Tensorflow (TF) currently provides a C++ API for implementing kernels and ops. |
| 12 | +The Voltron project aims to create a modular/plugin-based TF implementation with |
| 13 | +API and ABI surfaces. Plugins will be able to create and register custom kernel |
| 14 | +and op implementations. |
| 15 | + |
| 16 | +In order to provide a stable ABI, the Voltron team has chosen to provide C APIs |
| 17 | +to plugin authors. This document introduces the C API for op and kernel |
| 18 | +registration. For authors who wish to continue using C++ to interface with |
| 19 | +TensorFlow, an ABI-stable C++ header-only API is provided. |
| 20 | + |
| 21 | +## Motivation |
| 22 | + |
| 23 | +Presently, there is no ABI-stable API for extending TensorFlow with new kernels |
| 24 | +and ops. There is no guarantee that a plugin written with one compiler will work |
| 25 | +with a version of TensorFlow built with another, even on the same operating |
| 26 | +system and architecture. This makes it difficult to distribute plugins without |
| 27 | +also distributing the source code and requiring end-users to build the plugin |
| 28 | +alongside TensorFlow. |
| 29 | + |
| 30 | +An ABI-stable API for extending TensorFlow will simplify the distribution of |
| 31 | +plugins and allow plugin authors to distribute binary artifacts without |
| 32 | +necessarily publishing plugin source code. |
| 33 | + |
| 34 | +## User Benefit |
| 35 | + |
| 36 | +Plugin authors will be able to publish plugins that users can use more easily. |
| 37 | +In turn, the TensorFlow community will benefit from an increase in the number of |
| 38 | +variety of available plugins. |
| 39 | + |
| 40 | +## Design Overview |
| 41 | + |
| 42 | +In general, the kernel and op registration C APIs aim to permit the |
| 43 | +implementation of any kernel or op that is currently possible with the C++ API. |
| 44 | +Where possible, existing C++ function implementations are reused from within a C |
| 45 | +wrapper. The purpose of the wrapper is simply to provide ABI stability. |
| 46 | + |
| 47 | +Since plugins will be dynamically loaded (e.g. via `dlopen` on POSIX), the API |
| 48 | +avoids relying on static initialization. |
| 49 | + |
| 50 | +The intention is that existing kernels should be able to be ported to the new |
| 51 | +APIs with a minimum of reimplementation effort. This precludes a from-scratch |
| 52 | +re-imagining of TensorFlow APIs. |
| 53 | + |
| 54 | +The following diagram describes the components built with the proposed C and C++ |
| 55 | +APIs. |
| 56 | + |
| 57 | + +----------------+ <--+ |
| 58 | + | | | |
| 59 | + | Plugin | | |
| 60 | + | | | |
| 61 | + +----------------+ | |
| 62 | + | | | |
| 63 | + | C++ header API | | Plugin |
| 64 | + | | | my_plugin.so |
| 65 | + +--> +----------------+ | |
| 66 | + | | | | |
| 67 | + | | C API headers | | |
| 68 | + | | | | |
| 69 | + | +----------------+ <--+ |
| 70 | + | | | |
| 71 | + | | C API impl | |
| 72 | + Core | | | |
| 73 | + Tensorflow | +----------------+ |
| 74 | + libtf.so | | | |
| 75 | + | | Core C++ APIs | |
| 76 | + | | | |
| 77 | + +--> +----------------+ |
| 78 | + |
| 79 | +In this example, there are two object files: `my_plugin.so` and |
| 80 | +`libtensorflow.so`. `my_plugin.so` is implemented in terms of the C++ |
| 81 | +header-only API, which is in turn implemented in terms of the C API headers. The |
| 82 | +C API implementation is provided by TensorFlow at runtime when it loads the |
| 83 | +plugin's shared object. |
| 84 | + |
| 85 | +This design addresses changes that are required to the existing C API that are |
| 86 | +required to support op and kernel plugins. It also introduces the C++ |
| 87 | +header-only API, which currently does not exist. |
| 88 | + |
| 89 | +## Ops |
| 90 | + |
| 91 | +This section introduces changes to the C API that are required to support ops. |
| 92 | +An alpha version of this API is already checked in at `tensorflow/c/ops.h`. |
| 93 | + |
| 94 | +### Registration |
| 95 | + |
| 96 | +In the C++ API, ops are registered at static initialization time using the |
| 97 | +`REGISTER_OP` macro. For example: |
| 98 | + |
| 99 | +```c++ |
| 100 | +REGISTER_OP("Bitcast") |
| 101 | + .Input("input: T") |
| 102 | + .Output("output: type") |
| 103 | + .Attr("T: {bfloat16, ...}") |
| 104 | + .Attr("type: {bfloat16, ...}") |
| 105 | + .SetShapeFn([](InferenceContext* ctx) { ... }) |
| 106 | + .Doc("A bitcast operator"); |
| 107 | +``` |
| 108 | +
|
| 109 | +The equivalent C API will be a series of functions that operate on |
| 110 | +`TF_OpDefinitionBuilder *`, a pointer to an opaque struct (i.e. a struct whose |
| 111 | +content is not made known to the user). The functions include, but are not |
| 112 | +limited to: |
| 113 | +
|
| 114 | +* `TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name)`: |
| 115 | + constructs and returns a new op registration builder for an op with the given |
| 116 | + name |
| 117 | +
|
| 118 | +* `void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder, const |
| 119 | + char* attr)`: adds the given attribute to the builder (equivalent to `Attr` |
| 120 | + above) |
| 121 | +
|
| 122 | +* `void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder, const |
| 123 | + char* input)`: adds the given input to the builder (equivalent to `Input` |
| 124 | + above) |
| 125 | +
|
| 126 | +Additional functions are provided for setting other properties of the operation |
| 127 | +(e.g. `TF_OpDefinitionBuilderSetIsCommutative`). |
| 128 | +
|
| 129 | +Registration is then actually performed using the `TF_RegisterOpDefinition` |
| 130 | +function. This function populates a `TF_Status` indicating whether registration |
| 131 | +was successful and frees the resources associated with the op definition |
| 132 | +builder. |
| 133 | +
|
| 134 | +The C equivalent of the bitcast op registration example above is shown below: |
| 135 | +
|
| 136 | +```c++ |
| 137 | +
|
| 138 | +#include "tensorflow/c/ops.h" |
| 139 | +
|
| 140 | +void InferBitcastShape(TF_ShapeInferenceContext* ctx, // see the section below on |
| 141 | + TF_Status* status); // shape inference |
| 142 | +
|
| 143 | +void InitPlugin() { |
| 144 | + TF_OpDefinitionBuilder* b = TF_NewOpDefinitionBuilder("Bitcast"); |
| 145 | + TF_OpDefinitionBuilderAddInput(b, "input: T"); |
| 146 | + TF_OpDefinitionBuilderAddOutput(b, "output: type"); |
| 147 | + TF_OpDefinitionBuilderAddAttr(b, "T: {bfloat16, ...}"); |
| 148 | + TF_OpDefinitionBuilderAddAttr(b, "type: {bfloat16, ...}"); |
| 149 | + TF_OpDefinitionBuilderSetShapeInferenceFunction(b, &InferBitcastShape); |
| 150 | +
|
| 151 | + TF_Status* status = TF_NewStatus(); |
| 152 | + TF_RegisterOpDefinition(b, status); |
| 153 | + if (TF_GetCode(status) != TF_OK) { /* handle errors */ } |
| 154 | +} |
| 155 | +
|
| 156 | +``` |
| 157 | + |
| 158 | +### Shape Inference |
| 159 | + |
| 160 | +A significant feature of certain ops is their ability to infer their output |
| 161 | +shapes. TensorFlow will invoke the registered shape inference function (if one |
| 162 | +is provided) when it needs to know the op's output shape. The registration |
| 163 | +function declaration is shown below: |
| 164 | + |
| 165 | + |
| 166 | +```c++ |
| 167 | +void TF_OpDefinitionBuilderSetShapeInferenceFunction( |
| 168 | + TF_OpDefinitionBuilder* builder, |
| 169 | + void (*shape_inference_func)(TF_ShapeInferenceContext* ctx, TF_Status* status)); |
| 170 | +``` |
| 171 | +
|
| 172 | +A series of functions prefixed with `TF_ShapeInferenceContext` is provided for |
| 173 | +the following purposes: |
| 174 | +
|
| 175 | +* Examining operator input shapes (`TF_ShapeInferenceContextGetInput`) |
| 176 | +
|
| 177 | +* Creating and deleting shape and dimension handles (`TF_{New,Delete}ShapeHandle`, `TF_{New,Delete}DimensionHandle`) |
| 178 | +
|
| 179 | +* Manipulating shape and dimension handles (`TF_ShapeInferenceContextWithRank`, `TF_ShapeInferenceContextDim`) |
| 180 | +
|
| 181 | +In general, C analogues to the C++ methods in `tensorflow::shape_inference` |
| 182 | +(see `tensorflow/core/framework/shape_inference.h`) will be provided. |
| 183 | +
|
| 184 | +## Kernels |
| 185 | +
|
| 186 | +This section introduces changes to the C API that are required to support |
| 187 | +kernels. An alpha version of this API is already checked in at |
| 188 | +`tensorflow/c/kernels.h`. |
| 189 | +
|
| 190 | +### Registration |
| 191 | +
|
| 192 | +Kernel registration with the C++ API is accomplished with the |
| 193 | +`REGISTER_KERNEL_BUILDER` macro. This macro expands to code that relies on |
| 194 | +static initialization to register the provided kernel with the global kernel |
| 195 | +registry. See below for an example of registering a kernel with the C++ API: |
| 196 | +
|
| 197 | +```c++ |
| 198 | +
|
| 199 | +#include "tensorflow/core/framework/op_kernel.h" |
| 200 | +
|
| 201 | +class BitcastOp : public OpKernel { |
| 202 | + explicit BitcastOp(OpKernelConstruction* context) : OpKernel(context) { … } |
| 203 | + void Compute(OpKernelContext* context) override { … } |
| 204 | +}; |
| 205 | +
|
| 206 | +REGISTER_KERNEL_BUILDER(Name("Bitcast").Device(DEVICE_CPU), BitcastOp) |
| 207 | +``` |
| 208 | + |
| 209 | +The equivalent C API provides a series of functions that operate on |
| 210 | +`TF_KernelBuilder`, an opaque struct obtained with the `TF_NewKernelBuilder` call. |
| 211 | +The kernel builder is registered with TensorFlow using the |
| 212 | +`TF_RegisterKernelBuilder` function. See below for an example of registering |
| 213 | +the bitcast kernel using the C API: |
| 214 | + |
| 215 | +```c++ |
| 216 | +#include "tensorflow/c/kernels.h" |
| 217 | + |
| 218 | +typedef struct bitcast_kernel { … } bitcast_kernel; |
| 219 | + |
| 220 | +// Bitcast_Create, Bitcast_Compute and Bitcast_Delete actually implement the |
| 221 | +// kernel. See the section below for discussion on kernel implementation. |
| 222 | +static void* Bitcast_Create(TF_OpKernelConstruction* context) { |
| 223 | + bitcast_kernel* k = (bitcast_kernel*) calloc(1, sizeof(bitcast_kernel)); |
| 224 | + /* initialize the fields of k as needed */ |
| 225 | + return (void*) k; |
| 226 | +} |
| 227 | + |
| 228 | +static void* Bitcast_Compute(void* k, TF_OpKernelContext* context) { |
| 229 | + bitcast_kernel* kernel = (bitcast_kernel*) k; // this is the pointer returned by |
| 230 | + // Bitcast_Create |
| 231 | + /* compute the result */ |
| 232 | + TF_SetOutput(context, ...); |
| 233 | +} |
| 234 | + |
| 235 | +static void Bitcast_Delete(void *k) { free(k); } |
| 236 | + |
| 237 | +void InitPlugin() { |
| 238 | + TF_KernelBuilder* builder = TF_NewKernelBuilder(/*op_name*/"Bitcast", DEVICE_CPU, |
| 239 | + &Bitcast_Create, &Bitcast_Compute, &Bitcast_Delete); |
| 240 | + TF_Status* status = TF_NewStatus(); |
| 241 | + TF_RegisterKernelBuilder(/*kernel_name*/"Bitcast", builder, status); |
| 242 | + if (TF_GetCode(status) != TF_OK) { /* handle errors */ } |
| 243 | + TF_DeleteStatus(status); |
| 244 | +} |
| 245 | +``` |
| 246 | +
|
| 247 | +The registration function prototypes are provided below. Kernel authors must |
| 248 | +provide a compute function. Creation and deletion functions are optional, but |
| 249 | +if a creation function is provided that causes memory allocation, a deletion |
| 250 | +function that frees the memory should also be provided, otherwise a leak will |
| 251 | +occur. |
| 252 | +
|
| 253 | +```c++ |
| 254 | +TF_KernelBuilder* TF_NewKernelBuilder( |
| 255 | + const char* op_name, const char* device_name, |
| 256 | + void* (*create_func)(TF_OpKernelConstruction*), |
| 257 | + void (*compute_func)(void*, TF_OpKernelContext*), |
| 258 | + void (*delete_func)(void*)); |
| 259 | +
|
| 260 | +void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, |
| 261 | + TF_Status* status); |
| 262 | +``` |
| 263 | + |
| 264 | +### Implementation |
| 265 | + |
| 266 | +The main classes for C++ kernel implementations are `OpKernelCreation` |
| 267 | +(provided by TensorFlow to the kernel constructor) and `OpKernelContext` |
| 268 | +(provided to the kernel's `Compute` method). The analogues in the C API are |
| 269 | +`TF_OpKernelCreation` and `TF_OpKernelContext`. The aim of the C API is to |
| 270 | +provide functions for working with these structs that match, as closely as |
| 271 | +possible, the C++ API. |
| 272 | +
|
| 273 | +### Inputs and Outputs |
| 274 | +
|
| 275 | +Kernels must be able to retrieve their inputs and provide outputs. In the C++ |
| 276 | +API, the tensorflow::OpKernelContext::GetInput and SetOutput family of |
| 277 | +functions provide this functionality. The equivalent C calls will be |
| 278 | +`TF_GetInput` and `TF_SetInput`. These functions operate on `TF_Tensor`, which |
| 279 | +is already part of the existing TensorFlow C API. |
| 280 | +
|
| 281 | +String tensors will be supported in an ABI-stable way. This will require |
| 282 | +changes to their binary representation described in the [tstring design |
| 283 | +document](https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md). |
| 284 | +
|
| 285 | +## C++ Header-Only API |
| 286 | +
|
| 287 | +As described above, the main motivation for providing a C API is ABI stability. |
| 288 | +However, some programmers may find the C API less convenient than the |
| 289 | +non-ABI-stable C++ API. To address this concern, we plan to provide a |
| 290 | +header-only C++ API that is implemented in terms of the ABI-stable C API. This |
| 291 | +API will contain classes such as `Tensor`, `OpKernelContext`, and |
| 292 | +`OpKernelConstruction`, whose names will be familiar to existing C++ API users. |
| 293 | +Ideally, this API will be as close as possible to the existing non-ABI-stable |
| 294 | +Tensorflow C++ API, so that kernels and ops currently implemented in C++ may be |
| 295 | +ported to the ABI-stable C++ with as little implementation churn as possible. |
0 commit comments