Skip to content

Commit d4b896f

Browse files
aorenstePaulZhang12
authored andcommitted
Teach is_signature_compatible() to dig into similar annotations (#2693)
Summary: Pull Request resolved: #2693 D68450007 updated some annotations in pytorch. This function wasn't correctly evaluating `typing.Dict[X, Y]` and `dict[X, Y]` as the equivalent. Reviewed By: izaitsevfb Differential Revision: D68475380 fbshipit-source-id: 3b71ab41f95e6c20986ebe6fbf6f9cbe3b3d58f9
1 parent 8e5f5df commit d4b896f

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

torchrec/schema/utils.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,32 @@
88
# pyre-strict
99

1010
import inspect
11+
import typing
12+
from typing import Any
13+
14+
15+
def _is_annot_compatible(prev: object, curr: object) -> bool:
16+
if prev == curr:
17+
return True
18+
19+
if not (prev_origin := typing.get_origin(prev)):
20+
return False
21+
if not (curr_origin := typing.get_origin(curr)):
22+
return False
23+
24+
if prev_origin != curr_origin:
25+
return False
26+
27+
prev_args = typing.get_args(prev)
28+
curr_args = typing.get_args(curr)
29+
if len(prev_args) != len(curr_args):
30+
return False
31+
32+
for prev_arg, curr_arg in zip(prev_args, curr_args):
33+
if not _is_annot_compatible(prev_arg, curr_arg):
34+
return False
35+
36+
return True
1137

1238

1339
def is_signature_compatible(
@@ -84,6 +110,8 @@ def is_signature_compatible(
84110
return False
85111

86112
# TODO: Account for Union Types?
87-
if current_signature.return_annotation != previous_signature.return_annotation:
113+
if not _is_annot_compatible(
114+
previous_signature.return_annotation, current_signature.return_annotation
115+
):
88116
return False
89117
return True

0 commit comments

Comments
 (0)