|
63 | 63 | DataLoadingThread, |
64 | 64 | EmbeddingPipelinedForward, |
65 | 65 | get_h2d_func, |
| 66 | + GetAttrArgInfoStep, |
| 67 | + GetItemArgInfoStep, |
| 68 | + NoopArgInfoStep, |
66 | 69 | PipelinedForward, |
67 | 70 | PipelinedPostproc, |
68 | 71 | PipelineStage, |
| 72 | + PostprocArgInfoStep, |
69 | 73 | SparseDataDistUtil, |
70 | 74 | StageOut, |
71 | 75 | TrainPipelineContext, |
@@ -1152,44 +1156,56 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None: |
1152 | 1156 | pipelined_weighted_ebc = pipeline._pipelined_modules[1] |
1153 | 1157 |
|
1154 | 1158 | # Check pipelined args |
1155 | | - for ebc in [pipelined_ebc, pipelined_weighted_ebc]: |
1156 | | - self.assertEqual(len(ebc.forward._args), 1) |
1157 | | - self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) |
1158 | | - self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) |
1159 | | - self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) |
1160 | | - self.assertIsInstance( |
1161 | | - ebc.forward._args[0].postproc_modules[0], PipelinedPostproc |
1162 | | - ) |
1163 | | - self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) |
1164 | | - |
| 1159 | + self.assertEqual(len(pipelined_ebc.forward._args.args), 1) |
| 1160 | + self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0) |
1165 | 1161 | self.assertEqual( |
1166 | | - pipelined_ebc.forward._args[0].postproc_modules[0], |
1167 | | - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1168 | | - # `postproc_nonweighted`. |
1169 | | - pipelined_model.module.postproc_nonweighted, |
| 1162 | + pipelined_ebc.forward._args.args[0].steps, |
| 1163 | + [ |
| 1164 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1165 | + PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted), |
| 1166 | + GetItemArgInfoStep(0), |
| 1167 | + ], |
1170 | 1168 | ) |
| 1169 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1) |
| 1170 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0) |
1171 | 1171 | self.assertEqual( |
1172 | | - pipelined_weighted_ebc.forward._args[0].postproc_modules[0], |
1173 | | - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1174 | | - # `postproc_weighted`. |
1175 | | - pipelined_model.module.postproc_weighted, |
| 1172 | + pipelined_weighted_ebc.forward._args.args[0].steps, |
| 1173 | + [ |
| 1174 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1175 | + PostprocArgInfoStep(pipelined_model.module.postproc_weighted), |
| 1176 | + GetItemArgInfoStep(0), |
| 1177 | + ], |
1176 | 1178 | ) |
1177 | 1179 |
|
1178 | 1180 | # postproc args |
1179 | 1181 | self.assertEqual(len(pipeline._pipelined_postprocs), 2) |
1180 | | - input_attr_names = {"idlist_features", "idscore_features"} |
1181 | | - for i in range(len(pipeline._pipelined_postprocs)): |
1182 | | - postproc_mod = pipeline._pipelined_postprocs[i] |
1183 | | - self.assertEqual(len(postproc_mod._args), 1) |
| 1182 | + # postprocs can be added in any order, so we can't assert on exact steps structures |
| 1183 | + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args), 1) |
| 1184 | + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.kwargs), 0) |
| 1185 | + self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args[0].steps), 2) |
| 1186 | + self.assertEqual( |
| 1187 | + pipeline._pipelined_postprocs[0]._args.args[0].steps[0], NoopArgInfoStep() |
| 1188 | + ) |
| 1189 | + self.assertIsInstance( |
| 1190 | + pipeline._pipelined_postprocs[0]._args.args[0].steps[1], GetAttrArgInfoStep |
| 1191 | + ) |
1184 | 1192 |
|
1185 | | - input_attr_name = postproc_mod._args[0].input_attrs[1] |
1186 | | - self.assertTrue(input_attr_name in input_attr_names) |
1187 | | - self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name]) |
1188 | | - input_attr_names.remove(input_attr_name) |
| 1193 | + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args), 1) |
| 1194 | + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.kwargs), 0) |
| 1195 | + self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args[0].steps), 2) |
| 1196 | + self.assertEqual( |
| 1197 | + pipeline._pipelined_postprocs[1]._args.args[0].steps[0], NoopArgInfoStep() |
| 1198 | + ) |
| 1199 | + self.assertIsInstance( |
| 1200 | + pipeline._pipelined_postprocs[1]._args.args[0].steps[1], GetAttrArgInfoStep |
| 1201 | + ) |
1189 | 1202 |
|
1190 | | - self.assertEqual(postproc_mod._args[0].is_getitems, [False, False]) |
1191 | | - # no parent postproc module in FX graph |
1192 | | - self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None]) |
| 1203 | + get_arg_infos = { |
| 1204 | + # pyre-fixme[16]: assertions above ensure that steps[1] is a GetAttrArgInfoStep |
| 1205 | + postproc._args.args[0].steps[1].attr_name |
| 1206 | + for postproc in pipeline._pipelined_postprocs |
| 1207 | + } |
| 1208 | + self.assertEqual(get_arg_infos, {"idlist_features", "idscore_features"}) |
1193 | 1209 |
|
1194 | 1210 | # pyre-ignore |
1195 | 1211 | @unittest.skipIf( |
@@ -1235,69 +1251,63 @@ def test_pipeline_postproc_recursive(self) -> None: |
1235 | 1251 | pipelined_weighted_ebc = pipeline._pipelined_modules[1] |
1236 | 1252 |
|
1237 | 1253 | # Check pipelined args |
1238 | | - for ebc in [pipelined_ebc, pipelined_weighted_ebc]: |
1239 | | - self.assertEqual(len(ebc.forward._args), 1) |
1240 | | - self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0]) |
1241 | | - self.assertEqual(ebc.forward._args[0].is_getitems, [False, True]) |
1242 | | - self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2) |
1243 | | - self.assertIsInstance( |
1244 | | - ebc.forward._args[0].postproc_modules[0], PipelinedPostproc |
1245 | | - ) |
1246 | | - self.assertEqual(ebc.forward._args[0].postproc_modules[1], None) |
1247 | | - |
| 1254 | + self.assertEqual(len(pipelined_ebc.forward._args.args), 1) |
| 1255 | + self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0) |
1248 | 1256 | self.assertEqual( |
1249 | | - pipelined_ebc.forward._args[0].postproc_modules[0], |
1250 | | - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1251 | | - # `postproc_nonweighted`. |
1252 | | - pipelined_model.module.postproc_nonweighted, |
| 1257 | + pipelined_ebc.forward._args.args[0].steps, |
| 1258 | + [ |
| 1259 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1260 | + PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted), |
| 1261 | + GetItemArgInfoStep(0), |
| 1262 | + ], |
1253 | 1263 | ) |
| 1264 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1) |
| 1265 | + self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0) |
1254 | 1266 | self.assertEqual( |
1255 | | - pipelined_weighted_ebc.forward._args[0].postproc_modules[0], |
1256 | | - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1257 | | - # `postproc_weighted`. |
1258 | | - pipelined_model.module.postproc_weighted, |
| 1267 | + pipelined_weighted_ebc.forward._args.args[0].steps, |
| 1268 | + [ |
| 1269 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
| 1270 | + PostprocArgInfoStep(pipelined_model.module.postproc_weighted), |
| 1271 | + GetItemArgInfoStep(0), |
| 1272 | + ], |
1259 | 1273 | ) |
1260 | 1274 |
|
1261 | 1275 | # postproc args |
1262 | 1276 | self.assertEqual(len(pipeline._pipelined_postprocs), 3) |
1263 | 1277 |
|
1264 | | - # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1265 | | - # `_postproc_module`. |
| 1278 | + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`. |
1266 | 1279 | parent_postproc_mod = pipelined_model.module._postproc_module |
1267 | 1280 |
|
1268 | 1281 | for postproc_mod in pipeline._pipelined_postprocs: |
1269 | 1282 | # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1270 | 1283 | # `postproc_nonweighted`. |
1271 | 1284 | if postproc_mod == pipelined_model.module.postproc_nonweighted: |
1272 | | - self.assertEqual(len(postproc_mod._args), 1) |
1273 | | - args = postproc_mod._args[0] |
1274 | | - self.assertEqual(args.input_attrs, ["", "idlist_features"]) |
1275 | | - self.assertEqual(args.is_getitems, [False, False]) |
1276 | | - self.assertEqual(len(args.postproc_modules), 2) |
| 1285 | + self.assertEqual(len(postproc_mod._args.args), 1) |
| 1286 | + self.assertEqual(len(postproc_mod._args.kwargs), 0) |
1277 | 1287 | self.assertEqual( |
1278 | | - args.postproc_modules[0], |
1279 | | - parent_postproc_mod, |
| 1288 | + postproc_mod._args.args[0].steps, |
| 1289 | + [ |
| 1290 | + PostprocArgInfoStep(parent_postproc_mod), |
| 1291 | + GetAttrArgInfoStep("idlist_features"), |
| 1292 | + ], |
1280 | 1293 | ) |
1281 | | - self.assertEqual(args.postproc_modules[1], None) |
| 1294 | + |
1282 | 1295 | # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute |
1283 | 1296 | # `postproc_weighted`. |
1284 | 1297 | elif postproc_mod == pipelined_model.module.postproc_weighted: |
1285 | | - self.assertEqual(len(postproc_mod._args), 1) |
1286 | | - args = postproc_mod._args[0] |
1287 | | - self.assertEqual(args.input_attrs, ["", "idscore_features"]) |
1288 | | - self.assertEqual(args.is_getitems, [False, False]) |
1289 | | - self.assertEqual(len(args.postproc_modules), 2) |
| 1298 | + self.assertEqual(len(postproc_mod._args.args), 1) |
| 1299 | + self.assertEqual(len(postproc_mod._args.kwargs), 0) |
1290 | 1300 | self.assertEqual( |
1291 | | - args.postproc_modules[0], |
1292 | | - parent_postproc_mod, |
| 1301 | + postproc_mod._args.args[0].steps, |
| 1302 | + [ |
| 1303 | + PostprocArgInfoStep(parent_postproc_mod), |
| 1304 | + GetAttrArgInfoStep("idscore_features"), |
| 1305 | + ], |
1293 | 1306 | ) |
1294 | | - self.assertEqual(args.postproc_modules[1], None) |
1295 | 1307 | elif postproc_mod == parent_postproc_mod: |
1296 | | - self.assertEqual(len(postproc_mod._args), 1) |
1297 | | - args = postproc_mod._args[0] |
1298 | | - self.assertEqual(args.input_attrs, [""]) |
1299 | | - self.assertEqual(args.is_getitems, [False]) |
1300 | | - self.assertEqual(args.postproc_modules, [None]) |
| 1308 | + self.assertEqual(len(postproc_mod._args.args), 1) |
| 1309 | + self.assertEqual(len(postproc_mod._args.kwargs), 0) |
| 1310 | + self.assertEqual(postproc_mod._args.args[0].steps, [NoopArgInfoStep()]) |
1301 | 1311 |
|
1302 | 1312 | # pyre-ignore |
1303 | 1313 | @unittest.skipIf( |
|
0 commit comments