|
5 | 5 | import logging
|
6 | 6 | from typing import Any, List, Optional, Sequence, Tuple
|
7 | 7 |
|
8 |
| -import numpy as np |
9 | 8 | import tensorrt as trt
|
10 | 9 | import torch
|
11 | 10 | from torch.export import ExportedProgram
|
@@ -53,7 +52,7 @@ def construct_refit_mapping(
|
53 | 52 | module: torch.fx.GraphModule,
|
54 | 53 | inputs: Sequence[Input],
|
55 | 54 | settings: CompilationSettings = CompilationSettings(),
|
56 |
| -) -> dict[str, np.ndarray]: |
| 55 | +) -> Any: |
57 | 56 | """Find out the weight mapping between weight in exported program and TensorRT engine
|
58 | 57 | Args:
|
59 | 58 | module: FX GraphModule to interpret
|
@@ -93,8 +92,8 @@ def construct_refit_mapping_from_weight_name_map(
|
93 | 92 | # If weights is not in sd, we can leave it unchanged
|
94 | 93 | continue
|
95 | 94 | else:
|
96 |
| - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) |
97 |
| - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) |
| 95 | + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) |
| 96 | + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) |
98 | 97 | engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
|
99 | 98 | to_torch_device(settings.device)
|
100 | 99 | )
|
@@ -148,8 +147,8 @@ def _refit_single_trt_engine_with_gm(
|
148 | 147 | for constant_name, val in constant_mapping.items():
|
149 | 148 | np_weight_type = val.dtype
|
150 | 149 | val_tensor = torch.from_numpy(val).cuda()
|
151 |
| - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) |
152 |
| - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) |
| 150 | + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) |
| 151 | + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) |
153 | 152 | constant_mapping_with_type[constant_name] = (
|
154 | 153 | val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
|
155 | 154 | trt_dtype,
|
@@ -179,7 +178,7 @@ def _refit_single_trt_engine_with_gm(
|
179 | 178 | raise AssertionError(f"{layer_name} is not found in weight mapping")
|
180 | 179 | # Use Numpy to create weights
|
181 | 180 | weight = mapping[layer_name]
|
182 |
| - trt_dtype = dtype.try_from(weight.dtype).to(trt.DataType) |
| 181 | + trt_dtype = dtype._from(weight.dtype).to(trt.DataType) |
183 | 182 | trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size)
|
184 | 183 | refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
|
185 | 184 | refitted.add(layer_name)
|
|
0 commit comments