Skip to content

Commit b27dc74

Browse files
fix row-wise alltoall error when some embeddings use mean pooling and others use sum pooling (#2809)
Summary: KJT.lengths is modified by mean pooling callback. When some embeddings use mean pooling and others use sum pooling, KJT.lengths will be incorrect. Pull Request resolved: #2809 Reviewed By: PaulZhang12 Differential Revision: D71759505 Pulled By: iamzainhuda fbshipit-source-id: d899a2a072f84e9b6692aac0653914f5d3490f27
1 parent 367a323 commit b27dc74

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

torchrec/distributed/embeddingbag.py

+1
Original file line numberDiff line numberDiff line change
@@ -1716,6 +1716,7 @@ def _create_mean_pooling_divisor(
17161716
lengths = torch.index_select(input=lengths, dim=0, index=indices)
17171717

17181718
# only convert the sum pooling features to be 1 lengths
1719+
lengths = lengths.clone()
17191720
for feature in pooling_type_to_rs_features[PoolingType.SUM.value]:
17201721
feature_index = kjt_key_indices[feature]
17211722
feature_index = feature_index * batch_size

0 commit comments

Comments
 (0)