|
15 | 15 | from collections import OrderedDict
|
16 | 16 | from contextlib import AbstractContextManager, nullcontext
|
17 | 17 | from dataclasses import asdict
|
18 |
| -from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union |
| 18 | +from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union |
19 | 19 |
|
20 | 20 | import torch
|
21 | 21 | from fbgemm_gpu.split_embedding_configs import EmbOptimType
|
@@ -511,3 +511,44 @@ def interaction(self, *args, **kwargs) -> None:
|
511 | 511 | pdb.Pdb.interaction(self, *args, **kwargs)
|
512 | 512 | finally:
|
513 | 513 | sys.stdin = _stdin
|
| 514 | + |
| 515 | + |
| 516 | +def create_global_tensor_shape_stride_from_metadata( |
| 517 | + parameter_sharding: ParameterSharding, devices_per_node: Optional[int] = None |
| 518 | +) -> Tuple[torch.Size, Tuple[int, int]]: |
| 519 | + """ |
| 520 | + Create a global tensor shape and stride from shard metadata. |
| 521 | +
|
| 522 | + Returns: |
| 523 | + torch.Size: global tensor shape. |
| 524 | + tuple: global tensor stride. |
| 525 | + """ |
| 526 | + size = None |
| 527 | + if parameter_sharding.sharding_type == ShardingType.COLUMN_WISE.value: |
| 528 | + row_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[0] # pyre-ignore[16] |
| 529 | + col_dim = 0 |
| 530 | + for shard in parameter_sharding.sharding_spec.shards: |
| 531 | + col_dim += shard.shard_sizes[1] |
| 532 | + size = torch.Size([row_dim, col_dim]) |
| 533 | + elif ( |
| 534 | + parameter_sharding.sharding_type == ShardingType.ROW_WISE.value |
| 535 | + or parameter_sharding.sharding_type == ShardingType.TABLE_ROW_WISE.value |
| 536 | + ): |
| 537 | + row_dim = 0 |
| 538 | + col_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[1] |
| 539 | + for shard in parameter_sharding.sharding_spec.shards: |
| 540 | + row_dim += shard.shard_sizes[0] |
| 541 | + size = torch.Size([row_dim, col_dim]) |
| 542 | + elif parameter_sharding.sharding_type == ShardingType.TABLE_WISE.value: |
| 543 | + size = torch.Size(parameter_sharding.sharding_spec.shards[0].shard_sizes) |
| 544 | + elif parameter_sharding.sharding_type == ShardingType.GRID_SHARD.value: |
| 545 | + # we need node group size to appropriately calculate global shape from shard |
| 546 | + assert devices_per_node is not None |
| 547 | + row_dim, col_dim = 0, 0 |
| 548 | + num_cw_shards = len(parameter_sharding.sharding_spec.shards) // devices_per_node |
| 549 | + for _ in range(num_cw_shards): |
| 550 | + col_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[1] |
| 551 | + for _ in range(devices_per_node): |
| 552 | + row_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[0] |
| 553 | + size = torch.Size([row_dim, col_dim]) |
| 554 | + return size, (size[1], 1) if size else (torch.Size([0, 0]), (0, 1)) # pyre-ignore[7] |
0 commit comments