-
Notifications
You must be signed in to change notification settings - Fork 363
/
Copy path_features.py
99 lines (75 loc) · 3.07 KB
/
_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import sys
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
from torch_tensorrt._utils import sanitized_torch_version
from packaging import version
FeatureSet = namedtuple(
"FeatureSet",
[
"torchscript_frontend",
"torch_tensorrt_runtime",
"dynamo_frontend",
"fx_frontend",
"refit",
],
)
trtorch_dir = os.path.dirname(__file__)
linked_file = os.path.join(
"lib", "torchtrt.dll" if sys.platform.startswith("win") else "libtorchtrt.so"
)
linked_file_runtime = os.path.join(
"lib",
(
"torchtrt_runtime.dll"
if sys.platform.startswith("win")
else "libtorchtrt_runtime.so"
),
)
linked_file_full_path = os.path.join(trtorch_dir, linked_file)
linked_file_runtime_full_path = os.path.join(trtorch_dir, linked_file_runtime)
_TS_FE_AVAIL = os.path.isfile(linked_file_full_path)
_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path)
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
_FX_FE_AVAIL = True
_REFIT_AVAIL = True
ENABLED_FEATURES = FeatureSet(
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL
)
def _enabled_features_str() -> str:
enabled = lambda x: "ENABLED" if x else "DISABLED"
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n" # type: ignore[no-untyped-call]
return out_str
def needs_torch_tensorrt_runtime(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.torch_tensorrt_runtime:
return f(*args, **kwargs)
else:
def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
raise NotImplementedError("Torch-TensorRT Runtime is not available")
return not_implemented(*args, **kwargs)
return wrapper
def needs_refit(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.refit:
return f(*args, **kwargs)
else:
def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
raise NotImplementedError(
"Refit feature is currently not available in Python 3.13 or higher"
)
return not_implemented(*args, **kwargs)
return wrapper
T = TypeVar("T")
def for_all_methods(
decorator: Callable[..., Any], exclude: Optional[List[str]] = None
) -> Callable[..., Any]:
exclude_list: List[str] = []
if exclude:
exclude_list = exclude
def decorate(cls: Type[T]) -> Type[T]:
for attr in cls.__dict__:
if callable(getattr(cls, attr)) and attr not in exclude_list:
setattr(cls, attr, decorator(getattr(cls, attr)))
return cls
return decorate