diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index ee33feb0af..8b4eebc9da 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -466,7 +466,9 @@ def build( # nframes x nloc_masked constant_matrix = tf.reshape( tf.reshape( - tf.tile(tf.repeat(self.t_bias_atom_polar, natoms[2:]), [nframes]), + tf.tile( + tf.repeat(self.t_bias_atom_polar, natoms[2:]), [nframes] + ), [nframes, -1], )[nloc_mask], [nframes, -1], @@ -518,7 +520,9 @@ def build( # shift and scale sel_type_idx = self.sel_type.index(type_i) final_layer = final_layer * self.scale[sel_type_idx] - final_layer = final_layer + tf.slice(self.t_bias_atom_polar, [sel_type_idx], [1]) * tf.eye( + final_layer = final_layer + tf.slice( + self.t_bias_atom_polar, [sel_type_idx], [1] + ) * tf.eye( 3, batch_shape=[tf.shape(inputs)[0], natoms[2 + type_i]], dtype=GLOBAL_TF_FLOAT_PRECISION,