-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnn_predict.py
137 lines (117 loc) · 4.57 KB
/
cnn_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# -*- coding: utf-8 -*-
"""
Spyder Editor
This is a temporary script file.
"""
import os
import pickle
from tensorflow import optimizers
from keras.models import Sequential
from matplotlib import pyplot as plt
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import BatchNormalization, MaxPooling2D, Dense, Dropout, Flatten, Conv2D
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# 生成器读取图像
train_dir = r'C:\Users\sang\dataset\train'
val_dir = r'C:\Users\sang\dataset\val'
test_dir = r'C:\Users\sang\dataset\test'
train_datagen = ImageDataGenerator(
rescale=1./255, # 重放缩因子,数值乘以1.0/255(归一化)
shear_range=0.2, # 剪切强度(逆时针方向的剪切变换角度)
zoom_range=0.2, # 随机缩放的幅度
horizontal_flip=True # 进行随机水平翻转
)
val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(48, 48),
batch_size=128,
shuffle=True,
class_mode='categorical'
)
validation_generator = test_datagen.flow_from_directory(
val_dir,
target_size=(48, 48),
batch_size=128,
shuffle=True,
class_mode='categorical'
)
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(48, 48),
batch_size=128,
shuffle=True,
class_mode='categorical'
)
# 构建网络
model = Sequential()
# 第一段
# 第一卷积层,64个大小为5×5的卷积核,步长1,激活函数relu,卷积模式same,输入张量的大小
model.add(Conv2D(64, kernel_size=(5, 5), strides=(1, 1), activation='relu', padding='same', input_shape=(48, 48, 3)))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2))) # 第一池化层,池化核大小为2×2,步长2
model.add(BatchNormalization())
model.add(Dropout(0.4)) # 随机丢弃40%的网络连接,防止过拟合
# 第二段
model.add(Conv2D(128, kernel_size=(5, 5), strides=(1, 1), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(BatchNormalization())
model.add(Dropout(0.4))
# 第三段
model.add(Conv2D(256, kernel_size=(5, 5), strides=(1, 1), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Flatten()) # 过渡层
model.add(Dropout(0.3))
model.add(Dense(2048, activation='relu')) # 全连接层
model.add(Dropout(0.4))
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(512, activation='relu'))
model.add(Dense(7, activation='softmax')) # 分类输出层
model.summary()
# 编译
model.compile(loss='categorical_crossentropy',
# optimizer=optimizers.Adam(), # Adam优化器
optimizer=optimizers.RMSprop(learning_rate=0.0001), # rmsprop优化器
metrics=['accuracy'])
# 训练模型
history = model.fit(
train_generator, # 生成训练集生成器
steps_per_epoch=200, # train_num/batch_size=128
epochs=40, # 数据迭代轮数
validation_data=validation_generator, # 生成验证集生成器
validation_steps=28 # valid_num/batch_size=128
)
# 评估模型
test_loss, test_acc = model.evaluate(test_generator, steps=28)
print("test_loss: %.4f - test_acc: %.4f" % (test_loss, test_acc * 100))
# 保存模型
model_json = model.to_json()
with open('myModel_2_json.json', 'w') as json_file:
json_file.write(model_json)
model.save_weights('myModel_2_weight.h5')
model.save('myModel_2.h5')
with open('fit_2_log.txt', 'wb') as file_txt:
pickle.dump(history.history, file_txt, 0)
# 绘制训练中的损失曲线和精度曲线
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)
plt.figure("acc")
plt.plot(epochs, acc, 'r-', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='validation acc')
plt.title('Accuracy curve')
plt.legend()
plt.savefig('acc_2.jpg')
plt.show()
plt.figure("loss")
plt.plot(epochs, loss, 'r-', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='validation loss')
plt.title('Loss curve')
plt.legend()
plt.savefig('loss_2.jpg')
plt.show()