forked from pclubiitk/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
59 lines (40 loc) · 1.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# -*- coding: utf-8 -*-
"""main.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/14vKeukBgFkbdgQecsrBUw2LqAmedDGOX
"""
import argparse
import tensorflow as tf
import tensorflow.keras as keras
from datetime import datetime
from model import DenseNet3D_121, T3D_121, T3D_169
from dataloader import load_ucf101
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--model', type=str, default="T3D_121")
args = parser.parse_args()
model_type = args.model
num_epochs = args.epochs
return model_type , num_epochs
model_type, num_epochs = parse_args()
inputs = keras.Input(32,224,224,3)
if (model_type == 'DenseNet3D_121'):
outputs = DenseNet3D_121(inputs)
if (model_type == 'T3D_121'):
outputs = T3D_121(inputs)
if (model_type == 'DenseNet3D_169'):
outputs = T3D_169(inputs)
model = keras.models.Model(inputs, outputs)
train_clip, train_label, test_clip, test_label = load_ucf101()
logdir = "logs/loss/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
optimizer=tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, nesterov=True, decay= 0.0001)
model.compile(optimizer = optimizer,
loss='sparse_categorical_crossentropy',
metrics=['acc'],
)
training_history = model.fit(train_clip, train_label, epochs=num_epochs, batch_size=128,
validation_data=(test_clip, test_label), callbacks=[tensorboard_callback])
print("Average test loss: ", np.average(training_history.history['loss']))