Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/groundlight/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,18 @@ def _verify_connectivity(self) -> None:
)
raise GroundlightClientError(msg) from e

@staticmethod
def _propagate_result_type(iq_dict: dict) -> None:
"""Propagate result_type into the result dict for correct Pydantic Union deserialization.
Edge responses may omit result_type from the result object, causing Pydantic to
pick the wrong type from the Union (e.g. BinaryClassificationResult instead of
MultiClassificationResult).
"""
result = iq_dict.get("result")
result_type = iq_dict.get("result_type")
if isinstance(result, dict) and result_type and not result.get("result_type"):
result["result_type"] = result_type

@staticmethod
def _fixup_image_query(iq: ImageQuery) -> ImageQuery:
"""
Expand Down Expand Up @@ -605,7 +617,9 @@ def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-b
if obj.result_type == "counting" and getattr(obj.result, "label", None):
obj.result.pop("label")
obj.result["count"] = None
iq = ImageQuery.parse_obj(obj.to_dict())
iq_dict = obj.to_dict()
self._propagate_result_type(iq_dict)
iq = ImageQuery.parse_obj(iq_dict)
return self._fixup_image_query(iq)

def list_image_queries(
Expand Down Expand Up @@ -636,7 +650,10 @@ def list_image_queries(
if detector_id:
params["detector_id"] = detector_id
obj = self.image_queries_api.list_image_queries(**params)
image_queries = PaginatedImageQueryList.parse_obj(obj.to_dict())
obj_dict = obj.to_dict()
for iq_dict in obj_dict.get("results") or []:
self._propagate_result_type(iq_dict)
image_queries = PaginatedImageQueryList.parse_obj(obj_dict)
if image_queries.results is not None:
image_queries.results = [self._fixup_image_query(iq) for iq in image_queries.results]
return image_queries
Expand Down Expand Up @@ -809,7 +826,9 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments, t
params["image_query_id"] = image_query_id

raw_image_query = self.image_queries_api.submit_image_query(**params)
image_query = ImageQuery.parse_obj(raw_image_query.to_dict())
iq_dict = raw_image_query.to_dict()
self._propagate_result_type(iq_dict)
image_query = ImageQuery.parse_obj(iq_dict)

if wait > 0:
if confidence_threshold is None:
Expand Down
Loading