Skip to content

Commit

Permalink
Merge branch 'bias_atom_polar' of github.com:Yi-FanLi/deepmd-kit into…
Browse files Browse the repository at this point in the history
… bias_atom_polar
  • Loading branch information
Yi-FanLi committed Feb 6, 2025
2 parents b794731 + 4390144 commit 8259be7
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions deepmd/tf/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,9 @@ def build(
# nframes x nloc_masked
constant_matrix = tf.reshape(
tf.reshape(
tf.tile(tf.repeat(self.t_bias_atom_polar, natoms[2:]), [nframes]),
tf.tile(
tf.repeat(self.t_bias_atom_polar, natoms[2:]), [nframes]
),
[nframes, -1],
)[nloc_mask],
[nframes, -1],
Expand Down Expand Up @@ -518,7 +520,9 @@ def build(
# shift and scale
sel_type_idx = self.sel_type.index(type_i)
final_layer = final_layer * self.scale[sel_type_idx]
final_layer = final_layer + tf.slice(self.t_bias_atom_polar, [sel_type_idx], [1]) * tf.eye(
final_layer = final_layer + tf.slice(
self.t_bias_atom_polar, [sel_type_idx], [1]
) * tf.eye(
3,
batch_shape=[tf.shape(inputs)[0], natoms[2 + type_i]],
dtype=GLOBAL_TF_FLOAT_PRECISION,
Expand Down Expand Up @@ -565,7 +569,7 @@ def init_variables(
self.fitting_net_variables = get_fitting_net_variables_from_graph_def(
graph_def, suffix=suffix
)
if self.shift_diag
if self.shift_diag:
try:
self.bias_atom_polar = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_bias_atom_polar"
Expand Down

0 comments on commit 8259be7

Please sign in to comment.