Skip to content

Commit de6fac8

Browse files
22quinnfacebook-github-bot
authored andcommitted
Allow Dynamo tracing for embeddings_cat_empty_rank_handle_inference (#2881)
Summary: Pull Request resolved: #2881 Currently embeddings_cat_empty_rank_handle_inference cannot be traced by Dynamo. We have to exclude the op from lowering causing unnecessary and all kinds of graph breaks. This change allows the op to be traced into. Differential Revision: D72879113 fbshipit-source-id: 6219021dfbd5b71b450cda290c3f96394ef65521
1 parent 0bd591a commit de6fac8

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

torchrec/distributed/embedding_lookup.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,22 @@ def get_tbes_to_register(
731731
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
732732
return get_tbes_to_register_from_iterable(self._emb_modules)
733733

734+
def embeddings_cat_empty_rank_handle_inference(
735+
self,
736+
embeddings: List[torch.Tensor],
737+
dim: int = 0,
738+
) -> torch.Tensor:
739+
if len(self.grouped_configs) == 0:
740+
# return a dummy empty tensor when grouped_configs is empty
741+
dev: Optional[torch.device] = (
742+
torch.device(self.device) if self.device is not None else None
743+
)
744+
return torch.empty([0], dtype=self.output_dtype, device=dev)
745+
elif len(self.grouped_configs) == 1:
746+
return embeddings[0]
747+
else:
748+
return torch.cat(embeddings, dim=dim)
749+
734750
def forward(
735751
self,
736752
sparse_features: KeyedJaggedTensor,
@@ -747,9 +763,7 @@ def forward(
747763
# 2d embedding by nature
748764
embeddings.append(self._emb_modules[i].forward(features_by_group[i]))
749765

750-
return embeddings_cat_empty_rank_handle_inference(
751-
embeddings, device=self.device, dtype=self.output_dtype
752-
)
766+
return self.embeddings_cat_empty_rank_handle_inference(embeddings)
753767

754768
# pyre-ignore [14]
755769
def state_dict(
@@ -865,6 +879,22 @@ def get_tbes_to_register(
865879
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
866880
return get_tbes_to_register_from_iterable(self._emb_modules)
867881

882+
def embeddings_cat_empty_rank_handle_inference(
883+
self,
884+
embeddings: List[torch.Tensor],
885+
dim: int = 0,
886+
) -> torch.Tensor:
887+
if len(self.grouped_configs) == 0:
888+
# return a dummy empty tensor when grouped_configs is empty
889+
dev: Optional[torch.device] = (
890+
torch.device(self.device) if self.device is not None else None
891+
)
892+
return torch.empty([0], dtype=self.output_dtype, device=dev)
893+
elif len(self.grouped_configs) == 1:
894+
return embeddings[0]
895+
else:
896+
return torch.cat(embeddings, dim=dim)
897+
868898
def forward(
869899
self,
870900
sparse_features: KeyedJaggedTensor,
@@ -897,11 +927,9 @@ def forward(
897927
features = self._feature_processor(features)
898928
embeddings.append(emb_op.forward(features))
899929

900-
return embeddings_cat_empty_rank_handle_inference(
930+
return self.embeddings_cat_empty_rank_handle_inference(
901931
embeddings,
902932
dim=1,
903-
device=self.device,
904-
dtype=self.output_dtype,
905933
)
906934

907935
# pyre-ignore [14]

0 commit comments

Comments
 (0)