Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (4/16)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447727

fbshipit-source-id: 8844b1caa08de94d04ac4df3c768dbf8c865fd2f
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 3, 2024
1 parent 8c65a2d commit a865d4d
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 44 deletions.
8 changes: 5 additions & 3 deletions examples/torchrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,11 @@ def init_embdedding_configs(
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=embedding_dim,
num_embeddings=none_throws(num_embeddings_per_feature)[feature_idx]
if num_embeddings is None
else num_embeddings,
num_embeddings=(
none_throws(num_embeddings_per_feature)[feature_idx]
if num_embeddings is None
else num_embeddings
),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
Expand Down
44 changes: 24 additions & 20 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,33 +686,37 @@ def test_best_checkpoint_no_top_k(self) -> None:
bcs.on_train_epoch_end(state, my_train_unit)
self.assertEqual(
bcs._ckpt_dirpaths,
[
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
]
if mode == "min"
else [
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
],
(
[
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
]
if mode == "min"
else [
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
]
),
)

my_train_unit.train_loss = 0.015
my_train_unit.train_progress.increment_epoch()
bcs.on_train_epoch_end(state, my_train_unit)
self.assertEqual(
bcs._ckpt_dirpaths,
[
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
os.path.join(temp_dir, "epoch_2_step_0_train_loss=0.015"),
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
]
if mode == "min"
else [
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
os.path.join(temp_dir, "epoch_2_step_0_train_loss=0.015"),
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
],
(
[
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
os.path.join(temp_dir, "epoch_2_step_0_train_loss=0.015"),
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
]
if mode == "min"
else [
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
os.path.join(temp_dir, "epoch_2_step_0_train_loss=0.015"),
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
]
),
)

def test_best_checkpoint_top_k(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class PrecisionTest(unittest.TestCase):
def test_convert_precision_str_to_dtype_success(self) -> None:
for (precision_str, expected_dtype) in [
for precision_str, expected_dtype in [
("fp16", torch.float16),
("bf16", torch.bfloat16),
("fp32", None),
Expand Down
3 changes: 1 addition & 2 deletions torchtnt/framework/callbacks/base_csv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def get_step_output_rows(
unit: TPredictUnit,
# pyre-fixme: Missing parameter annotation [2]
step_output: Any,
) -> Union[List[str], List[List[str]]]:
...
) -> Union[List[str], List[List[str]]]: ...

def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
if get_global_rank() == 0:
Expand Down
1 change: 1 addition & 0 deletions torchtnt/framework/callbacks/checkpointer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass
from typing import Literal, Optional


# TODO: eventually support overriding all knobs
@dataclass
class KnobOptions:
Expand Down
1 change: 0 additions & 1 deletion torchtnt/framework/callbacks/memory_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class MemorySnapshot(Callback):

"""
A callback for memory snapshot collection during training, saving pickle files to the user-specified directory.
Uses `Memory Snapshots <https://zdevito.github.io/2022/08/16/memory-snapshots.html>`.
Expand Down
6 changes: 3 additions & 3 deletions torchtnt/utils/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ def __init__(
float(weights[name]) for name in self._iterator_names
]
self._iterator_is_exhausted: List[bool] = [False] * len(self._iterator_names)
self.stopping_mechanism: Optional[
StoppingMechanism
] = iteration_strategy.stopping_mechanism
self.stopping_mechanism: Optional[StoppingMechanism] = (
iteration_strategy.stopping_mechanism
)
self.enforce_same_loader_across_ranks: bool = (
iteration_strategy.enforce_same_loader_across_ranks
)
Expand Down
6 changes: 2 additions & 4 deletions torchtnt/utils/loggers/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def log(self, name: str, data: Scalar, step: int) -> None:
self._len_before_flush = len(self._log_buffer)

@abstractmethod
def flush(self) -> None:
...
def flush(self) -> None: ...

@abstractmethod
def close(self) -> None:
...
def close(self) -> None: ...
9 changes: 3 additions & 6 deletions torchtnt/utils/memory_snapshot_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,13 @@ def __exit__(
self.stop()

@abstractmethod
def start(self) -> None:
...
def start(self) -> None: ...

@abstractmethod
def stop(self) -> None:
...
def stop(self) -> None: ...

@abstractmethod
def step(self) -> None:
...
def step(self) -> None: ...


class MemorySnapshotProfiler(MemorySnapshotProfilerBase):
Expand Down
6 changes: 2 additions & 4 deletions torchtnt/utils/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
class Stateful(Protocol):
"""Defines the interface for checkpoint saving and loading."""

def state_dict(self) -> Dict[str, Any]:
...
def state_dict(self) -> Dict[str, Any]: ...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...


StatefulDict = Dict[str, Stateful]
Expand Down

0 comments on commit a865d4d

Please sign in to comment.