Skip to content

Commit 3db0a05

Browse files
committed
fixing result type issue
1 parent 02ad98a commit 3db0a05

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

src/groundlight/client.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,18 @@ def _verify_connectivity(self) -> None:
230230
)
231231
raise GroundlightClientError(msg) from e
232232

233+
@staticmethod
234+
def _propagate_result_type(iq_dict: dict) -> None:
235+
"""Propagate result_type into the result dict for correct Pydantic Union deserialization.
236+
Edge responses may omit result_type from the result object, causing Pydantic to
237+
pick the wrong type from the Union (e.g. BinaryClassificationResult instead of
238+
MultiClassificationResult).
239+
"""
240+
result = iq_dict.get("result")
241+
result_type = iq_dict.get("result_type")
242+
if isinstance(result, dict) and result_type and not result.get("result_type"):
243+
result["result_type"] = result_type
244+
233245
@staticmethod
234246
def _fixup_image_query(iq: ImageQuery) -> ImageQuery:
235247
"""
@@ -605,7 +617,9 @@ def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-b
605617
if obj.result_type == "counting" and getattr(obj.result, "label", None):
606618
obj.result.pop("label")
607619
obj.result["count"] = None
608-
iq = ImageQuery.parse_obj(obj.to_dict())
620+
iq_dict = obj.to_dict()
621+
self._propagate_result_type(iq_dict)
622+
iq = ImageQuery.parse_obj(iq_dict)
609623
return self._fixup_image_query(iq)
610624

611625
def list_image_queries(
@@ -636,7 +650,10 @@ def list_image_queries(
636650
if detector_id:
637651
params["detector_id"] = detector_id
638652
obj = self.image_queries_api.list_image_queries(**params)
639-
image_queries = PaginatedImageQueryList.parse_obj(obj.to_dict())
653+
obj_dict = obj.to_dict()
654+
for iq_dict in obj_dict.get("results") or []:
655+
self._propagate_result_type(iq_dict)
656+
image_queries = PaginatedImageQueryList.parse_obj(obj_dict)
640657
if image_queries.results is not None:
641658
image_queries.results = [self._fixup_image_query(iq) for iq in image_queries.results]
642659
return image_queries
@@ -809,7 +826,9 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments, t
809826
params["image_query_id"] = image_query_id
810827

811828
raw_image_query = self.image_queries_api.submit_image_query(**params)
812-
image_query = ImageQuery.parse_obj(raw_image_query.to_dict())
829+
iq_dict = raw_image_query.to_dict()
830+
self._propagate_result_type(iq_dict)
831+
image_query = ImageQuery.parse_obj(iq_dict)
813832

814833
if wait > 0:
815834
if confidence_threshold is None:

0 commit comments

Comments
 (0)