1
+ # Copyright 2022 IBM, Red Hat
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from torchvision .datasets import MNIST
17
+ from torchvision import transforms
18
+
19
+ def download_mnist_dataset (destination_dir ):
20
+ # Ensure the destination directory exists
21
+ if not os .path .exists (destination_dir ):
22
+ os .makedirs (destination_dir )
23
+
24
+ # Define transformations
25
+ transform = transforms .Compose ([
26
+ transforms .ToTensor (),
27
+ transforms .Normalize ((0.1307 ,), (0.3081 ,))
28
+ ])
29
+
30
+ # Download the training data
31
+ train_set = MNIST (root = destination_dir , train = True , download = True , transform = transform )
32
+
33
+ # Download the test data
34
+ test_set = MNIST (root = destination_dir , train = False , download = True , transform = transform )
35
+
36
+ print (f"MNIST dataset downloaded in { destination_dir } " )
37
+
38
+ # Specify the directory where you
39
+ script_dir = os .path .dirname (os .path .abspath (__file__ ))
40
+ destination_dir = script_dir + "/mnist_datasets"
41
+
42
+ download_mnist_dataset (destination_dir )
0 commit comments