Skip to content

Commit a94dedc

Browse files
authored
bf16 support for elementwise operation (#3462)
1 parent 8943fb9 commit a94dedc

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Callable, Optional, Union
44

55
import numpy as np
6+
import tensorrt as trt
67
import torch
78
from torch.fx.node import Target
89
from torch_tensorrt import _enums
@@ -15,11 +16,10 @@
1516
get_trt_tensor,
1617
has_dynamic_shape,
1718
set_layer_name,
19+
to_torch,
1820
)
1921
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor
2022

21-
import tensorrt as trt
22-
2323

2424
def get_python_op_from_trt_elementwise_op(
2525
trt_op: TRTElementWiseOp,
@@ -125,10 +125,9 @@ def convert_binary_elementwise(
125125
# dtype but we don't have a way to detect whether it makes sense for the
126126
# scalar to be float or half. Hence we go with the lhs dtype.
127127
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)):
128-
rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype))
128+
rhs_val = to_torch(rhs_val, dtype=lhs_dtype)
129129
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)):
130-
lhs_val = np.array([lhs_val], dtype=_enums.dtype._from(rhs_dtype).to(np.dtype))
131-
130+
lhs_val = to_torch(lhs_val, dtype=rhs_dtype)
132131
lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
133132
rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)
134133

tests/py/dynamo/conversion/test_binary_ops_aten.py

+22
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,28 @@ def forward(self, x, y):
228228
]
229229
self.run_test_with_dynamic_shape(Op(), input_specs)
230230

231+
@parameterized.expand(
232+
[
233+
(f"bf16_{op[0].__name__}_one_constant", op[0])
234+
for op in elementwise_ops
235+
if op[0].__name__ not in ["pow.Tensor_Tensor", "fmod.Tensor"]
236+
]
237+
)
238+
def test_elementwise_ops_bf16(self, _, orig_op):
239+
class TestModule(nn.Module):
240+
def __init__(self, orig_op):
241+
super().__init__()
242+
self.constant = torch.randn(1)
243+
self.orig_op = orig_op
244+
245+
def forward(self, x):
246+
x = self.orig_op(x, self.constant)
247+
return self.orig_op(x, -2)
248+
249+
m = TestModule(orig_op)
250+
inputs = [torch.randn(2, 2, dtype=torch.bfloat16)]
251+
self.run_test(m, inputs)
252+
231253

232254
if __name__ == "__main__":
233255
run_tests()

0 commit comments

Comments
 (0)