Skip to content

Commit 82201f6

Browse files
che-shfacebook-github-bot
authored andcommitted
Re-land torchrec pipeline refactor (with fix) (#2853)
Summary: Reintroduces #2741 with fix **Note**: restores original for now, to make it easier to review the fix (next diff). Will merge the next diff (with the fix) into this before landing Reviewed By: sarckk Differential Revision: D71086956
1 parent f0ae23d commit 82201f6

File tree

4 files changed

+597
-576
lines changed

4 files changed

+597
-576
lines changed

torchrec/distributed/train_pipeline/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
_to_device, # noqa
2727
_wait_for_batch, # noqa
2828
ArgInfo, # noqa
29+
ArgInfoStepFactory, # noqa
30+
CallArgs, # noqa
2931
DataLoadingThread, # noqa
3032
In, # noqa
3133
Out, # noqa

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

+80-70
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,13 @@
6363
DataLoadingThread,
6464
EmbeddingPipelinedForward,
6565
get_h2d_func,
66+
GetAttrArgInfoStep,
67+
GetItemArgInfoStep,
68+
NoopArgInfoStep,
6669
PipelinedForward,
6770
PipelinedPostproc,
6871
PipelineStage,
72+
PostprocArgInfoStep,
6973
SparseDataDistUtil,
7074
StageOut,
7175
TrainPipelineContext,
@@ -1152,44 +1156,56 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
11521156
pipelined_weighted_ebc = pipeline._pipelined_modules[1]
11531157

11541158
# Check pipelined args
1155-
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
1156-
self.assertEqual(len(ebc.forward._args), 1)
1157-
self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
1158-
self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
1159-
self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2)
1160-
self.assertIsInstance(
1161-
ebc.forward._args[0].postproc_modules[0], PipelinedPostproc
1162-
)
1163-
self.assertEqual(ebc.forward._args[0].postproc_modules[1], None)
1164-
1159+
self.assertEqual(len(pipelined_ebc.forward._args.args), 1)
1160+
self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0)
11651161
self.assertEqual(
1166-
pipelined_ebc.forward._args[0].postproc_modules[0],
1167-
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
1168-
# `postproc_nonweighted`.
1169-
pipelined_model.module.postproc_nonweighted,
1162+
pipelined_ebc.forward._args.args[0].steps,
1163+
[
1164+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
1165+
PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted),
1166+
GetItemArgInfoStep(0),
1167+
],
11701168
)
1169+
self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1)
1170+
self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0)
11711171
self.assertEqual(
1172-
pipelined_weighted_ebc.forward._args[0].postproc_modules[0],
1173-
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
1174-
# `postproc_weighted`.
1175-
pipelined_model.module.postproc_weighted,
1172+
pipelined_weighted_ebc.forward._args.args[0].steps,
1173+
[
1174+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
1175+
PostprocArgInfoStep(pipelined_model.module.postproc_weighted),
1176+
GetItemArgInfoStep(0),
1177+
],
11761178
)
11771179

11781180
# postproc args
11791181
self.assertEqual(len(pipeline._pipelined_postprocs), 2)
1180-
input_attr_names = {"idlist_features", "idscore_features"}
1181-
for i in range(len(pipeline._pipelined_postprocs)):
1182-
postproc_mod = pipeline._pipelined_postprocs[i]
1183-
self.assertEqual(len(postproc_mod._args), 1)
1182+
# postprocs can be added in any order, so we can't assert on exact steps structures
1183+
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args), 1)
1184+
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.kwargs), 0)
1185+
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args[0].steps), 2)
1186+
self.assertEqual(
1187+
pipeline._pipelined_postprocs[0]._args.args[0].steps[0], NoopArgInfoStep()
1188+
)
1189+
self.assertIsInstance(
1190+
pipeline._pipelined_postprocs[0]._args.args[0].steps[1], GetAttrArgInfoStep
1191+
)
11841192

1185-
input_attr_name = postproc_mod._args[0].input_attrs[1]
1186-
self.assertTrue(input_attr_name in input_attr_names)
1187-
self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name])
1188-
input_attr_names.remove(input_attr_name)
1193+
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args), 1)
1194+
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.kwargs), 0)
1195+
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args[0].steps), 2)
1196+
self.assertEqual(
1197+
pipeline._pipelined_postprocs[1]._args.args[0].steps[0], NoopArgInfoStep()
1198+
)
1199+
self.assertIsInstance(
1200+
pipeline._pipelined_postprocs[1]._args.args[0].steps[1], GetAttrArgInfoStep
1201+
)
11891202

1190-
self.assertEqual(postproc_mod._args[0].is_getitems, [False, False])
1191-
# no parent postproc module in FX graph
1192-
self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None])
1203+
get_arg_infos = {
1204+
# pyre-fixme[16]: assertions above ensure that steps[1] is a GetAttrArgInfoStep
1205+
postproc._args.args[0].steps[1].attr_name
1206+
for postproc in pipeline._pipelined_postprocs
1207+
}
1208+
self.assertEqual(get_arg_infos, {"idlist_features", "idscore_features"})
11931209

11941210
# pyre-ignore
11951211
@unittest.skipIf(
@@ -1235,69 +1251,63 @@ def test_pipeline_postproc_recursive(self) -> None:
12351251
pipelined_weighted_ebc = pipeline._pipelined_modules[1]
12361252

12371253
# Check pipelined args
1238-
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
1239-
self.assertEqual(len(ebc.forward._args), 1)
1240-
self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
1241-
self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
1242-
self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2)
1243-
self.assertIsInstance(
1244-
ebc.forward._args[0].postproc_modules[0], PipelinedPostproc
1245-
)
1246-
self.assertEqual(ebc.forward._args[0].postproc_modules[1], None)
1247-
1254+
self.assertEqual(len(pipelined_ebc.forward._args.args), 1)
1255+
self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0)
12481256
self.assertEqual(
1249-
pipelined_ebc.forward._args[0].postproc_modules[0],
1250-
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
1251-
# `postproc_nonweighted`.
1252-
pipelined_model.module.postproc_nonweighted,
1257+
pipelined_ebc.forward._args.args[0].steps,
1258+
[
1259+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
1260+
PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted),
1261+
GetItemArgInfoStep(0),
1262+
],
12531263
)
1264+
self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1)
1265+
self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0)
12541266
self.assertEqual(
1255-
pipelined_weighted_ebc.forward._args[0].postproc_modules[0],
1256-
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
1257-
# `postproc_weighted`.
1258-
pipelined_model.module.postproc_weighted,
1267+
pipelined_weighted_ebc.forward._args.args[0].steps,
1268+
[
1269+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
1270+
PostprocArgInfoStep(pipelined_model.module.postproc_weighted),
1271+
GetItemArgInfoStep(0),
1272+
],
12591273
)
12601274

12611275
# postproc args
12621276
self.assertEqual(len(pipeline._pipelined_postprocs), 3)
12631277

1264-
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
1265-
# `_postproc_module`.
1278+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
12661279
parent_postproc_mod = pipelined_model.module._postproc_module
12671280

12681281
for postproc_mod in pipeline._pipelined_postprocs:
12691282
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
12701283
# `postproc_nonweighted`.
12711284
if postproc_mod == pipelined_model.module.postproc_nonweighted:
1272-
self.assertEqual(len(postproc_mod._args), 1)
1273-
args = postproc_mod._args[0]
1274-
self.assertEqual(args.input_attrs, ["", "idlist_features"])
1275-
self.assertEqual(args.is_getitems, [False, False])
1276-
self.assertEqual(len(args.postproc_modules), 2)
1285+
self.assertEqual(len(postproc_mod._args.args), 1)
1286+
self.assertEqual(len(postproc_mod._args.kwargs), 0)
12771287
self.assertEqual(
1278-
args.postproc_modules[0],
1279-
parent_postproc_mod,
1288+
postproc_mod._args.args[0].steps,
1289+
[
1290+
PostprocArgInfoStep(parent_postproc_mod),
1291+
GetAttrArgInfoStep("idlist_features"),
1292+
],
12801293
)
1281-
self.assertEqual(args.postproc_modules[1], None)
1294+
12821295
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
12831296
# `postproc_weighted`.
12841297
elif postproc_mod == pipelined_model.module.postproc_weighted:
1285-
self.assertEqual(len(postproc_mod._args), 1)
1286-
args = postproc_mod._args[0]
1287-
self.assertEqual(args.input_attrs, ["", "idscore_features"])
1288-
self.assertEqual(args.is_getitems, [False, False])
1289-
self.assertEqual(len(args.postproc_modules), 2)
1298+
self.assertEqual(len(postproc_mod._args.args), 1)
1299+
self.assertEqual(len(postproc_mod._args.kwargs), 0)
12901300
self.assertEqual(
1291-
args.postproc_modules[0],
1292-
parent_postproc_mod,
1301+
postproc_mod._args.args[0].steps,
1302+
[
1303+
PostprocArgInfoStep(parent_postproc_mod),
1304+
GetAttrArgInfoStep("idscore_features"),
1305+
],
12931306
)
1294-
self.assertEqual(args.postproc_modules[1], None)
12951307
elif postproc_mod == parent_postproc_mod:
1296-
self.assertEqual(len(postproc_mod._args), 1)
1297-
args = postproc_mod._args[0]
1298-
self.assertEqual(args.input_attrs, [""])
1299-
self.assertEqual(args.is_getitems, [False])
1300-
self.assertEqual(args.postproc_modules, [None])
1308+
self.assertEqual(len(postproc_mod._args.args), 1)
1309+
self.assertEqual(len(postproc_mod._args.kwargs), 0)
1310+
self.assertEqual(postproc_mod._args.args[0].steps, [NoopArgInfoStep()])
13011311

13021312
# pyre-ignore
13031313
@unittest.skipIf(

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

+42-70
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
TrainPipelineSparseDistTestBase,
2424
)
2525
from torchrec.distributed.train_pipeline.utils import (
26-
_build_args_kwargs,
27-
_get_node_args,
2826
_rewrite_model,
2927
ArgInfo,
28+
ArgInfoStepFactory,
29+
CallArgs,
30+
NodeArgsHelper,
3031
PipelinedForward,
3132
PipelinedPostproc,
3233
TrainPipelineContext,
@@ -110,17 +111,19 @@ def test_rewrite_model(self) -> None:
110111
self.assertEqual(
111112
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
112113
# `sparse`.
113-
sharded_model.module.sparse.ebc.forward._args[0].postproc_modules[0],
114+
sharded_model.module.sparse.ebc.forward._args.args[0]
115+
.steps[0]
116+
.postproc_module,
114117
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
115118
# `postproc_module`.
116119
sharded_model.module.postproc_module,
117120
)
118121
self.assertEqual(
119122
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
120123
# `sparse`.
121-
sharded_model.module.sparse.weighted_ebc.forward._args[0].postproc_modules[
122-
0
123-
],
124+
sharded_model.module.sparse.weighted_ebc.forward._args.args[0]
125+
.steps[0]
126+
.postproc_module,
124127
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
125128
# `postproc_module`.
126129
sharded_model.module.postproc_module,
@@ -154,7 +157,7 @@ def forward(self, x):
154157
rewritten_model.test_module = PipelinedPostproc(
155158
postproc_module=rewritten_model.test_module,
156159
fqn="test_module",
157-
args=[],
160+
args=CallArgs(args=[], kwargs={}),
158161
context=TrainPipelineContext(),
159162
default_stream=MagicMock(),
160163
dist_stream=MagicMock(),
@@ -260,83 +263,53 @@ def test_restore_from_snapshot(self) -> None:
260263
@parameterized.expand(
261264
[
262265
(
263-
[
264-
# Empty attrs to ignore any attr based logic.
265-
ArgInfo(
266-
input_attrs=[
267-
"",
268-
],
269-
is_getitems=[False],
270-
postproc_modules=[None],
271-
constants=[None],
272-
name="id_list_features",
273-
),
274-
ArgInfo(
275-
input_attrs=[],
276-
is_getitems=[],
277-
postproc_modules=[],
278-
constants=[],
279-
name="id_score_list_features",
280-
),
281-
],
266+
CallArgs(
267+
args=[],
268+
kwargs={
269+
"id_list_features": ArgInfo(steps=[ArgInfoStepFactory.noop()]),
270+
# Empty attrs to ignore any attr based logic.
271+
"id_score_list_features": ArgInfo(
272+
steps=[ArgInfoStepFactory.noop()]
273+
),
274+
},
275+
),
282276
0,
283277
["id_list_features", "id_score_list_features"],
284278
),
285279
(
286-
[
287-
# Empty attrs to ignore any attr based logic.
288-
ArgInfo(
289-
input_attrs=[
290-
"",
291-
],
292-
is_getitems=[False],
293-
postproc_modules=[None],
294-
constants=[None],
295-
name=None,
296-
),
297-
ArgInfo(
298-
input_attrs=[],
299-
is_getitems=[],
300-
postproc_modules=[],
301-
constants=[],
302-
name=None,
303-
),
304-
],
280+
CallArgs(
281+
args=[
282+
# Empty attrs to ignore any attr based logic.
283+
ArgInfo(steps=[ArgInfoStepFactory.noop()]),
284+
ArgInfo(steps=[]),
285+
],
286+
kwargs={},
287+
),
305288
2,
306289
[],
307290
),
308291
(
309-
[
310-
# Empty attrs to ignore any attr based logic.
311-
ArgInfo(
312-
input_attrs=[
313-
"",
314-
],
315-
is_getitems=[False],
316-
postproc_modules=[None],
317-
constants=[None],
318-
name=None,
319-
),
320-
ArgInfo(
321-
input_attrs=[],
322-
is_getitems=[],
323-
postproc_modules=[],
324-
constants=[],
325-
name="id_score_list_features",
326-
),
327-
],
292+
CallArgs(
293+
args=[
294+
# Empty attrs to ignore any attr based logic.
295+
ArgInfo(
296+
steps=[ArgInfoStepFactory.noop()],
297+
)
298+
],
299+
kwargs={"id_score_list_features": ArgInfo(steps=[])},
300+
),
328301
1,
329302
["id_score_list_features"],
330303
),
331304
]
332305
)
333306
def test_build_args_kwargs(
334307
self,
335-
fwd_args: List[ArgInfo],
308+
fwd_args: CallArgs,
336309
args_len: int,
337310
kwarges_keys: List[str],
338311
) -> None:
339-
args, kwargs = _build_args_kwargs("initial_input", fwd_args)
312+
args, kwargs = fwd_args.build_args_kwargs("initial_input")
340313
self.assertEqual(len(args), args_len)
341314
self.assertEqual(list(kwargs.keys()), kwarges_keys)
342315

@@ -367,10 +340,9 @@ def test_get_node_args_helper_call_module_kjt(self) -> None:
367340
{},
368341
)
369342

370-
num_found = 0
371-
_, num_found = _get_node_args(
372-
MagicMock(), kjt_node, set(), TrainPipelineContext(), False
373-
)
343+
node_args_helper = NodeArgsHelper(MagicMock(), TrainPipelineContext(), False)
344+
345+
_, num_found = node_args_helper.get_node_args(kjt_node)
374346

375347
# Weights is call_module node, so we should only find 2 args unmodified
376348
self.assertEqual(num_found, len(kjt_args) - 1)

0 commit comments

Comments
 (0)