Skip to content

Commit 0a78345

Browse files
aporialiaofacebook-github-bot
authored andcommitted
4/n CW ShardedTensor Test (#2863)
Summary: Pull Request resolved: #2863 Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Reviewed By: iamzainhuda Differential Revision: D72017500 fbshipit-source-id: 6daef515fa8aa0a320ed6df1254653acf7b99430
1 parent 9c0856e commit 0a78345

File tree

1 file changed

+77
-5
lines changed

1 file changed

+77
-5
lines changed

torchrec/distributed/tests/test_dynamic_sharding.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
2727

2828
from torchrec.distributed.sharding_plan import (
29+
column_wise,
2930
construct_module_sharding_plan,
3031
get_module_to_default_sharders,
3132
table_wise,
@@ -79,6 +80,23 @@ def generate_embedding_bag_config(
7980
return embedding_bag_config
8081

8182

83+
def generate_rank_placements(
84+
world_size: int,
85+
num_tables: int,
86+
ranks_per_tables: List[int],
87+
) -> List[List[int]]:
88+
# Cannot include old/new rank generation with hypothesis library due to depedency on world_size
89+
placements = []
90+
max_rank = world_size - 1
91+
if ranks_per_tables == [0]:
92+
ranks_per_tables = [random.randint(1, max_rank) for _ in range(num_tables)]
93+
for i in range(num_tables):
94+
ranks_per_table = ranks_per_tables[i]
95+
placement = sorted(random.sample(range(world_size), ranks_per_table))
96+
placements.append(placement)
97+
return placements
98+
99+
82100
def create_test_initial_state_dict(
83101
sharded_module_type: nn.Module,
84102
num_tables: int,
@@ -379,19 +397,73 @@ def test_dynamic_sharding_ebc_tw(
379397
) -> None:
380398
# Tests EBC dynamic sharding implementation for TW
381399

400+
# Table wise can only have 1 rank allocated per table:
401+
ranks_per_tables = [1 for _ in range(num_tables)]
382402
# Cannot include old/new rank generation with hypothesis library due to depedency on world_size
383-
old_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)]
384-
new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)]
403+
old_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables)
404+
new_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables)
405+
406+
while new_ranks == old_ranks:
407+
new_ranks = generate_rank_placements(
408+
world_size, num_tables, ranks_per_tables
409+
)
410+
per_param_sharding = {}
411+
new_per_param_sharding = {}
412+
413+
# Construct parameter shardings
414+
for i in range(num_tables):
415+
per_param_sharding[table_name(i)] = table_wise(rank=old_ranks[i][0])
416+
new_per_param_sharding[table_name(i)] = table_wise(rank=new_ranks[i][0])
417+
418+
self._run_ebc_resharding_test(
419+
per_param_sharding,
420+
new_per_param_sharding,
421+
num_tables,
422+
world_size,
423+
data_type,
424+
)
385425

426+
@unittest.skipIf(
427+
torch.cuda.device_count() <= 1,
428+
"Not enough GPUs, this test requires at least two GPUs",
429+
)
430+
@given( # pyre-ignore
431+
num_tables=st.sampled_from([2, 3, 4]),
432+
data_type=st.sampled_from([DataType.FP32, DataType.FP16]),
433+
world_size=st.sampled_from([3, 4]),
434+
)
435+
@settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
436+
def test_dynamic_sharding_ebc_cw(
437+
self,
438+
num_tables: int,
439+
data_type: DataType,
440+
world_size: int,
441+
) -> None:
442+
# Tests EBC dynamic sharding implementation for CW
443+
444+
# Force the ranks per table to be consistent
445+
ranks_per_tables = [
446+
random.randint(1, world_size - 1) for _ in range(num_tables)
447+
]
448+
449+
old_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables)
450+
new_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables)
451+
452+
# Cannot include old/new rank generation with hypothesis library due to depedency on world_size
386453
while new_ranks == old_ranks:
387-
new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)]
454+
old_ranks = generate_rank_placements(
455+
world_size, num_tables, ranks_per_tables
456+
)
457+
new_ranks = generate_rank_placements(
458+
world_size, num_tables, ranks_per_tables
459+
)
388460
per_param_sharding = {}
389461
new_per_param_sharding = {}
390462

391463
# Construct parameter shardings
392464
for i in range(num_tables):
393-
per_param_sharding[table_name(i)] = table_wise(rank=old_ranks[i])
394-
new_per_param_sharding[table_name(i)] = table_wise(rank=new_ranks[i])
465+
per_param_sharding[table_name(i)] = column_wise(ranks=old_ranks[i])
466+
new_per_param_sharding[table_name(i)] = column_wise(ranks=new_ranks[i])
395467

396468
self._run_ebc_resharding_test(
397469
per_param_sharding,

0 commit comments

Comments
 (0)