2020
2121from nff .io .gmm import GaussianMixture
2222from nff .train .evaluate import evaluate
23- from nff .utils .cuda import batch_detach
2423from nff .utils .prediction import get_residual
2524
2625__all__ = [
4241class Uncertainty :
4342 """Base class for uncertainty predictions."""
4443
45- def __init__ ( # noqa: D107
44+ def __init__ (
4645 self ,
4746 order : str ,
4847 calibrate : bool ,
@@ -69,7 +68,7 @@ def __init__( # noqa: D107
6968
7069 self .CP = ConformalPrediction (alpha = cp_alpha )
7170
72- def __call__ (self , * args , ** kwargs ): # noqa: D102
71+ def __call__ (self , * args , ** kwargs ):
7372 return self .get_uncertainty (* args , ** kwargs )
7473
7574 def set_min_uncertainty (self , min_uncertainty : float , force : bool = False ) -> None :
@@ -175,7 +174,7 @@ class ConformalPrediction:
175174 on calibration data and apply to test data during prediction.
176175 """
177176
178- def __init__ (self , alpha : float ): # noqa: D107
177+ def __init__ (self , alpha : float ):
179178 self .alpha = alpha
180179
181180 def fit (
@@ -225,7 +224,7 @@ class EnsembleUncertainty(Uncertainty):
225224 targ_unit (Union[str, None], optional): Target unit of the quantity. Defaults to None.
226225 """
227226
228- def __init__ ( # noqa: D107
227+ def __init__ (
229228 self ,
230229 quantity : str ,
231230 order : str ,
@@ -342,7 +341,7 @@ class for the possible options.
342341 min_uncertainty (Union[float, None], optional): Minimum uncertainty value. Defaults to None.
343342 """
344343
345- def __init__ ( # noqa: D107
344+ def __init__ (
346345 self ,
347346 order : str = "atomic" ,
348347 shared_v : bool = False ,
@@ -429,7 +428,7 @@ class for the possible options.
429428 min_uncertainty (Union[float, None], optional): Minimum uncertainty value. Defaults to None.
430429 """
431430
432- def __init__ ( # noqa: D107
431+ def __init__ (
433432 self ,
434433 variance_key : str = "var" ,
435434 quantity : str = "forces" ,
@@ -480,7 +479,7 @@ class for the possible options.
480479 gmm_path (Union[str, None], optional): Path to the saved GMM model. Defaults to None.
481480 """
482481
483- def __init__ ( # noqa: D107
482+ def __init__ (
484483 self ,
485484 train_embed_key : str = "train_embedding" ,
486485 test_embed_key : str = "embedding" ,
@@ -701,30 +700,56 @@ def get_unc_class(model: torch.nn.Module, info_dict: dict) -> Uncertainty:
701700 # to refit it
702701 if info_dict .get ("uncertainty_type" ) == "gmm" and unc_class .is_fitted () is False :
703702 print ("GMM: Doing train prediction" )
704- train_predicted , _train_targs , _loss = evaluate (
705- model = model ,
706- loader = info_dict ["train_dset" ],
707- loss_fn = info_dict ["loss_fn" ],
708- device = device ,
709- requires_embedding = True ,
710- )
703+ if any (c in model .__repr__ () for c in ["Painn" , "SchNet" ]):
704+ train_predicted , _train_targs , _loss = evaluate (
705+ model = model ,
706+ loader = info_dict ["train_dset" ],
707+ loss_fn = info_dict ["loss_fn" ],
708+ device = device ,
709+ requires_embedding = True ,
710+ )
711+
712+ # GMM requires a 2D tensor for the embeddings, with the
713+ train_embedding = torch .concat (train_predicted ["embedding" ])
714+
715+ elif "MACE" in model .__repr__ ():
716+ _ , train_predicted = evaluate (
717+ model = model ,
718+ dset = info_dict ["train_dset" ],
719+ batch_size = info_dict ["batch_size" ],
720+ device = device ,
721+ embedding_kwargs = info_dict ["uncertainty_params" ]["embedding_kwargs" ],
722+ )
711723
712- # GMM requires a 2D tensor for the embeddings, with the
713- train_embedding = torch .stack ([t .flatten () for t in train_predicted ["embedding" ]], dim = 0 )
724+ train_embedding = train_predicted ["embeddings" ].detach ().cpu ().squeeze ()
714725
715726 print ("COLVAR: Fitting GMM" )
716727 unc_class .fit_gmm (train_embedding )
717728 calibrate = info_dict ["uncertainty_params" ].get ("calibrate" , False )
718- if calibrate and unc_class .CP . qhat is None :
729+ if calibrate and ( not hasattr ( unc_class .CP , "qhat" ) or unc_class . CP . qhat is None ) :
719730 print ("COLVAR: Fitting ConformalPrediction" )
720- calib_target , calib_predicted = evaluate (
721- model = model ,
722- dset = info_dict ["calib_dset" ],
723- batch_size = info_dict ["batch_size" ],
724- device = device ,
725- embedding_kwargs = info_dict ["uncertainty_params" ]["embedding_kwargs" ],
726- )
731+ if any (c in model .__repr__ () for c in ["Painn" , "SchNet" ]):
732+ calib_predicted , calib_target , _loss = evaluate (
733+ model = model ,
734+ loader = info_dict ["calib_dset" ],
735+ loss_fn = info_dict ["loss_fn" ],
736+ device = device ,
737+ requires_embedding = True ,
738+ )
739+
740+ elif "MACE" in model .__repr__ ():
741+ calib_target , calib_predicted = evaluate (
742+ model = model ,
743+ dset = info_dict ["calib_dset" ],
744+ batch_size = info_dict ["batch_size" ],
745+ device = device ,
746+ embedding_kwargs = info_dict ["uncertainty_params" ]["embedding_kwargs" ],
747+ )
748+
727749 # calib_predicted["embeddings"] = calib_predicted["embeddings"][0]
750+ print (calib_predicted .keys ())
751+ print (len (calib_predicted [unc_class .test_key ]))
752+ print (calib_predicted [unc_class .test_key ][0 ].shape )
728753 calib_uncertainty = (
729754 unc_class (
730755 results = calib_predicted ,
0 commit comments