1
+ import torch
2
+ from torch .utils .data import Dataset ,DataLoader
3
+ import cv2
4
+ import os
5
+ from tqdm import tqdm
6
+ from config import config
7
+ from glob import glob
8
+ import os
9
+ from torchvision import transforms
10
+ import numpy as np
11
+ import random
12
+ from shutil import copy
13
+ from PIL import Image
14
+ import math
15
+
16
+ np .random .seed (666 ) #设置随机种子 为了保证每次划分训练集和测试机的是相同的
17
+
18
+
19
+ '''
20
+ # 1. 对于mini_data 数据集的解析
21
+ def parse_data_config(data_path):
22
+ files = []
23
+ #
24
+ for img in os.listdir(data_path):
25
+ image = data_path + img
26
+ label = img.split("__")[0][3:]
27
+ files.append((image,label))
28
+ return files
29
+
30
+ #划分训练集和测试集
31
+ # ratio 为划分为测试集的比例
32
+ def divide_data(data_path,ratio):
33
+ files = parse_data_config(data_path)
34
+ temp = np.array(files)
35
+ test_data = []
36
+ train_data = []
37
+ for i in range(config.num_classes):
38
+ temp_data = []
39
+ for data in temp:
40
+ if data[1] == str(i):
41
+ temp_data.append(data)
42
+ np.random.shuffle(np.array(temp_data))
43
+ test_data =test_data + temp_data[:int(ratio * len(temp_data))]
44
+ train_data = train_data + temp_data[int(ratio*len(temp_data))+1:]
45
+ # np.random.shuffle(temp)
46
+ # test_data = files[:int(ratio * len(files))]
47
+ # train_data = files[int(ratio*len(files))+1:]
48
+
49
+ # 从训练集中挑选 10 中图片保存到 example 文件夹中
50
+ if not os.path.exists(config.example_folder):
51
+ os.mkdir(config.example_folder)
52
+ else:
53
+ for i in os.listdir(config.example_folder):
54
+ os.remove(os.path.join(config.example_folder+i))
55
+ for i in range(10):
56
+ index = random.randint(0,len(test_data)-1) # 随机生成图片的索引
57
+ copy(test_data[index][0],config.example_folder) # 将挑选的图像复制到example文件夹
58
+
59
+ return test_data, train_data
60
+ '''
61
+ # 2. 对于flowers 数据集的解析
62
+ def get_files (file_dir ,ratio ):
63
+ roses = []
64
+ labels_roses = []
65
+ tulips = []
66
+ labels_tulips = []
67
+ dandelion = []
68
+ labels_dandelion = []
69
+ sunflowers = []
70
+ labels_sunflowers = []
71
+ for file in os .listdir (file_dir + 'roses' ):
72
+ roses .append (file_dir + 'roses' + '/' + file )
73
+ labels_roses .append (0 )
74
+ for file in os .listdir (file_dir + 'tulips' ):
75
+ tulips .append (file_dir + 'tulips' + '/' + file )
76
+ labels_tulips .append (1 )
77
+ for file in os .listdir (file_dir + 'dandelion' ):
78
+ tulips .append (file_dir + 'dandelion' + '/' + file )
79
+ labels_dandelion .append (2 )
80
+ for file in os .listdir (file_dir + 'sunflowers' ):
81
+ sunflowers .append (file_dir + 'sunflowers' + '/' + file )
82
+ labels_sunflowers .append (3 )
83
+
84
+ image_list = np .hstack ((roses ,tulips , dandelion , sunflowers ))
85
+ labels_list = np .hstack ((labels_roses , labels_tulips , labels_dandelion , labels_sunflowers ))
86
+ temp = np .array ([image_list , labels_list ])
87
+ temp = temp .transpose ()
88
+ np .random .shuffle (temp )
89
+ all_image_list = list (temp [:,0 ])
90
+ all_label_list = list (temp [:,1 ])
91
+ all_label_list = [int (i ) for i in all_label_list ]
92
+ length = len (all_image_list )
93
+ n_test = int (math .ceil (length * ratio ))
94
+ n_train = length - n_test
95
+
96
+ tra_image = all_image_list [0 :n_train ]
97
+ tra_label = all_label_list [0 :n_train ]
98
+
99
+ test_image = all_image_list [n_train :- 1 ]
100
+ test_label = all_label_list [n_train :- 1 ]
101
+
102
+ train_data = [(tra_image [i ],tra_label [i ]) for i in range (len (tra_image ))]
103
+ test_data = [(test_image [i ],test_label [i ]) for i in range (len (test_image ))]
104
+ # print("train_data = ",test_image)
105
+ # print("test_data = " , test_label)
106
+ return test_data ,train_data
107
+
108
+ #这个数据集类的作用就是加载训练和测试时的数据
109
+ class datasets (Dataset ):
110
+ def __init__ (self ,data ,transform = None ,test = False ):
111
+ imgs = []
112
+ labels = []
113
+ self .test = test
114
+ self .len = len (data )
115
+ self .data = data
116
+ self .transform = transform
117
+ for i in self .data :
118
+ imgs .append (i [0 ])
119
+ self .imgs = imgs
120
+ labels .append (int (i [1 ]) ) #pytorch中交叉熵需要从0开始
121
+ self .labels = labels
122
+ def __getitem__ (self ,index ):
123
+ if self .test :
124
+ filename = self .imgs [index ]
125
+ filename = filename
126
+ img_path = self .imgs [index ]
127
+ img = cv2 .imread (img_path )
128
+ img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB )
129
+ img = cv2 .resize (img , (config .img_width , config .img_height ))
130
+ img = transforms .ToTensor ()(img )
131
+ return img ,filename
132
+ else :
133
+ img_path = self .imgs [index ]
134
+ label = self .labels [index ]
135
+ #label = int(label)
136
+ img = cv2 .imread (img_path )
137
+ img = cv2 .cvtColor (img ,cv2 .COLOR_BGR2RGB )
138
+ img = cv2 .resize (img ,(config .img_width ,config .img_height ))
139
+ # img = transforms.ToTensor()(img)
140
+
141
+ if self .transform is not None :
142
+ img = Image .fromarray (img )
143
+ img = self .transform (img )
144
+
145
+ else :
146
+ img = transforms .ToTensor ()(img )
147
+ return img ,label
148
+
149
+ def __len__ (self ):
150
+ return len (self .data )#self.len
151
+
152
+ def collate_fn (batch ): #表示如何将多个样本拼接成一个batch
153
+ imgs = []
154
+ label = []
155
+ for sample in batch :
156
+ imgs .append (sample [0 ])
157
+ label .append (sample [1 ])
158
+
159
+ return torch .stack (imgs , 0 ),label
160
+
161
+
162
+ #用于调试代码
163
+ if __name__ == '__main__' :
164
+ test_data ,_ = get_files (config .data_folder ,0.2 )
165
+ for i in (test_data ):
166
+ print (i )
167
+ print (len (test_data ))
168
+
169
+ transform = transforms .Compose ([transforms .ToTensor ()])
170
+ data = datasets (test_data ,transform = transform )
171
+ #print(data[0])
0 commit comments