Skip to content

Commit 65efa1e

Browse files
che-shfacebook-github-bot
authored andcommitted
Re-land torchrec pipeline refactor (with fix) (#2853)
Summary: Reintroduces #2741 with fix **Note**: restores original for now, to make it easier to review the fix (next diff). Will merge the next diff (with the fix) into this before landing Reviewed By: sarckk, TroyGarden Differential Revision: D71086956
1 parent 54723ee commit 65efa1e

File tree

4 files changed

+602
-581
lines changed

4 files changed

+602
-581
lines changed

torchrec/distributed/train_pipeline/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
_to_device, # noqa
2828
_wait_for_batch, # noqa
2929
ArgInfo, # noqa
30+
ArgInfoStepFactory, # noqa
31+
CallArgs, # noqa
3032
DataLoadingThread, # noqa
3133
In, # noqa
3234
Out, # noqa

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 80 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,13 @@
6464
DataLoadingThread,
6565
EmbeddingPipelinedForward,
6666
get_h2d_func,
67+
GetAttrArgInfoStep,
68+
GetItemArgInfoStep,
69+
NoopArgInfoStep,
6770
PipelinedForward,
6871
PipelinedPostproc,
6972
PipelineStage,
73+
PostprocArgInfoStep,
7074
SparseDataDistUtil,
7175
StageOut,
7276
TrainPipelineContext,
@@ -1254,44 +1258,56 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
12541258
pipelined_weighted_ebc = pipeline._pipelined_modules[1]
12551259

12561260
# Check pipelined args
1257-
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
1258-
self.assertEqual(len(ebc.forward._args), 1)
1259-
self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
1260-
self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
1261-
self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2)
1262-
self.assertIsInstance(
1263-
ebc.forward._args[0].postproc_modules[0], PipelinedPostproc
1264-
)
1265-
self.assertEqual(ebc.forward._args[0].postproc_modules[1], None)
1266-
1261+
self.assertEqual(len(pipelined_ebc.forward._args.args), 1)
1262+
self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0)
12671263
self.assertEqual(
1268-
pipelined_ebc.forward._args[0].postproc_modules[0],
1269-
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
1270-
# `postproc_nonweighted`.
1271-
pipelined_model.module.postproc_nonweighted,
1264+
pipelined_ebc.forward._args.args[0].steps,
1265+
[
1266+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
1267+
PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted),
1268+
GetItemArgInfoStep(0),
1269+
],
12721270
)
1271+
self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1)
1272+
self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0)
12731273
self.assertEqual(
1274-
pipelined_weighted_ebc.forward._args[0].postproc_modules[0],
1275-
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
1276-
# `postproc_weighted`.
1277-
pipelined_model.module.postproc_weighted,
1274+
pipelined_weighted_ebc.forward._args.args[0].steps,
1275+
[
1276+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
1277+
PostprocArgInfoStep(pipelined_model.module.postproc_weighted),
1278+
GetItemArgInfoStep(0),
1279+
],
12781280
)
12791281

12801282
# postproc args
12811283
self.assertEqual(len(pipeline._pipelined_postprocs), 2)
1282-
input_attr_names = {"idlist_features", "idscore_features"}
1283-
for i in range(len(pipeline._pipelined_postprocs)):
1284-
postproc_mod = pipeline._pipelined_postprocs[i]
1285-
self.assertEqual(len(postproc_mod._args), 1)
1284+
# postprocs can be added in any order, so we can't assert on exact steps structures
1285+
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args), 1)
1286+
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.kwargs), 0)
1287+
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args[0].steps), 2)
1288+
self.assertEqual(
1289+
pipeline._pipelined_postprocs[0]._args.args[0].steps[0], NoopArgInfoStep()
1290+
)
1291+
self.assertIsInstance(
1292+
pipeline._pipelined_postprocs[0]._args.args[0].steps[1], GetAttrArgInfoStep
1293+
)
12861294

1287-
input_attr_name = postproc_mod._args[0].input_attrs[1]
1288-
self.assertTrue(input_attr_name in input_attr_names)
1289-
self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name])
1290-
input_attr_names.remove(input_attr_name)
1295+
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args), 1)
1296+
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.kwargs), 0)
1297+
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args[0].steps), 2)
1298+
self.assertEqual(
1299+
pipeline._pipelined_postprocs[1]._args.args[0].steps[0], NoopArgInfoStep()
1300+
)
1301+
self.assertIsInstance(
1302+
pipeline._pipelined_postprocs[1]._args.args[0].steps[1], GetAttrArgInfoStep
1303+
)
12911304

1292-
self.assertEqual(postproc_mod._args[0].is_getitems, [False, False])
1293-
# no parent postproc module in FX graph
1294-
self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None])
1305+
get_arg_infos = {
1306+
# pyre-fixme[16]: assertions above ensure that steps[1] is a GetAttrArgInfoStep
1307+
postproc._args.args[0].steps[1].attr_name
1308+
for postproc in pipeline._pipelined_postprocs
1309+
}
1310+
self.assertEqual(get_arg_infos, {"idlist_features", "idscore_features"})
12951311

12961312
# pyre-ignore
12971313
@unittest.skipIf(
@@ -1337,69 +1353,63 @@ def test_pipeline_postproc_recursive(self) -> None:
13371353
pipelined_weighted_ebc = pipeline._pipelined_modules[1]
13381354

13391355
# Check pipelined args
1340-
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
1341-
self.assertEqual(len(ebc.forward._args), 1)
1342-
self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
1343-
self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
1344-
self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2)
1345-
self.assertIsInstance(
1346-
ebc.forward._args[0].postproc_modules[0], PipelinedPostproc
1347-
)
1348-
self.assertEqual(ebc.forward._args[0].postproc_modules[1], None)
1349-
1356+
self.assertEqual(len(pipelined_ebc.forward._args.args), 1)
1357+
self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0)
13501358
self.assertEqual(
1351-
pipelined_ebc.forward._args[0].postproc_modules[0],
1352-
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
1353-
# `postproc_nonweighted`.
1354-
pipelined_model.module.postproc_nonweighted,
1359+
pipelined_ebc.forward._args.args[0].steps,
1360+
[
1361+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
1362+
PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted),
1363+
GetItemArgInfoStep(0),
1364+
],
13551365
)
1366+
self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1)
1367+
self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0)
13561368
self.assertEqual(
1357-
pipelined_weighted_ebc.forward._args[0].postproc_modules[0],
1358-
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
1359-
# `postproc_weighted`.
1360-
pipelined_model.module.postproc_weighted,
1369+
pipelined_weighted_ebc.forward._args.args[0].steps,
1370+
[
1371+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
1372+
PostprocArgInfoStep(pipelined_model.module.postproc_weighted),
1373+
GetItemArgInfoStep(0),
1374+
],
13611375
)
13621376

13631377
# postproc args
13641378
self.assertEqual(len(pipeline._pipelined_postprocs), 3)
13651379

1366-
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
1367-
# `_postproc_module`.
1380+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
13681381
parent_postproc_mod = pipelined_model.module._postproc_module
13691382

13701383
for postproc_mod in pipeline._pipelined_postprocs:
13711384
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
13721385
# `postproc_nonweighted`.
13731386
if postproc_mod == pipelined_model.module.postproc_nonweighted:
1374-
self.assertEqual(len(postproc_mod._args), 1)
1375-
args = postproc_mod._args[0]
1376-
self.assertEqual(args.input_attrs, ["", "idlist_features"])
1377-
self.assertEqual(args.is_getitems, [False, False])
1378-
self.assertEqual(len(args.postproc_modules), 2)
1387+
self.assertEqual(len(postproc_mod._args.args), 1)
1388+
self.assertEqual(len(postproc_mod._args.kwargs), 0)
13791389
self.assertEqual(
1380-
args.postproc_modules[0],
1381-
parent_postproc_mod,
1390+
postproc_mod._args.args[0].steps,
1391+
[
1392+
PostprocArgInfoStep(parent_postproc_mod),
1393+
GetAttrArgInfoStep("idlist_features"),
1394+
],
13821395
)
1383-
self.assertEqual(args.postproc_modules[1], None)
1396+
13841397
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
13851398
# `postproc_weighted`.
13861399
elif postproc_mod == pipelined_model.module.postproc_weighted:
1387-
self.assertEqual(len(postproc_mod._args), 1)
1388-
args = postproc_mod._args[0]
1389-
self.assertEqual(args.input_attrs, ["", "idscore_features"])
1390-
self.assertEqual(args.is_getitems, [False, False])
1391-
self.assertEqual(len(args.postproc_modules), 2)
1400+
self.assertEqual(len(postproc_mod._args.args), 1)
1401+
self.assertEqual(len(postproc_mod._args.kwargs), 0)
13921402
self.assertEqual(
1393-
args.postproc_modules[0],
1394-
parent_postproc_mod,
1403+
postproc_mod._args.args[0].steps,
1404+
[
1405+
PostprocArgInfoStep(parent_postproc_mod),
1406+
GetAttrArgInfoStep("idscore_features"),
1407+
],
13951408
)
1396-
self.assertEqual(args.postproc_modules[1], None)
13971409
elif postproc_mod == parent_postproc_mod:
1398-
self.assertEqual(len(postproc_mod._args), 1)
1399-
args = postproc_mod._args[0]
1400-
self.assertEqual(args.input_attrs, [""])
1401-
self.assertEqual(args.is_getitems, [False])
1402-
self.assertEqual(args.postproc_modules, [None])
1410+
self.assertEqual(len(postproc_mod._args.args), 1)
1411+
self.assertEqual(len(postproc_mod._args.kwargs), 0)
1412+
self.assertEqual(postproc_mod._args.args[0].steps, [NoopArgInfoStep()])
14031413

14041414
# pyre-ignore
14051415
@unittest.skipIf(

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

Lines changed: 42 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
TrainPipelineSparseDistTestBase,
2424
)
2525
from torchrec.distributed.train_pipeline.utils import (
26-
_build_args_kwargs,
27-
_get_node_args,
2826
_rewrite_model,
2927
ArgInfo,
28+
ArgInfoStepFactory,
29+
CallArgs,
30+
NodeArgsHelper,
3031
PipelinedForward,
3132
PipelinedPostproc,
3233
TrainPipelineContext,
@@ -110,17 +111,19 @@ def test_rewrite_model(self) -> None:
110111
self.assertEqual(
111112
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
112113
# `sparse`.
113-
sharded_model.module.sparse.ebc.forward._args[0].postproc_modules[0],
114+
sharded_model.module.sparse.ebc.forward._args.args[0]
115+
.steps[0]
116+
.postproc_module,
114117
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
115118
# `postproc_module`.
116119
sharded_model.module.postproc_module,
117120
)
118121
self.assertEqual(
119122
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
120123
# `sparse`.
121-
sharded_model.module.sparse.weighted_ebc.forward._args[0].postproc_modules[
122-
0
123-
],
124+
sharded_model.module.sparse.weighted_ebc.forward._args.args[0]
125+
.steps[0]
126+
.postproc_module,
124127
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
125128
# `postproc_module`.
126129
sharded_model.module.postproc_module,
@@ -154,7 +157,7 @@ def forward(self, x):
154157
rewritten_model.test_module = PipelinedPostproc(
155158
postproc_module=rewritten_model.test_module,
156159
fqn="test_module",
157-
args=[],
160+
args=CallArgs(args=[], kwargs={}),
158161
context=TrainPipelineContext(),
159162
default_stream=MagicMock(),
160163
dist_stream=MagicMock(),
@@ -260,83 +263,53 @@ def test_restore_from_snapshot(self) -> None:
260263
@parameterized.expand(
261264
[
262265
(
263-
[
264-
# Empty attrs to ignore any attr based logic.
265-
ArgInfo(
266-
input_attrs=[
267-
"",
268-
],
269-
is_getitems=[False],
270-
postproc_modules=[None],
271-
constants=[None],
272-
name="id_list_features",
273-
),
274-
ArgInfo(
275-
input_attrs=[],
276-
is_getitems=[],
277-
postproc_modules=[],
278-
constants=[],
279-
name="id_score_list_features",
280-
),
281-
],
266+
CallArgs(
267+
args=[],
268+
kwargs={
269+
"id_list_features": ArgInfo(steps=[ArgInfoStepFactory.noop()]),
270+
# Empty attrs to ignore any attr based logic.
271+
"id_score_list_features": ArgInfo(
272+
steps=[ArgInfoStepFactory.noop()]
273+
),
274+
},
275+
),
282276
0,
283277
["id_list_features", "id_score_list_features"],
284278
),
285279
(
286-
[
287-
# Empty attrs to ignore any attr based logic.
288-
ArgInfo(
289-
input_attrs=[
290-
"",
291-
],
292-
is_getitems=[False],
293-
postproc_modules=[None],
294-
constants=[None],
295-
name=None,
296-
),
297-
ArgInfo(
298-
input_attrs=[],
299-
is_getitems=[],
300-
postproc_modules=[],
301-
constants=[],
302-
name=None,
303-
),
304-
],
280+
CallArgs(
281+
args=[
282+
# Empty attrs to ignore any attr based logic.
283+
ArgInfo(steps=[ArgInfoStepFactory.noop()]),
284+
ArgInfo(steps=[]),
285+
],
286+
kwargs={},
287+
),
305288
2,
306289
[],
307290
),
308291
(
309-
[
310-
# Empty attrs to ignore any attr based logic.
311-
ArgInfo(
312-
input_attrs=[
313-
"",
314-
],
315-
is_getitems=[False],
316-
postproc_modules=[None],
317-
constants=[None],
318-
name=None,
319-
),
320-
ArgInfo(
321-
input_attrs=[],
322-
is_getitems=[],
323-
postproc_modules=[],
324-
constants=[],
325-
name="id_score_list_features",
326-
),
327-
],
292+
CallArgs(
293+
args=[
294+
# Empty attrs to ignore any attr based logic.
295+
ArgInfo(
296+
steps=[ArgInfoStepFactory.noop()],
297+
)
298+
],
299+
kwargs={"id_score_list_features": ArgInfo(steps=[])},
300+
),
328301
1,
329302
["id_score_list_features"],
330303
),
331304
]
332305
)
333306
def test_build_args_kwargs(
334307
self,
335-
fwd_args: List[ArgInfo],
308+
fwd_args: CallArgs,
336309
args_len: int,
337310
kwarges_keys: List[str],
338311
) -> None:
339-
args, kwargs = _build_args_kwargs("initial_input", fwd_args)
312+
args, kwargs = fwd_args.build_args_kwargs("initial_input")
340313
self.assertEqual(len(args), args_len)
341314
self.assertEqual(list(kwargs.keys()), kwarges_keys)
342315

@@ -367,10 +340,9 @@ def test_get_node_args_helper_call_module_kjt(self) -> None:
367340
{},
368341
)
369342

370-
num_found = 0
371-
_, num_found = _get_node_args(
372-
MagicMock(), kjt_node, set(), TrainPipelineContext(), False
373-
)
343+
node_args_helper = NodeArgsHelper(MagicMock(), TrainPipelineContext(), False)
344+
345+
_, num_found = node_args_helper.get_node_args(kjt_node)
374346

375347
# Weights is call_module node, so we should only find 2 args unmodified
376348
self.assertEqual(num_found, len(kjt_args) - 1)

0 commit comments

Comments
 (0)