Skip to content

Commit ff90b57

Browse files
author
SAAS R7 User1
committed
fix a bug in catnet
1 parent 10775a2 commit ff90b57

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

exnn/exnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class ExNN(BaseNet):
66
"""
77
Enhanced explainable neural network (ExNN) based on sparse, orthogonal and smooth constraints.
88
9-
ExNN is based on our paper (Yang et al. 2018) with the following implementation details:
9+
ExNN is based on our paper (Yang et al. 2020 TNNLS) with the following implementation details:
1010
1111
1. Categorical variables should be first converted by one-hot encoding, and we directly link each of the dummy variables as a bias term to final output.
1212
@@ -80,7 +80,7 @@ class ExNN(BaseNet):
8080
8181
References
8282
----------
83-
.. Yang, Zebin, Aijun Zhang, and Agus Sudjianto. "Enhancing Explainability of Neural Networks through Architecture Constraints." arXiv preprint arXiv:1901.03838 (2019).
83+
.. Yang, Zebin, Aijun Zhang, and Agus Sudjianto. "Enhancing Explainability of Neural Networks through Architecture Constraints." TNNLS (2020).
8484
8585
"""
8686

exnn/gamnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
class GAMNet(BaseNet):
66
"""
7-
Generalized additive model vai neural network implementation. It is just a simplified version of sosxnn with identity projection layer.
7+
Generalized additive model vai neural network implementation. It is just a simplified version of exnn with identity projection layer.
88
99
Parameters
1010
----------

exnn/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ def __init__(self, category_num, bn_flag=True, cagetnet_id=0):
4646
self.cagetnet_id = cagetnet_id
4747

4848
self.categ_bias = self.add_weight(name="cate_bias_" + str(self.cagetnet_id),
49-
shape=[self.depth, 1],
49+
shape=[self.category_num, 1],
5050
initializer=tf.zeros_initializer(),
5151
trainable=True)
5252
self.moving_mean = self.add_weight(name="mean" + str(self.cagetnet_id), shape=[1], initializer=tf.zeros_initializer(), trainable=False)
5353
self.moving_norm = self.add_weight(name="norm" + str(self.cagetnet_id), shape=[1], initializer=tf.ones_initializer(), trainable=False)
5454

5555
def call(self, inputs, training=False):
5656

57-
dummy = tf.one_hot(indices=tf.cast(inputs[:, 0], tf.int32), depth=self.depth)
57+
dummy = tf.one_hot(indices=tf.cast(inputs[:, 0], tf.int32), depth=self.category_num)
5858
self.output_original = tf.matmul(dummy, self.categ_bias)
5959

6060
if training:

0 commit comments

Comments
 (0)