Skip to content

Commit d0861b2

Browse files
ZhengkaiZfacebook-github-bot
authored andcommitted
get rid of unncessary fx wrap in regrouping (#2882)
Summary: [torchrec] get rid of unncessary fx wrap in regrouping Reviewed By: taylorKempWork, FulinHuang Differential Revision: D72884977
1 parent 05aea06 commit d0861b2

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

torchrec/modules/regroup.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,6 @@ def _permuted_values(
3434
return torch.cat(values, dim=dim)
3535

3636

37-
@torch.fx.wrap
38-
def _build_dict(
39-
keys: List[str],
40-
values: Union[torch.Tensor, List[torch.Tensor]],
41-
splits: List[int],
42-
dim: int,
43-
) -> Dict[str, torch.Tensor]:
44-
if isinstance(values, torch.Tensor):
45-
return dict(zip(keys, torch.split(values, splits, dim=dim)))
46-
else:
47-
return dict(zip(keys, values))
48-
49-
5037
@torch.fx.wrap
5138
def module_init(module: "KTRegroupAsDict", keyed_tensors: List[KeyedTensor]) -> None:
5239
assert len(keyed_tensors) > 0, "Empty list provided"
@@ -115,6 +102,12 @@ def forward(self, values: List[torch.Tensor]) -> List[torch.Tensor]:
115102
)
116103

117104

105+
def _to_tensor_dict(
106+
keys: List[str], values: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]]
107+
) -> Dict[str, torch.Tensor]:
108+
return {key: values[i] for i, key in enumerate(keys)}
109+
110+
118111
class KTRegroupAsDict(torch.nn.Module, CacheMixin):
119112
"""
120113
KTRegroupAsDict is a nn.Module that mirrors beahvior of static method KeyedTensor.regroup_as_dict()
@@ -204,11 +197,13 @@ def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
204197
if self._use_fbgemm_regroup:
205198
values = _get_kts_values(keyed_tensors)
206199
permuted_values = self._permute_pooled_embs_impl(values)
200+
return _to_tensor_dict(self._keys, permuted_values)
207201
else:
208202
permuted_values = _permuted_values(
209203
keyed_tensors, self._idx_key_pairs, self._dim
210204
)
211-
return _build_dict(self._keys, permuted_values, self._splits, self._dim)
205+
splitted_values = torch.split(permuted_values, self._splits, dim=self._dim)
206+
return _to_tensor_dict(self._keys, splitted_values)
212207

213208
def clear_cache(self) -> None:
214209
self._is_inited = False

0 commit comments

Comments
 (0)