Skip to content

Commit ecc628d

Browse files
authored
Code cleanup in build_utils.py: Improve backend detection with _Backend (#1306)
Bunch of minor code quality things.
1 parent fa6f9b6 commit ecc628d

File tree

1 file changed

+20
-25
lines changed

1 file changed

+20
-25
lines changed

torchchat/utils/build_utils.py

+20-25
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from __future__ import annotations
88

9+
from enum import Enum
910
import logging
1011
import os
1112
from pathlib import Path
@@ -78,36 +79,33 @@ def set_backend(dso, pte):
7879
active_builder_args_pte = pte
7980

8081

81-
def use_aoti_backend() -> bool:
82+
class _Backend(Enum):
83+
AOTI = 0,
84+
EXECUTORCH = 1
85+
86+
87+
def _active_backend() -> _Backend:
8288
global active_builder_args_dso
8389
global active_builder_args_pte
8490

8591
# eager == aoti, which is when backend has not been explicitly set
8692
if (not active_builder_args_dso) and not (active_builder_args_pte):
87-
return True
93+
return _Backend.AOTI
8894

8995
if active_builder_args_pte and active_builder_args_dso:
9096
raise RuntimeError(
9197
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
9298
)
9399

94-
return bool(active_builder_args_dso)
100+
return _Backend.AOTI if active_builder_args_dso else _Backend.EXECUTORCH
95101

96102

97-
def use_et_backend() -> bool:
98-
global active_builder_args_dso
99-
global active_builder_args_pte
100-
101-
# eager == aoti, which is when backend has not been explicitly set
102-
if not (active_builder_args_pte or active_builder_args_dso):
103-
return False
103+
def use_aoti_backend() -> bool:
104+
return _active_backend() == _Backend.AOTI
104105

105-
if active_builder_args_pte and active_builder_args_dso:
106-
raise RuntimeError(
107-
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
108-
)
109106

110-
return bool(active_builder_args_pte)
107+
def use_et_backend() -> bool:
108+
return _active_backend() == _Backend.EXECUTORCH
111109

112110

113111
##########################################################################
@@ -142,9 +140,9 @@ def name_to_dtype(name, device):
142140
return torch.float16
143141
return torch.bfloat16
144142

145-
if name in name_to_dtype_dict:
143+
try:
146144
return name_to_dtype_dict[name]
147-
else:
145+
except KeyError:
148146
raise RuntimeError(f"unsupported dtype name {name} specified")
149147

150148

@@ -212,10 +210,7 @@ def canonical_path(path):
212210

213211

214212
def state_dict_device(d, device="cpu") -> Dict:
215-
for key, weight in d.items():
216-
d[key] = weight.to(device=device)
217-
218-
return d
213+
return {key : weight.to(device=device) for (key, weight) in d.items()}
219214

220215

221216
#########################################################################
@@ -259,9 +254,9 @@ def get_device(device) -> str:
259254
return torch.device(device)
260255

261256

262-
def is_cuda_or_cpu_device(device) -> bool:
263-
return device == "" or str(device) == "cpu" or ("cuda" in str(device))
264-
265-
266257
def is_cpu_device(device) -> bool:
267258
return device == "" or str(device) == "cpu"
259+
260+
261+
def is_cuda_or_cpu_device(device) -> bool:
262+
return is_cpu_device(device) or ("cuda" in str(device))

0 commit comments

Comments
 (0)