|
3 | 3 | from typing import Any, Callable, Optional, Union |
4 | 4 |
|
5 | 5 | import numpy as np |
| 6 | +import tensorrt as trt |
6 | 7 | import torch |
7 | 8 | from torch.fx.node import Target |
8 | 9 | from torch_tensorrt import _enums |
|
15 | 16 | get_trt_tensor, |
16 | 17 | has_dynamic_shape, |
17 | 18 | set_layer_name, |
| 19 | + to_torch, |
18 | 20 | ) |
19 | 21 | from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor |
20 | 22 |
|
21 | | -import tensorrt as trt |
22 | | - |
23 | 23 |
|
24 | 24 | def get_python_op_from_trt_elementwise_op( |
25 | 25 | trt_op: TRTElementWiseOp, |
@@ -125,10 +125,9 @@ def convert_binary_elementwise( |
125 | 125 | # dtype but we don't have a way to detect whether it makes sense for the |
126 | 126 | # scalar to be float or half. Hence we go with the lhs dtype. |
127 | 127 | if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)): |
128 | | - rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype)) |
| 128 | + rhs_val = to_torch(rhs_val, dtype=lhs_dtype) |
129 | 129 | if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)): |
130 | | - lhs_val = np.array([lhs_val], dtype=_enums.dtype._from(rhs_dtype).to(np.dtype)) |
131 | | - |
| 130 | + lhs_val = to_torch(lhs_val, dtype=rhs_dtype) |
132 | 131 | lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) |
133 | 132 | rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) |
134 | 133 |
|
|
0 commit comments