1
+ import fnmatch
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torch .nn .functional as F
7
+ from torch .utils .data .dataset import Dataset
8
+
9
+ class RandomScaleCrop (object ):
10
+ """
11
+ Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
12
+ """
13
+
14
+ def __init__ (self , scale = [1.0 , 1.2 , 1.5 ]):
15
+ self .scale = scale
16
+
17
+ def __call__ (self , img , label , depth , normal ):
18
+ height , width = img .shape [- 2 :]
19
+ sc = self .scale [random .randint (0 , len (self .scale ) - 1 )]
20
+ h , w = int (height / sc ), int (width / sc )
21
+ i = random .randint (0 , height - h )
22
+ j = random .randint (0 , width - w )
23
+ img_ = F .interpolate (
24
+ img [None , :, i : i + h , j : j + w ],
25
+ size = (height , width ),
26
+ mode = "bilinear" ,
27
+ align_corners = True ,
28
+ ).squeeze (0 )
29
+ label_ = (
30
+ F .interpolate (
31
+ label [None , None , i : i + h , j : j + w ],
32
+ size = (height , width ),
33
+ mode = "nearest" ,
34
+ )
35
+ .squeeze (0 )
36
+ .squeeze (0 )
37
+ )
38
+ depth_ = F .interpolate (
39
+ depth [None , :, i : i + h , j : j + w ], size = (height , width ), mode = "nearest"
40
+ ).squeeze (0 )
41
+ normal_ = F .interpolate (
42
+ normal [None , :, i : i + h , j : j + w ],
43
+ size = (height , width ),
44
+ mode = "bilinear" ,
45
+ align_corners = True ,
46
+ ).squeeze (0 )
47
+ return img_ , label_ , depth_ / sc , normal_
48
+
49
+ class RandomScaleCropCityScapes (object ):
50
+ """
51
+ Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
52
+ """
53
+ def __init__ (self , scale = [1.0 , 1.2 , 1.5 ]):
54
+ self .scale = scale
55
+
56
+ def __call__ (self , img , label , depth ):
57
+ height , width = img .shape [- 2 :]
58
+ sc = self .scale [random .randint (0 , len (self .scale ) - 1 )]
59
+ h , w = int (height / sc ), int (width / sc )
60
+ i = random .randint (0 , height - h )
61
+ j = random .randint (0 , width - w )
62
+ img_ = F .interpolate (img [None , :, i :i + h , j :j + w ], size = (height , width ), mode = 'bilinear' , align_corners = True ).squeeze (0 )
63
+ label_ = F .interpolate (label [None , None , i :i + h , j :j + w ], size = (height , width ), mode = 'nearest' ).squeeze (0 ).squeeze (0 )
64
+ depth_ = F .interpolate (depth [None , :, i :i + h , j :j + w ], size = (height , width ), mode = 'nearest' ).squeeze (0 )
65
+ return img_ , label_ , depth_ / sc
66
+
67
+ class CityScapes (Dataset ):
68
+ """
69
+ We could further improve the performance with the data augmentation of NYUv2 defined in:
70
+ [1] PAD-Net: Multi-Tasks Guided Prediction-and-Distillation Network for Simultaneous Depth Estimation and Scene Parsing
71
+ [2] Pattern affinitive propagation across depth, surface normal and semantic segmentation
72
+ [3] Mti-net: Multiscale task interaction networks for multi-task learning
73
+
74
+ 1. Random scale in a selected raio 1.0, 1.2, and 1.5.
75
+ 2. Random horizontal flip.
76
+
77
+ Please note that: all baselines and MTAN did NOT apply data augmentation in the original paper.
78
+ """
79
+ def __init__ (self , root , train = True , augmentation = False ):
80
+ self .train = train
81
+ self .root = os .path .expanduser (root )
82
+ self .augmentation = augmentation
83
+
84
+ # read the data file
85
+ if train :
86
+ self .data_path = root + '/train'
87
+ else :
88
+ self .data_path = root + '/val'
89
+
90
+ # calculate data length
91
+ self .data_len = len (fnmatch .filter (os .listdir (self .data_path + '/image' ), '*.npy' ))
92
+
93
+ def __getitem__ (self , index ):
94
+ # load data from the pre-processed npy files
95
+ image = torch .from_numpy (np .moveaxis (np .load (self .data_path + '/image/{:d}.npy' .format (index )), - 1 , 0 ))
96
+ semantic = torch .from_numpy (np .load (self .data_path + '/label_7/{:d}.npy' .format (index )))
97
+ depth = torch .from_numpy (np .moveaxis (np .load (self .data_path + '/depth/{:d}.npy' .format (index )), - 1 , 0 ))
98
+
99
+ # apply data augmentation if required
100
+ if self .augmentation :
101
+ image , semantic , depth = RandomScaleCropCityScapes ()(image , semantic , depth )
102
+ if torch .rand (1 ) < 0.5 :
103
+ image = torch .flip (image , dims = [2 ])
104
+ semantic = torch .flip (semantic , dims = [1 ])
105
+ depth = torch .flip (depth , dims = [2 ])
106
+
107
+ return image .float (), semantic .float (), depth .float ()
108
+
109
+ def __len__ (self ):
110
+ return self .data_len
0 commit comments