Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions deepmd/pt_expt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,64 @@ def train(
trainer.run()


def freeze(
model: str,
output: str = "frozen_model.pte",
head: str | None = None,
) -> None:
"""Freeze a pt_expt checkpoint into a .pte exported model.

Parameters
----------
model : str
Path to the checkpoint file (.pt).
output : str
Path for the output .pte file.
head : str or None
Head to freeze in multi-task mode (not yet supported).
"""
import torch

from deepmd.pt_expt.model.get_model import (
get_model,
)
from deepmd.pt_expt.train.wrapper import (
ModelWrapper,
)
from deepmd.pt_expt.utils.env import (
DEVICE,
)
from deepmd.pt_expt.utils.serialization import (
deserialize_to_file,
)

state_dict = torch.load(model, map_location=DEVICE, weights_only=True)
if "model" in state_dict:
state_dict = state_dict["model"]

Comment thread
wanghan-iapcm marked this conversation as resolved.
extra_state = state_dict.get("_extra_state")
if not isinstance(extra_state, dict) or "model_params" not in extra_state:
raise ValueError(
f"Unsupported checkpoint format at '{model}': missing "
"'_extra_state.model_params' in model state dict."
)
model_params = extra_state["model_params"]

if "model_dict" in model_params:
raise NotImplementedError(
"Multi-task freeze is not yet supported for the pt_expt backend."
)

m = get_model(model_params)
wrapper = ModelWrapper(m)
wrapper.load_state_dict(state_dict)
m.eval()

model_dict = m.serialize()
deserialize_to_file(output, {"model": model_dict})
log.info("Saved frozen model to %s", output)


def main(args: list[str] | argparse.Namespace | None = None) -> None:
"""Entry point for the pt_expt backend CLI.

Expand Down Expand Up @@ -195,6 +253,28 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
skip_neighbor_stat=FLAGS.skip_neighbor_stat,
output=FLAGS.output,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
# pt_expt training saves a symlink "model.ckpt.pt" → latest ckpt
default_ckpt = checkpoint_path / "model.ckpt.pt"
if default_ckpt.exists():
FLAGS.model = str(default_ckpt)
else:
raise FileNotFoundError(
f"Cannot find checkpoint in '{checkpoint_path}'. "
"Expected 'model.ckpt.pt' (created by pt_expt training)."
)
else:
model_path = Path(FLAGS.checkpoint_folder)
if not model_path.exists():
raise FileNotFoundError(
f"Checkpoint path '{model_path}' does not exist."
)
FLAGS.model = str(model_path)
if not FLAGS.output.endswith((".pte", ".pt2")):
Comment thread
wanghan-iapcm marked this conversation as resolved.
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte"))
Comment on lines +275 to +276
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether `.pt2` is supported wherever `.pte` is consumed.
# Expected: extension routing/loading code should explicitly accept both,
# otherwise prefer normalizing freeze output to `.pte` only.

rg -nP --type=py -C3 '\.pte\b|\.pt2\b'
rg -nP --type=py -C4 'freeze|test|infer|load.*model|endswith\(\(.*\.pte|\.pt2'

Repository: deepmodeling/deepmd-kit

Length of output: 50381


🏁 Script executed:

# Search for code that loads .pte or .pt2 files in inference/test paths
rg -n --type=py '\.pte|\.pt2' --max-count=100 deepmd/pt_expt/infer/ deepmd/pt_expt/utils/serialization.py deepmd/backend/pt_expt.py

# Check the DeepEval class that loads models
rg -n --type=py -B2 -A5 'class DeepEval|def.*__init__' deepmd/pt_expt/infer/deep_eval.py | head -50

# Search for any explicit .pt2 handling
rg -n --type=py 'pt2' deepmd/

Repository: deepmodeling/deepmd-kit

Length of output: 1948


🏁 Script executed:

# Check if .pt2 is used anywhere else in tests or loading code
rg -n 'pt2' source/tests/ deepmd/pt_expt/

# Check if there's any model loading with file extension handling
rg -n --type=py 'load.*pte|pte.*load|suffix' deepmd/pt_expt/infer/ deepmd/pt_expt/utils/

# Look for file operations that might accept both extensions
rg -n --type=py 'endswith|suffix|extension' deepmd/pt_expt/ -A2 -B2 | head -80

Repository: deepmodeling/deepmd-kit

Length of output: 4004


🏁 Script executed:

# Final verification: check if pt2 is referenced in any way outside main.py
rg -r 'pt2' --type=py . | grep -v 'test_hybrid_muon\|test_adamuon\|descrpt2r4'

# Verify the backend suffixes definition once more
cat -n deepmd/backend/pt_expt.py | sed -n '40,50p'

Repository: deepmodeling/deepmd-kit

Length of output: 50382


🏁 Script executed:

# Get the backend file content to see the suffix definition
cat -n deepmd/backend/pt_expt.py | sed -n '40,50p'

# Search more specifically for code that reads .pte files  
rg -n --type=py 'torch.export.load|load.*pte' deepmd/

Repository: deepmodeling/deepmd-kit

Length of output: 654


Tighten output normalization to accept only .pte extension.

Lines 275-276 preserve both .pte and .pt2 suffixes, but the backend and all loaders (serialize_from_file, DeepEval, torch.export.load) only support .pte. If users provide .pt2 output, the resulting files cannot be loaded downstream, creating broken artifacts.

Suggested fix
-        if not FLAGS.output.endswith((".pte", ".pt2")):
+        if not FLAGS.output.endswith(".pte"):
             FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte"))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/entrypoints/main.py` around lines 275 - 276, The output
normalization should only accept the .pte extension: change the condition that
checks FLAGS.output so it only allows ".pte" (i.e., replace the
endswith((".pte", ".pt2")) check with endswith(".pte")), and if it doesn't, set
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte")); update the logic
around FLAGS.output and Path.with_suffix to drop support for ".pt2" so
downstream loaders (serialize_from_file, DeepEval, torch.export.load) always get
a .pte file.

freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head)
else:
raise RuntimeError(
f"Unsupported command '{FLAGS.command}' for the pt_expt backend."
Expand Down
101 changes: 101 additions & 0 deletions source/tests/pt_expt/test_dp_freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import os
import shutil
import tempfile
import unittest
from copy import (
deepcopy,
)

import torch

from deepmd.pt_expt.entrypoints.main import (
freeze,
main,
)
from deepmd.pt_expt.model.get_model import (
get_model,
)
from deepmd.pt_expt.train.wrapper import (
ModelWrapper,
)

model_se_e2_a = {
"type_map": ["O", "H", "B"],
"descriptor": {
"type": "se_e2_a",
"sel": [46, 92, 4],
"rcut_smth": 0.50,
"rcut": 4.00,
"neuron": [25, 50, 100],
"resnet_dt": False,
"axis_neuron": 16,
"seed": 1,
},
"fitting_net": {
"neuron": [24, 24, 24],
"resnet_dt": True,
"seed": 1,
},
"data_stat_nbatch": 20,
}


class TestDPFreezePtExpt(unittest.TestCase):
"""Test dp freeze for the pt_expt backend."""

@classmethod
def setUpClass(cls) -> None:
cls.tmpdir = tempfile.mkdtemp()

# Build a model and save a fake checkpoint
model_params = deepcopy(model_se_e2_a)
model = get_model(model_params)
wrapper = ModelWrapper(model, model_params=model_params)
state_dict = wrapper.state_dict()
cls.ckpt_file = os.path.join(cls.tmpdir, "model.pt")
torch.save({"model": state_dict}, cls.ckpt_file)

@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree(cls.tmpdir)

def test_freeze_pte(self) -> None:
"""Freeze to .pte and verify the file is created."""
output = os.path.join(self.tmpdir, "frozen_model.pte")
freeze(model=self.ckpt_file, output=output)
self.assertTrue(os.path.exists(output))

def test_freeze_main_dispatcher(self) -> None:
"""Test main() CLI dispatcher with freeze command."""
output_file = os.path.join(self.tmpdir, "frozen_via_main.pte")
flags = argparse.Namespace(
command="freeze",
checkpoint_folder=self.ckpt_file,
output=output_file,
head=None,
log_level=2, # WARNING
log_path=None,
)
main(flags)
self.assertTrue(os.path.exists(output_file))

def test_freeze_default_suffix(self) -> None:
"""Test that main() defaults output suffix to .pte."""
output_file = os.path.join(self.tmpdir, "frozen_default_suffix.pth")
flags = argparse.Namespace(
command="freeze",
checkpoint_folder=self.ckpt_file,
output=output_file,
head=None,
log_level=2, # WARNING
log_path=None,
)
main(flags)
expected = os.path.join(self.tmpdir, "frozen_default_suffix.pte")
self.assertTrue(os.path.exists(expected))


if __name__ == "__main__":
unittest.main()
127 changes: 127 additions & 0 deletions source/tests/pt_expt/test_dp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import shutil
import tempfile
import unittest
from copy import (
deepcopy,
)
from pathlib import (
Path,
)

import torch

from deepmd.entrypoints.test import test as dp_test
from deepmd.pt_expt.entrypoints.main import (
freeze,
)
from deepmd.pt_expt.model.get_model import (
get_model,
)
from deepmd.pt_expt.train.wrapper import (
ModelWrapper,
)

model_se_e2_a = {
"type_map": ["O", "H", "B"],
"descriptor": {
"type": "se_e2_a",
"sel": [46, 92, 4],
"rcut_smth": 0.50,
"rcut": 4.00,
"neuron": [25, 50, 100],
"resnet_dt": False,
"axis_neuron": 16,
"seed": 1,
},
"fitting_net": {
"neuron": [24, 24, 24],
"resnet_dt": True,
"seed": 1,
},
"data_stat_nbatch": 20,
}


class TestDPTestPtExpt(unittest.TestCase):
"""Test dp test for the pt_expt backend (.pte models)."""

@classmethod
def setUpClass(cls) -> None:
cls.data_file = str(
Path(__file__).parents[1] / "pt" / "water" / "data" / "single"
)
cls.detail_file = os.path.join(
tempfile.mkdtemp(), "test_dp_test_pt_expt_detail"
)
cls.tmpdir = tempfile.mkdtemp()

# Build a model, save a checkpoint, and freeze to .pte
model_params = deepcopy(model_se_e2_a)
model = get_model(model_params)
wrapper = ModelWrapper(model, model_params=model_params)
state_dict = wrapper.state_dict()
ckpt_file = os.path.join(cls.tmpdir, "model.pt")
torch.save({"model": state_dict}, ckpt_file)

cls.pte_file = os.path.join(cls.tmpdir, "frozen_model.pte")
freeze(model=ckpt_file, output=cls.pte_file)

@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree(cls.tmpdir)
detail_dir = os.path.dirname(cls.detail_file)
if os.path.exists(detail_dir):
shutil.rmtree(detail_dir)

def test_dp_test_system(self) -> None:
"""Test dp test with -s system path."""
detail = self.detail_file + "_sys"
dp_test(
model=self.pte_file,
system=self.data_file,
datafile=None,
set_prefix="set",
numb_test=0,
rand_seed=None,
shuffle_test=False,
detail_file=detail,
atomic=False,
)
self.assertTrue(os.path.exists(detail + ".e.out"))
self.assertTrue(os.path.exists(detail + ".f.out"))
self.assertTrue(os.path.exists(detail + ".v.out"))

def test_dp_test_input_json(self) -> None:
"""Test dp test with --valid-data JSON input."""
config = {
"model": deepcopy(model_se_e2_a),
"training": {
"training_data": {"systems": [self.data_file]},
"validation_data": {"systems": [self.data_file]},
},
}
input_json = os.path.join(self.tmpdir, "test_input.json")
with open(input_json, "w") as fp:
json.dump(config, fp, indent=4)

detail = self.detail_file + "_json"
dp_test(
model=self.pte_file,
system=None,
datafile=None,
valid_json=input_json,
set_prefix="set",
numb_test=0,
rand_seed=None,
shuffle_test=False,
detail_file=detail,
atomic=False,
)
self.assertTrue(os.path.exists(detail + ".e.out"))


if __name__ == "__main__":
unittest.main()