|
5 | 5 | import logging |
6 | 6 | from typing import Any, List, Optional, Sequence, Tuple |
7 | 7 |
|
8 | | -import numpy as np |
9 | 8 | import tensorrt as trt |
10 | 9 | import torch |
11 | 10 | from torch.export import ExportedProgram |
@@ -53,7 +52,7 @@ def construct_refit_mapping( |
53 | 52 | module: torch.fx.GraphModule, |
54 | 53 | inputs: Sequence[Input], |
55 | 54 | settings: CompilationSettings = CompilationSettings(), |
56 | | -) -> dict[str, np.ndarray]: |
| 55 | +) -> Any: |
57 | 56 | """Find out the weight mapping between weight in exported program and TensorRT engine |
58 | 57 | Args: |
59 | 58 | module: FX GraphModule to interpret |
@@ -93,8 +92,8 @@ def construct_refit_mapping_from_weight_name_map( |
93 | 92 | # If weights is not in sd, we can leave it unchanged |
94 | 93 | continue |
95 | 94 | else: |
96 | | - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) |
97 | | - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) |
| 95 | + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) |
| 96 | + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) |
98 | 97 | engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( |
99 | 98 | to_torch_device(settings.device) |
100 | 99 | ) |
@@ -148,8 +147,8 @@ def _refit_single_trt_engine_with_gm( |
148 | 147 | for constant_name, val in constant_mapping.items(): |
149 | 148 | np_weight_type = val.dtype |
150 | 149 | val_tensor = torch.from_numpy(val).cuda() |
151 | | - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) |
152 | | - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) |
| 150 | + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) |
| 151 | + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) |
153 | 152 | constant_mapping_with_type[constant_name] = ( |
154 | 153 | val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), |
155 | 154 | trt_dtype, |
@@ -179,7 +178,7 @@ def _refit_single_trt_engine_with_gm( |
179 | 178 | raise AssertionError(f"{layer_name} is not found in weight mapping") |
180 | 179 | # Use Numpy to create weights |
181 | 180 | weight = mapping[layer_name] |
182 | | - trt_dtype = dtype.try_from(weight.dtype).to(trt.DataType) |
| 181 | + trt_dtype = dtype._from(weight.dtype).to(trt.DataType) |
183 | 182 | trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size) |
184 | 183 | refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) |
185 | 184 | refitted.add(layer_name) |
|
0 commit comments