@@ -1081,33 +1081,66 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no
1081
1081
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
1082
1082
# access is allowed on them.
1083
1083
1084
- # create ShardedTensor from local shards and metadata avoding all_gather collective
1085
- sharding_spec = none_throws (
1086
- self .module_sharding_plan [table_name ].sharding_spec
1087
- )
1088
-
1089
- tensor_properties = TensorProperties (
1090
- dtype = (
1091
- data_type_to_dtype (
1092
- self ._table_name_to_config [table_name ].data_type
1084
+ if self ._table_name_to_config [table_name ].use_virtual_table :
1085
+ # virtual table size will be recalculated before checkpointing. Here we cannot
1086
+ # use sharding spec to build tensor metadata which will exceed the checkpoint capacity limit
1087
+ self ._model_parallel_name_to_sharded_tensor [table_name ] = (
1088
+ ShardedTensor ._init_from_local_shards (
1089
+ local_shards ,
1090
+ (
1091
+ [
1092
+ # assuming virtual table only supports rw sharding for now
1093
+ # When backend return whole row, need to respect dim(1)
1094
+ # otherwise will see shard dim exceeded tensor dim error
1095
+ (
1096
+ 0
1097
+ if dim == 0
1098
+ else (
1099
+ local_shards [0 ].metadata .shard_sizes [1 ]
1100
+ if dim == 1
1101
+ else dim_size
1102
+ )
1103
+ )
1104
+ for dim , dim_size in enumerate (
1105
+ self ._name_to_table_size [table_name ]
1106
+ )
1107
+ ]
1108
+ ),
1109
+ process_group = (
1110
+ self ._env .sharding_pg
1111
+ if isinstance (self ._env , ShardingEnv2D )
1112
+ else self ._env .process_group
1113
+ ),
1093
1114
)
1094
- ),
1095
- )
1115
+ )
1116
+ else :
1117
+ # create ShardedTensor from local shards and metadata avoding all_gather collective
1118
+ sharding_spec = none_throws (
1119
+ self .module_sharding_plan [table_name ].sharding_spec
1120
+ )
1096
1121
1097
- self ._model_parallel_name_to_sharded_tensor [table_name ] = (
1098
- ShardedTensor ._init_from_local_shards_and_global_metadata (
1099
- local_shards = local_shards ,
1100
- sharded_tensor_metadata = sharding_spec .build_metadata (
1101
- tensor_sizes = self ._name_to_table_size [table_name ],
1102
- tensor_properties = tensor_properties ,
1103
- ),
1104
- process_group = (
1105
- self ._env .sharding_pg
1106
- if isinstance (self ._env , ShardingEnv2D )
1107
- else self ._env .process_group
1122
+ tensor_properties = TensorProperties (
1123
+ dtype = (
1124
+ data_type_to_dtype (
1125
+ self ._table_name_to_config [table_name ].data_type
1126
+ )
1108
1127
),
1109
1128
)
1110
- )
1129
+
1130
+ self ._model_parallel_name_to_sharded_tensor [table_name ] = (
1131
+ ShardedTensor ._init_from_local_shards_and_global_metadata (
1132
+ local_shards = local_shards ,
1133
+ sharded_tensor_metadata = sharding_spec .build_metadata (
1134
+ tensor_sizes = self ._name_to_table_size [table_name ],
1135
+ tensor_properties = tensor_properties ,
1136
+ ),
1137
+ process_group = (
1138
+ self ._env .sharding_pg
1139
+ if isinstance (self ._env , ShardingEnv2D )
1140
+ else self ._env .process_group
1141
+ ),
1142
+ )
1143
+ )
1111
1144
1112
1145
def extract_sharded_kvtensors (
1113
1146
module : ShardedEmbeddingBagCollection ,
0 commit comments