-
Notifications
You must be signed in to change notification settings - Fork 608
feat(pt_expt): add dp freeze support and dp test tests for .pte models #5302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
35745c1
11f96ba
e4298d3
33fe75e
1d99141
8479016
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -160,6 +160,64 @@ def train( | |
| trainer.run() | ||
|
|
||
|
|
||
| def freeze( | ||
| model: str, | ||
| output: str = "frozen_model.pte", | ||
| head: str | None = None, | ||
| ) -> None: | ||
| """Freeze a pt_expt checkpoint into a .pte exported model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : str | ||
| Path to the checkpoint file (.pt). | ||
| output : str | ||
| Path for the output .pte file. | ||
| head : str or None | ||
| Head to freeze in multi-task mode (not yet supported). | ||
| """ | ||
| import torch | ||
|
|
||
| from deepmd.pt_expt.model.get_model import ( | ||
| get_model, | ||
| ) | ||
| from deepmd.pt_expt.train.wrapper import ( | ||
| ModelWrapper, | ||
| ) | ||
| from deepmd.pt_expt.utils.env import ( | ||
| DEVICE, | ||
| ) | ||
| from deepmd.pt_expt.utils.serialization import ( | ||
| deserialize_to_file, | ||
| ) | ||
|
|
||
| state_dict = torch.load(model, map_location=DEVICE, weights_only=True) | ||
| if "model" in state_dict: | ||
| state_dict = state_dict["model"] | ||
|
|
||
| extra_state = state_dict.get("_extra_state") | ||
| if not isinstance(extra_state, dict) or "model_params" not in extra_state: | ||
| raise ValueError( | ||
| f"Unsupported checkpoint format at '{model}': missing " | ||
| "'_extra_state.model_params' in model state dict." | ||
| ) | ||
| model_params = extra_state["model_params"] | ||
|
|
||
| if "model_dict" in model_params: | ||
| raise NotImplementedError( | ||
| "Multi-task freeze is not yet supported for the pt_expt backend." | ||
| ) | ||
|
|
||
| m = get_model(model_params) | ||
| wrapper = ModelWrapper(m) | ||
| wrapper.load_state_dict(state_dict) | ||
| m.eval() | ||
|
|
||
| model_dict = m.serialize() | ||
| deserialize_to_file(output, {"model": model_dict}) | ||
| log.info("Saved frozen model to %s", output) | ||
|
|
||
|
|
||
| def main(args: list[str] | argparse.Namespace | None = None) -> None: | ||
| """Entry point for the pt_expt backend CLI. | ||
|
|
||
|
|
@@ -195,6 +253,28 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None: | |
| skip_neighbor_stat=FLAGS.skip_neighbor_stat, | ||
| output=FLAGS.output, | ||
| ) | ||
| elif FLAGS.command == "freeze": | ||
| if Path(FLAGS.checkpoint_folder).is_dir(): | ||
| checkpoint_path = Path(FLAGS.checkpoint_folder) | ||
| # pt_expt training saves a symlink "model.ckpt.pt" → latest ckpt | ||
| default_ckpt = checkpoint_path / "model.ckpt.pt" | ||
| if default_ckpt.exists(): | ||
| FLAGS.model = str(default_ckpt) | ||
| else: | ||
| raise FileNotFoundError( | ||
| f"Cannot find checkpoint in '{checkpoint_path}'. " | ||
| "Expected 'model.ckpt.pt' (created by pt_expt training)." | ||
| ) | ||
| else: | ||
| model_path = Path(FLAGS.checkpoint_folder) | ||
| if not model_path.exists(): | ||
| raise FileNotFoundError( | ||
| f"Checkpoint path '{model_path}' does not exist." | ||
| ) | ||
| FLAGS.model = str(model_path) | ||
| if not FLAGS.output.endswith((".pte", ".pt2")): | ||
|
wanghan-iapcm marked this conversation as resolved.
|
||
| FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte")) | ||
|
Comment on lines
+275
to
+276
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Verify whether `.pt2` is supported wherever `.pte` is consumed.
# Expected: extension routing/loading code should explicitly accept both,
# otherwise prefer normalizing freeze output to `.pte` only.
rg -nP --type=py -C3 '\.pte\b|\.pt2\b'
rg -nP --type=py -C4 'freeze|test|infer|load.*model|endswith\(\(.*\.pte|\.pt2'Repository: deepmodeling/deepmd-kit Length of output: 50381 🏁 Script executed: # Search for code that loads .pte or .pt2 files in inference/test paths
rg -n --type=py '\.pte|\.pt2' --max-count=100 deepmd/pt_expt/infer/ deepmd/pt_expt/utils/serialization.py deepmd/backend/pt_expt.py
# Check the DeepEval class that loads models
rg -n --type=py -B2 -A5 'class DeepEval|def.*__init__' deepmd/pt_expt/infer/deep_eval.py | head -50
# Search for any explicit .pt2 handling
rg -n --type=py 'pt2' deepmd/Repository: deepmodeling/deepmd-kit Length of output: 1948 🏁 Script executed: # Check if .pt2 is used anywhere else in tests or loading code
rg -n 'pt2' source/tests/ deepmd/pt_expt/
# Check if there's any model loading with file extension handling
rg -n --type=py 'load.*pte|pte.*load|suffix' deepmd/pt_expt/infer/ deepmd/pt_expt/utils/
# Look for file operations that might accept both extensions
rg -n --type=py 'endswith|suffix|extension' deepmd/pt_expt/ -A2 -B2 | head -80Repository: deepmodeling/deepmd-kit Length of output: 4004 🏁 Script executed: # Final verification: check if pt2 is referenced in any way outside main.py
rg -r 'pt2' --type=py . | grep -v 'test_hybrid_muon\|test_adamuon\|descrpt2r4'
# Verify the backend suffixes definition once more
cat -n deepmd/backend/pt_expt.py | sed -n '40,50p'Repository: deepmodeling/deepmd-kit Length of output: 50382 🏁 Script executed: # Get the backend file content to see the suffix definition
cat -n deepmd/backend/pt_expt.py | sed -n '40,50p'
# Search more specifically for code that reads .pte files
rg -n --type=py 'torch.export.load|load.*pte' deepmd/Repository: deepmodeling/deepmd-kit Length of output: 654 Tighten output normalization to accept only Lines 275-276 preserve both Suggested fix- if not FLAGS.output.endswith((".pte", ".pt2")):
+ if not FLAGS.output.endswith(".pte"):
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte"))🤖 Prompt for AI Agents |
||
| freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head) | ||
| else: | ||
| raise RuntimeError( | ||
| f"Unsupported command '{FLAGS.command}' for the pt_expt backend." | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| import argparse | ||
| import os | ||
| import shutil | ||
| import tempfile | ||
| import unittest | ||
| from copy import ( | ||
| deepcopy, | ||
| ) | ||
|
|
||
| import torch | ||
|
|
||
| from deepmd.pt_expt.entrypoints.main import ( | ||
| freeze, | ||
| main, | ||
| ) | ||
| from deepmd.pt_expt.model.get_model import ( | ||
| get_model, | ||
| ) | ||
| from deepmd.pt_expt.train.wrapper import ( | ||
| ModelWrapper, | ||
| ) | ||
|
|
||
| model_se_e2_a = { | ||
| "type_map": ["O", "H", "B"], | ||
| "descriptor": { | ||
| "type": "se_e2_a", | ||
| "sel": [46, 92, 4], | ||
| "rcut_smth": 0.50, | ||
| "rcut": 4.00, | ||
| "neuron": [25, 50, 100], | ||
| "resnet_dt": False, | ||
| "axis_neuron": 16, | ||
| "seed": 1, | ||
| }, | ||
| "fitting_net": { | ||
| "neuron": [24, 24, 24], | ||
| "resnet_dt": True, | ||
| "seed": 1, | ||
| }, | ||
| "data_stat_nbatch": 20, | ||
| } | ||
|
|
||
|
|
||
| class TestDPFreezePtExpt(unittest.TestCase): | ||
| """Test dp freeze for the pt_expt backend.""" | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls) -> None: | ||
| cls.tmpdir = tempfile.mkdtemp() | ||
|
|
||
| # Build a model and save a fake checkpoint | ||
| model_params = deepcopy(model_se_e2_a) | ||
| model = get_model(model_params) | ||
| wrapper = ModelWrapper(model, model_params=model_params) | ||
| state_dict = wrapper.state_dict() | ||
| cls.ckpt_file = os.path.join(cls.tmpdir, "model.pt") | ||
| torch.save({"model": state_dict}, cls.ckpt_file) | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls) -> None: | ||
| shutil.rmtree(cls.tmpdir) | ||
|
|
||
| def test_freeze_pte(self) -> None: | ||
| """Freeze to .pte and verify the file is created.""" | ||
| output = os.path.join(self.tmpdir, "frozen_model.pte") | ||
| freeze(model=self.ckpt_file, output=output) | ||
| self.assertTrue(os.path.exists(output)) | ||
|
|
||
| def test_freeze_main_dispatcher(self) -> None: | ||
| """Test main() CLI dispatcher with freeze command.""" | ||
| output_file = os.path.join(self.tmpdir, "frozen_via_main.pte") | ||
| flags = argparse.Namespace( | ||
| command="freeze", | ||
| checkpoint_folder=self.ckpt_file, | ||
| output=output_file, | ||
| head=None, | ||
| log_level=2, # WARNING | ||
| log_path=None, | ||
| ) | ||
| main(flags) | ||
| self.assertTrue(os.path.exists(output_file)) | ||
|
|
||
| def test_freeze_default_suffix(self) -> None: | ||
| """Test that main() defaults output suffix to .pte.""" | ||
| output_file = os.path.join(self.tmpdir, "frozen_default_suffix.pth") | ||
| flags = argparse.Namespace( | ||
| command="freeze", | ||
| checkpoint_folder=self.ckpt_file, | ||
| output=output_file, | ||
| head=None, | ||
| log_level=2, # WARNING | ||
| log_path=None, | ||
| ) | ||
| main(flags) | ||
| expected = os.path.join(self.tmpdir, "frozen_default_suffix.pte") | ||
| self.assertTrue(os.path.exists(expected)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| import json | ||
| import os | ||
| import shutil | ||
| import tempfile | ||
| import unittest | ||
| from copy import ( | ||
| deepcopy, | ||
| ) | ||
| from pathlib import ( | ||
| Path, | ||
| ) | ||
|
|
||
| import torch | ||
|
|
||
| from deepmd.entrypoints.test import test as dp_test | ||
| from deepmd.pt_expt.entrypoints.main import ( | ||
| freeze, | ||
| ) | ||
| from deepmd.pt_expt.model.get_model import ( | ||
| get_model, | ||
| ) | ||
| from deepmd.pt_expt.train.wrapper import ( | ||
| ModelWrapper, | ||
| ) | ||
|
|
||
| model_se_e2_a = { | ||
| "type_map": ["O", "H", "B"], | ||
| "descriptor": { | ||
| "type": "se_e2_a", | ||
| "sel": [46, 92, 4], | ||
| "rcut_smth": 0.50, | ||
| "rcut": 4.00, | ||
| "neuron": [25, 50, 100], | ||
| "resnet_dt": False, | ||
| "axis_neuron": 16, | ||
| "seed": 1, | ||
| }, | ||
| "fitting_net": { | ||
| "neuron": [24, 24, 24], | ||
| "resnet_dt": True, | ||
| "seed": 1, | ||
| }, | ||
| "data_stat_nbatch": 20, | ||
| } | ||
|
|
||
|
|
||
| class TestDPTestPtExpt(unittest.TestCase): | ||
| """Test dp test for the pt_expt backend (.pte models).""" | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls) -> None: | ||
| cls.data_file = str( | ||
| Path(__file__).parents[1] / "pt" / "water" / "data" / "single" | ||
| ) | ||
| cls.detail_file = os.path.join( | ||
| tempfile.mkdtemp(), "test_dp_test_pt_expt_detail" | ||
| ) | ||
| cls.tmpdir = tempfile.mkdtemp() | ||
|
|
||
| # Build a model, save a checkpoint, and freeze to .pte | ||
| model_params = deepcopy(model_se_e2_a) | ||
| model = get_model(model_params) | ||
| wrapper = ModelWrapper(model, model_params=model_params) | ||
| state_dict = wrapper.state_dict() | ||
| ckpt_file = os.path.join(cls.tmpdir, "model.pt") | ||
| torch.save({"model": state_dict}, ckpt_file) | ||
|
|
||
| cls.pte_file = os.path.join(cls.tmpdir, "frozen_model.pte") | ||
| freeze(model=ckpt_file, output=cls.pte_file) | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls) -> None: | ||
| shutil.rmtree(cls.tmpdir) | ||
| detail_dir = os.path.dirname(cls.detail_file) | ||
| if os.path.exists(detail_dir): | ||
| shutil.rmtree(detail_dir) | ||
|
|
||
| def test_dp_test_system(self) -> None: | ||
| """Test dp test with -s system path.""" | ||
| detail = self.detail_file + "_sys" | ||
| dp_test( | ||
| model=self.pte_file, | ||
| system=self.data_file, | ||
| datafile=None, | ||
| set_prefix="set", | ||
| numb_test=0, | ||
| rand_seed=None, | ||
| shuffle_test=False, | ||
| detail_file=detail, | ||
| atomic=False, | ||
| ) | ||
| self.assertTrue(os.path.exists(detail + ".e.out")) | ||
| self.assertTrue(os.path.exists(detail + ".f.out")) | ||
| self.assertTrue(os.path.exists(detail + ".v.out")) | ||
|
|
||
| def test_dp_test_input_json(self) -> None: | ||
| """Test dp test with --valid-data JSON input.""" | ||
| config = { | ||
| "model": deepcopy(model_se_e2_a), | ||
| "training": { | ||
| "training_data": {"systems": [self.data_file]}, | ||
| "validation_data": {"systems": [self.data_file]}, | ||
| }, | ||
| } | ||
| input_json = os.path.join(self.tmpdir, "test_input.json") | ||
| with open(input_json, "w") as fp: | ||
| json.dump(config, fp, indent=4) | ||
|
|
||
| detail = self.detail_file + "_json" | ||
| dp_test( | ||
| model=self.pte_file, | ||
| system=None, | ||
| datafile=None, | ||
| valid_json=input_json, | ||
| set_prefix="set", | ||
| numb_test=0, | ||
| rand_seed=None, | ||
| shuffle_test=False, | ||
| detail_file=detail, | ||
| atomic=False, | ||
| ) | ||
| self.assertTrue(os.path.exists(detail + ".e.out")) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Uh oh!
There was an error while loading. Please reload this page.