diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index a4bdc11b5..4fd86d016 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -274,6 +274,12 @@ def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: embedding_shard_metadata.append(table.local_metadata) return embedding_shard_metadata + def features_weighted(self) -> List[bool]: + is_weighted = [] + for table in self.embedding_tables: + is_weighted.extend([table.is_weighted] * table.num_features()) + return is_weighted + F = TypeVar("F", bound=Multistreamable) T = TypeVar("T") diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index 2752290de..5f2c2c767 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -264,6 +264,15 @@ def features_per_rank(self) -> List[int]: features_per_rank.append(num_features) return features_per_rank + def is_weighted_per_rank(self) -> List[List[bool]]: + is_weighted = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: + is_weighted_per_rank = [] + for grouped_config in grouped_embedding_configs: + is_weighted_per_rank.extend(grouped_config.features_weighted()) + is_weighted.append(is_weighted_per_rank) + return is_weighted + class TwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): """