diff --git a/lambdas/tests/unit/handlers/test_pdm_get_fhir_document_reference_by_id_handler.py b/lambdas/tests/unit/handlers/test_pdm_get_fhir_document_reference_by_id_handler.py index 147e7b089..37b6080ab 100644 --- a/lambdas/tests/unit/handlers/test_pdm_get_fhir_document_reference_by_id_handler.py +++ b/lambdas/tests/unit/handlers/test_pdm_get_fhir_document_reference_by_id_handler.py @@ -1,5 +1,9 @@ +import json +from copy import deepcopy + import pytest +from enums.lambda_error import LambdaError from enums.mtls import MtlsCommonNames from enums.snomed_codes import SnomedCodes from handlers.get_fhir_document_reference_handler import ( @@ -9,6 +13,7 @@ from models.document_reference import DocumentReference from tests.unit.conftest import TEST_UUID from tests.unit.helpers.data.dynamo.dynamo_responses import MOCK_SEARCH_RESPONSE +from utils.lambda_exceptions import GetFhirDocumentReferenceException from utils.lambda_handler_utils import extract_bearer_token SNOMED_CODE = SnomedCodes.PATIENT_DATA.value.code @@ -27,7 +32,7 @@ "userAgent": "curl/7.64.1", "clientCert": { "clientCertPem": "-----BEGIN CERTIFICATE-----...", - "subjectDN": "CN=ndrclient.main.int.pdm.national.nhs.uk,O=NHS,C=UK", + "subjectDN": "CN=ndrclient.main.dev.pdm.national.nhs.uk,O=NHS,C=UK", "issuerDN": "CN=NHS Root CA,O=NHS,C=UK", "serialNumber": "12:34:56", "validity": { @@ -49,8 +54,7 @@ def mock_config_service(mocker): mock_config = mocker.patch( "handlers.get_fhir_document_reference_handler.DynamicConfigurationService", ) - mock_config_instance = mock_config.return_value - return mock_config_instance + return mock_config.return_value @pytest.fixture @@ -70,10 +74,42 @@ def mock_mtls_common_names(monkeypatch): monkeypatch.setattr( MtlsCommonNames, "_get_mtls_common_names", - classmethod(lambda cls: {"PDM": ["ndrclient.main.int.pdm.national.nhs.uk"]}), + classmethod(lambda cls: {"PDM": ["ndrclient.main.dev.pdm.national.nhs.uk"]}), + ) + + +@pytest.fixture +def mock_mtls_disallow_all(monkeypatch): + monkeypatch.setattr( + MtlsCommonNames, + "_get_mtls_common_names", + classmethod(lambda cls: {"PDM": []}), ) +@pytest.fixture +def unauthorized_cn_event(): + ev = deepcopy(MOCK_MTLS_VALID_EVENT) + ev["requestContext"]["identity"]["clientCert"][ + "subjectDN" + ] = "CN=unauthorised.client.nhs.uk,O=NHS,C=UK" + return ev + + +@pytest.fixture +def event_missing_client_cert(): + ev = deepcopy(MOCK_MTLS_VALID_EVENT) + ev["requestContext"]["identity"].pop("clientCert", None) + return ev + + +@pytest.fixture +def event_malformed_subject_dn(): + ev = deepcopy(MOCK_MTLS_VALID_EVENT) + ev["requestContext"]["identity"]["clientCert"]["subjectDN"] = "O=NHS,C=UK" + return ev + + def test_lambda_handler_happy_path_with_mtls_pdm_login( set_env, mock_mtls_common_names, @@ -88,7 +124,7 @@ def test_lambda_handler_happy_path_with_mtls_pdm_login( assert response["statusCode"] == 200 assert response["body"] == "test_document_reference" - # Verify correct method calls + mock_document_service.handle_get_document_reference_request.assert_called_once_with( SNOMED_CODE, TEST_UUID, @@ -107,3 +143,88 @@ def test_extract_document_parameters_valid_pdm(): document_id, snomed_code = extract_document_parameters(MOCK_MTLS_VALID_EVENT) assert snomed_code is None assert document_id == TEST_UUID + + +def test_lambda_handler_mtls_unauthorised_cn_returns_400( + set_env, + mock_mtls_disallow_all, + mock_document_service, + unauthorized_cn_event, + context, +): + resp = lambda_handler(unauthorized_cn_event, context) + assert resp["statusCode"] == 400 + mock_document_service.handle_get_document_reference_request.assert_not_called() + + +def test_lambda_handler_mtls_missing_client_cert_returns_401( + set_env, + mock_mtls_disallow_all, + mock_document_service, + event_missing_client_cert, + context, +): + resp = lambda_handler(event_missing_client_cert, context) + assert resp["statusCode"] == 401 + mock_document_service.handle_get_document_reference_request.assert_not_called() + + +def test_lambda_handler_mtls_malformed_subject_dn_returns_401( + set_env, + mock_mtls_disallow_all, + mock_document_service, + event_malformed_subject_dn, + context, +): + resp = lambda_handler(event_malformed_subject_dn, context) + assert resp["statusCode"] == 401 + mock_document_service.handle_get_document_reference_request.assert_not_called() + + +def test_lambda_handler_mtls_invalid_path_parameters_returns_400( + set_env, + mock_mtls_common_names, + mock_document_service, + context, +): + ev = deepcopy(MOCK_MTLS_VALID_EVENT) + ev["pathParameters"] = {"id": "invalid_format_no_tilde"} + resp = lambda_handler(ev, context) + assert resp["statusCode"] == 400 + mock_document_service.handle_get_document_reference_request.assert_not_called() + + +@pytest.mark.parametrize( + "status, lambda_error", + [ + (404, LambdaError.DocumentReferenceNotFound), + (403, LambdaError.DocumentReferenceForbidden), + (400, LambdaError.DocumentReferenceMissingParameters), + (500, LambdaError.DocumentReferenceGeneralError), + ], +) +def test_lambda_handler_mtls_service_errors( + set_env, + mock_mtls_common_names, + mock_document_service, + context, + status, + lambda_error, +): + mock_document_service.handle_get_document_reference_request.side_effect = ( + GetFhirDocumentReferenceException(status, lambda_error) + ) + + resp = lambda_handler(MOCK_MTLS_VALID_EVENT, context) + assert resp["statusCode"] == status + + body = json.loads(resp["body"]) + assert body["resourceType"] == "OperationOutcome" + assert ( + body["issue"][0]["details"]["coding"][0]["code"] + == lambda_error.value.get("fhir_coding").code + ) + assert ( + body["issue"][0]["details"]["coding"][0]["display"] + == lambda_error.value.get("fhir_coding").display + ) diff --git a/lambdas/tests/unit/services/test_pdm_get_fhir_document_reference_by_id_service.py b/lambdas/tests/unit/services/test_pdm_get_fhir_document_reference_by_id_service.py index 04d4cfe17..eeef2f8ca 100644 --- a/lambdas/tests/unit/services/test_pdm_get_fhir_document_reference_by_id_service.py +++ b/lambdas/tests/unit/services/test_pdm_get_fhir_document_reference_by_id_service.py @@ -36,13 +36,11 @@ def test_get_document_reference_service(patched_service): def test_handle_get_document_reference_request(patched_service, mocker, set_env): documents = create_test_doc_store_refs() - expected = documents[0] - mock_document_ref = documents[0] mocker.patch.object( patched_service, "get_core_document_references", - return_value=mock_document_ref, + return_value=expected, ) actual = patched_service.handle_get_document_reference_request( @@ -54,18 +52,13 @@ def test_handle_get_document_reference_request(patched_service, mocker, set_env) def test_get_dynamo_table_for_patient_data_doc_type(patched_service): - """Test _get_dynamo_table_for_doc_type method with a non-Lloyd George document type.""" - patient_data_code = SnomedCodes.PATIENT_DATA.value - result = patched_service._get_dynamo_table_for_doc_type(patient_data_code) assert result == str(DynamoTables.CORE) -# Not PDM however the code that this relates to was introduced because of PDM +# The following two tests are not PDM however the code that this relates to was introduced because of PDM def test_get_dynamo_table_for_unsupported_doc_type(patched_service): - """Test _get_dynamo_table_for_doc_type method with a non-Lloyd George document type.""" - non_lg_code = SnomedCode(code="non-lg-code", display_name="Non Lloyd George") with pytest.raises(InvalidDocTypeException) as exc_info: @@ -75,18 +68,13 @@ def test_get_dynamo_table_for_unsupported_doc_type(patched_service): assert exc_info.value.error == LambdaError.DocTypeDB -# Not PDM however the code that this relates to was introduced because of PDM def test_get_dynamo_table_for_lloyd_george_doc_type(patched_service): - """Test _get_dynamo_table_for_doc_type method with Lloyd George document type.""" lg_code = SnomedCodes.LLOYD_GEORGE.value - result = patched_service._get_dynamo_table_for_doc_type(lg_code) - assert result == str(DynamoTables.LLOYD_GEORGE) def test_get_document_references_empty_result(patched_service): - # Test when no documents are found patched_service.document_service.get_item.return_value = None with pytest.raises(GetFhirDocumentReferenceException) as exc_info: