@@ -816,6 +816,159 @@ def test_row_wise_set_heterogenous_device(self, data_type: DataType) -> None:
816
816
0 ,
817
817
)
818
818
819
+ # pyre-fixme[56]
820
+ @given (data_type = st .sampled_from ([DataType .FP32 , DataType .FP16 ]))
821
+ @settings (verbosity = Verbosity .verbose , max_examples = 8 , deadline = None )
822
+ def test_row_wise_bucket_level_sharding (self , data_type : DataType ) -> None :
823
+
824
+ embedding_config = [
825
+ EmbeddingBagConfig (
826
+ name = f"table_{ idx } " ,
827
+ feature_names = [f"feature_{ idx } " ],
828
+ embedding_dim = 64 ,
829
+ num_embeddings = 4096 ,
830
+ data_type = data_type ,
831
+ )
832
+ for idx in range (2 )
833
+ ]
834
+ module_sharding_plan = construct_module_sharding_plan (
835
+ EmbeddingCollection (tables = embedding_config ),
836
+ per_param_sharding = {
837
+ "table_0" : row_wise (
838
+ sizes_placement = (
839
+ [2048 , 1024 , 1024 ],
840
+ ["cpu" , "cuda" , "cuda" ],
841
+ ),
842
+ num_buckets_per_rank = [20 , 30 , 40 ],
843
+ ),
844
+ "table_1" : row_wise (
845
+ sizes_placement = ([2048 , 1024 , 1024 ], ["cpu" , "cpu" , "cpu" ])
846
+ ),
847
+ },
848
+ local_size = 1 ,
849
+ world_size = 2 ,
850
+ device_type = "cuda" ,
851
+ )
852
+
853
+ # Make sure per_param_sharding setting override the default device_type
854
+ device_table_0_shard_0 = (
855
+ # pyre-ignore[16]
856
+ module_sharding_plan ["table_0" ]
857
+ .sharding_spec .shards [0 ]
858
+ .placement
859
+ )
860
+ self .assertEqual (
861
+ device_table_0_shard_0 .device ().type ,
862
+ "cpu" ,
863
+ )
864
+ # cpu always has rank 0
865
+ self .assertEqual (
866
+ device_table_0_shard_0 .rank (),
867
+ 0 ,
868
+ )
869
+ for i in range (1 , 3 ):
870
+ device_table_0_shard_i = (
871
+ module_sharding_plan ["table_0" ].sharding_spec .shards [i ].placement
872
+ )
873
+ self .assertEqual (
874
+ device_table_0_shard_i .device ().type ,
875
+ "cuda" ,
876
+ )
877
+ # first rank is assigned to cpu so index = rank - 1
878
+ self .assertEqual (
879
+ device_table_0_shard_i .device ().index ,
880
+ i - 1 ,
881
+ )
882
+ self .assertEqual (
883
+ device_table_0_shard_i .rank (),
884
+ i ,
885
+ )
886
+ for i in range (3 ):
887
+ device_table_1_shard_i = (
888
+ module_sharding_plan ["table_1" ].sharding_spec .shards [i ].placement
889
+ )
890
+ self .assertEqual (
891
+ device_table_1_shard_i .device ().type ,
892
+ "cpu" ,
893
+ )
894
+ # cpu always has rank 0
895
+ self .assertEqual (
896
+ device_table_1_shard_i .rank (),
897
+ 0 ,
898
+ )
899
+
900
+ expected = {
901
+ "table_0" : ParameterSharding (
902
+ sharding_type = "row_wise" ,
903
+ compute_kernel = "quant" ,
904
+ ranks = [
905
+ 0 ,
906
+ 1 ,
907
+ 2 ,
908
+ ],
909
+ sharding_spec = EnumerableShardingSpec (
910
+ shards = [
911
+ ShardMetadata (
912
+ shard_offsets = [0 , 0 ],
913
+ shard_sizes = [2048 , 64 ],
914
+ placement = "rank:0/cpu" ,
915
+ bucket_id_offset = 0 ,
916
+ num_buckets = 20 ,
917
+ ),
918
+ ShardMetadata (
919
+ shard_offsets = [2048 , 0 ],
920
+ shard_sizes = [1024 , 64 ],
921
+ placement = "rank:1/cuda:0" ,
922
+ bucket_id_offset = 20 ,
923
+ num_buckets = 30 ,
924
+ ),
925
+ ShardMetadata (
926
+ shard_offsets = [3072 , 0 ],
927
+ shard_sizes = [1024 , 64 ],
928
+ placement = "rank:2/cuda:1" ,
929
+ bucket_id_offset = 50 ,
930
+ num_buckets = 40 ,
931
+ ),
932
+ ]
933
+ ),
934
+ ),
935
+ "table_1" : ParameterSharding (
936
+ sharding_type = "row_wise" ,
937
+ compute_kernel = "quant" ,
938
+ ranks = [
939
+ 0 ,
940
+ 1 ,
941
+ 2 ,
942
+ ],
943
+ sharding_spec = EnumerableShardingSpec (
944
+ shards = [
945
+ ShardMetadata (
946
+ shard_offsets = [0 , 0 ],
947
+ shard_sizes = [2048 , 64 ],
948
+ placement = "rank:0/cpu" ,
949
+ bucket_id_offset = None ,
950
+ num_buckets = None ,
951
+ ),
952
+ ShardMetadata (
953
+ shard_offsets = [2048 , 0 ],
954
+ shard_sizes = [1024 , 64 ],
955
+ placement = "rank:0/cpu" ,
956
+ bucket_id_offset = None ,
957
+ num_buckets = None ,
958
+ ),
959
+ ShardMetadata (
960
+ shard_offsets = [3072 , 0 ],
961
+ shard_sizes = [1024 , 64 ],
962
+ placement = "rank:0/cpu" ,
963
+ bucket_id_offset = None ,
964
+ num_buckets = None ,
965
+ ),
966
+ ]
967
+ ),
968
+ ),
969
+ }
970
+ self .assertDictEqual (expected , module_sharding_plan )
971
+
819
972
# pyre-fixme[56]
820
973
@given (data_type = st .sampled_from ([DataType .FP32 , DataType .FP16 ]))
821
974
@settings (verbosity = Verbosity .verbose , max_examples = 8 , deadline = None )
@@ -929,18 +1082,85 @@ def test_str(self) -> None:
929
1082
)
930
1083
expected = """module: ebc
931
1084
932
- param | sharding type | compute kernel | ranks
1085
+ param | sharding type | compute kernel | ranks
933
1086
-------- | ------------- | -------------- | ------
934
1087
user_id | table_wise | dense | [0]
935
1088
movie_id | row_wise | dense | [0, 1]
936
1089
937
- param | shard offsets | shard sizes | placement
1090
+ param | shard offsets | shard sizes | placement
938
1091
-------- | ------------- | ----------- | -------------
939
1092
user_id | [0, 0] | [4096, 32] | rank:0/cuda:0
940
1093
movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0
941
1094
movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1
1095
+ """
1096
+ for i in range (len (expected .splitlines ())):
1097
+ self .assertEqual (
1098
+ expected .splitlines ()[i ].strip (), str (plan ).splitlines ()[i ].strip ()
1099
+ )
1100
+
1101
+ def test_str_bucket_wise_sharding (self ) -> None :
1102
+ plan = ShardingPlan (
1103
+ {
1104
+ "ebc" : EmbeddingModuleShardingPlan (
1105
+ {
1106
+ "user_id" : ParameterSharding (
1107
+ sharding_type = "table_wise" ,
1108
+ compute_kernel = "dense" ,
1109
+ ranks = [0 ],
1110
+ sharding_spec = EnumerableShardingSpec (
1111
+ [
1112
+ ShardMetadata (
1113
+ shard_offsets = [0 , 0 ],
1114
+ shard_sizes = [4096 , 32 ],
1115
+ placement = "rank:0/cuda:0" ,
1116
+ ),
1117
+ ]
1118
+ ),
1119
+ ),
1120
+ "movie_id" : ParameterSharding (
1121
+ sharding_type = "row_wise" ,
1122
+ compute_kernel = "dense" ,
1123
+ ranks = [0 , 1 ],
1124
+ sharding_spec = EnumerableShardingSpec (
1125
+ [
1126
+ ShardMetadata (
1127
+ shard_offsets = [0 , 0 ],
1128
+ shard_sizes = [2048 , 32 ],
1129
+ placement = "rank:0/cuda:0" ,
1130
+ bucket_id_offset = 0 ,
1131
+ num_buckets = 20 ,
1132
+ ),
1133
+ ShardMetadata (
1134
+ shard_offsets = [2048 , 0 ],
1135
+ shard_sizes = [2048 , 32 ],
1136
+ placement = "rank:0/cuda:1" ,
1137
+ bucket_id_offset = 20 ,
1138
+ num_buckets = 30 ,
1139
+ ),
1140
+ ]
1141
+ ),
1142
+ ),
1143
+ }
1144
+ )
1145
+ }
1146
+ )
1147
+ expected = """module: ebc
1148
+
1149
+ param | sharding type | compute kernel | ranks
1150
+ -------- | ------------- | -------------- | ------
1151
+ user_id | table_wise | dense | [0]
1152
+ movie_id | row_wise | dense | [0, 1]
1153
+
1154
+ param | shard offsets | shard sizes | placement | bucket id offset | num buckets
1155
+ -------- | ------------- | ----------- | ------------- | ---------------- | -----------
1156
+ user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 | None | None
1157
+ movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 | 0 | 20
1158
+ movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 | 20 | 30
942
1159
"""
943
1160
self .maxDiff = None
1161
+ print ("STR PLAN BUCKET WISE" )
1162
+ print (str (plan ))
1163
+ print ("=======" )
944
1164
for i in range (len (expected .splitlines ())):
945
1165
self .assertEqual (
946
1166
expected .splitlines ()[i ].strip (), str (plan ).splitlines ()[i ].strip ()
0 commit comments