|
10 | 10 | import copy
|
11 | 11 | import enum
|
12 | 12 | import unittest
|
| 13 | +from typing import Any, List |
13 | 14 | from unittest.mock import MagicMock
|
14 | 15 |
|
15 | 16 | import torch
|
| 17 | +from parameterized import parameterized |
16 | 18 |
|
17 | 19 | from torchrec.distributed.embedding_types import EmbeddingComputeKernel
|
18 | 20 | from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule
|
|
21 | 23 | TrainPipelineSparseDistTestBase,
|
22 | 24 | )
|
23 | 25 | from torchrec.distributed.train_pipeline.utils import (
|
| 26 | + _build_args_kwargs, |
24 | 27 | _get_node_args,
|
25 | 28 | _rewrite_model,
|
| 29 | + ArgInfo, |
26 | 30 | PipelinedForward,
|
27 | 31 | PipelinedPostproc,
|
28 | 32 | TrainPipelineContext,
|
@@ -253,6 +257,89 @@ def test_restore_from_snapshot(self) -> None:
|
253 | 257 | for source_model_type, recipient_model_type in variants:
|
254 | 258 | self._test_restore_from_snapshot(source_model_type, recipient_model_type)
|
255 | 259 |
|
| 260 | + @parameterized.expand( |
| 261 | + [ |
| 262 | + ( |
| 263 | + [ |
| 264 | + # Empty attrs to ignore any attr based logic. |
| 265 | + ArgInfo( |
| 266 | + input_attrs=[ |
| 267 | + "", |
| 268 | + ], |
| 269 | + is_getitems=[False], |
| 270 | + postproc_modules=[None], |
| 271 | + constants=[None], |
| 272 | + name="id_list_features", |
| 273 | + ), |
| 274 | + ArgInfo( |
| 275 | + input_attrs=[], |
| 276 | + is_getitems=[], |
| 277 | + postproc_modules=[], |
| 278 | + constants=[], |
| 279 | + name="id_score_list_features", |
| 280 | + ), |
| 281 | + ], |
| 282 | + 0, |
| 283 | + ["id_list_features", "id_score_list_features"], |
| 284 | + ), |
| 285 | + ( |
| 286 | + [ |
| 287 | + # Empty attrs to ignore any attr based logic. |
| 288 | + ArgInfo( |
| 289 | + input_attrs=[ |
| 290 | + "", |
| 291 | + ], |
| 292 | + is_getitems=[False], |
| 293 | + postproc_modules=[None], |
| 294 | + constants=[None], |
| 295 | + name=None, |
| 296 | + ), |
| 297 | + ArgInfo( |
| 298 | + input_attrs=[], |
| 299 | + is_getitems=[], |
| 300 | + postproc_modules=[], |
| 301 | + constants=[], |
| 302 | + name=None, |
| 303 | + ), |
| 304 | + ], |
| 305 | + 2, |
| 306 | + [], |
| 307 | + ), |
| 308 | + ( |
| 309 | + [ |
| 310 | + # Empty attrs to ignore any attr based logic. |
| 311 | + ArgInfo( |
| 312 | + input_attrs=[ |
| 313 | + "", |
| 314 | + ], |
| 315 | + is_getitems=[False], |
| 316 | + postproc_modules=[None], |
| 317 | + constants=[None], |
| 318 | + name=None, |
| 319 | + ), |
| 320 | + ArgInfo( |
| 321 | + input_attrs=[], |
| 322 | + is_getitems=[], |
| 323 | + postproc_modules=[], |
| 324 | + constants=[], |
| 325 | + name="id_score_list_features", |
| 326 | + ), |
| 327 | + ], |
| 328 | + 1, |
| 329 | + ["id_score_list_features"], |
| 330 | + ), |
| 331 | + ] |
| 332 | + ) |
| 333 | + def test_build_args_kwargs( |
| 334 | + self, |
| 335 | + fwd_args: List[ArgInfo], |
| 336 | + args_len: int, |
| 337 | + kwarges_keys: List[str], |
| 338 | + ) -> None: |
| 339 | + args, kwargs = _build_args_kwargs("initial_input", fwd_args) |
| 340 | + self.assertEqual(len(args), args_len) |
| 341 | + self.assertEqual(list(kwargs.keys()), kwarges_keys) |
| 342 | + |
256 | 343 |
|
257 | 344 | class TestUtils(unittest.TestCase):
|
258 | 345 | def test_get_node_args_helper_call_module_kjt(self) -> None:
|
|
0 commit comments