Skip to content

Commit b227f54

Browse files
author
pytorchbot
committed
2025-01-23 nightly release (4d7b7ff)
1 parent 61a849d commit b227f54

File tree

9 files changed

+58
-127
lines changed

9 files changed

+58
-127
lines changed

.github/scripts/install_fbgemm.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ if [[ $CU_VERSION = cu* ]]; then
1515
echo "[NOVA] Setting LD_LIBRARY_PATH ..."
1616
conda env config vars set -p ${CONDA_ENV} \
1717
LD_LIBRARY_PATH="/usr/local/lib:${CUDA_HOME}/lib64:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}"
18+
else
19+
echo "[NOVA] Setting LD_LIBRARY_PATH ..."
20+
conda env config vars set -p ${CONDA_ENV} \
21+
LD_LIBRARY_PATH="/usr/local/lib:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}"
1822
fi
1923

2024
if [ "$CHANNEL" = "nightly" ]; then

torchrec/distributed/embedding.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
)
2727

2828
import torch
29-
from tensordict import TensorDict
3029
from torch import distributed as dist, nn
3130
from torch.autograd.profiler import record_function
3231
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
@@ -91,7 +90,6 @@
9190
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
9291
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
9392
from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor
94-
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
9593

9694
try:
9795
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -1200,15 +1198,8 @@ def _compute_sequence_vbe_context(
12001198
def input_dist(
12011199
self,
12021200
ctx: EmbeddingCollectionContext,
1203-
features: TypeUnion[KeyedJaggedTensor, TensorDict],
1201+
features: KeyedJaggedTensor,
12041202
) -> Awaitable[Awaitable[KJTList]]:
1205-
need_permute: bool = True
1206-
if isinstance(features, TensorDict):
1207-
feature_keys = list(features.keys()) # pyre-ignore[6]
1208-
if self._features_order:
1209-
feature_keys = [feature_keys[i] for i in self._features_order]
1210-
need_permute = False
1211-
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
12121203
if self._has_uninitialized_input_dist:
12131204
self._create_input_dist(input_feature_names=features.keys())
12141205
self._has_uninitialized_input_dist = False
@@ -1218,7 +1209,7 @@ def input_dist(
12181209
unpadded_features = features
12191210
features = pad_vbe_kjt_lengths(unpadded_features)
12201211

1221-
if need_permute and self._features_order:
1212+
if self._features_order:
12221213
features = features.permute(
12231214
self._features_order,
12241215
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
import torch
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
30-
from tensordict import TensorDict
3130
from torch import distributed as dist, nn, Tensor
3231
from torch.autograd.profiler import record_function
3332
from torch.distributed._shard.sharded_tensor import TensorProperties
@@ -95,7 +94,6 @@
9594
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
9695
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
9796
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
98-
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
9997

10098
try:
10199
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -658,7 +656,9 @@ def __init__(
658656
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
659657
# to support mean pooling callback hook
660658
self._has_mean_pooling_callback: bool = (
661-
PoolingType.MEAN.value in self._pooling_type_to_rs_features
659+
True
660+
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
661+
else False
662662
)
663663
self._dim_per_key: Optional[torch.Tensor] = None
664664
self._kjt_key_indices: Dict[str, int] = {}
@@ -1189,16 +1189,8 @@ def _create_inverse_indices_permute_indices(
11891189

11901190
# pyre-ignore [14]
11911191
def input_dist(
1192-
self,
1193-
ctx: EmbeddingBagCollectionContext,
1194-
features: Union[KeyedJaggedTensor, TensorDict],
1192+
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
11951193
) -> Awaitable[Awaitable[KJTList]]:
1196-
if isinstance(features, TensorDict):
1197-
feature_keys = list(features.keys()) # pyre-ignore[6]
1198-
if len(self._features_order) > 0:
1199-
feature_keys = [feature_keys[i] for i in self._features_order]
1200-
self._has_features_permute = False # feature_keys are in order
1201-
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
12021194
ctx.variable_batch_per_feature = features.variable_stride_per_key()
12031195
ctx.inverse_indices = features.inverse_indices_or_none()
12041196
if self._has_uninitialized_input_dist:

torchrec/distributed/model_parallel.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def _create_process_groups(
770770
) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]:
771771
"""
772772
Creates process groups for sharding and replication, the process groups
773-
are created in the same exact order on all ranks as per `dist.new_group` API.
773+
are created using the DeviceMesh API.
774774
775775
Args:
776776
global_rank (int): The global rank of the current process.
@@ -781,44 +781,27 @@ def _create_process_groups(
781781
Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh,
782782
replication process group, and allreduce process group.
783783
"""
784-
# TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a
785784
peer_matrix = []
786-
sharding_pg, replica_pg = None, None
787785
step = world_size // local_size
788786

789-
my_group_rank = global_rank % step
790787
for group_rank in range(world_size // local_size):
791788
peers = [step * r + group_rank for r in range(local_size)]
792-
backend = dist.get_backend(self._pg)
793-
curr_pg = dist.new_group(backend=backend, ranks=peers)
794789
peer_matrix.append(peers)
795-
if my_group_rank == group_rank:
796-
logger.warning(
797-
f"[Connection] 2D sharding_group: [{global_rank}] -> [{peers}]"
798-
)
799-
sharding_pg = curr_pg
800-
assert sharding_pg is not None, "sharding_pg is not initialized!"
801-
dist.barrier()
802-
803-
my_inter_rank = global_rank // step
804-
for inter_rank in range(local_size):
805-
peers = [inter_rank * step + r for r in range(step)]
806-
backend = dist.get_backend(self._pg)
807-
curr_pg = dist.new_group(backend=backend, ranks=peers)
808-
if my_inter_rank == inter_rank:
809-
logger.warning(
810-
f"[Connection] 2D replica_group: [{global_rank}] -> [{peers}]"
811-
)
812-
replica_pg = curr_pg
813-
assert replica_pg is not None, "replica_pg is not initialized!"
814-
dist.barrier()
815790

816791
mesh = DeviceMesh(
817792
device_type=self._device.type,
818793
mesh=peer_matrix,
819794
mesh_dim_names=("replicate", "shard"),
820795
)
821796
logger.warning(f"[Connection] 2D Device Mesh created: {mesh}")
797+
sharding_pg = mesh.get_group(mesh_dim="shard")
798+
logger.warning(
799+
f"[Connection] 2D sharding_group: [{global_rank}] -> [{mesh['shard']}]"
800+
)
801+
replica_pg = mesh.get_group(mesh_dim="replicate")
802+
logger.warning(
803+
f"[Connection] 2D replica_group: [{global_rank}] -> [{mesh['replicate']}]"
804+
)
822805

823806
return mesh, sharding_pg, replica_pg
824807

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def gen_model_and_input(
147147
long_indices: bool = True,
148148
global_constant_batch: bool = False,
149149
num_inputs: int = 1,
150-
input_type: str = "kjt", # "kjt" or "td"
151150
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
152151
torch.manual_seed(0)
153152
if dedup_feature_names:
@@ -178,9 +177,9 @@ def gen_model_and_input(
178177
feature_processor_modules=feature_processor_modules,
179178
)
180179
inputs = []
181-
if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input:
182-
for _ in range(num_inputs):
183-
inputs.append(
180+
for _ in range(num_inputs):
181+
inputs.append(
182+
(
184183
cast(VariableBatchModelInputCallable, generate)(
185184
average_batch_size=batch_size,
186185
world_size=world_size,
@@ -189,26 +188,8 @@ def gen_model_and_input(
189188
weighted_tables=weighted_tables or [],
190189
global_constant_batch=global_constant_batch,
191190
)
192-
)
193-
elif generate == ModelInput.generate:
194-
for _ in range(num_inputs):
195-
inputs.append(
196-
ModelInput.generate(
197-
world_size=world_size,
198-
tables=tables,
199-
dedup_tables=dedup_tables,
200-
weighted_tables=weighted_tables or [],
201-
num_float_features=num_float_features,
202-
variable_batch_size=variable_batch_size,
203-
batch_size=batch_size,
204-
long_indices=long_indices,
205-
input_type=input_type,
206-
)
207-
)
208-
else:
209-
for _ in range(num_inputs):
210-
inputs.append(
211-
cast(ModelInputCallable, generate)(
191+
if generate == ModelInput.generate_variable_batch_input
192+
else cast(ModelInputCallable, generate)(
212193
world_size=world_size,
213194
tables=tables,
214195
dedup_tables=dedup_tables,
@@ -219,6 +200,7 @@ def gen_model_and_input(
219200
long_indices=long_indices,
220201
)
221202
)
203+
)
222204
return (model, inputs)
223205

224206

@@ -315,7 +297,6 @@ def sharding_single_rank_test(
315297
global_constant_batch: bool = False,
316298
world_size_2D: Optional[int] = None,
317299
node_group_size: Optional[int] = None,
318-
input_type: str = "kjt", # "kjt" or "td"
319300
) -> None:
320301
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
321302
# Generate model & inputs.
@@ -338,7 +319,6 @@ def sharding_single_rank_test(
338319
batch_size=batch_size,
339320
feature_processor_modules=feature_processor_modules,
340321
global_constant_batch=global_constant_batch,
341-
input_type=input_type,
342322
)
343323
global_model = global_model.to(ctx.device)
344324
global_input = inputs[0][0].to(ctx.device)

torchrec/distributed/tests/test_sequence_model_parallel.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -376,44 +376,3 @@ def _test_sharding(
376376
variable_batch_per_feature=variable_batch_per_feature,
377377
global_constant_batch=True,
378378
)
379-
380-
381-
@skip_if_asan_class
382-
class TDSequenceModelParallelTest(SequenceModelParallelTest):
383-
384-
def test_sharding_variable_batch(self) -> None:
385-
pass
386-
387-
def _test_sharding(
388-
self,
389-
sharders: List[TestEmbeddingCollectionSharder],
390-
backend: str = "gloo",
391-
world_size: int = 2,
392-
local_size: Optional[int] = None,
393-
constraints: Optional[Dict[str, ParameterConstraints]] = None,
394-
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
395-
qcomms_config: Optional[QCommsConfig] = None,
396-
apply_optimizer_in_backward_config: Optional[
397-
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
398-
] = None,
399-
variable_batch_size: bool = False,
400-
variable_batch_per_feature: bool = False,
401-
) -> None:
402-
self._run_multi_process_test(
403-
callable=sharding_single_rank_test,
404-
world_size=world_size,
405-
local_size=local_size,
406-
model_class=model_class,
407-
tables=self.tables,
408-
embedding_groups=self.embedding_groups,
409-
sharders=sharders,
410-
optim=EmbOptimType.EXACT_SGD,
411-
backend=backend,
412-
constraints=constraints,
413-
qcomms_config=qcomms_config,
414-
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
415-
variable_batch_size=variable_batch_size,
416-
variable_batch_per_feature=variable_batch_per_feature,
417-
global_constant_batch=True,
418-
input_type="td",
419-
)

torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def main(
160160

161161
tables = [
162162
EmbeddingBagConfig(
163-
num_embeddings=max(i + 1, 100) * 1000,
163+
num_embeddings=(i + 1) * 1000,
164164
embedding_dim=dim_emb,
165165
name="table_" + str(i),
166166
feature_names=["feature_" + str(i)],
@@ -169,7 +169,7 @@ def main(
169169
]
170170
weighted_tables = [
171171
EmbeddingBagConfig(
172-
num_embeddings=max(i + 1, 100) * 1000,
172+
num_embeddings=(i + 1) * 1000,
173173
embedding_dim=dim_emb,
174174
name="weighted_table_" + str(i),
175175
feature_names=["weighted_feature_" + str(i)],

torchrec/modules/embedding_modules.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
pooling_type_to_str,
2020
)
2121
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
22-
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
2322

2423

2524
@torch.fx.wrap
@@ -219,10 +218,7 @@ def __init__(
219218
self._feature_names: List[List[str]] = [table.feature_names for table in tables]
220219
self.reset_parameters()
221220

222-
def forward(
223-
self,
224-
features: KeyedJaggedTensor, # can also take TensorDict as input
225-
) -> KeyedTensor:
221+
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
226222
"""
227223
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
228224
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
@@ -233,7 +229,6 @@ def forward(
233229
KeyedTensor
234230
"""
235231
flat_feature_names: List[str] = []
236-
features = maybe_td_to_kjt(features, None)
237232
for names in self._feature_names:
238233
flat_feature_names.extend(names)
239234
inverse_indices = reorder_inverse_indices(
@@ -453,7 +448,7 @@ def __init__( # noqa C901
453448

454449
def forward(
455450
self,
456-
features: KeyedJaggedTensor, # can also take TensorDict as input
451+
features: KeyedJaggedTensor,
457452
) -> Dict[str, JaggedTensor]:
458453
"""
459454
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
@@ -466,7 +461,6 @@ def forward(
466461
Dict[str, JaggedTensor]
467462
"""
468463

469-
features = maybe_td_to_kjt(features, None)
470464
feature_embeddings: Dict[str, JaggedTensor] = {}
471465
jt_dict: Dict[str, JaggedTensor] = features.to_dict()
472466
for i, emb_module in enumerate(self.embeddings.values()):

torchrec/schema/utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,32 @@
88
# pyre-strict
99

1010
import inspect
11+
import typing
12+
from typing import Any
13+
14+
15+
def _is_annot_compatible(prev: object, curr: object) -> bool:
16+
if prev == curr:
17+
return True
18+
19+
if not (prev_origin := typing.get_origin(prev)):
20+
return False
21+
if not (curr_origin := typing.get_origin(curr)):
22+
return False
23+
24+
if prev_origin != curr_origin:
25+
return False
26+
27+
prev_args = typing.get_args(prev)
28+
curr_args = typing.get_args(curr)
29+
if len(prev_args) != len(curr_args):
30+
return False
31+
32+
for prev_arg, curr_arg in zip(prev_args, curr_args):
33+
if not _is_annot_compatible(prev_arg, curr_arg):
34+
return False
35+
36+
return True
1137

1238

1339
def is_signature_compatible(
@@ -84,6 +110,8 @@ def is_signature_compatible(
84110
return False
85111

86112
# TODO: Account for Union Types?
87-
if current_signature.return_annotation != previous_signature.return_annotation:
113+
if not _is_annot_compatible(
114+
previous_signature.return_annotation, current_signature.return_annotation
115+
):
88116
return False
89117
return True

0 commit comments

Comments
 (0)