Skip to content

Commit 617d3e2

Browse files
HydrogenSulfatecoderabbitai[bot]pre-commit-ci[bot]
authored
pd: support different label_dict in CINN (#4795)
pop unnecessary item when wrapping model with `jit.to_static`, so we can support se_e2_a/dpa2/dpa3 without extra modification. @njzjz can you give some suggestions for better code improvements? The current approach of fetching data via `self.get_data` isn't very concise. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved compatibility by dynamically matching label input specifications to available label keys during model compilation when CINN is enabled. This prevents errors caused by mismatched label keys at runtime. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: HydrogenSulfate <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a3677b6 commit 617d3e2

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

deepmd/pd/train/training.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,27 @@ def warm_up_linear(step, warmup_steps):
607607
)
608608

609609
backend = "CINN" if CINN else None
610+
# NOTE: This is a trick to decide the right input_spec for wrapper.forward
611+
_, label_dict, _ = self.get_data(is_train=True)
612+
613+
# Define specification templates
614+
spec_templates = {
615+
"find_box": np.float32(1.0),
616+
"find_coord": np.float32(1.0),
617+
"find_numb_copy": np.float32(0.0),
618+
"numb_copy": static.InputSpec([1, 1], "int64", name="numb_copy"),
619+
"find_energy": np.float32(1.0),
620+
"energy": static.InputSpec([1, 1], "float64", name="energy"),
621+
"find_force": np.float32(1.0),
622+
"force": static.InputSpec([1, -1, 3], "float64", name="force"),
623+
"find_virial": np.float32(0.0),
624+
"virial": static.InputSpec([1, 9], "float64", name="virial"),
625+
"natoms": static.InputSpec([1, -1], "int32", name="natoms"),
626+
}
627+
# Build spec only for keys present in sample data
628+
label_dict_spec = {
629+
k: spec_templates[k] for k in label_dict.keys() if k in spec_templates
630+
}
610631
self.wrapper.forward = jit.to_static(
611632
backend=backend,
612633
input_spec=[
@@ -615,19 +636,7 @@ def warm_up_linear(step, warmup_steps):
615636
None, # spin
616637
static.InputSpec([1, 9], "float64", name="box"), # box
617638
static.InputSpec([], "float64", name="cur_lr"), # cur_lr
618-
{
619-
"find_box": np.float32(1.0),
620-
"find_coord": np.float32(1.0),
621-
"find_numb_copy": np.float32(0.0),
622-
"numb_copy": static.InputSpec(
623-
[1, 1], "int64", name="numb_copy"
624-
),
625-
"find_energy": np.float32(1.0),
626-
"energy": static.InputSpec([1, 1], "float64", name="energy"),
627-
"find_force": np.float32(1.0),
628-
"force": static.InputSpec([1, -1, 3], "float64", name="force"),
629-
"natoms": static.InputSpec([1, -1], "int32", name="natoms"),
630-
}, # label,
639+
label_dict_spec, # label,
631640
# None, # task_key
632641
# False, # inference_only
633642
# False, # do_atomic_virial

0 commit comments

Comments
 (0)