From b0affeffaab52435e8b520c9bea2903496b428c2 Mon Sep 17 00:00:00 2001 From: Akol125 Date: Thu, 29 Jan 2026 22:17:03 +0000 Subject: [PATCH 1/5] make external service status code reusable --- lambdas/mns_subscription/src/mns_service.py | 41 +----- lambdas/mns_subscription/src/models/errors.py | 99 ------------- lambdas/mns_subscription/tests/test_errors.py | 100 ------------- .../tests/test_mns_service.py | 15 +- lambdas/shared/src/common/models/errors.py | 139 ++++++++++++++++++ lambdas/shared/src/common/pds_service.py | 8 + .../shared/tests/test_common/test_errors.py | 80 ++++++++++ 7 files changed, 238 insertions(+), 244 deletions(-) delete mode 100644 lambdas/mns_subscription/src/models/errors.py delete mode 100644 lambdas/mns_subscription/tests/test_errors.py diff --git a/lambdas/mns_subscription/src/mns_service.py b/lambdas/mns_subscription/src/mns_service.py index 309d44898..7aadbab84 100644 --- a/lambdas/mns_subscription/src/mns_service.py +++ b/lambdas/mns_subscription/src/mns_service.py @@ -7,15 +7,7 @@ from common.authentication import AppRestrictedAuth from common.models.errors import ( - ResourceNotFoundError, - UnhandledResponseError, -) -from models.errors import ( - BadRequestError, - ConflictError, - ServerError, - TokenValidationError, - UnauthorizedError, + raise_error_response, ) SQS_ARN = os.getenv("SQS_ARN") @@ -61,7 +53,7 @@ def subscribe_notification(self) -> dict | None: if response.status_code in (200, 201): return response.json() else: - MnsService.raise_error_response(response) + raise_error_response(response) def get_subscription(self) -> dict | None: response = requests.get(MNS_URL, headers=self.request_headers, timeout=10) @@ -79,7 +71,7 @@ def get_subscription(self) -> dict | None: return resource return None else: - MnsService.raise_error_response(response) + raise_error_response(response) def check_subscription(self) -> dict: """ @@ -107,7 +99,7 @@ def delete_subscription(self, subscription_id: str) -> str: logging.info(f"Deleted subscription {subscription_id}") return "Subscription Successfully Deleted..." else: - MnsService.raise_error_response(response) + raise_error_response(response) def check_delete_subscription(self): try: @@ -123,28 +115,3 @@ def check_delete_subscription(self): return "Subscription successfully deleted" except Exception as e: return f"Error deleting subscription: {str(e)}" - - @staticmethod - def raise_error_response(response): - error_mapping = { - 401: (TokenValidationError, "Token validation failed for the request"), - 400: ( - BadRequestError, - "Bad request: Resource type or parameters incorrect", - ), - 403: ( - UnauthorizedError, - "You don't have the right permissions for this request", - ), - 500: (ServerError, "Internal Server Error"), - 404: (ResourceNotFoundError, "Subscription or Resource not found"), - 409: (ConflictError, "SQS Queue Already Subscribed, can't re-subscribe"), - } - exception_class, error_message = error_mapping.get( - response.status_code, - (UnhandledResponseError, f"Unhandled error: {response.status_code}"), - ) - - if response.status_code == 404: - raise exception_class(resource_type=response.json(), resource_id=error_message) - raise exception_class(response=response.json(), message=error_message) diff --git a/lambdas/mns_subscription/src/models/errors.py b/lambdas/mns_subscription/src/models/errors.py deleted file mode 100644 index f85b7286e..000000000 --- a/lambdas/mns_subscription/src/models/errors.py +++ /dev/null @@ -1,99 +0,0 @@ -import uuid -from dataclasses import dataclass - -from common.models.errors import Code, Severity, create_operation_outcome - - -@dataclass -class UnauthorizedError(RuntimeError): - response: dict | str - message: str - - def __str__(self): - return f"{self.message}\n{self.response}" - - @staticmethod - def to_operation_outcome() -> dict: - msg = "Unauthorized request" - return create_operation_outcome( - resource_id=str(uuid.uuid4()), - severity=Severity.error, - code=Code.forbidden, - diagnostics=msg, - ) - - -@dataclass -class TokenValidationError(RuntimeError): - response: dict | str - message: str - - def __str__(self): - return f"{self.message}\n{self.response}" - - @staticmethod - def to_operation_outcome() -> dict: - msg = "Missing/Invalid Token" - return create_operation_outcome( - resource_id=str(uuid.uuid4()), - severity=Severity.error, - code=Code.invalid, - diagnostics=msg, - ) - - -@dataclass -class ConflictError(RuntimeError): - response: dict | str - message: str - - def __str__(self): - return f"{self.message}\n{self.response}" - - @staticmethod - def to_operation_outcome() -> dict: - msg = "Conflict" - return create_operation_outcome( - resource_id=str(uuid.uuid4()), - severity=Severity.error, - code=Code.duplicate, - diagnostics=msg, - ) - - -@dataclass -class BadRequestError(RuntimeError): - """Use when payload is missing required parameters""" - - response: dict | str - message: str - - def __str__(self): - return f"{self.message}\n{self.response}" - - def to_operation_outcome(self) -> dict: - return create_operation_outcome( - resource_id=str(uuid.uuid4()), - severity=Severity.error, - code=Code.incomplete, - diagnostics=self.__str__(), - ) - - -@dataclass -class ServerError(RuntimeError): - """Use when there is a server error""" - - response: dict | str - message: str - - def __str__(self): - return f"{self.message}\n{self.response}" - - def to_operation_outcome(self) -> dict: - return create_operation_outcome( - resource_id=str(uuid.uuid4()), - severity=Severity.error, - code=Code.server_error, - diagnostics=self.__str__(), - ) diff --git a/lambdas/mns_subscription/tests/test_errors.py b/lambdas/mns_subscription/tests/test_errors.py deleted file mode 100644 index 888019641..000000000 --- a/lambdas/mns_subscription/tests/test_errors.py +++ /dev/null @@ -1,100 +0,0 @@ -import unittest -from unittest.mock import patch - -import models.errors as errors - - -class TestErrors(unittest.TestCase): - def setUp(self): - TEST_UUID = "01234567-89ab-cdef-0123-4567890abcde" - # Patch uuid4 - self.uuid4_patch = patch("uuid.uuid4", return_value=TEST_UUID) - self.mock_uuid4 = self.uuid4_patch.start() - self.addCleanup(self.uuid4_patch.stop) - - def assert_response_message(self, context, response, message): - self.assertEqual(context.exception.response, response) - self.assertEqual(context.exception.message, message) - - def assert_operation_outcome(self, outcome): - self.assertEqual(outcome.get("resourceType"), "OperationOutcome") - - def test_errors_unauthorized_error(self): - """Test correct operation of UnauthorizedError""" - test_response = "test_response" - test_message = "test_message" - - with self.assertRaises(errors.UnauthorizedError) as context: - raise errors.UnauthorizedError(test_response, test_message) - self.assert_response_message(context, test_response, test_message) - self.assertEqual(str(context.exception), f"{test_message}\n{test_response}") - outcome = context.exception.to_operation_outcome() - self.assert_operation_outcome(outcome) - issue = outcome.get("issue")[0] - self.assertEqual(issue.get("severity"), errors.Severity.error) - self.assertEqual(issue.get("code"), errors.Code.forbidden) - self.assertEqual(issue.get("diagnostics"), "Unauthorized request") - - def test_errors_token_validation_error(self): - """Test correct operation of TokenValidationError""" - test_response = "test_response" - test_message = "test_message" - - with self.assertRaises(errors.TokenValidationError) as context: - raise errors.TokenValidationError(test_response, test_message) - self.assert_response_message(context, test_response, test_message) - self.assertEqual(str(context.exception), f"{test_message}\n{test_response}") - outcome = context.exception.to_operation_outcome() - self.assert_operation_outcome(outcome) - issue = outcome.get("issue")[0] - self.assertEqual(issue.get("severity"), errors.Severity.error) - self.assertEqual(issue.get("code"), errors.Code.invalid) - self.assertEqual(issue.get("diagnostics"), "Missing/Invalid Token") - - def test_errors_conflict_error(self): - """Test correct operation of ConflictError""" - test_response = "test_response" - test_message = "test_message" - - with self.assertRaises(errors.ConflictError) as context: - raise errors.ConflictError(test_response, test_message) - self.assert_response_message(context, test_response, test_message) - self.assertEqual(str(context.exception), f"{test_message}\n{test_response}") - outcome = context.exception.to_operation_outcome() - self.assert_operation_outcome(outcome) - issue = outcome.get("issue")[0] - self.assertEqual(issue.get("severity"), errors.Severity.error) - self.assertEqual(issue.get("code"), errors.Code.duplicate) - self.assertEqual(issue.get("diagnostics"), "Conflict") - - def test_errors_bad_request_error(self): - """Test correct operation of BadRequestError""" - test_response = "test_response" - test_message = "test_message" - - with self.assertRaises(errors.BadRequestError) as context: - raise errors.BadRequestError(test_response, test_message) - self.assert_response_message(context, test_response, test_message) - self.assertEqual(str(context.exception), f"{test_message}\n{test_response}") - outcome = context.exception.to_operation_outcome() - self.assert_operation_outcome(outcome) - issue = outcome.get("issue")[0] - self.assertEqual(issue.get("severity"), errors.Severity.error) - self.assertEqual(issue.get("code"), errors.Code.incomplete) - self.assertEqual(issue.get("diagnostics"), f"{test_message}\n{test_response}") - - def test_errors_server_error(self): - """Test correct operation of ServerError""" - test_response = "test_response" - test_message = "test_message" - - with self.assertRaises(errors.ServerError) as context: - raise errors.ServerError(test_response, test_message) - self.assert_response_message(context, test_response, test_message) - self.assertEqual(str(context.exception), f"{test_message}\n{test_response}") - outcome = context.exception.to_operation_outcome() - self.assert_operation_outcome(outcome) - issue = outcome.get("issue")[0] - self.assertEqual(issue.get("severity"), errors.Severity.error) - self.assertEqual(issue.get("code"), errors.Code.server_error) - self.assertEqual(issue.get("diagnostics"), f"{test_message}\n{test_response}") diff --git a/lambdas/mns_subscription/tests/test_mns_service.py b/lambdas/mns_subscription/tests/test_mns_service.py index 66bb34134..190b91cf8 100644 --- a/lambdas/mns_subscription/tests/test_mns_service.py +++ b/lambdas/mns_subscription/tests/test_mns_service.py @@ -4,16 +4,15 @@ from common.authentication import AppRestrictedAuth from common.models.errors import ( - ResourceNotFoundError, - UnhandledResponseError, -) -from mns_service import MNS_URL, MnsService -from models.errors import ( BadRequestError, + ResourceNotFoundError, ServerError, TokenValidationError, UnauthorizedError, + UnhandledResponseError, + raise_error_response, ) +from mns_service import MNS_URL, MnsService SQS_ARN = "arn:aws:sqs:eu-west-2:123456789012:my-queue" @@ -255,7 +254,7 @@ def mock_response(self, status_code, json_data=None): def test_404_resource_found_error(self): resp = self.mock_response(404, {"resource": "Not found"}) with self.assertRaises(ResourceNotFoundError) as context: - MnsService.raise_error_response(resp) + raise_error_response(resp) self.assertIn("Subscription or Resource not found", str(context.exception)) self.assertEqual(context.exception.resource_id, "Subscription or Resource not found") self.assertEqual(context.exception.resource_type, {"resource": "Not found"}) @@ -263,7 +262,7 @@ def test_404_resource_found_error(self): def test_400_bad_request_error(self): resp = self.mock_response(400, {"resource": "Invalid"}) with self.assertRaises(BadRequestError) as context: - MnsService.raise_error_response(resp) + raise_error_response(resp) self.assertIn("Bad request: Resource type or parameters incorrect", str(context.exception)) self.assertEqual( context.exception.message, @@ -274,7 +273,7 @@ def test_400_bad_request_error(self): def test_unhandled_status_code(self): resp = self.mock_response(418, {"resource": 1234}) with self.assertRaises(UnhandledResponseError) as context: - MnsService.raise_error_response(resp) + raise_error_response(resp) self.assertIn("Unhandled error: 418", str(context.exception)) self.assertEqual(context.exception.response, {"resource": 1234}) diff --git a/lambdas/shared/src/common/models/errors.py b/lambdas/shared/src/common/models/errors.py index 060fb7495..a85754936 100644 --- a/lambdas/shared/src/common/models/errors.py +++ b/lambdas/shared/src/common/models/errors.py @@ -157,6 +157,120 @@ def to_operation_outcome(self) -> dict: ) +@dataclass +class UnauthorizedError(RuntimeError): + response: dict | str + message: str + + def __str__(self): + return f"{self.message}\n{self.response}" + + @staticmethod + def to_operation_outcome() -> dict: + msg = "Unauthorized request" + return create_operation_outcome( + resource_id=str(uuid.uuid4()), + severity=Severity.error, + code=Code.forbidden, + diagnostics=msg, + ) + + +@dataclass +class TokenValidationError(RuntimeError): + response: dict | str + message: str + + def __str__(self): + return f"{self.message}\n{self.response}" + + @staticmethod + def to_operation_outcome() -> dict: + msg = "Missing/Invalid Token" + return create_operation_outcome( + resource_id=str(uuid.uuid4()), + severity=Severity.error, + code=Code.invalid, + diagnostics=msg, + ) + + +@dataclass +class ForbiddenError(Exception): + response: dict | str + message: str + + def __str__(self): + return f"{self.message}\n{self.response}" + + @staticmethod + def to_operation_outcome() -> dict: + msg = "Forbidden" + return create_operation_outcome( + resource_id=str(uuid.uuid4()), + severity=Severity.error, + code=Code.forbidden, + diagnostics=msg, + ) + + +@dataclass +class ConflictError(RuntimeError): + response: dict | str + message: str + + def __str__(self): + return f"{self.message}\n{self.response}" + + @staticmethod + def to_operation_outcome() -> dict: + msg = "Conflict" + return create_operation_outcome( + resource_id=str(uuid.uuid4()), + severity=Severity.error, + code=Code.duplicate, + diagnostics=msg, + ) + + +@dataclass +class BadRequestError(RuntimeError): + """Use when payload is missing required parameters""" + + response: dict | str + message: str + + def __str__(self): + return f"{self.message}\n{self.response}" + + def to_operation_outcome(self) -> dict: + return create_operation_outcome( + resource_id=str(uuid.uuid4()), + severity=Severity.error, + code=Code.incomplete, + diagnostics=self.__str__(), + ) + + +@dataclass +class ServerError(RuntimeError): + """Use when there is a server error""" + + response: dict | str + message: str + + def __str__(self): + return f"{self.message}\n{self.response}" + + def to_operation_outcome(self) -> dict: + return create_operation_outcome( + resource_id=str(uuid.uuid4()), + severity=Severity.error, + code=Code.server_error, + diagnostics=self.__str__(), + ) + + def create_operation_outcome(resource_id: str, severity: Severity, code: Code, diagnostics: str) -> dict: """Create an OperationOutcome object. Do not use `fhir.resource` library since it adds unnecessary validations""" return { @@ -179,3 +293,28 @@ def create_operation_outcome(resource_id: str, severity: Severity, code: Code, d } ], } + + +def raise_error_response(response): + error_mapping = { + 401: (TokenValidationError, "Token validation failed for the request"), + 400: ( + BadRequestError, + "Bad request: Resource type or parameters incorrect", + ), + 403: ( + UnauthorizedError, + "Forbidden: You do not have permission to access this resource", + ), + 500: (ServerError, "Internal Server Error"), + 404: (ResourceNotFoundError, "Subscription or Resource not found"), + 409: (ConflictError, "SQS Queue Already Subscribed, can't re-subscribe"), + } + exception_class, error_message = error_mapping.get( + response.status_code, + (UnhandledResponseError, f"Unhandled error: {response.status_code}"), + ) + + if response.status_code == 404: + raise exception_class(resource_type=response.json(), resource_id=error_message) + raise exception_class(response=response.json(), message=error_message) diff --git a/lambdas/shared/src/common/pds_service.py b/lambdas/shared/src/common/pds_service.py index 65dafef8b..939fdb0ef 100644 --- a/lambdas/shared/src/common/pds_service.py +++ b/lambdas/shared/src/common/pds_service.py @@ -34,6 +34,14 @@ def get_patient_details(self, patient_id: str) -> dict | None: elif response.status_code == 404: logger.info("Patient not found") return None + elif response.status_code in (400, 401, 403): + logger.info(f"PDS Client Error: Status = {response.status_code} - Body {response.text}") + msg = "Client error occurred while calling PDS" + raise UnhandledResponseError(response=response.json(), message=msg) + elif response.status_code in (500, 502, 503, 504): + logger.error(f"PDS Server Error: Status = {response.status_code} - Body {response.text}") + msg = "Server error occurred while calling PDS" + raise UnhandledResponseError(response=response.json(), message=msg) else: logger.error(f"PDS. Error response: {response.status_code} - {response.text}") msg = "Downstream service failed to validate the patient" diff --git a/lambdas/shared/tests/test_common/test_errors.py b/lambdas/shared/tests/test_common/test_errors.py index e843a84c6..fe1160ceb 100644 --- a/lambdas/shared/tests/test_common/test_errors.py +++ b/lambdas/shared/tests/test_common/test_errors.py @@ -173,3 +173,83 @@ def test_errors_identifier_duplication_error(self): issue.get("diagnostics"), f"The provided identifier: {test_identifier} is duplicated", ) + + def test_errors_unauthorized_error(self): + """Test correct operation of UnauthorizedError""" + test_response = "test_response" + test_message = "test_message" + + with self.assertRaises(errors.UnauthorizedError) as context: + raise errors.UnauthorizedError(test_response, test_message) + self.assert_response_message(context, test_response, test_message) + self.assertEqual(str(context.exception), f"{test_message}\n{test_response}") + outcome = context.exception.to_operation_outcome() + self.assert_operation_outcome(outcome) + issue = outcome.get("issue")[0] + self.assertEqual(issue.get("severity"), errors.Severity.error) + self.assertEqual(issue.get("code"), errors.Code.forbidden) + self.assertEqual(issue.get("diagnostics"), "Unauthorized request") + + def test_errors_token_validation_error(self): + """Test correct operation of TokenValidationError""" + test_response = "test_response" + test_message = "test_message" + + with self.assertRaises(errors.TokenValidationError) as context: + raise errors.TokenValidationError(test_response, test_message) + self.assert_response_message(context, test_response, test_message) + self.assertEqual(str(context.exception), f"{test_message}\n{test_response}") + outcome = context.exception.to_operation_outcome() + self.assert_operation_outcome(outcome) + issue = outcome.get("issue")[0] + self.assertEqual(issue.get("severity"), errors.Severity.error) + self.assertEqual(issue.get("code"), errors.Code.invalid) + self.assertEqual(issue.get("diagnostics"), "Missing/Invalid Token") + + def test_errors_conflict_error(self): + """Test correct operation of ConflictError""" + test_response = "test_response" + test_message = "test_message" + + with self.assertRaises(errors.ConflictError) as context: + raise errors.ConflictError(test_response, test_message) + self.assert_response_message(context, test_response, test_message) + self.assertEqual(str(context.exception), f"{test_message}\n{test_response}") + outcome = context.exception.to_operation_outcome() + self.assert_operation_outcome(outcome) + issue = outcome.get("issue")[0] + self.assertEqual(issue.get("severity"), errors.Severity.error) + self.assertEqual(issue.get("code"), errors.Code.duplicate) + self.assertEqual(issue.get("diagnostics"), "Conflict") + + def test_errors_bad_request_error(self): + """Test correct operation of BadRequestError""" + test_response = "test_response" + test_message = "test_message" + + with self.assertRaises(errors.BadRequestError) as context: + raise errors.BadRequestError(test_response, test_message) + self.assert_response_message(context, test_response, test_message) + self.assertEqual(str(context.exception), f"{test_message}\n{test_response}") + outcome = context.exception.to_operation_outcome() + self.assert_operation_outcome(outcome) + issue = outcome.get("issue")[0] + self.assertEqual(issue.get("severity"), errors.Severity.error) + self.assertEqual(issue.get("code"), errors.Code.incomplete) + self.assertEqual(issue.get("diagnostics"), f"{test_message}\n{test_response}") + + def test_errors_server_error(self): + """Test correct operation of ServerError""" + test_response = "test_response" + test_message = "test_message" + + with self.assertRaises(errors.ServerError) as context: + raise errors.ServerError(test_response, test_message) + self.assert_response_message(context, test_response, test_message) + self.assertEqual(str(context.exception), f"{test_message}\n{test_response}") + outcome = context.exception.to_operation_outcome() + self.assert_operation_outcome(outcome) + issue = outcome.get("issue")[0] + self.assertEqual(issue.get("severity"), errors.Severity.error) + self.assertEqual(issue.get("code"), errors.Code.server_error) + self.assertEqual(issue.get("diagnostics"), f"{test_message}\n{test_response}") From 93937caeb40012d7e12bad39316499948dfe07b9 Mon Sep 17 00:00:00 2001 From: Akol125 Date: Fri, 30 Jan 2026 12:34:53 +0000 Subject: [PATCH 2/5] add logging and more status code --- lambdas/shared/src/common/models/errors.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/lambdas/shared/src/common/models/errors.py b/lambdas/shared/src/common/models/errors.py index a85754936..2bc2ec8a0 100644 --- a/lambdas/shared/src/common/models/errors.py +++ b/lambdas/shared/src/common/models/errors.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from enum import Enum +from common.clients import logger + class Code(str, Enum): forbidden = "forbidden" @@ -298,23 +300,25 @@ def create_operation_outcome(resource_id: str, severity: Severity, code: Code, d def raise_error_response(response): error_mapping = { 401: (TokenValidationError, "Token validation failed for the request"), - 400: ( - BadRequestError, - "Bad request: Resource type or parameters incorrect", - ), - 403: ( - UnauthorizedError, - "Forbidden: You do not have permission to access this resource", - ), + 400: (BadRequestError, "Bad request"), + 403: (ForbiddenError, "Forbidden: You do not have permission to access this resource"), 500: (ServerError, "Internal Server Error"), - 404: (ResourceNotFoundError, "Subscription or Resource not found"), + 404: (ResourceNotFoundError, "Resource not found"), 409: (ConflictError, "SQS Queue Already Subscribed, can't re-subscribe"), + 408: (ServerError, "Request Timeout"), + 429: (ServerError, "Too Many Requests"), + 503: (ServerError, "Service Unavailable"), + 502: (ServerError, "Bad Gateway"), + 504: (ServerError, "Gateway Timeout"), } + exception_class, error_message = error_mapping.get( response.status_code, (UnhandledResponseError, f"Unhandled error: {response.status_code}"), ) + logger.info(f"{error_message}. Status={response.status_code}. Body={response.text}") + if response.status_code == 404: raise exception_class(resource_type=response.json(), resource_id=error_message) raise exception_class(response=response.json(), message=error_message) From 288d738f6eea8567023c5c6f1e21752b7dee51fb Mon Sep 17 00:00:00 2001 From: Akol125 Date: Fri, 30 Jan 2026 13:26:47 +0000 Subject: [PATCH 3/5] add retry with backoff --- .../tests/test_mns_service.py | 14 +++++----- lambdas/shared/src/common/models/constants.py | 1 + lambdas/shared/src/common/models/errors.py | 28 ++++++++++++++++++- lambdas/shared/src/common/pds_service.py | 23 ++++++--------- 4 files changed, 43 insertions(+), 23 deletions(-) diff --git a/lambdas/mns_subscription/tests/test_mns_service.py b/lambdas/mns_subscription/tests/test_mns_service.py index 190b91cf8..aaf1b21f4 100644 --- a/lambdas/mns_subscription/tests/test_mns_service.py +++ b/lambdas/mns_subscription/tests/test_mns_service.py @@ -5,10 +5,10 @@ from common.authentication import AppRestrictedAuth from common.models.errors import ( BadRequestError, + ForbiddenError, ResourceNotFoundError, ServerError, TokenValidationError, - UnauthorizedError, UnhandledResponseError, raise_error_response, ) @@ -63,7 +63,7 @@ def test_not_found_subscription(self, mock_post): with self.assertRaises(ResourceNotFoundError) as context: service.subscribe_notification() - self.assertIn("Subscription or Resource not found", str(context.exception)) + self.assertIn("Resource not found", str(context.exception)) @patch("mns_service.requests.post") def test_unhandled_error(self, mock_post): @@ -170,7 +170,7 @@ def test_delete_subscription_403(self, mock_delete): mock_delete.return_value = mock_response service = MnsService(self.authenticator) - with self.assertRaises(UnauthorizedError): + with self.assertRaises(ForbiddenError): service.delete_subscription("sub-id-123") @patch("mns_service.requests.delete") @@ -255,18 +255,18 @@ def test_404_resource_found_error(self): resp = self.mock_response(404, {"resource": "Not found"}) with self.assertRaises(ResourceNotFoundError) as context: raise_error_response(resp) - self.assertIn("Subscription or Resource not found", str(context.exception)) - self.assertEqual(context.exception.resource_id, "Subscription or Resource not found") + self.assertIn("Resource not found", str(context.exception)) + self.assertEqual(context.exception.resource_id, "Resource not found") self.assertEqual(context.exception.resource_type, {"resource": "Not found"}) def test_400_bad_request_error(self): resp = self.mock_response(400, {"resource": "Invalid"}) with self.assertRaises(BadRequestError) as context: raise_error_response(resp) - self.assertIn("Bad request: Resource type or parameters incorrect", str(context.exception)) + self.assertIn("Bad request", str(context.exception)) self.assertEqual( context.exception.message, - "Bad request: Resource type or parameters incorrect", + "Bad request", ) self.assertEqual(context.exception.response, {"resource": "Invalid"}) diff --git a/lambdas/shared/src/common/models/constants.py b/lambdas/shared/src/common/models/constants.py index 137b9c04c..1cdd16d99 100644 --- a/lambdas/shared/src/common/models/constants.py +++ b/lambdas/shared/src/common/models/constants.py @@ -60,6 +60,7 @@ class Constants: COMPLETED_STATUS = "completed" REINSTATED_RECORD_STATUS = "reinstated" + RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504} class Urls: diff --git a/lambdas/shared/src/common/models/errors.py b/lambdas/shared/src/common/models/errors.py index 2bc2ec8a0..f50eb307f 100644 --- a/lambdas/shared/src/common/models/errors.py +++ b/lambdas/shared/src/common/models/errors.py @@ -1,8 +1,12 @@ +import time import uuid from dataclasses import dataclass from enum import Enum +import requests + from common.clients import logger +from common.models.constants import Constants class Code(str, Enum): @@ -304,7 +308,7 @@ def raise_error_response(response): 403: (ForbiddenError, "Forbidden: You do not have permission to access this resource"), 500: (ServerError, "Internal Server Error"), 404: (ResourceNotFoundError, "Resource not found"), - 409: (ConflictError, "SQS Queue Already Subscribed, can't re-subscribe"), + 409: (ConflictError, "Conflict: Resource already exists"), 408: (ServerError, "Request Timeout"), 429: (ServerError, "Too Many Requests"), 503: (ServerError, "Service Unavailable"), @@ -322,3 +326,25 @@ def raise_error_response(response): if response.status_code == 404: raise exception_class(resource_type=response.json(), resource_id=error_message) raise exception_class(response=response.json(), message=error_message) + + +def request_with_retry_backoff( + url: str, headers: dict, *, timeout: int = 5, max_retries: int = 2, backoff_seconds: float = 0.5 +): + for request_attempt in range(max_retries + 1): + response = requests.get(url, headers=headers, timeout=timeout) + + if response.status_code not in Constants.RETRYABLE_STATUS_CODES: + return response + + if request_attempt < max_retries: + logger.info( + f"Retryable response. Status={response.status_code}. " + f"Attempt={request_attempt + 1}/{max_retries + 1}. Retrying..." + ) + + time.sleep(backoff_seconds * (2**request_attempt)) + continue + + # out of retries, return last response to be handled by caller + return response diff --git a/lambdas/shared/src/common/pds_service.py b/lambdas/shared/src/common/pds_service.py index 939fdb0ef..504adedb4 100644 --- a/lambdas/shared/src/common/pds_service.py +++ b/lambdas/shared/src/common/pds_service.py @@ -1,10 +1,11 @@ import uuid -import requests - from common.authentication import AppRestrictedAuth from common.clients import logger -from common.models.errors import UnhandledResponseError +from common.models.errors import ( + raise_error_response, + requests_get_with_retries, +) class PdsService: @@ -27,22 +28,14 @@ def get_patient_details(self, patient_id: str) -> dict | None: "X-Request-ID": str(uuid.uuid4()), "X-Correlation-ID": str(uuid.uuid4()), } - response = requests.get(f"{self.base_url}/{patient_id}", headers=request_headers, timeout=5) + response = requests_get_with_retries( + f"{self.base_url}/{patient_id}", headers=request_headers, timeout=5, max_retries=2, backoff_seconds=0.5 + ) if response.status_code == 200: return response.json() elif response.status_code == 404: logger.info("Patient not found") return None - elif response.status_code in (400, 401, 403): - logger.info(f"PDS Client Error: Status = {response.status_code} - Body {response.text}") - msg = "Client error occurred while calling PDS" - raise UnhandledResponseError(response=response.json(), message=msg) - elif response.status_code in (500, 502, 503, 504): - logger.error(f"PDS Server Error: Status = {response.status_code} - Body {response.text}") - msg = "Server error occurred while calling PDS" - raise UnhandledResponseError(response=response.json(), message=msg) else: - logger.error(f"PDS. Error response: {response.status_code} - {response.text}") - msg = "Downstream service failed to validate the patient" - raise UnhandledResponseError(response=response.json(), message=msg) + raise_error_response(response) From d2c005c9e664b1f00d2e5075f6c1442a8c8a9d57 Mon Sep 17 00:00:00 2001 From: Akol125 Date: Fri, 30 Jan 2026 14:23:38 +0000 Subject: [PATCH 4/5] fix test and imports --- lambdas/shared/src/common/pds_service.py | 4 ++-- lambdas/shared/tests/test_common/test_pds_service.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lambdas/shared/src/common/pds_service.py b/lambdas/shared/src/common/pds_service.py index 504adedb4..3bc164861 100644 --- a/lambdas/shared/src/common/pds_service.py +++ b/lambdas/shared/src/common/pds_service.py @@ -4,7 +4,7 @@ from common.clients import logger from common.models.errors import ( raise_error_response, - requests_get_with_retries, + request_with_retry_backoff, ) @@ -28,7 +28,7 @@ def get_patient_details(self, patient_id: str) -> dict | None: "X-Request-ID": str(uuid.uuid4()), "X-Correlation-ID": str(uuid.uuid4()), } - response = requests_get_with_retries( + response = request_with_retry_backoff( f"{self.base_url}/{patient_id}", headers=request_headers, timeout=5, max_retries=2, backoff_seconds=0.5 ) diff --git a/lambdas/shared/tests/test_common/test_pds_service.py b/lambdas/shared/tests/test_common/test_pds_service.py index 2d1a4bdfd..27a30e662 100644 --- a/lambdas/shared/tests/test_common/test_pds_service.py +++ b/lambdas/shared/tests/test_common/test_pds_service.py @@ -5,7 +5,7 @@ from responses import matchers from common.authentication import AppRestrictedAuth -from common.models.errors import UnhandledResponseError +from common.models.errors import BadRequestError from common.pds_service import PdsService @@ -56,10 +56,10 @@ def test_get_patient_details_not_found(self): def test_get_patient_details_error(self): """it should raise exception if PDS responded with error""" patient_id = "900000009" - response = {"msg": "an-error"} + response = {"BadRequest": "Some error occurred"} responses.add(responses.GET, f"{self.base_url}/{patient_id}", status=400, json=response) - with self.assertRaises(UnhandledResponseError) as e: + with self.assertRaises(BadRequestError) as e: # When self.pds_service.get_patient_details(patient_id) From 434fbebb4f6d1bb584395134b58cd0cf6746e0c2 Mon Sep 17 00:00:00 2001 From: Akol125 Date: Fri, 30 Jan 2026 15:43:44 +0000 Subject: [PATCH 5/5] add test for reusable library --- lambdas/shared/src/common/models/errors.py | 4 +- .../shared/tests/test_common/test_errors.py | 147 +++++++++++++++++- 2 files changed, 148 insertions(+), 3 deletions(-) diff --git a/lambdas/shared/src/common/models/errors.py b/lambdas/shared/src/common/models/errors.py index f50eb307f..83a1742d3 100644 --- a/lambdas/shared/src/common/models/errors.py +++ b/lambdas/shared/src/common/models/errors.py @@ -335,7 +335,7 @@ def request_with_retry_backoff( response = requests.get(url, headers=headers, timeout=timeout) if response.status_code not in Constants.RETRYABLE_STATUS_CODES: - return response + break if request_attempt < max_retries: logger.info( @@ -347,4 +347,4 @@ def request_with_retry_backoff( continue # out of retries, return last response to be handled by caller - return response + return response diff --git a/lambdas/shared/tests/test_common/test_errors.py b/lambdas/shared/tests/test_common/test_errors.py index fe1160ceb..033bc1ee0 100644 --- a/lambdas/shared/tests/test_common/test_errors.py +++ b/lambdas/shared/tests/test_common/test_errors.py @@ -1,7 +1,9 @@ import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, call, patch import src.common.models.errors as errors +from src.common.models.constants import Constants +from src.common.models.errors import request_with_retry_backoff class TestErrors(unittest.TestCase): @@ -253,3 +255,146 @@ def test_errors_server_error(self): self.assertEqual(issue.get("severity"), errors.Severity.error) self.assertEqual(issue.get("code"), errors.Code.server_error) self.assertEqual(issue.get("diagnostics"), f"{test_message}\n{test_response}") + + +class TestRaiseErrorResponse(unittest.TestCase): + def _make_response(self, status_code: int, text="err", json_data=None): + response = MagicMock() + response.status_code = status_code + response.text = text + response.json.return_value = json_data if json_data is not None else {"error": "something"} + return response + + @patch("src.common.models.errors.logger") + def test_400_raises_bad_request_error(self, mock_logger): + response = self._make_response(400, text="bad request") + + with self.assertRaises(errors.BadRequestError) as ctx: + errors.raise_error_response(response) + + self.assertIn("Bad request", str(ctx.exception)) + mock_logger.info.assert_called_once() + + @patch("src.common.models.errors.logger") + def test_403_raises_forbidden_error(self, mock_logger): + response = self._make_response(403, text="forbidden") + + with self.assertRaises(errors.ForbiddenError) as ctx: + errors.raise_error_response(response) + + self.assertIn("Forbidden", str(ctx.exception)) + mock_logger.info.assert_called_once() + + @patch("src.common.models.errors.logger") + def test_500_raises_server_error(self, mock_logger): + response = self._make_response(500, text="server error") + + with self.assertRaises(errors.ServerError) as ctx: + errors.raise_error_response(response) + + self.assertIn("Internal Server Error", str(ctx.exception)) + mock_logger.info.assert_called_once() + + @patch("src.common.models.errors.logger") + def test_unhandled_status_raises_unhandled_response_error(self, mock_logger): + response = self._make_response(418, text="I'm a teapot") + + with self.assertRaises(errors.UnhandledResponseError) as ctx: + errors.raise_error_response(response) + + self.assertIn("Unhandled error: 418", str(ctx.exception)) + mock_logger.info.assert_called_once() + + @patch("src.common.models.errors.logger") + def test_404_uses_resource_not_found_error_constructor(self, mock_logger): + """ + This validates the special-case 404 block: + raise exception_class(resource_type=response.json(), resource_id=error_message) + """ + response_json = {"resource": "Patient"} + response = self._make_response(404, text="not found", json_data=response_json) + + with self.assertRaises(errors.ResourceNotFoundError) as ctx: + errors.raise_error_response(response) + + # Here we validate the exception received those specific args + exc = ctx.exception + self.assertEqual(exc.resource_type, response_json) + self.assertEqual(exc.resource_id, "Resource not found") + mock_logger.info.assert_called_once() + + +def _make_response(status_code: int, text: str = "err"): + response = MagicMock() + response.status_code = status_code + response.text = text + return response + + +class TestRequestWithRetryBackoff(unittest.TestCase): + @patch("time.sleep") + @patch("requests.get") + def test_returns_immediately_for_non_retryable_status(self, mock_get, mock_sleep): + # Arrange + mock_get.return_value = _make_response(400) + # Ensure retryable codes include 429/5xx only (example) + with patch.object(Constants, "RETRYABLE_STATUS_CODES", {429, 500, 502, 503, 504}): + # Act + resp = request_with_retry_backoff("http://example.com", {}, max_retries=2, backoff_seconds=0.5) + + # Assert + self.assertEqual(resp.status_code, 400) + self.assertEqual(mock_get.call_count, 1) + mock_sleep.assert_not_called() + + @patch("time.sleep") + @patch("requests.get") + def test_retries_until_exhausted_for_retryable_status(self, mock_get, mock_sleep): + # Arrange: always retryable => should attempt 1 + max_retries times + mock_get.side_effect = [ + _make_response(503), + _make_response(503), + _make_response(503), + ] + with patch.object(Constants, "RETRYABLE_STATUS_CODES", {429, 500, 502, 503, 504}): + # Act + resp = request_with_retry_backoff("http://example.com", {}, max_retries=2, backoff_seconds=0.5) + + # Assert + self.assertEqual(resp.status_code, 503) + self.assertEqual(mock_get.call_count, 3) # 1 initial + 2 retries + self.assertEqual(mock_sleep.call_count, 2) # sleep between retries only + + @patch("time.sleep") + @patch("requests.get") + def test_stops_retrying_when_non_retryable_received(self, mock_get, mock_sleep): + # Arrange: retryable twice, then success => should stop + mock_get.side_effect = [ + _make_response(503), + _make_response(503), + _make_response(200, text="ok"), + ] + with patch.object(Constants, "RETRYABLE_STATUS_CODES", {429, 500, 502, 503, 504}): + # Act + resp = request_with_retry_backoff("http://example.com", {}, max_retries=2, backoff_seconds=0.5) + + # Assert + self.assertEqual(resp.status_code, 200) + self.assertEqual(mock_get.call_count, 3) + self.assertEqual(mock_sleep.call_count, 2) + + @patch("time.sleep") + @patch("requests.get") + def test_backoff_values_are_exponential(self, mock_get, mock_sleep): + # Arrange: always retryable + mock_get.side_effect = [ + _make_response(503), + _make_response(503), + _make_response(503), + ] + with patch.object(Constants, "RETRYABLE_STATUS_CODES", {429, 500, 502, 503, 504}): + # Act + request_with_retry_backoff("http://example.com", {}, max_retries=2, backoff_seconds=0.5) + + # Assert: 0.5, 1.0 for attempts 0 and 1 + mock_sleep.assert_has_calls([call(0.5), call(1.0)])