-
Notifications
You must be signed in to change notification settings - Fork 599
feat(pt_expt): add training infrastructure #5270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
wanghan-iapcm
merged 76 commits into
deepmodeling:master
from
wanghan-iapcm:feat-pt-expt-train
Mar 5, 2026
Merged
Changes from all commits
Commits
Show all changes
76 commits
Select commit
Hold shift + click to select a range
8356629
feat(pt_expt): add dos, dipole, polar and property fittings
292fa72
add make_fx, mv itertools to parameterized
7289167
feat(pt_expt): full models dipole, polar, dos, property and dp-zbl
553b91d
rm _forward_lower
0753cd7
rm register_dpmodel_mapping from fitting
6d6adfe
remove the atomic model in pt_expt. mv atomic model's output stat tes…
bf448ac
add translated_output_def
9572a04
base model registration
0dcd03b
implement compute_or_load_stat
28fbd08
fix bug in test_ener
237e4a8
refact make_model, concrete models from different backends inherit fr…
2a958ec
Add compute_or_load_stat consistency tests and fix dpmodel backend bugs
41af959
rm tmp test files
b2028a8
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
5a4a5d2
remove concrete methods and data from BaseModel
cab1b35
Merge branch 'master' into feat-other-full-model
19f9058
rm model_type
26b0a40
fix spin model
356a1e6
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
eecd82b
add get_observed_type_list to abstract API and implement in dpmodel
aba2d71
fix: dpmodel change_type_map drops model_with_new_type_stat and uses …
21dc4e7
consolidate get_out_bias/set_out_bias into base_atomic_model
61722b9
change fitting -> fitting_net
c41515a
fix: dpmodel change_out_bias missing compute_fitting_input_stat for s…
124eedd
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
3827a9c
fix bug
9e926bf
fix bug
f1dbd4f
add missing get_observed_type_list to paddel
df132d4
add tests for get_model_def_script get_min_nbor_dist and set_case_embd
0c169cb
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
1da8708
fix hlo
b4d43f0
add dipole model api tests. mv get_observed_type_list to base
6ac0cef
fix frozen model
4b54857
add polar model api tests.
c35ee54
add property model api tests, fix bugs
11c0201
add dos test, fix bug
eb32961
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
fbfd042
add ut for dp-zbl model
b49a10f
add test_get_model_def_script test_get_min_nbor_dist test_set_case_em…
d42c8d8
chore(pt): mv the input stat update to model_change_out_bias to keep …
b7af468
chore(pd): update in the same way as pt
0ec5748
update test for change out bias
00f83cc
test the stat is changed
ff4a27c
rm unused methods
15f2af8
use deep copy
35d4cbe
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
4540382
Extracted compare_variables_recursive to source/tests/consistent/mode…
aa2643e
Merge branch 'master' into feat-other-full-model
9427d25
fix: remove dead code and redundant assignments in dpmodel atomic models
4af1245
feat(pt_expt): add training infrastructure using DeepmdDataSystem
6ff1e2b
Merge branch 'feat-other-full-model' into feat-pt-expt-train
8be14bf
fix bug in test
f3e1b90
enable model compile. print wall-time in stdout
d6cecc2
padding nall so the model is compiled with mode `real`. fix bug in e…
4e8003a
more robust stat
04ed4cb
refactor: unify make_stat_input and validate fparam/aparam in stat
b9cb358
fix bug
934679e
Merge branch 'master' into feat-pt-expt-train
35e4edf
fix
b42e08d
changed find_ to int. refactorizing get_data
3e8ba6e
port dpmodel loss to pt_expt and use user-facing model output keys
99190cc
Merge branch 'master' into feat-pt-expt-train
wanghan-iapcm dfdf2ba
Merge remote-tracking branch 'origin/feat-pt-expt-train' into feat-pt…
18b39e3
use full like
f4675d1
fix bugs in ut
4eb92d3
fix: compiled training force loss not decreasing
fc980d5
fix the rmse_v inconsistency but in the dp backend. improve the consi…
3656a69
fix tests
6e8607e
Merge branch 'master' into feat-pt-expt-train
429bae6
fix ut
fb2eb42
fix bug
9b430a4
add training ut for dpa3 model
cb9f55b
fix: load checkpoint before torch.compile to support restart
9793548
fix: remove + self.start_step from both lambdas, since last_epoch alr…
99c0221
add ut
f0f260d
feat(pt_expt): run update_sel neighbor statistics on GPU
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| """Normalize raw batches from DeepmdDataSystem into canonical format.""" | ||
|
|
||
| from typing import ( | ||
| Any, | ||
| ) | ||
|
|
||
| import numpy as np | ||
|
|
||
| # Keys that are metadata / not needed by models or loss functions. | ||
| _DROP_KEYS = {"default_mesh", "sid", "fid"} | ||
|
|
||
| # Keys that belong to model input (everything else is label). | ||
| _INPUT_KEYS = {"coord", "atype", "spin", "box", "fparam", "aparam"} | ||
|
|
||
|
|
||
| def normalize_batch(batch: dict[str, Any]) -> dict[str, Any]: | ||
| """Normalize a raw batch from :class:`DeepmdDataSystem` to canonical format. | ||
|
|
||
| The following conversions are applied: | ||
|
|
||
| * ``"type"`` is renamed to ``"atype"`` (int64). | ||
| * ``"natoms_vec"`` (1-D) is tiled to 2-D ``[nframes, 2+ntypes]`` | ||
| and stored as ``"natoms"``. | ||
| * ``find_*`` flags are converted to ``np.bool_``. | ||
| * Metadata keys (``default_mesh``, ``sid``, ``fid``) are dropped. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| batch : dict[str, Any] | ||
| Raw batch dict returned by ``DeepmdDataSystem.get_batch()``. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, Any] | ||
| Normalized batch dict (new dict; the input is not mutated). | ||
| """ | ||
| out: dict[str, Any] = {} | ||
|
|
||
| for key, val in batch.items(): | ||
| if key in _DROP_KEYS: | ||
| continue | ||
|
|
||
| if key == "type": | ||
| out["atype"] = val.astype(np.int64) | ||
| elif key.startswith("find_"): | ||
| out[key] = np.bool_(float(val) > 0.5) | ||
wanghan-iapcm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| elif key == "natoms_vec": | ||
| nv = val | ||
| if nv.ndim == 1 and "coord" in batch: | ||
| nframes = batch["coord"].shape[0] | ||
| nv = np.tile(nv, (nframes, 1)) | ||
| out["natoms"] = nv | ||
| else: | ||
| out[key] = val | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| def split_batch( | ||
| batch: dict[str, Any], | ||
| ) -> tuple[dict[str, Any], dict[str, Any]]: | ||
| """Split a normalized batch into input and label dicts. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| batch : dict[str, Any] | ||
| Normalized batch (output of :func:`normalize_batch`). | ||
|
|
||
| Returns | ||
| ------- | ||
| input_dict : dict[str, Any] | ||
| Model inputs (coord, atype, box, fparam, aparam, spin). | ||
| label_dict : dict[str, Any] | ||
| Labels and find flags (energy, force, virial, find_*, natoms, …). | ||
| """ | ||
| input_dict: dict[str, Any] = {} | ||
| label_dict: dict[str, Any] = {} | ||
|
|
||
| for key, val in batch.items(): | ||
| if key in _INPUT_KEYS: | ||
| input_dict[key] = val | ||
| else: | ||
| label_dict[key] = val | ||
|
|
||
| return input_dict, label_dict | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.