Skip to content

Commit 0d0f719

Browse files
committed
further fix for tf2
1 parent e49af3f commit 0d0f719

File tree

3 files changed

+7
-0
lines changed

3 files changed

+7
-0
lines changed

python/model.py

+5
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,8 @@ def build_model(self,config,placeholders):
844844
assert(self.version == 8 or self.version == 10)
845845

846846
#Input layer---------------------------------------------------------------------------------
847+
#tf.compat.v1.disable_eager_execution() # Important, fix for tensorflow 2.4
848+
tf.compat.v1.disable_v2_behavior()
847849
bin_inputs = (placeholders["bin_inputs"] if "bin_inputs" in placeholders else
848850
tf.compat.v1.placeholder(tf.float32, [None] + self.bin_input_shape, name="bin_inputs"))
849851
global_inputs = (placeholders["global_inputs"] if "global_inputs" in placeholders else
@@ -1268,6 +1270,8 @@ def __init__(self,model,for_optimization,placeholders):
12681270
shortterm_value_error_prediction = tf.math.softplus(moremiscvalues_output[:,0]) * 0.25
12691271
shortterm_score_error_prediction = tf.math.softplus(moremiscvalues_output[:,1]) * 30.0
12701272

1273+
tf.compat.v1.disable_v2_behavior()
1274+
12711275
#Loss function
12721276
self.policy_target = (placeholders["policy_target"] if "policy_target" in placeholders else
12731277
tf.compat.v1.placeholder(tf.float32, [None] + model.policy_target_shape))
@@ -1545,6 +1549,7 @@ def __init__(self,model,for_optimization,placeholders):
15451549

15461550
if for_optimization:
15471551
#Prior/Regularization
1552+
tf.compat.v1.disable_v2_behavior()
15481553
self.l2_reg_coeff = (placeholders["l2_reg_coeff"] if "l2_reg_coeff" in placeholders else
15491554
tf.compat.v1.placeholder(tf.float32))
15501555
self.reg_loss_per_weight = self.l2_reg_coeff * (

python/tfrecordio.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def make_raw_input_feature_placeholders(model_config,pos_len,batch_size):
2222
num_bin_input_features = Model.get_num_bin_input_features(model_config)
2323
num_global_input_features = Model.get_num_global_input_features(model_config)
2424

25+
tf.compat.v1.disable_v2_behavior()
2526
return {
2627
"binchwp": tf.compat.v1.placeholder(tf.uint8,[batch_size,num_bin_input_features,(pos_len*pos_len+7)//8]),
2728
"ginc": tf.compat.v1.placeholder(tf.float32,[batch_size,num_global_input_features]),

python/train.py

+1
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def trainlog(s):
173173
assign_ops = []
174174
for variable in itertools.chain(tf.compat.v1.model_variables(), tf.compat.v1.trainable_variables()):
175175
if variable.name.startswith("swa_model/"):
176+
tf.compat.v1.disable_v2_behavior()
176177
placeholder = tf.compat.v1.placeholder(variable.dtype,variable.shape)
177178
assign_ops.append(tf.compat.v1.assign(variable,placeholder))
178179
swa_assign_placeholders[variable.name] = placeholder

0 commit comments

Comments
 (0)