@@ -797,70 +797,43 @@ def __repr__(self):
797797 def _pretty_repr (self ):
798798 return self .__class__ .__name__ + "(hyperparameter=\" {}\" , settings={})" .format (self .name , json .dumps (self ._algo_settings [self .name ], indent = 4 ))
799799
800- def _set_values (self , values = None ):
801- if values is None :
802- warnings .warn ("Categorical hyperparameter \" {}\" not modified" .format (self .name ))
803- else :
804- assert isinstance (values , dict ), "Invalid values input type for hyperparameter " \
805- "\" {}\" : " .format (self .name ) + \
806- "must be a dictionary"
807- admissible_values = self ._algo_settings [self .name ]["values" ].keys ()
808- for category , setting in values .items ():
809- assert category in admissible_values , "Unknown categorical value \" " + category + "\" . Expected a member of " + str (list (admissible_values ))
810- value_error_message = "Invalid input value for hyperparameter \" {}\" , category \" {}\" : " .format (self .name , category )
811- value_error_message += "expected a {\" enabled\" : bool} dictionary"
812- assert isinstance (setting , dict ), value_error_message
813- assert list (setting .keys ()) == ["enabled" ], value_error_message
814- assert all (type (v ) == bool for v in setting .values ()), value_error_message
815- for category , setting in values .items ():
816- self ._algo_settings [self .name ]["values" ][category ] = setting
817-
818- def enable_categories (self , categories , disable_others = False ):
819- """
820- Enables the search over categories listed in the first argument.
821- :param list categories: will enable the search over the provided categories
822- :param bool disable_others: if True, will also disable the search over categories not listed in the first argument
823- :return current CategoricalHyperparameterSettings
824- """
825- accepted_categories = self .get_all_categories ()
826- for category in categories :
827- assert isinstance (category , string_types )
828- assert category in accepted_categories
829- self ._set_values ({category : {"enabled" : True }
830- for category in categories })
831- if disable_others :
832- self ._set_values ({category : {"enabled" : False }
833- for category in accepted_categories
834- if category not in categories })
800+ def set_values (self , values ):
801+ """
802+ Enables the search over listed values (categories).
803+ :param values: enable the search over the provided categories and disable the search over non-provided categories
804+ :type values: list of str
805+ :return: current CategoricalHyperparameterSettings
806+ """
807+ all_possible_values = self .get_all_possible_values ()
808+ for category in values :
809+ assert isinstance (category , string_types ), \
810+ "Invalid input type {} for categorical hyperparameter {}: must be a string" .format (type (category ), self .name )
811+ assert category in all_possible_values , \
812+ "Invalid input value \" {}\" for categorical hyperparameter {}: must be a member of {}" .format (category , self .name , all_possible_values )
813+
814+ for category in all_possible_values :
815+ if category in values :
816+ self ._algo_settings [self .name ]["values" ][category ] = {"enabled" : True }
817+ else :
818+ self ._algo_settings [self .name ]["values" ][category ] = {"enabled" : False }
835819 return self
836820
837- def disable_categories (self , categories , enable_others = False ):
838- """
839- Disables the search over categories listed in the first argument.
840- :param list categories: will disable the search over the provided categories
841- :param bool enable_others: if True, will also enable the search over categories not listed in the first argument
842- :return current CategoricalHyperparameterSettings
843- """
844- accepted_categories = self .get_all_categories ()
845- for category in categories :
846- assert isinstance (category , string_types )
847- assert category in accepted_categories
848- self ._set_values ({category : {"enabled" : False }
849- for category in categories })
850- if enable_others :
851- self ._set_values ({category : {"enabled" : True }
852- for category in accepted_categories
853- if category not in categories })
854- return self
821+ def get_values (self ):
822+ """
823+ :return: list of enabled categories for this hyperparameter
824+ :rtype: list of str
825+ """
826+ values_dict = self ._algo_settings [self .name ]["values" ]
827+ return [value for value in values_dict .keys () if values_dict [value ]["enabled" ]]
855828
856- def get_all_categories (self ):
829+ def get_all_possible_values (self ):
857830 """
858- :return: list of valid categories for this hyperparameter
831+ :return: list of possible values for this hyperparameter
832+ :rtype: list of str
859833 """
860834 return list (self ._algo_settings [self .name ]["values" ].keys ())
861835
862836
863-
864837class SingleValueHyperparameterSettings (HyperparameterSettings ):
865838
866839 def __init__ (self , name , algo_settings , accepted_types = None ):
@@ -876,14 +849,28 @@ def __repr__(self):
876849
877850 def set_value (self , value ):
878851 """
879- :param bool | int | float value:
852+ :param value:
853+ :type value: bool | int | float
880854 :return: current SingleValueHyperparameterSettings
881855 """
882856 if self .accepted_types is not None :
883857 assert any (isinstance (value , accepted_type ) for accepted_type in self .accepted_types ), "Invalid type for hyperparameter {}. Type must be one of: {}" .format (self .name , self .accepted_types )
884858 self ._algo_settings [self .name ] = value
885859 return self
886860
861+ def get_value (self ):
862+ """
863+ :return: current value
864+ :rtype: bool | int | float
865+ """
866+ return self ._algo_settings [self .name ]
867+
868+ def get_accepted_types (self ):
869+ """
870+ :return: valid types for this hyperparameter
871+ """
872+ return self .accepted_types
873+
887874
888875class SingleCategoryHyperparameterSettings (HyperparameterSettings ):
889876
@@ -905,14 +892,29 @@ def __repr__(self):
905892
906893 def set_value (self , value ):
907894 """
908- :param str value:
895+ :param value:
896+ :type value: str
909897 :return: current SingleValueHyperparameterSettings
910898 """
911899 if self .accepted_values is not None :
912900 assert value in self .accepted_values , "Invalid value for hyperparameter {}. Must be in {}" .format (self .name , json .dumps (self .accepted_values ))
913901 self ._algo_settings [self .name ] = value
914902 return self
915903
904+ def get_value (self ):
905+ """
906+ :return: current value
907+ :rtype: str
908+ """
909+ return self ._algo_settings [self .name ]
910+
911+ def get_all_possible_values (self ):
912+ """
913+ :return: list of possible values for this hyperparameter
914+ :rtype: list of str
915+ """
916+ return list (self ._algo_settings [self .name ]["values" ].keys ())
917+
916918
917919class PredictionAlgorithmSettings (dict ):
918920
0 commit comments