@@ -816,6 +816,159 @@ def test_row_wise_set_heterogenous_device(self, data_type: DataType) -> None:
816816 0 ,
817817 )
818818
819+ # pyre-fixme[56]
820+ @given (data_type = st .sampled_from ([DataType .FP32 , DataType .FP16 ]))
821+ @settings (verbosity = Verbosity .verbose , max_examples = 8 , deadline = None )
822+ def test_row_wise_bucket_level_sharding (self , data_type : DataType ) -> None :
823+
824+ embedding_config = [
825+ EmbeddingBagConfig (
826+ name = f"table_{ idx } " ,
827+ feature_names = [f"feature_{ idx } " ],
828+ embedding_dim = 64 ,
829+ num_embeddings = 4096 ,
830+ data_type = data_type ,
831+ )
832+ for idx in range (2 )
833+ ]
834+ module_sharding_plan = construct_module_sharding_plan (
835+ EmbeddingCollection (tables = embedding_config ),
836+ per_param_sharding = {
837+ "table_0" : row_wise (
838+ sizes_placement = (
839+ [2048 , 1024 , 1024 ],
840+ ["cpu" , "cuda" , "cuda" ],
841+ ),
842+ num_buckets_per_rank = [20 , 30 , 40 ],
843+ ),
844+ "table_1" : row_wise (
845+ sizes_placement = ([2048 , 1024 , 1024 ], ["cpu" , "cpu" , "cpu" ])
846+ ),
847+ },
848+ local_size = 1 ,
849+ world_size = 2 ,
850+ device_type = "cuda" ,
851+ )
852+
853+ # Make sure per_param_sharding setting override the default device_type
854+ device_table_0_shard_0 = (
855+ # pyre-ignore[16]
856+ module_sharding_plan ["table_0" ]
857+ .sharding_spec .shards [0 ]
858+ .placement
859+ )
860+ self .assertEqual (
861+ device_table_0_shard_0 .device ().type ,
862+ "cpu" ,
863+ )
864+ # cpu always has rank 0
865+ self .assertEqual (
866+ device_table_0_shard_0 .rank (),
867+ 0 ,
868+ )
869+ for i in range (1 , 3 ):
870+ device_table_0_shard_i = (
871+ module_sharding_plan ["table_0" ].sharding_spec .shards [i ].placement
872+ )
873+ self .assertEqual (
874+ device_table_0_shard_i .device ().type ,
875+ "cuda" ,
876+ )
877+ # first rank is assigned to cpu so index = rank - 1
878+ self .assertEqual (
879+ device_table_0_shard_i .device ().index ,
880+ i - 1 ,
881+ )
882+ self .assertEqual (
883+ device_table_0_shard_i .rank (),
884+ i ,
885+ )
886+ for i in range (3 ):
887+ device_table_1_shard_i = (
888+ module_sharding_plan ["table_1" ].sharding_spec .shards [i ].placement
889+ )
890+ self .assertEqual (
891+ device_table_1_shard_i .device ().type ,
892+ "cpu" ,
893+ )
894+ # cpu always has rank 0
895+ self .assertEqual (
896+ device_table_1_shard_i .rank (),
897+ 0 ,
898+ )
899+
900+ expected = {
901+ "table_0" : ParameterSharding (
902+ sharding_type = "row_wise" ,
903+ compute_kernel = "quant" ,
904+ ranks = [
905+ 0 ,
906+ 1 ,
907+ 2 ,
908+ ],
909+ sharding_spec = EnumerableShardingSpec (
910+ shards = [
911+ ShardMetadata (
912+ shard_offsets = [0 , 0 ],
913+ shard_sizes = [2048 , 64 ],
914+ placement = "rank:0/cpu" ,
915+ bucket_id_offset = 0 ,
916+ num_buckets = 20 ,
917+ ),
918+ ShardMetadata (
919+ shard_offsets = [2048 , 0 ],
920+ shard_sizes = [1024 , 64 ],
921+ placement = "rank:1/cuda:0" ,
922+ bucket_id_offset = 20 ,
923+ num_buckets = 30 ,
924+ ),
925+ ShardMetadata (
926+ shard_offsets = [3072 , 0 ],
927+ shard_sizes = [1024 , 64 ],
928+ placement = "rank:2/cuda:1" ,
929+ bucket_id_offset = 50 ,
930+ num_buckets = 40 ,
931+ ),
932+ ]
933+ ),
934+ ),
935+ "table_1" : ParameterSharding (
936+ sharding_type = "row_wise" ,
937+ compute_kernel = "quant" ,
938+ ranks = [
939+ 0 ,
940+ 1 ,
941+ 2 ,
942+ ],
943+ sharding_spec = EnumerableShardingSpec (
944+ shards = [
945+ ShardMetadata (
946+ shard_offsets = [0 , 0 ],
947+ shard_sizes = [2048 , 64 ],
948+ placement = "rank:0/cpu" ,
949+ bucket_id_offset = None ,
950+ num_buckets = None ,
951+ ),
952+ ShardMetadata (
953+ shard_offsets = [2048 , 0 ],
954+ shard_sizes = [1024 , 64 ],
955+ placement = "rank:0/cpu" ,
956+ bucket_id_offset = None ,
957+ num_buckets = None ,
958+ ),
959+ ShardMetadata (
960+ shard_offsets = [3072 , 0 ],
961+ shard_sizes = [1024 , 64 ],
962+ placement = "rank:0/cpu" ,
963+ bucket_id_offset = None ,
964+ num_buckets = None ,
965+ ),
966+ ]
967+ ),
968+ ),
969+ }
970+ self .assertDictEqual (expected , module_sharding_plan )
971+
819972 # pyre-fixme[56]
820973 @given (data_type = st .sampled_from ([DataType .FP32 , DataType .FP16 ]))
821974 @settings (verbosity = Verbosity .verbose , max_examples = 8 , deadline = None )
@@ -929,18 +1082,85 @@ def test_str(self) -> None:
9291082 )
9301083 expected = """module: ebc
9311084
932- param | sharding type | compute kernel | ranks
1085+ param | sharding type | compute kernel | ranks
9331086-------- | ------------- | -------------- | ------
9341087user_id | table_wise | dense | [0]
9351088movie_id | row_wise | dense | [0, 1]
9361089
937- param | shard offsets | shard sizes | placement
1090+ param | shard offsets | shard sizes | placement
9381091-------- | ------------- | ----------- | -------------
9391092user_id | [0, 0] | [4096, 32] | rank:0/cuda:0
9401093movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0
9411094movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1
1095+ """
1096+ for i in range (len (expected .splitlines ())):
1097+ self .assertEqual (
1098+ expected .splitlines ()[i ].strip (), str (plan ).splitlines ()[i ].strip ()
1099+ )
1100+
1101+ def test_str_bucket_wise_sharding (self ) -> None :
1102+ plan = ShardingPlan (
1103+ {
1104+ "ebc" : EmbeddingModuleShardingPlan (
1105+ {
1106+ "user_id" : ParameterSharding (
1107+ sharding_type = "table_wise" ,
1108+ compute_kernel = "dense" ,
1109+ ranks = [0 ],
1110+ sharding_spec = EnumerableShardingSpec (
1111+ [
1112+ ShardMetadata (
1113+ shard_offsets = [0 , 0 ],
1114+ shard_sizes = [4096 , 32 ],
1115+ placement = "rank:0/cuda:0" ,
1116+ ),
1117+ ]
1118+ ),
1119+ ),
1120+ "movie_id" : ParameterSharding (
1121+ sharding_type = "row_wise" ,
1122+ compute_kernel = "dense" ,
1123+ ranks = [0 , 1 ],
1124+ sharding_spec = EnumerableShardingSpec (
1125+ [
1126+ ShardMetadata (
1127+ shard_offsets = [0 , 0 ],
1128+ shard_sizes = [2048 , 32 ],
1129+ placement = "rank:0/cuda:0" ,
1130+ bucket_id_offset = 0 ,
1131+ num_buckets = 20 ,
1132+ ),
1133+ ShardMetadata (
1134+ shard_offsets = [2048 , 0 ],
1135+ shard_sizes = [2048 , 32 ],
1136+ placement = "rank:0/cuda:1" ,
1137+ bucket_id_offset = 20 ,
1138+ num_buckets = 30 ,
1139+ ),
1140+ ]
1141+ ),
1142+ ),
1143+ }
1144+ )
1145+ }
1146+ )
1147+ expected = """module: ebc
1148+
1149+ param | sharding type | compute kernel | ranks
1150+ -------- | ------------- | -------------- | ------
1151+ user_id | table_wise | dense | [0]
1152+ movie_id | row_wise | dense | [0, 1]
1153+
1154+ param | shard offsets | shard sizes | placement | bucket id offset | num buckets
1155+ -------- | ------------- | ----------- | ------------- | ---------------- | -----------
1156+ user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 | None | None
1157+ movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 | 0 | 20
1158+ movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 | 20 | 30
9421159"""
9431160 self .maxDiff = None
1161+ print ("STR PLAN BUCKET WISE" )
1162+ print (str (plan ))
1163+ print ("=======" )
9441164 for i in range (len (expected .splitlines ())):
9451165 self .assertEqual (
9461166 expected .splitlines ()[i ].strip (), str (plan ).splitlines ()[i ].strip ()
0 commit comments