-
Notifications
You must be signed in to change notification settings - Fork 45.6k
/
Copy pathtn_expand_condense_test.py
159 lines (123 loc) · 5.76 KB
/
tn_expand_condense_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for ExpandCondense tensor network layer."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
"""Unit tests for ExpandCondense TN layer.
"""
def setUp(self):
super().setUp()
self.labels = np.concatenate((np.ones((50, 1)), np.zeros((50, 1))), axis=0)
def _build_model(self, data, proj_multiple=2):
model = tf_keras.models.Sequential()
model.add(
TNExpandCondense(
proj_multiplier=proj_multiple,
use_bias=True,
activation='relu',
input_shape=(data.shape[-1],)))
model.add(tf_keras.layers.Dense(1, activation='sigmoid'))
return model
@parameterized.parameters((768, 6), (1024, 2))
def test_train(self, input_dim, proj_multiple):
tf_keras.utils.set_random_seed(0)
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model for 5 epochs
history = model.fit(data, self.labels, epochs=5, batch_size=32)
# Check that loss decreases and accuracy increases
self.assertGreater(history.history['loss'][0], history.history['loss'][-1])
self.assertLess(
history.history['accuracy'][0], history.history['accuracy'][-1])
@parameterized.parameters((768, 6), (1024, 2))
def test_weights_change(self, input_dim, proj_multiple):
tf_keras.utils.set_random_seed(0)
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
before = model.get_weights()
model.fit(data, self.labels, epochs=5, batch_size=32)
after = model.get_weights()
# Make sure every layer's weights changed
for i, _ in enumerate(before):
self.assertTrue((after[i] != before[i]).any())
@parameterized.parameters((768, 6), (1024, 2))
def test_output_shape(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
input_shape = data.shape
actual_output_shape = model(data).shape
expected_output_shape = model.compute_output_shape(input_shape)
self.assertEqual(expected_output_shape, actual_output_shape)
@parameterized.parameters((768, 6), (1024, 2))
def test_expandcondense_num_parameters(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
proj_size = proj_multiple * data.shape[-1]
model = tf_keras.models.Sequential()
model.add(
TNExpandCondense(
proj_multiplier=proj_multiple,
use_bias=True,
activation='relu',
input_shape=(data.shape[-1],)))
w1_params = data.shape[-1]**2
w2_params = 128 * 128 * (proj_size // data.shape[-1])
w3_params = 128 * 128 * (proj_size // data.shape[-1])
w4_params = (data.shape[-1] // 128) * 128 * data.shape[-1]
bias_params = ((data.shape[-1] // 128) * 128 *
(proj_size // data.shape[-1]))
expected_num_parameters = (w1_params + w2_params + w3_params +
w4_params) + bias_params
self.assertEqual(expected_num_parameters, model.count_params())
@parameterized.parameters((912, 6), (200, 2))
def test_incorrect_sizes(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
with self.assertRaises(AssertionError):
model = self._build_model(data, proj_multiple)
model.compile(optimizer='adam', loss='binary_crossentropy')
@parameterized.parameters((768, 6), (1024, 2))
def test_config(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
expected_num_parameters = model.layers[0].count_params()
# Serialize model and use config to create new layer
model_config = model.get_config()
layer_config = model_config['layers'][1]['config']
new_model = TNExpandCondense.from_config(layer_config)
# Build the layer so we can count params below
new_model.build(layer_config['batch_input_shape'])
# Check that original layer had same num params as layer built from config
self.assertEqual(expected_num_parameters, new_model.count_params())
@parameterized.parameters((768, 6), (1024, 2))
def test_model_save(self, input_dim, proj_multiple):
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model for 5 epochs
model.fit(data, self.labels, epochs=5, batch_size=32)
save_path = os.path.join(self.get_temp_dir(), 'test_model')
model.save(save_path)
loaded_model = tf_keras.models.load_model(save_path)
# Compare model predictions and loaded_model predictions
self.assertAllEqual(model.predict(data), loaded_model.predict(data))
if __name__ == '__main__':
tf.test.main()