|
6 | 6 |
|
7 | 7 | from __future__ import annotations
|
8 | 8 |
|
| 9 | +from enum import Enum |
9 | 10 | import logging
|
10 | 11 | import os
|
11 | 12 | from pathlib import Path
|
@@ -78,36 +79,33 @@ def set_backend(dso, pte):
|
78 | 79 | active_builder_args_pte = pte
|
79 | 80 |
|
80 | 81 |
|
81 |
| -def use_aoti_backend() -> bool: |
| 82 | +class _Backend(Enum): |
| 83 | + AOTI = 0, |
| 84 | + EXECUTORCH = 1 |
| 85 | + |
| 86 | + |
| 87 | +def _active_backend() -> _Backend: |
82 | 88 | global active_builder_args_dso
|
83 | 89 | global active_builder_args_pte
|
84 | 90 |
|
85 | 91 | # eager == aoti, which is when backend has not been explicitly set
|
86 | 92 | if (not active_builder_args_dso) and not (active_builder_args_pte):
|
87 |
| - return True |
| 93 | + return _Backend.AOTI |
88 | 94 |
|
89 | 95 | if active_builder_args_pte and active_builder_args_dso:
|
90 | 96 | raise RuntimeError(
|
91 | 97 | "code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
|
92 | 98 | )
|
93 | 99 |
|
94 |
| - return bool(active_builder_args_dso) |
| 100 | + return _Backend.AOTI if active_builder_args_dso else _Backend.EXECUTORCH |
95 | 101 |
|
96 | 102 |
|
97 |
| -def use_et_backend() -> bool: |
98 |
| - global active_builder_args_dso |
99 |
| - global active_builder_args_pte |
100 |
| - |
101 |
| - # eager == aoti, which is when backend has not been explicitly set |
102 |
| - if not (active_builder_args_pte or active_builder_args_dso): |
103 |
| - return False |
| 103 | +def use_aoti_backend() -> bool: |
| 104 | + return _active_backend() == _Backend.AOTI |
104 | 105 |
|
105 |
| - if active_builder_args_pte and active_builder_args_dso: |
106 |
| - raise RuntimeError( |
107 |
| - "code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!" |
108 |
| - ) |
109 | 106 |
|
110 |
| - return bool(active_builder_args_pte) |
| 107 | +def use_et_backend() -> bool: |
| 108 | + return _active_backend() == _Backend.EXECUTORCH |
111 | 109 |
|
112 | 110 |
|
113 | 111 | ##########################################################################
|
@@ -142,9 +140,9 @@ def name_to_dtype(name, device):
|
142 | 140 | return torch.float16
|
143 | 141 | return torch.bfloat16
|
144 | 142 |
|
145 |
| - if name in name_to_dtype_dict: |
| 143 | + try: |
146 | 144 | return name_to_dtype_dict[name]
|
147 |
| - else: |
| 145 | + except KeyError: |
148 | 146 | raise RuntimeError(f"unsupported dtype name {name} specified")
|
149 | 147 |
|
150 | 148 |
|
@@ -212,10 +210,7 @@ def canonical_path(path):
|
212 | 210 |
|
213 | 211 |
|
214 | 212 | def state_dict_device(d, device="cpu") -> Dict:
|
215 |
| - for key, weight in d.items(): |
216 |
| - d[key] = weight.to(device=device) |
217 |
| - |
218 |
| - return d |
| 213 | + return {key : weight.to(device=device) for (key, weight) in d.items()} |
219 | 214 |
|
220 | 215 |
|
221 | 216 | #########################################################################
|
@@ -259,9 +254,9 @@ def get_device(device) -> str:
|
259 | 254 | return torch.device(device)
|
260 | 255 |
|
261 | 256 |
|
262 |
| -def is_cuda_or_cpu_device(device) -> bool: |
263 |
| - return device == "" or str(device) == "cpu" or ("cuda" in str(device)) |
264 |
| - |
265 |
| - |
266 | 257 | def is_cpu_device(device) -> bool:
|
267 | 258 | return device == "" or str(device) == "cpu"
|
| 259 | + |
| 260 | + |
| 261 | +def is_cuda_or_cpu_device(device) -> bool: |
| 262 | + return is_cpu_device(device) or ("cuda" in str(device)) |
0 commit comments