Skip to content

Commit b90ac7e

Browse files
aliafzalfacebook-github-bot
authored andcommitted
positional and kwargs corner case fix in for _build_args_kwargs (#2714)
Summary: ## Context: Setting None in positional args is colliding with the kwargs in situations when kwargs contains the argument name accepted by the method. Eg : ``` def input_dist(ctx, id_feature_list): ... // If _build_args_kwargs returns: args = [None] kwargs = {'id_feature_list': KJT} input_dist(ctx, *args, **kwargs) // extends to input_dist(ctx, None, id_feature_list=KJT) ``` which results in "TypeError: got multiple values for argument 'id_feature_list'" because id_feature_list is provided both positionally (None) and via kwargs. Reviewed By: sarckk Differential Revision: D68892351
1 parent 96abf2a commit b90ac7e

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

+87
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
import copy
1111
import enum
1212
import unittest
13+
from typing import Any, List
1314
from unittest.mock import MagicMock
1415

1516
import torch
17+
from parameterized import parameterized
1618

1719
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
1820
from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule
@@ -21,8 +23,10 @@
2123
TrainPipelineSparseDistTestBase,
2224
)
2325
from torchrec.distributed.train_pipeline.utils import (
26+
_build_args_kwargs,
2427
_get_node_args,
2528
_rewrite_model,
29+
ArgInfo,
2630
PipelinedForward,
2731
PipelinedPostproc,
2832
TrainPipelineContext,
@@ -253,6 +257,89 @@ def test_restore_from_snapshot(self) -> None:
253257
for source_model_type, recipient_model_type in variants:
254258
self._test_restore_from_snapshot(source_model_type, recipient_model_type)
255259

260+
@parameterized.expand(
261+
[
262+
(
263+
[
264+
# Empty attrs to ignore any attr based logic.
265+
ArgInfo(
266+
input_attrs=[
267+
"",
268+
],
269+
is_getitems=[False],
270+
postproc_modules=[None],
271+
constants=[None],
272+
name="id_list_features",
273+
),
274+
ArgInfo(
275+
input_attrs=[],
276+
is_getitems=[],
277+
postproc_modules=[],
278+
constants=[],
279+
name="id_score_list_features",
280+
),
281+
],
282+
0,
283+
["id_list_features", "id_score_list_features"],
284+
),
285+
(
286+
[
287+
# Empty attrs to ignore any attr based logic.
288+
ArgInfo(
289+
input_attrs=[
290+
"",
291+
],
292+
is_getitems=[False],
293+
postproc_modules=[None],
294+
constants=[None],
295+
name=None,
296+
),
297+
ArgInfo(
298+
input_attrs=[],
299+
is_getitems=[],
300+
postproc_modules=[],
301+
constants=[],
302+
name=None,
303+
),
304+
],
305+
2,
306+
[],
307+
),
308+
(
309+
[
310+
# Empty attrs to ignore any attr based logic.
311+
ArgInfo(
312+
input_attrs=[
313+
"",
314+
],
315+
is_getitems=[False],
316+
postproc_modules=[None],
317+
constants=[None],
318+
name=None,
319+
),
320+
ArgInfo(
321+
input_attrs=[],
322+
is_getitems=[],
323+
postproc_modules=[],
324+
constants=[],
325+
name="id_score_list_features",
326+
),
327+
],
328+
1,
329+
["id_score_list_features"],
330+
),
331+
]
332+
)
333+
def test_build_args_kwargs(
334+
self,
335+
fwd_args: List[ArgInfo],
336+
args_len: int,
337+
kwarges_keys: List[str],
338+
) -> None:
339+
args, kwargs = _build_args_kwargs("initial_input", fwd_args)
340+
self.assertEqual(len(args), args_len)
341+
self.assertEqual(list(kwargs.keys()), kwarges_keys)
342+
256343

257344
class TestUtils(unittest.TestCase):
258345
def test_get_node_args_helper_call_module_kjt(self) -> None:

torchrec/distributed/train_pipeline/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,10 @@ def _build_args_kwargs(
230230
else:
231231
args.append(arg)
232232
else:
233-
args.append(None)
233+
if arg_info.name:
234+
kwargs[arg_info.name] = None
235+
else:
236+
args.append(None)
234237
return args, kwargs
235238

236239

0 commit comments

Comments
 (0)