Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix)Make the weighted avarange fit for all kinds of systems #4593

Draft
wants to merge 18 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 55 additions & 71 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test(
)

if isinstance(dp, DeepPot):
err = test_ener(
err, find_energy, find_force, find_virial = test_ener(
dp,
data,
system,
Expand All @@ -143,6 +143,29 @@ def test(
atomic,
append_detail=(cc != 0),
)
err_part = {}

if find_energy == 1:
err_part["mae_e"] = err["mae_e"]
err_part["mae_ea"] = err["mae_ea"]
err_part["rmse_e"] = err["rmse_e"]
err_part["rmse_ea"] = err["rmse_ea"]

if find_force == 1:
if "rmse_f" in err:
err_part["mae_f"] = err["mae_f"]
err_part["rmse_f"] = err["rmse_f"]
else:
err_part["mae_fr"] = err["mae_fr"]
err_part["rmse_fr"] = err["rmse_fr"]
err_part["mae_fm"] = err["mae_fm"]
err_part["rmse_fm"] = err["rmse_fm"]
if find_virial == 1:
err_part["mae_v"] = err["mae_v"]
err_part["rmse_v"] = err["rmse_v"]

err = err_part

elif isinstance(dp, DeepDOS):
err = test_dos(
dp,
Expand Down Expand Up @@ -303,10 +326,11 @@ def test_ener(
if dp.has_spin:
data.add("spin", 3, atomic=True, must=True, high_prec=False)
data.add("force_mag", 3, atomic=True, must=False, high_prec=False)
if dp.has_hessian:
data.add("hessian", 1, atomic=True, must=True, high_prec=False)

test_data = data.get_test()
find_energy = test_data.get("find_energy")
Fixed Show fixed Hide fixed
find_force = test_data.get("find_force")
find_virial = test_data.get("find_virial")
Comment on lines +333 to +335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused flags.

The flags find_energy, find_force, and find_virial are obtained from test_data but are not used in error filtering since the filtering code is commented out. This creates dead code.

Apply this diff to remove the unused assignments:

-    find_energy = test_data.get("find_energy")
-    find_force = test_data.get("find_force")
-    find_virial = test_data.get("find_virial")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
find_energy = test_data.get("find_energy")
find_force = test_data.get("find_force")
find_virial = test_data.get("find_virial")
🧰 Tools
🪛 Ruff (0.8.2)

332-332: Local variable find_energy is assigned to but never used

Remove assignment to unused variable find_energy

(F841)

mixed_type = data.mixed_type
natoms = len(test_data["type"][0])
nframes = test_data["box"].shape[0]
Expand Down Expand Up @@ -354,9 +378,6 @@ def test_ener(
energy = energy.reshape([numb_test, 1])
force = force.reshape([numb_test, -1])
virial = virial.reshape([numb_test, 9])
if dp.has_hessian:
hessian = ret[3]
hessian = hessian.reshape([numb_test, -1])
if has_atom_ener:
ae = ret[3]
av = ret[4]
Expand Down Expand Up @@ -420,10 +441,6 @@ def test_ener(
rmse_ea = rmse_e / natoms
mae_va = mae_v / natoms
rmse_va = rmse_v / natoms
if dp.has_hessian:
diff_h = hessian - test_data["hessian"][:numb_test]
mae_h = mae(diff_h)
rmse_h = rmse(diff_h)
if has_atom_ener:
diff_ae = test_data["atom_ener"][:numb_test].reshape([-1]) - ae.reshape([-1])
mae_ae = mae(diff_ae)
Expand All @@ -439,26 +456,24 @@ def test_ener(
log.info(f"Energy RMSE : {rmse_e:e} eV")
log.info(f"Energy MAE/Natoms : {mae_ea:e} eV")
log.info(f"Energy RMSE/Natoms : {rmse_ea:e} eV")
if not out_put_spin:
log.info(f"Force MAE : {mae_f:e} eV/A")
log.info(f"Force RMSE : {rmse_f:e} eV/A")
else:
log.info(f"Force atom MAE : {mae_fr:e} eV/A")
log.info(f"Force atom RMSE : {rmse_fr:e} eV/A")
log.info(f"Force spin MAE : {mae_fm:e} eV/uB")
log.info(f"Force spin RMSE : {rmse_fm:e} eV/uB")
if find_force == 1:
if not out_put_spin:
log.info(f"Force MAE : {mae_f:e} eV/A")
log.info(f"Force RMSE : {rmse_f:e} eV/A")
else:
log.info(f"Force atom MAE : {mae_fr:e} eV/A")
log.info(f"Force atom RMSE : {rmse_fr:e} eV/A")
log.info(f"Force spin MAE : {mae_fm:e} eV/uB")
log.info(f"Force spin RMSE : {rmse_fm:e} eV/uB")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Resolve inconsistency in flag usage.

The flags find_force and find_virial are used in conditional logging but were marked for removal earlier. We need to maintain consistency in error handling and logging.

Consider one of these approaches:

  1. Keep the flags and restore the error filtering code
  2. Remove conditional logging and always log all available metrics

If keeping the flags, restore the error filtering code:

+    err_part = {}
+    if find_energy == 1:
+        err_part["mae_e"] = err["mae_e"]
+        err_part["mae_ea"] = err["mae_ea"]
+        err_part["rmse_e"] = err["rmse_e"]
+        err_part["rmse_ea"] = err["rmse_ea"]

Also applies to: 470-470

if data.pbc and not out_put_spin:
if data.pbc and not out_put_spin and find_virial == 1:
log.info(f"Virial MAE : {mae_v:e} eV")
log.info(f"Virial RMSE : {rmse_v:e} eV")
log.info(f"Virial MAE/Natoms : {mae_va:e} eV")
log.info(f"Virial RMSE/Natoms : {rmse_va:e} eV")
if has_atom_ener:
log.info(f"Atomic ener MAE : {mae_ae:e} eV")
log.info(f"Atomic ener RMSE : {rmse_ae:e} eV")
if dp.has_hessian:
log.info(f"Hessian MAE : {mae_h:e} eV/A^2")
log.info(f"Hessian RMSE : {rmse_h:e} eV/A^2")

if detail_file is not None:
detail_path = Path(detail_file)
Expand Down Expand Up @@ -542,24 +557,8 @@ def test_ener(
"pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz",
append=append_detail,
)
if dp.has_hessian:
data_h = test_data["hessian"][:numb_test].reshape(-1, 1)
pred_h = hessian.reshape(-1, 1)
h = np.concatenate(
(
data_h,
pred_h,
),
axis=1,
)
save_txt_file(
detail_path.with_suffix(".h.out"),
h,
header=f"{system}: data_h pred_h (3Na*3Na matrix in row-major order)",
append=append_detail,
)
if not out_put_spin:
dict_to_return = {
return {
"mae_e": (mae_e, energy.size),
"mae_ea": (mae_ea, energy.size),
"mae_f": (mae_f, force.size),
Expand All @@ -570,9 +569,9 @@ def test_ener(
"rmse_f": (rmse_f, force.size),
"rmse_v": (rmse_v, virial.size),
"rmse_va": (rmse_va, virial.size),
}
} ,find_energy,find_force,find_virial,
else:
dict_to_return = {
return {
"mae_e": (mae_e, energy.size),
"mae_ea": (mae_ea, energy.size),
"mae_fr": (mae_fr, force_r.size),
Expand All @@ -585,11 +584,7 @@ def test_ener(
"rmse_fm": (rmse_fm, force_m.size),
"rmse_v": (rmse_v, virial.size),
"rmse_va": (rmse_va, virial.size),
}
if dp.has_hessian:
dict_to_return["mae_h"] = (mae_h, hessian.size)
dict_to_return["rmse_h"] = (rmse_h, hessian.size)
return dict_to_return
} ,find_energy,find_force,find_virial,


def print_ener_sys_avg(avg: dict[str, float]) -> None:
Expand All @@ -616,9 +611,6 @@ def print_ener_sys_avg(avg: dict[str, float]) -> None:
log.info(f"Virial RMSE : {avg['rmse_v']:e} eV")
log.info(f"Virial MAE/Natoms : {avg['mae_va']:e} eV")
log.info(f"Virial RMSE/Natoms : {avg['rmse_va']:e} eV")
if "rmse_h" in avg.keys():
log.info(f"Hessian MAE : {avg['mae_h']:e} eV/A^2")
log.info(f"Hessian RMSE : {avg['rmse_h']:e} eV/A^2")


def test_dos(
Expand Down Expand Up @@ -739,9 +731,9 @@ def test_dos(
frame_output = np.hstack((test_out, pred_out))

save_txt_file(
detail_path.with_suffix(f".dos.out.{ii}"),
detail_path.with_suffix(".dos.out.%.d" % ii),
frame_output,
header=f"{system} - {ii}: data_dos pred_dos",
header="%s - %.d: data_dos pred_dos" % (system, ii),
append=append_detail,
)

Expand All @@ -753,9 +745,9 @@ def test_dos(
frame_output = np.hstack((test_out, pred_out))

save_txt_file(
detail_path.with_suffix(f".ados.out.{ii}"),
detail_path.with_suffix(".ados.out.%.d" % ii),
frame_output,
header=f"{system} - {ii}: data_ados pred_ados",
header="%s - %.d: data_ados pred_ados" % (system, ii),
append=append_detail,
)

Expand Down Expand Up @@ -814,17 +806,9 @@ def test_property(
tuple[list[np.ndarray], list[int]]
arrays with results and their shapes
"""
var_name = dp.get_var_name()
assert isinstance(var_name, str)
data.add(var_name, dp.task_dim, atomic=False, must=True, high_prec=True)
data.add("property", dp.task_dim, atomic=False, must=True, high_prec=True)
if has_atom_property:
data.add(
f"atom_{var_name}",
dp.task_dim,
atomic=True,
must=False,
high_prec=True,
)
data.add("atom_property", dp.task_dim, atomic=True, must=False, high_prec=True)

if dp.get_dim_fparam() > 0:
data.add(
Expand Down Expand Up @@ -875,12 +859,12 @@ def test_property(
aproperty = ret[1]
aproperty = aproperty.reshape([numb_test, natoms * dp.task_dim])

diff_property = property - test_data[var_name][:numb_test]
diff_property = property - test_data["property"][:numb_test]
mae_property = mae(diff_property)
rmse_property = rmse(diff_property)

if has_atom_property:
diff_aproperty = aproperty - test_data[f"atom_{var_name}"][:numb_test]
diff_aproperty = aproperty - test_data["atom_property"][:numb_test]
mae_aproperty = mae(diff_aproperty)
rmse_aproperty = rmse(diff_aproperty)

Expand All @@ -897,29 +881,29 @@ def test_property(
detail_path = Path(detail_file)

for ii in range(numb_test):
test_out = test_data[var_name][ii].reshape(-1, 1)
test_out = test_data["property"][ii].reshape(-1, 1)
pred_out = property[ii].reshape(-1, 1)

frame_output = np.hstack((test_out, pred_out))

save_txt_file(
detail_path.with_suffix(f".property.out.{ii}"),
detail_path.with_suffix(".property.out.%.d" % ii),
frame_output,
header=f"{system} - {ii}: data_property pred_property",
header="%s - %.d: data_property pred_property" % (system, ii),
append=append_detail,
)

if has_atom_property:
for ii in range(numb_test):
test_out = test_data[f"atom_{var_name}"][ii].reshape(-1, 1)
test_out = test_data["atom_property"][ii].reshape(-1, 1)
pred_out = aproperty[ii].reshape(-1, 1)

frame_output = np.hstack((test_out, pred_out))

save_txt_file(
detail_path.with_suffix(f".aproperty.out.{ii}"),
detail_path.with_suffix(".aproperty.out.%.d" % ii),
frame_output,
header=f"{system} - {ii}: data_aproperty pred_aproperty",
header="%s - %.d: data_aproperty pred_aproperty" % (system, ii),
append=append_detail,
)

Expand Down
Loading
Loading