Skip to content

Commit fef8a95

Browse files
committed
ENH: initial uncertainty sampling commit
- add prediction tools (may still need to be updated with more recent changes) - update colvar and uncertainty files to work with newer MACE version and other changes
1 parent 7b63605 commit fef8a95

3 files changed

Lines changed: 637 additions & 40 deletions

File tree

nff/md/colvars.py

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import itertools
88
from itertools import repeat
9-
from typing import TYPE_CHECKING
9+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
1010

1111
import numpy as np
1212
import torch
@@ -16,14 +16,15 @@
1616

1717
from nff.io.ase import AtomsBatch
1818
from nff.train import load_model
19+
from nff.train.evaluate import evaluate
1920
from nff.train.uncertainty import (
2021
EnsembleUncertainty,
2122
EvidentialUncertainty,
2223
GMMUncertainty,
2324
MVEUncertainty,
2425
)
2526
from nff.utils.cuda import batch_to
26-
from nff.utils.prediction import get_prediction, get_residual
27+
from nff.utils.prediction import evaluate_mace, get_prediction, get_residual
2728
from nff.utils.scatter import compute_grad
2829

2930
if TYPE_CHECKING:
@@ -152,21 +153,44 @@ def _init_uncertainty(self):
152153

153154
if self.info_dict.get("uncertainty_type") == "gmm" and self.unc_class.is_fitted() is False:
154155
print("COLVAR: Doing train prediction")
155-
_, train_predicted = get_prediction(
156-
model=self.model,
157-
dset=self.info_dict["train_dset"],
158-
batch_size=self.info_dict["batch_size"],
159-
device=self.device,
160-
requires_grad=False,
161-
)
156+
if any(c in self.model.__repr__() for c in ["Painn", "SchNet"]):
157+
train_predicted, _train_targs, _loss = evaluate(
158+
model=self.model,
159+
loader=self.info_dict["train_dset"],
160+
loss_fn=self.info_dict["loss_fn"],
161+
device=self.device,
162+
requires_embedding=True,
163+
)
162164

163-
train_embedding = train_predicted["embedding"][0].detach().cpu().squeeze()
164-
train_atomic_numbers = torch.cat(
165-
[torch.LongTensor(at.get_atomic_numbers()) for at in self.info_dict["train_dset"]]
166-
)
165+
# GMM requires a 2D tensor for the embeddings, with the
166+
train_embedding = torch.concat(train_predicted["embedding"])
167+
168+
elif "MACE" in self.model.__repr__():
169+
_, train_predicted = evaluate_mace(
170+
model=self.model,
171+
dset=self.info_dict["train_dset"],
172+
batch_size=self.info_dict["batch_size"],
173+
device=self.device,
174+
embedding_kwargs=self.info_dict["uncertainty_params"]["embedding_kwargs"],
175+
)
176+
177+
train_embedding = train_predicted["embeddings"].detach().cpu().squeeze()
178+
# print("COLVAR: Doing train prediction")
179+
# _, train_predicted = get_prediction(
180+
# model=self.model,
181+
# dset=self.info_dict["train_dset"],
182+
# batch_size=self.info_dict["batch_size"],
183+
# device=self.device,
184+
# requires_grad=False,
185+
# )
186+
187+
# train_embedding = train_predicted["embedding"][0].detach().cpu().squeeze()
188+
# train_atomic_numbers = torch.cat(
189+
# [torch.LongTensor(at.get_atomic_numbers()) for at in self.info_dict["train_dset"]]
190+
# )
167191

168192
print("COLVAR: Fitting GMM")
169-
self.unc_class.fit_gmm(train_embedding, train_atomic_numbers)
193+
self.unc_class.fit_gmm(train_embedding)
170194

171195
self.calibrate = self.info_dict["uncertainty_params"].get("calibrate", False)
172196
if self.calibrate:
@@ -667,7 +691,56 @@ def energy_gap(self, enkey1: str, enkey2: str):
667691

668692
return cv, cv_grad
669693

670-
def forward(self, atoms: Atoms) -> tuple[np.ndarray, np.ndarray]:
694+
def uncertainty(self, atoms: Atoms, pred=None, return_grad: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
695+
if pred is None:
696+
_, pred = get_prediction(
697+
self.model,
698+
dset=[atoms],
699+
batch_size=self.info_dict["batch_size"],
700+
device=self.device,
701+
get_target=False,
702+
requires_grad=True,
703+
pool_embedding=True,
704+
)
705+
706+
# get neighbor list
707+
atoms.update_nbr_list()
708+
pred["nbr_list"] = torch.LongTensor(atoms.nbr_list).to(self.device)
709+
710+
# get atomic numbers
711+
pred["test_atomic_numbers"] = torch.LongTensor(atoms.get_atomic_numbers())
712+
713+
uncertainty = self.unc_class(
714+
results=pred,
715+
num_atoms=pred["num_atoms"],
716+
reset_min_uncertainty=False,
717+
device=self.device,
718+
)
719+
720+
if return_grad is False:
721+
return uncertainty, None
722+
723+
if not uncertainty.requires_grad:
724+
uncertainty.requires_grad = True
725+
726+
uncertainty_grad = compute_grad(
727+
inputs=pred["xyz"],
728+
output=uncertainty,
729+
allow_unused=True,
730+
)
731+
if uncertainty_grad is None:
732+
uncertainty_grad = torch.zeros_like(pred["xyz"])
733+
734+
# make sure uncertainty is a scalar
735+
uncertainty = uncertainty.sum()
736+
737+
return uncertainty, uncertainty_grad
738+
739+
def forward(
740+
self,
741+
atoms: Atoms,
742+
pred: Optional[Dict] = None, # noqa
743+
) -> tuple[np.ndarray, np.ndarray]:
671744
"""Switch function to call the right CV-func
672745
673746
Args:
@@ -732,6 +805,9 @@ def forward(self, atoms: Atoms) -> tuple[np.ndarray, np.ndarray]:
732805
elif self.info_dict["name"] == "energy_gap":
733806
cv, cv_grad = self.energy_gap(self.info_dict["enkey_1"], self.info_dict["enkey_2"])
734807

808+
elif self.info_dict["name"] == "uncertainty":
809+
cv, cv_grad = self.uncertainty(atoms, pred)
810+
735811
return cv.detach().cpu().numpy(), cv_grad.detach().cpu().numpy()
736812

737813

nff/train/uncertainty.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from nff.io.gmm import GaussianMixture
2222
from nff.train.evaluate import evaluate
23-
from nff.utils.cuda import batch_detach
2423
from nff.utils.prediction import get_residual
2524

2625
__all__ = [
@@ -42,7 +41,7 @@
4241
class Uncertainty:
4342
"""Base class for uncertainty predictions."""
4443

45-
def __init__( # noqa: D107
44+
def __init__(
4645
self,
4746
order: str,
4847
calibrate: bool,
@@ -69,7 +68,7 @@ def __init__( # noqa: D107
6968

7069
self.CP = ConformalPrediction(alpha=cp_alpha)
7170

72-
def __call__(self, *args, **kwargs): # noqa: D102
71+
def __call__(self, *args, **kwargs):
7372
return self.get_uncertainty(*args, **kwargs)
7473

7574
def set_min_uncertainty(self, min_uncertainty: float, force: bool = False) -> None:
@@ -175,7 +174,7 @@ class ConformalPrediction:
175174
on calibration data and apply to test data during prediction.
176175
"""
177176

178-
def __init__(self, alpha: float): # noqa: D107
177+
def __init__(self, alpha: float):
179178
self.alpha = alpha
180179

181180
def fit(
@@ -225,7 +224,7 @@ class EnsembleUncertainty(Uncertainty):
225224
targ_unit (Union[str, None], optional): Target unit of the quantity. Defaults to None.
226225
"""
227226

228-
def __init__( # noqa: D107
227+
def __init__(
229228
self,
230229
quantity: str,
231230
order: str,
@@ -342,7 +341,7 @@ class for the possible options.
342341
min_uncertainty (Union[float, None], optional): Minimum uncertainty value. Defaults to None.
343342
"""
344343

345-
def __init__( # noqa: D107
344+
def __init__(
346345
self,
347346
order: str = "atomic",
348347
shared_v: bool = False,
@@ -429,7 +428,7 @@ class for the possible options.
429428
min_uncertainty (Union[float, None], optional): Minimum uncertainty value. Defaults to None.
430429
"""
431430

432-
def __init__( # noqa: D107
431+
def __init__(
433432
self,
434433
variance_key: str = "var",
435434
quantity: str = "forces",
@@ -480,7 +479,7 @@ class for the possible options.
480479
gmm_path (Union[str, None], optional): Path to the saved GMM model. Defaults to None.
481480
"""
482481

483-
def __init__( # noqa: D107
482+
def __init__(
484483
self,
485484
train_embed_key: str = "train_embedding",
486485
test_embed_key: str = "embedding",
@@ -701,30 +700,56 @@ def get_unc_class(model: torch.nn.Module, info_dict: dict) -> Uncertainty:
701700
# to refit it
702701
if info_dict.get("uncertainty_type") == "gmm" and unc_class.is_fitted() is False:
703702
print("GMM: Doing train prediction")
704-
train_predicted, _train_targs, _loss = evaluate(
705-
model=model,
706-
loader=info_dict["train_dset"],
707-
loss_fn=info_dict["loss_fn"],
708-
device=device,
709-
requires_embedding=True,
710-
)
703+
if any(c in model.__repr__() for c in ["Painn", "SchNet"]):
704+
train_predicted, _train_targs, _loss = evaluate(
705+
model=model,
706+
loader=info_dict["train_dset"],
707+
loss_fn=info_dict["loss_fn"],
708+
device=device,
709+
requires_embedding=True,
710+
)
711+
712+
# GMM requires a 2D tensor for the embeddings, with the
713+
train_embedding = torch.concat(train_predicted["embedding"])
714+
715+
elif "MACE" in model.__repr__():
716+
_, train_predicted = evaluate(
717+
model=model,
718+
dset=info_dict["train_dset"],
719+
batch_size=info_dict["batch_size"],
720+
device=device,
721+
embedding_kwargs=info_dict["uncertainty_params"]["embedding_kwargs"],
722+
)
711723

712-
# GMM requires a 2D tensor for the embeddings, with the
713-
train_embedding = torch.stack([t.flatten() for t in train_predicted["embedding"]], dim=0)
724+
train_embedding = train_predicted["embeddings"].detach().cpu().squeeze()
714725

715726
print("COLVAR: Fitting GMM")
716727
unc_class.fit_gmm(train_embedding)
717728
calibrate = info_dict["uncertainty_params"].get("calibrate", False)
718-
if calibrate and unc_class.CP.qhat is None:
729+
if calibrate and (not hasattr(unc_class.CP, "qhat") or unc_class.CP.qhat is None):
719730
print("COLVAR: Fitting ConformalPrediction")
720-
calib_target, calib_predicted = evaluate(
721-
model=model,
722-
dset=info_dict["calib_dset"],
723-
batch_size=info_dict["batch_size"],
724-
device=device,
725-
embedding_kwargs=info_dict["uncertainty_params"]["embedding_kwargs"],
726-
)
731+
if any(c in model.__repr__() for c in ["Painn", "SchNet"]):
732+
calib_predicted, calib_target, _loss = evaluate(
733+
model=model,
734+
loader=info_dict["calib_dset"],
735+
loss_fn=info_dict["loss_fn"],
736+
device=device,
737+
requires_embedding=True,
738+
)
739+
740+
elif "MACE" in model.__repr__():
741+
calib_target, calib_predicted = evaluate(
742+
model=model,
743+
dset=info_dict["calib_dset"],
744+
batch_size=info_dict["batch_size"],
745+
device=device,
746+
embedding_kwargs=info_dict["uncertainty_params"]["embedding_kwargs"],
747+
)
748+
727749
# calib_predicted["embeddings"] = calib_predicted["embeddings"][0]
750+
print(calib_predicted.keys())
751+
print(len(calib_predicted[unc_class.test_key]))
752+
print(calib_predicted[unc_class.test_key][0].shape)
728753
calib_uncertainty = (
729754
unc_class(
730755
results=calib_predicted,

0 commit comments

Comments
 (0)