Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions deepmd/backend/pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Callable,
)
from typing import (
TYPE_CHECKING,
ClassVar,
)

from deepmd.backend.backend import (
Backend,
)
from deepmd.pretrained.registry import (
available_model_names,
)

if TYPE_CHECKING:
from argparse import (
Namespace,
)

from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (
NeighborStat,
)


@Backend.register("pretrained")
class PretrainedBackend(Backend):
Comment thread
njzjz marked this conversation as resolved.
"""Internal virtual backend for pretrained model-name alias dispatch.

This backend is not intended to be selected explicitly by users as a real
compute backend (such as TensorFlow/PyTorch/Paddle/JAX). It only bridges
built-in pretrained model names into the regular deep-eval loading path.

For convenience, all built-in pretrained model names are registered as
suffix-like aliases, so users can pass model names directly, e.g.
``DeepPot("DPA-3.2-5M")``.
"""

name = "Pretrained"
features: ClassVar[Backend.Feature] = Backend.Feature.DEEP_EVAL
suffixes: ClassVar[list[str]] = [
*[model_name.lower() for model_name in available_model_names()],
]

def is_available(self) -> bool:
return True

@property
def entry_point_hook(self) -> Callable[["Namespace"], None]:
raise NotImplementedError("Unsupported backend: pretrained")

@property
def deep_eval(self) -> type["DeepEvalBackend"]:
from deepmd.pretrained.deep_eval import (
PretrainedDeepEvalBackend,
)

return PretrainedDeepEvalBackend

@property
def neighbor_stat(self) -> type["NeighborStat"]:
raise NotImplementedError("Unsupported backend: pretrained")

@property
def serialize_hook(self) -> Callable[[str], dict]:
raise NotImplementedError("Unsupported backend: pretrained")

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
raise NotImplementedError("Unsupported backend: pretrained")
5 changes: 5 additions & 0 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from deepmd.loggers.loggers import (
set_log_handles,
)
from deepmd.pretrained.entrypoints import (
pretrained_entrypoint,
)


def main(args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -97,5 +100,7 @@ def main(args: argparse.Namespace) -> None:
convert_backend(**dict_args)
elif args.command == "show":
show(**dict_args)
elif args.command == "pretrained":
pretrained_entrypoint(args)
else:
raise ValueError(f"Unknown command: {args.command}")
33 changes: 33 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from deepmd.backend.backend import (
Backend,
)
from deepmd.pretrained.registry import (
available_model_names,
)

try:
from deepmd._version import version as __version__
Expand Down Expand Up @@ -942,6 +945,35 @@ def main_parser() -> argparse.ArgumentParser:
],
nargs="+",
)

# pretrained
parser_pretrained = subparsers.add_parser(
"pretrained",
parents=[parser_log],
help="Manage builtin pretrained models",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
)
pretrained_subparsers = parser_pretrained.add_subparsers(
dest="pretrained_command",
required=True,
)
parser_pretrained_download = pretrained_subparsers.add_parser(
"download",
help="Download one pretrained model",
)

parser_pretrained_download.add_argument(
"MODEL",
choices=available_model_names(),
help="Pretrained model name",
)
parser_pretrained_download.add_argument(
"--cache-dir",
default=None,
type=str,
help="Optional cache directory for pretrained model files",
)

return parser


Expand Down Expand Up @@ -997,6 +1029,7 @@ def main(args: list[str] | None = None) -> None:
"gui",
"convert-backend",
"show",
"pretrained",
):
# common entrypoints
from deepmd.entrypoints.main import main as deepmd_main
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pretrained/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Pretrained model helpers for DeePMD-kit."""
186 changes: 186 additions & 0 deletions deepmd/pretrained/deep_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""DeepEval adapter for pretrained model-name aliases."""

from __future__ import (
annotations,
)

from pathlib import (
Path,
)
from typing import (
TYPE_CHECKING,
Any,
)

from deepmd.infer.deep_eval import (
DeepEval,
DeepEvalBackend,
)
from deepmd.pretrained.download import (
resolve_model_path,
)
from deepmd.pretrained.registry import (
MODEL_REGISTRY,
)

if TYPE_CHECKING:
import numpy as np


class InvalidPretrainedAliasError(ValueError):
"""Raised when a pretrained alias string is malformed."""

def __init__(self, model_file: str) -> None:
super().__init__(f"Invalid pretrained model name: {model_file}")


def parse_pretrained_alias(model_file: str) -> str:
"""Extract built-in pretrained model name from alias string.

Accepted form:
- ``<MODEL>`` where ``<MODEL>`` is a built-in registry name
"""
alias = Path(model_file).name

if alias in MODEL_REGISTRY:
return alias

lowered = alias.lower()
for model_name in MODEL_REGISTRY:
if model_name.lower() == lowered:
return model_name

raise InvalidPretrainedAliasError(model_file)


class PretrainedDeepEvalBackend(DeepEvalBackend):
"""Resolve alias and delegate to backend selected by resolved model path."""

def __init__(
self,
model_file: str,
output_def: object,
*args: object,
auto_batch_size: object = True,
neighbor_list: object | None = None,
**kwargs: object,
) -> None:
model_name = parse_pretrained_alias(model_file)
resolved = str(resolve_model_path(model_name))

# DeepEvalBackend.__new__ dispatches by resolved suffix (.pt/.pb/.dp...)
self._backend = DeepEvalBackend(
resolved,
output_def,
*args,
auto_batch_size=auto_batch_size,
neighbor_list=neighbor_list,
**kwargs,
)

def eval(
self,
coords: np.ndarray,
cells: np.ndarray | None,
atom_types: np.ndarray,
atomic: bool = False,
fparam: np.ndarray | None = None,
aparam: np.ndarray | None = None,
**kwargs: Any,
) -> dict[str, np.ndarray]:
return self._backend.eval(
coords,
cells,
atom_types,
atomic,
fparam=fparam,
aparam=aparam,
**kwargs,
)

def eval_descriptor(
self,
coords: np.ndarray,
cells: np.ndarray | None,
atom_types: np.ndarray,
fparam: np.ndarray | None = None,
aparam: np.ndarray | None = None,
efield: np.ndarray | None = None,
mixed_type: bool = False,
**kwargs: Any,
) -> np.ndarray:
return self._backend.eval_descriptor(
coords,
cells,
atom_types,
fparam=fparam,
aparam=aparam,
efield=efield,
mixed_type=mixed_type,
**kwargs,
)

def eval_fitting_last_layer(
self,
coords: np.ndarray,
cells: np.ndarray | None,
atom_types: np.ndarray,
fparam: np.ndarray | None = None,
aparam: np.ndarray | None = None,
**kwargs: Any,
) -> np.ndarray:
return self._backend.eval_fitting_last_layer(
coords,
cells,
atom_types,
fparam=fparam,
aparam=aparam,
**kwargs,
)

def get_rcut(self) -> float:
return self._backend.get_rcut()

def get_ntypes(self) -> int:
return self._backend.get_ntypes()

def get_type_map(self) -> list[str]:
return self._backend.get_type_map()

def get_dim_fparam(self) -> int:
return self._backend.get_dim_fparam()

def has_default_fparam(self) -> bool:
return self._backend.has_default_fparam()

def get_dim_aparam(self) -> int:
return self._backend.get_dim_aparam()

@property
def model_type(self) -> type[DeepEval]:
return self._backend.model_type

def get_sel_type(self) -> list[int]:
return self._backend.get_sel_type()

def get_numb_dos(self) -> int:
return self._backend.get_numb_dos()

def get_has_efield(self) -> bool:
return self._backend.get_has_efield()

def get_has_spin(self) -> bool:
return self._backend.get_has_spin()

def get_has_hessian(self) -> bool:
return self._backend.get_has_hessian()

def get_var_name(self) -> str:
return self._backend.get_var_name()

def get_ntypes_spin(self) -> int:
return self._backend.get_ntypes_spin()

def get_model(self) -> Any:
return self._backend.get_model()
Loading