@@ -844,6 +844,8 @@ def build_model(self,config,placeholders):
844
844
assert (self .version == 8 or self .version == 10 )
845
845
846
846
#Input layer---------------------------------------------------------------------------------
847
+ #tf.compat.v1.disable_eager_execution() # Important, fix for tensorflow 2.4
848
+ tf .compat .v1 .disable_v2_behavior ()
847
849
bin_inputs = (placeholders ["bin_inputs" ] if "bin_inputs" in placeholders else
848
850
tf .compat .v1 .placeholder (tf .float32 , [None ] + self .bin_input_shape , name = "bin_inputs" ))
849
851
global_inputs = (placeholders ["global_inputs" ] if "global_inputs" in placeholders else
@@ -1268,6 +1270,8 @@ def __init__(self,model,for_optimization,placeholders):
1268
1270
shortterm_value_error_prediction = tf .math .softplus (moremiscvalues_output [:,0 ]) * 0.25
1269
1271
shortterm_score_error_prediction = tf .math .softplus (moremiscvalues_output [:,1 ]) * 30.0
1270
1272
1273
+ tf .compat .v1 .disable_v2_behavior ()
1274
+
1271
1275
#Loss function
1272
1276
self .policy_target = (placeholders ["policy_target" ] if "policy_target" in placeholders else
1273
1277
tf .compat .v1 .placeholder (tf .float32 , [None ] + model .policy_target_shape ))
@@ -1545,6 +1549,7 @@ def __init__(self,model,for_optimization,placeholders):
1545
1549
1546
1550
if for_optimization :
1547
1551
#Prior/Regularization
1552
+ tf .compat .v1 .disable_v2_behavior ()
1548
1553
self .l2_reg_coeff = (placeholders ["l2_reg_coeff" ] if "l2_reg_coeff" in placeholders else
1549
1554
tf .compat .v1 .placeholder (tf .float32 ))
1550
1555
self .reg_loss_per_weight = self .l2_reg_coeff * (
0 commit comments