Skip to content

Commit 3eeb664

Browse files
Mr-Pepegaborbernat
andauthored
fix: Recursively evaluate guarded code (#393)
* Add test for nested guarded imports * Handle nested type guards * Refactor --------- Co-authored-by: Bernát Gábor <[email protected]>
1 parent 5eb0fcf commit 3eeb664

File tree

3 files changed

+62
-25
lines changed

3 files changed

+62
-25
lines changed

src/sphinx_autodoc_typehints/__init__.py

+50-25
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import ast
5+
import importlib
56
import inspect
67
import re
78
import sys
@@ -404,32 +405,56 @@ def get_all_type_hints(autodoc_mock_imports: list[str], obj: Any, name: str) ->
404405
_TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID = set()
405406

406407

407-
def _resolve_type_guarded_imports(autodoc_mock_imports: list[str], obj: Any) -> None: # noqa: C901
408-
if hasattr(obj, "__module__") and obj.__module__ in _TYPE_GUARD_IMPORTS_RESOLVED:
409-
return # already processed module
410-
if not hasattr(obj, "__globals__"): # classes with __slots__ do not have this
411-
return # if lacks globals nothing we can do
412-
if id(obj.__globals__) in _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID:
413-
return # already processed object
414-
_TYPE_GUARD_IMPORTS_RESOLVED.add(obj.__module__)
415-
if obj.__module__ not in sys.builtin_module_names:
416-
if hasattr(obj, "__globals__"):
417-
_TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID.add(id(obj.__globals__))
418-
419-
module = inspect.getmodule(obj)
420-
if module:
408+
def _should_skip_guarded_import_resolution(obj: Any) -> bool:
409+
if isinstance(obj, types.ModuleType):
410+
return False # Don't skip modules
411+
412+
if not hasattr(obj, "__globals__"):
413+
return True # Skip objects without __globals__
414+
415+
if hasattr(obj, "__module__"):
416+
return obj.__module__ in _TYPE_GUARD_IMPORTS_RESOLVED or obj.__module__ in sys.builtin_module_names
417+
418+
return id(obj.__globals__) in _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID
419+
420+
421+
def _execute_guarded_code(autodoc_mock_imports: list[str], obj: Any, module_code: str) -> None:
422+
for _, part in _TYPE_GUARD_IMPORT_RE.findall(module_code):
423+
guarded_code = textwrap.dedent(part)
424+
try:
421425
try:
422-
module_code = inspect.getsource(module)
423-
except (TypeError, OSError):
424-
... # no source code => no type guards
425-
else:
426-
for _, part in _TYPE_GUARD_IMPORT_RE.findall(module_code):
427-
guarded_code = textwrap.dedent(part)
428-
try:
429-
with mock(autodoc_mock_imports):
430-
exec(guarded_code, obj.__globals__) # noqa: S102
431-
except Exception as exc: # noqa: BLE001
432-
_LOGGER.warning("Failed guarded type import with %r", exc)
426+
with mock(autodoc_mock_imports):
427+
exec(guarded_code, getattr(obj, "__globals__", obj.__dict__)) # noqa: S102
428+
except ImportError as exc:
429+
# ImportError might have occurred because the module has guarded code as well,
430+
# so we recurse on the module.
431+
if exc.name:
432+
_resolve_type_guarded_imports(autodoc_mock_imports, importlib.import_module(exc.name))
433+
434+
# Retry the guarded code and see if it works now after resolving all nested type guards.
435+
with mock(autodoc_mock_imports):
436+
exec(guarded_code, getattr(obj, "__globals__", obj.__dict__)) # noqa: S102
437+
except Exception as exc: # noqa: BLE001
438+
_LOGGER.warning("Failed guarded type import with %r", exc)
439+
440+
441+
def _resolve_type_guarded_imports(autodoc_mock_imports: list[str], obj: Any) -> None:
442+
if _should_skip_guarded_import_resolution(obj):
443+
return
444+
445+
if hasattr(obj, "__globals__"):
446+
_TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID.add(id(obj.__globals__))
447+
448+
module = inspect.getmodule(obj)
449+
450+
if module:
451+
try:
452+
module_code = inspect.getsource(module)
453+
except (TypeError, OSError):
454+
... # no source code => no type guards
455+
else:
456+
_TYPE_GUARD_IMPORTS_RESOLVED.add(module.__name__)
457+
_execute_guarded_code(autodoc_mock_imports, obj, module_code)
433458

434459

435460
def _get_type_hint(autodoc_mock_imports: list[str], name: str, obj: Any) -> dict[str, Any]:

tests/roots/test-resolve-typing-guard/demo_typing_guard.py

+6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from decimal import Decimal
1313
from typing import Sequence
1414

15+
from demo_typing_guard_dummy import Literal # guarded by another `if TYPE_CHECKING` in demo_typing_guard_dummy
16+
1517

1618
if typing.TYPE_CHECKING:
1719
from typing import AnyStr
@@ -52,6 +54,10 @@ def guarded(self, item: Decimal) -> None:
5254
"""
5355

5456

57+
def func(_x: Literal) -> None:
58+
...
59+
60+
5561
__all__ = [
5662
"a",
5763
"ValueError",
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
from viktor import AI # module part of autodoc_mock_imports # noqa: F401
46

7+
if TYPE_CHECKING:
8+
# Nested type guard
9+
from typing import Literal # noqa: F401
10+
511

612
class AnotherClass:
713
"""Another class is here"""

0 commit comments

Comments
 (0)