2525# SOFTWARE.
2626
2727
28- from typing import Literal
28+ from typing import Literal , TypeAlias
2929
3030import torch
3131from torch import Tensor
3737from ._utils .pref_vector import pref_vector_to_str_suffix , pref_vector_to_weighting
3838from ._weighting_bases import Weighting
3939
40+ SUPPORTED_SCALE_MODE : TypeAlias = Literal ["min" , "median" , "rmse" ]
41+
4042
4143class AlignedMTL (GramianWeightedAggregator ):
4244 r"""
@@ -58,10 +60,10 @@ class AlignedMTL(GramianWeightedAggregator):
5860 def __init__ (
5961 self ,
6062 pref_vector : Tensor | None = None ,
61- scale_mode : Literal [ "min" , "median" , "rmse" ] = "min" ,
63+ scale_mode : SUPPORTED_SCALE_MODE = "min" ,
6264 ):
6365 self ._pref_vector = pref_vector
64- self ._scale_mode = scale_mode
66+ self ._scale_mode : SUPPORTED_SCALE_MODE = scale_mode
6567 super ().__init__ (AlignedMTLWeighting (pref_vector , scale_mode = scale_mode ))
6668
6769 def __repr__ (self ) -> str :
@@ -89,11 +91,11 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]):
8991 def __init__ (
9092 self ,
9193 pref_vector : Tensor | None = None ,
92- scale_mode : Literal [ "min" , "median" , "rmse" ] = "min" ,
94+ scale_mode : SUPPORTED_SCALE_MODE = "min" ,
9395 ):
9496 super ().__init__ ()
9597 self ._pref_vector = pref_vector
96- self ._scale_mode = scale_mode
98+ self ._scale_mode : SUPPORTED_SCALE_MODE = scale_mode
9799 self .weighting = pref_vector_to_weighting (pref_vector , default = MeanWeighting ())
98100
99101 def forward (self , gramian : PSDMatrix ) -> Tensor :
@@ -105,7 +107,7 @@ def forward(self, gramian: PSDMatrix) -> Tensor:
105107
106108 @staticmethod
107109 def _compute_balance_transformation (
108- M : Tensor , scale_mode : Literal [ "min" , "median" , "rmse" ] = "min"
110+ M : Tensor , scale_mode : SUPPORTED_SCALE_MODE = "min"
109111 ) -> Tensor :
110112 lambda_ , V = torch .linalg .eigh (M , UPLO = "U" ) # More modern equivalent to torch.symeig
111113 tol = torch .max (lambda_ ) * len (M ) * torch .finfo ().eps
0 commit comments