Skip to content

Commit 50ae63a

Browse files
committed
Make SingleValueHyperparameterSettings, SingleCategoryHyperparameterSettings and CategoricalHyperparameterSettings interface more consistent with each other
1 parent 9ba0b80 commit 50ae63a

File tree

1 file changed

+60
-58
lines changed

1 file changed

+60
-58
lines changed

dataikuapi/dss/ml.py

Lines changed: 60 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -797,70 +797,43 @@ def __repr__(self):
797797
def _pretty_repr(self):
798798
return self.__class__.__name__ + "(hyperparameter=\"{}\", settings={})".format(self.name, json.dumps(self._algo_settings[self.name], indent=4))
799799

800-
def _set_values(self, values=None):
801-
if values is None:
802-
warnings.warn("Categorical hyperparameter \"{}\" not modified".format(self.name))
803-
else:
804-
assert isinstance(values, dict), "Invalid values input type for hyperparameter " \
805-
"\"{}\": ".format(self.name) + \
806-
"must be a dictionary"
807-
admissible_values = self._algo_settings[self.name]["values"].keys()
808-
for category, setting in values.items():
809-
assert category in admissible_values, "Unknown categorical value \"" + category + "\". Expected a member of " + str(list(admissible_values))
810-
value_error_message = "Invalid input value for hyperparameter \"{}\", category \"{}\": ".format(self.name, category)
811-
value_error_message += "expected a {\"enabled\": bool} dictionary"
812-
assert isinstance(setting, dict), value_error_message
813-
assert list(setting.keys()) == ["enabled"], value_error_message
814-
assert all(type(v) == bool for v in setting.values()), value_error_message
815-
for category, setting in values.items():
816-
self._algo_settings[self.name]["values"][category] = setting
817-
818-
def enable_categories(self, categories, disable_others=False):
819-
"""
820-
Enables the search over categories listed in the first argument.
821-
:param list categories: will enable the search over the provided categories
822-
:param bool disable_others: if True, will also disable the search over categories not listed in the first argument
823-
:return current CategoricalHyperparameterSettings
824-
"""
825-
accepted_categories = self.get_all_categories()
826-
for category in categories:
827-
assert isinstance(category, string_types)
828-
assert category in accepted_categories
829-
self._set_values({category: {"enabled": True}
830-
for category in categories})
831-
if disable_others:
832-
self._set_values({category: {"enabled": False}
833-
for category in accepted_categories
834-
if category not in categories})
800+
def set_values(self, values):
801+
"""
802+
Enables the search over listed values (categories).
803+
:param values: enable the search over the provided categories and disable the search over non-provided categories
804+
:type values: list of str
805+
:return: current CategoricalHyperparameterSettings
806+
"""
807+
all_possible_values = self.get_all_possible_values()
808+
for category in values:
809+
assert isinstance(category, string_types), \
810+
"Invalid input type {} for categorical hyperparameter {}: must be a string".format(type(category), self.name)
811+
assert category in all_possible_values, \
812+
"Invalid input value \"{}\" for categorical hyperparameter {}: must be a member of {}".format(category, self.name, all_possible_values)
813+
814+
for category in all_possible_values:
815+
if category in values:
816+
self._algo_settings[self.name]["values"][category] = {"enabled": True}
817+
else:
818+
self._algo_settings[self.name]["values"][category] = {"enabled": False}
835819
return self
836820

837-
def disable_categories(self, categories, enable_others=False):
838-
"""
839-
Disables the search over categories listed in the first argument.
840-
:param list categories: will disable the search over the provided categories
841-
:param bool enable_others: if True, will also enable the search over categories not listed in the first argument
842-
:return current CategoricalHyperparameterSettings
843-
"""
844-
accepted_categories = self.get_all_categories()
845-
for category in categories:
846-
assert isinstance(category, string_types)
847-
assert category in accepted_categories
848-
self._set_values({category: {"enabled": False}
849-
for category in categories})
850-
if enable_others:
851-
self._set_values({category: {"enabled": True}
852-
for category in accepted_categories
853-
if category not in categories})
854-
return self
821+
def get_values(self):
822+
"""
823+
:return: list of enabled categories for this hyperparameter
824+
:rtype: list of str
825+
"""
826+
values_dict = self._algo_settings[self.name]["values"]
827+
return [value for value in values_dict.keys() if values_dict[value]["enabled"]]
855828

856-
def get_all_categories(self):
829+
def get_all_possible_values(self):
857830
"""
858-
:return: list of valid categories for this hyperparameter
831+
:return: list of possible values for this hyperparameter
832+
:rtype: list of str
859833
"""
860834
return list(self._algo_settings[self.name]["values"].keys())
861835

862836

863-
864837
class SingleValueHyperparameterSettings(HyperparameterSettings):
865838

866839
def __init__(self, name, algo_settings, accepted_types=None):
@@ -876,14 +849,28 @@ def __repr__(self):
876849

877850
def set_value(self, value):
878851
"""
879-
:param bool | int | float value:
852+
:param value:
853+
:type value: bool | int | float
880854
:return: current SingleValueHyperparameterSettings
881855
"""
882856
if self.accepted_types is not None:
883857
assert any(isinstance(value, accepted_type) for accepted_type in self.accepted_types), "Invalid type for hyperparameter {}. Type must be one of: {}".format(self.name, self.accepted_types)
884858
self._algo_settings[self.name] = value
885859
return self
886860

861+
def get_value(self):
862+
"""
863+
:return: current value
864+
:rtype: bool | int | float
865+
"""
866+
return self._algo_settings[self.name]
867+
868+
def get_accepted_types(self):
869+
"""
870+
:return: valid types for this hyperparameter
871+
"""
872+
return self.accepted_types
873+
887874

888875
class SingleCategoryHyperparameterSettings(HyperparameterSettings):
889876

@@ -905,14 +892,29 @@ def __repr__(self):
905892

906893
def set_value(self, value):
907894
"""
908-
:param str value:
895+
:param value:
896+
:type value: str
909897
:return: current SingleValueHyperparameterSettings
910898
"""
911899
if self.accepted_values is not None:
912900
assert value in self.accepted_values, "Invalid value for hyperparameter {}. Must be in {}".format(self.name, json.dumps(self.accepted_values))
913901
self._algo_settings[self.name] = value
914902
return self
915903

904+
def get_value(self):
905+
"""
906+
:return: current value
907+
:rtype: str
908+
"""
909+
return self._algo_settings[self.name]
910+
911+
def get_all_possible_values(self):
912+
"""
913+
:return: list of possible values for this hyperparameter
914+
:rtype: list of str
915+
"""
916+
return list(self._algo_settings[self.name]["values"].keys())
917+
916918

917919
class PredictionAlgorithmSettings(dict):
918920

0 commit comments

Comments
 (0)