Skip to content

Commit e5e2565

Browse files
committed
Revert "add NJT/TD support in test data generator (#2528)"
This reverts commit e35119d.
1 parent ebd64f3 commit e5e2565

10 files changed

+44
-121
lines changed

Diff for: install-requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
fbgemm-gpu
2-
tensordict
32
torchmetrics==1.0.3
43
tqdm
54
pyre-extensions

Diff for: requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ numpy
77
pandas
88
pyre-extensions
99
scikit-build
10-
tensordict
1110
torchmetrics==1.0.3
1211
torchx
1312
tqdm

Diff for: torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
#!/usr/bin/env python3
1111

12-
from typing import Dict, List
13-
1412
import click
1513

1614
import torch
@@ -84,10 +82,9 @@ def op_bench(
8482
)
8583

8684
def _func_to_benchmark(
87-
kjts: List[Dict[str, KeyedJaggedTensor]],
85+
kjt: KeyedJaggedTensor,
8886
model: torch.nn.Module,
8987
) -> torch.Tensor:
90-
kjt = kjts[0]["feature"]
9188
return model.forward(kjt.values(), kjt.offsets())
9289

9390
# breakpoint() # import fbvscode; fbvscode.set_trace()
@@ -111,8 +108,8 @@ def _func_to_benchmark(
111108

112109
result = benchmark_func(
113110
name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}",
114-
bench_inputs=[{"feature": inputs}],
115-
prof_inputs=[{"feature": inputs}],
111+
bench_inputs=inputs, # pyre-ignore
112+
prof_inputs=inputs, # pyre-ignore
116113
num_benchmarks=10,
117114
num_profiles=10,
118115
profile_dir=".",

Diff for: torchrec/distributed/benchmark/benchmark_utils.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -374,14 +374,11 @@ def get_inputs(
374374

375375
if train:
376376
sparse_features_by_rank = [
377-
model_input.idlist_features
378-
for model_input in model_input_by_rank
379-
if isinstance(model_input.idlist_features, KeyedJaggedTensor)
377+
model_input.idlist_features for model_input in model_input_by_rank
380378
]
381379
inputs_batch.append(sparse_features_by_rank)
382380
else:
383381
sparse_features = model_input_by_rank[0].idlist_features
384-
assert isinstance(sparse_features, KeyedJaggedTensor)
385382
inputs_batch.append([sparse_features])
386383

387384
# Transpose if train, as inputs_by_rank is currently in [B X R] format

Diff for: torchrec/distributed/test_utils/infer_utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ def model_input_to_forward_args_kjt(
264264
Optional[torch.Tensor],
265265
]:
266266
kjt = mi.idlist_features
267-
assert isinstance(kjt, KeyedJaggedTensor)
268267
return (
269268
kjt._keys,
270269
kjt._values,
@@ -292,8 +291,7 @@ def model_input_to_forward_args(
292291
]:
293292
idlist_kjt = mi.idlist_features
294293
idscore_kjt = mi.idscore_features
295-
assert isinstance(idlist_kjt, KeyedJaggedTensor)
296-
assert isinstance(idscore_kjt, KeyedJaggedTensor)
294+
assert idscore_kjt is not None
297295
return (
298296
mi.float_features,
299297
idlist_kjt._keys,

Diff for: torchrec/distributed/test_utils/test_model.py

+36-87
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import torch
1616
import torch.nn as nn
17-
from tensordict import TensorDict
1817
from torchrec.distributed.embedding_tower_sharding import (
1918
EmbeddingTowerCollectionSharder,
2019
EmbeddingTowerSharder,
@@ -47,8 +46,8 @@
4746
@dataclass
4847
class ModelInput(Pipelineable):
4948
float_features: torch.Tensor
50-
idlist_features: Union[KeyedJaggedTensor, TensorDict]
51-
idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]]
49+
idlist_features: KeyedJaggedTensor
50+
idscore_features: Optional[KeyedJaggedTensor]
5251
label: torch.Tensor
5352

5453
@staticmethod
@@ -77,13 +76,11 @@ def generate(
7776
randomize_indices: bool = True,
7877
device: Optional[torch.device] = None,
7978
max_feature_lengths: Optional[List[int]] = None,
80-
input_type: str = "kjt",
8179
) -> Tuple["ModelInput", List["ModelInput"]]:
8280
"""
8381
Returns a global (single-rank training) batch
8482
and a list of local (multi-rank training) batches of world_size.
8583
"""
86-
8784
batch_size_by_rank = [batch_size] * world_size
8885
if variable_batch_size:
8986
batch_size_by_rank = [
@@ -202,26 +199,11 @@ def _validate_pooling_factor(
202199
)
203200
global_idlist_lengths.append(lengths)
204201
global_idlist_indices.append(indices)
205-
206-
if input_type == "kjt":
207-
global_idlist_input = KeyedJaggedTensor(
208-
keys=idlist_features,
209-
values=torch.cat(global_idlist_indices),
210-
lengths=torch.cat(global_idlist_lengths),
211-
)
212-
elif input_type == "td":
213-
dict_of_nt = {
214-
k: torch.nested.nested_tensor_from_jagged(
215-
values=values,
216-
lengths=lengths,
217-
)
218-
for k, values, lengths in zip(
219-
idlist_features, global_idlist_indices, global_idlist_lengths
220-
)
221-
}
222-
global_idlist_input = TensorDict(source=dict_of_nt)
223-
else:
224-
raise ValueError(f"For IdList features, unknown input type {input_type}")
202+
global_idlist_kjt = KeyedJaggedTensor(
203+
keys=idlist_features,
204+
values=torch.cat(global_idlist_indices),
205+
lengths=torch.cat(global_idlist_lengths),
206+
)
225207

226208
for idx in range(len(idscore_ind_ranges)):
227209
ind_range = idscore_ind_ranges[idx]
@@ -263,25 +245,16 @@ def _validate_pooling_factor(
263245
global_idscore_lengths.append(lengths)
264246
global_idscore_indices.append(indices)
265247
global_idscore_weights.append(weights)
266-
267-
if input_type == "kjt":
268-
global_idscore_input = (
269-
KeyedJaggedTensor(
270-
keys=idscore_features,
271-
values=torch.cat(global_idscore_indices),
272-
lengths=torch.cat(global_idscore_lengths),
273-
weights=torch.cat(global_idscore_weights),
274-
)
275-
if global_idscore_indices
276-
else None
248+
global_idscore_kjt = (
249+
KeyedJaggedTensor(
250+
keys=idscore_features,
251+
values=torch.cat(global_idscore_indices),
252+
lengths=torch.cat(global_idscore_lengths),
253+
weights=torch.cat(global_idscore_weights),
277254
)
278-
elif input_type == "td":
279-
assert (
280-
len(idscore_features) == 0
281-
), "TensorDict does not support weighted features"
282-
global_idscore_input = None
283-
else:
284-
raise ValueError(f"For weighted features, unknown input type {input_type}")
255+
if global_idscore_indices
256+
else None
257+
)
285258

286259
if randomize_indices:
287260
global_float = torch.rand(
@@ -330,57 +303,36 @@ def _validate_pooling_factor(
330303
weights[lengths_cumsum[r] : lengths_cumsum[r + 1]]
331304
)
332305

333-
if input_type == "kjt":
334-
local_idlist_input = KeyedJaggedTensor(
335-
keys=idlist_features,
336-
values=torch.cat(local_idlist_indices),
337-
lengths=torch.cat(local_idlist_lengths),
338-
)
339-
340-
local_idscore_input = (
341-
KeyedJaggedTensor(
342-
keys=idscore_features,
343-
values=torch.cat(local_idscore_indices),
344-
lengths=torch.cat(local_idscore_lengths),
345-
weights=torch.cat(local_idscore_weights),
346-
)
347-
if local_idscore_indices
348-
else None
349-
)
350-
elif input_type == "td":
351-
dict_of_nt = {
352-
k: torch.nested.nested_tensor_from_jagged(
353-
values=values,
354-
lengths=lengths,
355-
)
356-
for k, values, lengths in zip(
357-
idlist_features, local_idlist_indices, local_idlist_lengths
358-
)
359-
}
360-
local_idlist_input = TensorDict(source=dict_of_nt)
361-
assert (
362-
len(idscore_features) == 0
363-
), "TensorDict does not support weighted features"
364-
local_idscore_input = None
306+
local_idlist_kjt = KeyedJaggedTensor(
307+
keys=idlist_features,
308+
values=torch.cat(local_idlist_indices),
309+
lengths=torch.cat(local_idlist_lengths),
310+
)
365311

366-
else:
367-
raise ValueError(
368-
f"For weighted features, unknown input type {input_type}"
312+
local_idscore_kjt = (
313+
KeyedJaggedTensor(
314+
keys=idscore_features,
315+
values=torch.cat(local_idscore_indices),
316+
lengths=torch.cat(local_idscore_lengths),
317+
weights=torch.cat(local_idscore_weights),
369318
)
319+
if local_idscore_indices
320+
else None
321+
)
370322

371323
local_input = ModelInput(
372324
float_features=global_float[r * batch_size : (r + 1) * batch_size],
373-
idlist_features=local_idlist_input,
374-
idscore_features=local_idscore_input,
325+
idlist_features=local_idlist_kjt,
326+
idscore_features=local_idscore_kjt,
375327
label=global_label[r * batch_size : (r + 1) * batch_size],
376328
)
377329
local_inputs.append(local_input)
378330

379331
return (
380332
ModelInput(
381333
float_features=global_float,
382-
idlist_features=global_idlist_input,
383-
idscore_features=global_idscore_input,
334+
idlist_features=global_idlist_kjt,
335+
idscore_features=global_idscore_kjt,
384336
label=global_label,
385337
),
386338
local_inputs,
@@ -671,9 +623,8 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":
671623

672624
def record_stream(self, stream: torch.Stream) -> None:
673625
self.float_features.record_stream(stream)
674-
if isinstance(self.idlist_features, KeyedJaggedTensor):
675-
self.idlist_features.record_stream(stream)
676-
if isinstance(self.idscore_features, KeyedJaggedTensor):
626+
self.idlist_features.record_stream(stream)
627+
if self.idscore_features is not None:
677628
self.idscore_features.record_stream(stream)
678629
self.label.record_stream(stream)
679630

@@ -1880,8 +1831,6 @@ def forward(self, input: ModelInput) -> ModelInput:
18801831
)
18811832

18821833
# stride will be same but features will be joined
1883-
assert isinstance(modified_input.idlist_features, KeyedJaggedTensor)
1884-
assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor)
18851834
modified_input.idlist_features = KeyedJaggedTensor.concat(
18861835
[modified_input.idlist_features, self._extra_input.idlist_features]
18871836
)

Diff for: torchrec/distributed/tests/test_infer_shardings.py

-3
Original file line numberDiff line numberDiff line change
@@ -1987,7 +1987,6 @@ def test_sharded_quant_fp_ebc_tw(
19871987
inputs = []
19881988
for model_input in model_inputs:
19891989
kjt = model_input.idlist_features
1990-
assert isinstance(kjt, KeyedJaggedTensor)
19911990
kjt = kjt.to(local_device)
19921991
weights = torch.rand(
19931992
kjt._values.size(0), dtype=torch.float, device=local_device
@@ -2167,7 +2166,6 @@ def test_sharded_quant_mc_ec_rw(
21672166
inputs = []
21682167
for model_input in model_inputs:
21692168
kjt = model_input.idlist_features
2170-
assert isinstance(kjt, KeyedJaggedTensor)
21712169
kjt = kjt.to(local_device)
21722170
weights = None
21732171
inputs.append(
@@ -2303,7 +2301,6 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None:
23032301
)
23042302
inputs = []
23052303
kjt = model_inputs[0].idlist_features
2306-
assert isinstance(kjt, KeyedJaggedTensor)
23072304
kjt = kjt.to(local_device)
23082305
weights = torch.rand(
23092306
kjt._values.size(0), dtype=torch.float, device=local_device

Diff for: torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,6 @@ def _gen_pipelines(
7575
default=100,
7676
help="Total number of sparse embeddings to be used.",
7777
)
78-
@click.option(
79-
"--ratio_features_weighted",
80-
default=0.4,
81-
help="percentage of features weighted vs unweighted",
82-
)
8378
@click.option(
8479
"--dim_emb",
8580
type=int,
@@ -137,7 +132,6 @@ def _gen_pipelines(
137132
def main(
138133
world_size: int,
139134
n_features: int,
140-
ratio_features_weighted: float,
141135
dim_emb: int,
142136
n_batches: int,
143137
batch_size: int,
@@ -155,9 +149,8 @@ def main(
155149
os.environ["MASTER_ADDR"] = str("localhost")
156150
os.environ["MASTER_PORT"] = str(get_free_port())
157151

158-
num_weighted_features = int(n_features * ratio_features_weighted)
159-
num_features = n_features - num_weighted_features
160-
152+
num_features = n_features // 2
153+
num_weighted_features = n_features // 2
161154
tables = [
162155
EmbeddingBagConfig(
163156
num_embeddings=(i + 1) * 1000,
@@ -264,7 +257,6 @@ def _generate_data(
264257
world_size=world_size,
265258
num_float_features=num_float_features,
266259
pooling_avg=pooling_factor,
267-
input_type=input_type,
268260
)[1]
269261
for i in range(num_batches)
270262
]

Diff for: torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,7 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
306306
# `parameters`.
307307
optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01)
308308

309-
data = [
310-
i.idlist_features
311-
for i in local_model_inputs
312-
if isinstance(i.idlist_features, KeyedJaggedTensor)
313-
]
309+
data = [i.idlist_features for i in local_model_inputs]
314310
dataloader = iter(data)
315311
pipeline = TrainPipelinePT2(
316312
model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing

Diff for: torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py

-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ def generate_kjt(
169169
randomize_indices=True,
170170
device=device,
171171
)[0]
172-
assert isinstance(global_input.idlist_features, KeyedJaggedTensor)
173172
return global_input.idlist_features
174173

175174

0 commit comments

Comments
 (0)