We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3a502c4 commit 65b2c45Copy full SHA for 65b2c45
exnn/base.py
@@ -156,7 +156,7 @@ def get_active_subnets(self, beta_threshold=0):
156
if self.bn_flag:
157
beta = self.output_layer.output_weights.numpy()
158
else:
159
- subnet_norm = [self.subnet_blocks.subnets[i].moving_norm.numpy()[0] for i in range(self.numerical_input_num)]
+ subnet_norm = [self.subnet_blocks.subnets[i].moving_norm.numpy()[0] for i in range(self.subnet_num)]
160
categ_norm = [self.categ_blocks.categnets[i].moving_norm.numpy()[0]for i in range(self.categ_variable_num)]
161
beta = self.output_layer.output_weights.numpy() * np.hstack([subnet_norm, categ_norm]).reshape([-1, 1])
162
beta = beta * self.output_layer.output_switcher.numpy()
0 commit comments