|
64 | 64 | DataLoadingThread,
|
65 | 65 | EmbeddingPipelinedForward,
|
66 | 66 | get_h2d_func,
|
| 67 | + GetAttrArgInfoStep, |
| 68 | + GetItemArgInfoStep, |
| 69 | + NoopArgInfoStep, |
67 | 70 | PipelinedForward,
|
68 | 71 | PipelinedPostproc,
|
69 | 72 | PipelineStage,
|
| 73 | + PostprocArgInfoStep, |
70 | 74 | SparseDataDistUtil,
|
71 | 75 | StageOut,
|
72 | 76 | TrainPipelineContext,
|
@@ -1254,44 +1258,56 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
|
1254 | 1258 | pipelined_weighted_ebc = pipeline._pipelined_modules[1]
|
1255 | 1259 |
|
1256 | 1260 | # Check pipelined args
|
1257 |
| - for ebc in [pipelined_ebc, pipelined_weighted_ebc]: |
1258 |
| - self.assertEqual(len(ebc.forward._args), 1) |
1259 |
| - self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) |
1260 |
| - self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) |
1261 |
| - self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) |
1262 |
| - self.assertIsInstance( |
1263 |
| - ebc.forward._args[0].postproc_modules[0], PipelinedPostproc |
1264 |
| - ) |
1265 |
| - self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) |
1266 |
| - |
| 1261 | + self.assertEqual(len(pipelined_ebc.forward._args.args), 1) |
| 1262 | + self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0) |
1267 | 1263 | self.assertEqual(
|
1268 |
| - pipelined_ebc.forward._args[0].postproc_modules[0], |
1269 |
| - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1270 |
| - # `postproc_nonweighted`. |
1271 |
| - pipelined_model.module.postproc_nonweighted, |
| 1264 | + pipelined_ebc.forward._args.args[0].steps, |
| 1265 | + [ |
| 1266 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1267 | + PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted), |
| 1268 | + GetItemArgInfoStep(0), |
| 1269 | + ], |
1272 | 1270 | )
|
| 1271 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1) |
| 1272 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0) |
1273 | 1273 | self.assertEqual(
|
1274 |
| - pipelined_weighted_ebc.forward._args[0].postproc_modules[0], |
1275 |
| - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1276 |
| - # `postproc_weighted`. |
1277 |
| - pipelined_model.module.postproc_weighted, |
| 1274 | + pipelined_weighted_ebc.forward._args.args[0].steps, |
| 1275 | + [ |
| 1276 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1277 | + PostprocArgInfoStep(pipelined_model.module.postproc_weighted), |
| 1278 | + GetItemArgInfoStep(0), |
| 1279 | + ], |
1278 | 1280 | )
|
1279 | 1281 |
|
1280 | 1282 | # postproc args
|
1281 | 1283 | self.assertEqual(len(pipeline._pipelined_postprocs), 2)
|
1282 |
| - input_attr_names = {"idlist_features", "idscore_features"} |
1283 |
| - for i in range(len(pipeline._pipelined_postprocs)): |
1284 |
| - postproc_mod = pipeline._pipelined_postprocs[i] |
1285 |
| - self.assertEqual(len(postproc_mod._args), 1) |
| 1284 | + # postprocs can be added in any order, so we can't assert on exact steps structures |
| 1285 | + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args), 1) |
| 1286 | + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.kwargs), 0) |
| 1287 | + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args[0].steps), 2) |
| 1288 | + self.assertEqual( |
| 1289 | + pipeline._pipelined_postprocs[0]._args.args[0].steps[0], NoopArgInfoStep() |
| 1290 | + ) |
| 1291 | + self.assertIsInstance( |
| 1292 | + pipeline._pipelined_postprocs[0]._args.args[0].steps[1], GetAttrArgInfoStep |
| 1293 | + ) |
1286 | 1294 |
|
1287 |
| - input_attr_name = postproc_mod._args[0].input_attrs[1] |
1288 |
| - self.assertTrue(input_attr_name in input_attr_names) |
1289 |
| - self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name]) |
1290 |
| - input_attr_names.remove(input_attr_name) |
| 1295 | + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args), 1) |
| 1296 | + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.kwargs), 0) |
| 1297 | + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args[0].steps), 2) |
| 1298 | + self.assertEqual( |
| 1299 | + pipeline._pipelined_postprocs[1]._args.args[0].steps[0], NoopArgInfoStep() |
| 1300 | + ) |
| 1301 | + self.assertIsInstance( |
| 1302 | + pipeline._pipelined_postprocs[1]._args.args[0].steps[1], GetAttrArgInfoStep |
| 1303 | + ) |
1291 | 1304 |
|
1292 |
| - self.assertEqual(postproc_mod._args[0].is_getitems, [False, False]) |
1293 |
| - # no parent postproc module in FX graph |
1294 |
| - self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None]) |
| 1305 | + get_arg_infos = { |
| 1306 | + # pyre-fixme[16]: assertions above ensure that steps[1] is a GetAttrArgInfoStep |
| 1307 | + postproc._args.args[0].steps[1].attr_name |
| 1308 | + for postproc in pipeline._pipelined_postprocs |
| 1309 | + } |
| 1310 | + self.assertEqual(get_arg_infos, {"idlist_features", "idscore_features"}) |
1295 | 1311 |
|
1296 | 1312 | # pyre-ignore
|
1297 | 1313 | @unittest.skipIf(
|
@@ -1337,69 +1353,63 @@ def test_pipeline_postproc_recursive(self) -> None:
|
1337 | 1353 | pipelined_weighted_ebc = pipeline._pipelined_modules[1]
|
1338 | 1354 |
|
1339 | 1355 | # Check pipelined args
|
1340 |
| - for ebc in [pipelined_ebc, pipelined_weighted_ebc]: |
1341 |
| - self.assertEqual(len(ebc.forward._args), 1) |
1342 |
| - self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) |
1343 |
| - self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) |
1344 |
| - self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) |
1345 |
| - self.assertIsInstance( |
1346 |
| - ebc.forward._args[0].postproc_modules[0], PipelinedPostproc |
1347 |
| - ) |
1348 |
| - self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) |
1349 |
| - |
| 1356 | + self.assertEqual(len(pipelined_ebc.forward._args.args), 1) |
| 1357 | + self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0) |
1350 | 1358 | self.assertEqual(
|
1351 |
| - pipelined_ebc.forward._args[0].postproc_modules[0], |
1352 |
| - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1353 |
| - # `postproc_nonweighted`. |
1354 |
| - pipelined_model.module.postproc_nonweighted, |
| 1359 | + pipelined_ebc.forward._args.args[0].steps, |
| 1360 | + [ |
| 1361 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1362 | + PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted), |
| 1363 | + GetItemArgInfoStep(0), |
| 1364 | + ], |
1355 | 1365 | )
|
| 1366 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1) |
| 1367 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0) |
1356 | 1368 | self.assertEqual(
|
1357 |
| - pipelined_weighted_ebc.forward._args[0].postproc_modules[0], |
1358 |
| - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1359 |
| - # `postproc_weighted`. |
1360 |
| - pipelined_model.module.postproc_weighted, |
| 1369 | + pipelined_weighted_ebc.forward._args.args[0].steps, |
| 1370 | + [ |
| 1371 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1372 | + PostprocArgInfoStep(pipelined_model.module.postproc_weighted), |
| 1373 | + GetItemArgInfoStep(0), |
| 1374 | + ], |
1361 | 1375 | )
|
1362 | 1376 |
|
1363 | 1377 | # postproc args
|
1364 | 1378 | self.assertEqual(len(pipeline._pipelined_postprocs), 3)
|
1365 | 1379 |
|
1366 |
| - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1367 |
| - # `_postproc_module`. |
| 1380 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
1368 | 1381 | parent_postproc_mod = pipelined_model.module._postproc_module
|
1369 | 1382 |
|
1370 | 1383 | for postproc_mod in pipeline._pipelined_postprocs:
|
1371 | 1384 | # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
1372 | 1385 | # `postproc_nonweighted`.
|
1373 | 1386 | if postproc_mod == pipelined_model.module.postproc_nonweighted:
|
1374 |
| - self.assertEqual(len(postproc_mod._args), 1) |
1375 |
| - args = postproc_mod._args[0] |
1376 |
| - self.assertEqual(args.input_attrs, ["", "idlist_features"]) |
1377 |
| - self.assertEqual(args.is_getitems, [False, False]) |
1378 |
| - self.assertEqual(len(args.postproc_modules), 2) |
| 1387 | + self.assertEqual(len(postproc_mod._args.args), 1) |
| 1388 | + self.assertEqual(len(postproc_mod._args.kwargs), 0) |
1379 | 1389 | self.assertEqual(
|
1380 |
| - args.postproc_modules[0], |
1381 |
| - parent_postproc_mod, |
| 1390 | + postproc_mod._args.args[0].steps, |
| 1391 | + [ |
| 1392 | + PostprocArgInfoStep(parent_postproc_mod), |
| 1393 | + GetAttrArgInfoStep("idlist_features"), |
| 1394 | + ], |
1382 | 1395 | )
|
1383 |
| - self.assertEqual(args.postproc_modules[1], None) |
| 1396 | + |
1384 | 1397 | # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
1385 | 1398 | # `postproc_weighted`.
|
1386 | 1399 | elif postproc_mod == pipelined_model.module.postproc_weighted:
|
1387 |
| - self.assertEqual(len(postproc_mod._args), 1) |
1388 |
| - args = postproc_mod._args[0] |
1389 |
| - self.assertEqual(args.input_attrs, ["", "idscore_features"]) |
1390 |
| - self.assertEqual(args.is_getitems, [False, False]) |
1391 |
| - self.assertEqual(len(args.postproc_modules), 2) |
| 1400 | + self.assertEqual(len(postproc_mod._args.args), 1) |
| 1401 | + self.assertEqual(len(postproc_mod._args.kwargs), 0) |
1392 | 1402 | self.assertEqual(
|
1393 |
| - args.postproc_modules[0], |
1394 |
| - parent_postproc_mod, |
| 1403 | + postproc_mod._args.args[0].steps, |
| 1404 | + [ |
| 1405 | + PostprocArgInfoStep(parent_postproc_mod), |
| 1406 | + GetAttrArgInfoStep("idscore_features"), |
| 1407 | + ], |
1395 | 1408 | )
|
1396 |
| - self.assertEqual(args.postproc_modules[1], None) |
1397 | 1409 | elif postproc_mod == parent_postproc_mod:
|
1398 |
| - self.assertEqual(len(postproc_mod._args), 1) |
1399 |
| - args = postproc_mod._args[0] |
1400 |
| - self.assertEqual(args.input_attrs, [""]) |
1401 |
| - self.assertEqual(args.is_getitems, [False]) |
1402 |
| - self.assertEqual(args.postproc_modules, [None]) |
| 1410 | + self.assertEqual(len(postproc_mod._args.args), 1) |
| 1411 | + self.assertEqual(len(postproc_mod._args.kwargs), 0) |
| 1412 | + self.assertEqual(postproc_mod._args.args[0].steps, [NoopArgInfoStep()]) |
1403 | 1413 |
|
1404 | 1414 | # pyre-ignore
|
1405 | 1415 | @unittest.skipIf(
|
|
0 commit comments