Skip to content

Commit c2f7d61

Browse files
Thomas Polasekfacebook-github-bot
authored andcommitted
Back out "Convert directory fbcode/torchrec to use the Ruff Formatter"
Summary: Original commit changeset: ee300de21222 Original Phabricator Diff: D66013071 bypass-github-export-checks Reviewed By: aporialiao Differential Revision: D66198773 fbshipit-source-id: 4a8e5a124937a8329d7ed39444d3fdc5e4f2d10c
1 parent 34cdb1d commit c2f7d61

File tree

107 files changed

+472
-369
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+472
-369
lines changed

benchmarks/ebc_benchmarks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def get_fused_ebc_uvm_time(
163163
location: EmbeddingLocation,
164164
epochs: int = 100,
165165
) -> Tuple[float, float]:
166+
166167
fused_ebc = FusedEmbeddingBagCollection(
167168
tables=embedding_bag_configs,
168169
optimizer_type=torch.optim.SGD,
@@ -194,6 +195,7 @@ def get_ebc_comparison(
194195
device: torch.device,
195196
epochs: int = 100,
196197
) -> Tuple[float, float, float, float, float]:
198+
197199
# Simple EBC module wrapping a list of nn.EmbeddingBag
198200
ebc = EmbeddingBagCollection(
199201
tables=embedding_bag_configs,

benchmarks/ebc_benchmarks_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def get_random_dataset(
2626
embedding_bag_configs: List[EmbeddingBagConfig],
2727
pooling_factors: Optional[Dict[str, int]] = None,
2828
) -> IterableDataset[Batch]:
29+
2930
if pooling_factors is None:
3031
pooling_factors = {}
3132

@@ -56,6 +57,7 @@ def train_one_epoch(
5657
dataset: IterableDataset[Batch],
5758
device: torch.device,
5859
) -> float:
60+
5961
start_time = time.perf_counter()
6062

6163
for data in dataset:
@@ -80,6 +82,7 @@ def train_one_epoch_fused_optimizer(
8082
dataset: IterableDataset[Batch],
8183
device: torch.device,
8284
) -> float:
85+
8386
start_time = time.perf_counter()
8487

8588
for data in dataset:
@@ -103,6 +106,7 @@ def train(
103106
device: torch.device,
104107
epochs: int = 100,
105108
) -> Tuple[float, float]:
109+
106110
training_time = []
107111
for _ in range(epochs):
108112
if optimizer:

examples/bert4rec/bert4rec_main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
# OSS import
3737
try:
38+
3839
# pyre-ignore[21]
3940
# @manual=//torchrec/github/examples/bert4rec:bert4rec_metrics
4041
from bert4rec_metrics import recalls_and_ndcgs_for_ks

examples/golden_training/train_dlrm_data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def train(
160160
)
161161

162162
def dense_filter(
163-
named_parameters: Iterator[Tuple[str, nn.Parameter]],
163+
named_parameters: Iterator[Tuple[str, nn.Parameter]]
164164
) -> Iterator[Tuple[str, nn.Parameter]]:
165165
for fqn, param in named_parameters:
166166
if "sparse" not in fqn:

examples/retrieval/two_tower_retrieval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
# OSS import
2929
try:
30+
3031
# pyre-ignore[21]
3132
# @manual=//torchrec/github/examples/retrieval:knn_index
3233
from knn_index import get_index

tools/lint/black_linter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,7 @@ def main() -> None:
179179
level=(
180180
logging.NOTSET
181181
if args.verbose
182-
else logging.DEBUG
183-
if len(args.filenames) < 1000
184-
else logging.INFO
182+
else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO
185183
),
186184
stream=sys.stderr,
187185
)

torchrec/datasets/criteo.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,8 @@ def get_file_row_ranges_and_remainder(
351351

352352
# If the ranges overlap.
353353
if rank_left_g <= file_right_g and rank_right_g >= file_left_g:
354-
overlap_left_g, overlap_right_g = (
355-
max(rank_left_g, file_left_g),
356-
min(rank_right_g, file_right_g),
354+
overlap_left_g, overlap_right_g = max(rank_left_g, file_left_g), min(
355+
rank_right_g, file_right_g
357356
)
358357

359358
# Convert overlap in global numbers to (local) numbers specific to the

torchrec/datasets/random.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
*,
3434
min_ids_per_features: Optional[List[int]] = None,
3535
) -> None:
36+
3637
self.keys = keys
3738
self.keys_length: int = len(keys)
3839
self.batch_size = batch_size
@@ -75,6 +76,7 @@ def __next__(self) -> Batch:
7576
return batch
7677

7778
def _generate_batch(self) -> Batch:
79+
7880
values = []
7981
lengths = []
8082
for key_idx, _ in enumerate(self.keys):

torchrec/datasets/test_utils/criteo_test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def _create_dataset_npys(
103103
labels: Optional[np.ndarray] = None,
104104
) -> Generator[Tuple[str, ...], None, None]:
105105
with tempfile.TemporaryDirectory() as tmpdir:
106+
106107
if filenames is None:
107108
filenames = [filename]
108109

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,9 @@ def purge(self) -> None:
785785
def named_split_embedding_weights(
786786
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
787787
) -> Iterator[Tuple[str, torch.Tensor]]:
788-
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
788+
assert (
789+
remove_duplicate
790+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
789791
for config, param in zip(
790792
self._config.embedding_tables,
791793
self.emb_module.split_embedding_weights(),
@@ -897,7 +899,9 @@ def named_parameters(
897899
def named_split_embedding_weights(
898900
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
899901
) -> Iterator[Tuple[str, torch.Tensor]]:
900-
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
902+
assert (
903+
remove_duplicate
904+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
901905
for config, tensor in zip(
902906
self._config.embedding_tables,
903907
self.split_embedding_weights(),
@@ -1078,9 +1082,8 @@ def named_parameters(
10781082
combined_key = "/".join(
10791083
[config.name for config in self._config.embedding_tables]
10801084
)
1081-
yield (
1082-
append_prefix(prefix, f"{combined_key}.weight"),
1083-
cast(nn.Parameter, self._emb_module.weights),
1085+
yield append_prefix(prefix, f"{combined_key}.weight"), cast(
1086+
nn.Parameter, self._emb_module.weights
10841087
)
10851088

10861089

@@ -1098,8 +1101,7 @@ def __init__(
10981101
self._pg = pg
10991102

11001103
self._pooling: PoolingMode = pooling_type_to_pooling_mode(
1101-
config.pooling,
1102-
sharding_type, # pyre-ignore[6]
1104+
config.pooling, sharding_type # pyre-ignore[6]
11031105
)
11041106

11051107
self._local_rows: List[int] = []
@@ -1218,7 +1220,9 @@ def purge(self) -> None:
12181220
def named_split_embedding_weights(
12191221
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
12201222
) -> Iterator[Tuple[str, torch.Tensor]]:
1221-
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1223+
assert (
1224+
remove_duplicate
1225+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
12221226
for config, tensor in zip(
12231227
self._config.embedding_tables,
12241228
self.emb_module.split_embedding_weights(),
@@ -1358,7 +1362,9 @@ def named_parameters(
13581362
def named_split_embedding_weights(
13591363
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
13601364
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
1361-
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1365+
assert (
1366+
remove_duplicate
1367+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
13621368
for config, tensor in zip(
13631369
self._config.embedding_tables,
13641370
self.split_embedding_weights(),
@@ -1561,7 +1567,6 @@ def named_parameters(
15611567
combined_key = "/".join(
15621568
[config.name for config in self._config.embedding_tables]
15631569
)
1564-
yield (
1565-
append_prefix(prefix, f"{combined_key}.weight"),
1566-
cast(nn.Parameter, self._emb_module.weights),
1570+
yield append_prefix(prefix, f"{combined_key}.weight"), cast(
1571+
nn.Parameter, self._emb_module.weights
15671572
)

0 commit comments

Comments
 (0)