diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py index ade9eb51f3..61a7151208 100644 --- a/deepmd/backend/pt_expt.py +++ b/deepmd/backend/pt_expt.py @@ -32,7 +32,7 @@ class PyTorchExportableBackend(Backend): """PyTorch exportable backend.""" - name = "PyTorch Exportable" + name = "PyTorch-Exportable" """The formal name of the backend.""" features: ClassVar[Backend.Feature] = ( Backend.Feature.ENTRY_POINT @@ -63,7 +63,7 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]: Callable[[Namespace], None] The entry point hook of the backend. """ - from deepmd.pt.entrypoints.main import main as deepmd_main + from deepmd.pt_expt.entrypoints.main import main as deepmd_main return deepmd_main diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index ecfd08b61a..1058dff570 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -332,19 +332,19 @@ def wrapped_sampler() -> list[dict]: atom_exclude_types = self.atom_excl.get_exclude_types() for sample in sampled: sample["atom_exclude_types"] = list(atom_exclude_types) - if ( - "find_fparam" not in sampled[0] - and "fparam" not in sampled[0] - and self.has_default_fparam() - ): + # For systems where fparam is missing (find_fparam == 0), + # fill with default fparam if available and mark as found. + if self.has_default_fparam(): default_fparam = self.get_default_fparam() if default_fparam is not None: default_fparam_np = np.array(default_fparam) for sample in sampled: - nframe = sample["atype"].shape[0] - sample["fparam"] = np.tile( - default_fparam_np.reshape(1, -1), (nframe, 1) - ) + if "find_fparam" in sample and not sample["find_fparam"]: + nframe = sample["atype"].shape[0] + sample["fparam"] = np.tile( + default_fparam_np.reshape(1, -1), (nframe, 1) + ) + sample["find_fparam"] = np.bool_(True) return sampled return wrapped_sampler diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 34dcba6335..e0a9d47cac 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -242,6 +242,8 @@ class DescrptDPA1(NativeOP, BaseDescriptor): arXiv preprint arXiv:2208.08236. """ + _update_sel_cls = UpdateSel + def __init__( self, rcut: float, @@ -662,7 +664,7 @@ def update_sel( The minimum distance between two atoms """ local_jdata_cpy = local_jdata.copy() - min_nbor_dist, sel = UpdateSel().update_one_sel( + min_nbor_dist, sel = cls._update_sel_cls().update_one_sel( train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True ) local_jdata_cpy["sel"] = sel[0] diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 5ac636c37c..44e88fa38a 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -441,6 +441,8 @@ class DescrptDPA2(NativeOP, BaseDescriptor): Comput Mater 10, 293 (2024). https://doi.org/10.1038/s41524-024-01493-2 """ + _update_sel_cls = UpdateSel + def __init__( self, ntypes: int, @@ -1114,7 +1116,7 @@ def update_sel( The minimum distance between two atoms """ local_jdata_cpy = local_jdata.copy() - update_sel = UpdateSel() + update_sel = cls._update_sel_cls() min_nbor_dist, repinit_sel = update_sel.update_one_sel( train_data, type_map, diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index e385ae5dda..956607b114 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -337,6 +337,8 @@ class DescrptDPA3(NativeOP, BaseDescriptor): arXiv preprint arXiv:2506.01686 (2025). """ + _update_sel_cls = UpdateSel + def __init__( self, ntypes: int, @@ -729,7 +731,7 @@ def update_sel( The minimum distance between two atoms """ local_jdata_cpy = local_jdata.copy() - update_sel = UpdateSel() + update_sel = cls._update_sel_cls() min_nbor_dist, repflow_e_sel = update_sel.update_one_sel( train_data, type_map, diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index f87ca2c5b6..cfef017180 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -51,7 +51,7 @@ class BD(ABC, PluginVariant, make_plugin_registry("descriptor")): def __new__(cls, *args: Any, **kwargs: Any) -> Any: if cls is BD: cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) - return super().__new__(cls) + return object.__new__(cls) @abstractmethod def get_rcut(self) -> float: diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 4710987f54..3abdfe750f 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -149,6 +149,8 @@ class DescrptSeA(NativeOP, BaseDescriptor): Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 4441-4451. """ + _update_sel_cls = UpdateSel + def __init__( self, rcut: float, @@ -582,7 +584,7 @@ def update_sel( The minimum distance between two atoms """ local_jdata_cpy = local_jdata.copy() - min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel( + min_nbor_dist, local_jdata_cpy["sel"] = cls._update_sel_cls().update_one_sel( train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False ) return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 5ea9ef525f..2fb0414633 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -128,6 +128,8 @@ class DescrptSeR(NativeOP, BaseDescriptor): Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 4441-4451. """ + _update_sel_cls = UpdateSel + def __init__( self, rcut: float, @@ -505,7 +507,7 @@ def update_sel( The minimum distance between two atoms """ local_jdata_cpy = local_jdata.copy() - min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel( + min_nbor_dist, local_jdata_cpy["sel"] = cls._update_sel_cls().update_one_sel( train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False ) return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 7877a1e9ab..6e1e30ab1e 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -116,6 +116,8 @@ class DescrptSeT(NativeOP, BaseDescriptor): Not used in this descriptor, only to be compat with input. """ + _update_sel_cls = UpdateSel + def __init__( self, rcut: float, @@ -505,7 +507,7 @@ def update_sel( The minimum distance between two atoms """ local_jdata_cpy = local_jdata.copy() - min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel( + min_nbor_dist, local_jdata_cpy["sel"] = cls._update_sel_cls().update_one_sel( train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False ) return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 994fa63b30..1687ff5ca7 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -141,6 +141,8 @@ class DescrptSeTTebd(NativeOP, BaseDescriptor): """ + _update_sel_cls = UpdateSel + def __init__( self, rcut: float, @@ -500,7 +502,7 @@ def update_sel( The minimum distance between two atoms """ local_jdata_cpy = local_jdata.copy() - min_nbor_dist, sel = UpdateSel().update_one_sel( + min_nbor_dist, sel = cls._update_sel_cls().update_one_sel( train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True ) local_jdata_cpy["sel"] = sel[0] diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index c7372f05ac..180e5458fb 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -266,6 +266,17 @@ def compute_input_stats( ) else: sampled = merged() if callable(merged) else merged + for ii, frame in enumerate(sampled): + if "find_fparam" not in frame: + raise ValueError( + f"numb_fparam > 0 but fparam is not acquired " + f"for system {ii}." + ) + if not frame["find_fparam"]: + raise ValueError( + f"numb_fparam > 0 but no fparam data is provided " + f"for system {ii}." + ) cat_data = np.concatenate( [frame["fparam"] for frame in sampled], axis=0 ) @@ -313,6 +324,17 @@ def compute_input_stats( ) else: sampled = merged() if callable(merged) else merged + for ii, frame in enumerate(sampled): + if "find_aparam" not in frame: + raise ValueError( + f"numb_aparam > 0 but aparam is not acquired " + f"for system {ii}." + ) + if not frame["find_aparam"]: + raise ValueError( + f"numb_aparam > 0 but no aparam data is provided " + f"for system {ii}." + ) sys_sumv = [] sys_sumv2 = [] sys_sumn = [] diff --git a/deepmd/dpmodel/fitting/make_base_fitting.py b/deepmd/dpmodel/fitting/make_base_fitting.py index 7b65a150b2..cf8172bd03 100644 --- a/deepmd/dpmodel/fitting/make_base_fitting.py +++ b/deepmd/dpmodel/fitting/make_base_fitting.py @@ -42,7 +42,7 @@ class BF(ABC, PluginVariant, make_plugin_registry("fitting")): def __new__(cls: type, *args: Any, **kwargs: Any) -> Any: if cls is BF: cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) - return super().__new__(cls) + return object.__new__(cls) @abstractmethod def output_def(self) -> FittingOutputDef: diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index 3bf9695852..9ab141bdc2 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -95,10 +95,10 @@ def call( label_dict: dict[str, Array], ) -> dict[str, Array]: """Calculate loss from model results and labeled results.""" - energy = model_dict["energy_redu"] - force = model_dict["energy_derv_r"] - virial = model_dict["energy_derv_c_redu"] - atom_ener = model_dict["energy"] + energy = model_dict["energy"] + force = model_dict["force"] + virial = model_dict["virial"] + atom_ener = model_dict["atom_energy"] energy_hat = label_dict["energy"] force_hat = label_dict["force"] virial_hat = label_dict["virial"] @@ -212,7 +212,7 @@ def call( ) loss += pref_v * l_huber_loss more_loss["rmse_v"] = self.display_if_exist( - xp.sqrt(l2_virial_loss), find_virial + xp.sqrt(l2_virial_loss) * atom_norm, find_virial ) if self.has_ae: atom_ener_reshape = xp.reshape(atom_ener, (-1,)) diff --git a/deepmd/dpmodel/loss/loss.py b/deepmd/dpmodel/loss/loss.py index 6dc468582a..4b9831c344 100644 --- a/deepmd/dpmodel/loss/loss.py +++ b/deepmd/dpmodel/loss/loss.py @@ -53,8 +53,11 @@ def display_if_exist(loss: Array, find_property: float) -> Array: the loss scalar or NaN """ xp = array_api_compat.array_namespace(loss) + dev = array_api_compat.device(loss) return xp.where( - xp.asarray(find_property, dtype=xp.bool), loss, xp.asarray(xp.nan) + xp.asarray(find_property, dtype=xp.bool, device=dev), + loss, + xp.asarray(xp.nan, device=dev), ) @classmethod diff --git a/deepmd/dpmodel/model/base_model.py b/deepmd/dpmodel/model/base_model.py index d87c0eb5b7..b89172c4f6 100644 --- a/deepmd/dpmodel/model/base_model.py +++ b/deepmd/dpmodel/model/base_model.py @@ -42,7 +42,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "BaseModel": if model_type == "standard": model_type = kwargs.get("fitting", {}).get("type", "ener") cls = cls.get_class_by_type(model_type) - return super().__new__(cls) + return object.__new__(cls) @abstractmethod def __call__(self, *args: Any, **kwds: Any) -> Any: diff --git a/deepmd/dpmodel/utils/batch.py b/deepmd/dpmodel/utils/batch.py new file mode 100644 index 0000000000..204ae9771f --- /dev/null +++ b/deepmd/dpmodel/utils/batch.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Normalize raw batches from DeepmdDataSystem into canonical format.""" + +from typing import ( + Any, +) + +import numpy as np + +# Keys that are metadata / not needed by models or loss functions. +_DROP_KEYS = {"default_mesh", "sid", "fid"} + +# Keys that belong to model input (everything else is label). +_INPUT_KEYS = {"coord", "atype", "spin", "box", "fparam", "aparam"} + + +def normalize_batch(batch: dict[str, Any]) -> dict[str, Any]: + """Normalize a raw batch from :class:`DeepmdDataSystem` to canonical format. + + The following conversions are applied: + + * ``"type"`` is renamed to ``"atype"`` (int64). + * ``"natoms_vec"`` (1-D) is tiled to 2-D ``[nframes, 2+ntypes]`` + and stored as ``"natoms"``. + * ``find_*`` flags are converted to ``np.bool_``. + * Metadata keys (``default_mesh``, ``sid``, ``fid``) are dropped. + + Parameters + ---------- + batch : dict[str, Any] + Raw batch dict returned by ``DeepmdDataSystem.get_batch()``. + + Returns + ------- + dict[str, Any] + Normalized batch dict (new dict; the input is not mutated). + """ + out: dict[str, Any] = {} + + for key, val in batch.items(): + if key in _DROP_KEYS: + continue + + if key == "type": + out["atype"] = val.astype(np.int64) + elif key.startswith("find_"): + out[key] = np.bool_(float(val) > 0.5) + elif key == "natoms_vec": + nv = val + if nv.ndim == 1 and "coord" in batch: + nframes = batch["coord"].shape[0] + nv = np.tile(nv, (nframes, 1)) + out["natoms"] = nv + else: + out[key] = val + + return out + + +def split_batch( + batch: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: + """Split a normalized batch into input and label dicts. + + Parameters + ---------- + batch : dict[str, Any] + Normalized batch (output of :func:`normalize_batch`). + + Returns + ------- + input_dict : dict[str, Any] + Model inputs (coord, atype, box, fparam, aparam, spin). + label_dict : dict[str, Any] + Labels and find flags (energy, force, virial, find_*, natoms, …). + """ + input_dict: dict[str, Any] = {} + label_dict: dict[str, Any] = {} + + for key, val in batch.items(): + if key in _INPUT_KEYS: + input_dict[key] = val + else: + label_dict[key] = val + + return input_dict, label_dict diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 5af7a9fc3c..e9407a435b 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -26,7 +26,10 @@ def compute_smooth_weight( if rmin >= rmax: raise ValueError("rmin should be less than rmax.") xp = array_api_compat.array_namespace(distance) - distance = xp.clip(distance, min=rmin, max=rmax) + # Use where instead of clip so that make_fx tracing does not + # decompose it into boolean-indexed ops with data-dependent sizes. + distance = xp.where(distance < rmin, xp.full_like(distance, rmin), distance) + distance = xp.where(distance > rmax, xp.full_like(distance, rmax), distance) uu = (distance - rmin) / (rmax - rmin) uu2 = uu * uu vv = uu2 * uu * (-6.0 * uu2 + 15.0 * uu - 10.0) + 1.0 @@ -42,7 +45,8 @@ def compute_exp_sw( if rmin >= rmax: raise ValueError("rmin should be less than rmax.") xp = array_api_compat.array_namespace(distance) - distance = xp.clip(distance, min=0.0, max=rmax) + distance = xp.where(distance < 0.0, xp.zeros_like(distance), distance) + distance = xp.where(distance > rmax, xp.full_like(distance, rmax), distance) C = 20 a = C / rmin b = rmin diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index 238e395104..721723821e 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -133,6 +133,7 @@ def iter( system["atype"], system["box"], ) + nframes, nloc = atype.shape[:2] ( extended_coord, extended_atype, @@ -169,12 +170,12 @@ def iter( env_mat = xp.reshape( env_mat, ( - coord.shape[0] * coord.shape[1], + nframes * nloc, self.descriptor.get_nsel(), self.last_dim, ), ) - atype = xp.reshape(atype, (coord.shape[0] * coord.shape[1],)) + atype = xp.reshape(atype, (nframes * nloc,)) # (1, nloc) eq (ntypes, 1), so broadcast is possible # shape: (ntypes, nloc) type_idx = xp.equal( @@ -199,7 +200,7 @@ def iter( # shape: (1, nloc, nnei) exclude_mask = xp.reshape( pair_exclude_mask.build_type_exclude_mask(nlist, extended_atype), - (1, coord.shape[0] * coord.shape[1], -1), + (1, nframes * nloc, -1), ) # shape: (ntypes, nloc, nnei) type_idx = xp.logical_and(type_idx[..., None], exclude_mask) diff --git a/deepmd/dpmodel/utils/stat.py b/deepmd/dpmodel/utils/stat.py index 34c500d7c8..8cb379380a 100644 --- a/deepmd/dpmodel/utils/stat.py +++ b/deepmd/dpmodel/utils/stat.py @@ -244,11 +244,9 @@ def compute_output_stats( for kk in keys: for idx, system in enumerate(sampled): - if (("find_atom_" + kk) in system) and ( - system["find_atom_" + kk] > 0.0 - ): + if (("find_atom_" + kk) in system) and system["find_atom_" + kk]: atomic_sampled_idx[kk].append(idx) - if (("find_" + kk) in system) and (system["find_" + kk] > 0.0): + if (("find_" + kk) in system) and system["find_" + kk]: global_sampled_idx[kk].append(idx) # use index to gather model predictions for the corresponding systems. diff --git a/deepmd/pt_expt/descriptor/dpa1.py b/deepmd/pt_expt/descriptor/dpa1.py index 3cef4ef78c..1bab33af8f 100644 --- a/deepmd/pt_expt/descriptor/dpa1.py +++ b/deepmd/pt_expt/descriptor/dpa1.py @@ -7,10 +7,13 @@ from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.pt_expt.utils.update_sel import ( + UpdateSel, +) @BaseDescriptor.register("se_atten") @BaseDescriptor.register("dpa1") @torch_module class DescrptDPA1(DescrptDPA1DP): - pass + _update_sel_cls = UpdateSel diff --git a/deepmd/pt_expt/descriptor/dpa2.py b/deepmd/pt_expt/descriptor/dpa2.py index db60accaa1..ba7f03e9e6 100644 --- a/deepmd/pt_expt/descriptor/dpa2.py +++ b/deepmd/pt_expt/descriptor/dpa2.py @@ -7,9 +7,12 @@ from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.pt_expt.utils.update_sel import ( + UpdateSel, +) @BaseDescriptor.register("dpa2") @torch_module class DescrptDPA2(DescrptDPA2DP): - pass + _update_sel_cls = UpdateSel diff --git a/deepmd/pt_expt/descriptor/dpa3.py b/deepmd/pt_expt/descriptor/dpa3.py index 82dc46a57f..7119f043bd 100644 --- a/deepmd/pt_expt/descriptor/dpa3.py +++ b/deepmd/pt_expt/descriptor/dpa3.py @@ -7,9 +7,12 @@ from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.pt_expt.utils.update_sel import ( + UpdateSel, +) @BaseDescriptor.register("dpa3") @torch_module class DescrptDPA3(DescrptDPA3DP): - pass + _update_sel_cls = UpdateSel diff --git a/deepmd/pt_expt/descriptor/se_atten_v2.py b/deepmd/pt_expt/descriptor/se_atten_v2.py index b0ae833075..5be8b94ea2 100644 --- a/deepmd/pt_expt/descriptor/se_atten_v2.py +++ b/deepmd/pt_expt/descriptor/se_atten_v2.py @@ -7,9 +7,12 @@ from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.pt_expt.utils.update_sel import ( + UpdateSel, +) @BaseDescriptor.register("se_atten_v2") @torch_module class DescrptSeAttenV2(DescrptSeAttenV2DP): - pass + _update_sel_cls = UpdateSel diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index fea695ebd9..d65682c200 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -7,10 +7,13 @@ from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.pt_expt.utils.update_sel import ( + UpdateSel, +) @BaseDescriptor.register("se_e2_a") @BaseDescriptor.register("se_a") @torch_module class DescrptSeA(DescrptSeADP): - pass + _update_sel_cls = UpdateSel diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index a449614f47..a2456ff58e 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -7,10 +7,13 @@ from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.pt_expt.utils.update_sel import ( + UpdateSel, +) @BaseDescriptor.register("se_e2_r") @BaseDescriptor.register("se_r") @torch_module class DescrptSeR(DescrptSeRDP): - pass + _update_sel_cls = UpdateSel diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index de76b4ecf7..9706f0ceb4 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -7,6 +7,9 @@ from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.pt_expt.utils.update_sel import ( + UpdateSel, +) @BaseDescriptor.register("se_e3") @@ -14,4 +17,4 @@ @BaseDescriptor.register("se_a_3be") @torch_module class DescrptSeT(DescrptSeTDP): - pass + _update_sel_cls = UpdateSel diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index 995dc24c3b..37c872ec64 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -7,9 +7,12 @@ from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) +from deepmd.pt_expt.utils.update_sel import ( + UpdateSel, +) @BaseDescriptor.register("se_e3_tebd") @torch_module class DescrptSeTTebd(DescrptSeTTebdDP): - pass + _update_sel_cls = UpdateSel diff --git a/deepmd/pt_expt/entrypoints/__init__.py b/deepmd/pt_expt/entrypoints/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt_expt/entrypoints/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt_expt/entrypoints/main.py b/deepmd/pt_expt/entrypoints/main.py new file mode 100644 index 0000000000..a0b814f07a --- /dev/null +++ b/deepmd/pt_expt/entrypoints/main.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Training entrypoint for the pt_expt backend.""" + +import argparse +import json +import logging +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import h5py + +from deepmd.pt_expt.train import ( + training, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, + get_data, + process_systems, +) +from deepmd.utils.path import ( + DPPath, +) + +log = logging.getLogger(__name__) + + +def get_trainer( + config: dict[str, Any], + init_model: str | None = None, + restart_model: str | None = None, +) -> training.Trainer: + """Build a :class:`training.Trainer` from a normalised config.""" + model_params = config["model"] + training_params = config["training"] + type_map = model_params["type_map"] + + # ----- training data ------------------------------------------------ + training_dataset_params = training_params["training_data"] + training_systems = process_systems( + training_dataset_params["systems"], + patterns=training_dataset_params.get("rglob_patterns", None), + ) + train_data = DeepmdDataSystem( + systems=training_systems, + batch_size=training_dataset_params["batch_size"], + test_size=1, + type_map=type_map, + trn_all_set=True, + sys_probs=training_dataset_params.get("sys_probs", None), + auto_prob_style=training_dataset_params.get("auto_prob", "prob_sys_size"), + ) + + # ----- validation data ---------------------------------------------- + validation_data = None + validation_dataset_params = training_params.get("validation_data", None) + if validation_dataset_params is not None: + val_systems = process_systems( + validation_dataset_params["systems"], + patterns=validation_dataset_params.get("rglob_patterns", None), + ) + validation_data = DeepmdDataSystem( + systems=val_systems, + batch_size=validation_dataset_params["batch_size"], + test_size=1, + type_map=type_map, + trn_all_set=True, + ) + + # ----- stat file path ----------------------------------------------- + stat_file_path = training_params.get("stat_file", None) + if stat_file_path is not None: + if not Path(stat_file_path).exists(): + if stat_file_path.endswith((".h5", ".hdf5")): + with h5py.File(stat_file_path, "w"): + pass + else: + Path(stat_file_path).mkdir() + stat_file_path = DPPath(stat_file_path, "a") + + trainer = training.Trainer( + config, + train_data, + stat_file_path=stat_file_path, + validation_data=validation_data, + init_model=init_model, + restart_model=restart_model, + ) + return trainer + + +def train( + input_file: str, + init_model: str | None = None, + restart: str | None = None, + skip_neighbor_stat: bool = False, + output: str = "out.json", +) -> None: + """Run training with the pt_expt backend. + + Parameters + ---------- + input_file : str + Path to the JSON configuration file. + init_model : str or None + Path to a checkpoint to initialise weights from. + restart : str or None + Path to a checkpoint to restart training from. + skip_neighbor_stat : bool + Skip neighbour statistics calculation. + output : str + Where to dump the normalised config. + """ + from deepmd.common import ( + j_loader, + ) + + log.info("Configuration path: %s", input_file) + config = j_loader(input_file) + + # suffix fix + if init_model is not None and not init_model.endswith(".pt"): + init_model += ".pt" + if restart is not None and not restart.endswith(".pt"): + restart += ".pt" + + # argcheck + config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") + config = normalize(config) + + # neighbour stat + if not skip_neighbor_stat: + log.info( + "Calculate neighbor statistics... " + "(add --skip-neighbor-stat to skip this step)" + ) + type_map = config["model"].get("type_map") + train_data = get_data(config["training"]["training_data"], 0, type_map, None) + from deepmd.pt_expt.model import ( + BaseModel, + ) + + config["model"], _min_nbor_dist = BaseModel.update_sel( + train_data, type_map, config["model"] + ) + + with open(output, "w") as fp: + json.dump(config, fp, indent=4) + + trainer = get_trainer(config, init_model, restart) + trainer.run() + + +def main(args: list[str] | argparse.Namespace | None = None) -> None: + """Entry point for the pt_expt backend CLI. + + Parameters + ---------- + args : list[str] | argparse.Namespace | None + Command-line arguments or pre-parsed namespace. + """ + from deepmd.loggers.loggers import ( + set_log_handles, + ) + from deepmd.main import ( + parse_args, + ) + + if not isinstance(args, argparse.Namespace): + FLAGS = parse_args(args=args) + else: + FLAGS = args + + set_log_handles( + FLAGS.log_level, + Path(FLAGS.log_path) if FLAGS.log_path else None, + mpi_log=None, + ) + log.info("DeePMD-kit backend: pt_expt (PyTorch Exportable)") + + if FLAGS.command == "train": + train( + input_file=FLAGS.INPUT, + init_model=FLAGS.init_model, + restart=FLAGS.restart, + skip_neighbor_stat=FLAGS.skip_neighbor_stat, + output=FLAGS.output, + ) + else: + raise RuntimeError( + f"Unsupported command '{FLAGS.command}' for the pt_expt backend." + ) diff --git a/deepmd/pt_expt/loss/__init__.py b/deepmd/pt_expt/loss/__init__.py new file mode 100644 index 0000000000..19f76a0cba --- /dev/null +++ b/deepmd/pt_expt/loss/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.pt_expt.loss.ener import ( + EnergyLoss, +) + +__all__ = [ + "EnergyLoss", +] diff --git a/deepmd/pt_expt/loss/ener.py b/deepmd/pt_expt/loss/ener.py new file mode 100644 index 0000000000..e5bd220bd0 --- /dev/null +++ b/deepmd/pt_expt/loss/ener.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.loss.ener import EnergyLoss as EnergyLossDP +from deepmd.pt_expt.common import ( + torch_module, +) + + +@torch_module +class EnergyLoss(EnergyLossDP): + pass diff --git a/deepmd/pt_expt/model/__init__.py b/deepmd/pt_expt/model/__init__.py index da120091e0..7b3f7cdeab 100644 --- a/deepmd/pt_expt/model/__init__.py +++ b/deepmd/pt_expt/model/__init__.py @@ -11,6 +11,9 @@ from .ener_model import ( EnergyModel, ) +from .get_model import ( + get_model, +) from .model import ( BaseModel, ) @@ -29,4 +32,5 @@ "EnergyModel", "PolarModel", "PropertyModel", + "get_model", ] diff --git a/deepmd/pt_expt/model/get_model.py b/deepmd/pt_expt/model/get_model.py new file mode 100644 index 0000000000..9d4296a50b --- /dev/null +++ b/deepmd/pt_expt/model/get_model.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Model factory for the pt_expt backend. + +Mirrors ``deepmd.dpmodel.model.model`` but uses the pt_expt +``BaseDescriptor`` / ``BaseFitting`` registries so that the +constructed objects are ``torch.nn.Module`` subclasses. +""" + +import copy +from typing import ( + Any, +) + +from deepmd.pt_expt.descriptor import ( + BaseDescriptor, +) +from deepmd.pt_expt.fitting import ( + BaseFitting, +) + +# Import from submodules directly to avoid circular import via __init__.py +from deepmd.pt_expt.model.dipole_model import ( + DipoleModel, +) +from deepmd.pt_expt.model.dos_model import ( + DOSModel, +) +from deepmd.pt_expt.model.ener_model import ( + EnergyModel, +) +from deepmd.pt_expt.model.model import ( + BaseModel, +) +from deepmd.pt_expt.model.polar_model import ( + PolarModel, +) +from deepmd.pt_expt.model.property_model import ( + PropertyModel, +) + + +def _get_standard_model_components( + data: dict[str, Any], + ntypes: int, +) -> tuple: + """Build descriptor and fitting from config dict.""" + # descriptor + data["descriptor"]["ntypes"] = ntypes + data["descriptor"]["type_map"] = copy.deepcopy(data["type_map"]) + descriptor = BaseDescriptor(**data["descriptor"]) + + # fitting + fitting_net = data.get("fitting_net", {}) + fitting_net["type"] = fitting_net.get("type", "ener") + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["type_map"] = copy.deepcopy(data["type_map"]) + fitting_net["mixed_types"] = descriptor.mixed_types() + if fitting_net["type"] in ["dipole", "polar"]: + fitting_net["embedding_width"] = descriptor.get_dim_emb() + fitting_net["dim_descrpt"] = descriptor.get_dim_out() + grad_force = "direct" not in fitting_net["type"] + if not grad_force: + fitting_net["out_dim"] = descriptor.get_dim_emb() + if "ener" in fitting_net["type"]: + fitting_net["return_energy"] = True + fitting = BaseFitting(**fitting_net) + return descriptor, fitting, fitting_net["type"] + + +def get_standard_model(data: dict) -> EnergyModel: + """Get a standard model from a config dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + data = copy.deepcopy(data) + ntypes = len(data["type_map"]) + descriptor, fitting, fitting_net_type = _get_standard_model_components(data, ntypes) + atom_exclude_types = data.get("atom_exclude_types", []) + pair_exclude_types = data.get("pair_exclude_types", []) + + if fitting_net_type == "dipole": + modelcls = DipoleModel + elif fitting_net_type == "polar": + modelcls = PolarModel + elif fitting_net_type == "dos": + modelcls = DOSModel + elif fitting_net_type in ["ener", "direct_force_ener"]: + modelcls = EnergyModel + elif fitting_net_type == "property": + modelcls = PropertyModel + else: + raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") + + model = modelcls( + descriptor=descriptor, + fitting=fitting, + type_map=data["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) + return model + + +def get_model(data: dict) -> BaseModel: + """Get a model from a config dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + model_type = data.get("type", "standard") + if model_type == "standard": + return get_standard_model(data) + else: + return BaseModel.get_class_by_type(model_type).get_model(data) diff --git a/deepmd/pt_expt/train/__init__.py b/deepmd/pt_expt/train/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt_expt/train/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py new file mode 100644 index 0000000000..f8730ed271 --- /dev/null +++ b/deepmd/pt_expt/train/training.py @@ -0,0 +1,886 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Training loop for the pt_expt backend. + +Uses ``DeepmdDataSystem`` (numpy-based batch provider) instead of the +pt backend's ``DpLoaderSet`` + ``DataLoader``. NumPy batches are +converted to torch tensors at the boundary. +""" + +import functools +import logging +import time +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.dpmodel.utils.batch import ( + normalize_batch, + split_batch, +) +from deepmd.dpmodel.utils.learning_rate import ( + LearningRateExp, +) +from deepmd.loggers.training import ( + format_training_message_per_task, +) +from deepmd.pt_expt.loss import ( + EnergyLoss, +) +from deepmd.pt_expt.model import ( + get_model, +) +from deepmd.pt_expt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt_expt.utils.env import ( + DEVICE, + GLOBAL_PT_FLOAT_PRECISION, +) +from deepmd.pt_expt.utils.stat import ( + make_stat_input, +) +from deepmd.utils.data import ( + DataRequirementItem, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.path import ( + DPPath, +) + +log = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helper: loss factory (reused from pt) +# --------------------------------------------------------------------------- + + +def get_loss( + loss_params: dict[str, Any], + start_lr: float, + _ntypes: int, + _model: Any, +) -> EnergyLoss: + loss_type = loss_params.get("type", "ener") + if loss_type == "ener": + loss_params["starter_learning_rate"] = start_lr + return EnergyLoss(**loss_params) + else: + raise ValueError(f"Unsupported loss type for pt_expt: {loss_type}") + + +def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]: + additional_data_requirement: list[DataRequirementItem] = [] + if _model.get_dim_fparam() > 0: + additional_data_requirement.append( + DataRequirementItem( + "fparam", + _model.get_dim_fparam(), + atomic=False, + must=False, + default=0.0, + ) + ) + if _model.get_dim_aparam() > 0: + additional_data_requirement.append( + DataRequirementItem( + "aparam", _model.get_dim_aparam(), atomic=True, must=True + ) + ) + return additional_data_requirement + + +# --------------------------------------------------------------------------- +# torch.compile helpers +# --------------------------------------------------------------------------- + + +def _trace_and_compile( + model: torch.nn.Module, + ext_coord: torch.Tensor, + ext_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + compile_opts: dict[str, Any], +) -> torch.nn.Module: + """Trace ``forward_lower`` with ``make_fx`` and compile with ``torch.compile``. + + Parameters + ---------- + model : torch.nn.Module + The (uncompiled) model. Temporarily set to eval mode for tracing. + ext_coord, ext_atype, nlist, mapping, fparam, aparam + Sample tensors (already padded to the desired max_nall). + compile_opts : dict + Options forwarded to ``torch.compile`` (excluding ``dynamic``). + + Returns + ------- + torch.nn.Module + The compiled ``forward_lower`` callable. + """ + from torch.fx.experimental.proxy_tensor import ( + make_fx, + ) + + was_training = model.training + # Trace in train mode so that create_graph=True is captured inside + # task_deriv_one. Without this, the autograd.grad that computes + # forces is traced with create_graph=False (eval mode), producing + # force tensors that are detached from model parameters — force loss + # backprop cannot reach the weights and force RMSE never decreases. + model.train() + + def fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + extended_coord = extended_coord.detach().requires_grad_(True) + return model.forward_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + ) + + # Use default tracing_mode="real" (concrete shapes) for best + # runtime performance. If data-dependent intermediate shapes + # change at runtime, the caller catches the error and retraces. + traced_lower = make_fx(fn)(ext_coord, ext_atype, nlist, mapping, fparam, aparam) + + if not was_training: + model.eval() + + # The inductor backend does not propagate gradients through the + # make_fx-decomposed autograd.grad ops (second-order gradients for + # force training). Use "aot_eager" which correctly preserves the + # gradient chain while still benefiting from make_fx decomposition. + if "backend" not in compile_opts: + compile_opts["backend"] = "aot_eager" + compiled_lower = torch.compile(traced_lower, dynamic=False, **compile_opts) + return compiled_lower + + +class _CompiledModel(torch.nn.Module): + """Coord extension (eager) -> pad nall -> compiled forward_lower. + + If a batch's ``nall`` exceeds the current ``max_nall``, the model is + automatically re-traced and recompiled with a larger pad size. + """ + + def __init__( + self, + original_model: torch.nn.Module, + compiled_forward_lower: torch.nn.Module, + max_nall: int, + compile_opts: dict[str, Any], + ) -> None: + super().__init__() + self.original_model = original_model + self.compiled_forward_lower = compiled_forward_lower + self._max_nall = max_nall + self._compile_opts = compile_opts + + def _recompile( + self, + ext_coord: torch.Tensor, + ext_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + new_max_nall: int, + ) -> None: + """Re-trace and recompile for the given inputs. + + If *new_max_nall* differs from the current ``_max_nall``, the + inputs are padded (or already padded by the caller). + """ + # Pad if the caller provides unpadded tensors (nall growth case) + actual_nall = ext_coord.shape[1] + pad_n = new_max_nall - actual_nall + if pad_n > 0: + ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n)) + ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n)) + mapping = torch.nn.functional.pad(mapping, (0, pad_n)) + + ext_coord = ext_coord.detach() + + self.compiled_forward_lower = _trace_and_compile( + self.original_model, + ext_coord, + ext_atype, + nlist, + mapping, + fparam, + aparam, + self._compile_opts, + ) + self._max_nall = new_max_nall + log.info( + "Recompiled model with max_nall=%d.", + new_max_nall, + ) + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, + ) + from deepmd.dpmodel.utils.region import ( + normalize_coord, + ) + + nframes, nloc = atype.shape[:2] + rcut = self.original_model.get_rcut() + sel = self.original_model.get_sel() + + # coord extension + nlist (data-dependent, run in eager) + coord_3d = coord.detach().reshape(nframes, nloc, 3) + box_flat = box.detach().reshape(nframes, 9) if box is not None else None + + if box_flat is not None: + coord_norm = normalize_coord(coord_3d, box_flat.reshape(nframes, 3, 3)) + else: + coord_norm = coord_3d + + ext_coord, ext_atype, mapping = extend_coord_with_ghosts( + coord_norm, atype, box_flat, rcut + ) + nlist = build_neighbor_list( + ext_coord, + ext_atype, + nloc, + rcut, + sel, + distinguish_types=False, + ) + ext_coord = ext_coord.reshape(nframes, -1, 3) + + # Grow max_nall if needed (retrace + recompile) + actual_nall = ext_coord.shape[1] + if actual_nall > self._max_nall: + new_max_nall = ((int(actual_nall * 1.2) + 7) // 8) * 8 + log.info( + "nall=%d exceeds max_nall=%d; recompiling with max_nall=%d.", + actual_nall, + self._max_nall, + new_max_nall, + ) + self._recompile( + ext_coord, ext_atype, nlist, mapping, fparam, aparam, new_max_nall + ) + + # Pad to max_nall so compiled graph sees a fixed shape + pad_n = self._max_nall - actual_nall + if pad_n > 0: + ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n)) + ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n)) + mapping = torch.nn.functional.pad(mapping, (0, pad_n)) + ext_coord = ext_coord.detach().requires_grad_(True) + + result = self.compiled_forward_lower( + ext_coord, ext_atype, nlist, mapping, fparam, aparam + ) + + # Translate forward_lower keys -> forward keys. + # ``extended_force`` lives on all extended atoms (nf, nall, 3). + # Ghost-atom forces must be scatter-summed back to local atoms + # via ``mapping`` — the same operation ``communicate_extended_output`` + # performs in the uncompiled path. + out: dict[str, torch.Tensor] = {} + out["atom_energy"] = result["atom_energy"] + out["energy"] = result["energy"] + if "extended_force" in result: + ext_force = result["extended_force"] # (nf, nall_padded, 3) + # mapping may be padded; only use actual_nall entries + map_actual = mapping[:, :actual_nall] # (nf, actual_nall) + ext_force_actual = ext_force[:, :actual_nall, :] # (nf, actual_nall, 3) + # scatter-sum extended forces onto local atoms + idx = map_actual.unsqueeze(-1).expand_as( + ext_force_actual + ) # (nf, actual_nall, 3) + force = torch.zeros( + nframes, nloc, 3, dtype=ext_force.dtype, device=ext_force.device + ) + force.scatter_add_(1, idx, ext_force_actual) + out["force"] = force + if "virial" in result: + out["virial"] = result["virial"] + if "extended_virial" in result: + out["extended_virial"] = result["extended_virial"] + if "atom_virial" in result: + out["atom_virial"] = result["atom_virial"] + if "mask" in result: + out["mask"] = result["mask"] + return out + + +# --------------------------------------------------------------------------- +# Trainer +# --------------------------------------------------------------------------- + + +class Trainer: + """Training driver for the pt_expt backend. + + Uses ``DeepmdDataSystem`` for data loading (numpy batches converted + to torch tensors at the boundary). Single-task, single-GPU only. + + Parameters + ---------- + config : dict + Full training configuration. + training_data : DeepmdDataSystem + Training data. + stat_file_path : DPPath or None + Path for saving / loading statistics. + validation_data : DeepmdDataSystem or None + Validation data. + init_model : str or None + Path to a checkpoint to initialise weights from. + restart_model : str or None + Path to a checkpoint to *restart* training from (restores step + optimiser). + """ + + def __init__( + self, + config: dict[str, Any], + training_data: DeepmdDataSystem, + stat_file_path: DPPath | None = None, + validation_data: DeepmdDataSystem | None = None, + init_model: str | None = None, + restart_model: str | None = None, + ) -> None: + resume_model = init_model or restart_model + resuming = resume_model is not None + self.restart_training = restart_model is not None + + model_params = config["model"] + training_params = config["training"] + + # Iteration config + self.num_steps = training_params["numb_steps"] + self.disp_file = training_params.get("disp_file", "lcurve.out") + self.disp_freq = training_params.get("disp_freq", 1000) + self.save_ckpt = training_params.get("save_ckpt", "model.ckpt") + self.save_freq = training_params.get("save_freq", 1000) + self.display_in_training = training_params.get("disp_training", True) + self.timing_in_training = training_params.get("time_training", True) + self.lcurve_should_print_header = True + + # Model --------------------------------------------------------------- + self.model = get_model(deepcopy(model_params)).to(DEVICE) + + # Loss ---------------------------------------------------------------- + self.loss = get_loss( + config.get("loss", {}), + config["learning_rate"]["start_lr"], + len(model_params["type_map"]), + self.model, + ) + + # Data requirements --------------------------------------------------- + data_requirement = self.loss.label_requirement + data_requirement += get_additional_data_requirement(self.model) + training_data.add_data_requirements(data_requirement) + if validation_data is not None: + validation_data.add_data_requirements(data_requirement) + + self.training_data = training_data + self.validation_data = validation_data + self.valid_numb_batch = training_params.get("validation_data", {}).get( + "numb_btch", 1 + ) + + # Statistics ---------------------------------------------------------- + data_stat_nbatch = model_params.get("data_stat_nbatch", 10) + + @functools.lru_cache + def get_sample() -> list[dict[str, np.ndarray]]: + return make_stat_input(training_data, data_stat_nbatch) + + if not resuming: + self.model.compute_or_load_stat( + sampled_func=get_sample, + stat_file_path=stat_file_path, + ) + + # Learning rate ------------------------------------------------------- + lr_params = config["learning_rate"].copy() + lr_params["num_steps"] = self.num_steps + self.lr_schedule = LearningRateExp(**lr_params) + + # Gradient clipping + self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) + + # Model wrapper ------------------------------------------------------- + self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params) + self.start_step = 0 + + # Optimiser ----------------------------------------------------------- + opt_type = training_params.get("opt_type", "Adam") + initial_lr = float(self.lr_schedule.value(self.start_step)) + + if opt_type == "Adam": + self.optimizer = torch.optim.Adam(self.wrapper.parameters(), lr=initial_lr) + elif opt_type == "AdamW": + weight_decay = training_params.get("weight_decay", 0.001) + self.optimizer = torch.optim.AdamW( + self.wrapper.parameters(), + lr=initial_lr, + weight_decay=weight_decay, + ) + else: + raise ValueError(f"Unsupported optimizer type: {opt_type}") + + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda step: self.lr_schedule.value(step) / initial_lr, + last_epoch=self.start_step - 1, + ) + + # Resume -------------------------------------------------------------- + if resuming: + log.info(f"Resuming from {resume_model}.") + state_dict = torch.load( + resume_model, map_location=DEVICE, weights_only=True + ) + if "model" in state_dict: + optimizer_state_dict = ( + state_dict["optimizer"] if self.restart_training else None + ) + state_dict = state_dict["model"] + else: + optimizer_state_dict = None + + self.start_step = ( + state_dict["_extra_state"]["train_infos"]["step"] + if self.restart_training + else 0 + ) + self.wrapper.load_state_dict(state_dict) + if optimizer_state_dict is not None: + self.optimizer.load_state_dict(optimizer_state_dict) + # rebuild scheduler from the resumed step. + # last_epoch handles the step offset; the lambda must NOT + # add self.start_step again (that would double-count). + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda step: self.lr_schedule.value(step) / initial_lr, + last_epoch=self.start_step - 1, + ) + + # torch.compile ------------------------------------------------------- + # The model's forward uses torch.autograd.grad (for forces) with + # create_graph=True so the loss backward can differentiate through + # forces. torch.compile does not support this "double backward". + # + # Solution: use make_fx to trace the model forward, which decomposes + # torch.autograd.grad into primitive ops. The resulting traced + # module is then compiled by torch.compile — no double backward. + self.enable_compile = training_params.get("enable_compile", False) + if self.enable_compile: + compile_opts = training_params.get("compile_options", {}) + log.info("Compiling model with torch.compile (%s)", compile_opts) + self._compile_model(compile_opts) + + # ------------------------------------------------------------------ + # torch.compile helpers + # ------------------------------------------------------------------ + + def _compile_model(self, compile_opts: dict[str, Any]) -> None: + """Replace ``self.model`` with a compiled version. + + The model's ``forward`` uses ``torch.autograd.grad`` (for force + computation) with ``create_graph=True``, which creates a "double + backward" that ``torch.compile`` cannot handle. + + Solution: use ``make_fx`` to trace ``forward_lower``, decomposing + ``torch.autograd.grad`` into primitive ops. The coord extension + + nlist build (data-dependent control flow) are kept outside the + compiled region. + + To avoid the overhead of symbolic tracing and dynamic shapes, the + extended-atom dimension (nall) is padded to a fixed maximum + estimated from the training data. This allows concrete-shape + tracing and ``dynamic=False``. If a batch exceeds the current + max_nall at runtime, the model is automatically re-traced and + recompiled with a larger pad size. + """ + from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, + ) + from deepmd.dpmodel.utils.region import ( + normalize_coord, + ) + + model = self.model + + # --- Estimate max_nall by sampling multiple batches --- + n_sample = 20 + max_nall = 0 + best_sample: ( + tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, dict] | None + ) = None + + for _ii in range(n_sample): + inp, _ = self.get_data(is_train=True) + coord = inp["coord"].detach() + atype = inp["atype"].detach() + box = inp.get("box") + if box is not None: + box = box.detach() + + nframes, nloc = atype.shape[:2] + coord_np = coord.cpu().numpy().reshape(nframes, nloc, 3) + atype_np = atype.cpu().numpy() + box_np = box.cpu().numpy().reshape(nframes, 9) if box is not None else None + + if box_np is not None: + coord_norm = normalize_coord(coord_np, box_np.reshape(nframes, 3, 3)) + else: + coord_norm = coord_np + + ext_coord_np, ext_atype_np, mapping_np = extend_coord_with_ghosts( + coord_norm, atype_np, box_np, model.get_rcut() + ) + nlist_np = build_neighbor_list( + ext_coord_np, + ext_atype_np, + nloc, + model.get_rcut(), + model.get_sel(), + distinguish_types=False, + ) + ext_coord_np = ext_coord_np.reshape(nframes, -1, 3) + nall = ext_coord_np.shape[1] + if nall > max_nall: + max_nall = nall + best_sample = ( + ext_coord_np, + ext_atype_np, + mapping_np, + nlist_np, + nloc, + inp, + ) + + # Add 20 % margin and round up to a multiple of 8. + max_nall = ((int(max_nall * 1.2) + 7) // 8) * 8 + log.info( + "Estimated max_nall=%d for compiled model (sampled %d batches).", + max_nall, + n_sample, + ) + + # --- Pad the largest sample to max_nall and trace --- + assert best_sample is not None + ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = ( + best_sample + ) + nframes = ext_coord_np.shape[0] + actual_nall = ext_coord_np.shape[1] + pad_n = max_nall - actual_nall + + if pad_n > 0: + ext_coord_np = np.pad(ext_coord_np, ((0, 0), (0, pad_n), (0, 0))) + ext_atype_np = np.pad(ext_atype_np, ((0, 0), (0, pad_n))) + mapping_np = np.pad(mapping_np, ((0, 0), (0, pad_n))) + + ext_coord = torch.tensor( + ext_coord_np, dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ext_atype = torch.tensor(ext_atype_np, dtype=torch.int64, device=DEVICE) + nlist_t = torch.tensor(nlist_np, dtype=torch.int64, device=DEVICE) + mapping_t = torch.tensor(mapping_np, dtype=torch.int64, device=DEVICE) + fparam = sample_input.get("fparam") + aparam = sample_input.get("aparam") + + compile_opts.pop("dynamic", None) # always False for padded approach + + compiled_lower = _trace_and_compile( + model, + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam, + aparam, + compile_opts, + ) + + self.wrapper.model = _CompiledModel( + model, compiled_lower, max_nall, compile_opts + ) + log.info( + "Model compiled with padded nall=%d (tracing_mode=real, dynamic=False).", + max_nall, + ) + + # ------------------------------------------------------------------ + # Data helpers + # ------------------------------------------------------------------ + + def get_data( + self, + is_train: bool = True, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Fetch a batch and split into input / label dicts. + + Returns + ------- + input_dict, label_dict + """ + data_sys = self.training_data if is_train else self.validation_data + if data_sys is None: + return {}, {} + + batch = normalize_batch(data_sys.get_batch()) + input_dict, label_dict = split_batch(batch) + + # Convert numpy values to torch tensors. + for dd in (input_dict, label_dict): + for key, val in dd.items(): + if val is None: + continue + if isinstance(val, np.ndarray): + if np.issubdtype(val.dtype, np.integer): + dd[key] = torch.from_numpy(val).to(DEVICE) + else: + dd[key] = torch.from_numpy(val).to( + dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + elif isinstance(val, (float, np.bool_)): + dd[key] = torch.tensor( + float(val), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + # requires_grad on coord for force computation via autograd + if "coord" in input_dict and input_dict["coord"] is not None: + input_dict["coord"] = input_dict["coord"].requires_grad_(True) + + return input_dict, label_dict + + # ------------------------------------------------------------------ + # Checkpointing + # ------------------------------------------------------------------ + + def save_checkpoint(self, step: int) -> None: + self.wrapper.train_infos["step"] = step + state = { + "model": self.wrapper.state_dict(), + "optimizer": self.optimizer.state_dict(), + } + ckpt_path = f"{self.save_ckpt}-{step}.pt" + torch.save(state, ckpt_path) + # symlink latest + latest = Path(f"{self.save_ckpt}.pt") + if latest.is_symlink() or latest.exists(): + latest.unlink() + latest.symlink_to(ckpt_path) + log.info(f"Saved checkpoint to {ckpt_path}") + + # ------------------------------------------------------------------ + # Training loop + # ------------------------------------------------------------------ + + @torch.compiler.disable + def _optimizer_step(self) -> None: + """Run optimizer and scheduler step outside torch._dynamo. + + Dynamo intercepts tensor creation inside Adam._init_group, + which can trigger CUDA init on CPU-only builds. + """ + self.optimizer.step() + self.scheduler.step() + + def run(self) -> None: + fout = open( + self.disp_file, + mode="w" if not self.restart_training else "a", + buffering=1, + ) + log.info("Start to train %d steps.", self.num_steps) + + self.wrapper.train() + wall_start = time.time() + + for step_id in range(self.start_step, self.num_steps): + cur_lr = float(self.lr_schedule.value(step_id)) + + if self.timing_in_training: + t_start = time.time() + + # --- forward / backward --- + self.optimizer.zero_grad(set_to_none=True) + input_dict, label_dict = self.get_data(is_train=True) + + cur_lr_sched = self.scheduler.get_last_lr()[0] + model_pred, loss, more_loss = self.wrapper( + **input_dict, cur_lr=cur_lr_sched, label=label_dict + ) + loss.backward() + + if self.gradient_max_norm > 0.0: + torch.nn.utils.clip_grad_norm_( + self.wrapper.parameters(), self.gradient_max_norm + ) + + self._optimizer_step() + + if self.timing_in_training: + t_end = time.time() + + # --- display --- + display_step_id = step_id + 1 + if self.display_in_training and ( + display_step_id % self.disp_freq == 0 or display_step_id == 1 + ): + self.wrapper.eval() + + train_results = {k: v for k, v in more_loss.items() if "l2_" not in k} + + # validation + valid_results: dict[str, Any] = {} + if self.validation_data is not None: + sum_natoms = 0 + for _ii in range(self.valid_numb_batch): + val_input, val_label = self.get_data(is_train=False) + if not val_input: + break + _, _vloss, _vmore = self.wrapper( + **val_input, cur_lr=cur_lr_sched, label=val_label + ) + natoms = int(val_input["atype"].shape[-1]) + sum_natoms += natoms + for k, v in _vmore.items(): + if "l2_" not in k: + valid_results[k] = ( + valid_results.get(k, 0.0) + v * natoms + ) + if sum_natoms > 0: + valid_results = { + k: v / sum_natoms for k, v in valid_results.items() + } + + # wall-clock time + wall_elapsed = time.time() - wall_start + if self.timing_in_training: + step_time = t_end - t_start + log.info( + "step=%d wall=%.2fs step_time=%.4fs", + display_step_id, + wall_elapsed, + step_time, + ) + else: + log.info("step=%d wall=%.2fs", display_step_id, wall_elapsed) + + # log + log.info( + format_training_message_per_task( + batch=display_step_id, + task_name="trn", + rmse=train_results, + learning_rate=cur_lr, + ) + ) + if valid_results: + log.info( + format_training_message_per_task( + batch=display_step_id, + task_name="val", + rmse=valid_results, + learning_rate=None, + ) + ) + + # lcurve file + if self.lcurve_should_print_header: + self.print_header(fout, train_results, valid_results) + self.lcurve_should_print_header = False + self.print_on_training( + fout, display_step_id, cur_lr, train_results, valid_results + ) + + self.wrapper.train() + + # --- checkpoint --- + if display_step_id % self.save_freq == 0: + self.save_checkpoint(display_step_id) + + # final save + self.save_checkpoint(self.num_steps) + wall_total = time.time() - wall_start + fout.close() + log.info("Training finished. Total wall time: %.2fs", wall_total) + + # ------------------------------------------------------------------ + # Printing helpers + # ------------------------------------------------------------------ + + def print_header( + self, + fout: Any, + train_results: dict[str, Any], + valid_results: dict[str, Any], + ) -> None: + train_keys = sorted(train_results.keys()) + header = "# {:5s}".format("step") + if valid_results: + for k in train_keys: + header += f" {k + '_val':>11s} {k + '_trn':>11s}" + else: + for k in train_keys: + header += f" {k + '_trn':>11s}" + header += " {:8s}\n".format("lr") + fout.write(header) + fout.flush() + + def print_on_training( + self, + fout: Any, + step_id: int, + cur_lr: float, + train_results: dict, + valid_results: dict, + ) -> None: + train_keys = sorted(train_results.keys()) + line = f"{step_id:7d}" + if valid_results: + for k in train_keys: + line += f" {valid_results.get(k, float('nan')):11.2e} {train_results[k]:11.2e}" + else: + for k in train_keys: + line += f" {train_results[k]:11.2e}" + line += f" {cur_lr:8.1e}\n" + fout.write(line) + fout.flush() diff --git a/deepmd/pt_expt/train/wrapper.py b/deepmd/pt_expt/train/wrapper.py new file mode 100644 index 0000000000..281168cdba --- /dev/null +++ b/deepmd/pt_expt/train/wrapper.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + Any, +) + +import torch + +log = logging.getLogger(__name__) + + +class ModelWrapper(torch.nn.Module): + """Simplified model wrapper that bundles a model and a loss. + + Single-task only for now (no multi-task support). + + Parameters + ---------- + model : torch.nn.Module + The model to train. + loss : torch.nn.Module + The loss module. + model_params : dict, optional + Model parameters to store as extra state. + """ + + def __init__( + self, + model: torch.nn.Module, + loss: torch.nn.Module | None = None, + model_params: dict[str, Any] | None = None, + ) -> None: + super().__init__() + self.model_params = model_params if model_params is not None else {} + self.train_infos: dict[str, Any] = { + "lr": 0, + "step": 0, + } + self.model = model + self.loss = loss + self.inference_only = self.loss is None + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + cur_lr: float | torch.Tensor | None = None, + label: dict[str, torch.Tensor] | None = None, + do_atomic_virial: bool = False, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor | None, dict | None]: + input_dict = { + "coord": coord, + "atype": atype, + "box": box, + "do_atomic_virial": do_atomic_virial, + "fparam": fparam, + "aparam": aparam, + } + + model_pred = self.model(**input_dict) + + if self.inference_only or label is None: + return model_pred, None, None + else: + natoms = atype.shape[-1] + loss, more_loss = self.loss( + cur_lr, + natoms, + model_pred, + label, + ) + return model_pred, loss, more_loss + + def set_extra_state(self, state: dict) -> None: + self.model_params = state.get("model_params", {}) + self.train_infos = state.get("train_infos", {"lr": 0, "step": 0}) + + def get_extra_state(self) -> dict: + return { + "model_params": self.model_params, + "train_infos": self.train_infos, + } diff --git a/deepmd/pt_expt/utils/neighbor_stat.py b/deepmd/pt_expt/utils/neighbor_stat.py new file mode 100644 index 0000000000..cf9d9f3c18 --- /dev/null +++ b/deepmd/pt_expt/utils/neighbor_stat.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Iterator, +) + +import numpy as np +import torch + +from deepmd.dpmodel.utils.neighbor_stat import NeighborStatOP as NeighborStatOPDP +from deepmd.pt_expt.common import ( + torch_module, +) +from deepmd.pt_expt.utils.env import ( + DEVICE, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat + + +@torch_module +class NeighborStatOP(NeighborStatOPDP): + pass + + +class NeighborStat(BaseNeighborStat): + """Neighbor statistics using torch on DEVICE. + + Parameters + ---------- + ntypes : int + The num of atom types + rcut : float + The cut-off radius + mixed_type : bool, optional, default=False + Treat all types as a single type. + """ + + def __init__( + self, + ntypes: int, + rcut: float, + mixed_type: bool = False, + ) -> None: + super().__init__(ntypes, rcut, mixed_type) + self.op = NeighborStatOP(ntypes, rcut, mixed_type) + + def iterator( + self, data: DeepmdDataSystem + ) -> Iterator[tuple[np.ndarray, float, str]]: + """Produce neighbor statistics for each data set. + + Yields + ------ + np.ndarray + The maximal number of neighbors + float + The squared minimal distance between two atoms + str + The directory of the data system + """ + for ii in range(len(data.system_dirs)): + for jj in data.data_systems[ii].dirs: + data_set = data.data_systems[ii] + data_set_data = data_set._load_set(jj) + minrr2, max_nnei = self._execute( + data_set_data["coord"], + data_set_data["type"], + data_set_data["box"] if data_set.pbc else None, + ) + yield np.max(max_nnei, axis=0), np.min(minrr2), jj + + def _execute( + self, + coord: np.ndarray, + atype: np.ndarray, + cell: np.ndarray | None, + ) -> tuple[np.ndarray, np.ndarray]: + """Execute the operation on DEVICE.""" + minrr2, max_nnei = self.op( + torch.from_numpy(coord).to(DEVICE), + torch.from_numpy(atype).to(DEVICE), + torch.from_numpy(cell).to(DEVICE) if cell is not None else None, + ) + minrr2 = minrr2.detach().cpu().numpy() + max_nnei = max_nnei.detach().cpu().numpy() + return minrr2, max_nnei diff --git a/deepmd/pt_expt/utils/stat.py b/deepmd/pt_expt/utils/stat.py new file mode 100644 index 0000000000..0a22ba4404 --- /dev/null +++ b/deepmd/pt_expt/utils/stat.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.utils.model_stat import ( + make_stat_input, +) + +__all__ = ["make_stat_input"] diff --git a/deepmd/pt_expt/utils/update_sel.py b/deepmd/pt_expt/utils/update_sel.py new file mode 100644 index 0000000000..20d3ef9da1 --- /dev/null +++ b/deepmd/pt_expt/utils/update_sel.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.pt_expt.utils.neighbor_stat import ( + NeighborStat, +) +from deepmd.utils.update_sel import ( + BaseUpdateSel, +) + + +class UpdateSel(BaseUpdateSel): + @property + def neighbor_stat(self) -> type[NeighborStat]: + return NeighborStat diff --git a/deepmd/tf/model/model_stat.py b/deepmd/tf/model/model_stat.py index 96c8b4a4af..1b0e8ab3a1 100644 --- a/deepmd/tf/model/model_stat.py +++ b/deepmd/tf/model/model_stat.py @@ -3,7 +3,9 @@ from deepmd.utils.model_stat import ( _make_all_stat_ref, - make_stat_input, +) +from deepmd.utils.model_stat import collect_batches as make_stat_input +from deepmd.utils.model_stat import ( merge_sys_stat, ) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index eaa0892369..f227b65175 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -41,6 +41,7 @@ doc_only_tf_supported = "(Supported Backend: TensorFlow) " doc_only_pt_supported = "(Supported Backend: PyTorch) " +doc_only_pt_expt_supported = "(Supported Backend: PyTorch Exportable) " doc_only_pd_supported = "(Supported Backend: Paddle) " # descriptors doc_loc_frame = "Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame." @@ -3617,6 +3618,17 @@ def training_args( default=0, doc=doc_only_pt_supported + doc_zero_stage, ), + Argument( + "enable_compile", + bool, + optional=True, + default=False, + doc=doc_only_pt_expt_supported + + "Enable torch.compile to accelerate training. " + "Uses make_fx to decompose autograd into primitive ops, " + "then compiles with torch.compile/Inductor for kernel fusion. " + "The first training step will be slower due to one-time compilation.", + ), ] variants = [ Variant( diff --git a/deepmd/utils/model_stat.py b/deepmd/utils/model_stat.py index 8061c7aa9c..33ebbcae57 100644 --- a/deepmd/utils/model_stat.py +++ b/deepmd/utils/model_stat.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from collections import ( defaultdict, ) @@ -8,6 +9,12 @@ import numpy as np +from deepmd.dpmodel.utils.batch import ( + normalize_batch, +) + +log = logging.getLogger(__name__) + def _make_all_stat_ref(data: Any, nbatches: int) -> dict[str, list[Any]]: all_stat = defaultdict(list) @@ -21,17 +28,20 @@ def _make_all_stat_ref(data: Any, nbatches: int) -> dict[str, list[Any]]: return all_stat -def make_stat_input( +def collect_batches( data: Any, nbatches: int, merge_sys: bool = True ) -> dict[str, list[Any]]: - """Pack data for statistics. + """Collect batches from a DeepmdDataSystem into a dict of lists. + + This is a low-level helper used by the TF backend and by + :func:`make_stat_input`. Parameters ---------- data - The data + The data (must support ``get_nsystems()`` and ``get_batch(sys_idx=)``) nbatches : int - The number of batches + The number of batches per system merge_sys : bool (True) Merge system data @@ -62,6 +72,55 @@ def make_stat_input( return all_stat +def make_stat_input( + data: Any, + nbatches: int, +) -> list[dict[str, np.ndarray]]: + """Pack data for statistics using DeepmdDataSystem. + + Collects *nbatches* batches from each system and concatenates them + into a single dict per system. The returned format + (``list[dict[str, np.ndarray]]``) is backend-agnostic and can be + consumed by ``compute_or_load_stat`` in dpmodel, pt_expt, and jax. + + Parameters + ---------- + data + The multi-system data manager + (must support ``get_nsystems()`` and ``get_batch(sys_idx=)``). + nbatches : int + Number of batches to collect per system. + + Returns + ------- + list[dict[str, np.ndarray]] + Per-system dicts with concatenated numpy arrays. + """ + all_stat = collect_batches(data, nbatches, merge_sys=False) + + nsystems = data.get_nsystems() + log.info(f"Packing data for statistics from {nsystems} systems") + + keys = list(all_stat.keys()) + lst: list[dict[str, np.ndarray]] = [] + for ii in range(nsystems): + merged: dict[str, np.ndarray] = {} + for key in keys: + vals = all_stat[key][ii] # list of batch arrays for this system + if isinstance(vals[0], np.ndarray): + if vals[0].ndim >= 2: + merged[key] = np.concatenate(vals, axis=0) + else: + # 1D arrays (e.g. natoms_vec) — per-system constant + merged[key] = vals[0] + else: + # scalar flags like find_* + merged[key] = vals[0] + + lst.append(normalize_batch(merged)) + return lst + + def merge_sys_stat(all_stat: dict[str, list[Any]]) -> dict[str, list[Any]]: first_key = next(iter(all_stat.keys())) nsys = len(all_stat[first_key]) diff --git a/source/tests/common/dpmodel/test_fitting_stat.py b/source/tests/common/dpmodel/test_fitting_stat.py index 101d2a9ad7..498124bd43 100644 --- a/source/tests/common/dpmodel/test_fitting_stat.py +++ b/source/tests/common/dpmodel/test_fitting_stat.py @@ -33,7 +33,9 @@ def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0)) tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0)) sys_dict["fparam"] = tmp_data_f + sys_dict["find_fparam"] = np.float32(1.0) sys_dict["aparam"] = tmp_data_a + sys_dict["find_aparam"] = np.float32(1.0) merged_output_stat.append(sys_dict) return merged_output_stat diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 76b7e9cb53..ed4e2caab9 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -238,7 +238,9 @@ class RefBackend(Enum): ARRAY_API_STRICT = 7 @abstractmethod - def extract_ret(self, ret: Any, backend: RefBackend) -> tuple[np.ndarray, ...]: + def extract_ret( + self, ret: Any, backend: RefBackend + ) -> tuple[np.ndarray, ...] | dict[str, np.ndarray]: """Extract the return value when comparing with other backends. Parameters @@ -250,10 +252,45 @@ def extract_ret(self, ret: Any, backend: RefBackend) -> tuple[np.ndarray, ...]: Returns ------- - tuple[np.ndarray, ...] - The extracted return value + tuple[np.ndarray, ...] | dict[str, np.ndarray] + The extracted return value. If a dict is returned, keys are used + in error messages to identify which value mismatches. """ + def _compare_ret(self, ret1, ret2) -> None: + """Compare two extracted return values (tuple or dict). + + For dicts, keys must match exactly unless one dict contains only + ``"loss"`` (e.g. TF backend), in which case only ``"loss"`` is compared. + """ + if isinstance(ret1, dict) and isinstance(ret2, dict): + keys1, keys2 = sorted(ret1.keys()), sorted(ret2.keys()) + if keys1 == ["loss"] or keys2 == ["loss"]: + compare_keys = ["loss"] + else: + self.assertEqual( + keys1, + keys2, + f"Keys mismatch: {keys1} vs {keys2}", + ) + compare_keys = keys1 + for key in compare_keys: + rr1, rr2 = ret1[key], ret2[key] + if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG: + continue + np.testing.assert_allclose( + rr1, rr2, rtol=self.rtol, atol=self.atol, err_msg=f"key: {key}" + ) + assert rr1.dtype == rr2.dtype, f"key {key}: {rr1.dtype} != {rr2.dtype}" + else: + for rr1, rr2 in zip(ret1, ret2, strict=True): + if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG: + continue + np.testing.assert_allclose( + rr1.ravel(), rr2.ravel(), rtol=self.rtol, atol=self.atol + ) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + def build_eval_tf( self, sess: "tf.Session", obj: Any, suffix: str ) -> list[np.ndarray]: @@ -388,11 +425,7 @@ def test_tf_consistent_with_ref(self) -> None: data2.pop("@version") np.testing.assert_equal(data1, data2) - for rr1, rr2 in zip(ret1, ret2, strict=True): - np.testing.assert_allclose( - rr1.ravel(), rr2.ravel(), rtol=self.rtol, atol=self.atol - ) - assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + self._compare_ret(ret1, ret2) def test_tf_self_consistent(self) -> None: """Test whether TF is self consistent.""" @@ -424,11 +457,7 @@ def test_dp_consistent_with_ref(self) -> None: ret2 = self.extract_ret(ret2, self.RefBackend.DP) data2 = dp_obj.serialize() np.testing.assert_equal(data1, data2) - for rr1, rr2 in zip(ret1, ret2, strict=True): - if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG: - continue - np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) - assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + self._compare_ret(ret1, ret2) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_dp_self_consistent(self) -> None: @@ -469,9 +498,7 @@ def test_pt_consistent_with_ref(self) -> None: data1.pop("@variables", None) data2.pop("@variables", None) np.testing.assert_equal(data1, data2) - for rr1, rr2 in zip(ret1, ret2, strict=True): - np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) - assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + self._compare_ret(ret1, ret2) def test_pt_self_consistent(self) -> None: """Test whether PT is self consistent.""" @@ -510,9 +537,7 @@ def test_pt_expt_consistent_with_ref(self) -> None: data1.pop("@variables", None) data2.pop("@variables", None) np.testing.assert_equal(data1, data2) - for rr1, rr2 in zip(ret1, ret2, strict=True): - np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) - assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + self._compare_ret(ret1, ret2) def test_pt_expt_self_consistent(self) -> None: """Test whether PT exportable is self consistent.""" @@ -547,9 +572,7 @@ def test_jax_consistent_with_ref(self) -> None: data1.pop("@variables", None) data2.pop("@variables", None) np.testing.assert_equal(data1, data2) - for rr1, rr2 in zip(ret1, ret2, strict=True): - np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) - assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + self._compare_ret(ret1, ret2) def test_jax_self_consistent(self) -> None: """Test whether JAX is self consistent.""" @@ -589,9 +612,7 @@ def test_pd_consistent_with_ref(self): data1.pop("@variables", None) data2.pop("@variables", None) np.testing.assert_equal(data1, data2) - for rr1, rr2 in zip(ret1, ret2, strict=True): - np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) - assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + self._compare_ret(ret1, ret2) def test_pd_self_consistent(self): """Test whether PD is self consistent.""" @@ -624,9 +645,7 @@ def test_array_api_strict_consistent_with_ref(self) -> None: ret2 = self.extract_ret(ret2, self.RefBackend.ARRAY_API_STRICT) data2 = array_api_strict_obj.serialize() np.testing.assert_equal(data1, data2) - for rr1, rr2 in zip(ret1, ret2, strict=True): - np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) - assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + self._compare_ret(ret1, ret2) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_array_api_strict_self_consistent(self) -> None: diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index 185a3d5801..ba0f68c163 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -523,6 +523,8 @@ def setUp(self) -> None: "aparam": rng.normal(size=(2, 6, numb_aparam)).astype( GLOBAL_NP_FLOAT_PRECISION ), + "find_fparam": True, + "find_aparam": True, }, { "fparam": rng.normal(size=(3, numb_fparam)).astype( @@ -531,6 +533,8 @@ def setUp(self) -> None: "aparam": rng.normal(size=(3, 6, numb_aparam)).astype( GLOBAL_NP_FLOAT_PRECISION ), + "find_fparam": True, + "find_aparam": True, }, ] @@ -583,6 +587,8 @@ def eval_pt(self, pt_obj: Any) -> Any: { "fparam": torch.from_numpy(d["fparam"]).to(PT_DEVICE), "aparam": torch.from_numpy(d["aparam"]).to(PT_DEVICE), + "find_fparam": d["find_fparam"], + "find_aparam": d["find_aparam"], } for d in self.stat_data ] @@ -669,6 +675,8 @@ def eval_jax(self, jax_obj: Any) -> Any: { "fparam": jnp.asarray(d["fparam"]), "aparam": jnp.asarray(d["aparam"]), + "find_fparam": d["find_fparam"], + "find_aparam": d["find_aparam"], } for d in self.stat_data ] @@ -696,6 +704,8 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: { "fparam": array_api_strict.asarray(d["fparam"]), "aparam": array_api_strict.asarray(d["aparam"]), + "find_fparam": d["find_fparam"], + "find_aparam": d["find_aparam"], } for d in self.stat_data ] @@ -727,6 +737,8 @@ def eval_pd(self, pd_obj: Any) -> Any: { "fparam": paddle.to_tensor(d["fparam"]).to(PD_DEVICE), "aparam": paddle.to_tensor(d["aparam"]).to(PD_DEVICE), + "find_fparam": d["find_fparam"], + "find_aparam": d["find_aparam"], } for d in self.stat_data ] diff --git a/source/tests/consistent/loss/test_ener.py b/source/tests/consistent/loss/test_ener.py index 36a8fba44a..1cc662fc5b 100644 --- a/source/tests/consistent/loss/test_ener.py +++ b/source/tests/consistent/loss/test_ener.py @@ -111,10 +111,10 @@ def setUp(self) -> None: ), } self.predict_dpmodel_style = { - "energy_derv_c_redu": self.predict["virial"], - "energy_derv_r": self.predict["force"], - "energy_redu": self.predict["energy"], - "energy": self.predict["atom_ener"], + "energy": self.predict["energy"], + "force": self.predict["force"], + "virial": self.predict["virial"], + "atom_energy": self.predict["atom_ener"], } self.label = { "energy": rng.random((self.nframes,)), @@ -251,8 +251,17 @@ def eval_pd(self, pd_obj: Any) -> Any: more_loss = {kk: to_numpy_array(vv) for kk, vv in more_loss.items()} return loss, more_loss - def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: - return (ret[0],) + def extract_ret(self, ret: Any, backend) -> dict[str, np.ndarray]: + loss = ret[0] + result = {"loss": np.atleast_1d(np.asarray(loss, dtype=np.float64))} + if len(ret) > 1: + more_loss = ret[1] + for k in sorted(more_loss): + if k.startswith("rmse_"): + result[k] = np.atleast_1d( + np.asarray(more_loss[k], dtype=np.float64) + ) + return result @property def rtol(self) -> float: diff --git a/source/tests/consistent/model/test_dipole.py b/source/tests/consistent/model/test_dipole.py index 3b1bc6644a..8b9c24cd57 100644 --- a/source/tests/consistent/model/test_dipole.py +++ b/source/tests/consistent/model/test_dipole.py @@ -1346,6 +1346,7 @@ def setUp(self) -> None: "dipole": dipole_stat, "find_dipole": np.float32(1.0), "aparam": aparam_stat, + "find_aparam": np.float32(1.0), } # pt sample (torch tensors) pt_sample = { @@ -1357,6 +1358,7 @@ def setUp(self) -> None: "dipole": numpy_to_torch(dipole_stat), "find_dipole": np.float32(1.0), "aparam": numpy_to_torch(aparam_stat), + "find_aparam": np.float32(1.0), } if self.fparam_in_data: @@ -1365,9 +1367,19 @@ def setUp(self) -> None: ) np_sample["fparam"] = fparam_stat pt_sample["fparam"] = numpy_to_torch(fparam_stat) + np_sample["find_fparam"] = np.float32(1.0) + pt_sample["find_fparam"] = np.float32(1.0) self.expected_fparam_avg = np.mean(fparam_stat, axis=0) else: - # No fparam → _make_wrapped_sampler injects default_fparam + # No fparam in data. dpmodel keeps zero-padded fparam with + # find_fparam=0; _make_wrapped_sampler injects default_fparam. + np_sample["fparam"] = np.zeros( + (nframes, 2), dtype=GLOBAL_NP_FLOAT_PRECISION + ) + np_sample["find_fparam"] = np.float32(0.0) + # pt pipeline pops fparam/find_fparam (stat.py), then + # wrapped_sampler injects default_fparam when keys are absent. + # pt_sample has no fparam/find_fparam keys. self.expected_fparam_avg = np.array([0.5, -0.3]) self.np_sampled = [np_sample] diff --git a/source/tests/consistent/model/test_dos.py b/source/tests/consistent/model/test_dos.py index 76cd1daa5c..016b4ffc04 100644 --- a/source/tests/consistent/model/test_dos.py +++ b/source/tests/consistent/model/test_dos.py @@ -1336,6 +1336,7 @@ def setUp(self) -> None: "dos": dos_stat, "find_dos": np.float32(1.0), "aparam": aparam_stat, + "find_aparam": np.float32(1.0), } # pt sample (torch tensors) pt_sample = { @@ -1347,6 +1348,7 @@ def setUp(self) -> None: "dos": numpy_to_torch(dos_stat), "find_dos": np.float32(1.0), "aparam": numpy_to_torch(aparam_stat), + "find_aparam": np.float32(1.0), } if self.fparam_in_data: @@ -1355,9 +1357,19 @@ def setUp(self) -> None: ) np_sample["fparam"] = fparam_stat pt_sample["fparam"] = numpy_to_torch(fparam_stat) + np_sample["find_fparam"] = np.float32(1.0) + pt_sample["find_fparam"] = np.float32(1.0) self.expected_fparam_avg = np.mean(fparam_stat, axis=0) else: - # No fparam -> _make_wrapped_sampler injects default_fparam + # No fparam in data. dpmodel keeps zero-padded fparam with + # find_fparam=0; _make_wrapped_sampler injects default_fparam. + np_sample["fparam"] = np.zeros( + (nframes, 2), dtype=GLOBAL_NP_FLOAT_PRECISION + ) + np_sample["find_fparam"] = np.float32(0.0) + # pt pipeline pops fparam/find_fparam (stat.py), then + # wrapped_sampler injects default_fparam when keys are absent. + # pt_sample has no fparam/find_fparam keys. self.expected_fparam_avg = np.array([0.5, -0.3]) self.np_sampled = [np_sample] diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index a49e5f8ca7..038b428fd9 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -1636,6 +1636,7 @@ def setUp(self) -> None: "energy": energy_stat, "find_energy": np.float32(1.0), "aparam": aparam_stat, + "find_aparam": np.float32(1.0), } # pt sample (torch tensors) pt_sample = { @@ -1647,6 +1648,7 @@ def setUp(self) -> None: "energy": numpy_to_torch(energy_stat), "find_energy": np.float32(1.0), "aparam": numpy_to_torch(aparam_stat), + "find_aparam": np.float32(1.0), } if self.fparam_in_data: @@ -1655,9 +1657,19 @@ def setUp(self) -> None: ) np_sample["fparam"] = fparam_stat pt_sample["fparam"] = numpy_to_torch(fparam_stat) + np_sample["find_fparam"] = np.float32(1.0) + pt_sample["find_fparam"] = np.float32(1.0) self.expected_fparam_avg = np.mean(fparam_stat, axis=0) else: - # No fparam → _make_wrapped_sampler injects default_fparam + # No fparam in data. dpmodel keeps zero-padded fparam with + # find_fparam=0; _make_wrapped_sampler injects default_fparam. + np_sample["fparam"] = np.zeros( + (nframes, 2), dtype=GLOBAL_NP_FLOAT_PRECISION + ) + np_sample["find_fparam"] = np.float32(0.0) + # pt pipeline pops fparam/find_fparam (stat.py), then + # wrapped_sampler injects default_fparam when keys are absent. + # pt_sample has no fparam/find_fparam keys. self.expected_fparam_avg = np.array([0.5, -0.3]) self.np_sampled = [np_sample] diff --git a/source/tests/consistent/model/test_polar.py b/source/tests/consistent/model/test_polar.py index 17235fb362..4fe3a2c6df 100644 --- a/source/tests/consistent/model/test_polar.py +++ b/source/tests/consistent/model/test_polar.py @@ -1340,6 +1340,7 @@ def setUp(self) -> None: "polarizability": polar_stat, "find_polarizability": np.float32(1.0), "aparam": aparam_stat, + "find_aparam": np.float32(1.0), } # pt sample (torch tensors) pt_sample = { @@ -1351,6 +1352,7 @@ def setUp(self) -> None: "polarizability": numpy_to_torch(polar_stat), "find_polarizability": np.float32(1.0), "aparam": numpy_to_torch(aparam_stat), + "find_aparam": np.float32(1.0), } if self.fparam_in_data: @@ -1359,9 +1361,19 @@ def setUp(self) -> None: ) np_sample["fparam"] = fparam_stat pt_sample["fparam"] = numpy_to_torch(fparam_stat) + np_sample["find_fparam"] = np.float32(1.0) + pt_sample["find_fparam"] = np.float32(1.0) self.expected_fparam_avg = np.mean(fparam_stat, axis=0) else: - # No fparam → _make_wrapped_sampler injects default_fparam + # No fparam in data. dpmodel keeps zero-padded fparam with + # find_fparam=0; _make_wrapped_sampler injects default_fparam. + np_sample["fparam"] = np.zeros( + (nframes, 2), dtype=GLOBAL_NP_FLOAT_PRECISION + ) + np_sample["find_fparam"] = np.float32(0.0) + # pt pipeline pops fparam/find_fparam (stat.py), then + # wrapped_sampler injects default_fparam when keys are absent. + # pt_sample has no fparam/find_fparam keys. self.expected_fparam_avg = np.array([0.5, -0.3]) self.np_sampled = [np_sample] diff --git a/source/tests/consistent/model/test_property.py b/source/tests/consistent/model/test_property.py index c03cca2010..326e72bbeb 100644 --- a/source/tests/consistent/model/test_property.py +++ b/source/tests/consistent/model/test_property.py @@ -1333,6 +1333,7 @@ def setUp(self) -> None: "foo": foo_stat, "find_foo": np.float32(1.0), "aparam": aparam_stat, + "find_aparam": np.float32(1.0), } # pt sample (torch tensors) pt_sample = { @@ -1344,6 +1345,7 @@ def setUp(self) -> None: "foo": numpy_to_torch(foo_stat), "find_foo": np.float32(1.0), "aparam": numpy_to_torch(aparam_stat), + "find_aparam": np.float32(1.0), } if self.fparam_in_data: @@ -1352,9 +1354,19 @@ def setUp(self) -> None: ) np_sample["fparam"] = fparam_stat pt_sample["fparam"] = numpy_to_torch(fparam_stat) + np_sample["find_fparam"] = np.float32(1.0) + pt_sample["find_fparam"] = np.float32(1.0) self.expected_fparam_avg = np.mean(fparam_stat, axis=0) else: - # No fparam -> _make_wrapped_sampler injects default_fparam + # No fparam in data. dpmodel keeps zero-padded fparam with + # find_fparam=0; _make_wrapped_sampler injects default_fparam. + np_sample["fparam"] = np.zeros( + (nframes, 2), dtype=GLOBAL_NP_FLOAT_PRECISION + ) + np_sample["find_fparam"] = np.float32(0.0) + # pt pipeline pops fparam/find_fparam (stat.py), then + # wrapped_sampler injects default_fparam when keys are absent. + # pt_sample has no fparam/find_fparam keys. self.expected_fparam_avg = np.array([0.5, -0.3]) self.np_sampled = [np_sample] diff --git a/source/tests/consistent/test_make_stat_input.py b/source/tests/consistent/test_make_stat_input.py new file mode 100644 index 0000000000..5a927f5541 --- /dev/null +++ b/source/tests/consistent/test_make_stat_input.py @@ -0,0 +1,371 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Consistency test between universal make_stat_input and pt make_stat_input. + +The universal make_stat_input (deepmd.utils.model_stat) uses DeepmdDataSystem +(numpy-based). The pt make_stat_input (deepmd.pt.utils.stat) uses DpLoaderSet + +DataLoader (torch-based). This test verifies that both produce equivalent +per-system stat dicts for the keys consumed by compute_or_load_stat. +""" + +import os +import unittest + +import numpy as np + +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + +from .common import ( + INSTALLED_PT, +) + +TESTS_DIR = os.path.dirname(os.path.dirname(__file__)) +EXAMPLE_DIR = os.path.join(TESTS_DIR, "..", "..", "examples", "water") + + +def _build_config( + systems: list[str], + type_map: list[str], + *, + data_stat_nbatch: int = 2, + numb_fparam: int = 0, + numb_aparam: int = 0, +) -> dict: + config = { + "model": { + "type_map": type_map, + "descriptor": { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": data_stat_nbatch, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": {"systems": systems, "batch_size": 1}, + "validation_data": { + "systems": systems[:1], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 99999, + "save_freq": 99999, + }, + } + if numb_fparam > 0: + config["model"]["fitting_net"]["numb_fparam"] = numb_fparam + if numb_aparam > 0: + config["model"]["fitting_net"]["numb_aparam"] = numb_aparam + config = update_deepmd_input(config, warning=False) + config = normalize(config) + return config + + +def _get_universal_stat(config: dict, data_requirement: list[DataRequirementItem]): + """Get stat using the universal make_stat_input (DeepmdDataSystem).""" + from deepmd.utils.data_system import ( + DeepmdDataSystem, + ) + from deepmd.utils.model_stat import ( + make_stat_input, + ) + + model_params = config["model"] + training_params = config["training"] + systems = training_params["training_data"]["systems"] + nbatch = model_params.get("data_stat_nbatch", 10) + + data = DeepmdDataSystem( + systems=systems, + batch_size=training_params["training_data"]["batch_size"], + test_size=1, + rcut=model_params["descriptor"]["rcut"], + type_map=model_params["type_map"], + ) + for item in data_requirement: + data.add( + item.key, + item.ndof, + atomic=item.atomic, + must=item.must, + high_prec=item.high_prec, + type_sel=item.type_sel, + repeat=item.repeat, + default=item.default, + dtype=item.dtype, + output_natoms_for_type_sel=item.output_natoms_for_type_sel, + ) + + return make_stat_input(data, nbatch) + + +def _get_pt_stat(config: dict, data_requirement: list[DataRequirementItem]): + """Get stat using the pt make_stat_input (DpLoaderSet + DataLoader).""" + from deepmd.pt.utils.dataloader import ( + DpLoaderSet, + ) + from deepmd.pt.utils.stat import ( + make_stat_input, + ) + + model_params = config["model"] + training_params = config["training"] + systems = training_params["training_data"]["systems"] + nbatch = model_params.get("data_stat_nbatch", 10) + + loader = DpLoaderSet( + systems, + training_params["training_data"]["batch_size"], + model_params["type_map"], + seed=10, + ) + for item in data_requirement: + loader.add_data_requirement([item]) + + return make_stat_input(loader.systems, loader.dataloaders, nbatch) + + +def _to_numpy(val): + """Convert torch.Tensor or np.ndarray to numpy.""" + import torch + + if isinstance(val, torch.Tensor): + return val.detach().cpu().numpy() + return val + + +def _compare_stat( + test_case: unittest.TestCase, + universal_stat: list[dict], + pt_stat: list[dict], + check_keys: list[str], +) -> None: + """Compare universal and pt stat outputs for the given keys. + + Verifies structural equivalence: same number of systems, same key + presence, matching find_* flags, consistent nframes, and consistent + per-frame sizes. + """ + test_case.assertEqual(len(universal_stat), len(pt_stat)) + for sys_idx in range(len(universal_stat)): + for key in check_keys: + in_uni = key in universal_stat[sys_idx] + in_pt = key in pt_stat[sys_idx] + # pt pops fparam/find_fparam when find_fparam==0 but + # universal keeps them. Skip when find_* is 0. + if in_uni and not in_pt: + find_key = f"find_{key}" if not key.startswith("find_") else key + find_val = universal_stat[sys_idx].get(find_key, None) + if find_val is not None and not find_val: + continue + test_case.assertEqual( + in_uni, in_pt, f"system {sys_idx}: key '{key}' presence mismatch" + ) + if not in_uni: + continue + + v_uni = _to_numpy(universal_stat[sys_idx][key]) + v_pt = _to_numpy(pt_stat[sys_idx][key]) + + if key.startswith("find_"): + # universal returns bool, pt returns float32 + test_case.assertEqual( + bool(v_uni), + bool(float(np.ravel(v_pt)[0]) > 0.5), + f"system {sys_idx}, key '{key}': find flag mismatch", + ) + continue + + v_uni = np.asarray(v_uni, dtype=np.float64) + v_pt = np.asarray(v_pt, dtype=np.float64) + + nf_uni = v_uni.shape[0] if v_uni.ndim >= 2 else 1 + nf_pt = v_pt.shape[0] if v_pt.ndim >= 2 else 1 + test_case.assertEqual( + nf_uni, + nf_pt, + f"system {sys_idx}, key '{key}': nframes mismatch", + ) + # coord shape differs: universal [nf, natoms*3], pt [nf, natoms, 3]. + # Compare per-frame size. + test_case.assertEqual( + v_uni.size // max(nf_uni, 1), + v_pt.size // max(nf_pt, 1), + f"system {sys_idx}, key '{key}': per-frame size mismatch " + f"(uni shape {v_uni.shape}, pt shape {v_pt.shape})", + ) + + +# --- Standard data requirements for energy model --- +_ENER_DATA_REQ = [ + DataRequirementItem("energy", 1, atomic=False, must=False, high_prec=True), + DataRequirementItem("force", 3, atomic=True, must=False, high_prec=False), +] + +_COMMON_CHECK_KEYS = [ + "atype", + "box", + "coord", + "energy", + "natoms", + "find_energy", + "find_force", +] + + +@unittest.skipUnless(INSTALLED_PT, "PyTorch backend not installed") +class TestMakeStatInputNormal(unittest.TestCase): + """Test with normal (non-mixed-type) water data, multiple systems.""" + + def test_consistency(self) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + self.skipTest(f"Example data not found: {data_dir}") + + systems = [ + os.path.join(data_dir, "data_0"), + os.path.join(data_dir, "data_1"), + ] + config = _build_config(systems, ["O", "H"]) + + universal_stat = _get_universal_stat(config, _ENER_DATA_REQ) + pt_stat = _get_pt_stat(config, _ENER_DATA_REQ) + + self.assertEqual(len(universal_stat), 2) + _compare_stat(self, universal_stat, pt_stat, _COMMON_CHECK_KEYS) + + +@unittest.skipUnless(INSTALLED_PT, "PyTorch backend not installed") +class TestMakeStatInputMixedType(unittest.TestCase): + """Test with mixed-type data.""" + + def test_consistency(self) -> None: + data_dir = os.path.join(TESTS_DIR, "tf", "finetune", "data_mixed_type") + if not os.path.isdir(data_dir): + self.skipTest(f"Mixed-type data not found: {data_dir}") + + config = _build_config([data_dir], ["O", "H"]) + + universal_stat = _get_universal_stat(config, _ENER_DATA_REQ) + pt_stat = _get_pt_stat(config, _ENER_DATA_REQ) + + _compare_stat( + self, + universal_stat, + pt_stat, + [*_COMMON_CHECK_KEYS, "real_natoms_vec"], + ) + + # For mixed-type data, real_natoms_vec is the per-frame version. + # Verify it is present in the universal output. + for sys_idx in range(len(universal_stat)): + self.assertIn( + "real_natoms_vec", + universal_stat[sys_idx], + f"system {sys_idx}: real_natoms_vec should be present " + f"for mixed-type data", + ) + + +@unittest.skipUnless(INSTALLED_PT, "PyTorch backend not installed") +class TestMakeStatInputFparamAparam(unittest.TestCase): + """Test with data containing fparam and aparam, multiple systems.""" + + def test_consistency(self) -> None: + data_dir = os.path.join(TESTS_DIR, "pt", "model", "water", "data") + if not os.path.isdir(data_dir): + self.skipTest(f"Water fparam data not found: {data_dir}") + + # data_0 has fparam/aparam, data_1 does not — tests find_fparam=0 case + systems = [ + os.path.join(data_dir, "data_0"), + os.path.join(data_dir, "data_1"), + ] + config = _build_config(systems, ["O", "H"], numb_fparam=2, numb_aparam=1) + + data_requirement = [ + *_ENER_DATA_REQ, + DataRequirementItem("fparam", 2, atomic=False, must=False, high_prec=False), + DataRequirementItem("aparam", 1, atomic=True, must=False, high_prec=False), + ] + universal_stat = _get_universal_stat(config, data_requirement) + pt_stat = _get_pt_stat(config, data_requirement) + + self.assertEqual(len(universal_stat), 2) + _compare_stat( + self, + universal_stat, + pt_stat, + [*_COMMON_CHECK_KEYS, "fparam", "aparam", "find_fparam", "find_aparam"], + ) + + +@unittest.skipUnless(INSTALLED_PT, "PyTorch backend not installed") +class TestMakeStatInputSpin(unittest.TestCase): + """Test with data containing spin, multiple systems.""" + + def test_consistency(self) -> None: + data_dir = os.path.join(TESTS_DIR, "pt", "NiO", "data") + if not os.path.isdir(data_dir): + self.skipTest(f"NiO spin data not found: {data_dir}") + + systems = [ + os.path.join(data_dir, "data_0"), + os.path.join(data_dir, "data_0"), + ] + config = _build_config(systems, ["Ni", "O"]) + + data_requirement = [ + *_ENER_DATA_REQ, + DataRequirementItem("spin", 3, atomic=True, must=True, high_prec=False), + ] + universal_stat = _get_universal_stat(config, data_requirement) + pt_stat = _get_pt_stat(config, data_requirement) + + self.assertEqual(len(universal_stat), 2) + _compare_stat( + self, + universal_stat, + pt_stat, + [*_COMMON_CHECK_KEYS, "spin", "find_spin"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/conftest.py b/source/tests/pt_expt/conftest.py index ec025c2202..15791050d6 100644 --- a/source/tests/pt_expt/conftest.py +++ b/source/tests/pt_expt/conftest.py @@ -1,4 +1,52 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +"""Conftest for pt_expt tests. + +Clears any leaked ``torch.utils._device.DeviceContext`` modes that may +have been left on the torch function mode stack by ``make_fx`` or other +tracing utilities during test collection. A stale ``DeviceContext`` +silently reroutes ``torch.tensor(...)`` calls (without an explicit +``device=``) to a fake CUDA device, causing spurious "no NVIDIA driver" +errors on CPU-only machines. + +The leak is triggered when pytest collects descriptor test modules that +import ``make_fx``. A ``DeviceContext(cuda:127)`` ends up on the +``torch.overrides`` function mode stack and is never popped. + +Our own code (``display_if_exist`` in ``deepmd/dpmodel/loss/loss.py``) +is already fixed to pass ``device=`` explicitly. However, PyTorch's +``Adam._init_group`` (``torch/optim/adam.py``) contains:: + + torch.tensor(0.0, dtype=_get_scalar_dtype()) # no device= + +on the ``capturable=False, fused=False`` path (the default). This is +a PyTorch bug — the ``capturable=True`` branch correctly uses +``device=p.device`` but the default branch omits it. We cannot fix +PyTorch internals, so this fixture works around the issue by popping +leaked ``DeviceContext`` modes before each test. +""" + import pytest +import torch.utils._device as _device +from torch.overrides import ( + _get_current_function_mode_stack, +) + -pytest.importorskip("torch") +@pytest.fixture(autouse=True) +def _clear_leaked_device_context(): + """Pop any stale ``DeviceContext`` before each test, restore after.""" + popped = [] + while True: + modes = _get_current_function_mode_stack() + if not modes: + break + top = modes[-1] + if isinstance(top, _device.DeviceContext): + top.__exit__(None, None, None) + popped.append(top) + else: + break + yield + # Restore in reverse order so the stack is back to its original state. + for ctx in reversed(popped): + ctx.__enter__() diff --git a/source/tests/pt_expt/fitting/test_fitting_stat.py b/source/tests/pt_expt/fitting/test_fitting_stat.py index b473c9309c..dcb99dd324 100644 --- a/source/tests/pt_expt/fitting/test_fitting_stat.py +++ b/source/tests/pt_expt/fitting/test_fitting_stat.py @@ -40,6 +40,8 @@ def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): # dpmodel's compute_input_stats expects numpy arrays sys_dict["fparam"] = tmp_data_f sys_dict["aparam"] = tmp_data_a + sys_dict["find_fparam"] = True + sys_dict["find_aparam"] = True merged_output_stat.append(sys_dict) return merged_output_stat diff --git a/source/tests/pt_expt/loss/__init__.py b/source/tests/pt_expt/loss/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/loss/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/loss/test_ener.py b/source/tests/pt_expt/loss/test_ener.py new file mode 100644 index 0000000000..37d7d4c703 --- /dev/null +++ b/source/tests/pt_expt/loss/test_ener.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for the pt_expt EnergyLoss wrapper. + +Three test types: +- test_consistency — construct -> forward -> serialize/deserialize -> forward -> compare; + also compare with dpmodel +- test_consistency_with_find_flags — same but with find_* flags as torch tensors + (mimicking real training where get_data converts them) +""" + +import numpy as np +import pytest +import torch + +from deepmd.dpmodel.loss.ener import EnergyLoss as EnergyLossDP +from deepmd.pt_expt.loss.ener import ( + EnergyLoss, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +def _make_data( + rng: np.random.Generator, + nframes: int, + natoms: int, + dtype: torch.dtype, + device: torch.device, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + """Build model prediction and label dicts as torch tensors.""" + model_pred = { + "energy": torch.tensor(rng.random((nframes,)), dtype=dtype, device=device), + "force": torch.tensor( + rng.random((nframes, natoms, 3)), dtype=dtype, device=device + ), + "virial": torch.tensor(rng.random((nframes, 9)), dtype=dtype, device=device), + "atom_energy": torch.tensor( + rng.random((nframes, natoms)), dtype=dtype, device=device + ), + } + label = { + "energy": torch.tensor(rng.random((nframes,)), dtype=dtype, device=device), + "force": torch.tensor( + rng.random((nframes, natoms, 3)), dtype=dtype, device=device + ), + "virial": torch.tensor(rng.random((nframes, 9)), dtype=dtype, device=device), + "atom_ener": torch.tensor( + rng.random((nframes, natoms)), dtype=dtype, device=device + ), + "atom_pref": torch.ones((nframes, natoms, 3), dtype=dtype, device=device), + "find_energy": torch.tensor(1.0, dtype=dtype, device=device), + "find_force": torch.tensor(1.0, dtype=dtype, device=device), + "find_virial": torch.tensor(1.0, dtype=dtype, device=device), + "find_atom_ener": torch.tensor(1.0, dtype=dtype, device=device), + "find_atom_pref": torch.tensor(1.0, dtype=dtype, device=device), + } + return model_pred, label + + +class TestEnergyLoss: + def setup_method(self) -> None: + self.device = env.DEVICE + + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + @pytest.mark.parametrize("use_huber", [False, True]) # huber loss + def test_consistency(self, prec, use_huber) -> None: + """Construct -> forward -> serialize/deserialize -> forward -> compare. + + Also compare with dpmodel. + """ + rng = np.random.default_rng(GLOBAL_SEED) + nframes, natoms = 2, 6 + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + learning_rate = 1e-3 + + loss0 = EnergyLoss( + starter_learning_rate=1e-3, + start_pref_e=0.02, + limit_pref_e=1.0, + start_pref_f=1000.0, + limit_pref_f=1.0, + start_pref_v=1.0, + limit_pref_v=1.0, + start_pref_ae=1.0, + limit_pref_ae=1.0, + start_pref_pf=0.0 if use_huber else 1.0, + limit_pref_pf=0.0 if use_huber else 1.0, + use_huber=use_huber, + ).to(self.device) + + model_pred, label = _make_data(rng, nframes, natoms, dtype, self.device) + + # Forward + l0, more0 = loss0(learning_rate, natoms, model_pred, label) + assert l0.shape == () + assert "rmse" in more0 + + # Serialize / deserialize round-trip + loss1 = EnergyLoss.deserialize(loss0.serialize()) + l1, more1 = loss1(learning_rate, natoms, model_pred, label) + + np.testing.assert_allclose( + l0.detach().cpu().numpy(), + l1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + for key in more0: + np.testing.assert_allclose( + more0[key].detach().cpu().numpy(), + more1[key].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=f"key={key}", + ) + + # Compare with dpmodel (numpy) + dp_loss = EnergyLossDP.deserialize(loss0.serialize()) + model_pred_np = {k: v.detach().cpu().numpy() for k, v in model_pred.items()} + label_np = { + k: v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v + for k, v in label.items() + } + l_dp, more_dp = dp_loss(learning_rate, natoms, model_pred_np, label_np) + + np.testing.assert_allclose( + l0.detach().cpu().numpy(), + np.array(l_dp), + rtol=rtol, + atol=atol, + err_msg="pt_expt vs dpmodel", + ) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py new file mode 100644 index 0000000000..3b3ab247bb --- /dev/null +++ b/source/tests/pt_expt/test_training.py @@ -0,0 +1,630 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Smoke test for the pt_expt training infrastructure. + +Verifies that: +1. ``get_model`` constructs a model from config +2. ``make_stat_input`` + ``compute_or_load_stat`` work +3. A few training steps run without error +4. Loss decreases over those steps +""" + +import os +import shutil +import tempfile +import unittest + +import torch + +from deepmd.pt_expt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt_expt.model import ( + get_model, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) + +EXAMPLE_DIR = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "examples", + "water", +) + + +def _make_config(data_dir: str, numb_steps: int = 5) -> dict: + """Build a minimal config dict pointing at *data_dir*.""" + config = { + "model": { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [ + os.path.join(data_dir, "data_0"), + ], + "batch_size": 1, + }, + "validation_data": { + "systems": [ + os.path.join(data_dir, "data_3"), + ], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 5, + "save_freq": numb_steps, + }, + } + return config + + +class TestTraining(unittest.TestCase): + """Basic smoke test for the pt_expt training loop.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + + def test_get_model(self) -> None: + """Test that get_model constructs a model from config.""" + config = _make_config(self.data_dir) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + model = get_model(config["model"]) + # model should be a torch.nn.Module + self.assertIsInstance(model, torch.nn.Module) + # should have parameters + nparams = sum(p.numel() for p in model.parameters()) + self.assertGreater(nparams, 0) + + def _run_training(self, config: dict) -> None: + """Run training and verify lcurve + checkpoint creation.""" + tmpdir = tempfile.mkdtemp(prefix="pt_expt_train_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + trainer.run() + + # Read lcurve to verify training ran + lcurve_path = os.path.join(tmpdir, "lcurve.out") + self.assertTrue(os.path.exists(lcurve_path), "lcurve.out not created") + + with open(lcurve_path) as f: + lines = [l for l in f.readlines() if not l.startswith("#")] + self.assertGreater(len(lines), 0, "lcurve.out is empty") + + # Verify checkpoint was saved + ckpt_files = [f for f in os.listdir(tmpdir) if f.endswith(".pt")] + self.assertGreater(len(ckpt_files), 0, "No checkpoint files saved") + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_training_loop(self) -> None: + """Run a few training steps and verify outputs.""" + config = _make_config(self.data_dir, numb_steps=5) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + self._run_training(config) + + def test_training_loop_compiled(self) -> None: + """Run a few training steps with torch.compile enabled.""" + config = _make_config(self.data_dir, numb_steps=5) + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + self._run_training(config) + + +class TestCompiledRecompile(unittest.TestCase): + """Test that _CompiledModel recompiles when nall exceeds max_nall.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + + def test_nall_growth_triggers_recompile(self) -> None: + """Shrink max_nall to force a recompile, then verify training works.""" + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + config = _make_config(self.data_dir, numb_steps=5) + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_recompile_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + + # The wrapper.model should be a _CompiledModel + compiled_model = trainer.wrapper.model + self.assertIsInstance(compiled_model, _CompiledModel) + + original_max_nall = compiled_model._max_nall + self.assertGreater(original_max_nall, 0) + + # Artificially shrink max_nall to 1 so the next batch + # will certainly exceed it and trigger recompilation. + compiled_model._max_nall = 1 + old_compiled_lower = compiled_model.compiled_forward_lower + + # Run one training step — should trigger recompile + trainer.wrapper.train() + trainer.optimizer.zero_grad(set_to_none=True) + inp, lab = trainer.get_data(is_train=True) + lr = trainer.scheduler.get_last_lr()[0] + _, loss, more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab) + loss.backward() + trainer.optimizer.step() + + # max_nall should have grown beyond 1 + new_max_nall = compiled_model._max_nall + self.assertGreater(new_max_nall, 1) + + # compiled_forward_lower should be a new object + self.assertIsNot( + compiled_model.compiled_forward_lower, + old_compiled_lower, + ) + + # Loss should be a finite scalar + self.assertFalse(torch.isnan(loss)) + self.assertFalse(torch.isinf(loss)) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +class TestCompiledConsistency(unittest.TestCase): + """Verify compiled model produces the same energy/force/virial as uncompiled.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + + def test_compiled_matches_uncompiled(self) -> None: + """Energy, force, virial from compiled model must match uncompiled.""" + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + config = _make_config(self.data_dir, numb_steps=1) + # enable virial in loss so the model returns it + config["loss"]["start_pref_v"] = 1.0 + config["loss"]["limit_pref_v"] = 1.0 + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_consistency_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + # Uncompiled model reference + uncompiled_model = trainer.model + uncompiled_model.eval() + + # Build compiled model from the same weights + config_compiled = _make_config(self.data_dir, numb_steps=1) + config_compiled["loss"]["start_pref_v"] = 1.0 + config_compiled["loss"]["limit_pref_v"] = 1.0 + config_compiled["training"]["enable_compile"] = True + config_compiled = update_deepmd_input(config_compiled, warning=False) + config_compiled = normalize(config_compiled) + trainer_compiled = get_trainer(config_compiled) + compiled_model = trainer_compiled.wrapper.model + self.assertIsInstance(compiled_model, _CompiledModel) + + # Copy uncompiled weights to compiled model so they match + compiled_model.original_model.load_state_dict( + uncompiled_model.state_dict() + ) + compiled_model.eval() + + # Get a batch and run both models + input_dict, _ = trainer.get_data(is_train=True) + coord = input_dict["coord"].detach() + atype = input_dict["atype"].detach() + box = input_dict.get("box") + if box is not None: + box = box.detach() + + # Force is computed via autograd.grad inside the model, so + # we cannot use torch.no_grad() here. + coord_uc = coord.clone().requires_grad_(True) + pred_uc = uncompiled_model(coord_uc, atype, box) + + pred_c = compiled_model(coord.clone(), atype, box) + + # Energy + torch.testing.assert_close( + pred_c["energy"], + pred_uc["energy"], + atol=1e-10, + rtol=1e-10, + msg="energy mismatch between compiled and uncompiled", + ) + # Force + self.assertIn("force", pred_c, "compiled model missing 'force'") + self.assertIn("force", pred_uc, "uncompiled model missing 'force'") + torch.testing.assert_close( + pred_c["force"], + pred_uc["force"], + atol=1e-10, + rtol=1e-10, + msg="force mismatch between compiled and uncompiled", + ) + # Virial + if "virial" in pred_uc: + self.assertIn("virial", pred_c, "compiled model missing 'virial'") + torch.testing.assert_close( + pred_c["virial"], + pred_uc["virial"], + atol=1e-10, + rtol=1e-10, + msg="virial mismatch between compiled and uncompiled", + ) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +class TestGetData(unittest.TestCase): + """Test the batch data conversion in Trainer.get_data.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + + def test_batch_shapes(self) -> None: + """Verify input/label shapes from get_data.""" + config = _make_config(self.data_dir, numb_steps=5) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_getdata_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + input_dict, label_dict = trainer.get_data(is_train=True) + + # coord should be a tensor with requires_grad + self.assertIsInstance(input_dict["coord"], torch.Tensor) + self.assertTrue(input_dict["coord"].requires_grad) + + # atype should be an integer tensor + self.assertIsInstance(input_dict["atype"], torch.Tensor) + + # force label should be a tensor + if "force" in label_dict: + self.assertIsInstance(label_dict["force"], torch.Tensor) + + # energy label should exist + self.assertIn("energy", label_dict) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +class TestRestart(unittest.TestCase): + """Test restart and init_model resume paths.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + + def _train_and_get_ckpt(self, config: dict, tmpdir: str) -> str: + """Train and return the path to the final checkpoint.""" + trainer = get_trainer(config) + trainer.run() + # find the latest checkpoint symlink + ckpt = os.path.join(tmpdir, "model.ckpt.pt") + self.assertTrue(os.path.exists(ckpt), "Checkpoint not created") + return ckpt + + def test_restart(self) -> None: + """Train 5 steps, restart from checkpoint, train 5 more.""" + tmpdir = tempfile.mkdtemp(prefix="pt_expt_restart_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + # Phase 1: train 5 steps + config = _make_config(self.data_dir, numb_steps=5) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + ckpt_path = self._train_and_get_ckpt(config, tmpdir) + + # Phase 2: restart from checkpoint, train to step 10 + config2 = _make_config(self.data_dir, numb_steps=10) + config2 = update_deepmd_input(config2, warning=False) + config2 = normalize(config2) + trainer2 = get_trainer(config2, restart_model=ckpt_path) + + # start_step should be restored + self.assertEqual(trainer2.start_step, 5) + + # LR should match the schedule at the resumed step, + # not double-count start_step. + expected_lr = trainer2.lr_schedule.value(trainer2.start_step) + actual_lr = trainer2.scheduler.get_last_lr()[0] + self.assertAlmostEqual( + actual_lr, + expected_lr, + places=10, + msg=f"LR after restart should be lr_schedule({trainer2.start_step})" + f"={expected_lr}, got {actual_lr}", + ) + + trainer2.run() + + # lcurve should have entries appended (restart opens in append mode) + with open(os.path.join(tmpdir, "lcurve.out")) as f: + lines = [l for l in f.readlines() if not l.startswith("#")] + self.assertGreater(len(lines), 0, "lcurve.out is empty after restart") + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_init_model(self) -> None: + """Train 5 steps, init_model from checkpoint (reset step), train 5 more.""" + tmpdir = tempfile.mkdtemp(prefix="pt_expt_init_model_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + # Phase 1: train 5 steps + config = _make_config(self.data_dir, numb_steps=5) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + ckpt_path = self._train_and_get_ckpt(config, tmpdir) + + # Phase 2: init_model — weights loaded but step reset to 0 + config2 = _make_config(self.data_dir, numb_steps=5) + config2 = update_deepmd_input(config2, warning=False) + config2 = normalize(config2) + trainer2 = get_trainer(config2, init_model=ckpt_path) + + # init_model resets step to 0 + self.assertEqual(trainer2.start_step, 0) + trainer2.run() + + with open(os.path.join(tmpdir, "lcurve.out")) as f: + lines = [l for l in f.readlines() if not l.startswith("#")] + self.assertGreater( + len(lines), 0, "lcurve.out is empty after init_model" + ) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_restart_with_compile(self) -> None: + """Train uncompiled, restart with compile enabled.""" + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_restart_compile_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + # Phase 1: train 5 steps without compile + config = _make_config(self.data_dir, numb_steps=5) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + ckpt_path = self._train_and_get_ckpt(config, tmpdir) + + # Phase 2: restart with compile enabled + config2 = _make_config(self.data_dir, numb_steps=10) + config2["training"]["enable_compile"] = True + config2 = update_deepmd_input(config2, warning=False) + config2 = normalize(config2) + trainer2 = get_trainer(config2, restart_model=ckpt_path) + + self.assertEqual(trainer2.start_step, 5) + self.assertIsInstance(trainer2.wrapper.model, _CompiledModel) + trainer2.run() + + with open(os.path.join(tmpdir, "lcurve.out")) as f: + lines = [l for l in f.readlines() if not l.startswith("#")] + self.assertGreater(len(lines), 0) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +def _make_dpa3_config(data_dir: str, numb_steps: int = 5) -> dict: + """Build a minimal DPA3 config dict pointing at *data_dir*.""" + config = { + "model": { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 4, + "a_dim": 4, + "nlayers": 2, + "e_rcut": 3.0, + "e_rcut_smth": 0.5, + "e_sel": 18, + "a_rcut": 2.5, + "a_rcut_smth": 0.5, + "a_sel": 10, + "axis_neuron": 4, + "fix_stat_std": 0.3, + }, + "seed": 1, + }, + "fitting_net": { + "neuron": [8, 8], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [ + os.path.join(data_dir, "data_0"), + ], + "batch_size": 1, + }, + "validation_data": { + "systems": [ + os.path.join(data_dir, "data_3"), + ], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 5, + "save_freq": numb_steps, + }, + } + return config + + +class TestTrainingDPA3(unittest.TestCase): + """Smoke test for the pt_expt training loop with DPA3 descriptor.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + + def test_get_model(self) -> None: + """Test that get_model constructs a DPA3 model from config.""" + config = _make_dpa3_config(self.data_dir) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + model = get_model(config["model"]) + self.assertIsInstance(model, torch.nn.Module) + nparams = sum(p.numel() for p in model.parameters()) + self.assertGreater(nparams, 0) + + def test_training_loop(self) -> None: + """Run a few DPA3 training steps and verify outputs.""" + config = _make_dpa3_config(self.data_dir, numb_steps=5) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_dpa3_train_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + trainer.run() + + lcurve_path = os.path.join(tmpdir, "lcurve.out") + self.assertTrue(os.path.exists(lcurve_path), "lcurve.out not created") + + with open(lcurve_path) as f: + lines = [l for l in f.readlines() if not l.startswith("#")] + self.assertGreater(len(lines), 0, "lcurve.out is empty") + + ckpt_files = [f for f in os.listdir(tmpdir) if f.endswith(".pt")] + self.assertGreater(len(ckpt_files), 0, "No checkpoint files saved") + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +if __name__ == "__main__": + unittest.main()