Skip to content

Commit d279ba1

Browse files
author
SAAS R7 User1
committed
fix a bug in catnet
1 parent ff90b57 commit d279ba1

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

exnn/layers.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,16 @@ def __init__(self, feature_list, cfeature_index_list, dummy_values, bn_flag=True
8484
self.bn_flag = bn_flag
8585

8686
self.categnets = []
87-
for i in self.cfeature_index_list:
88-
feature_name = self.feature_list[i]
87+
for i, idx in enumerate(self.cfeature_index_list):
88+
feature_name = self.feature_list[idx]
8989
self.categnets.append(CategNet(category_num=len(self.dummy_values[feature_name]), bn_flag=self.bn_flag, cagetnet_id=i))
9090

9191
def call(self, inputs, training=False):
9292
output = 0
9393
if len(self.cfeature_index_list) > 0:
9494
self.categ_output = []
95-
for i in self.cfeature_index_list:
96-
self.categ_output.append(self.categnets[i](tf.gather(inputs, [self.cfeature_index_list[i]], axis=1), training=training))
95+
for i, idx in enumerate(self.cfeature_index_list):
96+
self.categ_output.append(self.categnets[i](tf.gather(inputs, [idx], axis=1), training=training))
9797
output = tf.reshape(tf.squeeze(tf.stack(self.categ_output, 1)), [-1, len(self.cfeature_index_list)])
9898
return output
9999

0 commit comments

Comments
 (0)