Skip to content

Commit 7ff88b3

Browse files
authored
Merge pull request #1041 from warshallrho/master
Fix a bug in resnet; add performance test in static model
2 parents b1cd59d + 8e3ecea commit 7ff88b3

File tree

4 files changed

+168
-8
lines changed

4 files changed

+168
-8
lines changed

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ To release a new version, please update the changelog as followed:
9595
- Support string dtype in InputLayer (#PR 1017)
9696
- Support Dynamic RNN in RNN (#PR 1023)
9797
- Add ResNet50 static model (#PR 1030)
98+
_ Add performance test code in static model (#PR 1041)
9899

99100
### Changed
100101

@@ -115,6 +116,7 @@ To release a new version, please update the changelog as followed:
115116
- Copy original model's `trainable_weights` and `nontrainable_weights` when initializing `LayerList` (#PR 1029)
116117
- Remove redundant parts in `model.all_layers` (#PR 1029)
117118
- Replace `tf.image.resize_image_with_crop_or_pad` with `tf.image.resize_with_crop_or_pad` (#PR 1032)
119+
- Fix a bug in `ResNet50` static model (#PR 1041)
118120

119121
### Removed
120122

@@ -124,7 +126,7 @@ To release a new version, please update the changelog as followed:
124126

125127
- @zsdonghao
126128
- @ChrisWu1997: #1010 #1015 #1025 #1030
127-
- @warshallrho: #1017 #1021 #1026 #1029 #1032
129+
- @warshallrho: #1017 #1021 #1026 #1029 #1032 #1041
128130
- @ArnoldLIULJ: #1023
129131
- @JingqingZ: #1023
130132

tensorlayer/models/resnet.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -150,21 +150,21 @@ def ResNet50(pretrained=False, end_with='fc1000', n_classes=1000, name=None):
150150
n = BatchNorm(name='bn_conv1', act='relu')(n)
151151
n = MaxPool2d((3, 3), strides=(2, 2), name='max_pool1')(n)
152152

153-
for i, name in enumerate(block_names):
154-
if len(name) == 2:
155-
stage = int(name[0])
156-
block = name[1]
153+
for i, block_name in enumerate(block_names):
154+
if len(block_name) == 2:
155+
stage = int(block_name[0])
156+
block = block_name[1]
157157
if block == 'a':
158158
strides = (1, 1) if stage == 2 else (2, 2)
159159
n = conv_block(n, 3, block_filters[stage - 2], stage=stage, block=block, strides=strides)
160160
else:
161161
n = identity_block(n, 3, block_filters[stage - 2], stage=stage, block=block)
162-
elif name == 'avg_pool':
162+
elif block_name == 'avg_pool':
163163
n = GlobalMeanPool2d(name='avg_pool')(n)
164-
elif name == 'fc1000':
164+
elif block_name == 'fc1000':
165165
n = Dense(n_classes, name='fc1000')(n)
166166

167-
if name == end_with:
167+
if block_name == end_with:
168168
break
169169

170170
network = Model(inputs=ni, outputs=n, name=name)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import time
2+
import os
3+
import psutil
4+
import tensorflow as tf
5+
import tensorlayer as tl
6+
from exp_config import random_input_generator, MONITOR_INTERVAL, NUM_ITERS, BATCH_SIZE, LERANING_RATE
7+
8+
gpus = tf.config.experimental.list_physical_devices('GPU')
9+
if gpus:
10+
for gpu in gpus:
11+
tf.config.experimental.set_memory_growth(gpu, True)
12+
13+
tl.logging.set_verbosity(tl.logging.DEBUG)
14+
15+
# get the whole model
16+
vgg = tl.models.vgg16(mode='static')
17+
18+
# system monitor
19+
info = psutil.virtual_memory()
20+
monitor_interval = MONITOR_INTERVAL
21+
avg_mem_usage = 0
22+
max_mem_usage = 0
23+
count = 0
24+
total_time = 0
25+
26+
# training setting
27+
num_iter = NUM_ITERS
28+
batch_size = BATCH_SIZE
29+
train_weights = vgg.trainable_weights
30+
optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE)
31+
loss_object = tl.cost.cross_entropy
32+
33+
# data generator
34+
gen = random_input_generator(num_iter, batch_size)
35+
36+
37+
# training function
38+
@tf.function
39+
def train_step(x_batch, y_batch):
40+
# forward + backward
41+
with tf.GradientTape() as tape:
42+
## compute outputs
43+
_logits = vgg(x_batch)
44+
## compute loss and update model
45+
_loss = loss_object(_logits, y_batch)
46+
47+
grad = tape.gradient(_loss, train_weights)
48+
optimizer.apply_gradients(zip(grad, train_weights))
49+
50+
51+
# begin training
52+
vgg.train()
53+
54+
for idx, data in enumerate(gen):
55+
start_time = time.time()
56+
57+
train_step(data[0], data[1])
58+
59+
end_time = time.time()
60+
consume_time = end_time - start_time
61+
total_time += consume_time
62+
63+
if idx % monitor_interval == 0:
64+
cur_usage = psutil.Process(os.getpid()).memory_info().rss
65+
max_mem_usage = max(cur_usage, max_mem_usage)
66+
avg_mem_usage += cur_usage
67+
count += 1
68+
tl.logging.info(
69+
"[*] {} iteration: memory usage {:.2f}MB, consume time {:.4f}s".format(
70+
idx, cur_usage / (1024 * 1024), consume_time
71+
)
72+
)
73+
74+
print('consumed time:', total_time)
75+
76+
avg_mem_usage = avg_mem_usage / count / (1024 * 1024)
77+
max_mem_usage = max_mem_usage / (1024 * 1024)
78+
print('average memory usage: {:.2f}MB'.format(avg_mem_usage))
79+
print('maximum memory usage: {:.2f}MB'.format(max_mem_usage))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import time
2+
import os
3+
import psutil
4+
import tensorflow as tf
5+
import tensorlayer as tl
6+
from exp_config import random_input_generator, MONITOR_INTERVAL, NUM_ITERS, BATCH_SIZE, LERANING_RATE
7+
8+
gpus = tf.config.experimental.list_physical_devices('GPU')
9+
if gpus:
10+
for gpu in gpus:
11+
tf.config.experimental.set_memory_growth(gpu, True)
12+
13+
tl.logging.set_verbosity(tl.logging.DEBUG)
14+
15+
# get the whole model
16+
vgg = tl.models.vgg16(mode='static')
17+
18+
# system monitor
19+
info = psutil.virtual_memory()
20+
monitor_interval = MONITOR_INTERVAL
21+
avg_mem_usage = 0
22+
max_mem_usage = 0
23+
count = 0
24+
total_time = 0
25+
26+
# training setting
27+
num_iter = NUM_ITERS
28+
batch_size = BATCH_SIZE
29+
train_weights = vgg.trainable_weights
30+
optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE)
31+
loss_object = tl.cost.cross_entropy
32+
33+
# data generator
34+
gen = random_input_generator(num_iter, batch_size)
35+
36+
37+
# training function
38+
def train_step(x_batch, y_batch):
39+
# forward + backward
40+
with tf.GradientTape() as tape:
41+
## compute outputs
42+
_logits = vgg(x_batch)
43+
## compute loss and update model
44+
_loss = loss_object(_logits, y_batch)
45+
46+
grad = tape.gradient(_loss, train_weights)
47+
optimizer.apply_gradients(zip(grad, train_weights))
48+
return _loss
49+
50+
51+
# begin training
52+
vgg.train()
53+
54+
for idx, data in enumerate(gen):
55+
start_time = time.time()
56+
57+
loss = train_step(data[0], data[1])
58+
59+
end_time = time.time()
60+
consume_time = end_time - start_time
61+
total_time += consume_time
62+
63+
if idx % monitor_interval == 0:
64+
cur_usage = psutil.Process(os.getpid()).memory_info().rss
65+
max_mem_usage = max(cur_usage, max_mem_usage)
66+
avg_mem_usage += cur_usage
67+
count += 1
68+
tl.logging.info(
69+
"[*] {} iteration: memory usage {:.2f}MB, consume time {:.4f}s, loss {:.4f}".format(
70+
idx, cur_usage / (1024 * 1024), consume_time, loss
71+
)
72+
)
73+
74+
print('consumed time:', total_time)
75+
76+
avg_mem_usage = avg_mem_usage / count / (1024 * 1024)
77+
max_mem_usage = max_mem_usage / (1024 * 1024)
78+
print('average memory usage: {:.2f}MB'.format(avg_mem_usage))
79+
print('maximum memory usage: {:.2f}MB'.format(max_mem_usage))

0 commit comments

Comments
 (0)