Skip to content

Commit c585796

Browse files
authored
Add T=bfloat16 to custom_ops registration (#2688)
1 parent 50530e8 commit c585796

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ REGISTER_OP("Addons>AdjustHsvInYiq")
3131
.Input("scale_s: float")
3232
.Input("scale_v: float")
3333
.Output("output: T")
34-
.Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
34+
.Attr(
35+
"T: {uint8, int8, int16, int32, int64, half, float, double, bfloat16}")
3536
.SetShapeFn([](InferenceContext* c) {
3637
ShapeHandle images, delta_h, scale_s, scale_v;
3738

@@ -70,4 +71,4 @@ output: The hsv-adjusted image or images. No clipping will be done in this op.
7071
)Doc");
7172

7273
} // end namespace addons
73-
} // namespace tensorflow
74+
} // namespace tensorflow

tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ components: Component ids for each pixel in "image". Same shape as "image". Zero
5555

5656
REGISTER_OP("Addons>EuclideanDistanceTransform")
5757
.Input("images: uint8")
58-
.Attr("dtype: {float16, float32, float64}")
58+
.Attr("dtype: {bfloat16, float16, float32, float64}")
5959
.Output("transformed_images: dtype")
6060
.SetShapeFn(shape_inference::UnchangedShape)
6161
.Doc(EuclideanDistanceTransformDoc);
@@ -65,9 +65,9 @@ REGISTER_OP("Addons>ImageConnectedComponents")
6565
.Output("components: int64")
6666
.Attr(
6767
"dtype: {int64, int32, uint16, int16, uint8, int8, half, float, "
68-
"double, bool, string}")
68+
"bfloat16, double, bool, string}")
6969
.SetShapeFn(shape_inference::UnchangedShape)
7070
.Doc(ImageConnectedComponentsDoc);
7171

7272
} // end namespace addons
73-
} // namespace tensorflow
73+
} // namespace tensorflow

tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ REGISTER_OP("Addons>Resampler")
2929
.Input("data: T")
3030
.Input("warp: T")
3131
.Output("output: T")
32-
.Attr("T: {half, float, double}")
32+
.Attr("T: {bfloat16, half, float, double}")
3333
.SetShapeFn([](InferenceContext* c) {
3434
ShapeHandle data;
3535
ShapeHandle warp;
@@ -53,7 +53,7 @@ REGISTER_OP("Addons>ResamplerGrad")
5353
.Input("grad_output: T")
5454
.Output("grad_data: T")
5555
.Output("grad_warp: T")
56-
.Attr("T: {half, float, double}")
56+
.Attr("T: {bfloat16, half, float, double}")
5757
.SetShapeFn([](InferenceContext* c) {
5858
c->set_output(0, c->input(0));
5959
c->set_output(1, c->input(1));
@@ -62,4 +62,4 @@ REGISTER_OP("Addons>ResamplerGrad")
6262
.Doc(R"doc(Resampler Grad op.)doc");
6363

6464
} // namespace addons
65-
} // namespace tensorflow
65+
} // namespace tensorflow

tensorflow_addons/custom_ops/layers/cc/ops/embedding_bag_ops.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ REGISTER_OP("Addons>EmbeddingBag")
2828
.Input("params: T")
2929
.Input("weights: T")
3030
.Output("output: T")
31-
.Attr("T: {half, float, double}")
31+
.Attr("T: {bfloat16, half, float, double}")
3232
.Attr("Tindices: {int32, int64}")
3333
.Attr("combiner: {'SUM', 'MEAN'} = 'SUM'")
3434
.SetShapeFn([](InferenceContext* c) {
@@ -51,7 +51,7 @@ REGISTER_OP("Addons>EmbeddingBagGrad")
5151
.Input("grads: T")
5252
.Output("params_grads: T")
5353
.Output("weights_grads: T")
54-
.Attr("T: {half, float, double}")
54+
.Attr("T: {bfloat16, half, float, double}")
5555
.Attr("Tindices: {int32, int64}")
5656
.Attr("combiner: {'SUM', 'MEAN'} = 'SUM'")
5757
.SetShapeFn([](InferenceContext* c) {

0 commit comments

Comments
 (0)