|
63 | 63 | DataLoadingThread,
|
64 | 64 | EmbeddingPipelinedForward,
|
65 | 65 | get_h2d_func,
|
| 66 | + GetAttrArgInfoStep, |
| 67 | + GetItemArgInfoStep, |
| 68 | + NoopArgInfoStep, |
66 | 69 | PipelinedForward,
|
67 | 70 | PipelinedPostproc,
|
68 | 71 | PipelineStage,
|
| 72 | + PostprocArgInfoStep, |
69 | 73 | SparseDataDistUtil,
|
70 | 74 | StageOut,
|
71 | 75 | TrainPipelineContext,
|
@@ -1152,44 +1156,56 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
|
1152 | 1156 | pipelined_weighted_ebc = pipeline._pipelined_modules[1]
|
1153 | 1157 |
|
1154 | 1158 | # Check pipelined args
|
1155 |
| - for ebc in [pipelined_ebc, pipelined_weighted_ebc]: |
1156 |
| - self.assertEqual(len(ebc.forward._args), 1) |
1157 |
| - self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) |
1158 |
| - self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) |
1159 |
| - self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) |
1160 |
| - self.assertIsInstance( |
1161 |
| - ebc.forward._args[0].postproc_modules[0], PipelinedPostproc |
1162 |
| - ) |
1163 |
| - self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) |
1164 |
| - |
| 1159 | + self.assertEqual(len(pipelined_ebc.forward._args.args), 1) |
| 1160 | + self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0) |
1165 | 1161 | self.assertEqual(
|
1166 |
| - pipelined_ebc.forward._args[0].postproc_modules[0], |
1167 |
| - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1168 |
| - # `postproc_nonweighted`. |
1169 |
| - pipelined_model.module.postproc_nonweighted, |
| 1162 | + pipelined_ebc.forward._args.args[0].steps, |
| 1163 | + [ |
| 1164 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1165 | + PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted), |
| 1166 | + GetItemArgInfoStep(0), |
| 1167 | + ], |
1170 | 1168 | )
|
| 1169 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1) |
| 1170 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0) |
1171 | 1171 | self.assertEqual(
|
1172 |
| - pipelined_weighted_ebc.forward._args[0].postproc_modules[0], |
1173 |
| - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1174 |
| - # `postproc_weighted`. |
1175 |
| - pipelined_model.module.postproc_weighted, |
| 1172 | + pipelined_weighted_ebc.forward._args.args[0].steps, |
| 1173 | + [ |
| 1174 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1175 | + PostprocArgInfoStep(pipelined_model.module.postproc_weighted), |
| 1176 | + GetItemArgInfoStep(0), |
| 1177 | + ], |
1176 | 1178 | )
|
1177 | 1179 |
|
1178 | 1180 | # postproc args
|
1179 | 1181 | self.assertEqual(len(pipeline._pipelined_postprocs), 2)
|
1180 |
| - input_attr_names = {"idlist_features", "idscore_features"} |
1181 |
| - for i in range(len(pipeline._pipelined_postprocs)): |
1182 |
| - postproc_mod = pipeline._pipelined_postprocs[i] |
1183 |
| - self.assertEqual(len(postproc_mod._args), 1) |
| 1182 | + # postprocs can be added in any order, so we can't assert on exact steps structures |
| 1183 | + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args), 1) |
| 1184 | + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.kwargs), 0) |
| 1185 | + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args[0].steps), 2) |
| 1186 | + self.assertEqual( |
| 1187 | + pipeline._pipelined_postprocs[0]._args.args[0].steps[0], NoopArgInfoStep() |
| 1188 | + ) |
| 1189 | + self.assertIsInstance( |
| 1190 | + pipeline._pipelined_postprocs[0]._args.args[0].steps[1], GetAttrArgInfoStep |
| 1191 | + ) |
1184 | 1192 |
|
1185 |
| - input_attr_name = postproc_mod._args[0].input_attrs[1] |
1186 |
| - self.assertTrue(input_attr_name in input_attr_names) |
1187 |
| - self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name]) |
1188 |
| - input_attr_names.remove(input_attr_name) |
| 1193 | + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args), 1) |
| 1194 | + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.kwargs), 0) |
| 1195 | + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args[0].steps), 2) |
| 1196 | + self.assertEqual( |
| 1197 | + pipeline._pipelined_postprocs[1]._args.args[0].steps[0], NoopArgInfoStep() |
| 1198 | + ) |
| 1199 | + self.assertIsInstance( |
| 1200 | + pipeline._pipelined_postprocs[1]._args.args[0].steps[1], GetAttrArgInfoStep |
| 1201 | + ) |
1189 | 1202 |
|
1190 |
| - self.assertEqual(postproc_mod._args[0].is_getitems, [False, False]) |
1191 |
| - # no parent postproc module in FX graph |
1192 |
| - self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None]) |
| 1203 | + get_arg_infos = { |
| 1204 | + # pyre-fixme[16]: assertions above ensure that steps[1] is a GetAttrArgInfoStep |
| 1205 | + postproc._args.args[0].steps[1].attr_name |
| 1206 | + for postproc in pipeline._pipelined_postprocs |
| 1207 | + } |
| 1208 | + self.assertEqual(get_arg_infos, {"idlist_features", "idscore_features"}) |
1193 | 1209 |
|
1194 | 1210 | # pyre-ignore
|
1195 | 1211 | @unittest.skipIf(
|
@@ -1235,69 +1251,63 @@ def test_pipeline_postproc_recursive(self) -> None:
|
1235 | 1251 | pipelined_weighted_ebc = pipeline._pipelined_modules[1]
|
1236 | 1252 |
|
1237 | 1253 | # Check pipelined args
|
1238 |
| - for ebc in [pipelined_ebc, pipelined_weighted_ebc]: |
1239 |
| - self.assertEqual(len(ebc.forward._args), 1) |
1240 |
| - self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) |
1241 |
| - self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) |
1242 |
| - self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) |
1243 |
| - self.assertIsInstance( |
1244 |
| - ebc.forward._args[0].postproc_modules[0], PipelinedPostproc |
1245 |
| - ) |
1246 |
| - self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) |
1247 |
| - |
| 1254 | + self.assertEqual(len(pipelined_ebc.forward._args.args), 1) |
| 1255 | + self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0) |
1248 | 1256 | self.assertEqual(
|
1249 |
| - pipelined_ebc.forward._args[0].postproc_modules[0], |
1250 |
| - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1251 |
| - # `postproc_nonweighted`. |
1252 |
| - pipelined_model.module.postproc_nonweighted, |
| 1257 | + pipelined_ebc.forward._args.args[0].steps, |
| 1258 | + [ |
| 1259 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1260 | + PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted), |
| 1261 | + GetItemArgInfoStep(0), |
| 1262 | + ], |
1253 | 1263 | )
|
| 1264 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1) |
| 1265 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0) |
1254 | 1266 | self.assertEqual(
|
1255 |
| - pipelined_weighted_ebc.forward._args[0].postproc_modules[0], |
1256 |
| - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1257 |
| - # `postproc_weighted`. |
1258 |
| - pipelined_model.module.postproc_weighted, |
| 1267 | + pipelined_weighted_ebc.forward._args.args[0].steps, |
| 1268 | + [ |
| 1269 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1270 | + PostprocArgInfoStep(pipelined_model.module.postproc_weighted), |
| 1271 | + GetItemArgInfoStep(0), |
| 1272 | + ], |
1259 | 1273 | )
|
1260 | 1274 |
|
1261 | 1275 | # postproc args
|
1262 | 1276 | self.assertEqual(len(pipeline._pipelined_postprocs), 3)
|
1263 | 1277 |
|
1264 |
| - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1265 |
| - # `_postproc_module`. |
| 1278 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
1266 | 1279 | parent_postproc_mod = pipelined_model.module._postproc_module
|
1267 | 1280 |
|
1268 | 1281 | for postproc_mod in pipeline._pipelined_postprocs:
|
1269 | 1282 | # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
1270 | 1283 | # `postproc_nonweighted`.
|
1271 | 1284 | if postproc_mod == pipelined_model.module.postproc_nonweighted:
|
1272 |
| - self.assertEqual(len(postproc_mod._args), 1) |
1273 |
| - args = postproc_mod._args[0] |
1274 |
| - self.assertEqual(args.input_attrs, ["", "idlist_features"]) |
1275 |
| - self.assertEqual(args.is_getitems, [False, False]) |
1276 |
| - self.assertEqual(len(args.postproc_modules), 2) |
| 1285 | + self.assertEqual(len(postproc_mod._args.args), 1) |
| 1286 | + self.assertEqual(len(postproc_mod._args.kwargs), 0) |
1277 | 1287 | self.assertEqual(
|
1278 |
| - args.postproc_modules[0], |
1279 |
| - parent_postproc_mod, |
| 1288 | + postproc_mod._args.args[0].steps, |
| 1289 | + [ |
| 1290 | + PostprocArgInfoStep(parent_postproc_mod), |
| 1291 | + GetAttrArgInfoStep("idlist_features"), |
| 1292 | + ], |
1280 | 1293 | )
|
1281 |
| - self.assertEqual(args.postproc_modules[1], None) |
| 1294 | + |
1282 | 1295 | # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
1283 | 1296 | # `postproc_weighted`.
|
1284 | 1297 | elif postproc_mod == pipelined_model.module.postproc_weighted:
|
1285 |
| - self.assertEqual(len(postproc_mod._args), 1) |
1286 |
| - args = postproc_mod._args[0] |
1287 |
| - self.assertEqual(args.input_attrs, ["", "idscore_features"]) |
1288 |
| - self.assertEqual(args.is_getitems, [False, False]) |
1289 |
| - self.assertEqual(len(args.postproc_modules), 2) |
| 1298 | + self.assertEqual(len(postproc_mod._args.args), 1) |
| 1299 | + self.assertEqual(len(postproc_mod._args.kwargs), 0) |
1290 | 1300 | self.assertEqual(
|
1291 |
| - args.postproc_modules[0], |
1292 |
| - parent_postproc_mod, |
| 1301 | + postproc_mod._args.args[0].steps, |
| 1302 | + [ |
| 1303 | + PostprocArgInfoStep(parent_postproc_mod), |
| 1304 | + GetAttrArgInfoStep("idscore_features"), |
| 1305 | + ], |
1293 | 1306 | )
|
1294 |
| - self.assertEqual(args.postproc_modules[1], None) |
1295 | 1307 | elif postproc_mod == parent_postproc_mod:
|
1296 |
| - self.assertEqual(len(postproc_mod._args), 1) |
1297 |
| - args = postproc_mod._args[0] |
1298 |
| - self.assertEqual(args.input_attrs, [""]) |
1299 |
| - self.assertEqual(args.is_getitems, [False]) |
1300 |
| - self.assertEqual(args.postproc_modules, [None]) |
| 1308 | + self.assertEqual(len(postproc_mod._args.args), 1) |
| 1309 | + self.assertEqual(len(postproc_mod._args.kwargs), 0) |
| 1310 | + self.assertEqual(postproc_mod._args.args[0].steps, [NoopArgInfoStep()]) |
1301 | 1311 |
|
1302 | 1312 | # pyre-ignore
|
1303 | 1313 | @unittest.skipIf(
|
|
0 commit comments