feat(pt_expt): add dp freeze support and dp test tests for .pte models#5302
feat(pt_expt): add dp freeze support and dp test tests for .pte models#5302wanghan-iapcm wants to merge 1 commit intodeepmodeling:masterfrom
Conversation
Add freeze() function to pt_expt backend that loads a .pt checkpoint, reconstructs the model, serializes it, and exports to .pte via deserialize_to_file. Wire the freeze command in the main() CLI dispatcher. Add separate test files for dp freeze (test_dp_freeze.py) and dp test (test_dp_test.py) verifying the full freeze-then-test pipeline works end-to-end with .pte models.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 35745c127e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if Path(FLAGS.checkpoint_folder).is_dir(): | ||
| checkpoint_path = Path(FLAGS.checkpoint_folder) | ||
| latest_ckpt_file = (checkpoint_path / "checkpoint").read_text() | ||
| FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file)) |
There was a problem hiding this comment.
Support pt_expt checkpoint directory layout in freeze
When dp freeze is run with a directory (-c default is .), this branch unconditionally reads <dir>/checkpoint, but pt_expt training writes checkpoints as <save_ckpt>-<step>.pt plus a <save_ckpt>.pt symlink and does not create a checkpoint file (see deepmd/pt_expt/train/training.py::save_checkpoint). As a result, the normal post-training workflow dp --pt-expt freeze fails with FileNotFoundError before freeze() is called unless users manually pass a .pt file path.
Useful? React with 👍 / 👎.
📝 WalkthroughWalkthroughA new "freeze" CLI command is added to the pt_expt backend for serializing PyTorch model checkpoints to .pte format. The implementation loads model checkpoints, validates multi-task compatibility, builds models, and serializes to frozen output files. Comprehensive tests validate freeze functionality and dp_test command integration with frozen models. Changes
Sequence DiagramsequenceDiagram
participant CLI as main()
participant Freeze as freeze()
participant Model as ModelWrapper
participant File as Output File
CLI->>Freeze: freeze(checkpoint_path, output_file, head)
Freeze->>Freeze: Resolve & validate checkpoint path
Freeze->>Freeze: Load checkpoint state_dict
Freeze->>Model: get_model() → build model
Model->>Model: Initialize model parameters
Freeze->>Model: ModelWrapper(model) + load state_dict
Model->>Model: Serialize to .pte format
Freeze->>File: deserialize_to_file() write
File->>File: Save frozen model
Freeze->>CLI: Log success message
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
source/tests/pt_expt/test_dp_test.py (1)
27-45: Consider extracting shared test fixtures.The
model_se_e2_aconfiguration and checkpoint creation pattern are duplicated intest_dp_freeze.py. Consider extracting these to a shared conftest or helper module to reduce duplication.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_dp_test.py` around lines 27 - 45, The model configuration dictionary model_se_e2_a and the checkpoint creation logic duplicated between test_dp_test.py and test_dp_freeze.py should be extracted into a shared pytest fixture or helper function (e.g., in conftest.py or a test_helpers module); create a fixture named model_se_e2_a that returns the dict and a helper fixture/function (e.g., make_checkpoint or checkpoint_fixture) that encapsulates the checkpoint creation pattern, then update both tests to accept those fixtures instead of redefining the dict/checkpoint code so duplication is removed and maintenance is centralized.deepmd/pt_expt/entrypoints/main.py (1)
256-257: Minor:.pt2suffix check might be undocumented.The code accepts both
.pteand.pt2suffixes, but the docstring and default only mention.pte. Consider documenting.pt2if it's intentionally supported, or remove it if not needed.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/entrypoints/main.py` around lines 256 - 257, The FLAGS.output handling accepts both ".pte" and ".pt2" but only ".pte" is documented; decide whether ".pt2" is intentional and then update code accordingly: if intended, add ".pt2" to the module/docstring and the FLAGS.output help/default text (where FLAGS is defined) and update any docs/tests to mention ".pt2"; otherwise remove ".pt2" from the tuple in the conditional so FLAGS.output only normalizes to ".pte". Ensure changes reference FLAGS.output and the suffix check in main.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 250-253: The code reading the checkpoint name using
(checkpoint_path / "checkpoint").read_text() assigns latest_ckpt_file with
possible trailing newline/whitespace which breaks FLAGS.model path construction;
update the read to strip whitespace (e.g., call .strip() on the result) before
using checkpoint_path.joinpath and set FLAGS.model =
str(checkpoint_path.joinpath(latest_ckpt_file.strip())), ensuring you reference
FLAGS.checkpoint_folder, checkpoint_path, latest_ckpt_file and FLAGS.model when
making the change.
---
Nitpick comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 256-257: The FLAGS.output handling accepts both ".pte" and ".pt2"
but only ".pte" is documented; decide whether ".pt2" is intentional and then
update code accordingly: if intended, add ".pt2" to the module/docstring and the
FLAGS.output help/default text (where FLAGS is defined) and update any
docs/tests to mention ".pt2"; otherwise remove ".pt2" from the tuple in the
conditional so FLAGS.output only normalizes to ".pte". Ensure changes reference
FLAGS.output and the suffix check in main.py.
In `@source/tests/pt_expt/test_dp_test.py`:
- Around line 27-45: The model configuration dictionary model_se_e2_a and the
checkpoint creation logic duplicated between test_dp_test.py and
test_dp_freeze.py should be extracted into a shared pytest fixture or helper
function (e.g., in conftest.py or a test_helpers module); create a fixture named
model_se_e2_a that returns the dict and a helper fixture/function (e.g.,
make_checkpoint or checkpoint_fixture) that encapsulates the checkpoint creation
pattern, then update both tests to accept those fixtures instead of redefining
the dict/checkpoint code so duplication is removed and maintenance is
centralized.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: e8b4749c-4b84-4f85-a10c-9f671ace384a
📒 Files selected for processing (3)
deepmd/pt_expt/entrypoints/main.pysource/tests/pt_expt/test_dp_freeze.pysource/tests/pt_expt/test_dp_test.py
| if Path(FLAGS.checkpoint_folder).is_dir(): | ||
| checkpoint_path = Path(FLAGS.checkpoint_folder) | ||
| latest_ckpt_file = (checkpoint_path / "checkpoint").read_text() | ||
| FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file)) |
There was a problem hiding this comment.
Strip whitespace from checkpoint filename to avoid path errors.
read_text() preserves trailing newlines. If the checkpoint file contains "model-100.pt\n", the constructed path will be invalid.
Proposed fix
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
- latest_ckpt_file = (checkpoint_path / "checkpoint").read_text()
+ latest_ckpt_file = (checkpoint_path / "checkpoint").read_text().strip()
FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt_expt/entrypoints/main.py` around lines 250 - 253, The code reading
the checkpoint name using (checkpoint_path / "checkpoint").read_text() assigns
latest_ckpt_file with possible trailing newline/whitespace which breaks
FLAGS.model path construction; update the read to strip whitespace (e.g., call
.strip() on the result) before using checkpoint_path.joinpath and set
FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file.strip())), ensuring
you reference FLAGS.checkpoint_folder, checkpoint_path, latest_ckpt_file and
FLAGS.model when making the change.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5302 +/- ##
==========================================
+ Coverage 82.32% 82.33% +0.01%
==========================================
Files 768 768
Lines 77098 77125 +27
Branches 3659 3660 +1
==========================================
+ Hits 63469 63500 +31
+ Misses 12458 12453 -5
- Partials 1171 1172 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary
dp freezesupport for the pt_expt backend, enabling checkpoint.pt→ exported.pteconversiondp freezeanddp testwith.ptemodelsBackground
The pt_expt backend can export models to
.pteviadeserialize_to_file(), anddp testcan already load.ptemodels through the registeredDeepEval. However,dp freezewas notwired up — calling
dp freeze -b pt-expthitRuntimeError: Unsupported command 'freeze'.Changes
deepmd/pt_expt/entrypoints/main.pyfreeze()function: loads.ptcheckpoint → reconstructs model viaget_model+ModelWrapper→ serializes → exports to.pteviadeserialize_to_filefreezecommand inmain()dispatcher with checkpoint directory resolution and.ptedefault suffixsource/tests/pt_expt/test_dp_freeze.py(new)test_freeze_pte— verify.ptefile is created from checkpointtest_freeze_main_dispatcher— testmain()CLI dispatcher with freeze commandtest_freeze_default_suffix— verify non-.pteoutput suffix is corrected to.ptesource/tests/pt_expt/test_dp_test.py(new)test_dp_test_system— testdp testwith-ssystem path, verify.e.out,.f.out,.v.outoutputstest_dp_test_input_json— testdp testwith--valid-dataJSON inputTest plan
python -m pytest source/tests/pt_expt/test_dp_freeze.py -v(3 passed)python -m pytest source/tests/pt_expt/test_dp_test.py -v(2 passed)Summary by CodeRabbit
New Features
Tests