|
26 | 26 | from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
|
27 | 27 |
|
28 | 28 | from torchrec.distributed.sharding_plan import (
|
| 29 | + column_wise, |
29 | 30 | construct_module_sharding_plan,
|
30 | 31 | get_module_to_default_sharders,
|
31 | 32 | table_wise,
|
@@ -79,6 +80,23 @@ def generate_embedding_bag_config(
|
79 | 80 | return embedding_bag_config
|
80 | 81 |
|
81 | 82 |
|
| 83 | +def generate_rank_placements( |
| 84 | + world_size: int, |
| 85 | + num_tables: int, |
| 86 | + ranks_per_tables: List[int], |
| 87 | +) -> List[List[int]]: |
| 88 | + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size |
| 89 | + placements = [] |
| 90 | + max_rank = world_size - 1 |
| 91 | + if ranks_per_tables == [0]: |
| 92 | + ranks_per_tables = [random.randint(1, max_rank) for _ in range(num_tables)] |
| 93 | + for i in range(num_tables): |
| 94 | + ranks_per_table = ranks_per_tables[i] |
| 95 | + placement = sorted(random.sample(range(world_size), ranks_per_table)) |
| 96 | + placements.append(placement) |
| 97 | + return placements |
| 98 | + |
| 99 | + |
82 | 100 | def create_test_initial_state_dict(
|
83 | 101 | sharded_module_type: nn.Module,
|
84 | 102 | num_tables: int,
|
@@ -379,19 +397,73 @@ def test_dynamic_sharding_ebc_tw(
|
379 | 397 | ) -> None:
|
380 | 398 | # Tests EBC dynamic sharding implementation for TW
|
381 | 399 |
|
| 400 | + # Table wise can only have 1 rank allocated per table: |
| 401 | + ranks_per_tables = [1 for _ in range(num_tables)] |
382 | 402 | # Cannot include old/new rank generation with hypothesis library due to depedency on world_size
|
383 |
| - old_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] |
384 |
| - new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] |
| 403 | + old_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) |
| 404 | + new_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) |
| 405 | + |
| 406 | + while new_ranks == old_ranks: |
| 407 | + new_ranks = generate_rank_placements( |
| 408 | + world_size, num_tables, ranks_per_tables |
| 409 | + ) |
| 410 | + per_param_sharding = {} |
| 411 | + new_per_param_sharding = {} |
| 412 | + |
| 413 | + # Construct parameter shardings |
| 414 | + for i in range(num_tables): |
| 415 | + per_param_sharding[table_name(i)] = table_wise(rank=old_ranks[i][0]) |
| 416 | + new_per_param_sharding[table_name(i)] = table_wise(rank=new_ranks[i][0]) |
| 417 | + |
| 418 | + self._run_ebc_resharding_test( |
| 419 | + per_param_sharding, |
| 420 | + new_per_param_sharding, |
| 421 | + num_tables, |
| 422 | + world_size, |
| 423 | + data_type, |
| 424 | + ) |
385 | 425 |
|
| 426 | + @unittest.skipIf( |
| 427 | + torch.cuda.device_count() <= 1, |
| 428 | + "Not enough GPUs, this test requires at least two GPUs", |
| 429 | + ) |
| 430 | + @given( # pyre-ignore |
| 431 | + num_tables=st.sampled_from([2, 3, 4]), |
| 432 | + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), |
| 433 | + world_size=st.sampled_from([3, 4]), |
| 434 | + ) |
| 435 | + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) |
| 436 | + def test_dynamic_sharding_ebc_cw( |
| 437 | + self, |
| 438 | + num_tables: int, |
| 439 | + data_type: DataType, |
| 440 | + world_size: int, |
| 441 | + ) -> None: |
| 442 | + # Tests EBC dynamic sharding implementation for CW |
| 443 | + |
| 444 | + # Force the ranks per table to be consistent |
| 445 | + ranks_per_tables = [ |
| 446 | + random.randint(1, world_size - 1) for _ in range(num_tables) |
| 447 | + ] |
| 448 | + |
| 449 | + old_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) |
| 450 | + new_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) |
| 451 | + |
| 452 | + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size |
386 | 453 | while new_ranks == old_ranks:
|
387 |
| - new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] |
| 454 | + old_ranks = generate_rank_placements( |
| 455 | + world_size, num_tables, ranks_per_tables |
| 456 | + ) |
| 457 | + new_ranks = generate_rank_placements( |
| 458 | + world_size, num_tables, ranks_per_tables |
| 459 | + ) |
388 | 460 | per_param_sharding = {}
|
389 | 461 | new_per_param_sharding = {}
|
390 | 462 |
|
391 | 463 | # Construct parameter shardings
|
392 | 464 | for i in range(num_tables):
|
393 |
| - per_param_sharding[table_name(i)] = table_wise(rank=old_ranks[i]) |
394 |
| - new_per_param_sharding[table_name(i)] = table_wise(rank=new_ranks[i]) |
| 465 | + per_param_sharding[table_name(i)] = column_wise(ranks=old_ranks[i]) |
| 466 | + new_per_param_sharding[table_name(i)] = column_wise(ranks=new_ranks[i]) |
395 | 467 |
|
396 | 468 | self._run_ebc_resharding_test(
|
397 | 469 | per_param_sharding,
|
|
0 commit comments