|
| 1 | +# A patch file to replace ImageNet with a dummy dataset. |
| 2 | +# Use only for benchmarking purposes. |
| 3 | + |
| 4 | +diff --git a/train.py b/train.py |
| 5 | +index 6e3b058..8ddbcdd 100755 |
| 6 | +--- a/train.py |
| 7 | ++++ b/train.py |
| 8 | +@@ -61,6 +61,34 @@ except ImportError: |
| 9 | + torch.backends.cudnn.benchmark = True |
| 10 | + _logger = logging.getLogger('train') |
| 11 | + |
| 12 | ++ |
| 13 | ++class DummyImageDataset(torch.utils.data.Dataset): |
| 14 | ++ """Dummy dataset with synthetic images.""" |
| 15 | ++ _IMAGE_HEIGHT = 3072 |
| 16 | ++ _IMAGE_WIDTH = 2304 |
| 17 | ++ |
| 18 | ++ def __init__(self, num_images, num_classes): |
| 19 | ++ import numpy as np |
| 20 | ++ from PIL import Image |
| 21 | ++ imarray = np.random.rand(self._IMAGE_HEIGHT, self._IMAGE_WIDTH, 3) * 255 |
| 22 | ++ self.img = Image.fromarray(imarray.astype('uint8')).convert('RGB') |
| 23 | ++ self.num_images = num_images |
| 24 | ++ self.num_classes = num_classes |
| 25 | ++ self.transform = None |
| 26 | ++ self.target_transform = None |
| 27 | ++ |
| 28 | ++ def __len__(self): |
| 29 | ++ return self.num_images |
| 30 | ++ |
| 31 | ++ def __getitem__(self, idx): |
| 32 | ++ if self.transform is not None: |
| 33 | ++ img = self.transform(self.img) |
| 34 | ++ target = idx % self.num_classes |
| 35 | ++ if self.target_transform is not None: |
| 36 | ++ target = self.target_transform(target) |
| 37 | ++ return img, target |
| 38 | ++ |
| 39 | ++ |
| 40 | + # The first arg parser parses out only the --config argument, this argument is used to |
| 41 | + # load a yaml file containing key-values that override the defaults for the main parser below |
| 42 | + config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) |
| 43 | +@@ -71,8 +99,6 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', |
| 44 | + parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') |
| 45 | + |
| 46 | + # Dataset parameters |
| 47 | +-parser.add_argument('data_dir', metavar='DIR', |
| 48 | +- help='path to dataset') |
| 49 | + parser.add_argument('--dataset', '-d', metavar='NAME', default='', |
| 50 | + help='dataset type (default: ImageFolder/ImageTar if empty)') |
| 51 | + parser.add_argument('--train-split', metavar='NAME', default='train', |
| 52 | +@@ -486,17 +512,8 @@ def main(): |
| 53 | + _logger.info('Scheduled epochs: {}'.format(num_epochs)) |
| 54 | + |
| 55 | + # create the train and eval datasets |
| 56 | +- dataset_train = create_dataset( |
| 57 | +- args.dataset, root=args.data_dir, split=args.train_split, is_training=True, |
| 58 | +- class_map=args.class_map, |
| 59 | +- download=args.dataset_download, |
| 60 | +- batch_size=args.batch_size, |
| 61 | +- repeats=args.epoch_repeats) |
| 62 | +- dataset_eval = create_dataset( |
| 63 | +- args.dataset, root=args.data_dir, split=args.val_split, is_training=False, |
| 64 | +- class_map=args.class_map, |
| 65 | +- download=args.dataset_download, |
| 66 | +- batch_size=args.batch_size) |
| 67 | ++ dataset_train = DummyImageDataset(num_images=1231167, num_classes=1000) |
| 68 | ++ dataset_eval = DummyImageDataset(num_images=50000, num_classes=1000) |
| 69 | + |
| 70 | + # setup mixup / cutmix |
| 71 | + collate_fn = None |
0 commit comments