Skip to content

Commit 42f53d5

Browse files
author
brandon
committed
Adds sdk label support for multiclass detectors
1 parent 5e9da70 commit 42f53d5

File tree

4 files changed

+86
-95
lines changed

4 files changed

+86
-95
lines changed

src/groundlight/binary_labels.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,25 +53,3 @@ def convert_internal_label_to_display(
5353

5454
logger.warning(f"Unrecognized internal label {label} - leaving it alone as a string.")
5555
return label
56-
57-
58-
def convert_display_label_to_internal(
59-
context: Union[ImageQuery, Detector, str], # pylint: disable=unused-argument
60-
label: Union[Label, str],
61-
) -> str:
62-
"""Convert a label that comes from the user into the label string that we send to the server. We
63-
are strict here, and only allow YES/NO.
64-
65-
NOTE: We accept case-insensitive label strings from the user, but we send UPPERCASE labels to
66-
the server. E.g., user inputs "yes" -> the label is returned as "YES".
67-
"""
68-
# NOTE: In the future we should validate against actually supported labels for the detector
69-
if not isinstance(label, str):
70-
raise ValueError(f"Expected a string label, but got {label} of type {type(label)}")
71-
upper = label.upper()
72-
if upper == Label.YES:
73-
return DeprecatedLabel.PASS.value
74-
if upper == Label.NO:
75-
return DeprecatedLabel.FAIL.value
76-
77-
raise ValueError(f"Invalid label string '{label}'. Must be one of '{Label.YES.value}','{Label.NO.value}'.")

src/groundlight/client.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
from groundlight_openapi_client.api.labels_api import LabelsApi
1414
from groundlight_openapi_client.api.user_api import UserApi
1515
from groundlight_openapi_client.exceptions import NotFoundException, UnauthorizedException
16+
from groundlight_openapi_client.model.b_box_geometry_request import BBoxGeometryRequest
1617
from groundlight_openapi_client.model.detector_creation_input_request import DetectorCreationInputRequest
1718
from groundlight_openapi_client.model.label_value_request import LabelValueRequest
1819
from groundlight_openapi_client.model.patched_detector_request import PatchedDetectorRequest
20+
from groundlight_openapi_client.model.roi_request import ROIRequest
1921
from model import (
2022
ROI,
2123
BinaryClassificationResult,
@@ -26,7 +28,7 @@
2628
)
2729
from urllib3.exceptions import InsecureRequestWarning
2830

29-
from groundlight.binary_labels import Label, convert_display_label_to_internal, convert_internal_label_to_display
31+
from groundlight.binary_labels import Label, convert_internal_label_to_display
3032
from groundlight.config import API_TOKEN_MISSING_HELP_MESSAGE, API_TOKEN_VARIABLE_NAME, DISABLE_TLS_VARIABLE_NAME
3133
from groundlight.encodings import url_encode_dict
3234
from groundlight.images import ByteStreamWrapper, parse_supported_image_types
@@ -1066,16 +1068,17 @@ def _wait_for_result(
10661068
image_query = self._fixup_image_query(image_query)
10671069
return image_query
10681070

1071+
# pylint: disable=duplicate-code
10691072
def add_label(
1070-
self, image_query: Union[ImageQuery, str], label: Union[Label, str], rois: Union[List[ROI], str, None] = None
1073+
self, image_query: Union[ImageQuery, str], label: Union[Label, int, str], rois: Union[List[ROI], str, None] = None
10711074
):
10721075
"""
10731076
Provide a new label (annotation) for an image query. This is used to provide ground-truth labels
10741077
for training detectors, or to correct the results of detectors.
10751078
10761079
**Example usage**::
10771080
1078-
gl = Groundlight()
1081+
gl = ExperimentalApi()
10791082
10801083
# Using an ImageQuery object
10811084
image_query = gl.ask_ml(detector_id, image_data)
@@ -1088,27 +1091,41 @@ def add_label(
10881091
rois = [ROI(x=100, y=100, width=50, height=50)]
10891092
gl.add_label(image_query, "YES", rois=rois)
10901093
1091-
:param image_query: Either an ImageQuery object (returned from methods like :meth:`ask_ml`) or an image query ID
1092-
string starting with "iq_".
1093-
:param label: The label value to assign, typically "YES" or "NO" for binary classification detectors.
1094-
For multi-class detectors, use one of the defined class names.
1095-
:param rois: Optional list of ROI objects defining regions of interest in the image.
1096-
Each ROI specifies a bounding box with x, y coordinates and width, height.
1094+
:param image_query: Either an ImageQuery object (returned from methods like
1095+
`ask_ml`) or an image query ID string starting with "iq_".
1096+
1097+
:param label: The label value to assign, typically "YES" or "NO" for binary
1098+
classification detectors. For multi-class detectors, use one of
1099+
the defined class names.
1100+
1101+
:param rois: Optional list of ROI objects defining regions of interest in the
1102+
image. Each ROI specifies a bounding box with x, y coordinates
1103+
and width, height.
10971104
10981105
:return: None
10991106
"""
11001107
if isinstance(rois, str):
11011108
raise TypeError("rois must be a list of ROI objects. CLI support is not implemented")
1109+
if isinstance(label, int):
1110+
label = str(label)
11021111
if isinstance(image_query, ImageQuery):
11031112
image_query_id = image_query.id
11041113
else:
11051114
image_query_id = str(image_query)
11061115
# Some old imagequery id's started with "chk_"
1116+
# TODO: handle iqe_ for image_queries returned from edge endpoints
11071117
if not image_query_id.startswith(("chk_", "iq_")):
11081118
raise ValueError(f"Invalid image query id {image_query_id}")
1109-
api_label = convert_display_label_to_internal(image_query_id, label)
1110-
rois_json = [roi.dict() for roi in rois] if rois else None
1111-
request_params = LabelValueRequest(label=api_label, image_query_id=image_query_id, rois=rois_json)
1119+
geometry_requests = [BBoxGeometryRequest(**roi.geometry.dict()) for roi in rois] if rois else None
1120+
roi_requests = (
1121+
[
1122+
ROIRequest(label=roi.label, score=roi.score, geometry=geometry)
1123+
for roi, geometry in zip(rois, geometry_requests)
1124+
]
1125+
if rois and geometry_requests
1126+
else None
1127+
)
1128+
request_params = LabelValueRequest(label=label, image_query_id=image_query_id, rois=roi_requests)
11121129
self.labels_api.create_label(request_params)
11131130

11141131
def start_inspection(self) -> str:

src/groundlight/experimental_api.py

Lines changed: 1 addition & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from groundlight_openapi_client.model.verb_enum import VerbEnum
3333
from model import ROI, BBoxGeometry, Detector, DetectorGroup, ImageQuery, ModeEnum, PaginatedRuleList, Rule
3434

35-
from groundlight.binary_labels import Label, convert_display_label_to_internal
35+
from groundlight.binary_labels import Label
3636
from groundlight.images import parse_supported_image_types
3737
from groundlight.optional_imports import Image, np
3838

@@ -499,66 +499,6 @@ def create_roi(self, label: str, top_left: Tuple[float, float], bottom_right: Tu
499499
),
500500
)
501501

502-
# TODO: remove duplicate method on subclass
503-
# pylint: disable=duplicate-code
504-
def add_label(
505-
self, image_query: Union[ImageQuery, str], label: Union[Label, str], rois: Union[List[ROI], str, None] = None
506-
):
507-
"""
508-
Provide a new label (annotation) for an image query. This is used to provide ground-truth labels
509-
for training detectors, or to correct the results of detectors.
510-
511-
**Example usage**::
512-
513-
gl = ExperimentalApi()
514-
515-
# Using an ImageQuery object
516-
image_query = gl.ask_ml(detector_id, image_data)
517-
gl.add_label(image_query, "YES")
518-
519-
# Using an image query ID string directly
520-
gl.add_label("iq_abc123", "NO")
521-
522-
# With regions of interest (ROIs)
523-
rois = [ROI(x=100, y=100, width=50, height=50)]
524-
gl.add_label(image_query, "YES", rois=rois)
525-
526-
:param image_query: Either an ImageQuery object (returned from methods like
527-
`ask_ml`) or an image query ID string starting with "iq_".
528-
529-
:param label: The label value to assign, typically "YES" or "NO" for binary
530-
classification detectors. For multi-class detectors, use one of
531-
the defined class names.
532-
533-
:param rois: Optional list of ROI objects defining regions of interest in the
534-
image. Each ROI specifies a bounding box with x, y coordinates
535-
and width, height.
536-
537-
:return: None
538-
"""
539-
if isinstance(rois, str):
540-
raise TypeError("rois must be a list of ROI objects. CLI support is not implemented")
541-
if isinstance(image_query, ImageQuery):
542-
image_query_id = image_query.id
543-
else:
544-
image_query_id = str(image_query)
545-
# Some old imagequery id's started with "chk_"
546-
# TODO: handle iqe_ for image_queries returned from edge endpoints
547-
if not image_query_id.startswith(("chk_", "iq_")):
548-
raise ValueError(f"Invalid image query id {image_query_id}")
549-
api_label = convert_display_label_to_internal(image_query_id, label)
550-
geometry_requests = [BBoxGeometryRequest(**roi.geometry.dict()) for roi in rois] if rois else None
551-
roi_requests = (
552-
[
553-
ROIRequest(label=roi.label, score=roi.score, geometry=geometry)
554-
for roi, geometry in zip(rois, geometry_requests)
555-
]
556-
if rois and geometry_requests
557-
else None
558-
)
559-
request_params = LabelValueRequest(label=api_label, image_query_id=image_query_id, rois=roi_requests)
560-
self.labels_api.create_label(request_params)
561-
562502
def reset_detector(self, detector: Union[str, Detector]) -> None:
563503
"""
564504
Removes all image queries and training data for the given detector. This effectively resets

test/unit/test_labels.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from datetime import datetime
2+
3+
import pytest
4+
from groundlight import ExperimentalApi, ApiException
5+
6+
7+
def test_binary_labels(gl_experimental: ExperimentalApi):
8+
name = f"Test binary labels{datetime.utcnow()}"
9+
det = gl_experimental.create_detector(name, "test_query")
10+
iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg")
11+
gl_experimental.add_label(iq1, "YES")
12+
iq1 = gl_experimental.get_image_query(iq1.id)
13+
assert iq1.result.label == "YES"
14+
gl_experimental.add_label(iq1, "NO")
15+
iq1 = gl_experimental.get_image_query(iq1.id)
16+
assert iq1.result.label == "NO"
17+
gl_experimental.add_label(iq1, "UNCLEAR")
18+
iq1 = gl_experimental.get_image_query(iq1.id)
19+
assert iq1.result.label == "UNCLEAR"
20+
with pytest.raises(ApiException) as _:
21+
gl_experimental.add_label(iq1, "MAYBE")
22+
23+
def test_counting_labels(gl_experimental: ExperimentalApi):
24+
name = f"Test binary labels{datetime.utcnow()}"
25+
det = gl_experimental.create_counting_detector(name, "test_query")
26+
iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg")
27+
gl_experimental.add_label(iq1, 0)
28+
iq1 = gl_experimental.get_image_query(iq1.id)
29+
assert iq1.result.count == 0
30+
gl_experimental.add_label(iq1, 5)
31+
iq1 = gl_experimental.get_image_query(iq1.id)
32+
assert iq1.result.count == 5
33+
with pytest.raises(ApiException) as _:
34+
gl_experimental.add_label(iq1, "MAYBE")
35+
with pytest.raises(ApiException) as _:
36+
gl_experimental.add_label(iq1, -999)
37+
38+
def test_multiclass_labels(gl_experimental: ExperimentalApi):
39+
name = f"Test binary labels{datetime.utcnow()}"
40+
det = gl_experimental.create_multiclass_detector(name, "test_query", class_names=["apple", "banana", "cherry"])
41+
iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg")
42+
gl_experimental.add_label(iq1, "apple")
43+
iq1 = gl_experimental.get_image_query(iq1.id)
44+
assert iq1.result.label == "apple"
45+
gl_experimental.add_label(iq1, "banana")
46+
iq1 = gl_experimental.get_image_query(iq1.id)
47+
assert iq1.result.label == "banana"
48+
gl_experimental.add_label(iq1, "cherry")
49+
iq1 = gl_experimental.get_image_query(iq1.id)
50+
assert iq1.result.label == "cherry"
51+
# You can submit the index of the class as well
52+
gl_experimental.add_label(iq1, 2)
53+
iq1 = gl_experimental.get_image_query(iq1.id)
54+
assert iq1.result.label == "cherry"
55+
with pytest.raises(ApiException) as _:
56+
gl_experimental.add_label(iq1, "MAYBE")

0 commit comments

Comments
 (0)