feat(pt_expt): add training infrastructure#5270
feat(pt_expt): add training infrastructure#5270wanghan-iapcm merged 76 commits intodeepmodeling:masterfrom
Conversation
…ts to dpmodel's folder.
…om the base models of the corresponding backend
Add TestEnerComputeOrLoadStat to the consistency test framework, comparing
dp, pt, and pt_expt backends after compute_or_load_stat. Tests cover
descriptor stats, fparam/aparam fitting stats, output bias, and forward
consistency, parameterized over exclusion types and fparam source
(default injection vs explicit data). Both compute and load-from-file
paths are tested.
Three dpmodel bugs found and fixed:
- repflows.py: compute_input_stats now respects set_stddev_constant,
matching the pt backend behavior
- stat.py: compute_output_stats_global now applies atom_exclude_types
mask to natoms before computing output bias
- general_fitting.py: compute_input_stats now supports save/load of
fparam/aparam stats via stat_file_path, matching the pt backend
…-other-full-model
…-other-full-model
Move get_observed_type_list from a PT-only method to a backend-independent abstract API on BaseBaseModel, with a concrete implementation in dpmodel's make_model CM using array_api_compat for torch compatibility. Add a cross-backend consistency test that verifies dp, pt, and pt_expt return identical results when only a subset of types is observed.
…bare np ops dpmodel's model-level change_type_map was not forwarding model_with_new_type_stat to the atomic model, so fine-tuning with new atom types would silently lose reference statistics. Align with the pt backend by unwrapping .atomic_model and passing it through. Also fix array API violations in fitting change_type_map methods: np.zeros/np.ones/np.concatenate fail when arrays are torch tensors (pt_expt backend). Replace with xp.zeros/xp.ones/xp.concat using proper array namespace and device. Add cross-backend test (test_change_type_map_extend_stat) that exercises the model-level change_type_map with model_with_new_type_stat across dp, pt, and pt_expt.
Add get_out_bias() and set_out_bias() methods to dpmodel's base_atomic_model, and update make_model to call them instead of accessing the attribute directly. For PT, add get_out_bias() to base_atomic_model and remove the redundant implementations from dp_atomic_model, pairtab_atomic_model, and linear_atomic_model.
…et-by-statistic The PT backend calls atomic_model.compute_fitting_input_stat(merged) in change_out_bias when mode is set-by-statistic, but dpmodel/pt_expt did not. This meant fparam/aparam statistics (avg, inv_std) were never updated during bias adjustment in these backends. Add compute_fitting_input_stat to dpmodel's DPAtomicModel and call it from make_model.change_out_bias. Enhance test_change_out_bias with fparam/aparam data, pt_expt coverage, and verification that fitting input stats are updated after set-by-statistic but unchanged after change-by-statistic.
…-other-full-model
…-other-full-model
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: fb2eb42812
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/dpmodel/utils/env_mat.py (1)
1-1:⚠️ Potential issue | 🟠 MajorAddress ruff linting violations in env_mat.py before commit.
The file has 4 linting violations that must be fixed:
- Lines 27 & 46 (TRY003): Move exception messages into the exception class instead of raising with long messages outside
- Line 155 (RUF059): Replace unused unpacked variables
nfandnneiwith_nfand_nneiRun
ruff check deepmd/dpmodel/utils/env_mat.pyto verify all violations are resolved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/dpmodel/utils/env_mat.py` at line 1, Fix the TRY003 issues by moving long error strings into the exception constructors inside the except blocks: replace patterns that re-raise or raise with external messages by instantiating the exception with the message (and use "from err" when appropriate) in the except handlers found in the functions that parse/load environments (the exception-raising blocks around read/parse logic in env_mat.py). Also fix the RUF059 by renaming the unused unpacked variables nf and nnei to _nf and _nnei in the tuple unpacking inside the environment-to-matrix routine (the unpack in the function that builds the neighbor/environment arrays, e.g., env_to_matrix/compute_env_matrix) so ruff no longer reports unused variables; run ruff check deepmd/dpmodel/utils/env_mat.py to verify all violations are resolved.
♻️ Duplicate comments (1)
deepmd/pt_expt/train/training.py (1)
702-709:⚠️ Potential issue | 🟠 MajorFix checkpoint path handling for nested
save_ckptdirectories and resolve ruff lint violation.Lines 702-709 have two issues:
- When
save_ckptcontains a directory path, the relative symlink target becomes broken (resolves to incorrect doubled path likecheckpoints/checkpoints/...).- Parent directories are not created before
torch.save(), causing runtime errors for nested paths.Additionally, line 747 has an unused
model_predvariable (RUF059) that must be fixed before commit.🔧 Proposed fixes
For checkpoint handling (lines 702-709):
- ckpt_path = f"{self.save_ckpt}-{step}.pt" - torch.save(state, ckpt_path) + ckpt_path = Path(f"{self.save_ckpt}-{step}.pt") + ckpt_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(state, ckpt_path) # symlink latest latest = Path(f"{self.save_ckpt}.pt") if latest.is_symlink() or latest.exists(): latest.unlink() - latest.symlink_to(ckpt_path) - log.info(f"Saved checkpoint to {ckpt_path}") + latest.symlink_to(ckpt_path.resolve()) + log.info("Saved checkpoint to %s", ckpt_path)For RUF059 at line 747:
- model_pred, loss, more_loss = self.wrapper( + _, loss, more_loss = self.wrapper(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 702 - 709, The checkpoint code should ensure parent directories exist and create a symlink that points to the checkpoint file name (not a path that doubles the directory); before calling torch.save(state, ckpt_path) create the parent directory for Path(ckpt_path) (mkdir(parents=True, exist_ok=True)); after saving create the symlink target using the checkpoint file name (Path(ckpt_path).name) so latest.symlink_to(...) does not produce a doubled path; keep existing logic to unlink existing latest. Also remove or rename the unused variable model_pred (e.g., delete the assignment or convert it to _model_pred) to resolve the RUF059 lint error. Reference symbols: self.save_ckpt, ckpt_path, latest, torch.save, and model_pred.
🧹 Nitpick comments (4)
deepmd/dpmodel/loss/loss.py (1)
56-60: Preservelossdtype for the NaN branch.Great device fix. Consider setting NaN dtype to
loss.dtypeto avoid unintended promotion inxp.where.Suggested patch
return xp.where( xp.asarray(find_property, dtype=xp.bool, device=dev), loss, - xp.asarray(xp.nan, device=dev), + xp.asarray(xp.nan, dtype=loss.dtype, device=dev), )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/dpmodel/loss/loss.py` around lines 56 - 60, The NaN branch currently constructs a NaN without preserving the original loss dtype, which can cause dtype promotion in xp.where; update the NaN construction used in the xp.where call (the branch alongside xp.asarray(find_property, ...), xp.where, and the loss variable) to create a NaN with dtype=loss.dtype and the same device (e.g., use xp.asarray/xp.full with dtype=loss.dtype and device=dev) so the resulting xp.where preserves the loss dtype.deepmd/pt_expt/model/get_model.py (1)
70-70:get_standard_modelreturn annotation is too narrow.This function can return multiple model classes, so
-> EnergyModelis misleading for typing and call sites.♻️ Proposed fix
-def get_standard_model(data: dict) -> EnergyModel: +def get_standard_model(data: dict) -> BaseModel:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/model/get_model.py` at line 70, get_standard_model's return annotation is too narrow (currently -> EnergyModel) because it can return multiple different model classes; update the function signature for get_standard_model to use a broader/accurate typing (for example: -> Any, -> Union[EnergyModel, OtherModelClass1, OtherModelClass2] or a shared Protocol/ABC) and add the necessary typing imports (Any/Union/Protocol) so call sites and static checkers see the true possible return types; ensure you reference the actual concrete model class names used in the function when constructing the Union or implement a Protocol/ABC that all returned classes conform to and use that as the return annotation.source/tests/pt_expt/conftest.py (1)
29-32: Consider adding a compatibility guard for the private Torch APIs used in this fixture.This fixture depends on
torch.utils._device.DeviceContextandtorch.overrides._get_current_function_mode_stack(), which are private, undocumented APIs not covered by PyTorch's compatibility guarantees and may change across versions without deprecation.While these APIs are necessary to work around a PyTorch bug in
torch.optim.adam.Adam._init_group(which omitsdevice=on the default code path), consider wrapping the fixture logic in a try-except block or version check to prevent test collection failures if these internals change in future PyTorch releases. The current documentation is thorough, but an explicit compatibility guard would make the fixture more resilient—especially given that the repository allows torch to remain unpinned for macOS x86.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/conftest.py` around lines 29 - 32, Wrap the fixture code that imports/uses the private APIs torch.utils._device.DeviceContext and torch.overrides._get_current_function_mode_stack in a compatibility guard: surround the import and any usage with a try/except ImportError/AttributeError (or check torch.__version__ range) and provide a safe fallback path that skips the workaround when those symbols are missing (e.g., return/skip the fixture or use a no-op context), and log a clear warning so tests don’t fail during collection if the private APIs have changed.deepmd/pt_expt/train/wrapper.py (1)
65-74: Fail fast whenlabelis provided butcur_lris missing.Right now
Nonecan flow into the loss call. A small guard will make training-time failures explicit and easier to debug.♻️ Proposed change
if self.inference_only or label is None: return model_pred, None, None else: + if cur_lr is None: + raise ValueError("cur_lr must be provided when label is not None") natoms = atype.shape[-1] loss, more_loss = self.loss( cur_lr,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/wrapper.py` around lines 65 - 74, When label is provided but cur_lr can be None, add a fail-fast guard before calling self.loss: check if label is not None and cur_lr is None and raise a clear exception (e.g., ValueError) explaining that cur_lr is required during training; place this check in the same block that currently branches on self.inference_only/label (the code around self.inference_only, label, cur_lr and the call to self.loss) so you never pass None into self.loss.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/dpmodel/fitting/general_fitting.py`:
- Around line 271-279: The ValueError messages raised in general_fitting.py (the
two blocks that check frame["acquired_fparam"] and frame["find_fparam"] around
the numb_fparam logic) are long inline strings that trigger Ruff TRY003; extract
the message construction into a shared helper or custom exception class and
replace the inline f-strings with a short raise using that helper/exception.
Specifically, add a function like build_fparam_error(sys_index, detail_key) or a
custom exception FparamError(message) and call it from the existing checks where
you currently have raise ValueError(f"numb_fparam > 0 but fparam is not acquired
for system {ii}.") and raise ValueError(f"numb_fparam > 0 but no fparam data is
provided for system {ii}."); ensure the original formatted text is produced by
the helper and that both places (the block referencing frame["acquired_fparam"]
and the block referencing frame["find_fparam"]) use the new centralized routine
so strings are not inline in the raise statements.
In `@deepmd/dpmodel/utils/batch.py`:
- Around line 46-47: The current branch handling keys with
key.startswith("find_") uses float(val) which fails for array-valued flags;
change it to a vectorized conversion that preserves array shapes — convert val
to a numpy array and produce a boolean mask (e.g. np.asarray(val) and, if values
are probabilities, compare elementwise to 0.5, or simply astype(bool) for truthy
values) and assign that to out[key] so frame-wise masks and non-scalar inputs
are handled correctly.
In `@deepmd/dpmodel/utils/env_mat.py`:
- Around line 48-52: compute_exp_sw may divide by zero when rmin == 0 because a
= C / rmin; update compute_exp_sw to guard rmin <= 0 by either raising a clear
ValueError or clamping rmin to a small positive epsilon before computing a (keep
C and b semantics intact). Locate the computation using symbols a, C, rmin (and
xp.where/xp.full_like) and add the rmin check prior to computing a so that a
cannot become inf/nan and subsequent exp_sw calculations remain valid.
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 199-201: The long exception string in the raise expression should
be pulled out into a separate variable to satisfy Ruff TRY003; create a
descriptive variable (e.g., err_msg) that formats the message using
FLAGS.command, then raise RuntimeError(err_msg) instead of embedding the
f-string directly in the raise; update the raise site that currently uses
RuntimeError(f"Unsupported command '{FLAGS.command}' for the pt_expt backend.")
to use the new variable.
- Around line 80-88: The code fails when training_params["stat_file"] is a
nested path because parent directories aren't created before touching the
file/dir; update the stat_file_path handling so that when stat_file_path is a
file (endswith ".h5" or ".hdf5") you first ensure its parent directory exists
(Path(stat_file_path).parent.mkdir(parents=True, exist_ok=True)) before creating
the h5 file with h5py.File, and when stat_file_path is meant to be a directory
ensure you create all nested directories
(Path(stat_file_path).mkdir(parents=True, exist_ok=True)); afterward wrap the
path with DPPath as before (referencing stat_file_path and DPPath in main.py).
In `@deepmd/pt_expt/model/get_model.py`:
- Line 95: The inline f-string in the raise statement for Unknown fitting type
(raise RuntimeError(f"Unknown fitting type: {fitting_net_type}")) triggers Ruff
TRY003; fix by constructing the error message in a separate variable (e.g., msg
= f"Unknown fitting type: {fitting_net_type}") and then raise the exception with
that variable (raise RuntimeError(msg)) in the get_model.py code path that
handles fitting_net_type so the message construction is not inline in the raise.
In `@deepmd/pt_expt/train/training.py`:
- Around line 747-749: The tuple returned by self.wrapper(...) is being unpacked
into model_pred, loss, more_loss but model_pred is unused; update the unpack
targets in the training loop to use a throwaway name (e.g., _model_pred or _ )
instead of model_pred so RUF059 is satisfied, ensure the unpack pattern that
includes loss and more_loss remains unchanged, and run ruff check . and ruff
format . before committing; locate this change around the call to self.wrapper
in training.py (the training loop that assigns model_pred, loss, more_loss).
In `@source/tests/pt_expt/loss/test_ener.py`:
- Line 136: dp_loss(...) is being unpacked into l_dp and more_dp but more_dp is
never used, causing a Ruff RUF059 lint failure; update the call site in
test_ener.py where dp_loss(learning_rate, natoms, model_pred_np, label_np) is
assigned so the unused value is ignored (e.g., use a single-target assignment or
a throwaway variable) and ensure only l_dp is retained, referencing the dp_loss
call and the l_dp/more_dp names when you make the change.
In `@source/tests/pt_expt/test_training.py`:
- Line 140: Rename the ambiguous loop variable l in the list comprehensions to a
clear name like line (e.g., change lines = [l for l in f.readlines() if not
l.startswith("#")] to use line) for both occurrences mentioned, and fix the
unused unpack of more_loss by either removing the unused variable from the
unpack or replacing it with an underscore (e.g., use _, more_loss = ... or
more_loss, _ = ... as appropriate) so Ruff errors E741 and RUF059 are resolved.
---
Outside diff comments:
In `@deepmd/dpmodel/utils/env_mat.py`:
- Line 1: Fix the TRY003 issues by moving long error strings into the exception
constructors inside the except blocks: replace patterns that re-raise or raise
with external messages by instantiating the exception with the message (and use
"from err" when appropriate) in the except handlers found in the functions that
parse/load environments (the exception-raising blocks around read/parse logic in
env_mat.py). Also fix the RUF059 by renaming the unused unpacked variables nf
and nnei to _nf and _nnei in the tuple unpacking inside the
environment-to-matrix routine (the unpack in the function that builds the
neighbor/environment arrays, e.g., env_to_matrix/compute_env_matrix) so ruff no
longer reports unused variables; run ruff check deepmd/dpmodel/utils/env_mat.py
to verify all violations are resolved.
---
Duplicate comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 702-709: The checkpoint code should ensure parent directories
exist and create a symlink that points to the checkpoint file name (not a path
that doubles the directory); before calling torch.save(state, ckpt_path) create
the parent directory for Path(ckpt_path) (mkdir(parents=True, exist_ok=True));
after saving create the symlink target using the checkpoint file name
(Path(ckpt_path).name) so latest.symlink_to(...) does not produce a doubled
path; keep existing logic to unlink existing latest. Also remove or rename the
unused variable model_pred (e.g., delete the assignment or convert it to
_model_pred) to resolve the RUF059 lint error. Reference symbols:
self.save_ckpt, ckpt_path, latest, torch.save, and model_pred.
---
Nitpick comments:
In `@deepmd/dpmodel/loss/loss.py`:
- Around line 56-60: The NaN branch currently constructs a NaN without
preserving the original loss dtype, which can cause dtype promotion in xp.where;
update the NaN construction used in the xp.where call (the branch alongside
xp.asarray(find_property, ...), xp.where, and the loss variable) to create a NaN
with dtype=loss.dtype and the same device (e.g., use xp.asarray/xp.full with
dtype=loss.dtype and device=dev) so the resulting xp.where preserves the loss
dtype.
In `@deepmd/pt_expt/model/get_model.py`:
- Line 70: get_standard_model's return annotation is too narrow (currently ->
EnergyModel) because it can return multiple different model classes; update the
function signature for get_standard_model to use a broader/accurate typing (for
example: -> Any, -> Union[EnergyModel, OtherModelClass1, OtherModelClass2] or a
shared Protocol/ABC) and add the necessary typing imports (Any/Union/Protocol)
so call sites and static checkers see the true possible return types; ensure you
reference the actual concrete model class names used in the function when
constructing the Union or implement a Protocol/ABC that all returned classes
conform to and use that as the return annotation.
In `@deepmd/pt_expt/train/wrapper.py`:
- Around line 65-74: When label is provided but cur_lr can be None, add a
fail-fast guard before calling self.loss: check if label is not None and cur_lr
is None and raise a clear exception (e.g., ValueError) explaining that cur_lr is
required during training; place this check in the same block that currently
branches on self.inference_only/label (the code around self.inference_only,
label, cur_lr and the call to self.loss) so you never pass None into self.loss.
In `@source/tests/pt_expt/conftest.py`:
- Around line 29-32: Wrap the fixture code that imports/uses the private APIs
torch.utils._device.DeviceContext and
torch.overrides._get_current_function_mode_stack in a compatibility guard:
surround the import and any usage with a try/except ImportError/AttributeError
(or check torch.__version__ range) and provide a safe fallback path that skips
the workaround when those symbols are missing (e.g., return/skip the fixture or
use a no-op context), and log a clear warning so tests don’t fail during
collection if the private APIs have changed.
ℹ️ Review info
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (40)
deepmd/backend/pt_expt.pydeepmd/dpmodel/atomic_model/base_atomic_model.pydeepmd/dpmodel/descriptor/make_base_descriptor.pydeepmd/dpmodel/fitting/general_fitting.pydeepmd/dpmodel/fitting/make_base_fitting.pydeepmd/dpmodel/loss/ener.pydeepmd/dpmodel/loss/loss.pydeepmd/dpmodel/model/base_model.pydeepmd/dpmodel/utils/batch.pydeepmd/dpmodel/utils/env_mat.pydeepmd/dpmodel/utils/env_mat_stat.pydeepmd/dpmodel/utils/stat.pydeepmd/pt_expt/entrypoints/__init__.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/loss/__init__.pydeepmd/pt_expt/loss/ener.pydeepmd/pt_expt/model/__init__.pydeepmd/pt_expt/model/get_model.pydeepmd/pt_expt/train/__init__.pydeepmd/pt_expt/train/training.pydeepmd/pt_expt/train/wrapper.pydeepmd/pt_expt/utils/stat.pydeepmd/tf/model/model_stat.pydeepmd/utils/argcheck.pydeepmd/utils/model_stat.pysource/tests/common/dpmodel/test_fitting_stat.pysource/tests/consistent/common.pysource/tests/consistent/fitting/test_ener.pysource/tests/consistent/loss/test_ener.pysource/tests/consistent/model/test_dipole.pysource/tests/consistent/model/test_dos.pysource/tests/consistent/model/test_ener.pysource/tests/consistent/model/test_polar.pysource/tests/consistent/model/test_property.pysource/tests/consistent/test_make_stat_input.pysource/tests/pt_expt/conftest.pysource/tests/pt_expt/fitting/test_fitting_stat.pysource/tests/pt_expt/loss/__init__.pysource/tests/pt_expt/loss/test_ener.pysource/tests/pt_expt/test_training.py
_compile_model() wraps self.wrapper.model with _CompiledModel, changing state_dict keys from model.* to model.original_model.*. When this ran before the resume block, load_state_dict would fail on checkpoints saved from uncompiled training. Move the compile block after resume so checkpoint keys always match, then wrap the restored model. Add restart/init_model/restart+compile tests and DPA3 training test.
…eady handles the offset.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (3)
source/tests/pt_expt/test_training.py (1)
1-630: Consider extracting repeated lcurve reading into a helper.The pattern
[line for line in f.readlines() if not line.startswith("#")]appears 6 times. A small helper would reduce duplication and centralize the fix for thelvariable issue.♻️ Suggested helper
def _read_lcurve_data_lines(path: str) -> list[str]: """Read non-comment lines from an lcurve file.""" with open(path) as f: return [line for line in f.readlines() if not line.startswith("#")]Then replace all occurrences with:
lines = _read_lcurve_data_lines(lcurve_path)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training.py` around lines 1 - 630, Several tests duplicate the lcurve-reading comprehension ([l for l in f.readlines() if not l.startswith("#")]) — extract it to a helper function (e.g., _read_lcurve_data_lines(path: str) -> list[str]) and replace all inline occurrences in TestTraining._run_training, TestRestart.test_restart, TestRestart.test_init_model, TestRestart.test_restart_with_compile, TestTrainingDPA3.test_training_loop (and any other tests using the same pattern) to call this helper; ensure the helper opens the file, filters out lines starting with "#" and returns the list so the tests use lines = _read_lcurve_data_lines(lcurve_path).deepmd/pt_expt/train/training.py (2)
69-80: Input dict mutation may cause unexpected side effects.
loss_paramsis modified in-place by adding"starter_learning_rate". If the caller reuses the original dict, they'll see this key unexpectedly.🛡️ Suggested fix
def get_loss( loss_params: dict[str, Any], start_lr: float, _ntypes: int, _model: Any, ) -> EnergyLoss: + loss_params = loss_params.copy() loss_type = loss_params.get("type", "ener") if loss_type == "ener": loss_params["starter_learning_rate"] = start_lr return EnergyLoss(**loss_params)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 69 - 80, The get_loss function mutates the input loss_params by setting "starter_learning_rate", which can cause caller-side surprises; fix by making a shallow copy of loss_params (or use dict unpacking) into a new dict (e.g., params = dict(loss_params) or {**loss_params}), set params["starter_learning_rate"] = start_lr, and pass params to EnergyLoss(**params) instead of mutating loss_params directly; keep the same logic and error handling in get_loss.
244-252: Consider documenting the unuseddo_atomic_virialparameter.The parameter is part of the forward signature for API compatibility but is not used. Adding a brief comment or using
_do_atomic_virialwould clarify intent and silence ARG002.📝 Suggested fix
def forward( self, coord: torch.Tensor, atype: torch.Tensor, box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, - do_atomic_virial: bool = False, + do_atomic_virial: bool = False, # noqa: ARG002 - API compatibility ) -> dict[str, torch.Tensor]:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 244 - 252, The forward method currently accepts do_atomic_virial but does not use it; update the signature or body of the deepmd.pt_expt.train.training.forward function to indicate intentional unused status by either renaming the parameter to _do_atomic_virial or adding a short inline comment (e.g., "# kept for API compatibility, unused") and reference it (e.g., pass or assert False in debug) to silence ARG002 while preserving API compatibility; ensure the change only affects the forward signature and local use so callers remain unaffected.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@source/tests/pt_expt/test_training.py`:
- Around line 433-435: Rename the ambiguous comprehension variable "l" to "line"
in all list comprehensions in the test file (the occurrences that build "lines =
[l for l in f.readlines() if not l.startswith('#')]" and the similar
comprehensions later in the same test), i.e., update the iterator variable in
those list comprehensions to "line" and adjust any references inside the
comprehension accordingly; then run "ruff check ." and "ruff format ." to ensure
linter/formatter compliance.
- Around line 617-619: The test uses a single-letter variable `l` in a list
comprehension (e.g., inside the block that opens lcurve_path) which triggers
E741; replace `l` with a descriptive name like `line` in that comprehension
(lines = [line for line in f.readlines() if not line.startswith("#")]) and
update any other identical occurrences in source/tests/pt_expt/test_training.py
to use `line` instead of `l`; after changes run `ruff check .` and `ruff format
.` to ensure style checks pass.
---
Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 69-80: The get_loss function mutates the input loss_params by
setting "starter_learning_rate", which can cause caller-side surprises; fix by
making a shallow copy of loss_params (or use dict unpacking) into a new dict
(e.g., params = dict(loss_params) or {**loss_params}), set
params["starter_learning_rate"] = start_lr, and pass params to
EnergyLoss(**params) instead of mutating loss_params directly; keep the same
logic and error handling in get_loss.
- Around line 244-252: The forward method currently accepts do_atomic_virial but
does not use it; update the signature or body of the
deepmd.pt_expt.train.training.forward function to indicate intentional unused
status by either renaming the parameter to _do_atomic_virial or adding a short
inline comment (e.g., "# kept for API compatibility, unused") and reference it
(e.g., pass or assert False in debug) to silence ARG002 while preserving API
compatibility; ensure the change only affects the forward signature and local
use so callers remain unaffected.
In `@source/tests/pt_expt/test_training.py`:
- Around line 1-630: Several tests duplicate the lcurve-reading comprehension
([l for l in f.readlines() if not l.startswith("#")]) — extract it to a helper
function (e.g., _read_lcurve_data_lines(path: str) -> list[str]) and replace all
inline occurrences in TestTraining._run_training, TestRestart.test_restart,
TestRestart.test_init_model, TestRestart.test_restart_with_compile,
TestTrainingDPA3.test_training_loop (and any other tests using the same pattern)
to call this helper; ensure the helper opens the file, filters out lines
starting with "#" and returns the list so the tests use lines =
_read_lcurve_data_lines(lcurve_path).
Introduce a _update_sel_cls class variable on dpmodel descriptors so that each backend can supply its own UpdateSel / NeighborStat. The dpmodel default stays on CPU (numpy); pt_expt overrides the class variable with a GPU-aware UpdateSel backed by a @torch_module-wrapped NeighborStatOP, so array_api_compat dispatches to torch on DEVICE.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
deepmd/pt_expt/utils/neighbor_stat.py (1)
63-72: Consider hoistingdata_setassignment outside the inner loop.The
data_setvariable only depends oniiand doesn't change within the inner loop overjj. Moving it outside the inner loop would be slightly cleaner.♻️ Suggested refactor
for ii in range(len(data.system_dirs)): + data_set = data.data_systems[ii] for jj in data.data_systems[ii].dirs: - data_set = data.data_systems[ii] data_set_data = data_set._load_set(jj) minrr2, max_nnei = self._execute( data_set_data["coord"], data_set_data["type"], data_set_data["box"] if data_set.pbc else None, ) yield np.max(max_nnei, axis=0), np.min(minrr2), jj🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/utils/neighbor_stat.py` around lines 63 - 72, The inner loop currently reassigns data_set each iteration even though it only depends on ii; move the statement data_set = data.data_systems[ii] immediately before the inner loop (so it is set once per outer loop) and remove the redundant assignment inside the for jj loop, leaving uses of data_set._load_set(jj) and data_set.pbc unchanged; ensure indentation is adjusted accordingly in the function that contains the loops.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@deepmd/pt_expt/utils/neighbor_stat.py`:
- Around line 63-72: The inner loop currently reassigns data_set each iteration
even though it only depends on ii; move the statement data_set =
data.data_systems[ii] immediately before the inner loop (so it is set once per
outer loop) and remove the redundant assignment inside the for jj loop, leaving
uses of data_set._load_set(jj) and data_set.pbc unchanged; ensure indentation is
adjusted accordingly in the function that contains the loops.
ℹ️ Review info
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (17)
deepmd/dpmodel/descriptor/dpa1.pydeepmd/dpmodel/descriptor/dpa2.pydeepmd/dpmodel/descriptor/dpa3.pydeepmd/dpmodel/descriptor/se_e2_a.pydeepmd/dpmodel/descriptor/se_r.pydeepmd/dpmodel/descriptor/se_t.pydeepmd/dpmodel/descriptor/se_t_tebd.pydeepmd/pt_expt/descriptor/dpa1.pydeepmd/pt_expt/descriptor/dpa2.pydeepmd/pt_expt/descriptor/dpa3.pydeepmd/pt_expt/descriptor/se_atten_v2.pydeepmd/pt_expt/descriptor/se_e2_a.pydeepmd/pt_expt/descriptor/se_r.pydeepmd/pt_expt/descriptor/se_t.pydeepmd/pt_expt/descriptor/se_t_tebd.pydeepmd/pt_expt/utils/neighbor_stat.pydeepmd/pt_expt/utils/update_sel.py
…expt.utils.neighbor_stat (ported from PR deepmodeling#5270)
Summary by CodeRabbit