Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
8356629
feat(pt_expt): add dos, dipole, polar and property fittings
Feb 22, 2026
292fa72
add make_fx, mv itertools to parameterized
Feb 22, 2026
7289167
feat(pt_expt): full models dipole, polar, dos, property and dp-zbl
Feb 22, 2026
553b91d
rm _forward_lower
Feb 22, 2026
0753cd7
rm register_dpmodel_mapping from fitting
Feb 22, 2026
6d6adfe
remove the atomic model in pt_expt. mv atomic model's output stat tes…
Feb 22, 2026
bf448ac
add translated_output_def
Feb 22, 2026
9572a04
base model registration
Feb 22, 2026
0dcd03b
implement compute_or_load_stat
Feb 23, 2026
28fbd08
fix bug in test_ener
Feb 23, 2026
237e4a8
refact make_model, concrete models from different backends inherit fr…
Feb 23, 2026
2a958ec
Add compute_or_load_stat consistency tests and fix dpmodel backend bugs
Feb 23, 2026
41af959
rm tmp test files
Feb 24, 2026
b2028a8
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 24, 2026
5a4a5d2
remove concrete methods and data from BaseModel
Feb 24, 2026
cab1b35
Merge branch 'master' into feat-other-full-model
Feb 24, 2026
19f9058
rm model_type
Feb 24, 2026
26b0a40
fix spin model
Feb 24, 2026
356a1e6
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 24, 2026
eecd82b
add get_observed_type_list to abstract API and implement in dpmodel
Feb 24, 2026
aba2d71
fix: dpmodel change_type_map drops model_with_new_type_stat and uses …
Feb 24, 2026
21dc4e7
consolidate get_out_bias/set_out_bias into base_atomic_model
Feb 24, 2026
61722b9
change fitting -> fitting_net
Feb 24, 2026
c41515a
fix: dpmodel change_out_bias missing compute_fitting_input_stat for s…
Feb 24, 2026
124eedd
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 24, 2026
3827a9c
fix bug
Feb 24, 2026
9e926bf
fix bug
Feb 24, 2026
f1dbd4f
add missing get_observed_type_list to paddel
Feb 24, 2026
df132d4
add tests for get_model_def_script get_min_nbor_dist and set_case_embd
Feb 24, 2026
0c169cb
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 24, 2026
1da8708
fix hlo
Feb 24, 2026
b4d43f0
add dipole model api tests. mv get_observed_type_list to base
Feb 24, 2026
6ac0cef
fix frozen model
Feb 24, 2026
4b54857
add polar model api tests.
Feb 24, 2026
c35ee54
add property model api tests, fix bugs
Feb 24, 2026
11c0201
add dos test, fix bug
Feb 24, 2026
eb32961
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 24, 2026
fbfd042
add ut for dp-zbl model
Feb 25, 2026
b49a10f
add test_get_model_def_script test_get_min_nbor_dist test_set_case_em…
Feb 25, 2026
d42c8d8
chore(pt): mv the input stat update to model_change_out_bias to keep …
Feb 25, 2026
b7af468
chore(pd): update in the same way as pt
Feb 25, 2026
0ec5748
update test for change out bias
Feb 25, 2026
00f83cc
test the stat is changed
Feb 25, 2026
ff4a27c
rm unused methods
Feb 25, 2026
15f2af8
use deep copy
Feb 25, 2026
35d4cbe
Merge remote-tracking branch 'origin/feat-other-full-model' into feat…
Feb 25, 2026
4540382
Extracted compare_variables_recursive to source/tests/consistent/mode…
Feb 26, 2026
aa2643e
Merge branch 'master' into feat-other-full-model
Feb 26, 2026
9427d25
fix: remove dead code and redundant assignments in dpmodel atomic models
Feb 26, 2026
4af1245
feat(pt_expt): add training infrastructure using DeepmdDataSystem
Feb 26, 2026
6ff1e2b
Merge branch 'feat-other-full-model' into feat-pt-expt-train
Feb 26, 2026
8be14bf
fix bug in test
Feb 26, 2026
f3e1b90
enable model compile. print wall-time in stdout
Feb 26, 2026
d6cecc2
padding nall so the model is compiled with mode `real`. fix bug in e…
Feb 26, 2026
4e8003a
more robust stat
Feb 26, 2026
04ed4cb
refactor: unify make_stat_input and validate fparam/aparam in stat
Feb 27, 2026
b9cb358
fix bug
Feb 27, 2026
934679e
Merge branch 'master' into feat-pt-expt-train
Feb 27, 2026
35e4edf
fix
Feb 27, 2026
b42e08d
changed find_ to int. refactorizing get_data
Feb 27, 2026
3e8ba6e
port dpmodel loss to pt_expt and use user-facing model output keys
Feb 27, 2026
99190cc
Merge branch 'master' into feat-pt-expt-train
wanghan-iapcm Feb 27, 2026
dfdf2ba
Merge remote-tracking branch 'origin/feat-pt-expt-train' into feat-pt…
Feb 27, 2026
18b39e3
use full like
Feb 27, 2026
f4675d1
fix bugs in ut
Feb 27, 2026
4eb92d3
fix: compiled training force loss not decreasing
Feb 27, 2026
fc980d5
fix the rmse_v inconsistency but in the dp backend. improve the consi…
Feb 27, 2026
3656a69
fix tests
Feb 27, 2026
6e8607e
Merge branch 'master' into feat-pt-expt-train
Mar 1, 2026
429bae6
fix ut
Mar 1, 2026
fb2eb42
fix bug
Mar 2, 2026
9b430a4
add training ut for dpa3 model
Mar 2, 2026
cb9f55b
fix: load checkpoint before torch.compile to support restart
Mar 2, 2026
9793548
fix: remove + self.start_step from both lambdas, since last_epoch alr…
Mar 2, 2026
99c0221
add ut
Mar 2, 2026
f0f260d
feat(pt_expt): run update_sel neighbor statistics on GPU
Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/backend/pt_expt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
class PyTorchExportableBackend(Backend):
"""PyTorch exportable backend."""

name = "PyTorch Exportable"
name = "PyTorch-Exportable"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature.ENTRY_POINT
Expand Down Expand Up @@ -63,7 +63,7 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]:
Callable[[Namespace], None]
The entry point hook of the backend.
"""
from deepmd.pt.entrypoints.main import main as deepmd_main
from deepmd.pt_expt.entrypoints.main import main as deepmd_main

return deepmd_main

Expand Down
18 changes: 9 additions & 9 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,19 +332,19 @@ def wrapped_sampler() -> list[dict]:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
if (
"find_fparam" not in sampled[0]
and "fparam" not in sampled[0]
and self.has_default_fparam()
):
# For systems where fparam is missing (find_fparam == 0),
# fill with default fparam if available and mark as found.
if self.has_default_fparam():
default_fparam = self.get_default_fparam()
if default_fparam is not None:
default_fparam_np = np.array(default_fparam)
for sample in sampled:
nframe = sample["atype"].shape[0]
sample["fparam"] = np.tile(
default_fparam_np.reshape(1, -1), (nframe, 1)
)
if "find_fparam" in sample and not sample["find_fparam"]:
nframe = sample["atype"].shape[0]
sample["fparam"] = np.tile(
default_fparam_np.reshape(1, -1), (nframe, 1)
)
sample["find_fparam"] = np.bool_(True)
return sampled

return wrapped_sampler
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ class DescrptDPA1(NativeOP, BaseDescriptor):
arXiv preprint arXiv:2208.08236.
"""

_update_sel_cls = UpdateSel

def __init__(
self,
rcut: float,
Expand Down Expand Up @@ -662,7 +664,7 @@ def update_sel(
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
min_nbor_dist, sel = UpdateSel().update_one_sel(
min_nbor_dist, sel = cls._update_sel_cls().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True
)
local_jdata_cpy["sel"] = sel[0]
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ class DescrptDPA2(NativeOP, BaseDescriptor):
Comput Mater 10, 293 (2024). https://doi.org/10.1038/s41524-024-01493-2
"""

_update_sel_cls = UpdateSel

def __init__(
self,
ntypes: int,
Expand Down Expand Up @@ -1114,7 +1116,7 @@ def update_sel(
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
update_sel = UpdateSel()
update_sel = cls._update_sel_cls()
min_nbor_dist, repinit_sel = update_sel.update_one_sel(
train_data,
type_map,
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ class DescrptDPA3(NativeOP, BaseDescriptor):
arXiv preprint arXiv:2506.01686 (2025).
"""

_update_sel_cls = UpdateSel

def __init__(
self,
ntypes: int,
Expand Down Expand Up @@ -729,7 +731,7 @@ def update_sel(
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
update_sel = UpdateSel()
update_sel = cls._update_sel_cls()
min_nbor_dist, repflow_e_sel = update_sel.update_one_sel(
train_data,
type_map,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class BD(ABC, PluginVariant, make_plugin_registry("descriptor")):
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
if cls is BD:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)
return object.__new__(cls)

@abstractmethod
def get_rcut(self) -> float:
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ class DescrptSeA(NativeOP, BaseDescriptor):
Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 4441-4451.
"""

_update_sel_cls = UpdateSel

def __init__(
self,
rcut: float,
Expand Down Expand Up @@ -582,7 +584,7 @@ def update_sel(
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(
min_nbor_dist, local_jdata_cpy["sel"] = cls._update_sel_cls().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class DescrptSeR(NativeOP, BaseDescriptor):
Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 4441-4451.
"""

_update_sel_cls = UpdateSel

def __init__(
self,
rcut: float,
Expand Down Expand Up @@ -505,7 +507,7 @@ def update_sel(
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(
min_nbor_dist, local_jdata_cpy["sel"] = cls._update_sel_cls().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class DescrptSeT(NativeOP, BaseDescriptor):
Not used in this descriptor, only to be compat with input.
"""

_update_sel_cls = UpdateSel

def __init__(
self,
rcut: float,
Expand Down Expand Up @@ -505,7 +507,7 @@ def update_sel(
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(
min_nbor_dist, local_jdata_cpy["sel"] = cls._update_sel_cls().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ class DescrptSeTTebd(NativeOP, BaseDescriptor):

"""

_update_sel_cls = UpdateSel

def __init__(
self,
rcut: float,
Expand Down Expand Up @@ -500,7 +502,7 @@ def update_sel(
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
min_nbor_dist, sel = UpdateSel().update_one_sel(
min_nbor_dist, sel = cls._update_sel_cls().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True
)
local_jdata_cpy["sel"] = sel[0]
Expand Down
22 changes: 22 additions & 0 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,17 @@ def compute_input_stats(
)
else:
sampled = merged() if callable(merged) else merged
for ii, frame in enumerate(sampled):
if "find_fparam" not in frame:
raise ValueError(
f"numb_fparam > 0 but fparam is not acquired "
f"for system {ii}."
)
if not frame["find_fparam"]:
raise ValueError(
f"numb_fparam > 0 but no fparam data is provided "
f"for system {ii}."
)
cat_data = np.concatenate(
[frame["fparam"] for frame in sampled], axis=0
)
Expand Down Expand Up @@ -313,6 +324,17 @@ def compute_input_stats(
)
else:
sampled = merged() if callable(merged) else merged
for ii, frame in enumerate(sampled):
if "find_aparam" not in frame:
raise ValueError(
f"numb_aparam > 0 but aparam is not acquired "
f"for system {ii}."
)
if not frame["find_aparam"]:
raise ValueError(
f"numb_aparam > 0 but no aparam data is provided "
f"for system {ii}."
)
sys_sumv = []
sys_sumv2 = []
sys_sumn = []
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class BF(ABC, PluginVariant, make_plugin_registry("fitting")):
def __new__(cls: type, *args: Any, **kwargs: Any) -> Any:
if cls is BF:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)
return object.__new__(cls)

@abstractmethod
def output_def(self) -> FittingOutputDef:
Expand Down
10 changes: 5 additions & 5 deletions deepmd/dpmodel/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def call(
label_dict: dict[str, Array],
) -> dict[str, Array]:
"""Calculate loss from model results and labeled results."""
energy = model_dict["energy_redu"]
force = model_dict["energy_derv_r"]
virial = model_dict["energy_derv_c_redu"]
atom_ener = model_dict["energy"]
energy = model_dict["energy"]
force = model_dict["force"]
virial = model_dict["virial"]
atom_ener = model_dict["atom_energy"]
energy_hat = label_dict["energy"]
force_hat = label_dict["force"]
virial_hat = label_dict["virial"]
Expand Down Expand Up @@ -212,7 +212,7 @@ def call(
)
loss += pref_v * l_huber_loss
more_loss["rmse_v"] = self.display_if_exist(
xp.sqrt(l2_virial_loss), find_virial
xp.sqrt(l2_virial_loss) * atom_norm, find_virial
)
if self.has_ae:
atom_ener_reshape = xp.reshape(atom_ener, (-1,))
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ def display_if_exist(loss: Array, find_property: float) -> Array:
the loss scalar or NaN
"""
xp = array_api_compat.array_namespace(loss)
dev = array_api_compat.device(loss)
return xp.where(
xp.asarray(find_property, dtype=xp.bool), loss, xp.asarray(xp.nan)
xp.asarray(find_property, dtype=xp.bool, device=dev),
loss,
xp.asarray(xp.nan, device=dev),
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "BaseModel":
if model_type == "standard":
model_type = kwargs.get("fitting", {}).get("type", "ener")
cls = cls.get_class_by_type(model_type)
return super().__new__(cls)
return object.__new__(cls)

@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any:
Expand Down
86 changes: 86 additions & 0 deletions deepmd/dpmodel/utils/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Normalize raw batches from DeepmdDataSystem into canonical format."""

from typing import (
Any,
)

import numpy as np

# Keys that are metadata / not needed by models or loss functions.
_DROP_KEYS = {"default_mesh", "sid", "fid"}

# Keys that belong to model input (everything else is label).
_INPUT_KEYS = {"coord", "atype", "spin", "box", "fparam", "aparam"}


def normalize_batch(batch: dict[str, Any]) -> dict[str, Any]:
"""Normalize a raw batch from :class:`DeepmdDataSystem` to canonical format.

The following conversions are applied:

* ``"type"`` is renamed to ``"atype"`` (int64).
* ``"natoms_vec"`` (1-D) is tiled to 2-D ``[nframes, 2+ntypes]``
and stored as ``"natoms"``.
* ``find_*`` flags are converted to ``np.bool_``.
* Metadata keys (``default_mesh``, ``sid``, ``fid``) are dropped.

Parameters
----------
batch : dict[str, Any]
Raw batch dict returned by ``DeepmdDataSystem.get_batch()``.

Returns
-------
dict[str, Any]
Normalized batch dict (new dict; the input is not mutated).
"""
out: dict[str, Any] = {}

for key, val in batch.items():
if key in _DROP_KEYS:
continue

if key == "type":
out["atype"] = val.astype(np.int64)
elif key.startswith("find_"):
out[key] = np.bool_(float(val) > 0.5)
elif key == "natoms_vec":
nv = val
if nv.ndim == 1 and "coord" in batch:
nframes = batch["coord"].shape[0]
nv = np.tile(nv, (nframes, 1))
out["natoms"] = nv
else:
out[key] = val

return out


def split_batch(
batch: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Split a normalized batch into input and label dicts.

Parameters
----------
batch : dict[str, Any]
Normalized batch (output of :func:`normalize_batch`).

Returns
-------
input_dict : dict[str, Any]
Model inputs (coord, atype, box, fparam, aparam, spin).
label_dict : dict[str, Any]
Labels and find flags (energy, force, virial, find_*, natoms, …).
"""
input_dict: dict[str, Any] = {}
label_dict: dict[str, Any] = {}

for key, val in batch.items():
if key in _INPUT_KEYS:
input_dict[key] = val
else:
label_dict[key] = val

return input_dict, label_dict
8 changes: 6 additions & 2 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def compute_smooth_weight(
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
xp = array_api_compat.array_namespace(distance)
distance = xp.clip(distance, min=rmin, max=rmax)
# Use where instead of clip so that make_fx tracing does not
# decompose it into boolean-indexed ops with data-dependent sizes.
distance = xp.where(distance < rmin, xp.full_like(distance, rmin), distance)
distance = xp.where(distance > rmax, xp.full_like(distance, rmax), distance)
uu = (distance - rmin) / (rmax - rmin)
uu2 = uu * uu
vv = uu2 * uu * (-6.0 * uu2 + 15.0 * uu - 10.0) + 1.0
Expand All @@ -42,7 +45,8 @@ def compute_exp_sw(
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
xp = array_api_compat.array_namespace(distance)
distance = xp.clip(distance, min=0.0, max=rmax)
distance = xp.where(distance < 0.0, xp.zeros_like(distance), distance)
distance = xp.where(distance > rmax, xp.full_like(distance, rmax), distance)
C = 20
a = C / rmin
b = rmin
Expand Down
Loading