diff --git a/.pylintrc b/.pylintrc index 8838f39f..588224e4 100644 --- a/.pylintrc +++ b/.pylintrc @@ -87,10 +87,10 @@ disable=raw-checker-failed, missing-class-docstring, missing-function-docstring, invalid-name, - too-few-public-methods, line-too-long, import-error, # we only install linter dependencies in CI/CD wrong-import-order, # we use ruff to enforce import order + duplicate-code, # pylint has a tendancy to capture docstrings as duplicate code # Enable the message, report, category or checker with the given id(s). You can diff --git a/generated/docs/ImageQueriesApi.md b/generated/docs/ImageQueriesApi.md index 7f39a4da..47a65ae5 100644 --- a/generated/docs/ImageQueriesApi.md +++ b/generated/docs/ImageQueriesApi.md @@ -203,11 +203,12 @@ with groundlight_openapi_client.ApiClient(configuration) as api_client: api_instance = image_queries_api.ImageQueriesApi(api_client) page = 1 # int | A page number within the paginated result set. (optional) page_size = 1 # int | Number of items to return per page. (optional) + predictor_id = "predictor_id_example" # str | Optionally filter image queries by detector ID. (optional) # example passing only required values which don't have defaults set # and optional values try: - api_response = api_instance.list_image_queries(page=page, page_size=page_size) + api_response = api_instance.list_image_queries(page=page, page_size=page_size, predictor_id=predictor_id) pprint(api_response) except groundlight_openapi_client.ApiException as e: print("Exception when calling ImageQueriesApi->list_image_queries: %s\n" % e) @@ -220,6 +221,7 @@ Name | Type | Description | Notes ------------- | ------------- | ------------- | ------------- **page** | **int**| A page number within the paginated result set. | [optional] **page_size** | **int**| Number of items to return per page. | [optional] + **predictor_id** | **str**| Optionally filter image queries by detector ID. | [optional] ### Return type diff --git a/generated/groundlight_openapi_client/api/image_queries_api.py b/generated/groundlight_openapi_client/api/image_queries_api.py index cd8d0577..b25ad772 100644 --- a/generated/groundlight_openapi_client/api/image_queries_api.py +++ b/generated/groundlight_openapi_client/api/image_queries_api.py @@ -129,6 +129,7 @@ def __init__(self, api_client=None): "all": [ "page", "page_size", + "predictor_id", ], "required": [], "nullable": [], @@ -141,14 +142,17 @@ def __init__(self, api_client=None): "openapi_types": { "page": (int,), "page_size": (int,), + "predictor_id": (str,), }, "attribute_map": { "page": "page", "page_size": "page_size", + "predictor_id": "predictor_id", }, "location_map": { "page": "query", "page_size": "query", + "predictor_id": "query", }, "collection_format_map": {}, }, @@ -375,6 +379,7 @@ def list_image_queries(self, **kwargs): Keyword Args: page (int): A page number within the paginated result set.. [optional] page_size (int): Number of items to return per page.. [optional] + predictor_id (str): Optionally filter image queries by detector ID.. [optional] _return_http_data_only (bool): response data without head status code and headers. Default is True. _preload_content (bool): if False, the urllib3.HTTPResponse object diff --git a/generated/groundlight_openapi_client/model/patched_detector_request.py b/generated/groundlight_openapi_client/model/patched_detector_request.py index 64534047..251cb75d 100644 --- a/generated/groundlight_openapi_client/model/patched_detector_request.py +++ b/generated/groundlight_openapi_client/model/patched_detector_request.py @@ -8,7 +8,6 @@ Generated by: https://openapi-generator.tech """ - import re # noqa: F401 import sys # noqa: F401 @@ -25,7 +24,7 @@ file_type, none_type, validate_get_composed_info, - OpenApiModel + OpenApiModel, ) from groundlight_openapi_client.exceptions import ApiAttributeError @@ -34,9 +33,10 @@ def lazy_import(): from groundlight_openapi_client.model.blank_enum import BlankEnum from groundlight_openapi_client.model.escalation_type_enum import EscalationTypeEnum from groundlight_openapi_client.model.status_enum import StatusEnum - globals()['BlankEnum'] = BlankEnum - globals()['EscalationTypeEnum'] = EscalationTypeEnum - globals()['StatusEnum'] = StatusEnum + + globals()["BlankEnum"] = BlankEnum + globals()["EscalationTypeEnum"] = EscalationTypeEnum + globals()["StatusEnum"] = StatusEnum class PatchedDetectorRequest(ModelNormal): @@ -63,21 +63,20 @@ class PatchedDetectorRequest(ModelNormal): as additional properties values. """ - allowed_values = { - } + allowed_values = {} validations = { - ('name',): { - 'max_length': 200, - 'min_length': 1, + ("name",): { + "max_length": 200, + "min_length": 1, }, - ('confidence_threshold',): { - 'inclusive_maximum': 1.0, - 'inclusive_minimum': 0.0, + ("confidence_threshold",): { + "inclusive_maximum": 1.0, + "inclusive_minimum": 0.0, }, - ('patience_time',): { - 'inclusive_maximum': 3600, - 'inclusive_minimum': 0, + ("patience_time",): { + "inclusive_maximum": 3600, + "inclusive_minimum": 0, }, } @@ -88,7 +87,17 @@ def additional_properties_type(): of type self, this must run after the class is loaded """ lazy_import() - return (bool, date, datetime, dict, float, int, list, str, none_type,) # noqa: E501 + return ( + bool, + date, + datetime, + dict, + float, + int, + list, + str, + none_type, + ) # noqa: E501 _nullable = False @@ -104,28 +113,46 @@ def openapi_types(): """ lazy_import() return { - 'name': (str,), # noqa: E501 - 'confidence_threshold': (float,), # noqa: E501 - 'patience_time': (float,), # noqa: E501 - 'status': (bool, date, datetime, dict, float, int, list, str, none_type,), # noqa: E501 - 'escalation_type': (bool, date, datetime, dict, float, int, list, str, none_type,), # noqa: E501 + "name": (str,), # noqa: E501 + "confidence_threshold": (float,), # noqa: E501 + "patience_time": (float,), # noqa: E501 + "status": ( + bool, + date, + datetime, + dict, + float, + int, + list, + str, + none_type, + ), # noqa: E501 + "escalation_type": ( + bool, + date, + datetime, + dict, + float, + int, + list, + str, + none_type, + ), # noqa: E501 } @cached_property def discriminator(): return None - attribute_map = { - 'name': 'name', # noqa: E501 - 'confidence_threshold': 'confidence_threshold', # noqa: E501 - 'patience_time': 'patience_time', # noqa: E501 - 'status': 'status', # noqa: E501 - 'escalation_type': 'escalation_type', # noqa: E501 + "name": "name", # noqa: E501 + "confidence_threshold": "confidence_threshold", # noqa: E501 + "patience_time": "patience_time", # noqa: E501 + "status": "status", # noqa: E501 + "escalation_type": "escalation_type", # noqa: E501 } - read_only_vars = { - } + read_only_vars = {} _composed_schemas = {} @@ -172,17 +199,18 @@ def _from_openapi_data(cls, *args, **kwargs): # noqa: E501 escalation_type (bool, date, datetime, dict, float, int, list, str, none_type): Category that define internal proccess for labeling image queries * `STANDARD` - STANDARD * `NO_HUMAN_LABELING` - NO_HUMAN_LABELING. [optional] # noqa: E501 """ - _check_type = kwargs.pop('_check_type', True) - _spec_property_naming = kwargs.pop('_spec_property_naming', False) - _path_to_item = kwargs.pop('_path_to_item', ()) - _configuration = kwargs.pop('_configuration', None) - _visited_composed_classes = kwargs.pop('_visited_composed_classes', ()) + _check_type = kwargs.pop("_check_type", True) + _spec_property_naming = kwargs.pop("_spec_property_naming", False) + _path_to_item = kwargs.pop("_path_to_item", ()) + _configuration = kwargs.pop("_configuration", None) + _visited_composed_classes = kwargs.pop("_visited_composed_classes", ()) self = super(OpenApiModel, cls).__new__(cls) if args: raise ApiTypeError( - "Invalid positional arguments=%s passed to %s. Remove those invalid positional arguments." % ( + "Invalid positional arguments=%s passed to %s. Remove those invalid positional arguments." + % ( args, self.__class__.__name__, ), @@ -198,22 +226,24 @@ def _from_openapi_data(cls, *args, **kwargs): # noqa: E501 self._visited_composed_classes = _visited_composed_classes + (self.__class__,) for var_name, var_value in kwargs.items(): - if var_name not in self.attribute_map and \ - self._configuration is not None and \ - self._configuration.discard_unknown_keys and \ - self.additional_properties_type is None: + if ( + var_name not in self.attribute_map + and self._configuration is not None + and self._configuration.discard_unknown_keys + and self.additional_properties_type is None + ): # discard variable. continue setattr(self, var_name, var_value) return self required_properties = set([ - '_data_store', - '_check_type', - '_spec_property_naming', - '_path_to_item', - '_configuration', - '_visited_composed_classes', + "_data_store", + "_check_type", + "_spec_property_naming", + "_path_to_item", + "_configuration", + "_visited_composed_classes", ]) @convert_js_args_to_python_args @@ -258,15 +288,16 @@ def __init__(self, *args, **kwargs): # noqa: E501 escalation_type (bool, date, datetime, dict, float, int, list, str, none_type): Category that define internal proccess for labeling image queries * `STANDARD` - STANDARD * `NO_HUMAN_LABELING` - NO_HUMAN_LABELING. [optional] # noqa: E501 """ - _check_type = kwargs.pop('_check_type', True) - _spec_property_naming = kwargs.pop('_spec_property_naming', False) - _path_to_item = kwargs.pop('_path_to_item', ()) - _configuration = kwargs.pop('_configuration', None) - _visited_composed_classes = kwargs.pop('_visited_composed_classes', ()) + _check_type = kwargs.pop("_check_type", True) + _spec_property_naming = kwargs.pop("_spec_property_naming", False) + _path_to_item = kwargs.pop("_path_to_item", ()) + _configuration = kwargs.pop("_configuration", None) + _visited_composed_classes = kwargs.pop("_visited_composed_classes", ()) if args: raise ApiTypeError( - "Invalid positional arguments=%s passed to %s. Remove those invalid positional arguments." % ( + "Invalid positional arguments=%s passed to %s. Remove those invalid positional arguments." + % ( args, self.__class__.__name__, ), @@ -282,13 +313,17 @@ def __init__(self, *args, **kwargs): # noqa: E501 self._visited_composed_classes = _visited_composed_classes + (self.__class__,) for var_name, var_value in kwargs.items(): - if var_name not in self.attribute_map and \ - self._configuration is not None and \ - self._configuration.discard_unknown_keys and \ - self.additional_properties_type is None: + if ( + var_name not in self.attribute_map + and self._configuration is not None + and self._configuration.discard_unknown_keys + and self.additional_properties_type is None + ): # discard variable. continue setattr(self, var_name, var_value) if var_name in self.read_only_vars: - raise ApiAttributeError(f"`{var_name}` is a read-only attribute. Use `from_openapi_data` to instantiate " - f"class with read only attributes.") + raise ApiAttributeError( + f"`{var_name}` is a read-only attribute. Use `from_openapi_data` to instantiate " + "class with read only attributes." + ) diff --git a/generated/model.py b/generated/model.py index 25027f04..827e7ea1 100644 --- a/generated/model.py +++ b/generated/model.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: public-api.yaml -# timestamp: 2024-12-09T18:29:17+00:00 +# timestamp: 2024-12-10T01:13:13+00:00 from __future__ import annotations diff --git a/src/groundlight/binary_labels.py b/src/groundlight/binary_labels.py index c1d20470..557a4245 100644 --- a/src/groundlight/binary_labels.py +++ b/src/groundlight/binary_labels.py @@ -53,25 +53,3 @@ def convert_internal_label_to_display( logger.warning(f"Unrecognized internal label {label} - leaving it alone as a string.") return label - - -def convert_display_label_to_internal( - context: Union[ImageQuery, Detector, str], # pylint: disable=unused-argument - label: Union[Label, str], -) -> str: - """Convert a label that comes from the user into the label string that we send to the server. We - are strict here, and only allow YES/NO. - - NOTE: We accept case-insensitive label strings from the user, but we send UPPERCASE labels to - the server. E.g., user inputs "yes" -> the label is returned as "YES". - """ - # NOTE: In the future we should validate against actually supported labels for the detector - if not isinstance(label, str): - raise ValueError(f"Expected a string label, but got {label} of type {type(label)}") - upper = label.upper() - if upper == Label.YES: - return DeprecatedLabel.PASS.value - if upper == Label.NO: - return DeprecatedLabel.FAIL.value - - raise ValueError(f"Invalid label string '{label}'. Must be one of '{Label.YES.value}','{Label.NO.value}'.") diff --git a/src/groundlight/client.py b/src/groundlight/client.py index b45300bc..9cd4107d 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -13,9 +13,11 @@ from groundlight_openapi_client.api.labels_api import LabelsApi from groundlight_openapi_client.api.user_api import UserApi from groundlight_openapi_client.exceptions import NotFoundException, UnauthorizedException +from groundlight_openapi_client.model.b_box_geometry_request import BBoxGeometryRequest from groundlight_openapi_client.model.detector_creation_input_request import DetectorCreationInputRequest from groundlight_openapi_client.model.label_value_request import LabelValueRequest from groundlight_openapi_client.model.patched_detector_request import PatchedDetectorRequest +from groundlight_openapi_client.model.roi_request import ROIRequest from model import ( ROI, BinaryClassificationResult, @@ -26,7 +28,7 @@ ) from urllib3.exceptions import InsecureRequestWarning -from groundlight.binary_labels import Label, convert_display_label_to_internal, convert_internal_label_to_display +from groundlight.binary_labels import Label, convert_internal_label_to_display from groundlight.config import API_TOKEN_MISSING_HELP_MESSAGE, API_TOKEN_VARIABLE_NAME, DISABLE_TLS_VARIABLE_NAME from groundlight.encodings import url_encode_dict from groundlight.images import ByteStreamWrapper, parse_supported_image_types @@ -1070,7 +1072,10 @@ def _wait_for_result( return image_query def add_label( - self, image_query: Union[ImageQuery, str], label: Union[Label, str], rois: Union[List[ROI], str, None] = None + self, + image_query: Union[ImageQuery, str], + label: Union[Label, int, str], + rois: Union[List[ROI], str, None] = None, ): """ Provide a new label (annotation) for an image query. This is used to provide ground-truth labels @@ -1091,17 +1096,28 @@ def add_label( rois = [ROI(x=100, y=100, width=50, height=50)] gl.add_label(image_query, "YES", rois=rois) - :param image_query: Either an ImageQuery object (returned from methods like :meth:`ask_ml`) or an image query ID - string starting with "iq_". - :param label: The label value to assign, typically "YES" or "NO" for binary classification detectors. - For multi-class detectors, use one of the defined class names. - :param rois: Optional list of ROI objects defining regions of interest in the image. - Each ROI specifies a bounding box with x, y coordinates and width, height. + :param image_query: Either an ImageQuery object (returned from methods like + `ask_ml`) or an image query ID string starting with "iq_". + + :param label: The label value to assign, typically "YES" or "NO" for binary + classification detectors. For multi-class detectors, use one of + the defined class names. + + :param rois: Optional list of ROI objects defining regions of interest in the + image. Each ROI specifies a bounding box with x, y coordinates + and width, height. :return: None """ if isinstance(rois, str): raise TypeError("rois must be a list of ROI objects. CLI support is not implemented") + + # NOTE: bool is a subclass of int + if type(label) == int: # noqa: E721 pylint: disable=unidiomatic-typecheck + label = str(label) + elif not isinstance(label, (str, Label)): + raise TypeError("label must be a string or integer") + if isinstance(image_query, ImageQuery): image_query_id = image_query.id else: @@ -1109,9 +1125,16 @@ def add_label( # Some old imagequery id's started with "chk_" if not image_query_id.startswith(("chk_", "iq_")): raise ValueError(f"Invalid image query id {image_query_id}") - api_label = convert_display_label_to_internal(image_query_id, label) - rois_json = [roi.dict() for roi in rois] if rois else None - request_params = LabelValueRequest(label=api_label, image_query_id=image_query_id, rois=rois_json) + geometry_requests = [BBoxGeometryRequest(**roi.geometry.dict()) for roi in rois] if rois else None + roi_requests = ( + [ + ROIRequest(label=roi.label, score=roi.score, geometry=geometry) + for roi, geometry in zip(rois, geometry_requests) + ] + if rois and geometry_requests + else None + ) + request_params = LabelValueRequest(label=label, image_query_id=image_query_id, rois=roi_requests) self.labels_api.create_label(request_params) def start_inspection(self) -> str: diff --git a/src/groundlight/experimental_api.py b/src/groundlight/experimental_api.py index d3d9721c..5b8765b8 100644 --- a/src/groundlight/experimental_api.py +++ b/src/groundlight/experimental_api.py @@ -17,22 +17,18 @@ from groundlight_openapi_client.api.image_queries_api import ImageQueriesApi from groundlight_openapi_client.api.notes_api import NotesApi from groundlight_openapi_client.model.action_request import ActionRequest -from groundlight_openapi_client.model.b_box_geometry_request import BBoxGeometryRequest from groundlight_openapi_client.model.channel_enum import ChannelEnum from groundlight_openapi_client.model.condition_request import ConditionRequest from groundlight_openapi_client.model.count_mode_configuration import CountModeConfiguration from groundlight_openapi_client.model.detector_group_request import DetectorGroupRequest from groundlight_openapi_client.model.escalation_type_enum import EscalationTypeEnum -from groundlight_openapi_client.model.label_value_request import LabelValueRequest from groundlight_openapi_client.model.multi_class_mode_configuration import MultiClassModeConfiguration from groundlight_openapi_client.model.patched_detector_request import PatchedDetectorRequest -from groundlight_openapi_client.model.roi_request import ROIRequest from groundlight_openapi_client.model.rule_request import RuleRequest from groundlight_openapi_client.model.status_enum import StatusEnum from groundlight_openapi_client.model.verb_enum import VerbEnum -from model import ROI, BBoxGeometry, Detector, DetectorGroup, ImageQuery, ModeEnum, PaginatedRuleList, Rule +from model import ROI, BBoxGeometry, Detector, DetectorGroup, ModeEnum, PaginatedRuleList, Rule -from groundlight.binary_labels import Label, convert_display_label_to_internal from groundlight.images import parse_supported_image_types from groundlight.optional_imports import Image, np @@ -499,66 +495,6 @@ def create_roi(self, label: str, top_left: Tuple[float, float], bottom_right: Tu ), ) - # TODO: remove duplicate method on subclass - # pylint: disable=duplicate-code - def add_label( - self, image_query: Union[ImageQuery, str], label: Union[Label, str], rois: Union[List[ROI], str, None] = None - ): - """ - Provide a new label (annotation) for an image query. This is used to provide ground-truth labels - for training detectors, or to correct the results of detectors. - - **Example usage**:: - - gl = ExperimentalApi() - - # Using an ImageQuery object - image_query = gl.ask_ml(detector_id, image_data) - gl.add_label(image_query, "YES") - - # Using an image query ID string directly - gl.add_label("iq_abc123", "NO") - - # With regions of interest (ROIs) - rois = [ROI(x=100, y=100, width=50, height=50)] - gl.add_label(image_query, "YES", rois=rois) - - :param image_query: Either an ImageQuery object (returned from methods like - `ask_ml`) or an image query ID string starting with "iq_". - - :param label: The label value to assign, typically "YES" or "NO" for binary - classification detectors. For multi-class detectors, use one of - the defined class names. - - :param rois: Optional list of ROI objects defining regions of interest in the - image. Each ROI specifies a bounding box with x, y coordinates - and width, height. - - :return: None - """ - if isinstance(rois, str): - raise TypeError("rois must be a list of ROI objects. CLI support is not implemented") - if isinstance(image_query, ImageQuery): - image_query_id = image_query.id - else: - image_query_id = str(image_query) - # Some old imagequery id's started with "chk_" - # TODO: handle iqe_ for image_queries returned from edge endpoints - if not image_query_id.startswith(("chk_", "iq_")): - raise ValueError(f"Invalid image query id {image_query_id}") - api_label = convert_display_label_to_internal(image_query_id, label) - geometry_requests = [BBoxGeometryRequest(**roi.geometry.dict()) for roi in rois] if rois else None - roi_requests = ( - [ - ROIRequest(label=roi.label, score=roi.score, geometry=geometry) - for roi, geometry in zip(rois, geometry_requests) - ] - if rois and geometry_requests - else None - ) - request_params = LabelValueRequest(label=api_label, image_query_id=image_query_id, rois=roi_requests) - self.labels_api.create_label(request_params) - def reset_detector(self, detector: Union[str, Detector]) -> None: """ Removes all image queries and training data for the given detector. This effectively resets diff --git a/src/groundlight/internalapi.py b/src/groundlight/internalapi.py index f5a2e45b..18f2f3a7 100644 --- a/src/groundlight/internalapi.py +++ b/src/groundlight/internalapi.py @@ -87,7 +87,7 @@ def __init__(self, status=None, reason=None, http_resp=None): super().__init__(status, reason, http_resp) -class RequestsRetryDecorator: +class RequestsRetryDecorator: # pylint: disable=too-few-public-methods """ Decorate a function to retry sending HTTP requests. diff --git a/test/integration/test_groundlight.py b/test/integration/test_groundlight.py index 01a22bc2..ab82f0e7 100644 --- a/test/integration/test_groundlight.py +++ b/test/integration/test_groundlight.py @@ -10,7 +10,7 @@ import pytest from groundlight import Groundlight -from groundlight.binary_labels import VALID_DISPLAY_LABELS, DeprecatedLabel, Label, convert_internal_label_to_display +from groundlight.binary_labels import VALID_DISPLAY_LABELS, Label, convert_internal_label_to_display from groundlight.internalapi import ApiException, InternalApiError, NotFoundError from groundlight.optional_imports import * from groundlight.status_codes import is_user_error @@ -622,44 +622,15 @@ def test_add_label_names(gl: Groundlight, image_query_yes: ImageQuery, image_que gl.add_label(iqid_no, "NO") gl.add_label(iqid_no, "no") - # Invalid labels - with pytest.raises(ValueError): - gl.add_label(iqid_yes, "PASS") - with pytest.raises(ValueError): - gl.add_label(iqid_yes, "FAIL") - with pytest.raises(ValueError): - gl.add_label(iqid_yes, DeprecatedLabel.PASS) - with pytest.raises(ValueError): - gl.add_label(iqid_yes, DeprecatedLabel.FAIL) - with pytest.raises(ValueError): - gl.add_label(iqid_yes, "sorta") - with pytest.raises(ValueError): - gl.add_label(iqid_yes, "YES ") - with pytest.raises(ValueError): - gl.add_label(iqid_yes, " YES") - with pytest.raises(ValueError): - gl.add_label(iqid_yes, "0") - with pytest.raises(ValueError): - gl.add_label(iqid_yes, "1") - - # We technically don't allow these in the type signature, but users might do it anyway - with pytest.raises(ValueError): - gl.add_label(iqid_yes, 0) # type: ignore - with pytest.raises(ValueError): - gl.add_label(iqid_yes, 1) # type: ignore - with pytest.raises(ValueError): + with pytest.raises(TypeError): gl.add_label(iqid_yes, None) # type: ignore - with pytest.raises(ValueError): + with pytest.raises(TypeError): gl.add_label(iqid_yes, True) # type: ignore - with pytest.raises(ValueError): + with pytest.raises(TypeError): gl.add_label(iqid_yes, False) # type: ignore - with pytest.raises(ValueError): + with pytest.raises(TypeError): gl.add_label(iqid_yes, b"YES") # type: ignore - # We may want to support something like this in the future, but not yet - with pytest.raises(ValueError): - gl.add_label(iqid_yes, Label.UNCLEAR) - def test_label_conversion_produces_strings(): # In our code, it's easier to work with enums, but we allow users to pass in strings or enums diff --git a/test/unit/test_labels.py b/test/unit/test_labels.py new file mode 100644 index 00000000..d894801d --- /dev/null +++ b/test/unit/test_labels.py @@ -0,0 +1,59 @@ +from datetime import datetime + +import pytest +from groundlight import ApiException, ExperimentalApi + + +def test_binary_labels(gl_experimental: ExperimentalApi): + name = f"Test binary labels{datetime.utcnow()}" + det = gl_experimental.create_detector(name, "test_query") + iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg") + gl_experimental.add_label(iq1, "YES") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "YES" + gl_experimental.add_label(iq1, "NO") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "NO" + gl_experimental.add_label(iq1, "UNCLEAR") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "UNCLEAR" + with pytest.raises(ApiException) as _: + gl_experimental.add_label(iq1, "MAYBE") + + +def test_counting_labels(gl_experimental: ExperimentalApi): + name = f"Test binary labels{datetime.utcnow()}" + det = gl_experimental.create_counting_detector(name, "test_query", "test_object_class") + iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg") + gl_experimental.add_label(iq1, 0) + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.count == 0 + good_label = 5 + gl_experimental.add_label(iq1, good_label) + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.count == good_label + with pytest.raises(ApiException) as _: + gl_experimental.add_label(iq1, "MAYBE") + with pytest.raises(ApiException) as _: + gl_experimental.add_label(iq1, -999) + + +def test_multiclass_labels(gl_experimental: ExperimentalApi): + name = f"Test binary labels{datetime.utcnow()}" + det = gl_experimental.create_multiclass_detector(name, "test_query", class_names=["apple", "banana", "cherry"]) + iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg") + gl_experimental.add_label(iq1, "apple") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "apple" + gl_experimental.add_label(iq1, "banana") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "banana" + gl_experimental.add_label(iq1, "cherry") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "cherry" + # You can submit the index of the class as well + gl_experimental.add_label(iq1, 2) + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "cherry" + with pytest.raises(ApiException) as _: + gl_experimental.add_label(iq1, "MAYBE")