16
16
from torchvision .datasets import MNIST
17
17
from torchvision import transforms
18
18
19
+
19
20
def download_mnist_dataset (destination_dir ):
20
21
# Ensure the destination directory exists
21
22
if not os .path .exists (destination_dir ):
22
23
os .makedirs (destination_dir )
23
24
24
25
# Define transformations
25
- transform = transforms .Compose ([
26
- transforms .ToTensor (),
27
- transforms .Normalize ((0.1307 ,), (0.3081 ,))
28
- ])
26
+ transform = transforms .Compose (
27
+ [transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))]
28
+ )
29
29
30
30
# Download the training data
31
- train_set = MNIST (root = destination_dir , train = True , download = True , transform = transform )
31
+ train_set = MNIST (
32
+ root = destination_dir , train = True , download = True , transform = transform
33
+ )
32
34
33
35
# Download the test data
34
- test_set = MNIST (root = destination_dir , train = False , download = True , transform = transform )
36
+ test_set = MNIST (
37
+ root = destination_dir , train = False , download = True , transform = transform
38
+ )
35
39
36
40
print (f"MNIST dataset downloaded in { destination_dir } " )
37
41
42
+
38
43
# Specify the directory where you
39
44
script_dir = os .path .dirname (os .path .abspath (__file__ ))
40
45
destination_dir = script_dir + "/mnist_datasets"
41
46
42
- download_mnist_dataset (destination_dir )
47
+ download_mnist_dataset (destination_dir )
0 commit comments