diff --git a/lambdas/handlers/mns_notification_handler.py b/lambdas/handlers/mns_notification_handler.py index ffc8df2a8..9e4e797ad 100644 --- a/lambdas/handlers/mns_notification_handler.py +++ b/lambdas/handlers/mns_notification_handler.py @@ -1,8 +1,9 @@ import json +from pydantic import ValidationError + from enums.mns_notification_types import MNSNotificationTypes from models.sqs.mns_sqs_message import MNSSQSMessage -from pydantic import ValidationError from services.process_mns_message_service import MNSNotificationService from utils.audit_logging_setup import LoggingService from utils.decorators.ensure_env_var import ensure_environment_variables @@ -21,7 +22,8 @@ "LLOYD_GEORGE_DYNAMODB_NAME", "DOCUMENT_REVIEW_DYNAMODB_NAME", "MNS_NOTIFICATION_QUEUE_URL", - ] + "RESTRICTIONS_TABLE_NAME", + ], ) @override_error_check def lambda_handler(event, context): @@ -38,7 +40,7 @@ def lambda_handler(event, context): request_context.patient_nhs_no = mns_message.subject.nhs_number logger.info( - f"Processing SQS message for nhs number: {mns_message.subject.nhs_number}" + f"Processing SQS message for nhs number: {mns_message.subject.nhs_number}", ) if mns_message.type in MNSNotificationTypes.__members__.values(): diff --git a/lambdas/models/user_restrictions/user_restrictions.py b/lambdas/models/user_restrictions/user_restrictions.py index c38e80b1b..210fb7cea 100644 --- a/lambdas/models/user_restrictions/user_restrictions.py +++ b/lambdas/models/user_restrictions/user_restrictions.py @@ -10,11 +10,12 @@ class UserRestrictionsFields(StrEnum): ID = "ID" CREATOR = "CreatorSmartcard" RESTRICTED_USER = "RestrictedSmartcard" - REMOVED_BY = "RemoverSmartCard" - CUSTODIAN = "Custodian" + REMOVED_BY = "RemoverSmartcard" NHS_NUMBER = "NhsNumber" + CUSTODIAN = "Custodian" IS_ACTIVE = "IsActive" LAST_UPDATED = "LastUpdated" + CREATED = "Created" class UserRestrictionIndexes(StrEnum): diff --git a/lambdas/services/process_mns_message_service.py b/lambdas/services/process_mns_message_service.py index 96e647195..ff88df005 100644 --- a/lambdas/services/process_mns_message_service.py +++ b/lambdas/services/process_mns_message_service.py @@ -1,16 +1,20 @@ import os from botocore.exceptions import ClientError + from enums.death_notification_status import DeathNotificationStatus from enums.mns_notification_types import MNSNotificationTypes from enums.patient_ods_inactive_status import PatientOdsInactiveStatus from models.document_reference import DocumentReference from models.document_review import DocumentUploadReviewReference from models.sqs.mns_sqs_message import MNSSQSMessage +from models.user_restrictions.user_restrictions import UserRestriction from services.base.sqs_service import SQSService from services.document_reference_service import DocumentReferenceService from services.document_upload_review_service import DocumentUploadReviewService -from services.feature_flags_service import FeatureFlagService +from services.user_restrictions.user_restriction_dynamo_service import ( + UserRestrictionDynamoService, +) from utils.audit_logging_setup import LoggingService from utils.exceptions import PdsErrorException from utils.utilities import get_pds_service @@ -25,7 +29,7 @@ def __init__(self): self.pds_service = get_pds_service() self.sqs_service = SQSService() self.queue = os.getenv("MNS_NOTIFICATION_QUEUE_URL") - self.feature_flag_service = FeatureFlagService() + self.restrictions_dynamo_service = UserRestrictionDynamoService() def handle_mns_notification(self, message: MNSSQSMessage): try: @@ -36,35 +40,31 @@ def handle_mns_notification(self, message: MNSSQSMessage): case MNSNotificationTypes.DEATH_NOTIFICATION: logger.info("Handling death status notification.") self.handle_death_notification(message) - - except PdsErrorException as e: - logger.info("An error occurred when calling PDS") - logger.info( - f"Unable to process message: {message.id}, of type: {message.type}" - ) - logger.info(f"{e}") - raise e - - except ClientError as e: - logger.info( - f"Unable to process message: {message.id}, of type: {message.type}" + except (PdsErrorException, ClientError) as e: + logger.error( + f"Unable to process message: {message.id}, of type: {message.type}", ) - logger.info(f"{e}") - raise e + logger.error(str(e)) + raise def handle_gp_change_notification(self, message: MNSSQSMessage) -> None: - lg_documents, review_documents = self.get_all_patient_documents( - message.subject.nhs_number + nhs_number = message.subject.nhs_number + lg_documents, review_documents, restrictions = self._fetch_patient_data( + nhs_number, ) - if not lg_documents and not review_documents: + if not (lg_documents or review_documents or restrictions): return - updated_ods_code = self.get_updated_gp_ods(message.subject.nhs_number) - self.update_all_patient_documents( - lg_documents, review_documents, updated_ods_code + updated_ods_code = self.get_updated_gp_ods(nhs_number) + self._apply_ods_update( + nhs_number, + lg_documents, + review_documents, + restrictions, + updated_ods_code, ) - logger.info("Update complete for change of GP") + logger.info("Update complete for change of GP.") def handle_death_notification(self, message: MNSSQSMessage) -> None: death_notification_type = message.data["deathNotificationStatus"] @@ -73,52 +73,91 @@ def handle_death_notification(self, message: MNSSQSMessage) -> None: match death_notification_type: case DeathNotificationStatus.INFORMAL: logger.info( - "Patient is deceased - INFORMAL, moving on to the next message." + "Patient is deceased - INFORMAL, moving on to the next message.", ) case DeathNotificationStatus.REMOVED: - lg_documents, review_documents = self.get_all_patient_documents( - nhs_number + lg_documents, review_documents, restrictions = self._fetch_patient_data( + nhs_number, ) - if lg_documents or review_documents: - updated_ods_code = self.get_updated_gp_ods(nhs_number) - self.update_all_patient_documents( - lg_documents, review_documents, updated_ods_code - ) - logger.info("Update complete for death notification change.") + if not (lg_documents or review_documents or restrictions): + return + + updated_ods_code = self.get_updated_gp_ods(nhs_number) + self._apply_ods_update( + nhs_number, + lg_documents, + review_documents, + restrictions, + updated_ods_code, + ) + logger.info("Update complete for death notification change.") case DeathNotificationStatus.FORMAL: - lg_documents, review_documents = self.get_all_patient_documents( - nhs_number + lg_documents, review_documents, restrictions = self._fetch_patient_data( + nhs_number, + ) + self._apply_ods_update( + nhs_number, + lg_documents, + review_documents, + restrictions, + PatientOdsInactiveStatus.DECEASED, + ) + logger.info( + f"Update complete, patient marked {PatientOdsInactiveStatus.DECEASED}.", ) - if lg_documents or review_documents: - self.update_all_patient_documents( - lg_documents, - review_documents, - PatientOdsInactiveStatus.DECEASED, - ) - logger.info( - f"Update complete, patient marked {PatientOdsInactiveStatus.DECEASED}." - ) + def _fetch_patient_data( + self, + nhs_number: str, + ) -> tuple[ + list[DocumentReference], + list[DocumentUploadReviewReference], + list[UserRestriction], + ]: + lg_documents, review_documents = self.get_all_patient_documents(nhs_number) + restrictions = ( + self.restrictions_dynamo_service.query_restrictions_by_nhs_number( + nhs_number=nhs_number, + ) + ) + return lg_documents, review_documents, restrictions + + def _apply_ods_update( + self, + nhs_number: str, + lg_documents: list[DocumentReference], + review_documents: list[DocumentUploadReviewReference], + restrictions: list[UserRestriction], + ods_code: str, + ) -> None: + if lg_documents or review_documents: + self.update_all_patient_documents(lg_documents, review_documents, ods_code) + if restrictions: + self.update_restrictions( + nhs_number=nhs_number, + custodian=ods_code, + restrictions=restrictions, + ) def get_updated_gp_ods(self, nhs_number: str) -> str: patient_details = self.pds_service.fetch_patient_details(nhs_number) return patient_details.general_practice_ods def get_all_patient_documents( - self, nhs_number: str + self, + nhs_number: str, ) -> tuple[list[DocumentReference], list[DocumentUploadReviewReference]]: - """Fetch patient documents from both LG and document review tables.""" lg_documents = ( self.lg_document_service.fetch_documents_from_table_with_nhs_number( - nhs_number + nhs_number, ) ) review_documents = ( self.document_review_service.fetch_documents_from_table_with_nhs_number( - nhs_number + nhs_number, ) ) @@ -133,9 +172,26 @@ def update_all_patient_documents( """Update documents in both tables if they exist.""" if lg_documents: self.lg_document_service.update_patient_ods_code( - lg_documents, updated_ods_code + lg_documents, + updated_ods_code, ) if review_documents: self.document_review_service.update_document_review_custodian( - review_documents, updated_ods_code + review_documents, + updated_ods_code, + ) + + def update_restrictions( + self, + nhs_number: str, + custodian: str, + restrictions: list[UserRestriction], + ) -> None: + for restriction in restrictions: + logger.info(f"Updating restriction {restriction.id}") + self.restrictions_dynamo_service.update_restriction_custodian( + restriction_id=restriction.id, + updated_custodian=custodian, ) + + logger.info(f"All restrictions for patient {nhs_number} updated.") diff --git a/lambdas/services/user_restrictions/user_restriction_dynamo_service.py b/lambdas/services/user_restrictions/user_restriction_dynamo_service.py index 47ebf7a8b..269af1d56 100644 --- a/lambdas/services/user_restrictions/user_restriction_dynamo_service.py +++ b/lambdas/services/user_restrictions/user_restriction_dynamo_service.py @@ -17,6 +17,7 @@ from utils.dynamo_utils import build_mixed_condition_expression from utils.exceptions import ( UserRestrictionConditionCheckFailedException, + UserRestrictionDynamoDBException, UserRestrictionValidationException, ) @@ -77,31 +78,33 @@ def update_restriction_inactive( removed_by: str, patient_id: str, ): - try: - logger.info("Updating user restriction inactive.") - current_time = int(datetime.now(timezone.utc).timestamp()) + logger.info("Updating user restriction inactive.") + current_time = int(datetime.now(timezone.utc).timestamp()) - updated_fields = { - UserRestrictionsFields.REMOVED_BY.value: removed_by, - UserRestrictionsFields.LAST_UPDATED.value: current_time, - UserRestrictionsFields.IS_ACTIVE.value: False, - } + updated_fields = { + UserRestrictionsFields.REMOVED_BY.value: removed_by, + UserRestrictionsFields.LAST_UPDATED.value: current_time, + UserRestrictionsFields.IS_ACTIVE.value: False, + } + try: self.dynamo_service.update_item( table_name=self.table_name, key_pair={UserRestrictionsFields.ID.value: restriction_id}, updated_fields=updated_fields, - condition_expression=f"{UserRestrictionsFields.IS_ACTIVE.value} = :true " - f"AND {UserRestrictionsFields.RESTRICTED_USER.value} <> :user_id " - f"AND {UserRestrictionsFields.NHS_NUMBER.value} = :patient_id", + condition_expression=( + f"{UserRestrictionsFields.IS_ACTIVE} = :true" + f" AND {UserRestrictionsFields.RESTRICTED_USER} <> :user_id" + f" AND {UserRestrictionsFields.NHS_NUMBER} = :patient_id" + ), expression_attribute_values={ ":true": True, ":user_id": removed_by, ":patient_id": patient_id, }, ) - except ClientError as e: + logger.error(e) if ( e.response["Error"]["Code"] == DynamoClientErrors.CONDITION_CHECK_FAILURE @@ -111,7 +114,62 @@ def update_restriction_inactive( f"Unexpected DynamoDB error in update_restriction_inactive: " f"{e.response['Error']['Code']} - {e}", ) - raise e + raise UserRestrictionDynamoDBException( + "An issue occurred while updating user restriction inactive", + ) + + def query_restrictions_by_nhs_number( + self, + nhs_number: str, + ) -> list[UserRestriction]: + try: + logger.info("Building IsActive filter for DynamoDB query.") + filter_builder = DynamoQueryFilterBuilder() + filter_builder.add_condition( + UserRestrictionsFields.IS_ACTIVE, + AttributeOperator.EQUAL, + True, + ) + active_filter_expression = filter_builder.build() + + logger.info("Querying Restrictions by NHS Number.") + items = self.dynamo_service.query_table( + table_name=self.table_name, + index_name=UserRestrictionIndexes.NHS_NUMBER_INDEX, + search_key=UserRestrictionsFields.NHS_NUMBER, + search_condition=nhs_number, + query_filter=active_filter_expression, + ) + + return self._validate_restrictions(items) + except ClientError as e: + logger.error(e) + raise UserRestrictionDynamoDBException( + "An issue occurred while querying restrictions", + ) + + def update_restriction_custodian(self, restriction_id: str, updated_custodian: str): + logger.info(f"Updating custodian for restriction: {restriction_id}") + current_time = int(datetime.now(timezone.utc).timestamp()) + + updated_fields = { + UserRestrictionsFields.LAST_UPDATED.value: current_time, + UserRestrictionsFields.CUSTODIAN.value: updated_custodian, + } + + try: + self.dynamo_service.update_item( + table_name=self.table_name, + key_pair={UserRestrictionsFields.ID.value: restriction_id}, + updated_fields=updated_fields, + ) + except ClientError as e: + logger.error( + f"DynamoDB ClientError when updating custodian for restriction {restriction_id}: {e}", + ) + raise UserRestrictionDynamoDBException( + f"An issue occurred while updating restriction custodian for restriction {restriction_id}", + ) from e @staticmethod def _build_query_filter( diff --git a/lambdas/tests/unit/services/test_process_mns_message_service.py b/lambdas/tests/unit/services/test_process_mns_message_service.py index 8bcd0945c..f9bd0a63f 100644 --- a/lambdas/tests/unit/services/test_process_mns_message_service.py +++ b/lambdas/tests/unit/services/test_process_mns_message_service.py @@ -1,13 +1,18 @@ from unittest.mock import MagicMock +import freezegun import pytest from botocore.exceptions import ClientError + +from enums.feature_flags import FeatureFlags from enums.patient_ods_inactive_status import PatientOdsInactiveStatus from models.document_reference import DocumentReference from models.document_review import DocumentUploadReviewReference from models.sqs.mns_sqs_message import MNSSQSMessage +from models.user_restrictions.user_restrictions import UserRestriction +from services.feature_flags_service import FeatureFlagService from services.process_mns_message_service import MNSNotificationService -from tests.unit.conftest import TEST_CURRENT_GP_ODS, TEST_NHS_NUMBER +from tests.unit.conftest import TEST_CURRENT_GP_ODS, TEST_NHS_NUMBER, TEST_UUID from tests.unit.handlers.test_mns_notification_handler import ( MOCK_DEATH_MESSAGE_BODY, MOCK_GP_CHANGE_MESSAGE_BODY, @@ -18,24 +23,64 @@ @pytest.fixture -def mns_service(mocker, set_env, monkeypatch): +def mock_user_restriction_disabled(mocker): + mock_function = mocker.patch.object(FeatureFlagService, "get_feature_flags_by_flag") + mock_feature_flag = mock_function.return_value = { + FeatureFlags.USER_RESTRICTION_ENABLED: False, + } + yield mock_feature_flag + + +@pytest.fixture +def mock_user_restriction_enabled(mocker): + mock_function = mocker.patch.object(FeatureFlagService, "get_feature_flags_by_flag") + mock_feature_flag = mock_function.return_value = { + FeatureFlags.USER_RESTRICTION_ENABLED: True, + } + yield mock_feature_flag + + +MOCK_RESTRICTION_DICT = { + "ID": TEST_UUID, + "RestrictedSmartcard": "123456789012", + "NhsNumber": TEST_NHS_NUMBER, + "Custodian": TEST_CURRENT_GP_ODS, + "Created": 1700000000, + "CreatorSmartcard": "223456789022", + "RemoverSmartCard": None, + "IsActive": True, + "LastUpdated": 1700000001, +} + +MOCK_RESTRICTION = UserRestriction.model_validate(MOCK_RESTRICTION_DICT) + + +@pytest.fixture +def mns_service(mocker, set_env, monkeypatch, mock_user_restriction_enabled): monkeypatch.setenv("PDS_FHIR_IS_STUBBED", "False") service = MNSNotificationService() mocker.patch.object(service, "pds_service") mocker.patch.object(service, "document_review_service") mocker.patch.object(service, "lg_document_service") mocker.patch.object(service, "sqs_service") + mocker.patch.object(service, "restrictions_dynamo_service") yield service @pytest.fixture -def mns_service_feature_disabled(mocker, set_env, monkeypatch): +def mns_service_restrictions_feature_disabled( + mocker, + set_env, + monkeypatch, + mock_user_restriction_disabled, +): monkeypatch.setenv("PDS_FHIR_IS_STUBBED", "False") service = MNSNotificationService() mocker.patch.object(service, "pds_service") mocker.patch.object(service, "document_review_service") mocker.patch.object(service, "lg_document_service") mocker.patch.object(service, "sqs_service") + mocker.patch.object(service, "restrictions_dynamo_service") yield service @@ -89,7 +134,9 @@ def mock_document_review_references(mocker): def test_handle_gp_change_message_called_message_type_gp_change( - mns_service, mock_handle_gp_change, mock_handle_death_notification + mns_service, + mock_handle_gp_change, + mock_handle_death_notification, ): mns_service.handle_mns_notification(gp_change_message) @@ -98,7 +145,9 @@ def test_handle_gp_change_message_called_message_type_gp_change( def test_handle_gp_change_message_not_called_message_death_message( - mns_service, mock_handle_death_notification, mock_handle_gp_change + mns_service, + mock_handle_death_notification, + mock_handle_gp_change, ): mns_service.handle_mns_notification(death_notification_message) @@ -119,10 +168,13 @@ def test_handle_mns_notification_error_handling_pds_error(mns_service, mocker): def test_handle_mns_notification_error_handling_client_error(mns_service, mocker): client_error = ClientError( - {"Error": {"Code": "TestException", "Message": "Test exception"}}, "operation" + {"Error": {"Code": "TestException", "Message": "Test exception"}}, + "operation", ) mocker.patch.object( - mns_service, "handle_gp_change_notification", side_effect=client_error + mns_service, + "handle_gp_change_notification", + side_effect=client_error, ) with pytest.raises(ClientError): @@ -130,7 +182,10 @@ def test_handle_mns_notification_error_handling_client_error(mns_service, mocker def test_handle_gp_change_notification_with_patient_documents( - mns_service, mock_document_references, mock_document_review_references, mocker + mns_service, + mock_document_references, + mock_document_review_references, + mocker, ): mocker.patch.object(mns_service, "get_all_patient_documents") mns_service.get_all_patient_documents.return_value = ( @@ -144,26 +199,31 @@ def test_handle_gp_change_notification_with_patient_documents( mns_service.handle_gp_change_notification(gp_change_message) mns_service.get_all_patient_documents.assert_called_once_with( - gp_change_message.subject.nhs_number + gp_change_message.subject.nhs_number, ) mns_service.get_updated_gp_ods.assert_called_once_with( - gp_change_message.subject.nhs_number + gp_change_message.subject.nhs_number, ) mns_service.update_all_patient_documents.assert_called_once_with( - mock_document_references, mock_document_review_references, NEW_ODS_CODE + mock_document_references, + mock_document_review_references, + NEW_ODS_CODE, ) def test_handle_gp_change_notification_no_patient_documents(mns_service, mocker): mocker.patch.object(mns_service, "get_all_patient_documents") mns_service.get_all_patient_documents.return_value = ([], []) + mns_service.restrictions_dynamo_service.query_restrictions_by_nhs_number.return_value = ( + [] + ) mocker.patch.object(mns_service, "get_updated_gp_ods") mocker.patch.object(mns_service, "update_all_patient_documents") mns_service.handle_gp_change_notification(gp_change_message) mns_service.get_all_patient_documents.assert_called_once_with( - gp_change_message.subject.nhs_number + gp_change_message.subject.nhs_number, ) mns_service.get_updated_gp_ods.assert_not_called() mns_service.update_all_patient_documents.assert_not_called() @@ -182,7 +242,10 @@ def test_handle_death_notification_informal(mns_service, mocker): def test_handle_death_notification_removed_with_documents( - mns_service, mock_document_references, mock_document_review_references, mocker + mns_service, + mock_document_references, + mock_document_review_references, + mocker, ): mocker.patch.object(mns_service, "get_all_patient_documents") mocker.patch.object(mns_service, "get_updated_gp_ods") @@ -196,13 +259,15 @@ def test_handle_death_notification_removed_with_documents( mns_service.handle_death_notification(removed_death_notification_message) mns_service.get_all_patient_documents.assert_called_once_with( - removed_death_notification_message.subject.nhs_number + removed_death_notification_message.subject.nhs_number, ) mns_service.get_updated_gp_ods.assert_called_once_with( - removed_death_notification_message.subject.nhs_number + removed_death_notification_message.subject.nhs_number, ) mns_service.update_all_patient_documents.assert_called_once_with( - mock_document_references, mock_document_review_references, NEW_ODS_CODE + mock_document_references, + mock_document_review_references, + NEW_ODS_CODE, ) @@ -211,18 +276,24 @@ def test_handle_death_notification_removed_no_documents(mns_service, mocker): mocker.patch.object(mns_service, "get_updated_gp_ods") mocker.patch.object(mns_service, "update_all_patient_documents") mns_service.get_all_patient_documents.return_value = ([], []) + mns_service.restrictions_dynamo_service.query_restrictions_by_nhs_number.return_value = ( + [] + ) mns_service.handle_death_notification(removed_death_notification_message) mns_service.get_all_patient_documents.assert_called_once_with( - removed_death_notification_message.subject.nhs_number + removed_death_notification_message.subject.nhs_number, ) mns_service.get_updated_gp_ods.assert_not_called() mns_service.update_all_patient_documents.assert_not_called() def test_handle_death_notification_formal_with_documents( - mns_service, mock_document_references, mock_document_review_references, mocker + mns_service, + mock_document_references, + mock_document_review_references, + mocker, ): mocker.patch.object(mns_service, "get_all_patient_documents") mocker.patch.object(mns_service, "get_updated_gp_ods") @@ -235,7 +306,7 @@ def test_handle_death_notification_formal_with_documents( mns_service.handle_death_notification(death_notification_message) mns_service.get_all_patient_documents.assert_called_once_with( - death_notification_message.subject.nhs_number + death_notification_message.subject.nhs_number, ) mns_service.update_all_patient_documents.assert_called_once_with( mock_document_references, @@ -254,7 +325,7 @@ def test_handle_death_notification_formal_no_documents(mns_service, mocker): mns_service.handle_death_notification(death_notification_message) mns_service.get_all_patient_documents.assert_called_once_with( - death_notification_message.subject.nhs_number + death_notification_message.subject.nhs_number, ) mns_service.update_all_patient_documents.assert_not_called() @@ -269,12 +340,15 @@ def test_get_updated_gp_ods(mns_service): assert result == expected_ods mns_service.pds_service.fetch_patient_details.assert_called_once_with( - TEST_NHS_NUMBER + TEST_NHS_NUMBER, ) def test_pds_is_called_death_notification_removed( - mns_service, mocker, mock_document_references, mock_document_review_references + mns_service, + mocker, + mock_document_references, + mock_document_review_references, ): mocker.patch.object(mns_service, "get_updated_gp_ods") mocker.patch.object(mns_service, "update_all_patient_documents") @@ -290,6 +364,28 @@ def test_pds_is_called_death_notification_removed( mns_service.update_all_patient_documents.assert_called() +@pytest.mark.parametrize( + "mns_event", + [removed_death_notification_message, gp_change_message], +) +def test_pds_called_only_restrictions_present( + mns_service, + mocker, + mns_event, +): + mocker.patch.object(mns_service, "get_updated_gp_ods") + mocker.patch.object(mns_service, "get_all_patient_documents") + mns_service.get_all_patient_documents.return_value = ([], []) + + mns_service.restrictions_dynamo_service.query_restrictions_by_nhs_number.return_value = [ + MOCK_RESTRICTION, + ] + + mns_service.handle_mns_notification(mns_event) + + mns_service.get_updated_gp_ods.assert_called() + + def test_get_all_patient_documents(mns_service, mocker): expected_lg_docs = [MagicMock(spec=DocumentReference)] expected_review_docs = [MagicMock(spec=DocumentUploadReviewReference)] @@ -306,54 +402,71 @@ def test_get_all_patient_documents(mns_service, mocker): assert lg_docs == expected_lg_docs assert review_docs == expected_review_docs mns_service.lg_document_service.fetch_documents_from_table_with_nhs_number.assert_called_once_with( - TEST_NHS_NUMBER + TEST_NHS_NUMBER, ) mns_service.document_review_service.fetch_documents_from_table_with_nhs_number.assert_called_once_with( - TEST_NHS_NUMBER + TEST_NHS_NUMBER, ) def test_update_all_patient_documents_with_both_types( - mns_service, mock_document_references, mock_document_review_references, mocker + mns_service, + mock_document_references, + mock_document_review_references, + mocker, ): mns_service.update_all_patient_documents( - mock_document_references, mock_document_review_references, NEW_ODS_CODE + mock_document_references, + mock_document_review_references, + NEW_ODS_CODE, ) mns_service.lg_document_service.update_patient_ods_code.assert_called_once_with( - mock_document_references, NEW_ODS_CODE + mock_document_references, + NEW_ODS_CODE, ) mns_service.document_review_service.update_document_review_custodian.assert_called_once_with( - mock_document_review_references, NEW_ODS_CODE + mock_document_review_references, + NEW_ODS_CODE, ) def test_update_all_patient_documents_with_only_lg_documents( - mns_service, mock_document_references, mocker + mns_service, + mock_document_references, + mocker, ): mns_service.update_all_patient_documents(mock_document_references, [], NEW_ODS_CODE) mns_service.lg_document_service.update_patient_ods_code.assert_called_once_with( - mock_document_references, NEW_ODS_CODE + mock_document_references, + NEW_ODS_CODE, ) mns_service.document_review_service.update_document_review_custodian.assert_not_called() def test_update_all_patient_documents_with_only_review_documents( - mns_service, mock_document_review_references, mocker + mns_service, + mock_document_review_references, + mocker, ): mns_service.update_all_patient_documents( - [], mock_document_review_references, NEW_ODS_CODE + [], + mock_document_review_references, + NEW_ODS_CODE, ) mns_service.lg_document_service.update_patient_ods_code.assert_not_called() mns_service.document_review_service.update_document_review_custodian.assert_called_once_with( - mock_document_review_references, NEW_ODS_CODE + mock_document_review_references, + NEW_ODS_CODE, ) def test_handle_gp_change_notification_with_only_lg_documents( - mns_service, mock_document_references, mocker + mns_service, + mock_document_references, + mocker, ): mocker.patch.object(mns_service, "get_all_patient_documents") mns_service.get_all_patient_documents.return_value = ( @@ -367,18 +480,22 @@ def test_handle_gp_change_notification_with_only_lg_documents( mns_service.handle_gp_change_notification(gp_change_message) mns_service.get_all_patient_documents.assert_called_once_with( - gp_change_message.subject.nhs_number + gp_change_message.subject.nhs_number, ) mns_service.get_updated_gp_ods.assert_called_once_with( - gp_change_message.subject.nhs_number + gp_change_message.subject.nhs_number, ) mns_service.update_all_patient_documents.assert_called_once_with( - mock_document_references, [], NEW_ODS_CODE + mock_document_references, + [], + NEW_ODS_CODE, ) def test_handle_gp_change_notification_with_only_review_documents( - mns_service, mock_document_review_references, mocker + mns_service, + mock_document_review_references, + mocker, ): mocker.patch.object(mns_service, "get_all_patient_documents") mns_service.get_all_patient_documents.return_value = ( @@ -392,18 +509,22 @@ def test_handle_gp_change_notification_with_only_review_documents( mns_service.handle_gp_change_notification(gp_change_message) mns_service.get_all_patient_documents.assert_called_once_with( - gp_change_message.subject.nhs_number + gp_change_message.subject.nhs_number, ) mns_service.get_updated_gp_ods.assert_called_once_with( - gp_change_message.subject.nhs_number + gp_change_message.subject.nhs_number, ) mns_service.update_all_patient_documents.assert_called_once_with( - [], mock_document_review_references, NEW_ODS_CODE + [], + mock_document_review_references, + NEW_ODS_CODE, ) def test_handle_death_notification_formal_with_only_lg_documents( - mns_service, mock_document_references, mocker + mns_service, + mock_document_references, + mocker, ): mocker.patch.object(mns_service, "get_all_patient_documents") mocker.patch.object(mns_service, "get_updated_gp_ods") @@ -416,7 +537,7 @@ def test_handle_death_notification_formal_with_only_lg_documents( mns_service.handle_death_notification(death_notification_message) mns_service.get_all_patient_documents.assert_called_once_with( - death_notification_message.subject.nhs_number + death_notification_message.subject.nhs_number, ) mns_service.update_all_patient_documents.assert_called_once_with( mock_document_references, @@ -427,7 +548,9 @@ def test_handle_death_notification_formal_with_only_lg_documents( def test_handle_death_notification_formal_with_only_review_documents( - mns_service, mock_document_review_references, mocker + mns_service, + mock_document_review_references, + mocker, ): mocker.patch.object(mns_service, "get_all_patient_documents") mocker.patch.object(mns_service, "get_updated_gp_ods") @@ -440,7 +563,7 @@ def test_handle_death_notification_formal_with_only_review_documents( mns_service.handle_death_notification(death_notification_message) mns_service.get_all_patient_documents.assert_called_once_with( - death_notification_message.subject.nhs_number + death_notification_message.subject.nhs_number, ) mns_service.update_all_patient_documents.assert_called_once_with( [], @@ -451,7 +574,9 @@ def test_handle_death_notification_formal_with_only_review_documents( def test_handle_death_notification_removed_with_only_lg_documents( - mns_service, mock_document_references, mocker + mns_service, + mock_document_references, + mocker, ): mocker.patch.object(mns_service, "get_all_patient_documents") mocker.patch.object(mns_service, "get_updated_gp_ods") @@ -465,18 +590,22 @@ def test_handle_death_notification_removed_with_only_lg_documents( mns_service.handle_death_notification(removed_death_notification_message) mns_service.get_all_patient_documents.assert_called_once_with( - removed_death_notification_message.subject.nhs_number + removed_death_notification_message.subject.nhs_number, ) mns_service.get_updated_gp_ods.assert_called_once_with( - removed_death_notification_message.subject.nhs_number + removed_death_notification_message.subject.nhs_number, ) mns_service.update_all_patient_documents.assert_called_once_with( - mock_document_references, [], NEW_ODS_CODE + mock_document_references, + [], + NEW_ODS_CODE, ) def test_handle_death_notification_removed_with_only_review_documents( - mns_service, mock_document_review_references, mocker + mns_service, + mock_document_review_references, + mocker, ): mocker.patch.object(mns_service, "get_all_patient_documents") mocker.patch.object(mns_service, "get_updated_gp_ods") @@ -490,11 +619,94 @@ def test_handle_death_notification_removed_with_only_review_documents( mns_service.handle_death_notification(removed_death_notification_message) mns_service.get_all_patient_documents.assert_called_once_with( - removed_death_notification_message.subject.nhs_number + removed_death_notification_message.subject.nhs_number, ) mns_service.get_updated_gp_ods.assert_called_once_with( - removed_death_notification_message.subject.nhs_number + removed_death_notification_message.subject.nhs_number, ) mns_service.update_all_patient_documents.assert_called_once_with( - [], mock_document_review_references, NEW_ODS_CODE + [], + mock_document_review_references, + NEW_ODS_CODE, + ) + + +@pytest.mark.parametrize( + "mns_event", + [removed_death_notification_message, gp_change_message], +) +def test_handle_mns_notification_calls_update_restrictions_called_on_notification( + mns_service, + mns_event, + mocker, +): + mns_service.restrictions_dynamo_service.query_restrictions_by_nhs_number.return_value = [ + MOCK_RESTRICTION, + ] + mocker.patch.object(mns_service, "update_restrictions") + updated_ods = mocker.patch.object( + mns_service, + "get_updated_gp_ods", + return_value=TEST_CURRENT_GP_ODS, + ) + + mns_service.handle_mns_notification(mns_event) + + mns_service.update_restrictions.assert_called_with( + nhs_number=TEST_NHS_NUMBER, + custodian=updated_ods.return_value, + restrictions=[MOCK_RESTRICTION], + ) + + +def test_handle_mns_notification_calls_update_restriction_called_on_death_notification( + mns_service, + mocker, +): + mns_service.restrictions_dynamo_service.query_restrictions_by_nhs_number.return_value = [ + MOCK_RESTRICTION, + ] + mocker.patch.object(mns_service, "update_restrictions") + + mns_service.handle_mns_notification(death_notification_message) + + mns_service.restrictions_dynamo_service.query_restrictions_by_nhs_number.assert_called_with( + nhs_number=TEST_NHS_NUMBER, ) + + mns_service.update_restrictions.assert_called_once_with( + nhs_number=TEST_NHS_NUMBER, + custodian=PatientOdsInactiveStatus.DECEASED, + restrictions=[MOCK_RESTRICTION], + ) + + +@freezegun.freeze_time("2021-04-01") +def test_update_restrictions_uses_restriction_dynamo_service(mns_service, mocker): + mocker.patch.object(mns_service, "get_updated_gp_ods", return_value=NEW_ODS_CODE) + + mns_service.restrictions_dynamo_service.query_restrictions_by_nhs_number.return_value = [ + MOCK_RESTRICTION, + ] + + mns_service.update_restrictions( + nhs_number=TEST_NHS_NUMBER, + custodian=NEW_ODS_CODE, + restrictions=[MOCK_RESTRICTION], + ) + + mns_service.restrictions_dynamo_service.update_restriction_custodian.assert_called_once_with( + restriction_id=MOCK_RESTRICTION.id, + updated_custodian=NEW_ODS_CODE, + ) + + +def test_update_restrictions_not_called_on_informal_death_notification( + mns_service, + mocker, +): + mock_update_restriction = mocker.patch.object(mns_service, "update_restrictions") + + mns_service.handle_mns_notification(informal_death_notification_message) + + mock_update_restriction.assert_not_called() diff --git a/lambdas/tests/unit/services/user_restriction/test_user_restriction_dynamo_service.py b/lambdas/tests/unit/services/user_restriction/test_user_restriction_dynamo_service.py index 8fb849988..3ea664a01 100644 --- a/lambdas/tests/unit/services/user_restriction/test_user_restriction_dynamo_service.py +++ b/lambdas/tests/unit/services/user_restriction/test_user_restriction_dynamo_service.py @@ -1,4 +1,5 @@ import pytest +from boto3.dynamodb.conditions import Attr from botocore.exceptions import ClientError from freezegun import freeze_time from pydantic import ValidationError @@ -11,16 +12,21 @@ from services.user_restrictions.user_restriction_dynamo_service import ( UserRestrictionDynamoService, ) -from tests.unit.conftest import TEST_CURRENT_GP_ODS, TEST_NHS_NUMBER, TEST_UUID +from tests.unit.conftest import ( + MOCK_USER_RESTRICTION_TABLE, + TEST_CURRENT_GP_ODS, + TEST_NHS_NUMBER, + TEST_UUID, +) from tests.unit.services.user_restriction.conftest import MOCK_IDENTIFIER from utils.exceptions import ( UserRestrictionConditionCheckFailedException, + UserRestrictionDynamoDBException, UserRestrictionValidationException, ) TEST_ODS_CODE = "Y12345" TEST_SMART_CARD_ID = "SC001" -MOCK_USER_RESTRICTION_TABLE = "test_user_restriction_table" TEST_NEXT_TOKEN = "some-opaque-next-token" MOCK_TIME_STAMP = 1704110400 @@ -145,6 +151,42 @@ def test_query_restrictions_returns_empty_list_when_no_items(mock_service): assert next_token is None +def test_query_restrictions_by_nhs_number(mock_service, mocker): + mock_validate = mocker.patch.object(mock_service, "_validate_restrictions") + mock_service.dynamo_service.query_table.return_value = [MOCK_RESTRICTION_ITEM] + + expected_filter = Attr("IsActive").eq(True) + + mock_service.query_restrictions_by_nhs_number(nhs_number=TEST_NHS_NUMBER) + + mock_service.dynamo_service.query_table.assert_called_with( + table_name=MOCK_USER_RESTRICTION_TABLE, + index_name=UserRestrictionIndexes.NHS_NUMBER_INDEX.value, + search_key=UserRestrictionsFields.NHS_NUMBER.value, + search_condition=TEST_NHS_NUMBER, + query_filter=expected_filter, + ) + + mock_validate.assert_called_with([MOCK_RESTRICTION_ITEM]) + + +def test_query_restrictions_by_nhs_number_handles_client_error(mock_service): + mock_service.dynamo_service.query_table.side_effect = ClientError( + {"Error": {"Code": "500", "Message": "DynamoDB error"}}, + "query", + ) + + with pytest.raises(UserRestrictionDynamoDBException): + mock_service.query_restrictions_by_nhs_number(nhs_number=TEST_NHS_NUMBER) + + +def test_query_restrictions_by_nhs_number_handles_validation_error(mock_service): + mock_service.dynamo_service.query_table.return_value = [{"invalid": "object"}] + + with pytest.raises(UserRestrictionValidationException): + mock_service.query_restrictions_by_nhs_number(nhs_number=TEST_NHS_NUMBER) + + def test_validate_restrictions_raises_for_invalid_items(): with pytest.raises(UserRestrictionValidationException): UserRestrictionDynamoService._validate_restrictions([{"invalid": "data"}]) diff --git a/lambdas/utils/exceptions.py b/lambdas/utils/exceptions.py index 160228ef8..e73c9e080 100644 --- a/lambdas/utils/exceptions.py +++ b/lambdas/utils/exceptions.py @@ -257,3 +257,7 @@ class UserRestrictionException(Exception): class UserRestrictionConditionCheckFailedException(Exception): pass + + +class UserRestrictionDynamoDBException(Exception): + pass