Skip to content

Commit 4999964

Browse files
faran928facebook-github-bot
authored andcommitted
Enable Output Dist module when DI + Lowering is required (#2846)
Summary: Pull Request resolved: #2846 Enable Output Dist module when DI + Lowering is required. This is mainly to handle empty or zero tensors being passed across the boundary of Intermodules and merge as DI look up results during publishing may lead to this kind of set up. Reviewed By: jiayisuse Differential Revision: D71671275 fbshipit-source-id: 94936476d361b299ee7e5ffb9884f5265e364765
1 parent b27dc74 commit 4999964

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

torchrec/distributed/quant_embedding.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,15 @@ def _construct_jagged_tensors_tw(
300300
return ret
301301

302302

303+
@torch.fx.wrap
304+
def _fx_marker_construct_jagged_tensor(
305+
values: torch.Tensor,
306+
lengths: torch.Tensor,
307+
weights: Optional[torch.Tensor],
308+
) -> JaggedTensor:
309+
return JaggedTensor(values=values, lengths=lengths, weights=weights)
310+
311+
303312
def _construct_jagged_tensors_rw(
304313
embeddings: List[torch.Tensor],
305314
feature_keys: List[str],
@@ -326,7 +335,7 @@ def _construct_jagged_tensors_rw(
326335
length_per_key,
327336
)
328337
for i, key in enumerate(feature_keys):
329-
ret[key] = JaggedTensor(
338+
ret[key] = _fx_marker_construct_jagged_tensor(
330339
values=embs_split_per_key[i],
331340
lengths=lengths_list[i],
332341
weights=values_list[i] if need_indices else None,
@@ -455,6 +464,7 @@ def _construct_jagged_tensors(
455464
rw_bucket_mapping_tensor,
456465
),
457466
)
467+
458468
elif sharding_type == ShardingType.COLUMN_WISE.value:
459469
return _construct_jagged_tensors_cw(
460470
embeddings,
@@ -1112,6 +1122,7 @@ def forward(self, features: KeyedJaggedTensor) -> Tuple[
11121122
List[Optional[torch.Tensor]],
11131123
List[Optional[torch.Tensor]],
11141124
]:
1125+
11151126
with torch.no_grad():
11161127
ret: List[KJTList] = []
11171128
unbucketize_permute_tensor = []

0 commit comments

Comments
 (0)