diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index a7ac5c972..e98d99ecc 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -145,6 +145,7 @@ def create_embedding_bag_sharding( device: Optional[torch.device] = None, permute_embeddings: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + independent_emb_key_pg: Optional[dist.ProcessGroup] = None, ) -> EmbeddingSharding[ EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ]: @@ -165,6 +166,7 @@ def create_embedding_bag_sharding( env, device, qcomm_codecs_registry=qcomm_codecs_registry, + independent_emb_key_pg=independent_emb_key_pg, ) elif sharding_type == ShardingType.DATA_PARALLEL.value: return DpPooledEmbeddingSharding(sharding_infos, env, device) @@ -587,7 +589,7 @@ def __init__( self._table_names: List[str] = [] self._pooling_type_to_rs_features: Dict[str, List[str]] = defaultdict(list) self._table_name_to_config: Dict[str, EmbeddingBagConfig] = {} - + self.independent_emb_key_pg = dist.new_group(); for config in self._embedding_bag_configs: self._table_names.append(config.name) self._table_name_to_config[config.name] = config @@ -633,6 +635,7 @@ def __init__( device, permute_embeddings=True, qcomm_codecs_registry=self.qcomm_codecs_registry, + independent_emb_key_pg=self.independent_emb_key_pg, ) for embedding_configs in sharding_type_to_sharding_infos.values() ] diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index 0ecdabb7a..0260e233f 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -119,9 +119,11 @@ def __init__( need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, + independent_emb_key_pg: Optional[dist.ProcessGroup] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._env = env + self._independent_emb_key_pg: Optional[dist.ProcessGroup] = independent_emb_key_pg if independent_emb_key_pg else self._pg self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) self._pg: Optional[dist.ProcessGroup] = ( self._env.sharding_pg # pyre-ignore[16] @@ -538,7 +540,7 @@ def create_input_dist( return RwSparseFeaturesDist( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. - pg=self._pg, + pg=self._independent_emb_key_pg, num_features=num_features, feature_hash_sizes=feature_hash_sizes, device=device if device is not None else self._device,