Skip to content

feat(pt_expt): add training infrastructure#5270

Merged
wanghan-iapcm merged 76 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-pt-expt-train
Mar 5, 2026
Merged

feat(pt_expt): add training infrastructure#5270
wanghan-iapcm merged 76 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-pt-expt-train

Conversation

@wanghan-iapcm
Copy link
Collaborator

@wanghan-iapcm wanghan-iapcm commented Feb 26, 2026

Summary by CodeRabbit

  • New Features
    • New "PyTorch-Exportable" backend with end-to-end training CLI, model factory, loss wrapper, model wrapper, checkpointing, neighbor/stat utilities, and optional torch.compile acceleration with automatic re-tracing.
  • Bug Fixes
    • Added runtime validation for required parameter fields (fparam/aparam) and ensured loss/reporting arrays are device-consistent.
  • Refactor
    • Reworked stat input, batch normalization, and model-stat collection for clearer, backend-agnostic data flow.

Han Wang added 30 commits February 22, 2026 16:10
…om the base models of the corresponding backend
  Add TestEnerComputeOrLoadStat to the consistency test framework, comparing
  dp, pt, and pt_expt backends after compute_or_load_stat. Tests cover
  descriptor stats, fparam/aparam fitting stats, output bias, and forward
  consistency, parameterized over exclusion types and fparam source
  (default injection vs explicit data). Both compute and load-from-file
  paths are tested.

  Three dpmodel bugs found and fixed:
  - repflows.py: compute_input_stats now respects set_stddev_constant,
    matching the pt backend behavior
  - stat.py: compute_output_stats_global now applies atom_exclude_types
    mask to natoms before computing output bias
  - general_fitting.py: compute_input_stats now supports save/load of
    fparam/aparam stats via stat_file_path, matching the pt backend
  Move get_observed_type_list from a PT-only method to a backend-independent
  abstract API on BaseBaseModel, with a concrete implementation in dpmodel's
  make_model CM using array_api_compat for torch compatibility. Add a
  cross-backend consistency test that verifies dp, pt, and pt_expt return
  identical results when only a subset of types is observed.
…bare np ops

  dpmodel's model-level change_type_map was not forwarding
  model_with_new_type_stat to the atomic model, so fine-tuning with new
  atom types would silently lose reference statistics. Align with the pt
  backend by unwrapping .atomic_model and passing it through.

  Also fix array API violations in fitting change_type_map methods:
  np.zeros/np.ones/np.concatenate fail when arrays are torch tensors
  (pt_expt backend). Replace with xp.zeros/xp.ones/xp.concat using
  proper array namespace and device.

  Add cross-backend test (test_change_type_map_extend_stat) that
  exercises the model-level change_type_map with
  model_with_new_type_stat across dp, pt, and pt_expt.
  Add get_out_bias() and set_out_bias() methods to dpmodel's
  base_atomic_model, and update make_model to call them instead of
  accessing the attribute directly. For PT, add get_out_bias() to
  base_atomic_model and remove the redundant implementations from
  dp_atomic_model, pairtab_atomic_model, and linear_atomic_model.
…et-by-statistic

  The PT backend calls atomic_model.compute_fitting_input_stat(merged) in
  change_out_bias when mode is set-by-statistic, but dpmodel/pt_expt did
  not. This meant fparam/aparam statistics (avg, inv_std) were never updated
  during bias adjustment in these backends.

  Add compute_fitting_input_stat to dpmodel's DPAtomicModel and call it
  from make_model.change_out_bias. Enhance test_change_out_bias with
  fparam/aparam data, pt_expt coverage, and verification that fitting input
  stats are updated after set-by-statistic but unchanged after
  change-by-statistic.
@wanghan-iapcm wanghan-iapcm added the Test CUDA Trigger test CUDA workflow label Mar 1, 2026
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Mar 1, 2026
@wanghan-iapcm wanghan-iapcm added the Test CUDA Trigger test CUDA workflow label Mar 2, 2026
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Mar 2, 2026
@wanghan-iapcm wanghan-iapcm marked this pull request as ready for review March 2, 2026 08:15
@wanghan-iapcm wanghan-iapcm requested review from iProzd and njzjz March 2, 2026 08:15
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: fb2eb42812

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/dpmodel/utils/env_mat.py (1)

1-1: ⚠️ Potential issue | 🟠 Major

Address ruff linting violations in env_mat.py before commit.

The file has 4 linting violations that must be fixed:

  • Lines 27 & 46 (TRY003): Move exception messages into the exception class instead of raising with long messages outside
  • Line 155 (RUF059): Replace unused unpacked variables nf and nnei with _nf and _nnei

Run ruff check deepmd/dpmodel/utils/env_mat.py to verify all violations are resolved.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/dpmodel/utils/env_mat.py` at line 1, Fix the TRY003 issues by moving
long error strings into the exception constructors inside the except blocks:
replace patterns that re-raise or raise with external messages by instantiating
the exception with the message (and use "from err" when appropriate) in the
except handlers found in the functions that parse/load environments (the
exception-raising blocks around read/parse logic in env_mat.py). Also fix the
RUF059 by renaming the unused unpacked variables nf and nnei to _nf and _nnei in
the tuple unpacking inside the environment-to-matrix routine (the unpack in the
function that builds the neighbor/environment arrays, e.g.,
env_to_matrix/compute_env_matrix) so ruff no longer reports unused variables;
run ruff check deepmd/dpmodel/utils/env_mat.py to verify all violations are
resolved.
♻️ Duplicate comments (1)
deepmd/pt_expt/train/training.py (1)

702-709: ⚠️ Potential issue | 🟠 Major

Fix checkpoint path handling for nested save_ckpt directories and resolve ruff lint violation.

Lines 702-709 have two issues:

  1. When save_ckpt contains a directory path, the relative symlink target becomes broken (resolves to incorrect doubled path like checkpoints/checkpoints/...).
  2. Parent directories are not created before torch.save(), causing runtime errors for nested paths.

Additionally, line 747 has an unused model_pred variable (RUF059) that must be fixed before commit.

🔧 Proposed fixes

For checkpoint handling (lines 702-709):

-        ckpt_path = f"{self.save_ckpt}-{step}.pt"
-        torch.save(state, ckpt_path)
+        ckpt_path = Path(f"{self.save_ckpt}-{step}.pt")
+        ckpt_path.parent.mkdir(parents=True, exist_ok=True)
+        torch.save(state, ckpt_path)
         # symlink latest
         latest = Path(f"{self.save_ckpt}.pt")
         if latest.is_symlink() or latest.exists():
             latest.unlink()
-        latest.symlink_to(ckpt_path)
-        log.info(f"Saved checkpoint to {ckpt_path}")
+        latest.symlink_to(ckpt_path.resolve())
+        log.info("Saved checkpoint to %s", ckpt_path)

For RUF059 at line 747:

-            model_pred, loss, more_loss = self.wrapper(
+            _, loss, more_loss = self.wrapper(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/train/training.py` around lines 702 - 709, The checkpoint code
should ensure parent directories exist and create a symlink that points to the
checkpoint file name (not a path that doubles the directory); before calling
torch.save(state, ckpt_path) create the parent directory for Path(ckpt_path)
(mkdir(parents=True, exist_ok=True)); after saving create the symlink target
using the checkpoint file name (Path(ckpt_path).name) so latest.symlink_to(...)
does not produce a doubled path; keep existing logic to unlink existing latest.
Also remove or rename the unused variable model_pred (e.g., delete the
assignment or convert it to _model_pred) to resolve the RUF059 lint error.
Reference symbols: self.save_ckpt, ckpt_path, latest, torch.save, and
model_pred.
🧹 Nitpick comments (4)
deepmd/dpmodel/loss/loss.py (1)

56-60: Preserve loss dtype for the NaN branch.

Great device fix. Consider setting NaN dtype to loss.dtype to avoid unintended promotion in xp.where.

Suggested patch
         return xp.where(
             xp.asarray(find_property, dtype=xp.bool, device=dev),
             loss,
-            xp.asarray(xp.nan, device=dev),
+            xp.asarray(xp.nan, dtype=loss.dtype, device=dev),
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/dpmodel/loss/loss.py` around lines 56 - 60, The NaN branch currently
constructs a NaN without preserving the original loss dtype, which can cause
dtype promotion in xp.where; update the NaN construction used in the xp.where
call (the branch alongside xp.asarray(find_property, ...), xp.where, and the
loss variable) to create a NaN with dtype=loss.dtype and the same device (e.g.,
use xp.asarray/xp.full with dtype=loss.dtype and device=dev) so the resulting
xp.where preserves the loss dtype.
deepmd/pt_expt/model/get_model.py (1)

70-70: get_standard_model return annotation is too narrow.

This function can return multiple model classes, so -> EnergyModel is misleading for typing and call sites.

♻️ Proposed fix
-def get_standard_model(data: dict) -> EnergyModel:
+def get_standard_model(data: dict) -> BaseModel:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/model/get_model.py` at line 70, get_standard_model's return
annotation is too narrow (currently -> EnergyModel) because it can return
multiple different model classes; update the function signature for
get_standard_model to use a broader/accurate typing (for example: -> Any, ->
Union[EnergyModel, OtherModelClass1, OtherModelClass2] or a shared Protocol/ABC)
and add the necessary typing imports (Any/Union/Protocol) so call sites and
static checkers see the true possible return types; ensure you reference the
actual concrete model class names used in the function when constructing the
Union or implement a Protocol/ABC that all returned classes conform to and use
that as the return annotation.
source/tests/pt_expt/conftest.py (1)

29-32: Consider adding a compatibility guard for the private Torch APIs used in this fixture.

This fixture depends on torch.utils._device.DeviceContext and torch.overrides._get_current_function_mode_stack(), which are private, undocumented APIs not covered by PyTorch's compatibility guarantees and may change across versions without deprecation.

While these APIs are necessary to work around a PyTorch bug in torch.optim.adam.Adam._init_group (which omits device= on the default code path), consider wrapping the fixture logic in a try-except block or version check to prevent test collection failures if these internals change in future PyTorch releases. The current documentation is thorough, but an explicit compatibility guard would make the fixture more resilient—especially given that the repository allows torch to remain unpinned for macOS x86.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/conftest.py` around lines 29 - 32, Wrap the fixture code
that imports/uses the private APIs torch.utils._device.DeviceContext and
torch.overrides._get_current_function_mode_stack in a compatibility guard:
surround the import and any usage with a try/except ImportError/AttributeError
(or check torch.__version__ range) and provide a safe fallback path that skips
the workaround when those symbols are missing (e.g., return/skip the fixture or
use a no-op context), and log a clear warning so tests don’t fail during
collection if the private APIs have changed.
deepmd/pt_expt/train/wrapper.py (1)

65-74: Fail fast when label is provided but cur_lr is missing.

Right now None can flow into the loss call. A small guard will make training-time failures explicit and easier to debug.

♻️ Proposed change
         if self.inference_only or label is None:
             return model_pred, None, None
         else:
+            if cur_lr is None:
+                raise ValueError("cur_lr must be provided when label is not None")
             natoms = atype.shape[-1]
             loss, more_loss = self.loss(
                 cur_lr,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/train/wrapper.py` around lines 65 - 74, When label is provided
but cur_lr can be None, add a fail-fast guard before calling self.loss: check if
label is not None and cur_lr is None and raise a clear exception (e.g.,
ValueError) explaining that cur_lr is required during training; place this check
in the same block that currently branches on self.inference_only/label (the code
around self.inference_only, label, cur_lr and the call to self.loss) so you
never pass None into self.loss.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/dpmodel/fitting/general_fitting.py`:
- Around line 271-279: The ValueError messages raised in general_fitting.py (the
two blocks that check frame["acquired_fparam"] and frame["find_fparam"] around
the numb_fparam logic) are long inline strings that trigger Ruff TRY003; extract
the message construction into a shared helper or custom exception class and
replace the inline f-strings with a short raise using that helper/exception.
Specifically, add a function like build_fparam_error(sys_index, detail_key) or a
custom exception FparamError(message) and call it from the existing checks where
you currently have raise ValueError(f"numb_fparam > 0 but fparam is not acquired
for system {ii}.") and raise ValueError(f"numb_fparam > 0 but no fparam data is
provided for system {ii}."); ensure the original formatted text is produced by
the helper and that both places (the block referencing frame["acquired_fparam"]
and the block referencing frame["find_fparam"]) use the new centralized routine
so strings are not inline in the raise statements.

In `@deepmd/dpmodel/utils/batch.py`:
- Around line 46-47: The current branch handling keys with
key.startswith("find_") uses float(val) which fails for array-valued flags;
change it to a vectorized conversion that preserves array shapes — convert val
to a numpy array and produce a boolean mask (e.g. np.asarray(val) and, if values
are probabilities, compare elementwise to 0.5, or simply astype(bool) for truthy
values) and assign that to out[key] so frame-wise masks and non-scalar inputs
are handled correctly.

In `@deepmd/dpmodel/utils/env_mat.py`:
- Around line 48-52: compute_exp_sw may divide by zero when rmin == 0 because a
= C / rmin; update compute_exp_sw to guard rmin <= 0 by either raising a clear
ValueError or clamping rmin to a small positive epsilon before computing a (keep
C and b semantics intact). Locate the computation using symbols a, C, rmin (and
xp.where/xp.full_like) and add the rmin check prior to computing a so that a
cannot become inf/nan and subsequent exp_sw calculations remain valid.

In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 199-201: The long exception string in the raise expression should
be pulled out into a separate variable to satisfy Ruff TRY003; create a
descriptive variable (e.g., err_msg) that formats the message using
FLAGS.command, then raise RuntimeError(err_msg) instead of embedding the
f-string directly in the raise; update the raise site that currently uses
RuntimeError(f"Unsupported command '{FLAGS.command}' for the pt_expt backend.")
to use the new variable.
- Around line 80-88: The code fails when training_params["stat_file"] is a
nested path because parent directories aren't created before touching the
file/dir; update the stat_file_path handling so that when stat_file_path is a
file (endswith ".h5" or ".hdf5") you first ensure its parent directory exists
(Path(stat_file_path).parent.mkdir(parents=True, exist_ok=True)) before creating
the h5 file with h5py.File, and when stat_file_path is meant to be a directory
ensure you create all nested directories
(Path(stat_file_path).mkdir(parents=True, exist_ok=True)); afterward wrap the
path with DPPath as before (referencing stat_file_path and DPPath in main.py).

In `@deepmd/pt_expt/model/get_model.py`:
- Line 95: The inline f-string in the raise statement for Unknown fitting type
(raise RuntimeError(f"Unknown fitting type: {fitting_net_type}")) triggers Ruff
TRY003; fix by constructing the error message in a separate variable (e.g., msg
= f"Unknown fitting type: {fitting_net_type}") and then raise the exception with
that variable (raise RuntimeError(msg)) in the get_model.py code path that
handles fitting_net_type so the message construction is not inline in the raise.

In `@deepmd/pt_expt/train/training.py`:
- Around line 747-749: The tuple returned by self.wrapper(...) is being unpacked
into model_pred, loss, more_loss but model_pred is unused; update the unpack
targets in the training loop to use a throwaway name (e.g., _model_pred or _ )
instead of model_pred so RUF059 is satisfied, ensure the unpack pattern that
includes loss and more_loss remains unchanged, and run ruff check . and ruff
format . before committing; locate this change around the call to self.wrapper
in training.py (the training loop that assigns model_pred, loss, more_loss).

In `@source/tests/pt_expt/loss/test_ener.py`:
- Line 136: dp_loss(...) is being unpacked into l_dp and more_dp but more_dp is
never used, causing a Ruff RUF059 lint failure; update the call site in
test_ener.py where dp_loss(learning_rate, natoms, model_pred_np, label_np) is
assigned so the unused value is ignored (e.g., use a single-target assignment or
a throwaway variable) and ensure only l_dp is retained, referencing the dp_loss
call and the l_dp/more_dp names when you make the change.

In `@source/tests/pt_expt/test_training.py`:
- Line 140: Rename the ambiguous loop variable l in the list comprehensions to a
clear name like line (e.g., change lines = [l for l in f.readlines() if not
l.startswith("#")] to use line) for both occurrences mentioned, and fix the
unused unpack of more_loss by either removing the unused variable from the
unpack or replacing it with an underscore (e.g., use _, more_loss = ... or
more_loss, _ = ... as appropriate) so Ruff errors E741 and RUF059 are resolved.

---

Outside diff comments:
In `@deepmd/dpmodel/utils/env_mat.py`:
- Line 1: Fix the TRY003 issues by moving long error strings into the exception
constructors inside the except blocks: replace patterns that re-raise or raise
with external messages by instantiating the exception with the message (and use
"from err" when appropriate) in the except handlers found in the functions that
parse/load environments (the exception-raising blocks around read/parse logic in
env_mat.py). Also fix the RUF059 by renaming the unused unpacked variables nf
and nnei to _nf and _nnei in the tuple unpacking inside the
environment-to-matrix routine (the unpack in the function that builds the
neighbor/environment arrays, e.g., env_to_matrix/compute_env_matrix) so ruff no
longer reports unused variables; run ruff check deepmd/dpmodel/utils/env_mat.py
to verify all violations are resolved.

---

Duplicate comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 702-709: The checkpoint code should ensure parent directories
exist and create a symlink that points to the checkpoint file name (not a path
that doubles the directory); before calling torch.save(state, ckpt_path) create
the parent directory for Path(ckpt_path) (mkdir(parents=True, exist_ok=True));
after saving create the symlink target using the checkpoint file name
(Path(ckpt_path).name) so latest.symlink_to(...) does not produce a doubled
path; keep existing logic to unlink existing latest. Also remove or rename the
unused variable model_pred (e.g., delete the assignment or convert it to
_model_pred) to resolve the RUF059 lint error. Reference symbols:
self.save_ckpt, ckpt_path, latest, torch.save, and model_pred.

---

Nitpick comments:
In `@deepmd/dpmodel/loss/loss.py`:
- Around line 56-60: The NaN branch currently constructs a NaN without
preserving the original loss dtype, which can cause dtype promotion in xp.where;
update the NaN construction used in the xp.where call (the branch alongside
xp.asarray(find_property, ...), xp.where, and the loss variable) to create a NaN
with dtype=loss.dtype and the same device (e.g., use xp.asarray/xp.full with
dtype=loss.dtype and device=dev) so the resulting xp.where preserves the loss
dtype.

In `@deepmd/pt_expt/model/get_model.py`:
- Line 70: get_standard_model's return annotation is too narrow (currently ->
EnergyModel) because it can return multiple different model classes; update the
function signature for get_standard_model to use a broader/accurate typing (for
example: -> Any, -> Union[EnergyModel, OtherModelClass1, OtherModelClass2] or a
shared Protocol/ABC) and add the necessary typing imports (Any/Union/Protocol)
so call sites and static checkers see the true possible return types; ensure you
reference the actual concrete model class names used in the function when
constructing the Union or implement a Protocol/ABC that all returned classes
conform to and use that as the return annotation.

In `@deepmd/pt_expt/train/wrapper.py`:
- Around line 65-74: When label is provided but cur_lr can be None, add a
fail-fast guard before calling self.loss: check if label is not None and cur_lr
is None and raise a clear exception (e.g., ValueError) explaining that cur_lr is
required during training; place this check in the same block that currently
branches on self.inference_only/label (the code around self.inference_only,
label, cur_lr and the call to self.loss) so you never pass None into self.loss.

In `@source/tests/pt_expt/conftest.py`:
- Around line 29-32: Wrap the fixture code that imports/uses the private APIs
torch.utils._device.DeviceContext and
torch.overrides._get_current_function_mode_stack in a compatibility guard:
surround the import and any usage with a try/except ImportError/AttributeError
(or check torch.__version__ range) and provide a safe fallback path that skips
the workaround when those symbols are missing (e.g., return/skip the fixture or
use a no-op context), and log a clear warning so tests don’t fail during
collection if the private APIs have changed.

ℹ️ Review info

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a3db25a and 9b430a4.

📒 Files selected for processing (40)
  • deepmd/backend/pt_expt.py
  • deepmd/dpmodel/atomic_model/base_atomic_model.py
  • deepmd/dpmodel/descriptor/make_base_descriptor.py
  • deepmd/dpmodel/fitting/general_fitting.py
  • deepmd/dpmodel/fitting/make_base_fitting.py
  • deepmd/dpmodel/loss/ener.py
  • deepmd/dpmodel/loss/loss.py
  • deepmd/dpmodel/model/base_model.py
  • deepmd/dpmodel/utils/batch.py
  • deepmd/dpmodel/utils/env_mat.py
  • deepmd/dpmodel/utils/env_mat_stat.py
  • deepmd/dpmodel/utils/stat.py
  • deepmd/pt_expt/entrypoints/__init__.py
  • deepmd/pt_expt/entrypoints/main.py
  • deepmd/pt_expt/loss/__init__.py
  • deepmd/pt_expt/loss/ener.py
  • deepmd/pt_expt/model/__init__.py
  • deepmd/pt_expt/model/get_model.py
  • deepmd/pt_expt/train/__init__.py
  • deepmd/pt_expt/train/training.py
  • deepmd/pt_expt/train/wrapper.py
  • deepmd/pt_expt/utils/stat.py
  • deepmd/tf/model/model_stat.py
  • deepmd/utils/argcheck.py
  • deepmd/utils/model_stat.py
  • source/tests/common/dpmodel/test_fitting_stat.py
  • source/tests/consistent/common.py
  • source/tests/consistent/fitting/test_ener.py
  • source/tests/consistent/loss/test_ener.py
  • source/tests/consistent/model/test_dipole.py
  • source/tests/consistent/model/test_dos.py
  • source/tests/consistent/model/test_ener.py
  • source/tests/consistent/model/test_polar.py
  • source/tests/consistent/model/test_property.py
  • source/tests/consistent/test_make_stat_input.py
  • source/tests/pt_expt/conftest.py
  • source/tests/pt_expt/fitting/test_fitting_stat.py
  • source/tests/pt_expt/loss/__init__.py
  • source/tests/pt_expt/loss/test_ener.py
  • source/tests/pt_expt/test_training.py

Han Wang added 3 commits March 2, 2026 18:30
  _compile_model() wraps self.wrapper.model with _CompiledModel, changing
  state_dict keys from model.* to model.original_model.*. When this ran
  before the resume block, load_state_dict would fail on checkpoints saved
  from uncompiled training. Move the compile block after resume so
  checkpoint keys always match, then wrap the restored model.

  Add restart/init_model/restart+compile tests and DPA3 training test.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
source/tests/pt_expt/test_training.py (1)

1-630: Consider extracting repeated lcurve reading into a helper.

The pattern [line for line in f.readlines() if not line.startswith("#")] appears 6 times. A small helper would reduce duplication and centralize the fix for the l variable issue.

♻️ Suggested helper
def _read_lcurve_data_lines(path: str) -> list[str]:
    """Read non-comment lines from an lcurve file."""
    with open(path) as f:
        return [line for line in f.readlines() if not line.startswith("#")]

Then replace all occurrences with:

lines = _read_lcurve_data_lines(lcurve_path)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/test_training.py` around lines 1 - 630, Several tests
duplicate the lcurve-reading comprehension ([l for l in f.readlines() if not
l.startswith("#")]) — extract it to a helper function (e.g.,
_read_lcurve_data_lines(path: str) -> list[str]) and replace all inline
occurrences in TestTraining._run_training, TestRestart.test_restart,
TestRestart.test_init_model, TestRestart.test_restart_with_compile,
TestTrainingDPA3.test_training_loop (and any other tests using the same pattern)
to call this helper; ensure the helper opens the file, filters out lines
starting with "#" and returns the list so the tests use lines =
_read_lcurve_data_lines(lcurve_path).
deepmd/pt_expt/train/training.py (2)

69-80: Input dict mutation may cause unexpected side effects.

loss_params is modified in-place by adding "starter_learning_rate". If the caller reuses the original dict, they'll see this key unexpectedly.

🛡️ Suggested fix
 def get_loss(
     loss_params: dict[str, Any],
     start_lr: float,
     _ntypes: int,
     _model: Any,
 ) -> EnergyLoss:
+    loss_params = loss_params.copy()
     loss_type = loss_params.get("type", "ener")
     if loss_type == "ener":
         loss_params["starter_learning_rate"] = start_lr
         return EnergyLoss(**loss_params)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/train/training.py` around lines 69 - 80, The get_loss function
mutates the input loss_params by setting "starter_learning_rate", which can
cause caller-side surprises; fix by making a shallow copy of loss_params (or use
dict unpacking) into a new dict (e.g., params = dict(loss_params) or
{**loss_params}), set params["starter_learning_rate"] = start_lr, and pass
params to EnergyLoss(**params) instead of mutating loss_params directly; keep
the same logic and error handling in get_loss.

244-252: Consider documenting the unused do_atomic_virial parameter.

The parameter is part of the forward signature for API compatibility but is not used. Adding a brief comment or using _do_atomic_virial would clarify intent and silence ARG002.

📝 Suggested fix
     def forward(
         self,
         coord: torch.Tensor,
         atype: torch.Tensor,
         box: torch.Tensor | None = None,
         fparam: torch.Tensor | None = None,
         aparam: torch.Tensor | None = None,
-        do_atomic_virial: bool = False,
+        do_atomic_virial: bool = False,  # noqa: ARG002 - API compatibility
     ) -> dict[str, torch.Tensor]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/train/training.py` around lines 244 - 252, The forward method
currently accepts do_atomic_virial but does not use it; update the signature or
body of the deepmd.pt_expt.train.training.forward function to indicate
intentional unused status by either renaming the parameter to _do_atomic_virial
or adding a short inline comment (e.g., "# kept for API compatibility, unused")
and reference it (e.g., pass or assert False in debug) to silence ARG002 while
preserving API compatibility; ensure the change only affects the forward
signature and local use so callers remain unaffected.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@source/tests/pt_expt/test_training.py`:
- Around line 433-435: Rename the ambiguous comprehension variable "l" to "line"
in all list comprehensions in the test file (the occurrences that build "lines =
[l for l in f.readlines() if not l.startswith('#')]" and the similar
comprehensions later in the same test), i.e., update the iterator variable in
those list comprehensions to "line" and adjust any references inside the
comprehension accordingly; then run "ruff check ." and "ruff format ." to ensure
linter/formatter compliance.
- Around line 617-619: The test uses a single-letter variable `l` in a list
comprehension (e.g., inside the block that opens lcurve_path) which triggers
E741; replace `l` with a descriptive name like `line` in that comprehension
(lines = [line for line in f.readlines() if not line.startswith("#")]) and
update any other identical occurrences in source/tests/pt_expt/test_training.py
to use `line` instead of `l`; after changes run `ruff check .` and `ruff format
.` to ensure style checks pass.

---

Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 69-80: The get_loss function mutates the input loss_params by
setting "starter_learning_rate", which can cause caller-side surprises; fix by
making a shallow copy of loss_params (or use dict unpacking) into a new dict
(e.g., params = dict(loss_params) or {**loss_params}), set
params["starter_learning_rate"] = start_lr, and pass params to
EnergyLoss(**params) instead of mutating loss_params directly; keep the same
logic and error handling in get_loss.
- Around line 244-252: The forward method currently accepts do_atomic_virial but
does not use it; update the signature or body of the
deepmd.pt_expt.train.training.forward function to indicate intentional unused
status by either renaming the parameter to _do_atomic_virial or adding a short
inline comment (e.g., "# kept for API compatibility, unused") and reference it
(e.g., pass or assert False in debug) to silence ARG002 while preserving API
compatibility; ensure the change only affects the forward signature and local
use so callers remain unaffected.

In `@source/tests/pt_expt/test_training.py`:
- Around line 1-630: Several tests duplicate the lcurve-reading comprehension
([l for l in f.readlines() if not l.startswith("#")]) — extract it to a helper
function (e.g., _read_lcurve_data_lines(path: str) -> list[str]) and replace all
inline occurrences in TestTraining._run_training, TestRestart.test_restart,
TestRestart.test_init_model, TestRestart.test_restart_with_compile,
TestTrainingDPA3.test_training_loop (and any other tests using the same pattern)
to call this helper; ensure the helper opens the file, filters out lines
starting with "#" and returns the list so the tests use lines =
_read_lcurve_data_lines(lcurve_path).

ℹ️ Review info

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cb9f55b and 99c0221.

📒 Files selected for processing (2)
  • deepmd/pt_expt/train/training.py
  • source/tests/pt_expt/test_training.py

@njzjz njzjz linked an issue Mar 2, 2026 that may be closed by this pull request
  Introduce a _update_sel_cls class variable on dpmodel descriptors so
  that each backend can supply its own UpdateSel / NeighborStat.  The
  dpmodel default stays on CPU (numpy); pt_expt overrides the class
  variable with a GPU-aware UpdateSel backed by a @torch_module-wrapped
  NeighborStatOP, so array_api_compat dispatches to torch on DEVICE.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
deepmd/pt_expt/utils/neighbor_stat.py (1)

63-72: Consider hoisting data_set assignment outside the inner loop.

The data_set variable only depends on ii and doesn't change within the inner loop over jj. Moving it outside the inner loop would be slightly cleaner.

♻️ Suggested refactor
         for ii in range(len(data.system_dirs)):
+            data_set = data.data_systems[ii]
             for jj in data.data_systems[ii].dirs:
-                data_set = data.data_systems[ii]
                 data_set_data = data_set._load_set(jj)
                 minrr2, max_nnei = self._execute(
                     data_set_data["coord"],
                     data_set_data["type"],
                     data_set_data["box"] if data_set.pbc else None,
                 )
                 yield np.max(max_nnei, axis=0), np.min(minrr2), jj
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/utils/neighbor_stat.py` around lines 63 - 72, The inner loop
currently reassigns data_set each iteration even though it only depends on ii;
move the statement data_set = data.data_systems[ii] immediately before the inner
loop (so it is set once per outer loop) and remove the redundant assignment
inside the for jj loop, leaving uses of data_set._load_set(jj) and data_set.pbc
unchanged; ensure indentation is adjusted accordingly in the function that
contains the loops.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@deepmd/pt_expt/utils/neighbor_stat.py`:
- Around line 63-72: The inner loop currently reassigns data_set each iteration
even though it only depends on ii; move the statement data_set =
data.data_systems[ii] immediately before the inner loop (so it is set once per
outer loop) and remove the redundant assignment inside the for jj loop, leaving
uses of data_set._load_set(jj) and data_set.pbc unchanged; ensure indentation is
adjusted accordingly in the function that contains the loops.

ℹ️ Review info

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 99c0221 and f0f260d.

📒 Files selected for processing (17)
  • deepmd/dpmodel/descriptor/dpa1.py
  • deepmd/dpmodel/descriptor/dpa2.py
  • deepmd/dpmodel/descriptor/dpa3.py
  • deepmd/dpmodel/descriptor/se_e2_a.py
  • deepmd/dpmodel/descriptor/se_r.py
  • deepmd/dpmodel/descriptor/se_t.py
  • deepmd/dpmodel/descriptor/se_t_tebd.py
  • deepmd/pt_expt/descriptor/dpa1.py
  • deepmd/pt_expt/descriptor/dpa2.py
  • deepmd/pt_expt/descriptor/dpa3.py
  • deepmd/pt_expt/descriptor/se_atten_v2.py
  • deepmd/pt_expt/descriptor/se_e2_a.py
  • deepmd/pt_expt/descriptor/se_r.py
  • deepmd/pt_expt/descriptor/se_t.py
  • deepmd/pt_expt/descriptor/se_t_tebd.py
  • deepmd/pt_expt/utils/neighbor_stat.py
  • deepmd/pt_expt/utils/update_sel.py

Copy link
Member

@njzjz njzjz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

wanghan-iapcm pushed a commit to wanghan-iapcm/deepmd-kit that referenced this pull request Mar 4, 2026
@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Mar 4, 2026
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to a conflict with the base branch Mar 4, 2026
@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Mar 5, 2026
Merged via the queue into deepmodeling:master with commit fdeff2b Mar 5, 2026
70 checks passed
@wanghan-iapcm wanghan-iapcm deleted the feat-pt-expt-train branch March 5, 2026 07:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

PyTorch Exportable Training Loop Implementation

4 participants