Skip to content

Commit 047e656

Browse files
ringohoffmanjainapurva
authored andcommitted
Use importlib.util.find_spec to check if lm_eval is installed instead of trying to import it (#1023)
Use importlib.util.find_spec to check if lm_eval is installed instead of trying to import it There is a circular dependency when trying to import lm_eval inside torchao. The chain is like this: torchao -> lm_eval -> transformers.pipelines -> torchao And results in the following error: RuntimeError: Failed to import transformers.pipelines because of the following error (look up to see its traceback): cannot import name 'quantize_' from partially initialized module 'torchao.quantization' which 1. causes _lm_eval_available to be erroneously set to False, even if lm_eval is available 2. interrupts lm_eval's initialization, leaving it partially initialized you can observe this with: >>> import torchao >>> import lm_eval.__main__ >>> import lm_eval.api.registry >> lm_eval.api.registry AttributeError: module 'lm_eval' has no attribute 'api' Having a bare except clause here was suppressing this circular import error, which from glancing around seems kind of like a general pattern in this code base. It might be worth reconsidering this pattern.
1 parent 8d508ed commit 047e656

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

torchao/quantization/utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from typing import Dict, List, Optional, Tuple
6+
import importlib.util
7+
from typing import Dict, List, Optional
78

89
import torch
910
from torch.utils._python_dispatch import TorchDispatchMode
@@ -40,12 +41,7 @@
4041
"recommended_inductor_config_setter"
4142
]
4243

43-
try:
44-
import lm_eval # pyre-ignore[21] # noqa: F401
45-
46-
_lm_eval_available = True
47-
except:
48-
_lm_eval_available = False
44+
_lm_eval_available = importlib.util.find_spec("lm_eval") is not None
4945

5046
# basic SQNR
5147
def compute_error(x, y):

0 commit comments

Comments
 (0)