|
39 | 39 | SequenceShardingContext,
|
40 | 40 | )
|
41 | 41 | from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs
|
| 42 | +from torchrec.modules.utils import ( |
| 43 | + _fx_trec_get_feature_length, |
| 44 | + _get_batching_hinted_output, |
| 45 | +) |
42 | 46 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
43 | 47 |
|
| 48 | +torch.fx.wrap("_get_batching_hinted_output") |
| 49 | +torch.fx.wrap("_fx_trec_get_feature_length") |
| 50 | + |
44 | 51 |
|
45 | 52 | class RwSequenceEmbeddingDist(
|
46 | 53 | BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]
|
@@ -169,26 +176,70 @@ def __init__(
|
169 | 176 | device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
|
170 | 177 | ) -> None:
|
171 | 178 | super().__init__()
|
172 |
| - self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size) |
173 | 179 | self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = (
|
174 | 180 | device_type_from_sharding_infos
|
175 | 181 | )
|
| 182 | + num_cpu_ranks = 0 |
| 183 | + if self._device_type_from_sharding_infos and isinstance( |
| 184 | + self._device_type_from_sharding_infos, tuple |
| 185 | + ): |
| 186 | + for device_type in self._device_type_from_sharding_infos: |
| 187 | + if device_type == "cpu": |
| 188 | + num_cpu_ranks += 1 |
| 189 | + elif self._device_type_from_sharding_infos == "cpu": |
| 190 | + num_cpu_ranks = world_size |
| 191 | + |
| 192 | + self._device_dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne( |
| 193 | + device, world_size - num_cpu_ranks |
| 194 | + ) |
176 | 195 |
|
177 | 196 | def forward(
|
178 | 197 | self,
|
179 | 198 | local_embs: List[torch.Tensor],
|
180 | 199 | sharding_ctx: Optional[InferSequenceShardingContext] = None,
|
181 | 200 | ) -> List[torch.Tensor]:
|
182 |
| - if self._device_type_from_sharding_infos is not None: |
183 |
| - if isinstance(self._device_type_from_sharding_infos, tuple): |
184 |
| - # Update the tagging when tuple has heterogenous device type |
185 |
| - # Done in next diff stack along with the full support for |
186 |
| - # hetergoenous device type |
187 |
| - return local_embs |
188 |
| - elif self._device_type_from_sharding_infos == "cpu": |
189 |
| - # for cpu sharder, output dist should be a no-op |
190 |
| - return local_embs |
191 |
| - return self._dist(local_embs) |
| 201 | + assert ( |
| 202 | + self._device_type_from_sharding_infos is not None |
| 203 | + ), "_device_type_from_sharding_infos should always be set for InferRwSequenceEmbeddingDist" |
| 204 | + if isinstance(self._device_type_from_sharding_infos, tuple): |
| 205 | + assert sharding_ctx is not None |
| 206 | + assert sharding_ctx.embedding_names_per_rank is not None |
| 207 | + assert len(self._device_type_from_sharding_infos) == len( |
| 208 | + local_embs |
| 209 | + ), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types" |
| 210 | + non_cpu_local_embs = [] |
| 211 | + # Here looping through local_embs is also compatible with tracing |
| 212 | + # given the number of looks up / shards withing ShardedQuantEmbeddingCollection |
| 213 | + # are fixed and local_embs is the output of those looks ups. However, still |
| 214 | + # using _device_type_from_sharding_infos to iterate on local_embs list as |
| 215 | + # that's a better practice. |
| 216 | + for i, device_type in enumerate(self._device_type_from_sharding_infos): |
| 217 | + if device_type != "cpu": |
| 218 | + non_cpu_local_embs.append( |
| 219 | + _get_batching_hinted_output( |
| 220 | + _fx_trec_get_feature_length( |
| 221 | + sharding_ctx.features[i], |
| 222 | + # pyre-fixme [16] |
| 223 | + sharding_ctx.embedding_names_per_rank[i], |
| 224 | + ), |
| 225 | + local_embs[i], |
| 226 | + ) |
| 227 | + ) |
| 228 | + non_cpu_local_embs_dist = self._device_dist(non_cpu_local_embs) |
| 229 | + index = 0 |
| 230 | + result = [] |
| 231 | + for i, device_type in enumerate(self._device_type_from_sharding_infos): |
| 232 | + if device_type == "cpu": |
| 233 | + result.append(local_embs[i]) |
| 234 | + else: |
| 235 | + result.append(non_cpu_local_embs_dist[index]) |
| 236 | + index += 1 |
| 237 | + return result |
| 238 | + elif self._device_type_from_sharding_infos == "cpu": |
| 239 | + # for cpu sharder, output dist should be a no-op |
| 240 | + return local_embs |
| 241 | + else: |
| 242 | + return self._device_dist(local_embs) |
192 | 243 |
|
193 | 244 |
|
194 | 245 | class InferRwSequenceEmbeddingSharding(
|
@@ -237,6 +288,7 @@ def create_lookup(
|
237 | 288 | world_size=self._world_size,
|
238 | 289 | fused_params=fused_params,
|
239 | 290 | device=device if device is not None else self._device,
|
| 291 | + device_type_from_sharding_infos=self._device_type_from_sharding_infos, |
240 | 292 | )
|
241 | 293 |
|
242 | 294 | def create_output_dist(
|
|
0 commit comments