|
| 1 | +import torch |
| 2 | +from torch.utils.data import Dataset |
| 3 | +import json |
| 4 | +import os |
| 5 | +from PIL import Image |
| 6 | +from utils import Transform |
| 7 | + |
| 8 | + |
| 9 | +class PascalVOCDataset(Dataset): |
| 10 | + |
| 11 | + def __init__(self, data_folder, split, keep_difficult=False): |
| 12 | + self.split = split.upper() |
| 13 | + |
| 14 | + assert self.split in {'TRAIN', 'TEST'} |
| 15 | + if self.split == 'TEST': |
| 16 | + assert keep_difficult == True, 'MUST keep difficult boxes during val/test for mAP calculation!' |
| 17 | + |
| 18 | + self.data_folder = data_folder |
| 19 | + self.transform = Transform(split=self.split) |
| 20 | + self.keep_difficult = keep_difficult |
| 21 | + |
| 22 | + with open(os.path.join(data_folder, self.split + '_images.json'), 'r') as j: |
| 23 | + self.images = json.load(j) |
| 24 | + |
| 25 | + with open(os.path.join(data_folder, self.split + '_objects.json'), 'r') as j: |
| 26 | + self.objects = json.load(j) |
| 27 | + |
| 28 | + assert len(self.images) == len(self.objects) |
| 29 | + |
| 30 | + def __getitem__(self, i): |
| 31 | + image = Image.open(self.images[i], mode='r') |
| 32 | + image = image.convert('RGB') |
| 33 | + |
| 34 | + objects = self.objects[i] |
| 35 | + boxes = torch.FloatTensor(objects['boxes']) # (n_objects, 4) |
| 36 | + labels = torch.LongTensor(objects['labels']) # (n_objects) |
| 37 | + difficulties = torch.ByteTensor(objects['difficulties']) # (n_objects) |
| 38 | + |
| 39 | + if not self.keep_difficult: |
| 40 | + boxes = boxes[1 - difficulties] |
| 41 | + labels = labels[1 - difficulties] |
| 42 | + difficulties = difficulties[1 - difficulties] |
| 43 | + |
| 44 | + image, boxes, labels, difficulties = self.transform(image, boxes, labels, difficulties) |
| 45 | + |
| 46 | + return image, boxes, labels, difficulties |
| 47 | + |
| 48 | + def __len__(self): |
| 49 | + return len(self.images) |
| 50 | + |
| 51 | + def collate_fn(self, batch): |
| 52 | + images = list() |
| 53 | + boxes = list() |
| 54 | + labels = list() |
| 55 | + difficulties = list() |
| 56 | + for b in batch: |
| 57 | + images.append(b[0]) |
| 58 | + boxes.append(b[1]) |
| 59 | + labels.append(b[2]) |
| 60 | + difficulties.append(b[3]) |
| 61 | + images = torch.stack(images, dim=0) |
| 62 | + |
| 63 | + return images, boxes, labels, difficulties # tensor (N, ...), 3 lists of N tensors each |
0 commit comments