@@ -129,6 +129,8 @@ def default_class_name(self, layer: 'keras.Layer') -> str:
129129 return layer .__class__ .__name__
130130
131131 def maybe_get_activation_config (self , layer , out_tensors ):
132+ import inspect
133+
132134 import keras
133135
134136 activation = getattr (layer , 'activation' , None )
@@ -139,12 +141,31 @@ def maybe_get_activation_config(self, layer, out_tensors):
139141 intermediate_tensor_name = f'{ out_tensors [0 ].name } _activation'
140142 act_cls_name = activation .__name__
141143 act_config = {
142- 'class_name' : 'Activation' ,
143144 'activation' : act_cls_name ,
144145 'name' : f'{ name } _{ act_cls_name } ' ,
145146 'input_keras_tensor_names' : [intermediate_tensor_name ],
146147 'output_keras_tensor_names' : [out_tensors [0 ].name ],
147148 }
149+
150+ # Check activation type & update parameters
151+ match activation :
152+ case keras .activations .softmax :
153+ class_name = 'Softmax'
154+ act_config ['axis' ] = - 1
155+ case keras .activations .hard_sigmoid :
156+ class_name = 'HardActivation'
157+ case keras .activations .leaky_relu :
158+ class_name = 'LeakyReLU'
159+ signature = inspect .signature (keras .activations .leaky_relu )
160+ act_config ['activ_param' ] = signature .parameters ['negative_slope' ].default
161+ case keras .activations .elu :
162+ class_name = 'ELU'
163+ signature = inspect .signature (keras .activations .elu )
164+ act_config ['activ_param' ] = signature .parameters ['alpha' ].default
165+ case _:
166+ class_name = 'Activation'
167+ act_config ['class_name' ] = class_name
168+
148169 return act_config , intermediate_tensor_name
149170 return None , None
150171
0 commit comments