-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustomdatasets.py
60 lines (47 loc) · 1.78 KB
/
customdatasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from skimage.io import imread
from torch.utils import data
from torchvision import transforms
class SegmentationDataSet(data.Dataset):
def __init__(self,
inputs: list,
targets: list,
transform=None,
use_cache=False,
pre_transform=None,
):
self.inputs = inputs
self.targets = targets
self.transform = transform
self.inputs_dtype = torch.float32
self.targets_dtype = torch.long
self.use_cache = use_cache
self.pre_transform = pre_transform
if self.use_cache:
from multiprocessing import Pool
from itertools import repeat
with Pool() as pool:
self.cached_data = pool.starmap(self.read_images, zip(inputs, targets, repeat(self.pre_transform)))
def __len__(self):
return len(self.inputs)
def __getitem__(self,
index: int):
if self.use_cache:
x, y = self.cached_data[index]
else:
# Select the sample
input_ID = self.inputs[index]
target_ID = self.targets[index]
# Load input and target
x, y = imread(input_ID), imread(target_ID)
# Preprocessing
if self.transform is not None:
x, y = self.transform(x, y)
x, y = torch.from_numpy(x.copy()).type(self.inputs_dtype), torch.from_numpy(y.copy()).type(self.targets_dtype)
return x, y
@staticmethod
def read_images(inp, tar, pre_transform):
inp, tar = imread(inp), imread(tar)
if pre_transform:
inp, tar = pre_transform(inp, tar)
return inp, tar