From fc0f88f45bf47f2c8916c359c09860f1f651b1a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yifan=20Li=E6=9D=8E=E4=B8=80=E5=B8=86?= Date: Thu, 6 Feb 2025 03:37:31 -0500 Subject: [PATCH 1/2] fix syntax error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Yifan Li李一帆 --- deepmd/tf/fit/polar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 2fe85dbc1c..ee33feb0af 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -565,7 +565,7 @@ def init_variables( self.fitting_net_variables = get_fitting_net_variables_from_graph_def( graph_def, suffix=suffix ) - if self.shift_diag + if self.shift_diag: try: self.bias_atom_polar = get_tensor_by_name_from_graph( graph, f"fitting_attr{suffix}/t_bias_atom_polar" From 4390144186ddc012a059fcbdcc2b68ce6b3f98c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Feb 2025 08:40:23 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/tf/fit/polar.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index ee33feb0af..8b4eebc9da 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -466,7 +466,9 @@ def build( # nframes x nloc_masked constant_matrix = tf.reshape( tf.reshape( - tf.tile(tf.repeat(self.t_bias_atom_polar, natoms[2:]), [nframes]), + tf.tile( + tf.repeat(self.t_bias_atom_polar, natoms[2:]), [nframes] + ), [nframes, -1], )[nloc_mask], [nframes, -1], @@ -518,7 +520,9 @@ def build( # shift and scale sel_type_idx = self.sel_type.index(type_i) final_layer = final_layer * self.scale[sel_type_idx] - final_layer = final_layer + tf.slice(self.t_bias_atom_polar, [sel_type_idx], [1]) * tf.eye( + final_layer = final_layer + tf.slice( + self.t_bias_atom_polar, [sel_type_idx], [1] + ) * tf.eye( 3, batch_shape=[tf.shape(inputs)[0], natoms[2 + type_i]], dtype=GLOBAL_TF_FLOAT_PRECISION,