-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path7-train.py
executable file
·32 lines (29 loc) · 1.21 KB
/
7-train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#!/usr/bin/env python3
"""Function that trains a model using mini-batch gradient descent
also analyze validaiton data, learning decay and early stopping"""
import tensorflow.keras as K
def train_model(network, data, labels, batch_size, epochs,
validation_data=None, early_stopping=False,
patience=0, learning_rate_decay=False, alpha=0.1,
decay_rate=1, verbose=True, shuffle=False):
"""return history"""
callback = []
if validation_data is not None and early_stopping:
callback = [K.callbacks.EarlyStopping(monitor="val_loss",
patience=patience)]
if validation_data is not None and learning_rate_decay:
def lr_scheduler(epoch):
return (alpha / (1 + (decay_rate * epoch)))
learning = K.callbacks.LearningRateScheduler(schedule=lr_scheduler,
verbose=1)
callback.append(learning)
history = network.fit(
x=data,
y=labels,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
shuffle=shuffle,
validation_data=validation_data,
callbacks=callback)
return history