diff --git a/requirements.txt b/requirements.txt index 6d63107dd..6b17aeac6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ torchmetrics==1.0.3 torchx tqdm usort +parameterized # for tests # https://github.com/pytorch/pytorch/blob/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc/requirements.txt#L3 diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index f23dc0fe0..53fae9001 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -10,9 +10,11 @@ import copy import enum import unittest +from typing import List from unittest.mock import MagicMock import torch +from parameterized import parameterized from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule @@ -21,8 +23,10 @@ TrainPipelineSparseDistTestBase, ) from torchrec.distributed.train_pipeline.utils import ( + _build_args_kwargs, _get_node_args, _rewrite_model, + ArgInfo, PipelinedForward, PipelinedPostproc, TrainPipelineContext, @@ -253,6 +257,89 @@ def test_restore_from_snapshot(self) -> None: for source_model_type, recipient_model_type in variants: self._test_restore_from_snapshot(source_model_type, recipient_model_type) + @parameterized.expand( + [ + ( + [ + # Empty attrs to ignore any attr based logic. + ArgInfo( + input_attrs=[ + "", + ], + is_getitems=[False], + postproc_modules=[None], + constants=[None], + name="id_list_features", + ), + ArgInfo( + input_attrs=[], + is_getitems=[], + postproc_modules=[], + constants=[], + name="id_score_list_features", + ), + ], + 0, + ["id_list_features", "id_score_list_features"], + ), + ( + [ + # Empty attrs to ignore any attr based logic. + ArgInfo( + input_attrs=[ + "", + ], + is_getitems=[False], + postproc_modules=[None], + constants=[None], + name=None, + ), + ArgInfo( + input_attrs=[], + is_getitems=[], + postproc_modules=[], + constants=[], + name=None, + ), + ], + 2, + [], + ), + ( + [ + # Empty attrs to ignore any attr based logic. + ArgInfo( + input_attrs=[ + "", + ], + is_getitems=[False], + postproc_modules=[None], + constants=[None], + name=None, + ), + ArgInfo( + input_attrs=[], + is_getitems=[], + postproc_modules=[], + constants=[], + name="id_score_list_features", + ), + ], + 1, + ["id_score_list_features"], + ), + ] + ) + def test_build_args_kwargs( + self, + fwd_args: List[ArgInfo], + args_len: int, + kwarges_keys: List[str], + ) -> None: + args, kwargs = _build_args_kwargs("initial_input", fwd_args) + self.assertEqual(len(args), args_len) + self.assertEqual(list(kwargs.keys()), kwarges_keys) + class TestUtils(unittest.TestCase): def test_get_node_args_helper_call_module_kjt(self) -> None: diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 76cf87370..5b76c1e2d 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -230,7 +230,10 @@ def _build_args_kwargs( else: args.append(arg) else: - args.append(None) + if arg_info.name: + kwargs[arg_info.name] = None + else: + args.append(None) return args, kwargs