Skip to content

Commit 38f3f61

Browse files
committed
style: black formatting for precommit
1 parent fc4b54b commit 38f3f61

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

Diff for: demo-notebooks/guided-demos/download_mnist_datasets.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,32 @@
1616
from torchvision.datasets import MNIST
1717
from torchvision import transforms
1818

19+
1920
def download_mnist_dataset(destination_dir):
2021
# Ensure the destination directory exists
2122
if not os.path.exists(destination_dir):
2223
os.makedirs(destination_dir)
2324

2425
# Define transformations
25-
transform = transforms.Compose([
26-
transforms.ToTensor(),
27-
transforms.Normalize((0.1307,), (0.3081,))
28-
])
26+
transform = transforms.Compose(
27+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
28+
)
2929

3030
# Download the training data
31-
train_set = MNIST(root=destination_dir, train=True, download=True, transform=transform)
31+
train_set = MNIST(
32+
root=destination_dir, train=True, download=True, transform=transform
33+
)
3234

3335
# Download the test data
34-
test_set = MNIST(root=destination_dir, train=False, download=True, transform=transform)
36+
test_set = MNIST(
37+
root=destination_dir, train=False, download=True, transform=transform
38+
)
3539

3640
print(f"MNIST dataset downloaded in {destination_dir}")
3741

42+
3843
# Specify the directory where you
3944
script_dir = os.path.dirname(os.path.abspath(__file__))
4045
destination_dir = script_dir + "/mnist_datasets"
4146

42-
download_mnist_dataset(destination_dir)
47+
download_mnist_dataset(destination_dir)

Diff for: demo-notebooks/guided-demos/mnist_disconnected.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def prepare_data(self):
121121
def setup(self, stage=None):
122122
# Assign train/val datasets for use in dataloaders
123123
if stage == "fit" or stage is None:
124-
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform, download=False)
124+
mnist_full = MNIST(
125+
self.data_dir, train=True, transform=self.transform, download=False
126+
)
125127
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
126128

127129
# Assign test dataset for use in dataloader(s)

0 commit comments

Comments
 (0)