Skip to content

Commit d66b647

Browse files
author
SAAS R7 User1
committed
fix a bug in catnet
1 parent d279ba1 commit d66b647

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

exnn/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,9 @@ def visualize(self, folder="./results/", name="demo", save_png=False, save_eps=F
441441

442442
if self.cfeature_num_ > 0:
443443
for indice in active_categ_index:
444-
feature_name = self.cfeature_list_[indice - self.numerical_input_num]
445-
dummy_gamma = self.categ_blocks.categnets[indice - self.numerical_input_num].categ_bias.numpy()
446-
norm = self.categ_blocks.categnets[indice - self.numerical_input_num].moving_norm.numpy()
444+
feature_name = self.cfeature_list_[indice - self.subnet_num]
445+
dummy_gamma = self.categ_blocks.categnets[indice - self.subnet_num].categ_bias.numpy()
446+
norm = self.categ_blocks.categnets[indice - self.subnet_num].moving_norm.numpy()
447447
ax3 = fig.add_subplot(np.int(max_ids), 1, np.int(max_ids))
448448
ax3.bar(np.arange(len(self.dummy_values_[feature_name])), np.sign(beta[indice]) * dummy_gamma[:, 0] / norm)
449449
ax3.set_xticks(np.arange(len(self.dummy_values_[feature_name])))

0 commit comments

Comments
 (0)