Skip to content

Commit 2e00576

Browse files
author
SAAS R7 User1
committed
update
1 parent 8082876 commit 2e00576

File tree

5 files changed

+41
-20
lines changed

5 files changed

+41
-20
lines changed

xnn/base.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,13 @@ def call(self, inputs, training=False):
9898

9999
return output
100100

101+
@tf.function
101102
def predict(self, x):
102-
return self.apply(tf.cast(x, tf.float32), training=False).numpy()
103+
return self.apply(tf.cast(x, tf.float32), training=False)
104+
105+
@tf.function
106+
def evaluate(self, x, y, training=False):
107+
return self.loss_fn(y, self.apply(tf.cast(x, tf.float32), training=training))
103108

104109
@tf.function
105110
def train_step_init(self, inputs, labels):
@@ -113,13 +118,12 @@ def get_active_subnets(self):
113118
if self.bn_flag:
114119
beta = self.output_layer.output_weights.numpy() * self.output_layer.subnet_swicher.numpy()
115120
else:
116-
subnet_norm = [self.subnet_blocks.subnets[i].subnet_bn.moving_variance.numpy()[0] ** 0.5 for i in range(self.subnet_num)]
121+
subnet_norm = [self.subnet_blocks.subnets[i].subnet_norm.numpy()[0] for i in range(self.subnet_num)]
117122
beta = self.output_layer.output_weights.numpy() * np.array([subnet_norm]).reshape([-1, 1]) * self.output_layer.subnet_swicher.numpy()
118123

119124
subnets_scale = (np.abs(beta) / np.sum(np.abs(beta))).reshape([-1])
120125
sorted_index = np.argsort(subnets_scale)
121126
active_index = sorted_index[subnets_scale[sorted_index].cumsum()>self.beta_threshold][::-1]
122-
# active_index = sorted_index[subnets_scale[sorted_index]>self.beta_threshold][::-1]
123127
return active_index, beta, subnets_scale
124128

125129
def fit(self, train_x, train_y):
@@ -150,8 +154,8 @@ def fit(self, train_x, train_y):
150154
batch_yy = tr_y[offset:(offset + self.batch_size)]
151155
self.train_step_init(tf.cast(batch_xx, tf.float32), batch_yy)
152156

153-
self.err_train.append(self.loss_fn(tr_y, self.apply(tf.cast(tr_x, tf.float32), training=True)).numpy())
154-
self.err_val.append(self.loss_fn(val_y, self.apply(tf.cast(val_x, tf.float32), training=True)).numpy())
157+
self.err_train.append(self.evaluate(tr_x, tr_y, training=True))
158+
self.err_val.append(self.evaluate(val_x, val_y, training=True))
155159
if self.verbose & (epoch % 1 == 0):
156160
print("Training epoch: %d, train loss: %0.5f, val loss: %0.5f" %
157161
(epoch + 1, self.err_train[-1], self.err_val[-1]))
@@ -188,12 +192,13 @@ def fit(self, train_x, train_y):
188192
batch_yy = tr_y[offset:(offset + self.batch_size)]
189193
self.train_step_finetune(tf.cast(batch_xx, tf.float32), batch_yy)
190194

191-
self.err_train.append(self.loss_fn(tr_y, self.apply(tf.cast(tr_x, tf.float32), training=True)).numpy())
192-
self.err_val.append(self.loss_fn(val_y, self.apply(tf.cast(val_x, tf.float32), training=True)).numpy())
195+
self.err_train.append(self.evaluate(tr_x, tr_y, training=True))
196+
self.err_val.append(self.evaluate(val_x, val_y, training=True))
193197
if self.verbose & (epoch % 1 == 0):
194198
print("Tuning epoch: %d, train loss: %0.5f, val loss: %0.5f" %
195199
(epoch + 1, self.err_train[-1], self.err_val[-1]))
196200

201+
self.evaluate(train_x, train_y, training=True)
197202
# record the key values in the network
198203
self.subnet_input_min = []
199204
self.subnet_input_max = []
@@ -240,9 +245,10 @@ def visualize(self, folder="./results", name="demo", dummy_name=None, save_eps=F
240245
np.min(subnets_outputs), np.max(subnets_outputs), 6), 2)
241246
ax1.set_yticks(yint)
242247
ax1.set_yticklabels(["{0: .2f}".format(j) for j in yint])
243-
legend_style = mlines.Line2D([], [], color='black', marker='o', linewidth=0.0, markersize=6,
244-
label='Scale: ' + str(np.round(100 * subnets_scale[indice], 1)) + "%")
245-
plt.legend(handles=[legend_style], fontsize=18)
248+
ax1.set_ylim([np.min(subnets_outputs) - (np.max(subnets_outputs) - np.min(subnets_outputs))*0.1,
249+
np.max(subnets_outputs) + (np.max(subnets_outputs) - np.min(subnets_outputs))*0.25])
250+
ax1.text(0.25, 0.9,'Scale: ' + str(np.round(100 * subnets_scale[indice], 1)) + "%",
251+
fontsize=16, horizontalalignment='center', verticalalignment='center', transform=ax1.transAxes)
246252

247253
ax2 = f.add_subplot(np.int(max_ids), 2, i * 2 + 2)
248254
ax2.bar(np.arange(input_size), coef_index.T[indice, :input_size])

xnn/gamnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, input_num, input_dummy_num=0, subnet_arch=[10, 6], task="Regr
7070
subnet_arch=subnet_arch,
7171
task=task,
7272
proj_method="gam",
73-
activation_func=tf.tanh,
73+
activation_func=activation_func,
7474
bn_flag=True,
7575
lr_bp=lr_bp,
7676
l1_proj=0,

xnn/layers.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def build(self, input_shape=None):
5151
initializer=self.kernel_iniializer,
5252
trainable=self.trainable,
5353
regularizer=tf.keras.regularizers.l1(self.l1_proj))
54+
self.built = True
5455

5556
def call(self, inputs, training=False):
5657
output = tf.matmul(inputs, self.proj_weights)
@@ -59,19 +60,22 @@ def call(self, inputs, training=False):
5960

6061
class Subnetwork(tf.keras.layers.Layer):
6162

62-
def __init__(self, subnet_arch=[10, 6], activation_func=tf.tanh, smooth_lambda=0.0, bn_flag=False):
63+
def __init__(self, subnet_arch=[10, 6], activation_func=tf.tanh, smooth_lambda=0.0, bn_flag=False, subnet_id=0):
6364
super(Subnetwork, self).__init__()
6465
self.dense = []
6566
self.subnet_arch = subnet_arch
6667
self.activation_func = activation_func
6768
self.smooth_lambda = smooth_lambda
6869
self.bn_flag = bn_flag
70+
self.subnet_id = subnet_id
6971

7072
def build(self, input_shape=None):
7173
for nodes in self.subnet_arch:
7274
self.dense.append(layers.Dense(nodes, activation=self.activation_func))
73-
self.output_layer = layers.Dense(1, activation=tf.identity)
74-
self.subnet_bn = BatchNormalization(momentum=0.0, epsilon=1e-10, center=False, scale=False)
75+
self.output_layer = layers.Dense(1, activation=self.activation_func)
76+
self.moving_mean = self.add_weight(name="mean"+str(self.subnet_id), shape=[1], initializer=tf.zeros_initializer(),trainable=False)
77+
self.moving_norm = self.add_weight(name="norm"+str(self.subnet_id), shape=[1], initializer=tf.ones_initializer(),trainable=False)
78+
self.built = True
7579

7680
def call(self, inputs, training=False):
7781
with tf.GradientTape() as t1:
@@ -85,12 +89,21 @@ def call(self, inputs, training=False):
8589
self.grad1 = t2.gradient(self.output_original, inputs)
8690
self.grad2 = t1.gradient(self.grad1, inputs)
8791

92+
if training:
93+
mean, norm = tf.reduce_mean(self.output_original, 0), tf.maximum(tf.math.reduce_std(self.output_original, 0), 1e-10)
94+
self.subnet_mean = mean
95+
self.subnet_norm = norm
96+
self.moving_mean.assign(mean)
97+
self.moving_norm.assign(norm)
98+
else:
99+
self.subnet_mean = self.moving_mean
100+
self.subnet_norm = self.moving_norm
101+
88102
if self.bn_flag:
89-
output = self.subnet_bn(self.output_original, training=training)
103+
output = (self.output_original - self.subnet_mean) / (self.subnet_norm)
90104
else:
91-
_ = self.subnet_bn(self.output_original, training=training)
92105
output = self.output_original
93-
self.smooth_penalty = tf.reduce_mean(tf.square(self.grad2)) / tf.sqrt(self.subnet_bn.moving_variance)
106+
self.smooth_penalty = tf.reduce_mean(tf.square(self.grad2)) / self.subnet_norm
94107
return output
95108

96109

@@ -110,7 +123,8 @@ def build(self, input_shape=None):
110123
self.subnets.append(Subnetwork(self.subnet_arch,
111124
self.activation_func,
112125
self.smooth_lambda,
113-
self.bn_flag))
126+
self.bn_flag,
127+
subnet_id=i))
114128
self.built = True
115129

116130
def call(self, inputs, training=False):
@@ -149,6 +163,7 @@ def build(self, input_shape=None):
149163
shape=[1],
150164
initializer=tf.zeros_initializer(),
151165
trainable=True)
166+
self.built = True
152167

153168
def call(self, inputs, training=False):
154169
output = (tf.matmul(inputs, self.subnet_swicher * self.output_weights) + self.output_bias)

xnn/sosxnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(self, input_num, subnet_num, input_dummy_num=0, subnet_arch=[10, 6]
9191
subnet_arch=subnet_arch,
9292
task=task,
9393
proj_method="orthogonal",
94-
activation_func=tf.tanh,
94+
activation_func=activation_func,
9595
bn_flag=True,
9696
lr_bp=lr_bp,
9797
l1_proj=l1_proj,

xnn/xnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self, input_num, subnet_num, input_dummy_num=0, subnet_arch=[10, 6]
8585
subnet_arch=subnet_arch,
8686
task=task,
8787
proj_method="comb",
88-
activation_func=tf.tanh,
88+
activation_func=activation_func,
8989
bn_flag=False,
9090
lr_bp=lr_bp,
9191
l1_proj=l1_proj,

0 commit comments

Comments
 (0)