Skip to content

Commit 1a7da6e

Browse files
zhxchen17pytorchmergebot
authored andcommitted
[export] Add test to enforce consistency between synced thrift and generated thrift from schema.py (pytorch#141989)
Summary: In this diff we implement a way to ensure the internal thrift schema from cfgr (configerator/structs/caffe2/torch/export/schema.thrift) and the schema in OSS (torch/_export/serde/schema.thrift) are in sync, by adding a unittest to reflect on the type names and fields from each schema and compare them field by field. When we detect new fields/types from torch/_export/serde/schema.thrift, there'll be a test failure on the trunk and the error message hints people to add the missing field/type to the thrift schema from cfgr, so that they are always in sync in practice. Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_thrift_schema_in_sync Differential Revision: D66716834 Pull Request resolved: pytorch#141989 Approved by: https://github.com/yiming0416
1 parent bab15df commit 1a7da6e

File tree

8 files changed

+117
-55
lines changed

8 files changed

+117
-55
lines changed

scripts/export/update_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
first_line = (
5959
"@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
6060
)
61-
checksum = f"checksum<<{commit.checksum_result}>>"
61+
checksum = f"checksum<<{commit.checksum_next}>>"
6262
yaml_header = "# " + first_line
6363
yaml_header += "\n# " + checksum
6464
yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
@@ -73,7 +73,7 @@
7373
yaml_content = yaml_header + "\n" + yaml_payload
7474

7575
thrift_schema = "// " + first_line
76-
thrift_schema += "\n// " + checksum
76+
thrift_schema += f"\n// checksum<<{commit.thrift_checksum_next}>>"
7777
thrift_schema += "\n" + commit.thrift_schema
7878

7979
if args.dry_run:

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,7 @@ def main():
13381338
"_inductor/codegen/*.h",
13391339
"_inductor/codegen/aoti_runtime/*.cpp",
13401340
"_export/serde/*.yaml",
1341+
"_export/serde/*.thrift",
13411342
"share/cmake/ATen/*.cmake",
13421343
"share/cmake/Caffe2/*.cmake",
13431344
"share/cmake/Caffe2/public/*.cmake",

test/export/test_schema.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,27 @@ def test_schema_compatibility(self):
2626
except SchemaUpdateError as e:
2727
self.fail(f"Failed to update schema: {e}\n{msg}")
2828

29-
self.assertEqual(commit.checksum_base, commit.checksum_result, msg)
29+
self.assertEqual(commit.checksum_head, commit.checksum_next, msg)
30+
31+
def test_thrift_schema_unchanged(self):
32+
msg = """
33+
Detected an unexpected change to schema.thrift. Please update schema.py instead and run the following script:
34+
Example(s):
35+
python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
36+
"""
37+
38+
if IS_FBCODE:
39+
msg += """or
40+
buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/
41+
"""
42+
43+
try:
44+
commit = update_schema()
45+
except SchemaUpdateError as e:
46+
self.fail(f"Failed to update schema: {e}\n{msg}")
47+
48+
self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_real, msg)
49+
self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_next, msg)
3050

3151
def test_schema_diff(self):
3252
additions, subtractions = _diff_schema(
@@ -105,14 +125,17 @@ def test_schema_check(self):
105125

106126
commit = _Commit(
107127
result=src,
108-
checksum_result="",
128+
checksum_next="",
109129
yaml_path="",
110130
additions=additions,
111131
subtractions=subtractions,
112132
base=dst,
113-
checksum_base="",
133+
checksum_head="",
114134
cpp_header="",
115135
cpp_header_path="",
136+
thrift_checksum_head="",
137+
thrift_checksum_real="",
138+
thrift_checksum_next="",
116139
thrift_schema="",
117140
thrift_schema_path="",
118141
)
@@ -141,14 +164,17 @@ def test_schema_check(self):
141164

142165
commit = _Commit(
143166
result=src,
144-
checksum_result="",
167+
checksum_next="",
145168
yaml_path="",
146169
additions=additions,
147170
subtractions=subtractions,
148171
base=dst,
149-
checksum_base="",
172+
checksum_head="",
150173
cpp_header="",
151174
cpp_header_path="",
175+
thrift_checksum_head="",
176+
thrift_checksum_real="",
177+
thrift_checksum_next="",
152178
thrift_schema="",
153179
thrift_schema_path="",
154180
)
@@ -180,14 +206,17 @@ def test_schema_check(self):
180206

181207
commit = _Commit(
182208
result=src,
183-
checksum_result="",
209+
checksum_next="",
184210
yaml_path="",
185211
additions=additions,
186212
subtractions=subtractions,
187213
base=dst,
188-
checksum_base="",
214+
checksum_head="",
189215
cpp_header="",
190216
cpp_header_path="",
217+
thrift_checksum_head="",
218+
thrift_checksum_real="",
219+
thrift_checksum_next="",
191220
thrift_schema="",
192221
thrift_schema_path="",
193222
)
@@ -242,14 +271,17 @@ def test_schema_check(self):
242271

243272
commit = _Commit(
244273
result=src,
245-
checksum_result="",
274+
checksum_next="",
246275
yaml_path="",
247276
additions=additions,
248277
subtractions=subtractions,
249278
base=dst,
250-
checksum_base="",
279+
checksum_head="",
251280
cpp_header="",
252281
cpp_header_path="",
282+
thrift_checksum_head="",
283+
thrift_checksum_real="",
284+
thrift_checksum_next="",
253285
thrift_schema="",
254286
thrift_schema_path="",
255287
)
@@ -274,14 +306,17 @@ def test_schema_check(self):
274306

275307
commit = _Commit(
276308
result=src,
277-
checksum_result="",
309+
checksum_next="",
278310
yaml_path="",
279311
additions=additions,
280312
subtractions=subtractions,
281313
base=dst,
282-
checksum_base="",
314+
checksum_head="",
283315
cpp_header="",
284316
cpp_header_path="",
317+
thrift_checksum_head="",
318+
thrift_checksum_real="",
319+
thrift_checksum_next="",
285320
thrift_schema="",
286321
thrift_schema_path="",
287322
)
@@ -313,14 +348,17 @@ def test_schema_check(self):
313348

314349
commit = _Commit(
315350
result=src,
316-
checksum_result="",
351+
checksum_next="",
317352
yaml_path="",
318353
additions=additions,
319354
subtractions=subtractions,
320355
base=dst,
321-
checksum_base="",
356+
checksum_head="",
322357
cpp_header="",
323358
cpp_header_path="",
359+
thrift_checksum_head="",
360+
thrift_checksum_real="",
361+
thrift_checksum_next="",
324362
thrift_schema="",
325363
thrift_schema_path="",
326364
)
@@ -349,14 +387,17 @@ def test_schema_check(self):
349387

350388
commit = _Commit(
351389
result=src,
352-
checksum_result="",
390+
checksum_next="",
353391
yaml_path="",
354392
additions=additions,
355393
subtractions=subtractions,
356394
base=dst,
357-
checksum_base="",
395+
checksum_head="",
358396
cpp_header="",
359397
cpp_header_path="",
398+
thrift_checksum_head="",
399+
thrift_checksum_real="",
400+
thrift_checksum_next="",
360401
thrift_schema="",
361402
thrift_schema_path="",
362403
)

torch/_export/serde/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ class Device:
5858
@dataclass(repr=False)
5959
class SymExprHint(_Union):
6060
as_int: Annotated[int, 10]
61-
as_float: Annotated[float, 20]
62-
as_bool: Annotated[bool, 30]
61+
as_bool: Annotated[bool, 20]
62+
as_float: Annotated[float, 30]
6363

6464

6565
# This is for storing the symbolic expressions behind symints/symfloats/symbools

torch/_export/serde/schema.thrift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// @generated by update_schema.py
2-
// checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>>
2+
// checksum<<0e89c5e620ad16c05bfe4fa2060ad43dcb0938dc31d77faad36b92f216c2c903>>
33

4-
namespace py3 torch._export.schema
4+
namespace py3 torch._export
55
namespace cpp2 torch._export.schema
66

77
enum Layout {
@@ -51,8 +51,8 @@ struct Device {
5151

5252
union SymExprHint {
5353
10: i64 as_int;
54-
20: double as_float;
55-
30: bool as_bool;
54+
20: bool as_bool;
55+
30: double as_float;
5656
}
5757

5858
struct SymExpr {

torch/_export/serde/schema.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# @generated by update_schema.py
2-
# checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>>
2+
# checksum<<0335ca6e44a8a815ea638d538de0ad4f78a644af2689f6e93c0e8219117466e7>>
33
Argument:
44
kind: union
55
fields:
@@ -380,10 +380,10 @@ SymExprHint:
380380
fields:
381381
as_int:
382382
type: int
383-
as_float:
384-
type: float
385383
as_bool:
386384
type: bool
385+
as_float:
386+
type: float
387387
SymFloat:
388388
kind: union
389389
fields:

torch/_export/serde/schema_check.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _staged_schema():
3131
thrift_type_defs: Dict[str, str] = {}
3232

3333
def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
34-
def dump_type(t) -> Tuple[str, str, str]:
34+
def dump_type(t, level: int) -> Tuple[str, str, str]:
3535
CPP_TYPE_MAP = {
3636
str: "std::string",
3737
int: "int64_t",
@@ -90,20 +90,21 @@ def dump_type(t) -> Tuple[str, str, str]:
9090
"",
9191
)
9292
elif o == Union:
93+
assert level == 0, "Optional is only supported at the top level."
9394
args = typing.get_args(t)
9495
assert len(args) == 2 and args[1] == type(None)
95-
yaml_type, cpp_type, thrift_type = dump_type(args[0])
96+
yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1)
9697
return (
9798
f"Optional[{yaml_type}]",
9899
f"std::optional<{cpp_type}>",
99100
f"optional {thrift_type}",
100101
)
101102
elif o == Annotated:
102-
return dump_type(t.__origin__)
103+
return dump_type(t.__origin__, level)
103104
else:
104105
raise AssertionError(f"Type {t} is not supported in export schema.")
105106
yaml_arg_types, cpp_arg_types, thrift_arg_types = zip(
106-
*[dump_type(x) for x in typing.get_args(t)]
107+
*[dump_type(x, level + 1) for x in typing.get_args(t)]
107108
)
108109
return (
109110
(f"{yaml_head}[{', '.join(yaml_arg_types)}]"),
@@ -136,7 +137,7 @@ def dump_cpp_value(v) -> str:
136137
)
137138

138139
def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str], str, int]:
139-
t, cpp_type, thrift_type = dump_type(f.type)
140+
t, cpp_type, thrift_type = dump_type(f.type, 0)
140141
ret = {"type": t}
141142
cpp_default: Optional[str] = None
142143
assert (
@@ -455,7 +456,7 @@ class ForwardRef {{
455456
}} // namespace torch
456457
"""
457458
thrift_schema = f"""
458-
namespace py3 torch._export.schema
459+
namespace py3 torch._export
459460
namespace cpp2 torch._export.schema
460461
{chr(10).join(thrift_enum_defs)}
461462
{chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())}
@@ -528,21 +529,24 @@ def _diff_schema(dst, src):
528529
return additions, subtractions
529530

530531

531-
def _hash_schema(s):
532-
return hashlib.sha256(repr(s).encode("utf-8")).hexdigest()
532+
def _hash_content(s: str):
533+
return hashlib.sha256(s.strip().encode("utf-8")).hexdigest()
533534

534535

535536
@dataclasses.dataclass
536537
class _Commit:
537538
result: Dict[str, Any]
538-
checksum_result: str
539+
checksum_next: str
539540
yaml_path: str
540541
additions: Dict[str, Any]
541542
subtractions: Dict[str, Any]
542543
base: Dict[str, Any]
543-
checksum_base: Optional[str]
544+
checksum_head: Optional[str]
544545
cpp_header: str
545546
cpp_header_path: str
547+
thrift_checksum_head: Optional[str]
548+
thrift_checksum_real: Optional[str]
549+
thrift_checksum_next: str
546550
thrift_schema: str
547551
thrift_schema_path: str
548552

@@ -555,13 +559,26 @@ def update_schema():
555559
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
556560
_check(match is not None, "checksum not found in schema.yaml")
557561
assert match is not None
558-
checksum_base = match.group(1)
562+
checksum_head = match.group(1)
563+
564+
thrift_content = importlib.resources.read_text(__package__, "schema.thrift")
565+
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content)
566+
_check(match is not None, "checksum not found in schema.thrift")
567+
assert match is not None
568+
thrift_checksum_head = match.group(1)
569+
thrift_content = thrift_content.splitlines()
570+
assert thrift_content[0].startswith("// @" + "generated")
571+
assert thrift_content[1].startswith("// checksum<<")
572+
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
573+
559574
from yaml import load, Loader
560575

561576
dst = load(content, Loader=Loader)
562577
assert isinstance(dst, dict)
563578
else:
564-
checksum_base = None
579+
checksum_head = None
580+
thrift_checksum_head = None
581+
thrift_checksum_real = None
565582
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
566583

567584
src, cpp_header, thrift_schema = _staged_schema()
@@ -574,14 +591,17 @@ def update_schema():
574591

575592
return _Commit(
576593
result=src,
577-
checksum_result=_hash_schema(src),
594+
checksum_next=_hash_content(repr(src)),
578595
yaml_path=yaml_path,
579596
additions=additions,
580597
subtractions=subtractions,
581598
base=dst,
582-
checksum_base=checksum_base,
599+
checksum_head=checksum_head,
583600
cpp_header=cpp_header,
584601
cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h",
602+
thrift_checksum_head=thrift_checksum_head,
603+
thrift_checksum_real=thrift_checksum_real,
604+
thrift_checksum_next=_hash_content(thrift_schema),
585605
thrift_schema=thrift_schema,
586606
thrift_schema_path=thrift_schema_path,
587607
)

0 commit comments

Comments
 (0)