1+ import torch
2+ from torch .utils .data import Dataset ,DataLoader
3+ import cv2
4+ import os
5+ from tqdm import tqdm
6+ from config import config
7+ from glob import glob
8+ import os
9+ from torchvision import transforms
10+ import numpy as np
11+ import random
12+ from shutil import copy
13+ from PIL import Image
14+ import math
15+
16+ np .random .seed (666 ) #设置随机种子 为了保证每次划分训练集和测试机的是相同的
17+
18+
19+ '''
20+ # 1. 对于mini_data 数据集的解析
21+ def parse_data_config(data_path):
22+ files = []
23+ #
24+ for img in os.listdir(data_path):
25+ image = data_path + img
26+ label = img.split("__")[0][3:]
27+ files.append((image,label))
28+ return files
29+
30+ #划分训练集和测试集
31+ # ratio 为划分为测试集的比例
32+ def divide_data(data_path,ratio):
33+ files = parse_data_config(data_path)
34+ temp = np.array(files)
35+ test_data = []
36+ train_data = []
37+ for i in range(config.num_classes):
38+ temp_data = []
39+ for data in temp:
40+ if data[1] == str(i):
41+ temp_data.append(data)
42+ np.random.shuffle(np.array(temp_data))
43+ test_data =test_data + temp_data[:int(ratio * len(temp_data))]
44+ train_data = train_data + temp_data[int(ratio*len(temp_data))+1:]
45+ # np.random.shuffle(temp)
46+ # test_data = files[:int(ratio * len(files))]
47+ # train_data = files[int(ratio*len(files))+1:]
48+
49+ # 从训练集中挑选 10 中图片保存到 example 文件夹中
50+ if not os.path.exists(config.example_folder):
51+ os.mkdir(config.example_folder)
52+ else:
53+ for i in os.listdir(config.example_folder):
54+ os.remove(os.path.join(config.example_folder+i))
55+ for i in range(10):
56+ index = random.randint(0,len(test_data)-1) # 随机生成图片的索引
57+ copy(test_data[index][0],config.example_folder) # 将挑选的图像复制到example文件夹
58+
59+ return test_data, train_data
60+ '''
61+ # 2. 对于flowers 数据集的解析
62+ def get_files (file_dir ,ratio ):
63+ roses = []
64+ labels_roses = []
65+ tulips = []
66+ labels_tulips = []
67+ dandelion = []
68+ labels_dandelion = []
69+ sunflowers = []
70+ labels_sunflowers = []
71+ for file in os .listdir (file_dir + 'roses' ):
72+ roses .append (file_dir + 'roses' + '/' + file )
73+ labels_roses .append (0 )
74+ for file in os .listdir (file_dir + 'tulips' ):
75+ tulips .append (file_dir + 'tulips' + '/' + file )
76+ labels_tulips .append (1 )
77+ for file in os .listdir (file_dir + 'dandelion' ):
78+ tulips .append (file_dir + 'dandelion' + '/' + file )
79+ labels_dandelion .append (2 )
80+ for file in os .listdir (file_dir + 'sunflowers' ):
81+ sunflowers .append (file_dir + 'sunflowers' + '/' + file )
82+ labels_sunflowers .append (3 )
83+
84+ image_list = np .hstack ((roses ,tulips , dandelion , sunflowers ))
85+ labels_list = np .hstack ((labels_roses , labels_tulips , labels_dandelion , labels_sunflowers ))
86+ temp = np .array ([image_list , labels_list ])
87+ temp = temp .transpose ()
88+ np .random .shuffle (temp )
89+ all_image_list = list (temp [:,0 ])
90+ all_label_list = list (temp [:,1 ])
91+ all_label_list = [int (i ) for i in all_label_list ]
92+ length = len (all_image_list )
93+ n_test = int (math .ceil (length * ratio ))
94+ n_train = length - n_test
95+
96+ tra_image = all_image_list [0 :n_train ]
97+ tra_label = all_label_list [0 :n_train ]
98+
99+ test_image = all_image_list [n_train :- 1 ]
100+ test_label = all_label_list [n_train :- 1 ]
101+
102+ train_data = [(tra_image [i ],tra_label [i ]) for i in range (len (tra_image ))]
103+ test_data = [(test_image [i ],test_label [i ]) for i in range (len (test_image ))]
104+ # print("train_data = ",test_image)
105+ # print("test_data = " , test_label)
106+ return test_data ,train_data
107+
108+ #这个数据集类的作用就是加载训练和测试时的数据
109+ class datasets (Dataset ):
110+ def __init__ (self ,data ,transform = None ,test = False ):
111+ imgs = []
112+ labels = []
113+ self .test = test
114+ self .len = len (data )
115+ self .data = data
116+ self .transform = transform
117+ for i in self .data :
118+ imgs .append (i [0 ])
119+ self .imgs = imgs
120+ labels .append (int (i [1 ]) ) #pytorch中交叉熵需要从0开始
121+ self .labels = labels
122+ def __getitem__ (self ,index ):
123+ if self .test :
124+ filename = self .imgs [index ]
125+ filename = filename
126+ img_path = self .imgs [index ]
127+ img = cv2 .imread (img_path )
128+ img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB )
129+ img = cv2 .resize (img , (config .img_width , config .img_height ))
130+ img = transforms .ToTensor ()(img )
131+ return img ,filename
132+ else :
133+ img_path = self .imgs [index ]
134+ label = self .labels [index ]
135+ #label = int(label)
136+ img = cv2 .imread (img_path )
137+ img = cv2 .cvtColor (img ,cv2 .COLOR_BGR2RGB )
138+ img = cv2 .resize (img ,(config .img_width ,config .img_height ))
139+ # img = transforms.ToTensor()(img)
140+
141+ if self .transform is not None :
142+ img = Image .fromarray (img )
143+ img = self .transform (img )
144+
145+ else :
146+ img = transforms .ToTensor ()(img )
147+ return img ,label
148+
149+ def __len__ (self ):
150+ return len (self .data )#self.len
151+
152+ def collate_fn (batch ): #表示如何将多个样本拼接成一个batch
153+ imgs = []
154+ label = []
155+ for sample in batch :
156+ imgs .append (sample [0 ])
157+ label .append (sample [1 ])
158+
159+ return torch .stack (imgs , 0 ),label
160+
161+
162+ #用于调试代码
163+ if __name__ == '__main__' :
164+ test_data ,_ = get_files (config .data_folder ,0.2 )
165+ for i in (test_data ):
166+ print (i )
167+ print (len (test_data ))
168+
169+ transform = transforms .Compose ([transforms .ToTensor ()])
170+ data = datasets (test_data ,transform = transform )
171+ #print(data[0])
0 commit comments