diff --git a/.github/workflows/python-code-quality.yml b/.github/workflows/python-code-quality.yml index b44139b42..ac25df2b8 100644 --- a/.github/workflows/python-code-quality.yml +++ b/.github/workflows/python-code-quality.yml @@ -22,12 +22,6 @@ jobs: - name: Check Python Formatting run: | make python-linting - - name: Check Python Imports - run: | - make python-check-imports - name: Check for Python Dead Code run: | make python-check-dead-code - - name: Check Python Security - run: | - make python-check-security diff --git a/.vscode/extensions.json b/.vscode/extensions.json index c3cfc4972..e460c205d 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -7,7 +7,8 @@ "streetsidesoftware.code-spell-checker", "vscode-icons-team.vscode-icons", "yzhang.dictionary-completion", - "yzhang.markdown-all-in-one" + "yzhang.markdown-all-in-one", + "charliermarsh.ruff" ], "unwantedRecommendations": [ "googlecloudtools.cloudcode", diff --git a/Makefile b/Makefile index d7f129e26..494cc55aa 100644 --- a/Makefile +++ b/Makefile @@ -510,13 +510,11 @@ trigger-dos-deployment-pipeline: rm -rf jenkins.cookies python-linting: - make python-code-check FILES=application make docker-run-ruff python-code-checks: make python-check-dead-code - make python-check-imports - make python-code-check FILES=application + make python-linting make unit-test echo "Python code checks completed" @@ -526,32 +524,6 @@ python-check-dead-code: DIR=$(APPLICATION_DIR) \ CMD="python -m vulture" -python-format: - make python-code-format FILES=application - make python-code-format FILES=test - -python-check-imports: - make -s docker-run-python \ - IMAGE=$$(make _docker-get-reg)/tester:latest \ - DIR=$(APPLICATION_DIR) \ - CMD="python -m isort . -l=120 --check-only --profile=black \ - --force-alphabetical-sort-within-sections --known-local-folder=common \ - " - -python-fix-imports: - make -s docker-run-python \ - IMAGE=$$(make _docker-get-reg)/tester:latest \ - DIR=$(APPLICATION_DIR) \ - CMD="python -m isort . -l=120 --profile=black --force-alphabetical-sort-within-sections \ - --known-local-folder=common \ - " - -python-check-security: - make -s docker-run-python \ - IMAGE=$$(make _docker-get-reg)/tester:latest \ - DIR=$(APPLICATION_DIR) \ - CMD="python -m bandit -r . -c pyproject.toml" - create-ecr-repositories: make docker-create-repository NAME=change-event-dlq-handler make docker-create-repository NAME=dos-db-handler @@ -651,8 +623,10 @@ undeploy-dynamodb-cleanup-job: # Undeploys dynamodb cleanup job docker-run-ruff: # Runs ruff tests - mandatory: RUFF_OPTS=[options] make -s docker-run \ - IMAGE=$$(make _docker-get-reg)/tester:latest \ + IMAGE=$$(make _docker-get-reg)/tester \ CMD="ruff check . $(RUFF_OPTS)" -ruff-auto-fix: # Auto fixes ruff warnings +python-ruff-fix: # Auto fixes ruff warnings make docker-run-ruff RUFF_OPTS="--fix" + +.SILENT: docker-run-ruff diff --git a/application/change_event_dlq_handler/change_event_dlq_handler.py b/application/change_event_dlq_handler/change_event_dlq_handler.py index b4c974166..00ada7380 100644 --- a/application/change_event_dlq_handler/change_event_dlq_handler.py +++ b/application/change_event_dlq_handler/change_event_dlq_handler.py @@ -1,7 +1,9 @@ +from typing import Any + from aws_embedded_metrics import metric_scope from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.tracing import Tracer -from aws_lambda_powertools.utilities.data_classes import event_source, SQSEvent +from aws_lambda_powertools.utilities.data_classes import SQSEvent, event_source from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext from common.constants import FIFO_DLQ_HANDLER_REPORT_ID @@ -20,8 +22,8 @@ @event_source(data_class=SQSEvent) @logger.inject_lambda_context(clear_state=True) @metric_scope -def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: - """Entrypoint handler for the change event dlq handler lambda +def lambda_handler(event: SQSEvent, context: LambdaContext, metrics: Any) -> None: # noqa: ARG001, ANN401 + """Entrypoint handler for the change event dlq handler lambda. Messages are sent to the change event dlq handler lambda when a message fails in either the change event queue or holding queue @@ -29,6 +31,7 @@ def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: Args: event (SQSEvent): Lambda function invocation event (list of 1 SQS Message) context (LambdaContext): Lambda function context object + metrics (Metrics): CloudWatch embedded metrics object """ record = next(event.records) handle_sqs_msg_attributes(record.message_attributes) diff --git a/application/change_event_dlq_handler/tests/test_change_event_dlq_handler.py b/application/change_event_dlq_handler/tests/test_change_event_dlq_handler.py index b6fea71c4..ce39fc4cb 100644 --- a/application/change_event_dlq_handler/tests/test_change_event_dlq_handler.py +++ b/application/change_event_dlq_handler/tests/test_change_event_dlq_handler.py @@ -1,12 +1,12 @@ from dataclasses import dataclass from json import dumps -from typing import Any, Dict +from typing import Any from unittest.mock import MagicMock, patch +import pytest from aws_embedded_metrics.logger.metrics_logger import MetricsLogger -from pytest import fixture -from ..change_event_dlq_handler import lambda_handler +from application.change_event_dlq_handler.change_event_dlq_handler import lambda_handler from common.tests.conftest import PHARMACY_STANDARD_EVENT, PHARMACY_STANDARD_EVENT_STAFF FILE_PATH = "application.change_event_dlq_handler.change_event_dlq_handler" @@ -30,9 +30,9 @@ } -@fixture +@pytest.fixture() def dead_letter_change_event_from_change_event_queue(): - yield { + return { "Records": [ { "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", @@ -62,14 +62,14 @@ def dead_letter_change_event_from_change_event_queue(): "eventSource": "aws:sqs", "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", "awsRegion": "us-east-2", - } - ] + }, + ], } -@fixture +@pytest.fixture() def dead_letter_staff_change_event_from_change_event_queue(): - yield { + return { "Records": [ { "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", @@ -99,14 +99,14 @@ def dead_letter_staff_change_event_from_change_event_queue(): "eventSource": "aws:sqs", "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", "awsRegion": "us-east-2", - } - ] + }, + ], } -@fixture +@pytest.fixture() def dead_letter_change_event_from_holding_queue(): - yield { + return { "Records": [ { "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", @@ -123,16 +123,16 @@ def dead_letter_change_event_from_holding_queue(): "eventSource": "aws:sqs", "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", "awsRegion": "us-east-2", - } - ] + }, + ], } -@fixture +@pytest.fixture() def lambda_context(): @dataclass class LambdaContext: - """Mock LambdaContext - All dummy values""" + """Mock LambdaContext - All dummy values.""" function_name: str = "change-event-dlq-handler" memory_limit_in_mb: int = 128 @@ -151,8 +151,8 @@ def test_lambda_handler_event_from_change_event_queue( mock_set_dimensions: MagicMock, mock_add_change_event_to_dynamodb: MagicMock, mock_extract_body: MagicMock, - dead_letter_staff_change_event_from_change_event_queue: Dict[str, Any], - dead_letter_change_event_from_change_event_queue: Dict[str, Any], + dead_letter_staff_change_event_from_change_event_queue: dict[str, Any], + dead_letter_change_event_from_change_event_queue: dict[str, Any], lambda_context, ): # Arrange @@ -163,7 +163,7 @@ def test_lambda_handler_event_from_change_event_queue( # Assert mock_extract_body.assert_called_once_with(dead_letter_change_event_from_change_event_queue["Records"][0]["body"]) expected_timestamp = int( - dead_letter_change_event_from_change_event_queue["Records"][0]["attributes"]["SentTimestamp"] + dead_letter_change_event_from_change_event_queue["Records"][0]["attributes"]["SentTimestamp"], ) mock_add_change_event_to_dynamodb.assert_called_once_with(extracted_body, 123456789, expected_timestamp) @@ -175,7 +175,7 @@ def test_lambda_handler_event_from_holding_queue( mock_put_metric: MagicMock, mock_set_dimensions: MagicMock, mock_add_change_event_to_dynamodb: MagicMock, - dead_letter_change_event_from_holding_queue: Dict[str, Any], + dead_letter_change_event_from_holding_queue: dict[str, Any], lambda_context, ): # Act @@ -183,5 +183,7 @@ def test_lambda_handler_event_from_holding_queue( # Assert expected_timestamp = int(dead_letter_change_event_from_holding_queue["Records"][0]["attributes"]["SentTimestamp"]) mock_add_change_event_to_dynamodb.assert_called_once_with( - CHANGE_EVENT_FROM_CHANGE_EVENT_QUEUE, CHANGE_EVENT_FROM_HOLDING_QUEUE["sequence_number"], expected_timestamp + CHANGE_EVENT_FROM_CHANGE_EVENT_QUEUE, + CHANGE_EVENT_FROM_HOLDING_QUEUE["sequence_number"], + expected_timestamp, ) diff --git a/application/common/appconfig.py b/application/common/appconfig.py index 8b55750ec..b3e4513b5 100644 --- a/application/common/appconfig.py +++ b/application/common/appconfig.py @@ -1,15 +1,15 @@ from os import getenv -from typing import Any, Dict +from typing import Any from aws_lambda_powertools.utilities.feature_flags.appconfig import AppConfigStore from aws_lambda_powertools.utilities.feature_flags.feature_flags import FeatureFlags class AppConfig: - """Application configuration""" + """Application configuration.""" def __init__(self, name: str) -> None: - """Initialise the application configuration + """Initialise the application configuration. Args: name (str): name of the application configuration profile @@ -22,8 +22,8 @@ def __init__(self, name: str) -> None: name=name, ) - def get_raw_configuration(self) -> Dict[str, Any]: - """Get the raw configuration + def get_raw_configuration(self) -> dict[str, Any]: + """Get the raw configuration. Returns: dict: raw configuration @@ -31,7 +31,7 @@ def get_raw_configuration(self) -> Dict[str, Any]: return self.app_config.get_raw_configuration def get_feature_flags(self) -> FeatureFlags: - """Get the feature flags for the given name + """Get the feature flags for the given name. Returns: FeatureFlags: feature flags class diff --git a/application/common/constants.py b/application/common/constants.py index 2afe0f1cb..d83ad29b6 100644 --- a/application/common/constants.py +++ b/application/common/constants.py @@ -27,6 +27,8 @@ VALID_SERVICE_TYPES_KEY = "VALID_SERVICE_TYPES" ODSCODE_LENGTH_KEY = "ODSCODE_LENGTH" +CLOSED_AND_HIDDEN_STATUSES = ["HIDDEN", "CLOSED"] + SERVICE_TYPES = { PHARMACY_ORG_TYPE_ID: { SERVICE_TYPES_ALIAS_KEY: PHARMACY_SERVICE_KEY, diff --git a/application/common/dos.py b/application/common/dos.py index 161f97982..888271fa7 100644 --- a/application/common/dos.py +++ b/application/common/dos.py @@ -1,7 +1,7 @@ from collections import defaultdict +from collections.abc import Iterable from dataclasses import dataclass, fields from itertools import groupby -from typing import Dict, Iterable, List, Optional, Set, Union from aws_lambda_powertools.logging import Logger from psycopg import Connection @@ -26,7 +26,7 @@ class DoSService: """Class to represent a DoS Service, field names are equal to equivalent db column names.""" - id: int + id: int # noqa: A003 uid: int name: str odscode: str @@ -45,15 +45,16 @@ class DoSService: longitude: float @staticmethod - def field_names() -> List[str]: + def field_names() -> list[str]: + """Returns a list of field names for this class.""" return [f.name for f in fields(DoSService)] def __init__(self, db_cursor_row: dict) -> None: - """Sets the attributes of this object to those found in the db row + """Sets the attributes of this object to those found in the db row. + Args: - db_cursor_row (dict): row from db as key/val pairs + db_cursor_row (dict): row from db as key/val pairs. """ - for row_key, row_value in db_cursor_row.items(): setattr(self, row_key, row_value) @@ -62,7 +63,7 @@ def __init__(self, db_cursor_row: dict) -> None: self.palliative_care = False def __repr__(self) -> str: - """Returns a string representation of this object""" + """Returns a string representation of this object.""" if self.publicname is not None: name = self.publicname elif self.name is not None: @@ -76,14 +77,16 @@ def __repr__(self) -> str: ) def normal_postcode(self) -> str: + """Returns the postcode with no spaces and in uppercase.""" return self.postcode.replace(" ", "").upper() def any_generic_bankholiday_open_periods(self) -> bool: + """Returns True if any of the opening times are generic bank holiday opening times.""" return len(self.standard_opening_times.generic_bankholiday) > 0 -def get_matching_dos_services(odscode: str, org_type_id: str) -> List[DoSService]: - """Retrieves DoS Services from DoS database +def get_matching_dos_services(odscode: str, org_type_id: str) -> list[DoSService]: + """Retrieves DoS Services from DoS database. Args: odscode (str): ODScode to match on @@ -106,13 +109,13 @@ def get_matching_dos_services(odscode: str, org_type_id: str) -> List[DoSService named_args = {"ODS": f"{odscode}%"} # Safe as conditional is configurable but variables is inputed to psycopg as variables sql_query = ( - "SELECT s.id, uid, s.name, odscode, address, postcode, web, typeid," # nosec B608 + "SELECT s.id, uid, s.name, odscode, address, postcode, web, typeid," # noqa: S608 "statusid, publicphone, publicname, st.name servicename" " FROM services s LEFT JOIN servicetypes st ON s.typeid = st.id" f" WHERE {conditions}" ) with connect_to_dos_db_replica() as connection: - cursor = query_dos_db(connection=connection, query=sql_query, vars=named_args) + cursor = query_dos_db(connection=connection, query=sql_query, query_vars=named_args) # Create list of DoSService objects from returned rows services = [DoSService(row) for row in cursor.fetchall()] cursor.close() @@ -120,10 +123,19 @@ def get_matching_dos_services(odscode: str, org_type_id: str) -> List[DoSService return services -def get_dos_locations(postcode: Union[str, None] = None, try_cache: bool = True) -> List[DoSLocation]: +def get_dos_locations(postcode: str | None = None, try_cache: bool = True) -> list[DoSLocation]: + """Retrieves DoS Locations from DoS database. + + Args: + postcode (str, optional): Postcode to match on. Defaults to None. + try_cache (bool, optional): Whether to try and use the local cache. Defaults to True. + + Returns: + list[DoSLocation]: List of DoSLocation objects with matching postcode, taken from DoS database + """ logger.info(f"Searching for DoS locations with postcode of '{postcode}'") norm_pc = postcode.replace(" ", "").upper() - global dos_location_cache + global dos_location_cache # noqa: PLW0602 if try_cache and norm_pc in dos_location_cache: logger.info(f"Postcode {norm_pc} location/s found in local cache.") return dos_location_cache[norm_pc] @@ -132,12 +144,14 @@ def get_dos_locations(postcode: Union[str, None] = None, try_cache: bool = True) postcode_variations = [norm_pc] + [f"{norm_pc[:i]} {norm_pc[i:]}" for i in range(1, len(norm_pc))] db_column_names = [f.name for f in fields(DoSLocation)] sql_command = ( - f"SELECT {', '.join(db_column_names)} FROM locations WHERE postcode = ANY(%(pc_variations)s)" # nosec B608 + f"SELECT {', '.join(db_column_names)} FROM locations WHERE postcode = ANY(%(pc_variations)s)" # noqa: S608 # Safe as conditional is configurable but variables is inputted to psycopg as variables ) with connect_to_dos_db_replica() as connection: - cursor = query_dos_db(connection=connection, query=sql_command, vars={"pc_variations": postcode_variations}) + cursor = query_dos_db( + connection=connection, query=sql_command, query_vars={"pc_variations": postcode_variations}, + ) dos_locations = [DoSLocation(**row) for row in cursor.fetchall()] cursor.close() dos_location_cache[norm_pc] = dos_locations @@ -146,20 +160,23 @@ def get_dos_locations(postcode: Union[str, None] = None, try_cache: bool = True) return dos_locations -def get_all_valid_dos_postcodes() -> Set[str]: +def get_all_valid_dos_postcodes() -> set[str]: """Gets all the valid DoS postcodes that are found in the locations table. - Returns: A set of normalised postcodes as strings""" + + Returns: + set[str]: A set of normalised postcodes as strings. + """ logger.info("Collecting all valid postcodes from DoS DB") sql_command = "SELECT postcode FROM locations" with connect_to_dos_db_replica() as connection: cursor = query_dos_db(connection=connection, query=sql_command) - postcodes = set(row["postcode"].replace(" ", "").upper() for row in cursor.fetchall()) + postcodes = {row["postcode"].replace(" ", "").upper() for row in cursor.fetchall()} cursor.close() logger.info(f"Found {len(postcodes)} unique postcodes from DoS DB.") return postcodes -def get_valid_dos_location(postcode: str) -> Optional[DoSLocation]: +def get_valid_dos_location(postcode: str) -> DoSLocation | None: """Gets the valid DoS location for the given postcode. Args: @@ -172,11 +189,11 @@ def get_valid_dos_location(postcode: str) -> Optional[DoSLocation]: return dos_locations[0] if dos_locations else None -def get_services_from_db(typeids: Iterable) -> List[DoSService]: - """VUNERABLE TO SQL INJECTION: DO NOT USE IN LAMBDA""" +def get_services_from_db(typeids: Iterable) -> list[DoSService]: + """VUNERABLE TO SQL INJECTION: DO NOT USE IN LAMBDA.""" # Find base services sql_query = ( - "SELECT s.id, uid, s.name, odscode, address, postcode, web, typeid, " # nosec B608 - Not for use within lambda + "SELECT s.id, uid, s.name, odscode, address, postcode, web, typeid, " # noqa: S608 - Not for use within lambda "statusid, publicphone, publicname, st.name servicename " "FROM services s LEFT JOIN servicetypes st ON s.typeid = st.id " f"WHERE typeid IN ({','.join(map(str, typeids))}) " @@ -186,11 +203,11 @@ def get_services_from_db(typeids: Iterable) -> List[DoSService]: cursor = query_dos_db(connection=connection, query=sql_query) services = [DoSService(row) for row in cursor.fetchall()] cursor.close() - service_id_strings = set(str(s.id) for s in services) + service_id_strings = {str(s.id) for s in services} # Collect and apply all std open times to services sql_query = ( - "SELECT sdo.serviceid, sdo.dayid, otd.name, sdot.starttime, sdot.endtime " # nosec - Not used within lambda + "SELECT sdo.serviceid, sdo.dayid, otd.name, sdot.starttime, sdot.endtime " # noqa: S608 "FROM servicedayopenings sdo " "INNER JOIN servicedayopeningtimes sdot " "ON sdo.id = sdot.servicedayopeningid " @@ -199,7 +216,7 @@ def get_services_from_db(typeids: Iterable) -> List[DoSService]: f"WHERE sdo.serviceid IN ({','.join(service_id_strings)})" ) cursor = query_dos_db(connection=connection, query=sql_query) - std_open_times = db_rows_to_std_open_times_map([db_row for db_row in cursor.fetchall()]) + std_open_times = db_rows_to_std_open_times_map(list(cursor.fetchall())) for service in services: service.standard_opening_times = std_open_times.get(service.id, StandardOpeningTimes()) cursor.close() @@ -207,14 +224,14 @@ def get_services_from_db(typeids: Iterable) -> List[DoSService]: # Collect and apply all spec open times to services # Not used within lambda sql_query = ( - "SELECT ssod.serviceid, ssod.date, ssot.starttime, ssot.endtime, ssot.isclosed " # nosec + "SELECT ssod.serviceid, ssod.date, ssot.starttime, ssot.endtime, ssot.isclosed " # noqa: S608 "FROM servicespecifiedopeningdates ssod " "INNER JOIN servicespecifiedopeningtimes ssot " "ON ssod.id = ssot.servicespecifiedopeningdateid " f"WHERE ssod.serviceid IN ({','.join(service_id_strings)})" ) cursor = query_dos_db(connection=connection, query=sql_query) - spec_open_times = db_rows_to_spec_open_times_map([row for row in cursor.fetchall()]) + spec_open_times = db_rows_to_spec_open_times_map(list(cursor.fetchall())) for service in services: service.specified_opening_times = spec_open_times.get(service.id, []) cursor.close() @@ -222,17 +239,17 @@ def get_services_from_db(typeids: Iterable) -> List[DoSService]: return services -def get_specified_opening_times_from_db(connection: Connection, service_id: int) -> List[SpecifiedOpeningTime]: - """Retrieves specified opening times from DoS database +def get_specified_opening_times_from_db(connection: Connection, service_id: int) -> list[SpecifiedOpeningTime]: + """Retrieves specified opening times from DoS database. Args: - serviceid (int): serviceid to match on + connection (Connection): Connection to DoS database + service_id (int): serviceid to match on Returns: List[SpecifiedOpeningTime]: List of Specified Opening times with matching serviceid """ - logger.info(f"Searching for specified opening times with serviceid that matches '{service_id}'") sql_query = ( @@ -243,16 +260,18 @@ def get_specified_opening_times_from_db(connection: Connection, service_id: int) "WHERE ssod.serviceid = %(SERVICE_ID)s" ) named_args = {"SERVICE_ID": service_id} - cursor = query_dos_db(connection=connection, query=sql_query, vars=named_args) + cursor = query_dos_db(connection=connection, query=sql_query, query_vars=named_args) specified_opening_times = db_rows_to_spec_open_times(cursor.fetchall()) cursor.close() return specified_opening_times def get_standard_opening_times_from_db(connection: Connection, service_id: int) -> StandardOpeningTimes: - """Retrieves standard opening times from DoS database. If ther service id does not even match any service this - function will still return a blank StandardOpeningTime with no opening periods.""" + """Retrieves standard opening times from DoS database. + If the service id does not even match any service this function will still return a blank StandardOpeningTime + with no opening periods. + """ logger.info(f"Searching for standard opening times with serviceid that matches '{service_id}'") sql_command = ( "SELECT sdo.serviceid, sdo.dayid, otd.name, sdot.starttime, sdot.endtime " @@ -264,15 +283,16 @@ def get_standard_opening_times_from_db(connection: Connection, service_id: int) "WHERE sdo.serviceid = %(SERVICE_ID)s" ) named_args = {"SERVICE_ID": service_id} - cursor = query_dos_db(connection=connection, query=sql_command, vars=named_args) + cursor = query_dos_db(connection=connection, query=sql_command, query_vars=named_args) standard_opening_times = db_rows_to_std_open_times(cursor.fetchall()) cursor.close() return standard_opening_times -def db_rows_to_spec_open_times(db_rows: Iterable[dict]) -> List[SpecifiedOpeningTime]: - """Turns a set of dos database rows into a list of SpecifiedOpenTime objects - note: The rows must to be for the same service +def db_rows_to_spec_open_times(db_rows: Iterable[dict]) -> list[SpecifiedOpeningTime]: + """Turns a set of dos database rows into a list of SpecifiedOpenTime objects. + + note: The rows must to be for the same service. """ specified_opening_times = [] date_sorted_rows = sorted(db_rows, key=lambda row: (row["date"], row["starttime"])) @@ -289,8 +309,10 @@ def db_rows_to_spec_open_times(db_rows: Iterable[dict]) -> List[SpecifiedOpening return specified_opening_times -def db_rows_to_spec_open_times_map(db_rows: Iterable[dict]) -> Dict[str, List[SpecifiedOpeningTime]]: - """Turns a set of dos database rows (from multiple services) into lists of SpecifiedOpenTime objects +def db_rows_to_spec_open_times_map(db_rows: Iterable[dict]) -> dict[str, list[SpecifiedOpeningTime]]: + """Map DB rows to SpecifiedOpeningTime objects. + + Turns a set of dos database rows (from multiple services) into lists of SpecifiedOpenTime objects which are sorted into a dictionary where the key is the service id of the service those SpecifiedOpenTime objects correspond to. """ @@ -306,8 +328,9 @@ def db_rows_to_spec_open_times_map(db_rows: Iterable[dict]) -> Dict[str, List[Sp def db_rows_to_std_open_times(db_rows: Iterable[dict]) -> StandardOpeningTimes: - """Turns a set of dos database rows into a StandardOpeningTime object - note: The rows must be for the same service + """Turns a set of dos database rows into a StandardOpeningTime object. + + note: The rows must be for the same service. """ standard_opening_times = StandardOpeningTimes() for row in db_rows: @@ -319,8 +342,10 @@ def db_rows_to_std_open_times(db_rows: Iterable[dict]) -> StandardOpeningTimes: return standard_opening_times -def db_rows_to_std_open_times_map(db_rows: Iterable[dict]) -> Dict[str, StandardOpeningTimes]: - """Turns a set of dos database rows (from multiple services) into StandardOpeningTime objects +def db_rows_to_std_open_times_map(db_rows: Iterable[dict]) -> dict[str, StandardOpeningTimes]: + """Map DB rows to StandardOpeningTime objects. + + Turns a set of dos database rows (from multiple services) into StandardOpeningTime objects which are sorted into a dictionary where the key is the service id of the service those StandardOpeningTime objects correspond to. """ @@ -336,7 +361,7 @@ def db_rows_to_std_open_times_map(db_rows: Iterable[dict]) -> Dict[str, Standard def has_palliative_care(service: DoSService, connection: Connection) -> bool: - """Checks if a service has palliative care + """Checks if a service has palliative care. Args: service: The service to check @@ -356,7 +381,7 @@ def has_palliative_care(service: DoSService, connection: Connection) -> bool: "PALIATIVE_CARE_SYMPTOM_GROUP": DOS_PALLIATIVE_CARE_SYMPTOM_GROUP, "PALIATIVE_CARE_SYMPTOM_DESCRIMINATOR": DOS_PALLIATIVE_CARE_SYMPTOM_DISCRIMINATOR, } - cursor = query_dos_db(connection=connection, query=sql_command, vars=named_args) + cursor = query_dos_db(connection=connection, query=sql_command, query_vars=named_args) cursor.fetchall() return cursor.rowcount != 0 return False diff --git a/application/common/dos_db_connection.py b/application/common/dos_db_connection.py index b5b4683cd..2aae613c9 100644 --- a/application/common/dos_db_connection.py +++ b/application/common/dos_db_connection.py @@ -1,11 +1,12 @@ +from collections.abc import Generator from contextlib import contextmanager from os import environ from time import time_ns -from typing import Any, Dict, Generator, Optional +from typing import Any from aws_lambda_powertools.logging import Logger -from psycopg import connect, Connection, Cursor -from psycopg.rows import dict_row, DictRow +from psycopg import Connection, Cursor, connect +from psycopg.rows import DictRow, dict_row from typing_extensions import LiteralString from common.secretsmanager import get_secret @@ -16,7 +17,7 @@ @contextmanager def connect_to_dos_db_replica() -> Generator[Connection, None, None]: - """Creates a new connection to the DoS DB Replica + """Creates a new connection to the DoS DB Replica. Yields: Generator[connection, None, None]: Connection to the database @@ -45,7 +46,7 @@ def connect_to_dos_db_replica() -> Generator[Connection, None, None]: @contextmanager def connect_to_dos_db() -> Generator[Connection[DictRow], None, None]: - """Creates a new connection to the DoS DB + """Creates a new connection to the DoS DB. Yields: Generator[connection, None, None]: Connection to the database @@ -66,10 +67,15 @@ def connect_to_dos_db() -> Generator[Connection[DictRow], None, None]: db_connection.close() -def connection_to_db( - server: str, port: str, db_name: str, db_schema: str, db_user: str, db_password: str +def connection_to_db( # noqa: PLR0913 + server: str, + port: str, + db_name: str, + db_schema: str, + db_user: str, + db_password: str, ) -> Connection: - """Creates a new connection to a database + """Creates a new connection to a database. Args: server (str): Database server to connect to @@ -97,23 +103,29 @@ def connection_to_db( def query_dos_db( - connection: Connection, query: LiteralString, vars: Optional[Dict[str, Any]] = None, log_vars: bool = True + connection: Connection, + query: LiteralString, + query_vars: dict[str, Any] | None = None, + log_vars: bool = True, ) -> Cursor[DictRow]: - """Queries the database given in the connection object + """Queries the database given in the connection object. Args: connection (Connection): Connection to the database query (str): Query to execute - vars (Optional[Dict[str, Any]], optional): Variables to use in the query. Defaults to None. + query_vars (Optional[Dict[str, Any]], optional): Variables to use in the query. Defaults to None. + log_vars (bool, optional): Whether to log the query variables. Defaults to True. Returns: DictRow: Cursor to the query results """ cursor = connection.cursor(row_factory=dict_row) - logger.info("Query to execute", extra={"query": query, "vars": vars if log_vars else "Vars have been redacted."}) + logger.info( + "Query to execute", extra={"query": query, "vars": query_vars if log_vars else "Vars have been redacted."}, + ) time_start = time_ns() // 1000000 - cursor.execute(query=query, params=vars) + cursor.execute(query=query, params=query_vars) logger.info(f"DoS DB query completed in {(time_ns() // 1000000) - time_start}ms") return cursor diff --git a/application/common/dos_location.py b/application/common/dos_location.py index ebf41c88f..d12bc72f1 100644 --- a/application/common/dos_location.py +++ b/application/common/dos_location.py @@ -3,9 +3,9 @@ @dataclass(init=True, repr=True) class DoSLocation: - """A Class to represent a location in the UK store within the DoS Database locations table""" + """A Class to represent a location in the UK store within the DoS Database locations table.""" - id: int + id: int # noqa: A003 postcode: str easting: float northing: float @@ -14,7 +14,9 @@ class DoSLocation: longitude: float def normal_postcode(self) -> str: + """Returns the postcode in a normalised format.""" return self.postcode.replace(" ", "").upper() def is_valid(self) -> bool: + """Returns True if the location is valid.""" return None not in (self.easting, self.northing, self.latitude, self.longitude) diff --git a/application/common/dynamodb.py b/application/common/dynamodb.py index 96cd7b34b..0ad978b0f 100644 --- a/application/common/dynamodb.py +++ b/application/common/dynamodb.py @@ -5,13 +5,13 @@ from json import dumps, loads from os import environ from time import time -from typing import Any, Dict, List, Union +from typing import Any from aws_lambda_powertools.logging.logger import Logger from boto3 import client, resource from boto3.dynamodb.types import TypeSerializer -from common.errors import DynamoDBException +from common.errors import DynamoDBError TTL = 157680000 # int((365*5)*24*60*60) 5 years in seconds logger = Logger(child=True) @@ -19,7 +19,7 @@ ddb_resource = resource("dynamodb", region_name=environ["AWS_REGION"]) -def dict_hash(change_event: Dict[str, Any], sequence_number: str) -> str: +def dict_hash(change_event: dict[str, Any], sequence_number: str) -> str: """MD5 hash of a dictionary.""" change_event_hash = hashlib.new("md5", usedforsecurity=False) encoded = dumps([change_event, sequence_number], sort_keys=True).encode() @@ -28,7 +28,7 @@ def dict_hash(change_event: Dict[str, Any], sequence_number: str) -> str: def put_circuit_is_open(circuit: str, is_open: bool) -> None: - """Set the circuit open status for a given circuit + """Set the circuit open status for a given circuit. Args: circuit (str): Name of the circuit @@ -44,16 +44,19 @@ def put_circuit_is_open(circuit: str, is_open: bool) -> None: put_item = {k: serializer.serialize(v) for k, v in dynamo_record.items()} response = dynamodb.put_item(TableName=environ["CHANGE_EVENTS_TABLE_NAME"], Item=put_item) logger.info("Put circuit status", extra={"response": response, "item": put_item}) - except Exception as err: - raise DynamoDBException(f"Unable to set circuit '{circuit}' to open.") from err + except Exception as err: # noqa: BLE001 + msg = f"Unable to set circuit '{circuit}' to open." + raise DynamoDBError(msg) from err -def get_circuit_is_open(circuit: str) -> Union[bool, None]: - """Gets the open status of a given circuit +def get_circuit_is_open(circuit: str) -> bool | None: + """Gets the open status of a given circuit. + Args: circuit (str): Name of the circuit + Returns: - Union[bool, None]: returns the status or None if the circuit does not exist + Union[bool, None]: returns the status or None if the circuit does not exist. """ try: respone = dynamodb.get_item( @@ -71,15 +74,18 @@ def get_circuit_is_open(circuit: str) -> Union[bool, None]: logger.debug(f"Circuit '{circuit}' is_open resp={item}") return None if item is None else bool(item["IsOpen"]["BOOL"]) - except Exception as err: - raise DynamoDBException(f"Unable to get circuit status for '{circuit}'.") from err + except Exception as err: # noqa: BLE001 + msg = f"Unable to get circuit status for '{circuit}'." + raise DynamoDBError(msg) from err + +def add_change_event_to_dynamodb(change_event: dict[str, Any], sequence_number: int, event_received_time: int) -> str: + """Add change event to dynamodb but store the message and use the event for details. -def add_change_event_to_dynamodb(change_event: Dict[str, Any], sequence_number: int, event_received_time: int) -> str: - """Add change event to dynamodb but store the message and use the event for details Args: change_event (Dict[str, Any]): sequence id for given ODSCode - event_received_time (str): received timestamp from SQSEvent + sequence_number (int): sequence id for given ODSCode + event_received_time (str): received timestamp from SQSEvent. Returns: dict: returns response from dynamodb @@ -98,19 +104,21 @@ def add_change_event_to_dynamodb(change_event: Dict[str, Any], sequence_number: put_item = {k: serializer.serialize(v) for k, v in dynamo_record.items()} response = dynamodb.put_item(TableName=environ["CHANGE_EVENTS_TABLE_NAME"], Item=put_item) logger.info("Added record to dynamodb", extra={"response": response, "item": put_item}) - except Exception as err: - raise DynamoDBException(f"Unable to add change event (seq no: {sequence_number}) into dynamodb") from err + except Exception as err: # noqa: BLE001 + msg = f"Unable to add change event (seq no: {sequence_number}) into dynamodb" + raise DynamoDBError(msg) from err return record_id def get_latest_sequence_id_for_a_given_odscode_from_dynamodb(odscode: str) -> int: - """Get latest sequence id for a given odscode from dynamodb + """Get latest sequence id for a given odscode from dynamodb. + Args: odscode (str): odscode for the change event + Returns: - int: Sequence number of the message or None if not present + int: Sequence number of the message or None if not present. """ - # try: resp = dynamodb.query( TableName=environ["CHANGE_EVENTS_TABLE_NAME"], IndexName="gsi_ods_sequence", @@ -124,8 +132,6 @@ def get_latest_sequence_id_for_a_given_odscode_from_dynamodb(odscode: str) -> in if resp.get("Count") > 0: sequence_number = int(resp.get("Items")[0]["SequenceNumber"]["N"]) logger.debug(f"Sequence number for osdscode '{odscode}'= {sequence_number}") - # except Exception as err: - # raise DynamoDBException(f"Unable to get sequence id from dynamodb for a given ODSCode '{odscode}'.") from err return sequence_number @@ -134,15 +140,16 @@ def get_newest_event_per_odscode(threads: int = 2, limit: int = None) -> dict[st change_event_table = ddb_resource.Table(environ["CHANGE_EVENTS_TABLE_NAME"]) logger.info( f"Returning newest events per ODSCode from DDB table " - f"{environ['CHANGE_EVENTS_TABLE_NAME']}' ({threads} threads).") + f"{environ['CHANGE_EVENTS_TABLE_NAME']}' ({threads} threads).", + ) - def merge_newest_events(newest_events: dict, more_events: List[dict]): + def merge_newest_events(newest_events: dict, more_events: list[dict]): # noqa: ANN202 for event in more_events: newest_event = newest_events.get(event["ODSCode"]) if not (newest_event is not None and newest_event["SequenceNumber"] > event["SequenceNumber"]): newest_events[event["ODSCode"]] = event - def scan_thread(segment: int, total_segments: int): + def scan_thread(segment: int, total_segments: int): # noqa: ANN202 scan_kwargs = {"Segment": segment, "TotalSegments": total_segments} if limit is not None: scan_kwargs["Limit"] = limit @@ -159,10 +166,10 @@ def scan_thread(segment: int, total_segments: int): scan_kwargs["ExclusiveStartKey"] = resp["LastEvaluatedKey"] else: return newest_events + return None with ThreadPoolExecutor() as executor: - thread_runs = [ - executor.submit(scan_thread, segment=i, total_segments=threads) for i in range(threads)] + thread_runs = [executor.submit(scan_thread, segment=i, total_segments=threads) for i in range(threads)] newest_events = {} for thread in thread_runs: merge_newest_events(newest_events, thread.result().values()) diff --git a/application/common/errors.py b/application/common/errors.py index 5d60cdfce..3aab94552 100644 --- a/application/common/errors.py +++ b/application/common/errors.py @@ -1,6 +1,6 @@ -class ValidationException(Exception): - pass +class ValidationError(Exception): + """Exception raised for errors in the input.""" -class DynamoDBException(Exception): - pass +class DynamoDBError(Exception): + """Exception raised for all DynamoDB errors.""" diff --git a/application/common/middlewares.py b/application/common/middlewares.py index f737a1e50..bb138147b 100644 --- a/application/common/middlewares.py +++ b/application/common/middlewares.py @@ -1,49 +1,79 @@ +from typing import Any + from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.middleware_factory import lambda_handler_decorator from aws_lambda_powertools.utilities.typing import LambdaContext from botocore.exceptions import ClientError -from common.errors import ValidationException +from common.errors import ValidationError from common.utilities import extract_body, json_str_body logger = Logger(child=True) @lambda_handler_decorator(trace_execution=True) -def redact_staff_key_from_event(handler, event, context: LambdaContext): +def redact_staff_key_from_event(handler, event, context: LambdaContext) -> Any: # noqa: ANN001, ANN401 + """Lambda middleware to remove the 'Staff' key from the Change Event payload. + + Args: + handler: Lambda handler function + event: Lambda event + context: Lambda context object + + Returns: + Any: Lambda handler response + """ logger.info("Checking if 'Staff' key needs removing from Change Event payload") - if 'Records' in event and len(list(event['Records'])) > 0: - for record in event['Records']: - change_event = extract_body(record['body']) - if change_event.pop('Staff', None) is not None: - record['body'] = json_str_body(change_event) + if "Records" in event and len(list(event["Records"])) > 0: + for record in event["Records"]: + change_event = extract_body(record["body"]) + if change_event.pop("Staff", None) is not None: + record["body"] = json_str_body(change_event) logger.info("Redacted 'Staff' key from Change Event payload") return handler(event, context) @lambda_handler_decorator(trace_execution=True) -def unhandled_exception_logging(handler, event, context: LambdaContext): +def unhandled_exception_logging(handler, event, context: LambdaContext) -> Any: # noqa: ANN001, ANN401 + """Lambda middleware to log unhandled exceptions. + + Args: + handler: Lambda handler function + event: Lambda event + context: Lambda context object + + Returns: + Any: Lambda handler response + """ try: - response = handler(event, context) - return response - except ValidationException as err: - logger.exception(f"Validation Error - {err}", extra={"error": err, "event": event}) - return + return handler(event, context) + except ValidationError as error: + logger.exception(f"Validation Error - {error}", extra={"event": event}) # noqa: TRY401 + return None except ClientError as err: error_code = err.response["Error"]["Code"] error_msg = err.response["Error"]["Message"] logger.exception(f"Boto3 Client Error - '{error_code}': {error_msg}", extra={"error": err, "event": event}) - raise err - except BaseException as err: - logger.exception(f"Something went wrong - {err}", extra={"error": err, "event": event}) - raise err + raise + except BaseException: + logger.exception("Error Occurred", extra={"event": event}) + raise @lambda_handler_decorator(trace_execution=True) -def unhandled_exception_logging_hidden_event(handler, event, context: LambdaContext): +def unhandled_exception_logging_hidden_event(handler, event, context: LambdaContext) -> Any: # noqa: ANN001, ANN401 + """Lambda middleware to log unhandled exceptions but hide the event. + + Args: + handler: Lambda handler function + event: Lambda event + context: Lambda context object + + Returns: + Any: Lambda handler response + """ try: - response = handler(event, context) - return response - except BaseException as err: - logger.error("Something went wrong but the event is hidden") - raise err + return handler(event, context) + except BaseException: + logger.exception("Something went wrong but the event is hidden") + raise diff --git a/application/common/nhs.py b/application/common/nhs.py index 8cfd9a68c..175d629e3 100644 --- a/application/common/nhs.py +++ b/application/common/nhs.py @@ -2,20 +2,25 @@ from dataclasses import dataclass from datetime import datetime from itertools import groupby -from typing import Any, Dict, List, Optional, Union +from typing import Any from aws_lambda_powertools.logging import Logger -from common.constants import DENTIST_SERVICE_TYPE_IDS, NHS_UK_PALLIATIVE_CARE_SERVICE_CODE, PHARMACY_SERVICE_TYPE_IDS +from common.constants import ( + CLOSED_AND_HIDDEN_STATUSES, + DENTIST_SERVICE_TYPE_IDS, + NHS_UK_PALLIATIVE_CARE_SERVICE_CODE, + PHARMACY_SERVICE_TYPE_IDS, +) from common.dos import DoSService -from common.opening_times import OpenPeriod, SpecifiedOpeningTime, StandardOpeningTimes, WEEKDAYS +from common.opening_times import WEEKDAYS, OpenPeriod, SpecifiedOpeningTime, StandardOpeningTimes logger = Logger(child=True) @dataclass class NHSEntity: - """This is an object to store an NHS Entity data + """This is an object to store an NHS Entity data. Some fields are pulled straight from the payload while others are processed first. So attribute names differ from payload format for consistency within object. @@ -28,16 +33,16 @@ class NHSEntity: org_type: str org_sub_type: str org_status: str - address_lines: List[str] + address_lines: list[str] postcode: str website: str phone: str - standard_opening_times: Optional[StandardOpeningTimes] - specified_opening_times: Optional[List[SpecifiedOpeningTime]] + standard_opening_times: StandardOpeningTimes | None + specified_opening_times: list[SpecifiedOpeningTime] | None palliative_care: bool - CLOSED_AND_HIDDEN_STATUSES = ["HIDDEN", "CLOSED"] - def __init__(self, entity_data: dict): + def __init__(self, entity_data: dict) -> None: + """Initialise the object with the entity data.""" self.entity_data = entity_data self.odscode = entity_data.get("ODSCode") @@ -52,7 +57,7 @@ def __init__(self, entity_data: dict): self.address_lines = [ line for line in [entity_data.get(x) for x in [f"Address{i}" for i in range(1, 5)] + ["City", "County"]] - if isinstance(line, str) and line.strip() != "" + if isinstance(line, str) and line.strip() ] self.standard_opening_times = self._get_standard_opening_times() @@ -62,25 +67,26 @@ def __init__(self, entity_data: dict): self.palliative_care = self.extract_uec_service(NHS_UK_PALLIATIVE_CARE_SERVICE_CODE) def __repr__(self) -> str: + """Returns a string representation of the object.""" return f"" - def normal_postcode(self): + def normal_postcode(self) -> str: + """Returns the postcode in a normalised format.""" return self.postcode.replace(" ", "").upper() - def extract_contact(self, contact_type: str) -> Optional[str]: - """Returns the nested contact value within the input payload""" + def extract_contact(self, contact_type: str) -> str | None: + """Returns the nested contact value within the input payload.""" for item in self.entity_data.get("Contacts", []): if ( item.get("ContactMethodType", "").upper() == contact_type.upper() and item.get("ContactType", "").upper() == "PRIMARY" and item.get("ContactAvailabilityType", "").upper() == "OFFICE HOURS" ): - return item.get("ContactValue") return None - def extract_uec_service(self, service_code: str) -> Union[bool, None]: - """Extracts the UEC service from the payload (e.g. Palliative Care) + def extract_uec_service(self, service_code: str) -> bool | None: + """Extracts the UEC service from the payload (e.g. Palliative Care). Args: service_code (str): NHS UK Service Code of the UEC service to extract if exists @@ -90,11 +96,12 @@ def extract_uec_service(self, service_code: str) -> Union[bool, None]: """ if isinstance(self.entity_data.get("UecServices", []), list): return any(item.get("ServiceCode") == service_code for item in self.entity_data.get("UecServices", [])) - else: - return None + return None def _get_standard_opening_times(self) -> StandardOpeningTimes: - """Filters the raw opening times data for standard weekly opening + """Get the standard opening times. + + Filters the raw opening times data for standard weekly opening times and returns it in a StandardOpeningTimes object. Args: @@ -116,8 +123,8 @@ def _get_standard_opening_times(self) -> StandardOpeningTimes: return std_opening_times - def _get_specified_opening_times(self) -> List[SpecifiedOpeningTime]: - """Get all the Specified Opening Times + def _get_specified_opening_times(self) -> list[SpecifiedOpeningTime]: + """Get all the Specified Opening Times. Args: opening_time_type (str): OpeningTimeType to filter the data, General for pharmacy @@ -146,16 +153,15 @@ def _get_specified_opening_times(self) -> List[SpecifiedOpeningTime]: return specified_opening_times def is_status_hidden_or_closed(self) -> bool: - """Check if the status is hidden or closed. If so, return True + """Check if the status is hidden or closed. If so, return True. Returns: bool: True if status is hidden or closed, False otherwise """ - return self.org_status.upper() in self.CLOSED_AND_HIDDEN_STATUSES + return self.org_status.upper() in CLOSED_AND_HIDDEN_STATUSES def all_times_valid(self) -> bool: - """Does checks on all opening times for correct format, business rules, overlaps""" - + """Does checks on all opening times for correct format, business rules, overlaps.""" # Check format matches either spec or std format for item in self.entity_data.get("OpeningTimes", []): if not (is_std_opening_json(item) or is_spec_opening_json(item)): @@ -165,16 +171,26 @@ def all_times_valid(self) -> bool: return self.standard_opening_times.is_valid() and SpecifiedOpeningTime.valid_list(self.specified_opening_times) def is_matching_dos_service(self, dos_service: DoSService) -> bool: + """Check if the entity matches the DoS service. + + Args: + dos_service (DoSService): DoS service to check against + + Returns: + bool: True if the entity matches the DoS service, False otherwise + """ if None in (self.odscode, dos_service.odscode): return False if dos_service.typeid in PHARMACY_SERVICE_TYPE_IDS: return ( - len(dos_service.odscode) >= 5 and len(self.odscode) >= 5 and dos_service.odscode[:5] == self.odscode[:5] + len(dos_service.odscode) >= 5 # noqa: PLR2004 + and len(self.odscode) >= 5 # noqa: PLR2004 + and dos_service.odscode[:5] == self.odscode[:5] ) if dos_service.typeid in DENTIST_SERVICE_TYPE_IDS: - if not (len(dos_service.odscode) >= 6 and len(self.odscode) >= 7): + if not (len(dos_service.odscode) >= 6 and len(self.odscode) >= 7): # noqa: PLR2004 return False odscode_extra_0 = f"{dos_service.odscode[0]}0{dos_service.odscode[1:]}" return self.odscode[:7] in (dos_service.odscode[:7], odscode_extra_0[:7]) @@ -184,15 +200,13 @@ def is_matching_dos_service(self, dos_service: DoSService) -> bool: def is_std_opening_json(item: dict) -> bool: - """Checks EXACT match to definition of General/Standard opening time for NHS Open time payload object""" - + """Checks EXACT match to definition of General/Standard opening time for NHS Open time payload object.""" # Check values if ( str(item.get("OpeningTimeType")).upper() != "GENERAL" or str(item.get("Weekday")).lower() not in WEEKDAYS or item.get("AdditionalOpeningDate") not in [None, ""] ): - return False is_open = item.get("IsOpen") @@ -214,8 +228,7 @@ def is_std_opening_json(item: dict) -> bool: def is_spec_opening_json(item: dict) -> bool: - """Checks EXACT match to definition of Additional/Spec opening time for NHS Open time payload object""" - + """Checks EXACT match to definition of Additional/Spec opening time for NHS Open time payload object.""" if str(item.get("OpeningTimeType")).upper() != "ADDITIONAL": return False @@ -243,11 +256,14 @@ def is_spec_opening_json(item: dict) -> bool: def match_nhs_entities_to_services( - nhs_entities: List[NHSEntity], services: List[DoSService] -) -> Dict[str, List[DoSService]]: - """Takes lists of NHS Entities and DoS Services and creates a dict where the keys are NHS odscodes - and the values are the corresponding lists of services that match that code.""" + nhs_entities: list[NHSEntity], + services: list[DoSService], +) -> dict[str, list[DoSService]]: + """Match NHS Entities to corresponding list of services. + Takes lists of NHS Entities and DoS Services and creates a dict where the keys are NHS odscodes + and the values are the corresponding lists of services that match that code. + """ logger.info("Matching all NHS Entities to corresponding list of services.") servicelist_map = defaultdict(list) for nhs_entity in nhs_entities: @@ -257,26 +273,26 @@ def match_nhs_entities_to_services( logger.info( f"{len(servicelist_map)}/{len(nhs_entities)} nhs entities matches with at least 1 service. " - f"{len(nhs_entities) - len(servicelist_map)} not matched." + f"{len(nhs_entities) - len(servicelist_map)} not matched.", ) return dict(servicelist_map) -def skip_if_key_is_none(key: Any) -> bool: - """If the key is None, skip the item""" - +def skip_if_key_is_none(key: Any) -> bool: # noqa: ANN401 + """If the key is None, skip the item.""" return key is None -def get_palliative_care_log_value(palliative_care: bool, skip_palliative_care: bool) -> Union[bool, str]: - """Get the value to log for palliative care +def get_palliative_care_log_value(palliative_care: bool, skip_palliative_care: bool) -> bool | str: + """Get the value to log for palliative care. Args: palliative_care (bool): The value of palliative care skip_palliative_care (bool): Whether to skip palliative care Returns: - bool | str: The value to log""" + bool | str: The value to log + """ return ( "Never been updated on Profile Manager, skipped palliative care checks" if skip_palliative_care diff --git a/application/common/opening_times.py b/application/common/opening_times.py index f613ffd95..d55ba3772 100644 --- a/application/common/opening_times.py +++ b/application/common/opening_times.py @@ -1,7 +1,7 @@ from contextlib import suppress from dataclasses import dataclass from datetime import date, datetime, time -from typing import Any, Dict, List, Optional +from typing import Any, Optional from aws_lambda_powertools.logging import Logger @@ -14,58 +14,121 @@ @dataclass(unsafe_hash=True, init=True) class OpenPeriod: + """Represents a period of time when a service is open. + + Attributes: + start (time): The start time of the open period + end (time): The end time of the open period + """ start: time end: time def start_string(self) -> str: + """Get the start time as a string. + + Returns: + str: The start time as a string + """ return self.start.strftime("%H:%M:%S") def end_string(self) -> str: + """Get the end time as a string. + + Returns: + str: The end time as a string + """ return self.end.strftime("%H:%M:%S") - def __str__(self): + def __str__(self) -> str: + """Get the open period as a string. + + Returns: + str: The open period as a string + """ return f"{self.start_string()}-{self.end_string()}" - def __repr__(self): + def __repr__(self) -> str: + """Get the open period as a string. + + Returns: + str: The open period as a string + """ return f"OpenPeriod({self})" - def __eq__(self, other: Any): + def __eq__(self, other: Any) -> bool: # noqa: ANN401 + """Check if two OpenPeriod objects are equal. + + Args: + other (Any): The object to compare to + + Returns: + bool: True if the objects are equal, False otherwise + """ return isinstance(other, OpenPeriod) and self.start == other.start and self.end == other.end - def __lt__(self, other: Any): + def __lt__(self, other: Any) -> bool: # noqa: ANN401 + """Check if one OpenPeriod object is less than another. + + Args: + other (Any): The object to compare to + + Returns: + bool: True if the first object is less than the second, False otherwise + """ if self.start == other.start: return self.end < other.end return self.start < other.start - def __gt__(self, other: Any): + def __gt__(self, other: Any) -> bool: # noqa: ANN401 + """Check if one OpenPeriod object is less than another. + + Args: + other (Any): The object to compare to + + Returns: + bool: True if the first object is less than the second, False otherwise + """ if self.start == other.start: return self.end > other.end return self.start > other.start def start_before_end(self) -> bool: + """Check if the start time is before the end time. + + Returns: + bool: True if the start time is before the end time, False otherwise + """ return self.start < self.end - def overlaps(self, other) -> bool: - assert self.start_before_end() - assert other.start_before_end() + def overlaps(self, other: Any) -> bool: # noqa: ANN401 + """Check if two OpenPeriod objects overlap. + + Args: + other (Any): The object to compare to + + Returns: + bool: True if the objects overlap, False otherwise + """ + assert self.start_before_end() # noqa: S101 + assert other.start_before_end() # noqa: S101 return self.start <= other.end and other.start <= self.end def export_db_string_format(self) -> str: - """Exports open period into a DoS db accepted format for previous value in the service history entry""" + """Exports open period into a DoS db accepted format for previous value in the service history entry.""" return f"{self.start.strftime(DOS_TIME_FORMAT)}-{self.end.strftime(DOS_TIME_FORMAT)}" def export_time_in_seconds(self) -> str: - """Exports open period into a DoS DB accepted format for service history""" + """Exports open period into a DoS DB accepted format for service history.""" return f"{self._seconds_since_midnight(self.start)}-{self._seconds_since_midnight(self.end)}" def _seconds_since_midnight(self, time: time) -> int: - """Returns the number of seconds since midnight for the given time""" + """Returns the number of seconds since midnight for the given time.""" return time.hour * 60 * 60 + time.minute * 60 + time.second @staticmethod - def any_overlaps(open_periods: List["OpenPeriod"]) -> bool: - """Returns whether any OpenPeriod object in list overlaps any others in the list""" + def any_overlaps(open_periods: list["OpenPeriod"]) -> bool: + """Returns whether any OpenPeriod object in list overlaps any others in the list.""" untested = open_periods.copy() while len(untested) > 1: test_op = untested.pop(0) @@ -75,28 +138,28 @@ def any_overlaps(open_periods: List["OpenPeriod"]) -> bool: return False @staticmethod - def list_string(open_periods: List["OpenPeriod"]) -> str: - """Returns a string version of a list of open periods in a consistently sorted order + def list_string(open_periods: list["OpenPeriod"]) -> str: + """Returns a string version of a list of open periods in a consistently sorted order. eg. '[08:00:00-13:00:00, 14:00:00-17:00:00, 18:00:00-20:00:00] """ - sorted_str_list = [str(op) for op in sorted(list(open_periods))] + sorted_str_list = [str(op) for op in sorted(open_periods)] return f"[{', '.join(sorted_str_list)}]" @staticmethod - def all_start_before_end(open_periods: List["OpenPeriod"]) -> bool: - """Returns whether all OpenPeriod object in list start before they ends""" + def all_start_before_end(open_periods: list["OpenPeriod"]) -> bool: + """Returns whether all OpenPeriod object in list start before they ends.""" return all(op.start_before_end() for op in open_periods) @staticmethod - def equal_lists(a: List["OpenPeriod"], b: List["OpenPeriod"]) -> bool: - """Checks equality between 2 lists of open periodsRelies on sorting and eq functions in OpenPeriod""" + def equal_lists(a: list["OpenPeriod"], b: list["OpenPeriod"]) -> bool: + """Checks equality between 2 lists of open periodsRelies on sorting and eq functions in OpenPeriod.""" return sorted(a) == sorted(b) @staticmethod def from_string(open_period_string: str) -> Optional["OpenPeriod"]: - """Builds an OpenPeriod object from a string like 12:00-13:00 or 12:00:00-13:00:00""" + """Builds an OpenPeriod object from a string like 12:00-13:00 or 12:00:00-13:00:00.""" try: startime_str, endtime_str = open_period_string.split("-") return OpenPeriod.from_string_times(startime_str, endtime_str) @@ -105,7 +168,7 @@ def from_string(open_period_string: str) -> Optional["OpenPeriod"]: @staticmethod def from_string_times(opening_time_str: str, closing_time_str: str) -> Optional["OpenPeriod"]: - """Builds an OpenPeriod object from string time arguments""" + """Builds an OpenPeriod object from string time arguments.""" open_time = string_to_time(opening_time_str) close_time = string_to_time(closing_time_str) if None in (open_time, close_time): @@ -113,8 +176,8 @@ def from_string_times(opening_time_str: str, closing_time_str: str) -> Optional[ return OpenPeriod(open_time, close_time) - def export_test_format(self) -> Dict[str, str]: - """Exports open period for use in the DoS DB Hander""" + def export_test_format(self) -> dict[str, str]: + """Exports open period for use in the DoS DB Hander.""" return { "start_time": self.start.strftime(DOS_TIME_FORMAT), "end_time": self.end.strftime(DOS_TIME_FORMAT), @@ -122,31 +185,78 @@ def export_test_format(self) -> Dict[str, str]: class SpecifiedOpeningTime: - def __init__(self, open_periods: List[OpenPeriod], specified_date: date, is_open: bool = True): - assert isinstance(specified_date, date) + """A class to represent a specified opening time for a service.""" + + def __init__(self, open_periods: list[OpenPeriod], specified_date: date, is_open: bool = True) -> None: + """Initialise a SpecifiedOpeningTime object. + + Args: + open_periods (list[OpenPeriod]): A list of OpenPeriod objects + specified_date (date): The date the open periods apply to + is_open (bool, optional): Whether the service is open on the specified date. Defaults to True. + """ + assert isinstance(specified_date, date) # noqa: S101 self.open_periods = open_periods self.date = specified_date self.is_open = is_open def date_string(self) -> str: + """Returns the date as a string in the format DD-MM-YYYY. + + Returns: + str: The date as a string in the format DD-MM-YYYY + """ return self.date.strftime("%d-%m-%Y") def open_periods_string(self) -> str: + """Returns a string version of the open periods. + + Returns: + str: A string version of the open periods + """ return OpenPeriod.list_string(self.open_periods) - def __hash__(self): + def __hash__(self) -> int: + """Returns a hash of the object. + + Returns: + int: A hash of the object + """ return hash((tuple(sorted(self.open_periods)), self.date, self.is_open)) - def __repr__(self): + def __repr__(self) -> str: + """Returns a string representation of the object. + + Returns: + str: A string representation of the object + """ return f"" - def __str__(self): + def __str__(self) -> str: + """Returns a string representation of the object. + + Returns: + str: A string representation of the object + """ return f"{self.open_string()} on {self.date_string()} {self.open_periods_string()}" - def open_string(self): + def open_string(self) -> str: + """Returns a string representation of whether the service is open or closed. + + Returns: + str: A string representation of whether the service is open or closed + """ return "OPEN" if self.is_open else "CLOSED" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: # noqa: ANN401 + """Checks equality between 2 SpecifiedOpeningTime objects. + + Args: + other (Any): The object to compare to + + Returns: + bool: Whether the objects are equal + """ return ( isinstance(other, SpecifiedOpeningTime) and self.is_open == other.is_open @@ -154,14 +264,14 @@ def __eq__(self, other): and OpenPeriod.equal_lists(self.open_periods, other.open_periods) ) - def export_service_history_format(self) -> List[str]: - """Exports Specified opening time into a DoS service history accepted format""" + def export_service_history_format(self) -> list[str]: + """Exports Specified opening time into a DoS service history accepted format.""" exp_open_periods = [op.export_time_in_seconds() for op in sorted(self.open_periods)] date_str = self.date.strftime(DOS_DATE_FORMAT) return [f"{date_str}-{period}" for period in exp_open_periods] if self.is_open else [f"{date_str}-closed"] - def export_dos_log_format(self) -> List[str]: - """Exports Specified opening times into a DoS Logs accepted format""" + def export_dos_log_format(self) -> list[str]: + """Exports Specified opening times into a DoS Logs accepted format.""" exp_open_periods = [op.export_db_string_format() for op in sorted(self.open_periods)] date_str = self.date.strftime(DOS_DATE_FORMAT) return [f"{date_str}-{period}" for period in exp_open_periods] if self.is_open else [f"{date_str}-closed"] @@ -171,9 +281,11 @@ def contradiction(self) -> bool: return self.is_open != (len(self.open_periods) > 0) def any_overlaps(self) -> bool: + """Returns whether any of the open periods overlap.""" return OpenPeriod.any_overlaps(self.open_periods) def all_start_before_end(self) -> bool: + """Returns whether all open periods start before they end.""" return OpenPeriod.all_start_before_end(self.open_periods) def is_valid(self) -> bool: @@ -181,36 +293,44 @@ def is_valid(self) -> bool: return self.all_start_before_end() and (not self.any_overlaps()) and (not self.contradiction()) @staticmethod - def equal_lists(a: List["SpecifiedOpeningTime"], b: List["SpecifiedOpeningTime"]) -> bool: - """Checks equality between 2 lists of SpecifiedOpeningTime Relies on equality, - and hash functions of SpecifiedOpeningTime""" + def equal_lists(a: list["SpecifiedOpeningTime"], b: list["SpecifiedOpeningTime"]) -> bool: + """Checks equality between 2 lists of SpecifiedOpeningTime. + + Checks equality between 2 lists of SpecifiedOpeningTime Relies on equality, + and hash functions of SpecifiedOpeningTime. + """ hash_list_a = [hash(a) for a in a] hash_list_b = [hash(b) for b in b] return sorted(hash_list_a) == sorted(hash_list_b) @staticmethod - def valid_list(list: List["SpecifiedOpeningTime"]) -> bool: - return all([x.is_valid() for x in list]) + def valid_list(times_list: list["SpecifiedOpeningTime"]) -> bool: + """Checks whether a list of SpecifiedOpeningTime is valid.""" + return all(x.is_valid() for x in times_list) @staticmethod - def remove_past_dates(list: List["SpecifiedOpeningTime"], date_now=None) -> List["SpecifiedOpeningTime"]: + def remove_past_dates( + times_list: list["SpecifiedOpeningTime"], + date_now: Any = None, # noqa: ANN401 + ) -> list["SpecifiedOpeningTime"]: + """Removes any SpecifiedOpeningTime objects from the list that are in the past.""" if date_now is None: - date_now = datetime.now().date() + date_now = datetime.now().date() # noqa: DTZ005 future_dates = [] - for item in list: + for item in times_list: if item.date >= date_now: future_dates.append(item) return future_dates def export_test_format(self) -> dict: - """Exports Specified opening time into a test format that can be used in the tests""" + """Exports Specified opening time into a test format that can be used in the tests.""" exp_open_periods = [op.export_test_format() for op in sorted(self.open_periods)] date_str = self.date.strftime(DOS_DATE_FORMAT) return {date_str: exp_open_periods} @staticmethod - def export_test_format_list(spec_opening_dates: List["SpecifiedOpeningTime"]) -> dict: - """Runs the export_test_format on a list of SpecifiedOpeningTime objects and combines the results""" + def export_test_format_list(spec_opening_dates: list["SpecifiedOpeningTime"]) -> dict: + """Runs the export_test_format on a list of SpecifiedOpeningTime objects and combines the results.""" opening_dates_cr_format = {} for spec_open_date in spec_opening_dates: spec_open_date_payload = spec_open_date.export_test_format() @@ -219,7 +339,7 @@ def export_test_format_list(spec_opening_dates: List["SpecifiedOpeningTime"]) -> class StandardOpeningTimes: - """Represents the standard openings times for a week. Structured as a set of OpenPeriods per day + """Represents the standard openings times for a week. Structured as a set of OpenPeriods per day. monday: [OpenPeriod1, OpenPeriod2] tuesday: [OpenPeriod1] @@ -229,48 +349,48 @@ class StandardOpeningTimes: An empty list that no open periods means CLOSED """ - def __init__(self): - # Initialise all weekday OpenPeriod lists as empty + def __init__(self) -> None: + """Initialises the StandardOpeningTimes object with empty lists for each day.""" for day in WEEKDAYS: setattr(self, day, []) self.generic_bankholiday = [] self.explicit_closed_days = set() - def __repr__(self): + def __repr__(self) -> str: + """Returns a string representation of the StandardOpeningTimes object.""" closed_days_str = "" if len(self.explicit_closed_days) > 0: closed_days_str = f" exp_closed_days={self.explicit_closed_days}" - return f"" + return f"" - def __str__(self): + def __str__(self) -> str: + """Returns a string representation of the StandardOpeningTimes object.""" return self.to_string(", ") - def __len__(self): + def __len__(self) -> int: + """Returns the number of OpenPeriods in the StandardOpeningTimes object.""" return sum([len(getattr(self, day)) for day in WEEKDAYS]) - def __eq__(self, other: "StandardOpeningTimes"): - """Check equality of 2 StandardOpeningTimes (generic bankholiday values are ignored)""" - + def __eq__(self, other: "StandardOpeningTimes") -> bool: + """Check equality of 2 StandardOpeningTimes (generic bankholiday values are ignored).""" if not isinstance(other, StandardOpeningTimes): return False if self.all_closed_days() != other.all_closed_days(): return False - for day in WEEKDAYS: - if not OpenPeriod.equal_lists(self.get_openings(day), other.get_openings(day)): - return False - - return True + return all(OpenPeriod.equal_lists(self.get_openings(day), other.get_openings(day)) for day in WEEKDAYS) def to_string(self, seperator: str = ", ") -> str: + """Returns a string representation of the StandardOpeningTimes object.""" return seperator.join([f"{day}={OpenPeriod.list_string(getattr(self, day))}" for day in WEEKDAYS]) - def get_openings(self, day: str) -> List[OpenPeriod]: + def get_openings(self, day: str) -> list[OpenPeriod]: + """Returns the list of OpenPeriods for the given day.""" return getattr(self, day.lower()) - def all_closed_days(self) -> List[str]: + def all_closed_days(self) -> list[str]: """Returns a set of all implicit AND explicit closed days.""" all_closed_days = self.explicit_closed_days @@ -282,20 +402,19 @@ def all_closed_days(self) -> List[str]: return all_closed_days def fully_closed(self) -> bool: - """"Returns whether the object contains any openings""" - for day in WEEKDAYS: - if len(getattr(self, day)) > 0: - return False - return True + """Returns whether the object contains any openings.""" + return all(len(getattr(self, day)) <= 0 for day in WEEKDAYS) def is_open(self, weekday: str) -> bool: + """Returns whether the object contains any openings for the given day.""" return len(getattr(self, weekday)) > 0 def same_openings(self, other: "StandardOpeningTimes", day: str) -> bool: + """Returns whether the object contains the same openings for the given day.""" return OpenPeriod.equal_lists(self.get_openings(day), other.get_openings(day)) def add_open_period(self, open_period: OpenPeriod, weekday: str) -> None: - """Adds a formatted open period to the specified weekda + """Adds a formatted open period to the specified weekday. Args: open_period (OpenPeriod): The open period to add @@ -310,40 +429,34 @@ def add_open_period(self, open_period: OpenPeriod, weekday: str) -> None: else: logger.error(f"Cannot add opening time for invalid weekday '{weekday}', open period not added.") - def any_overlaps(self): - for weekday in WEEKDAYS: - if OpenPeriod.any_overlaps(getattr(self, weekday)): - return True - return False + def any_overlaps(self) -> bool: + """Returns True if any open period overlaps with another open period.""" + return any(OpenPeriod.any_overlaps(getattr(self, weekday)) for weekday in WEEKDAYS) - def all_start_before_end(self): - for weekday in WEEKDAYS: - if not OpenPeriod.all_start_before_end(getattr(self, weekday)): - return False - return True + def all_start_before_end(self) -> bool: + """Returns True if all open periods start before they end.""" + return all(OpenPeriod.all_start_before_end(getattr(self, weekday)) for weekday in WEEKDAYS) def any_contradictions(self) -> bool: """Returns True if any open period falls on a day that is marked as closed.""" - for weekday in self.explicit_closed_days: - if self.is_open(weekday): - return True - return False + return any(self.is_open(weekday) for weekday in self.explicit_closed_days) def is_valid(self) -> bool: + """Returns True if the object is valid.""" return self.all_start_before_end() and not self.any_overlaps() and not self.any_contradictions() def export_opening_times_for_day(self, weekday: str) -> list[str]: - """Exports standard opening times into DoS format for a specific day in the week""" + """Exports standard opening times into DoS format for a specific day in the week.""" open_periods = sorted(getattr(self, weekday)) return [open_period.export_db_string_format() for open_period in open_periods] def export_opening_times_in_seconds_for_day(self, weekday: str) -> list[str]: - """Exports standard opening times into time in seconds format for a specific day in the week""" + """Exports standard opening times into time in seconds format for a specific day in the week.""" open_periods = sorted(getattr(self, weekday)) return [open_period.export_time_in_seconds() for open_period in open_periods] - def export_test_format(self) -> Dict[str, List[Dict[str, str]]]: - """Exports standard opening times into a test format""" + def export_test_format(self) -> dict[str, list[dict[str, str]]]: + """Exports standard opening times into a test format.""" change = {} for weekday in WEEKDAYS: open_periods = sorted(getattr(self, weekday)) @@ -351,8 +464,8 @@ def export_test_format(self) -> Dict[str, List[Dict[str, str]]]: return change -def opening_period_times_from_list(open_periods: List[OpenPeriod], with_space: bool = True) -> str: - """Converts a list of OpenPeriods into a string of times separated by a space +def opening_period_times_from_list(open_periods: list[OpenPeriod], with_space: bool = True) -> str: + """Converts a list of OpenPeriods into a string of times separated by a space. Args: open_periods (List[OpenPeriod]): The list of OpenPeriods to convert @@ -365,7 +478,8 @@ def opening_period_times_from_list(open_periods: List[OpenPeriod], with_space: b ) -def string_to_time(time_str: str) -> Optional[time]: +def string_to_time(time_str: str) -> time | None: + """Converts a string to a time object.""" for time_format in ("%H:%M", "%H:%M:%S"): with suppress(ValueError): return datetime.strptime(str(time_str), time_format).time() diff --git a/application/common/report_logging.py b/application/common/report_logging.py index b3c9b33e1..3a4636a78 100644 --- a/application/common/report_logging.py +++ b/application/common/report_logging.py @@ -1,6 +1,6 @@ import json from os import environ -from typing import List, Union +from typing import Any from aws_embedded_metrics import metric_scope from aws_lambda_powertools.logging.logger import Logger @@ -19,15 +19,15 @@ UNMATCHED_PHARMACY_REPORT_ID, UNMATCHED_SERVICE_TYPE_REPORT_ID, ) -from common.dos import DoSService, VALID_STATUS_ID +from common.dos import VALID_STATUS_ID, DoSService from common.nhs import NHSEntity from common.opening_times import OpenPeriod logger = Logger(child=True) -def log_blank_standard_opening_times(nhs_entity: NHSEntity, matching_services: List[DoSService]) -> None: - """Log events where matches services are found but no std opening times exist +def log_blank_standard_opening_times(nhs_entity: NHSEntity, matching_services: list[DoSService]) -> None: + """Log events where matches services are found but no std opening times exist. Args: nhs_entity (NHSEntity): The NHS entity to report @@ -52,8 +52,8 @@ def log_blank_standard_opening_times(nhs_entity: NHSEntity, matching_services: L ) -def log_closed_or_hidden_services(nhs_entity: NHSEntity, matching_services: List[DoSService]) -> None: - """Log closed or hidden NHS UK services +def log_closed_or_hidden_services(nhs_entity: NHSEntity, matching_services: list[DoSService]) -> None: + """Log closed or hidden NHS UK services. Args: nhs_entity (NHSEntity): The NHS entity to report @@ -79,11 +79,11 @@ def log_closed_or_hidden_services(nhs_entity: NHSEntity, matching_services: List def log_unmatched_nhsuk_service(nhs_entity: NHSEntity) -> None: - """Log unmatched NHS Services + """Log unmatched NHS Services. + Args: - nhs_entity (NHSEntity): NHS entity to log + nhs_entity (NHSEntity): NHS entity to log. """ - logger.warning( f"No matching DOS services found that fit all criteria for ODSCode '{nhs_entity.odscode}'", extra={ @@ -105,11 +105,13 @@ def log_unmatched_nhsuk_service(nhs_entity: NHSEntity) -> None: @metric_scope -def log_invalid_nhsuk_postcode(nhs_entity: NHSEntity, dos_service: DoSService, metrics) -> None: - """Log invalid NHS pharmacy postcode +def log_invalid_nhsuk_postcode(nhs_entity: NHSEntity, dos_service: DoSService, metrics: Any) -> None: # noqa: ANN401 + """Log invalid NHS pharmacy postcode. + Args: nhs_entity (NHSEntity): The NHS entity to report - dos_service (List[DoSService]): The list of DoS matching services + dos_service (List[DoSService]): The list of DoS matching services. + metrics (Any): The metrics object to report to. """ error_msg = f"NHS entity '{nhs_entity.odscode}' postcode '{nhs_entity.postcode}' is not a valid DoS postcode!" logger.warning( @@ -141,12 +143,17 @@ def log_invalid_nhsuk_postcode(nhs_entity: NHSEntity, dos_service: DoSService, m @metric_scope -def log_invalid_open_times(nhs_entity: NHSEntity, matching_services: List[DoSService], metrics) -> None: - """Report invalid open times for nhs entity +def log_invalid_open_times( + nhs_entity: NHSEntity, + matching_services: list[DoSService], + metrics: Any, # noqa: ANN401 +) -> None: + """Report invalid open times for nhs entity. Args: nhs_entity (NHSEntity): The NHS entity to report matching_services (List[DoSService]): The list of DoS matching services + metrics (Any): The metrics object to report to. """ error_msg = f"NHS Entity '{nhs_entity.odscode}' has a misformatted or illogical set of opening times." logger.warning( @@ -167,11 +174,12 @@ def log_invalid_open_times(nhs_entity: NHSEntity, matching_services: List[DoSSer metrics.put_metric("InvalidOpenTimes", 1, "Count") -def log_unmatched_service_types(nhs_entity: NHSEntity, unmatched_services: List[DoSService]) -> None: - """Log unmatched DOS service types +def log_unmatched_service_types(nhs_entity: NHSEntity, unmatched_services: list[DoSService]) -> None: + """Log unmatched DOS service types. + Args: nhs_entity (NHSEntity): The NHS entity to report - unmatched_services (List[DoSService]): The list of DoS unmatched services + unmatched_services (List[DoSService]): The list of DoS unmatched services. """ for unmatched_service in unmatched_services: logger.warning( @@ -195,8 +203,12 @@ def log_unmatched_service_types(nhs_entity: NHSEntity, unmatched_services: List[ def log_service_with_generic_bank_holiday(nhs_entity: NHSEntity, dos_service: DoSService) -> None: - """Log a service found to have a generic bank holiday open times set in DoS.""" + """Log a service found to have a generic bank holiday open times set in DoS. + Args: + nhs_entity (NHSEntity): The NHS entity to report + dos_service (DoSService): The DoS service to report + """ open_periods_str = OpenPeriod.list_string(dos_service.standard_opening_times.generic_bankholiday) logger.warning( @@ -216,6 +228,12 @@ def log_service_with_generic_bank_holiday(nhs_entity: NHSEntity, dos_service: Do def log_website_is_invalid(nhs_uk_entity: NHSEntity, nhs_website: str) -> None: + """Log a service found to have an invalid website. + + Args: + nhs_uk_entity (NHSEntity): The NHS entity to report + nhs_website (str): The NHS website to report + """ logger.warning( f"Website is not valid, {nhs_website=}", extra={ @@ -229,6 +247,11 @@ def log_website_is_invalid(nhs_uk_entity: NHSEntity, nhs_website: str) -> None: def log_palliative_care_z_code_does_not_exist(symptom_group_symptom_discriminator_combo_rowcount: int) -> None: + """Log a service found to have an invalid website. + + Args: + symptom_group_symptom_discriminator_combo_rowcount (int): The number of rows returned from the database query + """ logger.warning( "Palliative care Z code does not exist in the DoS database", extra={ @@ -241,7 +264,7 @@ def log_palliative_care_z_code_does_not_exist(symptom_group_symptom_discriminato ) -def log_service_updated( +def log_service_updated( # noqa: PLR0913 action: str, data_field_modified: str, new_value: str, @@ -250,6 +273,17 @@ def log_service_updated( service_uid: str, type_id: str, ) -> None: + """Log a service update. + + Args: + action (str): The action performed + data_field_modified (str): The data field modified + new_value (str): The new value + previous_value (str): The previous value + service_name (str): The service name + service_uid (str): The service uid + type_id (str): The type id + """ logger.warning( "Service update complete", extra={ @@ -267,6 +301,12 @@ def log_service_updated( def log_palliative_care_not_equal(nhs_uk_palliative_care: bool, dos_palliative_care: bool) -> None: + """Log a service found to have an invalid website. + + Args: + nhs_uk_palliative_care (bool): The NHS website to report + dos_palliative_care (bool): The NHS entity to report + """ logger.warning( "Palliative care not equal", extra={ @@ -278,8 +318,17 @@ def log_palliative_care_not_equal(nhs_uk_palliative_care: bool, dos_palliative_c def log_incorrect_palliative_stockholder_type( - nhs_uk_palliative_care: Union[bool, str], dos_palliative_care: bool, dos_service: DoSService + nhs_uk_palliative_care: bool | str, + dos_palliative_care: bool, + dos_service: DoSService, ) -> None: + """Log a service found to have an invalid website. + + Args: + nhs_uk_palliative_care (bool): The NHS website to report + dos_palliative_care (bool): The NHS entity to report + dos_service (DoSService): The DoS service to report + """ logger.warning( "Palliative care on wrong service type", extra={ @@ -291,7 +340,13 @@ def log_incorrect_palliative_stockholder_type( ) -def log_unexpected_pharmacy_profiling(matching_services: List[DoSService], reason: str) -> None: +def log_unexpected_pharmacy_profiling(matching_services: list[DoSService], reason: str) -> None: + """Log a service found to have an invalid website. + + Args: + matching_services (list[DoSService]): The DoS services to report + reason (str): The reason for the report + """ for service in matching_services: logger.warning( "Pharmacy profiling is incorrect", diff --git a/application/common/s3.py b/application/common/s3.py index 8c682c034..59c1c4633 100644 --- a/application/common/s3.py +++ b/application/common/s3.py @@ -7,11 +7,11 @@ def put_content_to_s3(content: bytes, s3_filename: str) -> None: - """Upload a file contents to S3 + """Upload a file contents to S3. Args: content (bytes): File contents - s3_file_name (str): The filename when the file is stored in S3 + s3_filename (str): The filename when the file is stored in S3 """ bucket = getenv("SEND_EMAIL_BUCKET_NAME") client("s3").put_object(Body=content, Bucket=bucket, Key=s3_filename, ServerSideEncryption="AES256") diff --git a/application/common/secretsmanager.py b/application/common/secretsmanager.py index a735d14fa..eaf9711d9 100644 --- a/application/common/secretsmanager.py +++ b/application/common/secretsmanager.py @@ -1,5 +1,4 @@ from json import loads -from typing import Dict from aws_lambda_powertools.logging import Logger from boto3 import client @@ -10,8 +9,8 @@ secrets_manager = client(service_name="secretsmanager") -def get_secret(secret_name: str) -> Dict[str, str]: - """Get the secret from AWS Secrets Manager +def get_secret(secret_name: str) -> dict[str, str]: + """Get the secret from AWS Secrets Manager. Args: secret_name (str): Secret name to get @@ -25,7 +24,7 @@ def get_secret(secret_name: str) -> Dict[str, str]: try: secret_value_response = secrets_manager.get_secret_value(SecretId=secret_name) except ClientError as err: - raise Exception(f"Failed getting secret '{secret_name}' from secrets manager") from err + msg = f"Failed getting secret '{secret_name}' from secrets manager" + raise Exception(msg) from err # noqa: TRY002 secrets_json_str = secret_value_response["SecretString"] - secrets = loads(secrets_json_str) - return secrets + return loads(secrets_json_str) diff --git a/application/common/service_type.py b/application/common/service_type.py index 2ed04c049..21f36df35 100644 --- a/application/common/service_type.py +++ b/application/common/service_type.py @@ -1,4 +1,3 @@ -from typing import List from aws_lambda_powertools.logging import Logger @@ -7,8 +6,8 @@ logger = Logger(child=True) -def get_valid_service_types(organisation_type_id: str) -> List[int]: - """Get the valid service types for the organisation type id +def get_valid_service_types(organisation_type_id: str) -> list[int]: + """Get the valid service types for the organisation type id. Args: organisation_type_id (str): organisation type id from nhs uk entity diff --git a/application/common/tests/conftest.py b/application/common/tests/conftest.py index 944a847cf..ed3e34df6 100644 --- a/application/common/tests/conftest.py +++ b/application/common/tests/conftest.py @@ -3,24 +3,27 @@ from os import environ, path from pathlib import Path from random import choices, randint, uniform +from typing import Any +import pytest +from aws_lambda_powertools.utilities.typing import LambdaContext from boto3 import Session from moto import mock_dynamodb -from pytest import fixture -from ..dos import DoSLocation, DoSService -from ..opening_times import StandardOpeningTimes +from application.common.dos import DoSLocation, DoSService +from application.common.opening_times import StandardOpeningTimes STD_EVENT_PATH = path.join(Path(__file__).parent.resolve(), "STANDARD_EVENT.json") -with open(STD_EVENT_PATH, "r", encoding="utf8") as file: +with open(STD_EVENT_PATH, encoding="utf8") as file: PHARMACY_STANDARD_EVENT = json.load(file) STD_EVENT_STAFF_PATH = path.join(Path(__file__).parent.resolve(), "STANDARD_EVENT_WITH_STAFF.json") -with open(STD_EVENT_STAFF_PATH, "r", encoding="utf8") as file: +with open(STD_EVENT_STAFF_PATH, encoding="utf8") as file: PHARMACY_STANDARD_EVENT_STAFF = json.load(file) -def get_std_event(**kwargs) -> dict: +def get_std_event(**kwargs: Any) -> dict: # noqa: ANN401 + """Creates a standard event with random data for the unit testing.""" event = PHARMACY_STANDARD_EVENT.copy() for name, value in kwargs.items(): if value is not None: @@ -28,8 +31,8 @@ def get_std_event(**kwargs) -> dict: return event -def dummy_dos_service(**kwargs) -> DoSService: - """Creates a DoSService Object with random data for the unit testing""" +def dummy_dos_service(**kwargs: Any) -> DoSService: # noqa: ANN401 + """Creates a DoSService Object with random data for the unit testing.""" test_data = {} for col in DoSService.field_names(): random_str = "".join(choices("ABCDEFGHIJKLM", k=8)) @@ -45,8 +48,8 @@ def dummy_dos_service(**kwargs) -> DoSService: return dos_service -def blank_dos_service(**kwargs) -> DoSService: - """Creates a DoSService Object with blank str data for the unit testing""" +def blank_dos_service(**kwargs: Any) -> DoSService: # noqa: ANN401 + """Creates a DoSService Object with blank str data for the unit testing.""" test_data = {} for col in DoSService.field_names(): test_data[col] = "" @@ -60,7 +63,7 @@ def blank_dos_service(**kwargs) -> DoSService: def dummy_dos_location() -> DoSLocation: - """Creates a DoSLocation Object with random data for the unit testing""" + """Creates a DoSLocation Object with random data for the unit testing.""" return DoSLocation( id=randint(1111, 9999), postcode="".join(choices("01234567890ABCDEFGHIJKLM", k=6)), @@ -72,42 +75,46 @@ def dummy_dos_location() -> DoSLocation: ) -@fixture -def change_event(): - change_event = PHARMACY_STANDARD_EVENT.copy() - yield change_event +@pytest.fixture() +def change_event() -> dict: + """Generate a change event for testing.""" + return PHARMACY_STANDARD_EVENT.copy() -@fixture -def aws_credentials(): +@pytest.fixture() +def _aws_credentials() -> None: """Mocked AWS Credentials for moto.""" environ["AWS_ACCESS_KEY_ID"] = "testing" - environ["AWS_SECRET_ACCESS_KEY"] = "testing" - environ["AWS_SECURITY_TOKEN"] = "testing" - environ["AWS_SESSION_TOKEN"] = "testing" + environ["AWS_SECRET_ACCESS_KEY"] = "testing" # noqa: S105 + environ["AWS_SECURITY_TOKEN"] = "testing" # noqa: S105 + environ["AWS_SESSION_TOKEN"] = "testing" # noqa: S105 environ["CHANGE_EVENTS_TABLE_NAME"] = "CHANGE_EVENTS_TABLE" environ["AWS_REGION"] = "us-east-2" -@fixture -def dynamodb_client(boto_session): - yield boto_session.client("dynamodb", region_name=environ["AWS_REGION"]) +@pytest.fixture() +def dynamodb_client(boto_session: Any) -> Any: # noqa: ANN401 + """DynamoDB Client Class.""" + return boto_session.client("dynamodb", region_name=environ["AWS_REGION"]) -@fixture -def dynamodb_resource(boto_session): - yield boto_session.resource("dynamodb", region_name=environ["AWS_REGION"]) +@pytest.fixture() +def dynamodb_resource(boto_session: Any) -> Any: # noqa: ANN401 + """DynamoDB Resource Class.""" + return boto_session.resource("dynamodb", region_name=environ["AWS_REGION"]) -@fixture -def boto_session(aws_credentials): +@pytest.fixture() +def boto_session(_aws_credentials: Any) -> Any: # noqa: ANN401 + """Mocked AWS Credentials for moto.""" with mock_dynamodb(): yield Session() -@fixture -def dead_letter_message(): - yield { +@pytest.fixture() +def dead_letter_message() -> dict: + """Generate a dead letter message for testing.""" + return { "Records": [ { "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", @@ -149,13 +156,15 @@ def dead_letter_message(): "eventSource": "aws:sqs", "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:cr-fifo-dlq-queue", "awsRegion": "us-east-2", - } - ] + }, + ], } -@fixture -def lambda_context(): +@pytest.fixture() +def lambda_context() -> LambdaContext: + """Generate a lambda context for testing.""" + @dataclass class LambdaContext: function_name: str = "service-matcher" diff --git a/application/common/tests/test_appconfig.py b/application/common/tests/test_appconfig.py index 7d977c51d..573a61445 100644 --- a/application/common/tests/test_appconfig.py +++ b/application/common/tests/test_appconfig.py @@ -1,7 +1,7 @@ from os import environ from unittest.mock import MagicMock, patch -from ..appconfig import AppConfig +from application.common.appconfig import AppConfig FILE_PATH = "application.common.appconfig" @@ -16,7 +16,7 @@ def test_app_config(mock_app_config_store): AppConfig(feature_flags_name) # Assert mock_app_config_store.assert_called_once_with( - environment=environment, application=f"uec-dos-int-{environment}-lambda-app-config", name=feature_flags_name + environment=environment, application=f"uec-dos-int-{environment}-lambda-app-config", name=feature_flags_name, ) # Clean up del environ["SHARED_ENVIRONMENT"] diff --git a/application/common/tests/test_dos.py b/application/common/tests/test_dos.py index aac9c6377..52b42c7ea 100644 --- a/application/common/tests/test_dos.py +++ b/application/common/tests/test_dos.py @@ -2,12 +2,13 @@ from random import choices from unittest.mock import MagicMock, patch -from ..dos import ( +from .conftest import dummy_dos_service +from application.common.dos import ( + DoSService, db_rows_to_spec_open_times, db_rows_to_spec_open_times_map, db_rows_to_std_open_times, db_rows_to_std_open_times_map, - DoSService, get_all_valid_dos_postcodes, get_dos_locations, get_matching_dos_services, @@ -17,8 +18,7 @@ get_valid_dos_location, has_palliative_care, ) -from ..opening_times import OpenPeriod, SpecifiedOpeningTime, StandardOpeningTimes -from .conftest import dummy_dos_service +from application.common.opening_times import OpenPeriod, SpecifiedOpeningTime, StandardOpeningTimes from common.constants import ( DENTIST_ORG_TYPE_ID, DOS_PALLIATIVE_CARE_SYMPTOM_DISCRIMINATOR, @@ -53,9 +53,6 @@ def test_field_names(): def test__init__(): - """Pass in random list of values as a mock database row then make sure - they're correctly set as the attributes of the created object. - """ # Arrange test_db_row = {} for column in DoSService.field_names(): @@ -134,7 +131,7 @@ def test_get_matching_dos_services_pharmacy_services_returned(mock_query_dos_db, "statusid, publicphone, publicname, st.name servicename FROM services s " "LEFT JOIN servicetypes st ON s.typeid = st.id WHERE odscode LIKE %(ODS)s" ), - vars={"ODS": f"{odscode[:5]}%"}, + query_vars={"ODS": f"{odscode[:5]}%"}, ) mock_cursor.fetchall.assert_called_with() mock_cursor.close.assert_called_with() @@ -200,7 +197,7 @@ def test_get_matching_dos_services_dentist_services_returned(mock_query_dos_db, "st.name servicename FROM services s LEFT JOIN servicetypes st ON s.typeid = st.id WHERE " "odscode = %(ODS)s or odscode LIKE %(ODS7)s" ), - vars={"ODS": f"{ods6_code}", "ODS7": f"{odscode}%"}, + query_vars={"ODS": f"{ods6_code}", "ODS7": f"{odscode}%"}, ) mock_cursor.fetchall.assert_called_with() mock_cursor.close.assert_called_with() @@ -228,7 +225,7 @@ def test_get_matching_dos_services_no_services_returned(mock_query_dos_db, mock_ "publicphone, publicname, st.name servicename FROM services s LEFT JOIN servicetypes" " st ON s.typeid = st.id WHERE odscode LIKE %(ODS)s" ), - vars={"ODS": f"{odscode[:5]}%"}, + query_vars={"ODS": f"{odscode[:5]}%"}, ) mock_cursor.fetchall.assert_called_with() mock_cursor.close.assert_called_with() @@ -302,7 +299,7 @@ def test_get_specified_opening_times_from_db_times_returned(mock_query_dos_db, m "", "", "", - ] + ], ) # Act responses = get_specified_opening_times_from_db(connection=mock_connection, service_id=service_id) @@ -319,7 +316,7 @@ def test_get_specified_opening_times_from_db_times_returned(mock_query_dos_db, m "INNER JOIN servicespecifiedopeningtimes ssot " "ON ssod.id = ssot.servicespecifiedopeningdateid " "WHERE ssod.serviceid = %(SERVICE_ID)s", - vars={"SERVICE_ID": service_id}, + query_vars={"SERVICE_ID": service_id}, ) @@ -358,7 +355,7 @@ def test_get_standard_opening_times_from_db_times_returned(mock_query_dos_db, mo "LEFT JOIN openingtimedays otd " "ON sdo.dayid = otd.id " "WHERE sdo.serviceid = %(SERVICE_ID)s", - vars={"SERVICE_ID": service_id}, + query_vars={"SERVICE_ID": service_id}, ) @@ -389,7 +386,7 @@ def test_get_specified_opening_times_from_db_no_times_returned(mock_query_dos_db "INNER JOIN servicespecifiedopeningtimes ssot " "ON ssod.id = ssot.servicespecifiedopeningdateid " "WHERE ssod.serviceid = %(SERVICE_ID)s", - vars={"SERVICE_ID": service_id}, + query_vars={"SERVICE_ID": service_id}, ) @@ -410,7 +407,7 @@ def test_get_dos_locations(mock_query_dos_db, mock_connect_to_dos_db_replica): "postaltown": "town", "latitude": 4.0, "longitude": 2.0, - } + }, ] mock_cursor.fetchall.return_value = db_return mock_query_dos_db.return_value = mock_cursor @@ -433,7 +430,7 @@ def test_get_dos_locations(mock_query_dos_db, mock_connect_to_dos_db_replica): connection=mock_connection, query="SELECT id, postcode, easting, northing, postaltown, latitude, longitude " "FROM locations WHERE postcode = ANY(%(pc_variations)s)", - vars={"pc_variations": postcode_variations}, + query_vars={"pc_variations": postcode_variations}, ) @@ -548,7 +545,7 @@ def test_get_services_from_db(mock_query_dos_db, mock_connect_to_dos_db_replica) "friday=[13:00:00-15:30:00], saturday=[], sunday=[]>" "" "", - ] + ], ) # Act @@ -755,7 +752,7 @@ def test_db_rows_to_std_open_times_map(): assert actual_std_open_times_map == expcted_std_open_times_map -def get_db_item(odscode="FA9321", name="fake name", id=9999, typeid=13): +def get_db_item(odscode="FA9321", name="fake name", id=9999, typeid=13): # noqa: A002 return { "id": id, "uid": "159514725", @@ -800,7 +797,9 @@ def test_has_palliative_care(mock_query_dos_db: MagicMock): assert True is has_palliative_care(dos_service, connection) # Assert mock_query_dos_db.assert_called_once_with( - connection=connection, query=expected_sql_command, vars=expected_named_args + connection=connection, + query=expected_sql_command, + query_vars=expected_named_args, ) diff --git a/application/common/tests/test_dos_db_connection.py b/application/common/tests/test_dos_db_connection.py index 391dfa6b5..14e10cdb4 100644 --- a/application/common/tests/test_dos_db_connection.py +++ b/application/common/tests/test_dos_db_connection.py @@ -3,7 +3,12 @@ from psycopg.rows import dict_row -from ..dos_db_connection import connect_to_dos_db, connect_to_dos_db_replica, connection_to_db, query_dos_db +from application.common.dos_db_connection import ( + connect_to_dos_db, + connect_to_dos_db_replica, + connection_to_db, + query_dos_db, +) FILE_PATH = "application.common.dos_db_connection" diff --git a/application/common/tests/test_dos_location.py b/application/common/tests/test_dos_location.py index 1c7010e4e..5c3abc554 100644 --- a/application/common/tests/test_dos_location.py +++ b/application/common/tests/test_dos_location.py @@ -1,52 +1,51 @@ import pytest -from application.common.dos_location import DoSLocation - from .conftest import dummy_dos_location +from application.common.dos_location import DoSLocation @pytest.mark.parametrize( - "dos_location, expected_result", + ("dos_location", "expected_result"), [ ( DoSLocation( - id=1, postcode="TE57ER", easting=None, northing=None, postaltown="TOWN", latitude=None, longitude=None + id=1, postcode="TE57ER", easting=None, northing=None, postaltown="TOWN", latitude=None, longitude=None, ), False, ), ( DoSLocation( - id=1, postcode="TE57ER", easting=None, northing=1, postaltown="TOWN", latitude=1.1, longitude=1.1 + id=1, postcode="TE57ER", easting=None, northing=1, postaltown="TOWN", latitude=1.1, longitude=1.1, ), False, ), ( DoSLocation( - id=1, postcode="TE57ER", easting=1, northing=None, postaltown="TOWN", latitude=1.1, longitude=1.1 + id=1, postcode="TE57ER", easting=1, northing=None, postaltown="TOWN", latitude=1.1, longitude=1.1, ), False, ), ( DoSLocation( - id=1, postcode="TE57ER", easting=1, northing=1, postaltown="TOWN", latitude=None, longitude=1.1 + id=1, postcode="TE57ER", easting=1, northing=1, postaltown="TOWN", latitude=None, longitude=1.1, ), False, ), ( DoSLocation( - id=1, postcode="TE57ER", easting=1, northing=1, postaltown="TOWN", latitude=1.1, longitude=None + id=1, postcode="TE57ER", easting=1, northing=1, postaltown="TOWN", latitude=1.1, longitude=None, ), False, ), ( DoSLocation( - id=1, postcode="TE57ER", easting=None, northing=None, postaltown="TOWN", latitude=1.1, longitude=1.1 + id=1, postcode="TE57ER", easting=None, northing=None, postaltown="TOWN", latitude=1.1, longitude=1.1, ), False, ), ( DoSLocation( - id=1, postcode="TE57ER", easting=1, northing=1, postaltown="TOWN", latitude=None, longitude=None + id=1, postcode="TE57ER", easting=1, northing=1, postaltown="TOWN", latitude=None, longitude=None, ), False, ), @@ -64,7 +63,7 @@ def test_doslocation_is_valid(dos_location: DoSLocation, expected_result: bool): @pytest.mark.parametrize( - "input_postcode, expected_result", + ("input_postcode", "expected_result"), [ ("TE57ER", "TE57ER"), ("TE5 7ER", "TE57ER"), diff --git a/application/common/tests/test_dynamodb.py b/application/common/tests/test_dynamodb.py index a612efeb5..32ca59185 100644 --- a/application/common/tests/test_dynamodb.py +++ b/application/common/tests/test_dynamodb.py @@ -2,19 +2,20 @@ from json import dumps, loads from os import environ from time import time +from typing import Any from unittest.mock import patch +import pytest from aws_lambda_powertools.logging import Logger from boto3.dynamodb.types import TypeDeserializer -from pytest import fixture, raises FILE_PATH = "application.common.dynamodb" -@fixture -def dynamodb_table_create(dynamodb_client): - """Create a DynamoDB CHANGE_EVENTS_TABLE table fixture.""" - table = dynamodb_client.create_table( +@pytest.fixture() +def dynamodb_table_create(dynamodb_client) -> dict[str, Any]: + """Create a DynamoDB CHANGE_EVENTS_TABLE table pytest.fixture.""" + return dynamodb_client.create_table( TableName=environ["CHANGE_EVENTS_TABLE_NAME"], BillingMode="PAY_PER_REQUEST", KeySchema=[ @@ -34,52 +35,49 @@ def dynamodb_table_create(dynamodb_client): {"AttributeName": "SequenceNumber", "KeyType": "RANGE"}, ], "Projection": {"ProjectionType": "ALL"}, - } + }, ], ) - return table def test_get_circuit_is_open_none(dynamodb_table_create, dynamodb_client): - from ..dynamodb import get_circuit_is_open + from application.common.dynamodb import get_circuit_is_open - is_open = get_circuit_is_open("BLABLABLA") - assert is_open is None + assert get_circuit_is_open("BLABLABLA") is None def test_put_and_get_circuit_is_open(dynamodb_table_create, dynamodb_client): - from ..dynamodb import get_circuit_is_open, put_circuit_is_open + from application.common.dynamodb import get_circuit_is_open, put_circuit_is_open put_circuit_is_open("TESTCIRCUIT", True) - is_open = get_circuit_is_open("TESTCIRCUIT") - assert is_open + assert get_circuit_is_open("TESTCIRCUIT") def test_put_circuit_exception(dynamodb_table_create, dynamodb_client): - from ..dynamodb import put_circuit_is_open + from application.common.dynamodb import put_circuit_is_open temp_table = environ["CHANGE_EVENTS_TABLE_NAME"] del environ["CHANGE_EVENTS_TABLE_NAME"] - with raises(Exception): + with pytest.raises(Exception): # noqa: PT011,B017 put_circuit_is_open("TESTCIRCUIT", True) environ["CHANGE_EVENTS_TABLE_NAME"] = temp_table def test_get_circuit_exception(dynamodb_table_create, dynamodb_client): - from ..dynamodb import get_circuit_is_open + from application.common.dynamodb import get_circuit_is_open temp_table = environ["CHANGE_EVENTS_TABLE_NAME"] del environ["CHANGE_EVENTS_TABLE_NAME"] - with raises(Exception): + with pytest.raises(Exception): # noqa: PT011,B017 get_circuit_is_open("TESTCIRCUIT") environ["CHANGE_EVENTS_TABLE_NAME"] = temp_table def test_add_change_event_to_dynamodb(dynamodb_table_create, change_event, dynamodb_client): - from ..dynamodb import add_change_event_to_dynamodb, dict_hash, TTL + from application.common.dynamodb import TTL, add_change_event_to_dynamodb, dict_hash # Arrange event_received_time = int(time()) @@ -105,9 +103,14 @@ def test_add_change_event_to_dynamodb(dynamodb_table_create, change_event, dynam def test_get_latest_sequence_id_for_same_change_event_from_dynamodb( - dynamodb_table_create, change_event, dynamodb_client + dynamodb_table_create, + change_event, + dynamodb_client, ): - from ..dynamodb import add_change_event_to_dynamodb, get_latest_sequence_id_for_a_given_odscode_from_dynamodb + from application.common.dynamodb import ( + add_change_event_to_dynamodb, + get_latest_sequence_id_for_a_given_odscode_from_dynamodb, + ) event_received_time = int(time()) add_change_event_to_dynamodb(change_event.copy(), 1, event_received_time) @@ -131,7 +134,10 @@ def test_get_latest_sequence_id_for_same_change_event_from_dynamodb( def test_same_sequence_id_and_same_change_event_multiple_times(dynamodb_table_create, change_event, dynamodb_client): - from ..dynamodb import add_change_event_to_dynamodb, get_latest_sequence_id_for_a_given_odscode_from_dynamodb + from application.common.dynamodb import ( + add_change_event_to_dynamodb, + get_latest_sequence_id_for_a_given_odscode_from_dynamodb, + ) event_received_time = int(time()) add_change_event_to_dynamodb(change_event.copy(), 3, event_received_time) @@ -153,7 +159,7 @@ def test_same_sequence_id_and_same_change_event_multiple_times(dynamodb_table_cr def test_no_records_in_db_for_a_given_odscode(dynamodb_table_create, change_event): - from ..dynamodb import get_latest_sequence_id_for_a_given_odscode_from_dynamodb + from application.common.dynamodb import get_latest_sequence_id_for_a_given_odscode_from_dynamodb latest_sequence_number = get_latest_sequence_id_for_a_given_odscode_from_dynamodb(change_event["ODSCode"]) assert latest_sequence_number == 0 @@ -161,10 +167,15 @@ def test_no_records_in_db_for_a_given_odscode(dynamodb_table_create, change_even @patch.object(Logger, "error") def test_get_latest_sequence_id_for_different_change_event_from_dynamodb( - mock_logger, dynamodb_table_create, change_event, dynamodb_client + mock_logger, + dynamodb_table_create, + change_event, + dynamodb_client, ): - - from ..dynamodb import add_change_event_to_dynamodb, get_latest_sequence_id_for_a_given_odscode_from_dynamodb + from application.common.dynamodb import ( + add_change_event_to_dynamodb, + get_latest_sequence_id_for_a_given_odscode_from_dynamodb, + ) event_received_time = int(time()) odscode = change_event["ODSCode"] @@ -241,21 +252,21 @@ def copy_and_modify_website(ce, new_website: str): def test_get_newest_event_per_odscode(dynamodb_table_create, change_event, dynamodb_client, dynamodb_resource): - from ..dynamodb import add_change_event_to_dynamodb, get_newest_event_per_odscode + from application.common.dynamodb import add_change_event_to_dynamodb, get_newest_event_per_odscode - ceAAA11 = change_event.copy() + ceAAA11 = change_event.copy() # noqa: N806 ceAAA11["ODSCode"] = "AAA11" add_change_event_to_dynamodb(ceAAA11, 301, int(time())) for i in range(10): add_change_event_to_dynamodb(ceAAA11, i, int(time())) - ceBBB22 = change_event.copy() + ceBBB22 = change_event.copy() # noqa: N806 ceBBB22["ODSCode"] = "BBB22" add_change_event_to_dynamodb(ceBBB22, 505, int(time())) for i in range(10): add_change_event_to_dynamodb(ceBBB22, i, int(time())) - ceCCC33 = change_event.copy() + ceCCC33 = change_event.copy() # noqa: N806 ceCCC33["ODSCode"] = "CCC33" add_change_event_to_dynamodb(ceCCC33, 400, int(time())) for i in range(10): diff --git a/application/common/tests/test_errors.py b/application/common/tests/test_errors.py index adc7f3f99..ee187973d 100644 --- a/application/common/tests/test_errors.py +++ b/application/common/tests/test_errors.py @@ -1,15 +1,17 @@ -from pytest import raises +import pytest -from ..errors import DynamoDBException, ValidationException +from application.common.errors import DynamoDBError, ValidationError def test_validation_exception(): # Arrange & Act - with raises(ValidationException): - raise ValidationException("Test") + with pytest.raises(ValidationError): # noqa: PT012 + msg = "Test" + raise ValidationError(msg) def test_dynamodb_exception(): # Arrange & Act - with raises(DynamoDBException): - raise DynamoDBException("Test") + with pytest.raises(DynamoDBError): # noqa: PT012 + msg = "Test" + raise DynamoDBError(msg) diff --git a/application/common/tests/test_middlewares.py b/application/common/tests/test_middlewares.py index 2009d81ca..1a82626a5 100644 --- a/application/common/tests/test_middlewares.py +++ b/application/common/tests/test_middlewares.py @@ -1,22 +1,23 @@ import logging +import re from json import dumps +import pytest from aws_lambda_powertools.utilities.data_classes import SQSEvent from botocore.exceptions import ClientError -from pytest import raises -from ..middlewares import ( +from application.common.middlewares import ( redact_staff_key_from_event, unhandled_exception_logging, unhandled_exception_logging_hidden_event, ) -from ..tests.conftest import PHARMACY_STANDARD_EVENT, PHARMACY_STANDARD_EVENT_STAFF -from ..utilities import extract_body +from application.common.tests.conftest import PHARMACY_STANDARD_EVENT, PHARMACY_STANDARD_EVENT_STAFF +from application.common.utilities import extract_body def test_redact_staff_key_from_event_with_no_staff_key(caplog): @redact_staff_key_from_event() - def dummy_handler(event, context): + def dummy_handler(event, context) -> SQSEvent: return event # Arrange @@ -32,12 +33,12 @@ def dummy_handler(event, context): def test_redact_staff_key_from_event(caplog): @redact_staff_key_from_event() - def dummy_handler(event, context): + def dummy_handler(event, context) -> SQSEvent: return event # Arrange event = SQS_EVENT.copy() - event['Records'][0]['body'] = dumps(PHARMACY_STANDARD_EVENT_STAFF.copy()) + event["Records"][0]["body"] = dumps(PHARMACY_STANDARD_EVENT_STAFF.copy()) assert "Staff" in extract_body(event["Records"][0]["body"]) # Act result = dummy_handler(event, None) @@ -48,7 +49,7 @@ def dummy_handler(event, context): def test_redact_staff_key_from_event_no_records(caplog): @redact_staff_key_from_event() - def dummy_handler(event, context): + def dummy_handler(event, context) -> SQSEvent: return event # Arrange @@ -63,27 +64,28 @@ def dummy_handler(event, context): def test_unhandled_exception_logging(caplog): @unhandled_exception_logging - def client_error_func(event, context): + def client_error_func(event, context) -> None: raise ClientError({"Error": {"Code": "dummy_error", "Message": "dummy_message"}}, "op_name") @unhandled_exception_logging - def regular_error_func(event, context): - raise Exception("dummy exception message") + def regular_error_func(event, context) -> None: + msg = "dummy exception message" + raise Exception(msg) # noqa: TRY002 with caplog.at_level(logging.ERROR): - - with raises(ClientError): + with pytest.raises( + ClientError, + match=re.escape("An error occurred (dummy_error) when calling the op_name operation: dummy_message"), + ): client_error_func(None, None) - assert "Boto3 Client Error - 'dummy_error': dummy_message" in caplog.text - with raises(Exception): + with pytest.raises(Exception, match="dummy exception message"): regular_error_func(None, None) - assert "dummy_error" in caplog.text def test_unhandled_exception_logging_no_error(): @unhandled_exception_logging - def dummy_handler(event, context): + def dummy_handler(event, context) -> None: pass # Arrange @@ -95,19 +97,19 @@ def dummy_handler(event, context): def test_unhandled_exception_logging_hidden_event(caplog): @unhandled_exception_logging_hidden_event - def regular_error_func(event, context): - raise Exception("dummy exception message") + def regular_error_func(event, context) -> None: + msg = "dummy exception message" + raise Exception(msg) # noqa: TRY002 with caplog.at_level(logging.ERROR): - - with raises(Exception): + with pytest.raises(Exception, match="dummy exception message"): regular_error_func(None, None) assert "dummy_error" not in caplog.text def test_unhandled_exception_logging_hidden_event_no_error(): @unhandled_exception_logging_hidden_event - def dummy_handler(event, context): + def dummy_handler(event, context) -> None: pass # Arrange @@ -136,6 +138,6 @@ def dummy_handler(event, context): "eventSource": "aws:sqs", "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", "awsRegion": "us-east-2", - } - ] + }, + ], } diff --git a/application/common/tests/test_nhs.py b/application/common/tests/test_nhs.py index a24c3e380..378741339 100644 --- a/application/common/tests/test_nhs.py +++ b/application/common/tests/test_nhs.py @@ -2,15 +2,16 @@ import pytest -from ..nhs import ( +from .conftest import PHARMACY_STANDARD_EVENT, dummy_dos_service +from application.common.constants import CLOSED_AND_HIDDEN_STATUSES +from application.common.nhs import ( + NHSEntity, get_palliative_care_log_value, is_spec_opening_json, is_std_opening_json, match_nhs_entities_to_services, - NHSEntity, skip_if_key_is_none, ) -from .conftest import dummy_dos_service, PHARMACY_STANDARD_EVENT from common.constants import DENTIST_SERVICE_TYPE_IDS, PHARMACY_SERVICE_TYPE_IDS from common.opening_times import OpenPeriod, SpecifiedOpeningTime, StandardOpeningTimes @@ -102,8 +103,8 @@ def test_get_specified_opening_times(): "AdditionalOpeningDate": "Jan 20 2023", "IsOpen": False, }, - ] - } + ], + }, ) # Act # Assert @@ -124,7 +125,7 @@ def test_get_specified_opening_times(): ), f"NHS entity should contain {exp_spec_open_time} but can't be found in list {actual_spec_open_times}" assert len(actual_spec_open_times) == len( - expected + expected, ), f"Should return {len(expected)} , actually: {len(actual_spec_open_times)}" @@ -173,8 +174,8 @@ def test_get_standard_opening_times(): "AdditionalOpeningDate": "", "IsOpen": False, }, - ] - } + ], + }, ) # Act expected_std_open_times = StandardOpeningTimes() @@ -201,7 +202,7 @@ def test_is_status_hidden_or_closed_open_service(organisation_status: str): assert not result -@pytest.mark.parametrize("organisation_status", NHSEntity.CLOSED_AND_HIDDEN_STATUSES) +@pytest.mark.parametrize("organisation_status", CLOSED_AND_HIDDEN_STATUSES) def test_is_status_hidden_or_closed_not_open_service(organisation_status: str): # Arrange test_data = {"OrganisationStatus": organisation_status} @@ -213,7 +214,7 @@ def test_is_status_hidden_or_closed_not_open_service(organisation_status: str): @pytest.mark.parametrize( - "open_time_json, expected", + ("open_time_json", "expected"), [ ({}, False), ( @@ -380,7 +381,7 @@ def test_is_std_opening_json(open_time_json, expected): @pytest.mark.parametrize( - "open_time_json, expected", + ("open_time_json", "expected"), [ ({}, False), ( @@ -559,7 +560,7 @@ def test_is_spec_opening_json(open_time_json, expected): assert actual == expected, f"Spec time should be valid={expected} but wasn't. open_time={open_time_json}" -def test_is_matching_dos_service(): +def test_is_matching_dos_service(): # noqa: PLR0915 nhs_entity = NHSEntity({}) dos_service = dummy_dos_service() @@ -671,7 +672,7 @@ def test_match_nhs_entities_to_services(): @pytest.mark.parametrize( - "input_value, output_value", + ("input_value", "output_value"), [ ("", None), (None, None), @@ -683,7 +684,7 @@ def test_match_nhs_entities_to_services(): "ServiceName": "Pharmacy palliative care medication stockholder", "ServiceDescription": None, "ServiceCode": "SRV0559", - } + }, ], True, ), @@ -695,14 +696,21 @@ def test_extract_uec_service(input_value, output_value): @pytest.mark.parametrize( - "input_value,output_value", [(None, True), ("", False), ("V012345", False), (False, False), ("V012345", False)] + ("input_value", "output_value"), + [ + (None, True), + ("", False), + ("V012345", False), + (False, False), + ("V012345", False), + ], ) def test_skip_if_key_is_none(input_value, output_value): assert output_value == skip_if_key_is_none(input_value) @pytest.mark.parametrize( - "palliative_care,skip_palliative_care,output_value", + ("palliative_care", "skip_palliative_care", "output_value"), [ (True, False, True), (False, False, False), diff --git a/application/common/tests/test_opening_times.py b/application/common/tests/test_opening_times.py index 90c15d891..c68bc3a3b 100644 --- a/application/common/tests/test_opening_times.py +++ b/application/common/tests/test_opening_times.py @@ -2,29 +2,28 @@ import pytest -from ..opening_times import ( - opening_period_times_from_list, +from application.common.opening_times import ( + WEEKDAYS, OpenPeriod, SpecifiedOpeningTime, StandardOpeningTimes, - WEEKDAYS, + opening_period_times_from_list, ) OP = OpenPeriod.from_string -def test_open_period_repr(capsys): +def test_open_period_repr(): # Arrange open_period = OpenPeriod(time(8, 0), time(12, 0)) # Act - print(open_period) + value = repr(open_period) # Assert - captured = capsys.readouterr() - assert captured.out == "08:00:00-12:00:00\n" + assert value == "OpenPeriod(08:00:00-12:00:00)" @pytest.mark.parametrize( - "start, end, other_start, other_end, expected", + ("start", "end", "other_start", "other_end", "expected"), [ (time(8, 0), time(12, 0), time(8, 0), time(12, 0), True), (time(8, 0), time(12, 0), time(13, 0), time(23, 0), False), @@ -77,7 +76,10 @@ def test_open_period_eq_hash(): assert hash(a) != hash(a2) -@pytest.mark.parametrize("start, end, expected", [(time(8, 0), time(12, 0), True), (time(12, 0), time(8, 0), False)]) +@pytest.mark.parametrize( + ("start", "end", "expected"), + [(time(8, 0), time(12, 0), True), (time(12, 0), time(8, 0), False)], +) def test_open_period_start_before_end(start, end, expected): # Arrange open_period = OpenPeriod(start, end) @@ -229,7 +231,7 @@ def test_open_period_hash(opening_period_2: OpenPeriod): assert open_period_1 == opening_period_2, f"{open_period_1} not found to be equal to {opening_period_2}" assert hash(open_period_1) == hash( - opening_period_2 + opening_period_2, ), f"hash {hash(open_period_1)} not found to be equal to {hash(opening_period_2)}" @@ -480,7 +482,7 @@ def test_specifiedopentimes_remove_past_dates(): future2 = SpecifiedOpeningTime([a, b, c], (now_date + timedelta(weeks=5))) past = SpecifiedOpeningTime([b], (now_date - timedelta(weeks=4))) - assert SpecifiedOpeningTime.remove_past_dates(list=[future1, future2, past]) == [future1, future2] + assert SpecifiedOpeningTime.remove_past_dates(times_list=[future1, future2, past]) == [future1, future2] def test_specifiedopentime_export_service_history_format_open(): @@ -511,7 +513,9 @@ def test_specifiedopentime_export_service_history_format_closed(): def test_specifiedopentime_export_dos_log_format_open(): # Arrange specified_opening_time = SpecifiedOpeningTime( - [OpenPeriod(time(9, 0, 0), time(11, 0, 0))], date(2021, 12, 24), is_open=True + [OpenPeriod(time(9, 0, 0), time(11, 0, 0))], + date(2021, 12, 24), + is_open=True, ) # Act result = specified_opening_time.export_dos_log_format() @@ -529,7 +533,7 @@ def test_specifiedopentime_export_dos_log_format_closed(): @pytest.mark.parametrize( - "expected, actual", + ("expected", "actual"), [ ({"2021-12-25": []}, SpecifiedOpeningTime([], date(2021, 12, 25))), ( @@ -541,7 +545,7 @@ def test_specifiedopentime_export_dos_log_format_closed(): "2039-12-30": [ {"start_time": "02:00", "end_time": "09:30"}, {"start_time": "11:45", "end_time": "18:00"}, - ] + ], }, SpecifiedOpeningTime( [OpenPeriod(time(2, 0, 0), time(9, 30, 0)), OpenPeriod(time(11, 45, 0), time(18, 0, 0))], @@ -554,7 +558,7 @@ def test_specifiedopentime_export_dos_log_format_closed(): {"start_time": "05:00", "end_time": "09:30"}, {"start_time": "11:45", "end_time": "18:00"}, {"start_time": "20:45", "end_time": "22:00"}, - ] + ], }, SpecifiedOpeningTime( [ @@ -681,7 +685,6 @@ def test_stdopeningtimes_export_opening_times_in_seconds_for_day(): def test_standard_opening_times_export_test_format(): - # Start with empty std_opening_times = StandardOpeningTimes() expected = { @@ -723,7 +726,7 @@ def test_opening_period_times_from_list(): # Act response = opening_period_times_from_list(times) # Assert - assert "08:00-09:00, 09:00-10:00" == response + assert response == "08:00-09:00, 09:00-10:00" def test_std_open_times_fully_closed(): diff --git a/application/common/tests/test_report_logging.py b/application/common/tests/test_report_logging.py index 7b1e361f0..7b5fcba62 100644 --- a/application/common/tests/test_report_logging.py +++ b/application/common/tests/test_report_logging.py @@ -5,8 +5,7 @@ from aws_lambda_powertools.logging import Logger from application.common.constants import INCORRECT_PALLIATIVE_STOCKHOLDER_TYPE_REPORT_ID - -from ..report_logging import ( +from application.common.report_logging import ( log_blank_standard_opening_times, log_closed_or_hidden_services, log_incorrect_palliative_stockholder_type, @@ -115,7 +114,7 @@ def test_log_unmatched_nhsuk_service(mock_logger): "City": "city", "County": "country", "Postcode": "MK2 4AX", - } + }, ) # Act log_unmatched_nhsuk_service(nhs_entity) @@ -150,7 +149,7 @@ def test_log_invalid_nhsuk_postcode(mock_logger): county = "county" city = "city" nhs_entity = NHSEntity( - {"Address1": "address1", "Address2": "address2", "Address3": "address3", "City": city, "County": county} + {"Address1": "address1", "Address2": "address2", "Address3": "address3", "City": city, "County": county}, ) nhs_entity.odscode = "SLC4X" nhs_entity.org_name = "OrganisationName" @@ -267,7 +266,7 @@ def test_log_service_with_generic_bank_holiday(mock_logger): def test_log_unmatched_service_types(mock_logger): # Arrange nhs_entity = NHSEntity( - {"Address1": "address1", "Address2": "address2", "Address3": "address3", "City": "city", "County": "county"} + {"Address1": "address1", "Address2": "address2", "Address3": "address3", "City": "city", "County": "county"}, ) nhs_entity.odscode = "SLC4X" nhs_entity.org_name = "OrganisationName" @@ -331,7 +330,7 @@ def test_log_palliative_care_z_code_does_not_exist(mock_logger: MagicMock): symptom_group_symptom_discriminator_combo_rowcount = 1 # Act log_palliative_care_z_code_does_not_exist( - symptom_group_symptom_discriminator_combo_rowcount=symptom_group_symptom_discriminator_combo_rowcount + symptom_group_symptom_discriminator_combo_rowcount=symptom_group_symptom_discriminator_combo_rowcount, ) # Assert assert ( diff --git a/application/common/tests/test_s3.py b/application/common/tests/test_s3.py index e6e8b4749..c1c550367 100644 --- a/application/common/tests/test_s3.py +++ b/application/common/tests/test_s3.py @@ -17,7 +17,7 @@ def test_put_content_to_s3(mock_client): # Assert mock_client.assert_called_once_with("s3") mock_client.return_value.put_object.assert_called_once_with( - Body=content, Bucket=bucket_name, Key=s3_filename, ServerSideEncryption="AES256" + Body=content, Bucket=bucket_name, Key=s3_filename, ServerSideEncryption="AES256", ) # Cleanup del environ["SEND_EMAIL_BUCKET_NAME"] diff --git a/application/common/tests/test_secretsmanager.py b/application/common/tests/test_secretsmanager.py index 7417c3854..4a3c6cd49 100644 --- a/application/common/tests/test_secretsmanager.py +++ b/application/common/tests/test_secretsmanager.py @@ -1,8 +1,8 @@ from json import dumps import boto3 +import pytest from moto import mock_secretsmanager -from pytest import raises FILE_PATH = "application.common.secretsmanager" @@ -26,5 +26,5 @@ def test_get_secret(): def test_get_secret_resource_not_found(): from application.common.secretsmanager import get_secret - with raises(Exception, match="Failed getting secret 'fake_secret_name' from secrets manager"): + with pytest.raises(Exception, match="Failed getting secret 'fake_secret_name' from secrets manager"): get_secret("fake_secret_name") diff --git a/application/common/tests/test_service_type.py b/application/common/tests/test_service_type.py index 33a1f3e52..edecbcb2d 100644 --- a/application/common/tests/test_service_type.py +++ b/application/common/tests/test_service_type.py @@ -1,13 +1,13 @@ import pytest -from ..constants import SERVICE_TYPES, VALID_SERVICE_TYPES_KEY -from ..service_type import get_valid_service_types +from application.common.constants import SERVICE_TYPES, VALID_SERVICE_TYPES_KEY +from application.common.service_type import get_valid_service_types FILE_PATH = "application.common.service_type" @pytest.mark.parametrize( - "org_type, expected_valid_service_types", + ("org_type", "expected_valid_service_types"), [ ( "Dentist", diff --git a/application/common/tests/test_utilities.py b/application/common/tests/test_utilities.py index 51a3080a0..99d50aa36 100644 --- a/application/common/tests/test_utilities.py +++ b/application/common/tests/test_utilities.py @@ -1,10 +1,10 @@ from json import loads from os import environ +import pytest from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord -from pytest import mark, raises -from ..utilities import ( +from application.common.utilities import ( add_metric, extract_body, get_sequence_number, @@ -31,7 +31,7 @@ def test_extract_body_exception(): # Arrange expected_change_event = "test" # Act & Assert - with raises(Exception): + with pytest.raises(ValueError, match="Change Event unable to be extracted"): extract_body(expected_change_event) @@ -41,16 +41,13 @@ def test_json_str_body(): # Act result = json_str_body({"test": "test"}) # Assert - assert ( - result == expected_json_str - ), f"Change event body should be {expected_json_str} str but is {result}" + assert result == expected_json_str, f"Change event body should be {expected_json_str} str but is {result}" def test_expected_json_str_exception(): # Act & Assert - with raises(Exception) as exception: + with pytest.raises(TypeError, match="Object of type set is not JSON serializable"): json_str_body(body={"not a json dict"}) - assert "Dict Change Event body cannot be converted to a JSON string" in str(exception.value) def test_get_sequence_number(): @@ -120,13 +117,13 @@ def test_handle_sqs_msg_attributes(dead_letter_message): assert attributes["error_msg_http_code"] == "400" -@mark.parametrize("val,expected", [("", True), (" ", True), (None, True), ("True val", False)]) +@pytest.mark.parametrize(("val", "expected"), [("", True), (" ", True), (None, True), ("True val", False)]) def test_is_val_none_or_empty(val, expected): assert is_val_none_or_empty(val) == expected -@mark.parametrize( - "input_dict,keys_tobe_removed,msg_limit,expected", +@pytest.mark.parametrize( + ("input_dict", "keys_tobe_removed", "msg_limit", "expected"), [ ({"Name": "John", "Address": ["2", "4"], "Age": 34}, ["Address"], 20, {"Name": "John", "Age": 34}), ({"Name": "John", "Address": ["2", "4"], "Age": 34}, ["Address", "Age"], 20, {"Name": "John"}), @@ -168,6 +165,6 @@ def test_add_metric(): "eventSource": "aws:sqs", "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", "awsRegion": "us-east-2", - } - ] + }, + ], } diff --git a/application/common/types.py b/application/common/types.py index 54c637b01..06acf7a88 100644 --- a/application/common/types.py +++ b/application/common/types.py @@ -1,10 +1,10 @@ -from typing import Any, Dict, Optional, TypedDict +from typing import Any, TypedDict class HoldingQueueChangeEventItem(TypedDict): - """Represents a change event sent to the service matcher lambda via the holding queue""" + """Represents a change event sent to the service matcher lambda via the holding queue.""" - change_event: Dict[str, Any] + change_event: dict[str, Any] dynamo_record_id: str correlation_id: str sequence_number: int @@ -12,14 +12,14 @@ class HoldingQueueChangeEventItem(TypedDict): class UpdateRequest(TypedDict): - """Class to represent the update request payload""" + """Class to represent the update request payload.""" - change_event: Dict[str, Any] + change_event: dict[str, Any] service_id: str class UpdateRequestMetadata(TypedDict): - """Class to represent the update request metadata""" + """Class to represent the update request metadata.""" dynamo_record_id: str correlation_id: str @@ -30,17 +30,20 @@ class UpdateRequestMetadata(TypedDict): class UpdateRequestQueueItem(TypedDict): - """Class to represent the update request queue item containing the payload and metadata - Optional fields are for the health check as it does not have a payload or metadata""" + """Update Request Queue Item. - update_request: Optional[UpdateRequest] - recipient_id: Optional[str] - metadata: Optional[UpdateRequestMetadata] + Class to represent the update request queue item containing the payload and metadata + Optional fields are for the health check as it does not have a payload or metadata. + """ + + update_request: UpdateRequest | None + recipient_id: str | None + metadata: UpdateRequestMetadata | None is_health_check: bool class EmailFile(TypedDict): - """Class to represent the email file saved to S3""" + """Class to represent the email file saved to S3.""" correlation_id: str email_body: str @@ -49,7 +52,7 @@ class EmailFile(TypedDict): class EmailMessage(TypedDict): - """Class to represent the email message for the send email lambda""" + """Class to represent the email message for the send email lambda.""" change_id: str correlation_id: str diff --git a/application/common/utilities.py b/application/common/utilities.py index 4a9624262..ac854d37f 100644 --- a/application/common/utilities.py +++ b/application/common/utilities.py @@ -1,6 +1,6 @@ from json import dumps, loads from os import environ -from typing import Any, Dict, Union +from typing import Any from aws_embedded_metrics import metric_scope from aws_lambda_powertools.logging import Logger @@ -9,10 +9,11 @@ logger = Logger() -def is_val_none_or_empty(val: Any) -> bool: - """Checks if the value is None or empty +def is_val_none_or_empty(val: Any) -> bool: # noqa: ANN401 + """Checks if the value is None or empty. + Args: - val Any: Value to be checked + val (Any): Value to check Returns: bool: True if the value is None or empty, False otherwise @@ -20,47 +21,62 @@ def is_val_none_or_empty(val: Any) -> bool: return not (val and not val.isspace()) -def extract_body(body: str) -> Dict[str, Any]: - """Extracts the event body from the lambda function invocation event +def extract_body(body: str) -> dict[str, Any]: + """Extracts the event body from the lambda function invocation event. Args: - message_body (str): A JSON string body + body (str): Lambda function invocation event body + Returns: Dict[str, Any]: Message body as a dictionary """ try: body = loads(body) except ValueError as e: - raise ValueError("Change Event unable to be extracted") from e + msg = "Change Event unable to be extracted" + raise ValueError(msg) from e return body -def json_str_body(body: Dict[str, Any]) -> str: - """Encode a Dict event body from the lambda function invocation event into a JSON string +def json_str_body(body: dict[str, Any]) -> str: + """Encode a Dict event body from the lambda function invocation event into a JSON string. Args: - body Dict[str, Any]: body as a dictionary + body (Dict[str, Any]): A Dict body + Returns: (str): A JSON string body """ try: return dumps(body) except ValueError as e: - raise ValueError("Dict Change Event body cannot be converted to a JSON string") from e + msg = "Dict Change Event body cannot be converted to a JSON string" + raise ValueError(msg) from e -def get_sequence_number(record: SQSRecord) -> Union[int, None]: - """Gets the sequence number from the SQS record sent by NHS UK +def get_sequence_number(record: SQSRecord) -> int | None: + """Gets the sequence number from the SQS record sent by NHS UK. + Args: record (SQSRecord): SQS record + Returns: - Optional[int]: Sequence number of the message or None if not present + Optional[int]: Sequence number of the message or None if not present. """ seq_num_str = record.message_attributes.get("sequence-number", {}).get("stringValue") return None if seq_num_str is None else int(seq_num_str) -def get_sqs_msg_attribute(msg_attributes: Dict[str, Any], key: str) -> Union[str, float, None]: +def get_sqs_msg_attribute(msg_attributes: dict[str, Any], key: str) -> str | float | None: + """Gets the value of the given key from the SQS message attributes. + + Args: + msg_attributes (dict[str, Any]): Message attributes + key (str): Key to get the value for + + Returns: + str | float | None: Value of the given key or None if not present. + """ attribute = msg_attributes.get(key) if attribute is None: return None @@ -69,9 +85,18 @@ def get_sqs_msg_attribute(msg_attributes: Dict[str, Any], key: str) -> Union[str return attribute.get("stringValue") if data_type == "Number": return float(attribute.get("stringValue")) + return None + +def handle_sqs_msg_attributes(msg_attributes: dict[str, Any]) -> dict[str, Any] | None: + """Extracts the error message and error message http code from the SQS message attributes. -def handle_sqs_msg_attributes(msg_attributes: Dict[str, Any]) -> Dict[str, Any]: + Args: + msg_attributes (dict[str, Any]): Message attributes + + Returns: + dict[str, Any]: Dictionary with error message and error message http code or None if not present. + """ if msg_attributes is not None: attributes = {"error_msg": "", "error_msg_http_code": ""} if "error_msg_http_code" in msg_attributes: @@ -80,16 +105,23 @@ def handle_sqs_msg_attributes(msg_attributes: Dict[str, Any]) -> Dict[str, Any]: attributes["error_msg"] = msg_attributes["error_msg"]["stringValue"] return attributes + return None + +def remove_given_keys_from_dict_by_msg_limit( + event: dict[str, Any], + keys: list, + msg_limit: int = 10000, +) -> dict[str, Any]: + """Removing given keys from the dictionary if the dictionary size is more than message limit. -def remove_given_keys_from_dict_by_msg_limit(event: Dict[str, Any], keys: list, msg_limit: int = 10000): - """Removing given keys from the dictionary if the dictionary size is more than message limit Args: - event Dict[str, Any]: Message body as a dictionary - keys list: keys to be removed - msg_limit int: message limit in char length + event (Dict[str, Any]): Message body as a dictionary + keys (list): List of keys to be removed from the dictionary + msg_limit (int): Message limit in characters + Returns: - Dict[str, Any]: Message body as a dictionary + Dict[str, Any]: Message body as a dictionary. """ msg_length = len(dumps(event).encode("utf-8")) if msg_length > msg_limit: @@ -98,8 +130,8 @@ def remove_given_keys_from_dict_by_msg_limit(event: Dict[str, Any], keys: list, @metric_scope -def add_metric(metric_name: str, metrics) -> None: # type: ignore - """Adds a metric to the custom metrics collection +def add_metric(metric_name: str, metrics: Any) -> None: # noqa: ANN401 + """Adds a metric to the custom metrics collection. Args: metric_name (str): Name of the metric to be added to CloudWatch diff --git a/application/dos_db_handler/dos_db_handler.py b/application/dos_db_handler/dos_db_handler.py index 6940c4f30..b611ae438 100644 --- a/application/dos_db_handler/dos_db_handler.py +++ b/application/dos_db_handler/dos_db_handler.py @@ -1,10 +1,10 @@ from json import dumps -from typing import Any, Dict +from typing import Any from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext -from common.dos import get_specified_opening_times_from_db, get_standard_opening_times_from_db, SpecifiedOpeningTime +from common.dos import SpecifiedOpeningTime, get_specified_opening_times_from_db, get_standard_opening_times_from_db from common.dos_db_connection import connect_to_dos_db, query_dos_db from common.middlewares import unhandled_exception_logging @@ -13,8 +13,8 @@ @unhandled_exception_logging() @logger.inject_lambda_context(clear_state=True) -def lambda_handler(event: Dict[str, Any], context: LambdaContext) -> str: - """Entrypoint handler for the lambda +def lambda_handler(event: dict[str, Any], context: LambdaContext) -> str: # noqa: ARG001 + """Entrypoint handler for the lambda. WARNING: This lambda is for TESTING PURPOSES ONLY. It is not intended to be used in production. @@ -36,7 +36,7 @@ def lambda_handler(event: Dict[str, Any], context: LambdaContext) -> str: if request["type"] == "write": # returns a single value (typically id) return dumps(result, default=str)[0][0] - elif request["type"] == "read": + elif request["type"] == "read": # noqa: RET505 # returns all values return dumps(result, default=str) elif request["type"] == "insert": @@ -45,7 +45,8 @@ def lambda_handler(event: Dict[str, Any], context: LambdaContext) -> str: elif request["type"] == "change_event_standard_opening_times": service_id = request.get("service_id") if service_id is None: - raise ValueError("Missing service_id") + msg = "Missing service_id" + raise ValueError(msg) with connect_to_dos_db() as connection: standard_opening_times = get_standard_opening_times_from_db(connection=connection, service_id=service_id) result = standard_opening_times.export_test_format() @@ -53,20 +54,31 @@ def lambda_handler(event: Dict[str, Any], context: LambdaContext) -> str: elif request["type"] == "change_event_specified_opening_times": service_id = request.get("service_id") if service_id is None: - raise ValueError("Missing service_id") + msg = "Missing service_id" + raise ValueError(msg) with connect_to_dos_db() as connection: specified_opening_times = get_specified_opening_times_from_db(connection=connection, service_id=service_id) result = SpecifiedOpeningTime.export_test_format_list(specified_opening_times) return result else: # add comment - raise ValueError("Unsupported request") + msg = "Unsupported request" + raise ValueError(msg) -def run_query(query, query_vars) -> list: +def run_query(query: str, query_vars: dict) -> list: + """Run a query against the database. + + Args: + query (str): Query to run + query_vars (dict): Query variables + + Returns: + list: Query result + """ logger.info("Running query", extra={"query": query}) with connect_to_dos_db() as connection: - cursor = query_dos_db(connection=connection, query=query, vars=query_vars) + cursor = query_dos_db(connection=connection, query=query, query_vars=query_vars) query_result = cursor.fetchall() connection.commit() cursor.close() diff --git a/application/dos_db_update_dlq_handler/dos_db_update_dlq_handler.py b/application/dos_db_update_dlq_handler/dos_db_update_dlq_handler.py index 7834c48ee..523f8d054 100644 --- a/application/dos_db_update_dlq_handler/dos_db_update_dlq_handler.py +++ b/application/dos_db_update_dlq_handler/dos_db_update_dlq_handler.py @@ -1,7 +1,9 @@ +from typing import Any + from aws_embedded_metrics import metric_scope from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.tracing import Tracer -from aws_lambda_powertools.utilities.data_classes import event_source, SQSEvent +from aws_lambda_powertools.utilities.data_classes import SQSEvent, event_source from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext from common.constants import DLQ_HANDLER_REPORT_ID @@ -17,15 +19,17 @@ @tracer.capture_lambda_handler() @event_source(data_class=SQSEvent) @logger.inject_lambda_context( - clear_state=True, correlation_id_path='Records[0].messageAttributes."correlation-id".stringValue' + clear_state=True, + correlation_id_path='Records[0].messageAttributes."correlation-id".stringValue', ) @metric_scope -def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: - """Entrypoint handler for the lambda +def lambda_handler(event: SQSEvent, context: LambdaContext, metrics: Any) -> None: # noqa: ANN401, ARG001 + """Entrypoint handler for the lambda. Args: event (SQSEvent): Lambda function invocation event (list of 1 SQS Message) context (LambdaContext): Lambda function context object + metrics (Any): Embedded metrics object """ record = next(event.records) message = record.body diff --git a/application/dos_db_update_dlq_handler/tests/test_dos_db_update_dlq_handler.py b/application/dos_db_update_dlq_handler/tests/test_dos_db_update_dlq_handler.py index bf0bee28b..a66286662 100644 --- a/application/dos_db_update_dlq_handler/tests/test_dos_db_update_dlq_handler.py +++ b/application/dos_db_update_dlq_handler/tests/test_dos_db_update_dlq_handler.py @@ -2,19 +2,19 @@ from os import environ from unittest.mock import patch +import pytest from aws_embedded_metrics.logger.metrics_logger import MetricsLogger -from pytest import fixture -from ..dos_db_update_dlq_handler import lambda_handler +from application.dos_db_update_dlq_handler.dos_db_update_dlq_handler import lambda_handler FILE_PATH = "application.dos_db_update_dlq_handler.dos_db_update_dlq_handler" -@fixture +@pytest.fixture() def lambda_context(): @dataclass class LambdaContext: - """Mock LambdaContext - All dummy values""" + """Mock LambdaContext - All dummy values.""" function_name: str = "dos-db-update-dlq-handler" memory_limit_in_mb: int = 128 @@ -67,8 +67,8 @@ def test_lambda_handler(mock_put_metric, mock_set_dimentions, mock_extract_body, "dataType": "String", }, }, - } - ] + }, + ], } environ["ENV"] = "test" mock_extract_body.return_value = extracted_body diff --git a/application/event_replay/event_replay.py b/application/event_replay/event_replay.py index 94ee3240d..b969bf0b9 100644 --- a/application/event_replay/event_replay.py +++ b/application/event_replay/event_replay.py @@ -1,7 +1,7 @@ from decimal import Decimal from os import getenv from time import time_ns -from typing import Any, Dict +from typing import Any from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.tracing import Tracer @@ -19,8 +19,8 @@ @tracer.capture_lambda_handler() @unhandled_exception_logging @logger.inject_lambda_context(clear_state=True) -def lambda_handler(event: Dict[str, Any], context: LambdaContext) -> str: - """Entrypoint handler for the authoriser lambda +def lambda_handler(event: dict[str, Any], context: LambdaContext) -> str: # noqa: ARG001 + """Entrypoint handler for the authoriser lambda. Args: event (Dict[str, Any]): Lambda function invocation event @@ -43,18 +43,39 @@ def lambda_handler(event: Dict[str, Any], context: LambdaContext) -> str: return dumps({"message": "The change event has been re-sent successfully", "correlation_id": correlation_id}) -def validate_event(event: Dict[str, Any]) -> None: +def validate_event(event: dict[str, Any]) -> None: + """Validate the event payload. + + Args: + event (dict[str, Any]): The event payload + """ if "odscode" not in event: - raise ValueError("Missing 'odscode' in event") + msg = "Missing 'odscode' in event" + raise ValueError(msg) if "sequence_number" not in event: - raise ValueError("Missing 'sequence_number' in event") + msg = "Missing 'sequence_number' in event" + raise ValueError(msg) + +def build_correlation_id() -> str: + """Build a correlation id for the event replay. -def build_correlation_id(): + Returns: + str: The correlation id + """ return f'{time_ns()}-{getenv("ENV")}-replayed-event' -def get_change_event(odscode: str, sequence_number: Decimal) -> Dict[str, Any]: +def get_change_event(odscode: str, sequence_number: Decimal) -> dict[str, Any]: + """Get the change event from dynamodb. + + Args: + odscode (str): The ods code of the organisation + sequence_number (Decimal): The sequence number of the change event + + Returns: + dict[str, Any]: The change event + """ response = client("dynamodb").query( TableName=getenv("CHANGE_EVENTS_TABLE_NAME"), IndexName="gsi_ods_sequence", @@ -72,7 +93,8 @@ def get_change_event(odscode: str, sequence_number: Decimal) -> Dict[str, Any]: ScanIndexForward=False, ) if len(response["Items"]) == 0: - raise ValueError(f"No change event found for ods code {odscode} and sequence number {sequence_number}") + msg = f"No change event found for ods code {odscode} and sequence number {sequence_number}" + raise ValueError(msg) item = response["Items"][0] logger.info("Retrieved change event from dynamodb", extra={"item": item}) deserializer = TypeDeserializer() @@ -82,7 +104,15 @@ def get_change_event(odscode: str, sequence_number: Decimal) -> Dict[str, Any]: return change_event -def send_change_event(change_event: Dict[str, Any], odscode: str, sequence_number: int, correlation_id: str): +def send_change_event(change_event: dict[str, Any], odscode: str, sequence_number: int, correlation_id: str) -> None: + """Send the change event to the change event SQS queue. + + Args: + change_event (dict[str, Any]): The change event + odscode (str): The ods code of the organisation + sequence_number (int): The sequence number of the change event + correlation_id (str): The correlation id of the event replay + """ sqs = client("sqs") queue_url = sqs.get_queue_url(QueueName=getenv("CHANGE_EVENT_SQS_NAME"))["QueueUrl"] logger.info("Sending change event to SQS", extra={"queue_url": queue_url}) diff --git a/application/event_replay/tests/test_event_replay.py b/application/event_replay/tests/test_event_replay.py index bfdda6f60..b63c2d65e 100644 --- a/application/event_replay/tests/test_event_replay.py +++ b/application/event_replay/tests/test_event_replay.py @@ -2,12 +2,12 @@ from decimal import Decimal from json import dumps from os import environ -from typing import Any, Dict +from typing import Any from unittest.mock import patch +import pytest from aws_lambda_powertools.logging import Logger from boto3.dynamodb.types import TypeSerializer -from pytest import fixture, raises from application.event_replay.event_replay import ( build_correlation_id, @@ -20,16 +20,16 @@ FILE_PATH = "application.event_replay.event_replay" -@fixture -def event() -> Dict[str, Any]: +@pytest.fixture() +def event() -> dict[str, Any]: return {"odscode": "FXXX1", "sequence_number": "1"} -@fixture +@pytest.fixture() def lambda_context(): @dataclass class LambdaContext: - """Mock LambdaContext - All dummy values""" + """Mock LambdaContext - All dummy values.""" function_name: str = "event-replay" memory_limit_in_mb: int = 128 @@ -39,7 +39,7 @@ class LambdaContext: return LambdaContext() -@fixture +@pytest.fixture() def change_event(): return { "Address1": "Flat 619", @@ -72,14 +72,17 @@ def test_lambda_handler( response = lambda_handler(event, lambda_context) # Assert assert response == dumps( - {"message": "The change event has been re-sent successfully", "correlation_id": correlation_id} + {"message": "The change event has been re-sent successfully", "correlation_id": correlation_id}, ) mock_append_keys.assert_any_call(ods_code=event["odscode"]) mock_append_keys.assert_any_call(sequence_number=event["sequence_number"]) mock_build_correlation_id.assert_called_once_with() mock_get_change_event.assert_called_once_with(event["odscode"], Decimal(event["sequence_number"])) mock_send_change_event.assert_called_once_with( - change_event, event["odscode"], int(event["sequence_number"]), correlation_id + change_event, + event["odscode"], + int(event["sequence_number"]), + correlation_id, ) @@ -92,7 +95,7 @@ def test_validate_event_no_odscode(event): # Arrange del event["odscode"] # Act & Assert - with raises(Exception): + with pytest.raises(ValueError, match="odscode"): validate_event(event) @@ -100,7 +103,7 @@ def test_validate_event_no_sequence_number(event): # Arrange del event["sequence_number"] # Act & Assert - with raises(Exception): + with pytest.raises(ValueError, match="sequence_number"): validate_event(event) @@ -155,7 +158,7 @@ def test_get_change_event_no_change_event_in_dynamodb(mock_client, event, change environ["AWS_REGION"] = "eu-west-1" mock_client.return_value.query.return_value = {"Items": []} # Act - with raises(Exception): + with pytest.raises(ValueError, match="No change event found for ods code FXXX1 and sequence number 1"): get_change_event(event["odscode"], Decimal(event["sequence_number"])) # Assert mock_client.assert_called_with("dynamodb") diff --git a/application/ingest_change_event/change_event_validation.py b/application/ingest_change_event/change_event_validation.py index 6e441fa1b..9d10de9b7 100644 --- a/application/ingest_change_event/change_event_validation.py +++ b/application/ingest_change_event/change_event_validation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.utilities.validation import validate @@ -13,56 +13,62 @@ SERVICE_TYPES, SERVICE_TYPES_ALIAS_KEY, ) -from common.errors import ValidationException +from common.errors import ValidationError logger = Logger(child=True) -def validate_change_event(event: Dict[str, Any]) -> None: - """Validate event using business rules +def validate_change_event(event: dict[str, Any]) -> None: + """Validate event using business rules. + Args: - event (Dict[str, Any]): Lambda function invocation event + event (Dict[str, Any]): Lambda function invocation event. """ logger.info(f"Attempting to validate event payload: {event}") try: validate(event=event, schema=INPUT_SCHEMA) except SchemaValidationError as exception: - raise ValidationException(exception) from exception + raise ValidationError(exception) from exception validate_organisation_keys(event.get("OrganisationTypeId"), event.get("OrganisationSubType")) check_ods_code_length(event["ODSCode"], SERVICE_TYPES[event["OrganisationTypeId"]][ODSCODE_LENGTH_KEY]) logger.info("Event has been validated") def check_ods_code_length(odscode: str, odscode_length: int) -> None: - """Check ODS code length as expected, exception raise if error + """Check ODS code length as expected, exception raise if error. + Note: ods code type is checked by schema validation + Args: - odscode (str): odscode of NHS UK service + odscode (str): odscode of NHS UK service. + odscode_length (int): expected length of odscode. """ logger.debug(f"Checking ODSCode {odscode} length") if len(odscode) != odscode_length: - raise ValidationException(f"ODSCode Wrong Length, '{odscode}' is not length {odscode_length}.") + msg = f"ODSCode Wrong Length, '{odscode}' is not length {odscode_length}." + raise ValidationError(msg) def validate_organisation_keys(org_type_id: str, org_sub_type: str) -> None: - """Validate the organisation type id and organisation sub type + """Validate the organisation type id and organisation sub type. Args: org_type_id (str): organisation type id org_sub_type (str): organisation sub type Raises: - ValidationException: Either Org Type ID or Org Sub Type is not part of the valid list + ValidationError: Either Org Type ID or Org Sub Type is not part of the valid list """ validate_organisation_type_id(org_type_id) if org_sub_type in SERVICE_TYPES[org_type_id][ORGANISATION_SUB_TYPES_KEY]: logger.info(f"Subtype type id: {org_sub_type} validated") else: - raise ValidationException(f"Unexpected Org Sub Type ID: '{org_sub_type}'") + msg = f"Unexpected Org Sub Type ID: '{org_sub_type}'" + raise ValidationError(msg) def validate_organisation_type_id(org_type_id: str) -> None: - """Check if the organisation type id is valid + """Check if the organisation type id is valid. Args: org_type_id (str): organisation type id @@ -70,8 +76,10 @@ def validate_organisation_type_id(org_type_id: str) -> None: app_config = AppConfig("ingest-change-event") feature_flags = app_config.get_feature_flags() in_accepted_org_types: bool = feature_flags.evaluate( - name="accepted_org_types", context={"org_type": org_type_id}, default=False - ) # type: ignore + name="accepted_org_types", + context={"org_type": org_type_id}, + default=False, + ) logger.debug(f"Accepted org types: {in_accepted_org_types}") if ( org_type_id == PHARMACY_ORG_TYPE_ID @@ -86,7 +94,8 @@ def validate_organisation_type_id(org_type_id: str) -> None: ) else: logger.append_keys(in_accepted_org_types=in_accepted_org_types, app_config=app_config.get_raw_configuration()) - raise ValidationException(f"Unexpected Org Type ID: '{org_type_id}'") + msg = f"Unexpected Org Type ID: '{org_type_id}'" + raise ValidationError(msg) INPUT_SCHEMA = { diff --git a/application/ingest_change_event/ingest_change_event.py b/application/ingest_change_event/ingest_change_event.py index 1b033de80..7902a2719 100644 --- a/application/ingest_change_event/ingest_change_event.py +++ b/application/ingest_change_event/ingest_change_event.py @@ -1,11 +1,12 @@ from json import dumps from os import environ from time import gmtime, strftime, time_ns +from typing import Any from aws_embedded_metrics import metric_scope from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.tracing import Tracer -from aws_lambda_powertools.utilities.data_classes import event_source, SQSEvent +from aws_lambda_powertools.utilities.data_classes import SQSEvent, event_source from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext from boto3 import client @@ -29,8 +30,8 @@ correlation_id_path='Records[0].messageAttributes."correlation-id".stringValue', ) @metric_scope -def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: - """Entrypoint handler for the ingest change event lambda +def lambda_handler(event: SQSEvent, context: LambdaContext, metrics: Any) -> None: # noqa: ANN401, ARG001 + """Entrypoint handler for the ingest change event lambda. This lambda runs the change event validation, puts the change event on the dynamodb table and then sends the validated change event to the delay queue. @@ -38,12 +39,14 @@ def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: Args: event (SQSEvent): Lambda function invocation event context (LambdaContext): Lambda function context object + metrics (Any): Embedded metrics object Event: The event payload should contain an Update Request """ time_start_ns = time_ns() if len(list(event.records)) != 1: - raise ValueError(f"{len(list(event.records))} records found in event. Expected 1.") + msg = f"{len(list(event.records))} records found in event. Expected 1." + raise ValueError(msg) record = next(event.records) change_event = extract_body(record.body) @@ -71,7 +74,7 @@ def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: if sequence_number is None: logger.error("No sequence number provided, so message will be ignored.") return - elif sequence_number < db_latest_sequence_number: + elif sequence_number < db_latest_sequence_number: # noqa: RET505 logger.error( "Sequence id is smaller than the existing one in db for a given odscode, so will be ignored", extra={"incoming_sequence_number": sequence_number, "db_latest_sequence_number": db_latest_sequence_number}, @@ -93,11 +96,12 @@ def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: @metric_scope -def add_change_event_received_metric(ods_code: str, metrics) -> None: # type: ignore - """Adds a success metric to the custom metrics collection +def add_change_event_received_metric(ods_code: str, metrics: Any) -> None: # noqa: ANN401 + """Adds a success metric to the custom metrics collection. Args: - event (UpdateRequestQueueItem): Lambda function invocation event + ods_code (str): ODS Code of the change event + metrics (Any): Embedded metrics object """ metrics.set_namespace("UEC-DOS-INT") metrics.set_property("message", f"Change Event Received for ODSCode: {ods_code}") diff --git a/application/ingest_change_event/tests/conftest.py b/application/ingest_change_event/tests/conftest.py index 7f3a5c236..7c8012b15 100644 --- a/application/ingest_change_event/tests/conftest.py +++ b/application/ingest_change_event/tests/conftest.py @@ -1,22 +1,23 @@ -from pytest import fixture +import pytest from testfixtures import LogCapture from common.tests.conftest import PHARMACY_STANDARD_EVENT, PHARMACY_STANDARD_EVENT_STAFF -@fixture() -def log_capture(): +@pytest.fixture() +def log_capture() -> LogCapture: + """LogCapture fixture for lambda functions.""" with LogCapture(names="lambda") as capture: yield capture -@fixture -def change_event(): - change_event = PHARMACY_STANDARD_EVENT.copy() - yield change_event +@pytest.fixture() +def change_event() -> dict: + """Get a standard change event.""" + return PHARMACY_STANDARD_EVENT.copy() -@fixture -def change_event_staff(): - change_event_staff = PHARMACY_STANDARD_EVENT_STAFF.copy() - yield change_event_staff +@pytest.fixture() +def change_event_staff() -> dict: + """Get a standard change event with staff.""" + return PHARMACY_STANDARD_EVENT_STAFF.copy() diff --git a/application/ingest_change_event/tests/test_change_event_validation.py b/application/ingest_change_event/tests/test_change_event_validation.py index d3c2f1178..b2a52fd95 100644 --- a/application/ingest_change_event/tests/test_change_event_validation.py +++ b/application/ingest_change_event/tests/test_change_event_validation.py @@ -1,14 +1,13 @@ from unittest.mock import MagicMock, patch import pytest -from pytest import raises -from ...ingest_change_event.change_event_validation import ( +from application.ingest_change_event.change_event_validation import ( + ValidationError, check_ods_code_length, validate_change_event, validate_organisation_keys, validate_organisation_type_id, - ValidationException, ) from common.constants import DENTIST_ORG_TYPE_ID, PHARMACY_ORG_TYPE_ID @@ -27,7 +26,7 @@ def test_validate_change_event_missing_key(mock_check_ods_code_length, mock_vali # Arrange del change_event["ODSCode"] # Act - with raises(ValidationException): + with pytest.raises(ValidationError): validate_change_event(change_event) # Assert mock_check_ods_code_length.assert_not_called() @@ -35,7 +34,7 @@ def test_validate_change_event_missing_key(mock_check_ods_code_length, mock_vali @pytest.mark.parametrize( - "odscode, odscode_length", + ("odscode", "odscode_length"), [ ("FXXX1", 5), ("AAAAA", 5), @@ -49,7 +48,7 @@ def test_check_ods_code_length(odscode, odscode_length): @pytest.mark.parametrize( - "odscode, odscode_length", + ("odscode", "odscode_length"), [ ("FXXX11", 5), ("AAAA", 5), @@ -59,12 +58,12 @@ def test_check_ods_code_length(odscode, odscode_length): ) def test_check_ods_code_length_incorrect_length(odscode, odscode_length): # Act & Assert - with raises(ValidationException): + with pytest.raises(ValidationError): check_ods_code_length(odscode, odscode_length) @pytest.mark.parametrize( - "org_type_id, org_sub_type", + ("org_type_id", "org_sub_type"), [ ( "Dentist", @@ -87,7 +86,7 @@ def test_validate_organisation_keys( @pytest.mark.parametrize( - "org_type_id, org_sub_type", + ("org_type_id", "org_sub_type"), [ ( "Dentist", @@ -101,12 +100,14 @@ def test_validate_organisation_keys( ) @patch(f"{FILE_PATH}.validate_organisation_type_id") def test_validate_organisation_keys_org_sub_type_id_exception( - mock_validate_organisation_type_id, org_type_id, org_sub_type + mock_validate_organisation_type_id, + org_type_id, + org_sub_type, ): # Act & Assert - with raises(ValidationException) as exception: + with pytest.raises(ValidationError) as exception: validate_organisation_keys(org_type_id, org_sub_type) - assert f"Unexpected Org Sub Type ID: '{org_sub_type}'" in str(exception.value) + assert f"Unexpected Org Sub Type ID: '{org_sub_type}'" in str(exception.value) @pytest.mark.parametrize("org_type_id", [PHARMACY_ORG_TYPE_ID, DENTIST_ORG_TYPE_ID]) @@ -120,7 +121,9 @@ def test_validate_organisation_type_id(mock_app_config, org_type_id): validate_organisation_type_id(org_type_id) # Assert feature_flags.evaluate.assert_called_once_with( - name="accepted_org_types", context={"org_type": org_type_id}, default=False + name="accepted_org_types", + context={"org_type": org_type_id}, + default=False, ) @@ -132,11 +135,13 @@ def test_validate_organisation_type_id_wrong_org_type_id_exception(mock_app_conf mock_app_config().get_feature_flags.return_value = feature_flags feature_flags.evaluate.return_value = False # Act - with raises(ValidationException) as exception: + with pytest.raises(ValidationError) as exception: validate_organisation_type_id(org_type_id) - assert f"Unexpected Org Type ID: '{org_type_id}'" in str(exception.value) + assert f"Unexpected Org Type ID: '{org_type_id}'" in str(exception.value) # Assert feature_flags.evaluate.assert_called_once_with( - name="accepted_org_types", context={"org_type": org_type_id}, default=False + name="accepted_org_types", + context={"org_type": org_type_id}, + default=False, ) mock_app_config().get_raw_configuration.assert_called_once_with() diff --git a/application/ingest_change_event/tests/test_ingest_change_event.py b/application/ingest_change_event/tests/test_ingest_change_event.py index ffa8e1061..dc835f4b2 100644 --- a/application/ingest_change_event/tests/test_ingest_change_event.py +++ b/application/ingest_change_event/tests/test_ingest_change_event.py @@ -3,10 +3,10 @@ from os import environ from unittest.mock import MagicMock, patch +import pytest from aws_embedded_metrics.logger.metrics_logger import MetricsLogger from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.utilities.typing import LambdaContext -from pytest import fixture, raises from application.common.types import HoldingQueueChangeEventItem from application.ingest_change_event.ingest_change_event import add_change_event_received_metric, lambda_handler @@ -14,34 +14,37 @@ FILE_PATH = "application.ingest_change_event.ingest_change_event" -@fixture -def mock_metric_logger(): +@pytest.fixture(autouse=True) +def _mock_metric_logger() -> None: InvocationTracker.reset() - async def flush(self): - print("flush called") + async def flush(self) -> None: InvocationTracker.record() MetricsLogger.flush = flush -class InvocationTracker(object): +class InvocationTracker: + """Tracks the number of times a function has been invoked.""" + invocations = 0 @staticmethod - def record(): + def record() -> None: + """Record an invocation.""" InvocationTracker.invocations += 1 @staticmethod - def reset(): + def reset() -> None: + """Reset the invocation count.""" InvocationTracker.invocations = 0 -@fixture +@pytest.fixture() def lambda_context(): @dataclass class LambdaContext: - """Mock LambdaContext - All dummy values""" + """Mock LambdaContext - All dummy values.""" function_name: str = "ingest-change-event" memory_limit_in_mb: int = 128 @@ -91,7 +94,7 @@ def test_lambda_handler( dynamo_record_id=None, correlation_id=None, sequence_number=None, - message_received=None, # type: ignore + message_received=None, ) # Act response = lambda_handler(event, lambda_context) @@ -102,7 +105,9 @@ def test_lambda_handler( mock_validate_change_event.assert_called_once_with(change_event) mock_add_change_event_received_metric.assert_called_once_with(ods_code=change_event["ODSCode"]) mock_remove_given_keys_from_dict_by_msg_limit.assert_called_once_with( - change_event, ["Facilities", "Metrics"], 10000 + change_event, + ["Facilities", "Metrics"], + 10000, ) mock_get_latest_sequence_id_for_a_given_odscode_from_dynamodb.assert_called_once_with(change_event["ODSCode"]) mock_add_change_event_to_dynamodb.assert_called_once_with(change_event, sequence_number, sqs_timestamp) @@ -161,7 +166,7 @@ def test_lambda_handler_with_sensitive_staff_key( dynamo_record_id=None, correlation_id=None, sequence_number=None, - message_received=None, # type: ignore + message_received=None, ) # Act response = lambda_handler(event, lambda_context) @@ -171,7 +176,9 @@ def test_lambda_handler_with_sensitive_staff_key( mock_validate_change_event.assert_called_once_with(change_event) mock_add_change_event_received_metric.assert_called_once_with(ods_code=change_event["ODSCode"]) mock_remove_given_keys_from_dict_by_msg_limit.assert_called_once_with( - change_event, ["Facilities", "Metrics"], 10000 + change_event, + ["Facilities", "Metrics"], + 10000, ) mock_get_latest_sequence_id_for_a_given_odscode_from_dynamodb.assert_called_once_with(change_event["ODSCode"]) mock_add_change_event_to_dynamodb.assert_called_once_with(change_event, sequence_number, sqs_timestamp) @@ -234,7 +241,7 @@ def test_lambda_handler_no_sequence_number( dynamo_record_id=None, correlation_id=None, sequence_number=None, - message_received=None, # type: ignore + message_received=None, ) # Act response = lambda_handler(event, lambda_context) @@ -245,7 +252,9 @@ def test_lambda_handler_no_sequence_number( mock_validate_change_event.assert_called_once_with(change_event) mock_add_change_event_received_metric.assert_called_once_with(ods_code=change_event["ODSCode"]) mock_remove_given_keys_from_dict_by_msg_limit.assert_called_once_with( - change_event, ["Facilities", "Metrics"], 10000 + change_event, + ["Facilities", "Metrics"], + 10000, ) mock_get_latest_sequence_id_for_a_given_odscode_from_dynamodb.assert_called_once_with(change_event["ODSCode"]) mock_add_change_event_to_dynamodb.assert_called_once_with(change_event, sequence_number, sqs_timestamp) @@ -299,7 +308,7 @@ def test_lambda_handler_less_than_latest_sequence_number( dynamo_record_id=None, correlation_id=None, sequence_number=None, - message_received=None, # type: ignore + message_received=None, ) # Act response = lambda_handler(event, lambda_context) @@ -310,7 +319,9 @@ def test_lambda_handler_less_than_latest_sequence_number( mock_validate_change_event.assert_called_once_with(change_event) mock_add_change_event_received_metric.assert_called_once_with(ods_code=change_event["ODSCode"]) mock_remove_given_keys_from_dict_by_msg_limit.assert_called_once_with( - change_event, ["Facilities", "Metrics"], 10000 + change_event, + ["Facilities", "Metrics"], + 10000, ) mock_get_latest_sequence_id_for_a_given_odscode_from_dynamodb.assert_called_once_with(change_event["ODSCode"]) mock_add_change_event_to_dynamodb.assert_called_once_with(change_event, sequence_number, sqs_timestamp) @@ -366,10 +377,10 @@ def test_lambda_handler_mutiple_records( dynamo_record_id=None, correlation_id=None, sequence_number=None, - message_received=None, # type: ignore + message_received=None, ) # Act - with raises(ValueError): + with pytest.raises(ValueError, match="3 records found in event. Expected 1."): lambda_handler(event, lambda_context) # Assert mock_time_ns.assert_called_once() @@ -386,7 +397,7 @@ def test_lambda_handler_mutiple_records( del environ["HOLDING_QUEUE_URL"] -def test_add_change_event_received_metric(mock_metric_logger): +def test_add_change_event_received_metric(): # Arrange odscode = "V12345" environ["ENV"] = "test" @@ -419,6 +430,6 @@ def test_add_change_event_received_metric(mock_metric_logger): "eventSource": "aws:sqs", "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", "awsRegion": "us-east-2", - } - ] + }, + ], } diff --git a/application/orchestrator/orchestrator.py b/application/orchestrator/orchestrator.py index 6d2b8b7b0..2ef75ad88 100644 --- a/application/orchestrator/orchestrator.py +++ b/application/orchestrator/orchestrator.py @@ -1,7 +1,7 @@ from json import dumps from os import environ, getenv from time import gmtime, sleep, strftime, time -from typing import Any, Dict +from typing import Any from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.tracing import Tracer @@ -23,8 +23,8 @@ @unhandled_exception_logging() @tracer.capture_lambda_handler() @logger.inject_lambda_context(clear_state=True) -def lambda_handler(event: Dict[str, Any], context: LambdaContext) -> None: - """Entrypoint handler for the orchestrator lambda +def lambda_handler(event: dict[str, Any], context: LambdaContext) -> None: + """Entrypoint handler for the orchestrator lambda. Args: event (Dict[str, Any]): Lambda function invocation event @@ -40,10 +40,14 @@ def lambda_handler(event: Dict[str, Any], context: LambdaContext) -> None: # Wait then continue sleep(int(environ["SLEEP_FOR_WHEN_OPEN"])) update_request_queue_item = UpdateRequestQueueItem( - is_health_check=True, update_request=None, recipient_id=None, metadata=None + is_health_check=True, + update_request=None, + recipient_id=None, + metadata=None, ) logger.info( - "Sending health check to try and re-open the circuit", extra={"request": update_request_queue_item} + "Sending health check to try and re-open the circuit", + extra={"request": update_request_queue_item}, ) invoke_lambda(update_request_queue_item) continue @@ -100,9 +104,9 @@ def lambda_handler(event: Dict[str, Any], context: LambdaContext) -> None: loop = loop + 1 -def invoke_lambda(payload: Dict[str, Any]) -> Dict[str, Any]: +def invoke_lambda(payload: dict[str, Any]) -> dict[str, Any]: return lambda_client.invoke( FunctionName=getenv("SERVICE_SYNC_FUNCTION_NAME"), InvocationType="Event", Payload=dumps(payload), - ) # type: ignore + ) diff --git a/application/orchestrator/tests/test_orchestrator.py b/application/orchestrator/tests/test_orchestrator.py index c1007c086..89eda5c0f 100644 --- a/application/orchestrator/tests/test_orchestrator.py +++ b/application/orchestrator/tests/test_orchestrator.py @@ -7,17 +7,16 @@ from pytest import approx, fixture from application.orchestrator.orchestrator import invoke_lambda, lambda_handler - from common.types import UpdateRequestMetadata, UpdateRequestQueueItem FILE_PATH = "application.orchestrator.orchestrator" -@fixture +@fixture() def lambda_context(): @dataclass class LambdaContext: - """Mock LambdaContext - All dummy values""" + """Mock LambdaContext - All dummy values.""" function_name: str = "orchestrator" memory_limit_in_mb: int = 128 @@ -46,7 +45,9 @@ def test_invoke_lambda(mock_lambda_client: MagicMock, lambda_context: LambdaCont response = invoke_lambda(payload) # Assert mock_lambda_client.invoke.assert_called_once_with( - FunctionName="MyFirstFunction", InvocationType="Event", Payload=dumps(payload) + FunctionName="MyFirstFunction", + InvocationType="Event", + Payload=dumps(payload), ) assert response == expected del environ["SERVICE_SYNC_FUNCTION_NAME"] @@ -128,20 +129,20 @@ def test_orchestrator_circuit_closed_single_loop( "Messages": [ {"MessageAttributes": EXAMPLE_ATTRIBUTES, "Body": dumps(EXAMPLE_MESSAGE_1), "ReceiptHandle": "H1"}, {"MessageAttributes": EXAMPLE_ATTRIBUTES, "Body": dumps(EXAMPLE_MESSAGE_2), "ReceiptHandle": "H2"}, - ] + ], } # Act lambda_handler({}, lambda_context) # Assert - assert 2 == mock_invoke.call_count - assert 2 == mock_sleep.call_count + assert mock_invoke.call_count == 2 + assert mock_sleep.call_count == 2 c0_args, c0_kwargs = mock_sleep.call_args_list[0] c1_args, c1_kwargs = mock_sleep.call_args_list[1] - assert 0.4 == approx(c0_args[0]) - assert 0.3 == approx(c1_args[0]) + assert approx(c0_args[0]) == 0.4 + assert approx(c1_args[0]) == 0.3 @patch(f"{FILE_PATH}.get_circuit_is_open", return_value=False) @@ -184,12 +185,12 @@ def test_orchestrator_circuit_closed_double_loop( "Messages": [ {"MessageAttributes": EXAMPLE_ATTRIBUTES, "Body": dumps(EXAMPLE_MESSAGE_1), "ReceiptHandle": "H1"}, {"MessageAttributes": EXAMPLE_ATTRIBUTES, "Body": dumps(EXAMPLE_MESSAGE_2), "ReceiptHandle": "H2"}, - ] + ], }, { "Messages": [ {"MessageAttributes": EXAMPLE_ATTRIBUTES, "Body": dumps(EXAMPLE_MESSAGE_3), "ReceiptHandle": "H3"}, - ] + ], }, ] @@ -197,15 +198,15 @@ def test_orchestrator_circuit_closed_double_loop( lambda_handler({}, lambda_context) # Assert - assert 3 == mock_invoke.call_count - assert 3 == mock_sleep.call_count + assert mock_invoke.call_count == 3 + assert mock_sleep.call_count == 3 c0_args, c0_kwargs = mock_sleep.call_args_list[0] c1_args, c1_kwargs = mock_sleep.call_args_list[1] c2_args, c2_kwargs = mock_sleep.call_args_list[2] - assert 0.4 == approx(c0_args[0]) - assert 0.3 == approx(c1_args[0]) - assert 0 == approx(c2_args[0]) + assert approx(c0_args[0]) == 0.4 + assert approx(c1_args[0]) == 0.3 + assert approx(c2_args[0]) == 0 @patch(f"{FILE_PATH}.get_circuit_is_open", return_value=False) @@ -233,9 +234,9 @@ def test_orchestrator_circuit_closed_single_loop_no_messages( lambda_handler({}, lambda_context) # Assert - assert 3 == mock_time.call_count - assert 0 == mock_invoke.call_count - assert 1 == mock_sleep.call_count + assert mock_time.call_count == 3 + assert mock_invoke.call_count == 0 + assert mock_sleep.call_count == 1 mock_sleep.assert_called_once_with(1) @@ -266,8 +267,8 @@ def test_orchestrator_circuit_closed_single_loop_circuit_open( lambda_handler({}, lambda_context) # Assert - assert 3 == mock_time.call_count - assert 1 == mock_invoke.call_count - assert 1 == mock_sleep.call_count + assert mock_time.call_count == 3 + assert mock_invoke.call_count == 1 + assert mock_sleep.call_count == 1 mock_invoke.assert_called_once_with(EXPECTED_HEALTH_CHECK) mock_sleep.assert_called_once_with(5) diff --git a/application/pyproject.toml b/application/pyproject.toml index 76a5085c1..428453983 100644 --- a/application/pyproject.toml +++ b/application/pyproject.toml @@ -1,7 +1,3 @@ -[tool.bandit] -exclude_dirs = ["/tests","dos_db_handler"] -skips = ["B101"] - [tool.vulture] make_whitelist = true paths=["."] diff --git a/application/requirements-dev.txt b/application/requirements-dev.txt index 09edf334f..0c29513be 100644 --- a/application/requirements-dev.txt +++ b/application/requirements-dev.txt @@ -1,12 +1,8 @@ Faker aws-lambda-context -bandit boto3 -coverage -isort locust moto -mutmut pandas pytest pytest-bdd diff --git a/application/send_email/__init__.py b/application/send_email/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/application/send_email/send_email.py b/application/send_email/send_email.py index 029502072..9f5275518 100644 --- a/application/send_email/send_email.py +++ b/application/send_email/send_email.py @@ -19,8 +19,8 @@ @tracer.capture_lambda_handler() @unhandled_exception_logging_hidden_event @logger.inject_lambda_context(clear_state=True, correlation_id_path="correlation_id") -def lambda_handler(event: EmailMessage, context: LambdaContext) -> None: - """Entrypoint handler for the service_sync lambda +def lambda_handler(event: EmailMessage, context: LambdaContext) -> None: # noqa: ARG001 + """Entrypoint handler for the service_sync lambda. Args: event (EmailMessage): Lambda function invocation event @@ -37,7 +37,7 @@ def lambda_handler(event: EmailMessage, context: LambdaContext) -> None: def send_email(email_address: str, html_content: str, subject: str, correlation_id: str) -> None: - """Send an email to the specified email address + """Send an email to the specified email address. Args: email_address (str): Email address to send the email to @@ -45,7 +45,6 @@ def send_email(email_address: str, html_content: str, subject: str, correlation_ subject (str): Subject of the email correlation_id (str): Correlation ID of the email """ - aws_account_name = environ["AWS_ACCOUNT_NAME"] if aws_account_name != "nonprod" or "email" in correlation_id: logger.info("Preparing to send email") @@ -71,8 +70,9 @@ def send_email(email_address: str, html_content: str, subject: str, correlation_ logger.info("Sent email") smtp.quit() logger.info("Disconnected from SMTP server") - add_metric("EmailSent") # type: ignore - except BaseException as exception: - add_metric("EmailFailed") # type: ignore - logger.error("Email failed", extra={"error_name": type(exception).__name__}) - raise SMTPException("An error occurred while sending the email") + add_metric("EmailSent") + except BaseException: + add_metric("EmailFailed") + logger.exception("Email failed") + msg = "An error occurred while sending the email" + raise SMTPException(msg) from None diff --git a/application/send_email/tests/__init__.py b/application/send_email/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/application/send_email/tests/test_send_email.py b/application/send_email/tests/test_send_email.py index 4c6313853..064d91617 100644 --- a/application/send_email/tests/test_send_email.py +++ b/application/send_email/tests/test_send_email.py @@ -3,11 +3,10 @@ from smtplib import SMTPException from unittest.mock import MagicMock, patch +import pytest from aws_lambda_powertools.utilities.typing import LambdaContext -from pytest import fixture, raises from application.send_email.send_email import lambda_handler, send_email - from common.types import EmailMessage FILE_PATH = "application.send_email.send_email" @@ -28,11 +27,11 @@ ) -@fixture +@pytest.fixture() def lambda_context(): @dataclass class LambdaContext: - """Mock LambdaContext - All dummy values""" + """Mock LambdaContext - All dummy values.""" function_name: str = "send-email" memory_limit_in_mb: int = 128 @@ -63,7 +62,10 @@ def test_lambda_handler(mock_send_email: MagicMock, lambda_context: LambdaContex @patch(f"{FILE_PATH}.SMTP") @patch(f"{FILE_PATH}.get_secret") def test_send_email( - mock_get_secret: MagicMock, mock_smtp: MagicMock, mock_mime_multipart: MagicMock, add_metric: MagicMock + mock_get_secret: MagicMock, + mock_smtp: MagicMock, + mock_mime_multipart: MagicMock, + add_metric: MagicMock, ): # Arrange environ["AWS_ACCOUNT_NAME"] = "test" @@ -129,7 +131,10 @@ def test_send_email_nonprod(mock_get_secret: MagicMock, mock_smtp: MagicMock, mo @patch(f"{FILE_PATH}.SMTP") @patch(f"{FILE_PATH}.get_secret") def test_send_email_exception( - mock_get_secret: MagicMock, mock_smtp: MagicMock, mock_mime_multipart: MagicMock, add_metric: MagicMock + mock_get_secret: MagicMock, + mock_smtp: MagicMock, + mock_mime_multipart: MagicMock, + add_metric: MagicMock, ): # Arrange environ["AWS_ACCOUNT_NAME"] = "test" @@ -144,7 +149,7 @@ def test_send_email_exception( } mock_smtp.return_value.ehlo.side_effect = SMTPException() # Act - with raises(SMTPException, match="An error occurred while sending the email"): + with pytest.raises(SMTPException, match="An error occurred while sending the email"): send_email( email_address=RECIPIENT_EMAIL_ADDRESS, html_content=EMAIL_BODY, diff --git a/application/service_matcher/requirements.txt b/application/service_matcher/requirements.txt index dd348b6a3..338fc2680 100644 --- a/application/service_matcher/requirements.txt +++ b/application/service_matcher/requirements.txt @@ -1,3 +1,4 @@ aws-embedded-metrics aws-lambda-powertools[tracer] ~= 2.1.0 psycopg[binary] +pytz diff --git a/application/service_matcher/service_matcher.py b/application/service_matcher/service_matcher.py index 14fb83231..012b61804 100644 --- a/application/service_matcher/service_matcher.py +++ b/application/service_matcher/service_matcher.py @@ -3,17 +3,18 @@ from json import dumps from operator import countOf from os import environ -from typing import Any, Dict, List +from typing import Any from aws_embedded_metrics import metric_scope from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.tracing import Tracer -from aws_lambda_powertools.utilities.data_classes import event_source, SQSEvent +from aws_lambda_powertools.utilities.data_classes import SQSEvent, event_source from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext from boto3 import client +from pytz import timezone from common.constants import DENTIST_ORG_TYPE_ID, PHARMACY_ORG_TYPE_ID, PHARMACY_SERVICE_TYPE_ID -from common.dos import DoSService, get_matching_dos_services, VALID_STATUS_ID +from common.dos import VALID_STATUS_ID, DoSService, get_matching_dos_services from common.middlewares import unhandled_exception_logging from common.nhs import NHSEntity from common.report_logging import ( @@ -38,13 +39,14 @@ @logger.inject_lambda_context(clear_state=True) @event_source(data_class=SQSEvent) @metric_scope -def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: - """Entrypoint handler for the service_matcher lambda +def lambda_handler(event: SQSEvent, context: LambdaContext, metrics: Any) -> None: # noqa: ANN401 + """Entrypoint handler for the service_matcher lambda. Args: event (SQSEvent): Lambda function invocation event (list of 1 SQS Message) Change Event has been validate by the ingest change event lambda context (LambdaContext): Lambda function context object + metrics (Any): Embedded metrics object Event: The event payload should contain a NHS Entity (Service) """ @@ -58,7 +60,10 @@ def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: metrics.set_property("function_name", context.function_name) # Get Datetime from milliseconds - date_time = datetime.fromtimestamp(holding_queue_change_event_item["message_received"] // 1000.0) + date_time = datetime.fromtimestamp( + holding_queue_change_event_item["message_received"] // 1000.0, + tz=timezone("Europe/London"), + ) metrics.set_property("message_received", date_time.strftime("%m/%d/%Y, %H:%M:%S")) nhs_entity = NHSEntity(change_event) @@ -92,11 +97,13 @@ def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: service for service in matching_services if service.typeid == PHARMACY_SERVICE_TYPE_ID ] log_unexpected_pharmacy_profiling( - matching_services=type_13_matching_services, reason="Multiple 'Pharmacy' type services found (type 13)" + matching_services=type_13_matching_services, + reason="Multiple 'Pharmacy' type services found (type 13)", ) elif countOf(dos_matching_service_types, PHARMACY_SERVICE_TYPE_ID) == 0: log_unexpected_pharmacy_profiling( - matching_services=matching_services, reason="No 'Pharmacy' type services found (type 13)" + matching_services=matching_services, + reason="No 'Pharmacy' type services found (type 13)", ) update_requests: list[UpdateRequest] = [ @@ -111,15 +118,25 @@ def lambda_handler(event: SQSEvent, context: LambdaContext, metrics) -> None: ) -def divide_chunks(to_chunk, chunk_size): +def divide_chunks(to_chunk: list, chunk_size: int) -> Any: # noqa: ANN401 + """Yield successive n-sized chunks from l.""" # looping till length l for i in range(0, len(to_chunk), chunk_size): - yield to_chunk[i : i + chunk_size] # noqa: E203 + yield to_chunk[i : i + chunk_size] + +def get_matching_services(nhs_entity: NHSEntity) -> list[DoSService]: + """Gets the matching DoS services for the given nhs entity. -def get_matching_services(nhs_entity: NHSEntity) -> List[DoSService]: - """Using the nhs entity attributed to this object, it finds the - matching DoS services from the db and filters the results""" + Using the nhs entity attributed to this object, it finds the + matching DoS services from the db and filters the results. + + Args: + nhs_entity (NHSEntity): The nhs entity to match against. + + Returns: + list[DoSService]: The list of matching DoS services. + """ # Check database for services with same first 5 digits of ODSCode logger.info(f"Getting matching DoS Services for odscode '{nhs_entity.odscode}'.") matching_dos_services = get_matching_dos_services(nhs_entity.odscode, nhs_entity.org_type_id) @@ -139,23 +156,26 @@ def get_matching_services(nhs_entity: NHSEntity) -> List[DoSService]: if nhs_entity.org_type_id == PHARMACY_ORG_TYPE_ID: logger.info( f"Found {len(matching_dos_services)} services in DB with " - f"matching first 5 chars of ODSCode: {matching_dos_services}" + f"matching first 5 chars of ODSCode: {matching_dos_services}", ) elif nhs_entity.org_type_id == DENTIST_ORG_TYPE_ID: logger.info(f"Found {len(matching_dos_services)} services in DB with matching ODSCode: {matching_dos_services}") logger.info( f"Found {len(matching_services)} services with typeid in " f"allowlist {valid_service_types} and status id = " - f"{VALID_STATUS_ID}: {matching_services}" + f"{VALID_STATUS_ID}: {matching_services}", ) return matching_services def send_update_requests( - update_requests: List[Dict[str, Any]], message_received: int, record_id: str, sequence_number: int + update_requests: list[dict[str, Any]], + message_received: int, + record_id: str, + sequence_number: int, ) -> None: - """Sends update request payload off to next part of workflow""" + """Sends update request payload off to next part of workflow.""" messages = [] for update_request in update_requests: service_id = update_request.get("service_id") @@ -193,7 +213,7 @@ def send_update_requests( "message_deduplication_id": {"DataType": "String", "StringValue": message_deduplication_id}, "message_group_id": {"DataType": "String", "StringValue": message_group_id}, }, - } + }, ) chunks = list(divide_chunks(messages, 10)) for i, chunk in enumerate(chunks): diff --git a/application/service_matcher/tests/conftest.py b/application/service_matcher/tests/conftest.py index 46b93a801..c598ce936 100644 --- a/application/service_matcher/tests/conftest.py +++ b/application/service_matcher/tests/conftest.py @@ -1,16 +1,25 @@ -from pytest import fixture +import pytest from testfixtures import LogCapture from common.tests.conftest import PHARMACY_STANDARD_EVENT -@fixture() -def log_capture(): +@pytest.fixture() +def log_capture() -> LogCapture: + """Capture logs. + + Yields: + LogCapture: Log capture + """ with LogCapture(names="lambda") as capture: yield capture -@fixture -def change_event(): - change_event = PHARMACY_STANDARD_EVENT.copy() - yield change_event +@pytest.fixture() +def change_event() -> dict: + """Generate a change event. + + Returns: + dict: Change event + """ + return PHARMACY_STANDARD_EVENT.copy() diff --git a/application/service_matcher/tests/test_service_matcher.py b/application/service_matcher/tests/test_service_matcher.py index 61fc58b92..5e1d9430f 100644 --- a/application/service_matcher/tests/test_service_matcher.py +++ b/application/service_matcher/tests/test_service_matcher.py @@ -5,38 +5,36 @@ from os import environ from unittest.mock import patch +import pytest from aws_embedded_metrics.logger.metrics_logger import MetricsLogger from aws_lambda_powertools.logging import Logger -from pytest import fixture, raises from application.common.types import HoldingQueueChangeEventItem from application.service_matcher.service_matcher import get_matching_services, lambda_handler, send_update_requests - from common.nhs import NHSEntity from common.opening_times import OpenPeriod, SpecifiedOpeningTime -from common.tests.conftest import dummy_dos_service, PHARMACY_STANDARD_EVENT +from common.tests.conftest import PHARMACY_STANDARD_EVENT, dummy_dos_service FILE_PATH = "application.service_matcher.service_matcher" SERVICE_MATCHER_ENVIRONMENT_VARIABLES = ["ENV"] -@fixture -def mock_metric_logger(): +@pytest.fixture(autouse=True) +def _mock_metric_logger() -> None: InvocationTracker.reset() - async def flush(self): - print("flush called") + async def flush(self) -> None: InvocationTracker.record() MetricsLogger.flush = flush -@fixture +@pytest.fixture() def lambda_context(): @dataclass class LambdaContext: - """Mock LambdaContext - All dummy values""" + """Mock LambdaContext - All dummy values.""" function_name: str = "service-matcher" memory_limit_in_mb: int = 128 @@ -110,7 +108,6 @@ def test_lambda_handler_unmatched_service( mock_get_matching_services, change_event, lambda_context, - mock_metric_logger, ): # Arrange mock_entity = NHSEntity(change_event) @@ -317,9 +314,10 @@ def test_lambda_handler_should_throw_exception_if_event_records_len_not_eq_one(l sqs_event["Records"] = [] for env in SERVICE_MATCHER_ENVIRONMENT_VARIABLES: environ[env] = "test" - - with raises(Exception): + # Act / Assert + with pytest.raises(StopIteration): lambda_handler(sqs_event, lambda_context) + # Clean up for env in SERVICE_MATCHER_ENVIRONMENT_VARIABLES: del environ[env] @@ -353,7 +351,12 @@ def test_send_update_requests(mock_logger, get_correlation_id_mockm, mock_sqs): "MessageDeduplicationId": f"1-{hashed_payload}", "MessageGroupId": "1", "MessageAttributes": get_message_attributes( - "1", message_received, record_id, odscode, f"1-{hashed_payload}", "1" + "1", + message_received, + record_id, + odscode, + f"1-{hashed_payload}", + "1", ), } mock_sqs.send_message_batch.assert_called_with( @@ -418,7 +421,6 @@ def test_lambda_handler_unexpected_pharmacy_profiling_multiple_type_13s( mock_log_unexpected_pharmacy_profiling, change_event, lambda_context, - mock_metric_logger, ): # Arrange mock_entity = NHSEntity(change_event) @@ -440,7 +442,8 @@ def test_lambda_handler_unexpected_pharmacy_profiling_multiple_type_13s( mock_get_matching_services.assert_called_once_with(mock_entity) mock_send_update_requests.assert_called() mock_log_unexpected_pharmacy_profiling.assert_called_once_with( - matching_services=[service, service], reason="Multiple 'Pharmacy' type services found (type 13)" + matching_services=[service, service], + reason="Multiple 'Pharmacy' type services found (type 13)", ) # Clean up for env in SERVICE_MATCHER_ENVIRONMENT_VARIABLES: @@ -464,7 +467,6 @@ def test_lambda_handler_unexpected_pharmacy_profiling_no_type_13s( mock_log_unexpected_pharmacy_profiling, change_event, lambda_context, - mock_metric_logger, ): # Arrange mock_entity = NHSEntity(change_event) @@ -486,7 +488,8 @@ def test_lambda_handler_unexpected_pharmacy_profiling_no_type_13s( mock_get_matching_services.assert_called_once_with(mock_entity) mock_send_update_requests.assert_called() mock_log_unexpected_pharmacy_profiling.assert_called_once_with( - matching_services=[service, service], reason="No 'Pharmacy' type services found (type 13)" + matching_services=[service, service], + reason="No 'Pharmacy' type services found (type 13)", ) # Clean up for env in SERVICE_MATCHER_ENVIRONMENT_VARIABLES: @@ -520,18 +523,22 @@ def test_lambda_handler_unexpected_pharmacy_profiling_no_type_13s( "eventSource": "aws:sqs", "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", "awsRegion": "us-east-2", - } - ] + }, + ], } -class InvocationTracker(object): +class InvocationTracker: + """Tracks the number of times a function has been invoked.""" + invocations = 0 @staticmethod - def record(): + def record() -> None: + """Record an invocation.""" InvocationTracker.invocations += 1 @staticmethod - def reset(): + def reset() -> None: + """Reset the invocation count.""" InvocationTracker.invocations = 0 diff --git a/application/service_sync/changes_to_dos.py b/application/service_sync/changes_to_dos.py index 201c489d3..79a77847a 100644 --- a/application/service_sync/changes_to_dos.py +++ b/application/service_sync/changes_to_dos.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple +from typing import Any from aws_lambda_powertools.logging import Logger @@ -9,7 +9,7 @@ from common.dos import DoSService, get_valid_dos_location from common.dos_location import DoSLocation from common.nhs import NHSEntity -from common.opening_times import opening_period_times_from_list, SpecifiedOpeningTime, StandardOpeningTimes +from common.opening_times import SpecifiedOpeningTime, StandardOpeningTimes, opening_period_times_from_list from common.report_logging import log_invalid_nhsuk_postcode from common.utilities import is_val_none_or_empty @@ -18,33 +18,33 @@ @dataclass(init=True, repr=True) class ChangesToDoS: - """Class to determine if an update needs to be made to the DoS db and if so, what the update should be""" + """Class to determine if an update needs to be made to the DoS db and if so, what the update should be.""" # Holding data classes for use within this class dos_service: DoSService nhs_entity: NHSEntity service_histories: ServiceHistories # Variable to know if fields need to be changed - demographic_changes: Dict[Optional[str], Any] = field(default_factory=dict) - standard_opening_times_changes: Dict[Optional[int], Any] = field(default_factory=dict) + demographic_changes: dict[str | None, Any] = field(default_factory=dict) + standard_opening_times_changes: dict[int | None, Any] = field(default_factory=dict) specified_opening_times_changes: bool = False palliative_care_changes: bool = False # New value to be saved to the database - new_address: Optional[str] = None - new_postcode: Optional[str] = None - new_public_phone: Optional[str] = None - new_specified_opening_times: Optional[List[SpecifiedOpeningTime]] = None - new_website: Optional[str] = None - new_palliative_care: Optional[bool] = None + new_address: str | None = None + new_postcode: str | None = None + new_public_phone: str | None = None + new_specified_opening_times: list[SpecifiedOpeningTime] | None = None + new_website: str | None = None + new_palliative_care: bool | None = None # Existing DoS data for use building service history - current_address: Optional[str] = None - current_postcode: Optional[str] = None - current_public_phone: Optional[str] = None - current_specified_opening_times: Optional[List[SpecifiedOpeningTime]] = None - current_website: Optional[str] = None - current_palliative_care: Optional[bool] = None + current_address: str | None = None + current_postcode: str | None = None + current_public_phone: str | None = None + current_specified_opening_times: list[SpecifiedOpeningTime] | None = None + current_website: str | None = None + current_palliative_care: bool | None = None # Each day that has changed will have a current and new value in the format below # new_day_opening_times e.g. new_monday_opening_times @@ -52,7 +52,7 @@ class ChangesToDoS: # The type of the value is a list of OpenPeriod objects def check_for_standard_opening_times_day_changes(self, weekday: str) -> bool: - """Check if the standard opening times have changed for a specific day + """Check if the standard opening times have changed for a specific day. Args: weekday (str): The day of the week lowercase to check (e.g. "monday") @@ -66,11 +66,9 @@ def check_for_standard_opening_times_day_changes(self, weekday: str) -> bool: nhs_opening_times = nhs_standard_open_dates.get_openings(weekday.title()) if not dos_standard_open_dates.same_openings(nhs_standard_open_dates, weekday): logger.info( - ( - f"{weekday.title()} opening times not equal. " - f"dos={opening_period_times_from_list(dos_opening_times)}, " - f"nhs={opening_period_times_from_list(nhs_opening_times)}" - ) + f"{weekday.title()} opening times not equal. " + f"dos={opening_period_times_from_list(dos_opening_times)}, " + f"nhs={opening_period_times_from_list(nhs_opening_times)}", ) # Set variable for the correct day setattr(self, f"current_{weekday}_opening_times", dos_opening_times) @@ -78,17 +76,15 @@ def check_for_standard_opening_times_day_changes(self, weekday: str) -> bool: return True else: logger.info( - ( - f"{weekday.title()} opening times are equal, so no change. " - f"dos={opening_period_times_from_list(dos_opening_times)} " - f"nhs={opening_period_times_from_list(nhs_opening_times)}" - ) + f"{weekday.title()} opening times are equal, so no change. " + f"dos={opening_period_times_from_list(dos_opening_times)} " + f"nhs={opening_period_times_from_list(nhs_opening_times)}", ) return False def check_for_specified_opening_times_changes(self) -> bool: """Check if the specified opening times have changed - Also past specified opening times are removed from the comparison + Also past specified opening times are removed from the comparison. Returns: bool: If there are changes to the specified opening times (not valiated) @@ -116,9 +112,9 @@ def check_for_specified_opening_times_changes(self) -> bool: self.new_specified_opening_times = future_nhs_spec_open_dates return True - def check_for_address_and_postcode_for_changes(self) -> Tuple[bool, bool, Optional[DoSLocation]]: + def check_for_address_and_postcode_for_changes(self) -> tuple[bool, bool, DoSLocation | None]: """Check if address and postcode have changed between dos_service and nhs_entity, - Postcode changes are validated against the DoS locations table + Postcode changes are validated against the DoS locations table. Returns: Tuple[bool, bool]: Tuple of booleans, first is if address has changed, second is if postcode has changed, third is the DoSLocation object for the postcode @@ -166,7 +162,7 @@ def check_for_address_and_postcode_for_changes(self) -> Tuple[bool, bool, Option return not is_address_same, not is_postcode_same, valid_dos_location def check_website_for_change(self) -> bool: - """Compares the website of from the dos_service and nhs_entity""" + """Compares the website of from the dos_service and nhs_entity.""" if is_val_none_or_empty(self.nhs_entity.website) and not is_val_none_or_empty(self.dos_service.web): # Deleting the existing website self.current_website = self.dos_service.web @@ -181,7 +177,7 @@ def check_website_for_change(self) -> bool: return False def compare_and_validate_website(self, dos_service: DoSService, nhs_entity: NHSEntity, nhs_website: str) -> bool: - """Compares the website of from the dos_service and formatted nhs website + """Compares the website of from the dos_service and formatted nhs website. Args: dos_service (DoSService): DoSService object to compare @@ -201,7 +197,7 @@ def compare_and_validate_website(self, dos_service: DoSService, nhs_entity: NHSE return False def check_public_phone_for_change(self) -> bool: - """Compares the public phone of from the dos_service and nhs_entity + """Compares the public phone of from the dos_service and nhs_entity. Returns: bool: True if the public phone has changed, False if not @@ -212,7 +208,7 @@ def check_public_phone_for_change(self) -> bool: not is_val_none_or_empty(self.current_public_phone) or not is_val_none_or_empty(self.new_public_phone) ): logger.info( - f"Public Phone is not equal, DoS='{self.current_public_phone}' != NHS UK='{self.new_public_phone}'" + f"Public Phone is not equal, DoS='{self.current_public_phone}' != NHS UK='{self.new_public_phone}'", ) return True else: @@ -220,7 +216,7 @@ def check_public_phone_for_change(self) -> bool: return False def check_palliative_care_for_change(self) -> bool: - """Compares the palliative care of from the dos_service and nhs_entity + """Compares the palliative care of from the dos_service and nhs_entity. Returns: bool: True if the palliative care is different, False if not @@ -230,11 +226,11 @@ def check_palliative_care_for_change(self) -> bool: if self.current_palliative_care != self.new_palliative_care: logger.info( f"Palliative Care is not equal, DoS='{self.current_palliative_care}' " - + f"!= NHS UK='{self.new_palliative_care}'" + + f"!= NHS UK='{self.new_palliative_care}'", ) return True else: logger.info( - f"Palliative Care is equal, DoS='{self.current_palliative_care}' == NHSUK='{self.new_palliative_care}'" + f"Palliative Care is equal, DoS='{self.current_palliative_care}' == NHSUK='{self.new_palliative_care}'", ) return False diff --git a/application/service_sync/compare_data.py b/application/service_sync/compare_data.py index fc987ebcf..a6db7e871 100644 --- a/application/service_sync/compare_data.py +++ b/application/service_sync/compare_data.py @@ -24,7 +24,7 @@ ) from common.dos import DoSService from common.dos_location import DoSLocation -from common.nhs import get_palliative_care_log_value, NHSEntity, skip_if_key_is_none +from common.nhs import NHSEntity, get_palliative_care_log_value, skip_if_key_is_none from common.opening_times import DAY_IDS, WEEKDAYS from common.report_logging import log_incorrect_palliative_stockholder_type, log_palliative_care_not_equal @@ -32,7 +32,9 @@ def compare_nhs_uk_and_dos_data( - dos_service: DoSService, nhs_entity: NHSEntity, service_histories: ServiceHistories + dos_service: DoSService, + nhs_entity: NHSEntity, + service_histories: ServiceHistories, ) -> ChangesToDoS: """Compares the data of the dos_service and nhs_entity and returns a ChangesToDoS object. @@ -110,7 +112,7 @@ def compare_location_data(changes_to_dos: ChangesToDoS) -> ChangesToDoS: - longitude - town - easting - - northing + - northing. Args: changes_to_dos (ChangesToDoS): ChangesToDoS holder object @@ -191,7 +193,8 @@ def compare_opening_times(changes_to_dos: ChangesToDoS) -> ChangesToDoS: for weekday, dos_weekday_key, day_id in zip(WEEKDAYS, DOS_STANDARD_OPENING_TIMES_CHANGE_KEY_LIST, DAY_IDS): if changes_to_dos.check_for_standard_opening_times_day_changes(weekday=weekday): changes_to_dos.standard_opening_times_changes[day_id] = getattr( - changes_to_dos, f"new_{weekday}_opening_times" + changes_to_dos, + f"new_{weekday}_opening_times", ) changes_to_dos.service_histories.add_standard_opening_times_change( current_opening_times=changes_to_dos.dos_service.standard_opening_times, @@ -262,7 +265,7 @@ def compare_palliative_care(changes_to_dos: ChangesToDoS) -> ChangesToDoS: ChangesToDoS: ChangesToDoS holder object """ skip_palliative_care_check = skip_if_key_is_none( - changes_to_dos.nhs_entity.extract_uec_service(NHS_UK_PALLIATIVE_CARE_SERVICE_CODE) + changes_to_dos.nhs_entity.extract_uec_service(NHS_UK_PALLIATIVE_CARE_SERVICE_CODE), ) logger.debug(f"Skip palliative care check: {skip_palliative_care_check}") if ( @@ -276,14 +279,16 @@ def compare_palliative_care(changes_to_dos: ChangesToDoS) -> ChangesToDoS: dos_palliative_care=changes_to_dos.dos_service.palliative_care, ) changes_to_dos.service_histories.add_sgsdid_change( - sgsdid=DOS_PALLIATIVE_CARE_SGSDID, new_value=changes_to_dos.nhs_entity.palliative_care + sgsdid=DOS_PALLIATIVE_CARE_SGSDID, + new_value=changes_to_dos.nhs_entity.palliative_care, ) elif ( changes_to_dos.dos_service.typeid in DOS_PHARMACY_NO_PALLIATIVE_CARE_TYPES and changes_to_dos.dos_service.palliative_care is True ): nhs_uk_palliative_care = get_palliative_care_log_value( - changes_to_dos.nhs_entity.palliative_care, skip_palliative_care_check + changes_to_dos.nhs_entity.palliative_care, + skip_palliative_care_check, ) log_incorrect_palliative_stockholder_type( nhs_uk_palliative_care=nhs_uk_palliative_care, @@ -295,7 +300,8 @@ def compare_palliative_care(changes_to_dos: ChangesToDoS) -> ChangesToDoS: "No change / Not suitable for palliative care comparison", extra={ "nhs_uk_palliative_care": get_palliative_care_log_value( - changes_to_dos.nhs_entity.palliative_care, skip_palliative_care_check + changes_to_dos.nhs_entity.palliative_care, + skip_palliative_care_check, ), "dos_palliative_care": changes_to_dos.dos_service.palliative_care, }, diff --git a/application/service_sync/dos_data.py b/application/service_sync/dos_data.py index 82325f7db..98afdc883 100644 --- a/application/service_sync/dos_data.py +++ b/application/service_sync/dos_data.py @@ -1,10 +1,9 @@ from os import environ -from typing import Dict, List, Tuple from aws_lambda_powertools.logging import Logger from psycopg import Connection from psycopg.rows import DictRow -from psycopg.sql import Identifier, Literal, SQL +from psycopg.sql import SQL, Identifier, Literal from .changes_to_dos import ChangesToDoS from .service_histories import ServiceHistories @@ -26,12 +25,12 @@ def run_db_health_check() -> None: - """Runs a health check to ensure the db is running""" + """Runs a health check to ensure the db is running.""" try: logger.info("Running health check") with connect_to_dos_db() as connection: cursor = query_dos_db(connection=connection, query="SELECT id FROM services LIMIT 1") - response_rows: List[DictRow] = cursor.fetchall() + response_rows: list[DictRow] = cursor.fetchall() if len(response_rows) > 0: logger.info("DoS database is running") else: @@ -40,7 +39,7 @@ def run_db_health_check() -> None: return with connect_to_dos_db_replica() as connection: cursor = query_dos_db(connection=connection, query="SELECT id FROM services LIMIT 1") - response_rows: List[DictRow] = cursor.fetchall() + response_rows: list[DictRow] = cursor.fetchall() if len(response_rows) > 0: logger.info("DoS database replica is running") else: @@ -56,8 +55,8 @@ def run_db_health_check() -> None: add_metric("ServiceSyncHealthCheckFailure") -def get_dos_service_and_history(service_id: int) -> Tuple[DoSService, ServiceHistories]: - """Retrieves DoS Services from DoS database +def get_dos_service_and_history(service_id: int) -> tuple[DoSService, ServiceHistories]: + """Retrieves DoS Services from DoS database. Args: service_id (str): Id of service to retrieve @@ -75,8 +74,8 @@ def get_dos_service_and_history(service_id: int) -> Tuple[DoSService, ServiceHis # Connect to the DoS database with connect_to_dos_db() as connection: # Query the DoS database for the service - cursor = query_dos_db(connection=connection, query=sql_query, vars=query_vars) - rows: List[DictRow] = cursor.fetchall() + cursor = query_dos_db(connection=connection, query=sql_query, query_vars=query_vars) + rows: list[DictRow] = cursor.fetchall() if len(rows) == 1: # Select first row (service) and create DoSService object service = DoSService(rows[0]) @@ -84,15 +83,19 @@ def get_dos_service_and_history(service_id: int) -> Tuple[DoSService, ServiceHis logger.append_keys(service_uid=service.uid) logger.append_keys(type_id=service.typeid) elif not rows: - raise ValueError(f"Service ID {service_id} not found") + msg = f"Service ID {service_id} not found" + raise ValueError(msg) else: - raise ValueError(f"Multiple services found for Service Id: {service_id}") + msg = f"Multiple services found for Service Id: {service_id}" + raise ValueError(msg) # Set up remaining service data service.standard_opening_times = get_standard_opening_times_from_db( - connection=connection, service_id=service_id + connection=connection, + service_id=service_id, ) service.specified_opening_times = get_specified_opening_times_from_db( - connection=connection, service_id=service_id + connection=connection, + service_id=service_id, ) # Set up palliative care flag service.palliative_care = has_palliative_care(service=service, connection=connection) @@ -105,20 +108,21 @@ def get_dos_service_and_history(service_id: int) -> Tuple[DoSService, ServiceHis def update_dos_data(changes_to_dos: ChangesToDoS, service_id: int, service_histories: ServiceHistories) -> None: - """Updates the DoS database with the changes to the service + """Updates the DoS database with the changes to the service. Args: changes_to_dos (ChangesToDoS): Changes to the dos service service_id (int): Id of service to update service_histories (ServiceHistories): Service history of the service """ - connection = None try: # Save all the changes to the DoS database with a single transaction with connect_to_dos_db() as connection: is_demographic_changes: bool = save_demographics_into_db( - connection=connection, service_id=service_id, demographics_changes=changes_to_dos.demographic_changes + connection=connection, + service_id=service_id, + demographics_changes=changes_to_dos.demographic_changes, ) is_standard_opening_times_changes: bool = save_standard_opening_times_into_db( connection=connection, @@ -144,7 +148,7 @@ def update_dos_data(changes_to_dos: ChangesToDoS, service_id: int, service_histo is_standard_opening_times_changes, is_specified_opening_times_changes, is_palliative_care_changes, - ] + ], ): service_histories.save_service_histories(connection=connection) connection.commit() @@ -160,7 +164,7 @@ def update_dos_data(changes_to_dos: ChangesToDoS, service_id: int, service_histo def save_demographics_into_db(connection: Connection, service_id: int, demographics_changes: dict) -> bool: - """Saves the demographic changes to the DoS database + """Saves the demographic changes to the DoS database. Args: connection (connection): Connection to the DoS database @@ -185,7 +189,7 @@ def save_demographics_into_db(connection: Connection, service_id: int, demograph cursor = query_dos_db( connection=connection, query=query_str, - vars={"SERVICE_ID": service_id}, + query_vars={"SERVICE_ID": service_id}, ) cursor.close() return True @@ -196,9 +200,11 @@ def save_demographics_into_db(connection: Connection, service_id: int, demograph def save_standard_opening_times_into_db( - connection: Connection, service_id: int, standard_opening_times_changes: Dict[int, List[OpenPeriod]] + connection: Connection, + service_id: int, + standard_opening_times_changes: dict[int, list[OpenPeriod]], ) -> bool: - """Saves the standard opening times changes to the DoS database + """Saves the standard opening times changes to the DoS database. Args: connection (connection): Connection to the DoS database @@ -217,7 +223,7 @@ def save_standard_opening_times_into_db( cursor = query_dos_db( connection=connection, query="""DELETE FROM servicedayopenings WHERE serviceid=%(SERVICE_ID)s AND dayid=%(DAY_ID)s""", - vars={"SERVICE_ID": service_id, "DAY_ID": dayid}, + query_vars={"SERVICE_ID": service_id, "DAY_ID": dayid}, ) cursor.close() if opening_periods != []: @@ -228,7 +234,7 @@ def save_standard_opening_times_into_db( """INSERT INTO servicedayopenings (serviceid, dayid) """ """VALUES (%(SERVICE_ID)s, %(DAY_ID)s) RETURNING id""" ), - vars={"SERVICE_ID": service_id, "DAY_ID": dayid}, + query_vars={"SERVICE_ID": service_id, "DAY_ID": dayid}, ) # Get the id of the newly created servicedayopenings entry by using the RETURNING clause service_day_opening_id = cursor.fetchone()["id"] @@ -243,7 +249,7 @@ def save_standard_opening_times_into_db( """INSERT INTO servicedayopeningtimes (servicedayopeningid, starttime, endtime) """ """VALUES (%(SERVICE_DAY_OPENING_ID)s, %(OPEN_PERIOD_START)s, %(OPEN_PERIOD_END)s);""" ), - vars={ + query_vars={ "SERVICE_DAY_OPENING_ID": service_day_opening_id, "OPEN_PERIOD_START": open_period.start, "OPEN_PERIOD_END": open_period.end, @@ -262,9 +268,9 @@ def save_specified_opening_times_into_db( connection: Connection, service_id: int, is_changes: bool, - specified_opening_times_changes: List[SpecifiedOpeningTime], + specified_opening_times_changes: list[SpecifiedOpeningTime], ) -> bool: - """Saves the specified opening times changes to the DoS database + """Saves the specified opening times changes to the DoS database. Args: connection (connection): Connection to the DoS database @@ -275,7 +281,6 @@ def save_specified_opening_times_into_db( Returns: bool: True if changes were made to the database, False if no changes were made """ - if is_changes: logger.info(f"Deleting all specified opening times for service id {service_id}") # Cascade delete the standard opening times in both @@ -283,7 +288,7 @@ def save_specified_opening_times_into_db( cursor = query_dos_db( connection=connection, query=("""DELETE FROM servicespecifiedopeningdates WHERE serviceid=%(SERVICE_ID)s """), - vars={"SERVICE_ID": service_id}, + query_vars={"SERVICE_ID": service_id}, ) cursor.close() for specified_opening_times_day in specified_opening_times_changes: @@ -294,7 +299,7 @@ def save_specified_opening_times_into_db( """INSERT INTO servicespecifiedopeningdates (date,serviceid) """ """VALUES (%(SPECIFIED_OPENING_TIMES_DATE)s,%(SERVICE_ID)s) RETURNING id;""" ), - vars={"SPECIFIED_OPENING_TIMES_DATE": specified_opening_times_day.date, "SERVICE_ID": service_id}, + query_vars={"SPECIFIED_OPENING_TIMES_DATE": specified_opening_times_day.date, "SERVICE_ID": service_id}, ) # Get the id of the newly created servicedayopenings entry by using the RETURNING clause service_specified_opening_date_id = cursor.fetchone()["id"] @@ -304,10 +309,8 @@ def save_specified_opening_times_into_db( open_period: OpenPeriod # Type hint for the for loop for open_period in specified_opening_times_day.open_periods: logger.debug( - ( - "Saving standard opening times period for dayid: " - f"{specified_opening_times_day.date}, period: {open_period}" - ) + "Saving standard opening times period for dayid: " + f"{specified_opening_times_day.date}, period: {open_period}", ) cursor = query_dos_db( connection=connection, @@ -317,7 +320,7 @@ def save_specified_opening_times_into_db( """VALUES (%(OPEN_PERIOD_START)s, %(OPEN_PERIOD_END)s,""" """%(IS_CLOSED)s,%(SERVICE_SPECIFIED_OPENING_DATE_ID)s);""" ), - vars={ + query_vars={ "OPEN_PERIOD_START": open_period.start, "OPEN_PERIOD_END": open_period.end, "IS_CLOSED": not specified_opening_times_day.is_open, @@ -335,7 +338,7 @@ def save_specified_opening_times_into_db( """VALUES ('00:00:00', '00:00:00',""" """%(IS_CLOSED)s,%(SERVICE_SPECIFIED_OPENING_DATE_ID)s);""" ), - vars={ + query_vars={ "IS_CLOSED": not specified_opening_times_day.is_open, "SERVICE_SPECIFIED_OPENING_DATE_ID": service_specified_opening_date_id, }, @@ -349,9 +352,12 @@ def save_specified_opening_times_into_db( def save_palliative_care_into_db( - connection: Connection, service_id: int, is_changes: bool, palliative_care: bool + connection: Connection, + service_id: int, + is_changes: bool, + palliative_care: bool, ) -> bool: - """Saves the palliative care changes to the DoS database + """Saves the palliative care changes to the DoS database. Args: connection (connection): Connection to the DoS database @@ -364,7 +370,7 @@ def save_palliative_care_into_db( """ def save_palliative_care_update() -> None: - """Saves the palliative care update to the DoS database""" + """Saves the palliative care update to the DoS database.""" query_vars = { "SERVICE_ID": service_id, "SDID": DOS_PALLIATIVE_CARE_SYMPTOM_DISCRIMINATOR, @@ -378,7 +384,7 @@ def save_palliative_care_update() -> None: else: query = "DELETE FROM servicesgsds WHERE serviceid=%(SERVICE_ID)s AND sdid=%(SDID)s AND sgid=%(SGID)s;" logger.debug(f"Setting palliative care to false for service id {service_id}") - cursor = query_dos_db(connection=connection, query=query, vars=query_vars) + cursor = query_dos_db(connection=connection, query=query, query_vars=query_vars) cursor.close() logger.info( f"Saving palliative care changes for service id {service_id}", @@ -408,7 +414,7 @@ def save_palliative_care_update() -> None: def validate_dos_palliative_care_z_code_exists(connection: Connection) -> bool: - """Validates that the palliative care Z code exists in the DoS database + """Validates that the palliative care Z code exists in the DoS database. Args: connection (connection): Connection to the DoS database @@ -422,7 +428,7 @@ def validate_dos_palliative_care_z_code_exists(connection: Connection) -> bool: "SELECT id FROM symptomgroupsymptomdiscriminators " "WHERE symptomgroupid=%(SGID)s AND symptomdiscriminatorid=%(SDID)s;" ), - vars={"SGID": DOS_PALLIATIVE_CARE_SYMPTOM_GROUP, "SDID": DOS_PALLIATIVE_CARE_SYMPTOM_DISCRIMINATOR}, + query_vars={"SGID": DOS_PALLIATIVE_CARE_SYMPTOM_GROUP, "SDID": DOS_PALLIATIVE_CARE_SYMPTOM_DISCRIMINATOR}, ) symptom_group_symptom_discriminator_combo_rowcount = cursor.rowcount cursor.close() diff --git a/application/service_sync/format.py b/application/service_sync/format.py index 365a849fe..60cd78600 100644 --- a/application/service_sync/format.py +++ b/application/service_sync/format.py @@ -6,13 +6,15 @@ def format_address(address: str) -> str: """Formats an address line to title case and removes apostrophes. As well it replaces any '&' symbols with and. Args: - value (str): Address line to format + address (str): Address line to format Returns: str: Formatted address line """ address = sub( - r"[A-Za-z]+('[A-Za-z]+)?", lambda word: word.group(0).capitalize(), address + r"[A-Za-z]+('[A-Za-z]+)?", + lambda word: word.group(0).capitalize(), + address, ) # Capitalise first letter of each word address = address.replace("'", "") # Remove apostrophes address = address.replace("&", "and") # Replace '&' with 'and' @@ -29,7 +31,7 @@ def format_website(website: str) -> str: str: Formatted website """ nhs_uk_website = urlparse(website) - if nhs_uk_website.netloc == "": # handle website like www.test.com + if not nhs_uk_website.netloc: # handle website like www.test.com if "/" in website: nhs_uk_website = website.split("/") nhs_uk_website[0] = nhs_uk_website[0].lower() diff --git a/application/service_sync/pending_changes.py b/application/service_sync/pending_changes.py index 88caafead..9149e0efc 100644 --- a/application/service_sync/pending_changes.py +++ b/application/service_sync/pending_changes.py @@ -1,9 +1,8 @@ from dataclasses import dataclass from datetime import datetime -from json import dumps, JSONDecodeError, loads +from json import JSONDecodeError, dumps, loads from os import environ from time import time_ns -from typing import List, Optional from aws_lambda_powertools.logging import Logger from boto3 import client @@ -22,9 +21,9 @@ @dataclass(repr=True) class PendingChange: - """A class representing a pending change from the DoS database with useful information about the change""" + """A class representing a pending change from the DoS database with useful information about the change.""" - id: str # Id of the pending change from the change table + id: str # Id of the pending change from the change table # noqa: A003 value: str # Value of the pending change as a JSON string creatorsname: str # User name of the user who made the change email: str # Email address of the user who made the change @@ -34,15 +33,16 @@ class PendingChange: user_id: str # User id of the user who made the change def __init__(self, db_cursor_row: dict) -> None: - """Sets the attributes of this object to those found in the db row + """Sets the attributes of this object to those found in the db row. + Args: - db_cursor_row (dict): row from db as key/val pairs + db_cursor_row (dict): row from db as key/val pairs. """ for row_key, row_value in db_cursor_row.items(): setattr(self, row_key, row_value) def __repr__(self) -> str: - """Returns a string representation of this object + """Returns a string representation of this object. Returns: str: String representation of this object @@ -60,7 +60,7 @@ def __repr__(self) -> str: ) def is_valid(self) -> bool: - """Checks if the pending change is valid + """Checks if the pending change is valid. Returns: bool: True if the pending change is valid, False otherwise @@ -68,17 +68,17 @@ def is_valid(self) -> bool: try: value_dict = loads(self.value) changes = value_dict["new"] - is_types_valid = [True if change in DI_CHANGE_ITEMS else False for change in changes.keys()] + is_types_valid = [change in DI_CHANGE_ITEMS for change in changes] return all(is_types_valid) except Exception: logger.exception( - f"Invalid JSON at pending change {self.id}, unable to show as contains sensitive user data" + f"Invalid JSON at pending change {self.id}, unable to show as contains sensitive user data", ) return False def check_and_remove_pending_dos_changes(service_id: str) -> None: - """Checks for pending changes in DoS and removes them if they exist + """Checks for pending changes in DoS and removes them if they exist. Args: service_id (str): The ID of the service to check @@ -96,8 +96,8 @@ def check_and_remove_pending_dos_changes(service_id: str) -> None: logger.info("No valid pending changes found") -def get_pending_changes(connection: Connection, service_id: str) -> Optional[List[PendingChange]]: - """Gets pending changes for a service ID +def get_pending_changes(connection: Connection, service_id: str) -> list[PendingChange] | None: + """Gets pending changes for a service ID. Args: connection (connection): The connection to the DoS database @@ -113,13 +113,13 @@ def get_pending_changes(connection: Connection, service_id: str) -> Optional[Lis "WHERE serviceid=%(SERVICE_ID)s AND approvestatus='PENDING'" ) query_vars = {"SERVICE_ID": service_id} - cursor = query_dos_db(connection=connection, query=sql_query, vars=query_vars) - response_rows: List[DictRow] = cursor.fetchall() + cursor = query_dos_db(connection=connection, query=sql_query, query_vars=query_vars) + response_rows: list[DictRow] = cursor.fetchall() cursor.close() if len(response_rows) < 1: return None logger.info(f"Pending changes found for Service ID {service_id}") - pending_changes: List[PendingChange] = [] + pending_changes: list[PendingChange] = [] for row in response_rows: pending_change = PendingChange(row) logger.info(f"Pending change found: {pending_change}", extra={"pending_change": pending_change}) @@ -132,8 +132,8 @@ def get_pending_changes(connection: Connection, service_id: str) -> Optional[Lis return pending_changes -def reject_pending_changes(connection: Connection, pending_changes: List[PendingChange]) -> None: - """Rejects pending changes from the database +def reject_pending_changes(connection: Connection, pending_changes: list[PendingChange]) -> None: + """Rejects pending changes from the database. Args: connection (connection): The connection to the DoS database @@ -146,7 +146,7 @@ def reject_pending_changes(connection: Connection, pending_changes: List[Pending ) # SQL Injection is prevented by the query only using data from DoS DB sql_query = ( - "UPDATE changes SET approvestatus='REJECTED', " # nosec B608 + "UPDATE changes SET approvestatus='REJECTED', " # noqa: S608 "modifiedtimestamp=%(TIMESTAMP)s, modifiersname=%(USER_NAME)s" f""" WHERE {conditions}""" ) @@ -154,25 +154,28 @@ def reject_pending_changes(connection: Connection, pending_changes: List[Pending "USER_NAME": DOS_INTEGRATION_USER_NAME, "TIMESTAMP": datetime.now(timezone("Europe/London")), } - cursor = query_dos_db(connection=connection, query=sql_query, vars=query_vars) + cursor = query_dos_db(connection=connection, query=sql_query, query_vars=query_vars) cursor.close() logger.info("Rejected pending change/s", extra={"pending_changes": pending_changes}) -def log_rejected_changes(pending_changes: List[PendingChange]) -> None: - """Logs the rejected changes +def log_rejected_changes(pending_changes: list[PendingChange]) -> None: + """Logs the rejected changes. Args: pending_changes (List[PendingChange]): The pending changes to log """ for pending_change in pending_changes: ServiceUpdateLogger( - service_uid=pending_change.uid, service_name=pending_change.name, type_id=pending_change.typeid, odscode="" + service_uid=pending_change.uid, + service_name=pending_change.name, + type_id=pending_change.typeid, + odscode="", ).log_rejected_change(pending_change.id) -def send_rejection_emails(pending_changes: List[PendingChange]) -> None: - """Sends rejection emails to the users who created the pending changes +def send_rejection_emails(pending_changes: list[PendingChange]) -> None: + """Sends rejection emails to the users who created the pending changes. Args: pending_changes (List[PendingChange]): The pending changes to send rejection emails for @@ -211,15 +214,16 @@ def send_rejection_emails(pending_changes: List[PendingChange]) -> None: def build_change_rejection_email_contents(pending_change: PendingChange, file_name: str) -> str: - """Builds the contents of the change rejection email + """Builds the contents of the change rejection email. Args: pending_change (PendingChange): The pending change to build the email for + file_name (str): The name of the file to upload to S3 Returns: str: The contents of the email """ - with open("service_sync/rejection-email.html", "r") as email_template: + with open("service_sync/rejection-email.html") as email_template: file_contents = email_template.read() email_template.close() email_correlation_id = f"{pending_change.uid}-{time_ns()}" diff --git a/application/service_sync/service_histories.py b/application/service_sync/service_histories.py index b8e1962d8..479860414 100644 --- a/application/service_sync/service_histories.py +++ b/application/service_sync/service_histories.py @@ -2,7 +2,7 @@ from itertools import chain from json import dumps, loads from time import time -from typing import Any, List +from typing import Any from aws_lambda_powertools.logging import Logger from psycopg import Connection @@ -23,6 +23,8 @@ class ServiceHistories: + """A service to be added to the servicehistories table.""" + NEW_CHANGE_KEY: str service_history: dict[str, Any] existing_service_history: dict[str, Any] @@ -30,6 +32,11 @@ class ServiceHistories: history_already_exists: bool def __init__(self, service_id: int) -> None: + """Initialises the ServiceHistories object. + + Args: + service_id (int): The service id of the service to be added to the servicehistories table. + """ # Epoch time in seconds rounded down to the nearest second self.current_epoch_time = int(time()) # Use same date/time from epoch time and format it to DoS date/time format @@ -40,7 +47,7 @@ def __init__(self, service_id: int) -> None: self.NEW_CHANGE_KEY = "new_change" def get_service_history_from_db(self, connection: Connection) -> None: - """Gets the service_histories json from the database + """Gets the service_histories json from the database. Args: connection (Connection): The connection to the database @@ -51,7 +58,7 @@ def get_service_history_from_db(self, connection: Connection) -> None: query="Select history from servicehistories where serviceid = %(SERVICE_ID)s", params={"SERVICE_ID": self.service_id}, ) - results: List[Any] = cursor.fetchall() + results: list[Any] = cursor.fetchall() if results != []: # Change History exists in the database logger.debug(f"Service history exists in the database for serviceid {self.service_id}") @@ -65,7 +72,7 @@ def get_service_history_from_db(self, connection: Connection) -> None: self.history_already_exists = False def create_service_histories_entry(self) -> None: - """Creates a new entry in the service_histories json for any changes that will be made to the service""" + """Creates a new entry in the service_histories json for any changes that will be made to the service.""" self.service_history[self.NEW_CHANGE_KEY] = { "new": {}, "initiator": {"userid": DOS_INTEGRATION_USER_NAME, "timestamp": "TBD"}, @@ -73,7 +80,7 @@ def create_service_histories_entry(self) -> None: } # Timestamp will be created when the change is sent to db for it to be realtime def add_change(self, dos_change_key: str, change: ServiceHistoriesChange) -> None: - """Adds a change to the updated service_histories json""" + """Adds a change to the updated service_histories json.""" self.service_history[self.NEW_CHANGE_KEY]["new"][dos_change_key] = change.get_change() def add_standard_opening_times_change( @@ -83,7 +90,7 @@ def add_standard_opening_times_change( weekday: str, dos_weekday_change_key: str, ) -> ServiceHistoriesChange: - """Adds a standard opening times change to the updated service_histories json + """Adds a standard opening times change to the updated service_histories json. Args: current_opening_times (StandardOpeningTimes): The current standard opening times @@ -115,10 +122,10 @@ def add_standard_opening_times_change( def add_specified_opening_times_change( self, - current_opening_times: List[SpecifiedOpeningTime], - new_opening_times: List[SpecifiedOpeningTime], + current_opening_times: list[SpecifiedOpeningTime], + new_opening_times: list[SpecifiedOpeningTime], ) -> ServiceHistoriesChange: - """Adds a change to the updated service_histories json + """Adds a change to the updated service_histories json. Args: current_opening_times (List[SpecifiedOpeningTime]): The current specified opening times @@ -155,7 +162,7 @@ def add_specified_opening_times_change( return change def add_sgsdid_change(self, sgsdid: str, new_value: bool) -> ServiceHistoriesChange: - """Adds a change to the updated service_histories json + """Adds a change to the updated service_histories json. Args: sgsdid (str): The sgsdid for the change @@ -180,8 +187,8 @@ def add_sgsdid_change(self, sgsdid: str, new_value: bool) -> ServiceHistoriesCha ) return change - def get_formatted_specified_opening_times(self, opening_times: List[SpecifiedOpeningTime]) -> list[str]: - """Returns the specified opening times in the format that is expected by the DoS Service History + def get_formatted_specified_opening_times(self, opening_times: list[SpecifiedOpeningTime]) -> list[str]: + """Returns the specified opening times in the format that is expected by the DoS Service History. Args: opening_times (List[SpecifiedOpeningTime]): The specified opening times to be formatted @@ -192,11 +199,11 @@ def get_formatted_specified_opening_times(self, opening_times: List[SpecifiedOpe # Get the opening times in the format that is expected by the DoS Service History Table opening_times = [ specified_opening_time.export_service_history_format() for specified_opening_time in opening_times - ] # type: ignore + ] return list(chain.from_iterable(opening_times)) def save_service_histories(self, connection: Connection) -> None: - """Saves the service_histories json to the database + """Saves the service_histories json to the database. Args: connection (connection): The database connection @@ -220,7 +227,7 @@ def save_service_histories(self, connection: Connection) -> None: """UPDATE services SET modifiedby=%(USER_NAME)s, """ """modifiedtime=%(CURRENT_DATE_TIME)s WHERE id = %(SERVICE_ID)s;""" ), - vars={ + query_vars={ "USER_NAME": DOS_INTEGRATION_USER_NAME, "CURRENT_DATE_TIME": current_date_time, "SERVICE_ID": self.service_id, @@ -235,7 +242,7 @@ def save_service_histories(self, connection: Connection) -> None: query=( """UPDATE servicehistories SET history = %(SERVICE_HISTORY)s WHERE serviceid = %(SERVICE_ID)s;""" ), - vars={"SERVICE_HISTORY": json_service_history, "SERVICE_ID": self.service_id}, + query_vars={"SERVICE_HISTORY": json_service_history, "SERVICE_ID": self.service_id}, log_vars=False, ) logger.info(f"Service history updated for serviceid {self.service_id}") @@ -248,7 +255,7 @@ def save_service_histories(self, connection: Connection) -> None: """INSERT INTO servicehistories (serviceid, history) """ """VALUES (%(SERVICE_ID)s, %(SERVICE_HISTORY)s);""" ), - vars={"SERVICE_ID": self.service_id, "SERVICE_HISTORY": json_service_history}, + query_vars={"SERVICE_ID": self.service_id, "SERVICE_HISTORY": json_service_history}, log_vars=False, ) cursor.close() diff --git a/application/service_sync/service_histories_change.py b/application/service_sync/service_histories_change.py index d125b6aee..5094222ea 100644 --- a/application/service_sync/service_histories_change.py +++ b/application/service_sync/service_histories_change.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict +from typing import Any from aws_lambda_powertools.logging import Logger @@ -17,7 +17,7 @@ @dataclass(repr=True) class ServiceHistoriesChange: - """A change to be added to the servicehistories table""" + """A change to be added to the servicehistories table.""" data: str previous_value: Any @@ -25,7 +25,22 @@ class ServiceHistoriesChange: change_action: str area: str - def __init__(self, data: Any, previous_value: Any, change_key: str, area=DOS_DEMOGRAPHICS_AREA_TYPE) -> None: + def __init__( + self, + data: Any, # noqa: ANN401 + previous_value: Any, # noqa: ANN401 + change_key: str, + area: str = DOS_DEMOGRAPHICS_AREA_TYPE, + ) -> None: + """Initialises the ServiceHistoriesChange object. + + Args: + data (Any): The data to be added to the servicehistories table. + previous_value (Any): The previous value of the data to be added to the servicehistories table. + change_key (str): The change key for the data to be added to the servicehistories table. + area (str, optional): The area of the data to be added to the servicehistories table. + Defaults to DOS_DEMOGRAPHICS_AREA_TYPE. + """ self.data = data self.previous_value = previous_value self.change_key = change_key @@ -41,10 +56,11 @@ def __init__(self, data: Any, previous_value: Any, change_key: str, area=DOS_DEM self.change_action = self.get_sgsd_change_action() else: logger.error(f"Unknown change key {self.change_key}") - raise ValueError("Unknown change key") + msg = "Unknown change key" + raise ValueError(msg) def get_demographics_change_action(self) -> str: - """Gets the change action for a demographics change + """Gets the change action for a demographics change. Returns: str: Change action - add, delete, modify @@ -53,13 +69,13 @@ def get_demographics_change_action(self) -> str: previous_value = self.previous_value if previous_value is None or previous_value == "None" and new_value is not None: return "add" - elif new_value is None: + elif new_value is None: # noqa: RET505 return "delete" else: return "modify" def get_sgsd_change_action(self) -> str: - """Gets the change action for a sgsd change + """Gets the change action for a sgsd change. Returns: str: Change action - add, delete @@ -69,23 +85,24 @@ def get_sgsd_change_action(self) -> str: return "add" if value == "add" else "delete" def get_opening_times_change_action(self) -> str: - """Gets the change action for a opening times (specified or standard) change + """Gets the change action for a opening times (specified or standard) change. Returns: str: Change action - add, delete, modify """ if "remove" in self.data and "add" in self.data: return "modify" - elif "remove" in self.data: + elif "remove" in self.data: # noqa: RET505 return "delete" elif "add" in self.data: return "add" else: logger.error(f"Unknown change action from {self.data}") - raise ValueError("Unknown change action") + msg = "Unknown change action" + raise ValueError(msg) - def get_change(self) -> Dict[str, Any]: - """Gets the change to be added to the servicehistories table + def get_change(self) -> dict[str, Any]: + """Gets the change to be added to the servicehistories table. Returns: Dict[str, Any]: Change to be added to the servicehistories table diff --git a/application/service_sync/service_sync.py b/application/service_sync/service_sync.py index a8986aa31..78b97db32 100644 --- a/application/service_sync/service_sync.py +++ b/application/service_sync/service_sync.py @@ -1,6 +1,6 @@ from os import environ from time import time_ns -from typing import Any, Dict +from typing import Any from aws_embedded_metrics import metric_scope from aws_lambda_powertools.logging import Logger @@ -24,8 +24,8 @@ @tracer.capture_lambda_handler() @unhandled_exception_logging @logger.inject_lambda_context(clear_state=True, correlation_id_path="metadata.correlation_id") -def lambda_handler(event: UpdateRequestQueueItem, context: LambdaContext) -> None: - """Entrypoint handler for the service_sync lambda +def lambda_handler(event: UpdateRequestQueueItem, context: LambdaContext) -> None: # noqa: ARG001 + """Entrypoint handler for the service_sync lambda. Args: event (UpdateRequestQueueItem): Lambda function invocation event @@ -41,7 +41,7 @@ def lambda_handler(event: UpdateRequestQueueItem, context: LambdaContext) -> Non service_id: int = event["update_request"]["service_id"] check_and_remove_pending_dos_changes(service_id) # Set up NHS UK Service - change_event: Dict[str, Any] = event["update_request"]["change_event"] + change_event: dict[str, Any] = event["update_request"]["change_event"] nhs_entity = NHSEntity(change_event) # Get current DoS state dos_service, service_histories = get_dos_service_and_history(service_id=service_id) @@ -58,16 +58,21 @@ def lambda_handler(event: UpdateRequestQueueItem, context: LambdaContext) -> Non # Delete the message from the queue remove_sqs_message_from_queue(event=event) # Log custom metrics - add_success_metric(event=event) # type: ignore + add_success_metric(event=event) add_metric("UpdateRequestSuccess") add_metric("ServiceUpdateSuccess") except Exception: put_circuit_is_open(environ["CIRCUIT"], True) - add_metric("UpdateRequestFailed") # type: ignore + add_metric("UpdateRequestFailed") logger.exception("Error processing change event") def set_up_logging(event: UpdateRequestQueueItem) -> None: + """Sets up the logger with the ODS code and service ID. + + Args: + event (UpdateRequestQueueItem): Lambda function invocation event + """ logger.append_keys( ods_code=event["update_request"]["change_event"].get("ODSCode"), service_id=event["update_request"]["service_id"], @@ -75,7 +80,7 @@ def set_up_logging(event: UpdateRequestQueueItem) -> None: def remove_sqs_message_from_queue(event: UpdateRequestQueueItem) -> None: - """Removes the SQS message from the queue + """Removes the SQS message from the queue. Args: event (UpdateRequestQueueItem): Lambda function invocation event @@ -86,11 +91,12 @@ def remove_sqs_message_from_queue(event: UpdateRequestQueueItem) -> None: @metric_scope -def add_success_metric(event: UpdateRequestQueueItem, metrics) -> None: # type: ignore - """Adds a success metric to the custom metrics collection +def add_success_metric(event: UpdateRequestQueueItem, metrics: Any) -> None: # noqa: ANN401 + """Adds a success metric to the custom metrics collection. Args: event (UpdateRequestQueueItem): Lambda function invocation event + metrics (Any): Custom metrics collection """ after = time_ns() // 1000000 metadata: UpdateRequestMetadata = event["metadata"] diff --git a/application/service_sync/service_update_logging.py b/application/service_sync/service_update_logging.py index 43cc7403b..f1f16e057 100644 --- a/application/service_sync/service_update_logging.py +++ b/application/service_sync/service_update_logging.py @@ -1,7 +1,7 @@ from itertools import chain -from logging import Formatter, INFO, Logger, StreamHandler +from logging import INFO, Formatter, Logger, StreamHandler from os import environ, getenv -from typing import Any, Dict, List, Optional, Union +from typing import Any from aws_embedded_metrics import metric_scope from aws_lambda_powertools.logging import Logger as PowerToolsLogger @@ -15,14 +15,14 @@ DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, DOS_STANDARD_OPENING_TIMES_CHANGE_KEY_LIST, ) -from common.opening_times import opening_period_times_from_list, SpecifiedOpeningTime, StandardOpeningTimes +from common.opening_times import SpecifiedOpeningTime, StandardOpeningTimes, opening_period_times_from_list from common.report_logging import log_service_updated logger = PowerToolsLogger(child=True) class ServiceUpdateLogger: - """A class to handle specfic logs to be sent to DoS Splunk""" + """A class to handle specfic logs to be sent to DoS Splunk.""" NULL_VALUE: str = "NULL" dos_basic_format = "%(asctime)s|%(levelname)s|DOS_INTEGRATION_%(environment)s|%(message)s" @@ -30,6 +30,14 @@ class ServiceUpdateLogger: logger: PowerToolsLogger def __init__(self, service_uid: str, service_name: str, type_id: str, odscode: str) -> None: + """Initialise the ServiceUpdateLogger. + + Args: + service_uid (str): The service uid + service_name (str): The service name + type_id (str): The service type id + odscode (str): The service odscode + """ # Create new logger / get existing logger self.dos_logger = Logger("dos_logger") self.logger = PowerToolsLogger(child=True) @@ -49,9 +57,12 @@ def __init__(self, service_uid: str, service_name: str, type_id: str, odscode: s self.environment = getenv("ENV", "UNKNOWN").upper() def get_opening_times_change( - self, data_field_modified: str, previous_value: Optional[str], new_value: Optional[str] + self, + data_field_modified: str, + previous_value: str | None, + new_value: str | None, ) -> tuple[str, str]: - """Get the opening times change in the format required for the log message + """Get the opening times change in the format required for the log message. Args: data_field_modified (str): The dos change name for field that was modified e.g cmsopentimemonday @@ -61,11 +72,11 @@ def get_opening_times_change( Returns: tuple[str, str]: The formatted previous and new values """ - existing_value = f"{data_field_modified}_existing={previous_value}" if previous_value != "" else previous_value - if previous_value != "" and new_value != "": + existing_value = f"{data_field_modified}_existing={previous_value}" if previous_value else previous_value + if previous_value and new_value: # Modify updated_value = f"{data_field_modified}_update=remove={previous_value}add={new_value}" - elif new_value == "": + elif not new_value: # Remove updated_value = f"{data_field_modified}_update=remove={previous_value}" else: @@ -74,9 +85,13 @@ def get_opening_times_change( return existing_value, updated_value def log_service_update( - self, data_field_modified: str, action: str, previous_value: Optional[str], new_value: Optional[str] + self, + data_field_modified: str, + action: str, + previous_value: str | None, + new_value: str | None, ) -> None: - """Logs a service update to DoS Splunk + """Logs a service update to DoS Splunk. Args: data_field_modified (str): The dos change name for field that was modified e.g cmsurl @@ -97,7 +112,7 @@ def log_service_update( service_uid=self.service_uid, type_id=self.type_id, ) - add_service_updated_metric(data_field_modified=data_field_modified) # type: ignore + add_service_updated_metric(data_field_modified=data_field_modified) self.dos_logger.info( msg=( @@ -109,15 +124,15 @@ def log_service_update( extra={"environment": self.environment}, ) - def log_standard_opening_times_service_update_for_weekday( + def log_standard_opening_times_service_update_for_weekday( # noqa: PLR0913 self, data_field_modified: str, action: str, - previous_value: Union[StandardOpeningTimes, str], - new_value: Union[StandardOpeningTimes, str], + previous_value: StandardOpeningTimes | str, + new_value: StandardOpeningTimes | str, weekday: str, ) -> None: - """Logs a service update to DoS Splunk for a standard opening times update + """Logs a service update to DoS Splunk for a standard opening times update. Args: data_field_modified (str): The dos change name for field that was modified e.g cmsopentimemonday @@ -129,12 +144,12 @@ def log_standard_opening_times_service_update_for_weekday( previous_value = ( opening_period_times_from_list(open_periods=previous_value.get_openings(weekday), with_space=False) if not isinstance(previous_value, str) - else previous_value # type: ignore + else previous_value ) new_value = ( opening_period_times_from_list(open_periods=new_value.get_openings(weekday), with_space=False) if not isinstance(new_value, str) - else new_value # type: ignore + else new_value ) existing_value, updated_value = self.get_opening_times_change(data_field_modified, previous_value, new_value) @@ -148,35 +163,37 @@ def log_standard_opening_times_service_update_for_weekday( def log_specified_opening_times_service_update( self, action: str, - previous_value: Optional[List[SpecifiedOpeningTime]], - new_value: Optional[List[SpecifiedOpeningTime]], + previous_value: list[SpecifiedOpeningTime] | None, + new_value: list[SpecifiedOpeningTime] | None, ) -> None: - """Logs a service update to DoS Splunk for a specified opening times update + """Logs a service update to DoS Splunk for a specified opening times update. Args: action (str): The action that was performed e.g add, remove, update previous_value (Optional[List[SpecifiedOpeningTime]]): The previous value of the field or none new_value (Optional[List[SpecifiedOpeningTime]]): The new value of the field or none - """ # noqa: E501 + """ def get_and_format_specified_opening_times( - specified_opening_times: Optional[List[SpecifiedOpeningTime]], + specified_opening_times: list[SpecifiedOpeningTime] | None, ) -> str: specified_opening_times = ( [specified_opening_time.export_dos_log_format() for specified_opening_time in specified_opening_times] if specified_opening_times is not None - else "" # type: ignore + else "" ) return ( ",".join(list(chain.from_iterable(specified_opening_times))) if isinstance(specified_opening_times, list) else "" - ) # type: ignore + ) previous_value = get_and_format_specified_opening_times(previous_value) new_value = get_and_format_specified_opening_times(new_value) existing_value, updated_value = self.get_opening_times_change( - DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, previous_value, new_value + DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, + previous_value, + new_value, ) self.log_service_update( @@ -190,7 +207,7 @@ def log_rejected_change( self, change_id: str, ) -> None: - """Logs a rejected change to DoS Splunk + """Logs a rejected change to DoS Splunk. Args: change_id (str): The change id to log @@ -209,7 +226,13 @@ def log_sgsdid_service_update( self, action: str, new_value: str, - ): + ) -> None: + """Logs a service update to DoS Splunk for a sgsdid update. + + Args: + action (str): The action that was performed e.g add, remove, update + new_value (str): The new value of the field + """ add_or_remove = "add" if action == "add" else "remove" self.log_service_update( data_field_modified=DOS_SGSDID_CHANGE_KEY, @@ -236,7 +259,7 @@ def log_service_updates(changes_to_dos: ChangesToDoS, service_histories: Service odscode=str(changes_to_dos.nhs_entity.odscode), ) most_recent_service_history_entry = list(service_histories.service_history.keys())[0] - service_history_changes: Dict[str, str] = service_histories.service_history[most_recent_service_history_entry][ + service_history_changes: dict[str, str] = service_histories.service_history[most_recent_service_history_entry][ "new" ] for change_key, change_values in service_history_changes.items(): @@ -258,7 +281,8 @@ def log_service_updates(changes_to_dos: ChangesToDoS, service_histories: Service ) elif change_key == DOS_SGSDID_CHANGE_KEY: service_update_logger.log_sgsdid_service_update( - action=change_values.get("changetype", "UNKOWN"), new_value=DOS_PALLIATIVE_CARE_SGSDID + action=change_values.get("changetype", "UNKOWN"), + new_value=DOS_PALLIATIVE_CARE_SGSDID, ) else: service_update_logger.log_service_update( @@ -271,7 +295,13 @@ def log_service_updates(changes_to_dos: ChangesToDoS, service_histories: Service @metric_scope -def add_service_updated_metric(data_field_modified: str, metrics: Any) -> None: +def add_service_updated_metric(data_field_modified: str, metrics: Any) -> None: # noqa: ANN401 + """Adds a metric to the service updated metric. + + Args: + data_field_modified (str): The data field modified + metrics (Any): The metrics object + """ metrics.set_namespace("UEC-DOS-INT") metrics.set_property("correlation_id", logger.get_correlation_id()) metrics.put_dimensions({"ENV": environ["ENV"], "field": data_field_modified}) diff --git a/application/service_sync/tests/test_changes_to_dos.py b/application/service_sync/tests/test_changes_to_dos.py index 7ce00b57c..293dc4d2e 100644 --- a/application/service_sync/tests/test_changes_to_dos.py +++ b/application/service_sync/tests/test_changes_to_dos.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock, patch -from pytest import mark +import pytest from application.common.opening_times import WEEKDAYS from application.service_sync.changes_to_dos import ChangesToDoS @@ -34,7 +34,7 @@ def test_changes_to_dos(): assert None is changes_to_dos.current_website -@mark.parametrize("weekday", WEEKDAYS) +@pytest.mark.parametrize("weekday", WEEKDAYS) def test_changes_to_dos_check_for_standard_opening_times_day_changes(weekday: str): # Arrange dos_service = MagicMock() @@ -49,7 +49,7 @@ def test_changes_to_dos_check_for_standard_opening_times_day_changes(weekday: st assert hasattr(changes_to_dos, f"new_{weekday}_opening_times") -@mark.parametrize("weekday", WEEKDAYS) +@pytest.mark.parametrize("weekday", WEEKDAYS) def test_changes_to_dos_check_for_standard_opening_times_day_changes_no_changes(weekday: str): # Arrange dos_service = MagicMock() @@ -120,7 +120,8 @@ def test_changes_to_dos_check_for_address_and_postcode_for_changes(mock_get_vali @patch(f"{FILE_PATH}.log_invalid_nhsuk_postcode") @patch(f"{FILE_PATH}.get_valid_dos_location") def test_changes_to_dos_check_for_address_and_postcode_for_changes_postcode_invalid( - mock_get_valid_dos_location: MagicMock, mock_log_invalid_nhsuk_postcode: MagicMock + mock_get_valid_dos_location: MagicMock, + mock_log_invalid_nhsuk_postcode: MagicMock, ): # Arrange dos_service = MagicMock() @@ -144,7 +145,8 @@ def test_changes_to_dos_check_for_address_and_postcode_for_changes_postcode_inva @patch(f"{FILE_PATH}.is_val_none_or_empty") @patch(f"{FILE_PATH}.format_website") def test_changes_to_dos_check_website_for_change_remove_website( - mock_format_website: MagicMock, mock_is_val_none_or_empty: MagicMock + mock_format_website: MagicMock, + mock_is_val_none_or_empty: MagicMock, ): # Arrange dos_service = MagicMock() diff --git a/application/service_sync/tests/test_compare_data.py b/application/service_sync/tests/test_compare_data.py index 49dd051e6..fd88044ba 100644 --- a/application/service_sync/tests/test_compare_data.py +++ b/application/service_sync/tests/test_compare_data.py @@ -1,4 +1,4 @@ -from unittest.mock import call, MagicMock, patch +from unittest.mock import MagicMock, call, patch from application.common.constants import ( DOS_ADDRESS_CHANGE_KEY, @@ -49,7 +49,9 @@ def test_compare_nhs_uk_and_dos_data( response = compare_nhs_uk_and_dos_data(dos_service, nhs_entity, service_histories) # Assert mock_changes_to_dos.assert_called_once_with( - dos_service=dos_service, nhs_entity=nhs_entity, service_histories=service_histories + dos_service=dos_service, + nhs_entity=nhs_entity, + service_histories=service_histories, ) mock_compare_website.assert_called_once_with(changes_to_dos=mock_changes_to_dos.return_value) mock_compare_public_phone.assert_called_once_with(changes_to_dos=mock_compare_website.return_value) @@ -155,7 +157,7 @@ def test_compare_location_data(mock_set_up_for_services_table_change: MagicMock) update_service_history=False, ), call().__eq__(mock_set_up_for_services_table_change.return_value), - ] # type: ignore + ], ) @@ -246,10 +248,11 @@ def test_compare_opening_times( dos_weekday_change_key=DOS_STANDARD_OPENING_TIMES_SUNDAY_CHANGE_KEY, weekday="sunday", ), - ] + ], ) changes_to_dos.service_histories.add_specified_opening_times_change.assert_called_once_with( - current_opening_times=None, new_opening_times=None + current_opening_times=None, + new_opening_times=None, ) @@ -334,7 +337,8 @@ def test_set_up_for_services_table_change(mock_service_histories_change: MagicMo change_key=change_key, ) changes_to_dos.service_histories.add_change.assert_called_once_with( - dos_change_key=change_key, change=mock_service_histories_change.return_value + dos_change_key=change_key, + change=mock_service_histories_change.return_value, ) @@ -368,7 +372,8 @@ def test_set_up_for_services_table_change_no_service_history_update(mock_service @patch(f"{FILE_PATH}.log_incorrect_palliative_stockholder_type") @patch(f"{FILE_PATH}.log_palliative_care_not_equal") def test_compare_palliative_care_unequal( - mock_log_palliative_care_not_equal: MagicMock, mock_log_incorrect_palliative_stockholder_type: MagicMock + mock_log_palliative_care_not_equal: MagicMock, + mock_log_incorrect_palliative_stockholder_type: MagicMock, ): # Arrange dos_service = MagicMock() @@ -384,7 +389,8 @@ def test_compare_palliative_care_unequal( # Assert assert response == changes_to_dos mock_log_palliative_care_not_equal.assert_called_once_with( - nhs_uk_palliative_care=nhs_palliative_care, dos_palliative_care=dos_palliative_care + nhs_uk_palliative_care=nhs_palliative_care, + dos_palliative_care=dos_palliative_care, ) mock_log_incorrect_palliative_stockholder_type.assert_not_called() diff --git a/application/service_sync/tests/test_dos_data.py b/application/service_sync/tests/test_dos_data.py index 1440043a0..e736e72ce 100644 --- a/application/service_sync/tests/test_dos_data.py +++ b/application/service_sync/tests/test_dos_data.py @@ -1,9 +1,9 @@ from datetime import date, time from os import environ -from unittest.mock import call, MagicMock, patch +from unittest.mock import MagicMock, call, patch +import pytest from aws_lambda_powertools.logging import Logger -from pytest import raises from application.common.opening_times import OpenPeriod, SpecifiedOpeningTime from application.service_sync.dos_data import ( @@ -41,7 +41,7 @@ def test_run_db_health_check_success( run_db_health_check() # Assert mock_logger.assert_has_calls( - [call("Running health check"), call("DoS database is running"), call("DoS database replica is running")] + [call("Running health check"), call("DoS database is running"), call("DoS database replica is running")], ) mock_connect_to_dos_db.assert_called_once() mock_connect_to_dos_db_replica.assert_called_once() @@ -149,14 +149,16 @@ def test_get_dos_service_and_history( # Assert assert mock_dos_service() == dos_service mock_get_standard_opening_times_from_db.assert_called_once_with( - connection=mock_connect_to_dos_db().__enter__(), service_id=service_id + connection=mock_connect_to_dos_db().__enter__(), + service_id=service_id, ) mock_get_specified_opening_times_from_db.assert_called_once_with( - connection=mock_connect_to_dos_db().__enter__(), service_id=service_id + connection=mock_connect_to_dos_db().__enter__(), + service_id=service_id, ) assert mock_service_histories() == service_history mock_service_histories.return_value.get_service_history_from_db.assert_called_once_with( - mock_connect_to_dos_db().__enter__() + mock_connect_to_dos_db().__enter__(), ) mock_service_histories.return_value.create_service_histories_entry.assert_called_once_with() @@ -171,7 +173,7 @@ def test_get_dos_service_and_history_no_match( service_id = 12345 mock_query_dos_db.return_value.fetchall.return_value = [] # Act - with raises(ValueError, match=f"Service ID {service_id} not found"): + with pytest.raises(ValueError, match=f"Service ID {service_id} not found"): get_dos_service_and_history(service_id) mock_connect_to_dos_db.assert_called_once() @@ -186,7 +188,7 @@ def test_get_dos_service_and_history_mutiple_matches( service_id = 12345 mock_query_dos_db.return_value.fetchall.return_value = [["Test"], ["Test"]] # Act - with raises(ValueError, match=f"Multiple services found for Service Id: {service_id}"): + with pytest.raises(ValueError, match=f"Multiple services found for Service Id: {service_id}"): get_dos_service_and_history(service_id) mock_connect_to_dos_db.assert_called_once() @@ -304,7 +306,7 @@ def test_save_demographics_into_db(mock_query_dos_db: MagicMock, mock_sql: Magic mock_query_dos_db.assert_called_once_with( connection=mock_connection, query=query, - vars={"SERVICE_ID": service_id}, + query_vars={"SERVICE_ID": service_id}, ) @@ -369,7 +371,8 @@ def test_save_specified_opening_times_into_db_closed(mock_query_dos_db: MagicMoc @patch(f"{FILE_PATH}.validate_dos_palliative_care_z_code_exists") @patch(f"{FILE_PATH}.query_dos_db") def test_save_palliative_care_into_db_insert( - mock_query_dos_db: MagicMock, mock_validate_dos_palliative_care_z_code_exists: MagicMock + mock_query_dos_db: MagicMock, + mock_validate_dos_palliative_care_z_code_exists: MagicMock, ): # Arrange mock_connection = MagicMock() @@ -383,14 +386,15 @@ def test_save_palliative_care_into_db_insert( mock_query_dos_db.assert_called_once_with( connection=mock_connection, query="INSERT INTO servicesgsds (serviceid, sdid, sgid) VALUES (%(SERVICE_ID)s, %(SDID)s, %(SGID)s);", - vars={"SERVICE_ID": service_id, "SDID": 14167, "SGID": 360}, + query_vars={"SERVICE_ID": service_id, "SDID": 14167, "SGID": 360}, ) @patch(f"{FILE_PATH}.validate_dos_palliative_care_z_code_exists") @patch(f"{FILE_PATH}.query_dos_db") def test_save_palliative_care_into_db_delete( - mock_query_dos_db: MagicMock, mock_validate_dos_palliative_care_z_code_exists: MagicMock + mock_query_dos_db: MagicMock, + mock_validate_dos_palliative_care_z_code_exists: MagicMock, ): # Arrange mock_connection = MagicMock() @@ -404,7 +408,7 @@ def test_save_palliative_care_into_db_delete( mock_query_dos_db.assert_called_once_with( connection=mock_connection, query="DELETE FROM servicesgsds WHERE serviceid=%(SERVICE_ID)s AND sdid=%(SDID)s AND sgid=%(SGID)s;", - vars={"SERVICE_ID": service_id, "SDID": 14167, "SGID": 360}, + query_vars={"SERVICE_ID": service_id, "SDID": 14167, "SGID": 360}, ) @@ -412,7 +416,9 @@ def test_save_palliative_care_into_db_delete( @patch(f"{FILE_PATH}.validate_dos_palliative_care_z_code_exists") @patch(f"{FILE_PATH}.query_dos_db") def test_save_palliative_care_into_db_no_z_code( - mock_query_dos_db: MagicMock, mock_validate_dos_palliative_care_z_code_exists: MagicMock, mock_add_metric: MagicMock + mock_query_dos_db: MagicMock, + mock_validate_dos_palliative_care_z_code_exists: MagicMock, + mock_add_metric: MagicMock, ): # Arrange mock_connection = MagicMock() @@ -431,7 +437,9 @@ def test_save_palliative_care_into_db_no_z_code( @patch(f"{FILE_PATH}.validate_dos_palliative_care_z_code_exists") @patch(f"{FILE_PATH}.query_dos_db") def test_save_palliative_care_into_db_no_change( - mock_query_dos_db: MagicMock, mock_validate_dos_palliative_care_z_code_exists: MagicMock, mock_add_metric: MagicMock + mock_query_dos_db: MagicMock, + mock_validate_dos_palliative_care_z_code_exists: MagicMock, + mock_add_metric: MagicMock, ): # Arrange mock_connection = MagicMock() @@ -463,10 +471,10 @@ def test_validate_dos_palliative_care_z_code_exists(mock_query_dos_db: MagicMock "SELECT id FROM symptomgroupsymptomdiscriminators WHERE symptomgroupid=%(SGID)s " "AND symptomdiscriminatorid=%(SDID)s;" ), - vars={"SGID": 360, "SDID": 14167}, + query_vars={"SGID": 360, "SDID": 14167}, ), call().close(), - ] + ], ) @@ -487,8 +495,8 @@ def test_validate_dos_palliative_care_z_code_exists_does_not_exist(mock_query_do "SELECT id FROM symptomgroupsymptomdiscriminators WHERE symptomgroupid=%(SGID)s " "AND symptomdiscriminatorid=%(SDID)s;" ), - vars={"SGID": 360, "SDID": 14167}, + query_vars={"SGID": 360, "SDID": 14167}, ), call().close(), - ] + ], ) diff --git a/application/service_sync/tests/test_format.py b/application/service_sync/tests/test_format.py index e13ff219d..d37157cec 100644 --- a/application/service_sync/tests/test_format.py +++ b/application/service_sync/tests/test_format.py @@ -1,10 +1,10 @@ -from pytest import mark +import pytest from application.service_sync.format import format_address, format_website -@mark.parametrize( - "address, formatted_address", +@pytest.mark.parametrize( + ("address", "formatted_address"), [ ("3rd Floor", "3Rd Floor"), ("24 Hour Road", "24 Hour Road"), @@ -29,8 +29,8 @@ def test_format_address(address: str, formatted_address: str): assert formatted_address == format_address(address) -@mark.parametrize( - "website, formatted_website", +@pytest.mark.parametrize( + ("website", "formatted_website"), [ ("www.test.com", "www.test.com"), ("www.test.com", "www.test.com"), diff --git a/application/service_sync/tests/test_pending_changes.py b/application/service_sync/tests/test_pending_changes.py index e83681df9..25e8519f4 100644 --- a/application/service_sync/tests/test_pending_changes.py +++ b/application/service_sync/tests/test_pending_changes.py @@ -1,17 +1,17 @@ from json import dumps from os import environ from random import choices -from unittest.mock import call, MagicMock, patch +from unittest.mock import MagicMock, call, patch -from pytest import CaptureFixture +import pytest from pytz import timezone from application.service_sync.pending_changes import ( + PendingChange, build_change_rejection_email_contents, check_and_remove_pending_dos_changes, get_pending_changes, log_rejected_changes, - PendingChange, reject_pending_changes, send_rejection_emails, ) @@ -118,7 +118,8 @@ def test_check_and_remove_pending_dos_changes( assert None is response mock_connect_to_dos_db.assert_called_once() mock_get_pending_changes.assert_called_once_with( - connection=mock_connect_to_dos_db.return_value.__enter__.return_value, service_id=service_id + connection=mock_connect_to_dos_db.return_value.__enter__.return_value, + service_id=service_id, ) mock_reject_pending_changes.assert_called_once_with( connection=mock_connect_to_dos_db.return_value.__enter__.return_value, @@ -149,7 +150,8 @@ def test_check_and_remove_pending_dos_changes_no_pending_changes( assert None is response mock_connect_to_dos_db.assert_called_once() mock_get_pending_changes.assert_called_once_with( - connection=mock_connect_to_dos_db.return_value.__enter__.return_value, service_id=service_id + connection=mock_connect_to_dos_db.return_value.__enter__.return_value, + service_id=service_id, ) mock_reject_pending_changes.assert_not_called() mock_log_rejected_changes.assert_not_called() @@ -177,7 +179,8 @@ def test_check_and_remove_pending_dos_changes_invalid_changes( assert None is response mock_connect_to_dos_db.assert_called_once() mock_get_pending_changes.assert_called_once_with( - connection=mock_connect_to_dos_db.return_value.__enter__.return_value, service_id=service_id + connection=mock_connect_to_dos_db.return_value.__enter__.return_value, + service_id=service_id, ) mock_reject_pending_changes.assert_not_called() mock_log_rejected_changes.assert_not_called() @@ -188,7 +191,9 @@ def test_check_and_remove_pending_dos_changes_invalid_changes( @patch(f"{FILE_PATH}.PendingChange.is_valid") @patch(f"{FILE_PATH}.query_dos_db") def test_get_pending_changes_is_pending_changes_valid_changes( - mock_query_dos_db: MagicMock, mock_is_valid: MagicMock, mock_repr: MagicMock + mock_query_dos_db: MagicMock, + mock_is_valid: MagicMock, + mock_repr: MagicMock, ): # Arrange connection = MagicMock() @@ -202,7 +207,7 @@ def test_get_pending_changes_is_pending_changes_valid_changes( mock_query_dos_db.assert_called_once_with( connection=connection, query=EXPECTED_QUERY, - vars={"SERVICE_ID": service_id}, + query_vars={"SERVICE_ID": service_id}, ) assert mock_repr.call_count == 2 mock_is_valid.assert_called_once() @@ -213,7 +218,9 @@ def test_get_pending_changes_is_pending_changes_valid_changes( @patch(f"{FILE_PATH}.PendingChange.is_valid") @patch(f"{FILE_PATH}.query_dos_db") def test_get_pending_changes_is_pending_changes_invalid_changes( - mock_query_dos_db: MagicMock, mock_is_valid: MagicMock, mock_repr: MagicMock + mock_query_dos_db: MagicMock, + mock_is_valid: MagicMock, + mock_repr: MagicMock, ): # Arrange connection = MagicMock() @@ -227,7 +234,7 @@ def test_get_pending_changes_is_pending_changes_invalid_changes( mock_query_dos_db.assert_called_once_with( connection=connection, query=EXPECTED_QUERY, - vars={"SERVICE_ID": service_id}, + query_vars={"SERVICE_ID": service_id}, ) assert mock_repr.call_count == 3 mock_is_valid.assert_called_once() @@ -248,7 +255,7 @@ def test_get_pending_changes_no_changes(mock_query_dos_db: MagicMock, mock_is_va mock_query_dos_db.assert_called_once_with( connection=connection, query=EXPECTED_QUERY, - vars={"SERVICE_ID": service_id}, + query_vars={"SERVICE_ID": service_id}, ) mock_is_valid.assert_not_called() assert None is response @@ -272,7 +279,7 @@ def test_reject_pending_changes_single_rejection(mock_query_dos_db: MagicMock, m "UPDATE changes SET approvestatus='REJECTED', " f"modifiedtimestamp=%(TIMESTAMP)s, modifiersname=%(USER_NAME)s WHERE id='{pending_change.id}'" ), - vars={"USER_NAME": "DOS_INTEGRATION", "TIMESTAMP": mock_datetime.now.return_value}, + query_vars={"USER_NAME": "DOS_INTEGRATION", "TIMESTAMP": mock_datetime.now.return_value}, ) @@ -300,11 +307,11 @@ def test_reject_pending_changes_multiple_rejections(mock_query_dos_db: MagicMock f"modifiedtimestamp=%(TIMESTAMP)s, modifiersname=%(USER_NAME)s " f"WHERE id in ('{pending_change1.id}','{pending_change2.id}','{pending_change3.id}')" ), - vars={"USER_NAME": "DOS_INTEGRATION", "TIMESTAMP": mock_datetime.now.return_value}, + query_vars={"USER_NAME": "DOS_INTEGRATION", "TIMESTAMP": mock_datetime.now.return_value}, ) -def test_log_rejected_changes(capsys: CaptureFixture): +def test_log_rejected_changes(capsys: pytest.CaptureFixture): # Arrange pending_change = PendingChange(ROW) pending_changes = [pending_change] @@ -355,10 +362,10 @@ def test_send_rejection_emails( "user_id": pending_change.user_id, "email_body": mock_build_change_rejection_email_contents.return_value, "email_subject": expected_subject, - } + }, ), call(mock_email_message.return_value), - ] + ], ) mock_put_content_to_s3.assert_called_once_with( content=mock_dumps.return_value, diff --git a/application/service_sync/tests/test_service_histories.py b/application/service_sync/tests/test_service_histories.py index ec00ad913..08be15bbb 100644 --- a/application/service_sync/tests/test_service_histories.py +++ b/application/service_sync/tests/test_service_histories.py @@ -10,10 +10,9 @@ DOS_STANDARD_OPENING_TIMES_MONDAY_CHANGE_KEY, ) from application.common.opening_times import OpenPeriod, SpecifiedOpeningTime +from application.service_sync.service_histories import ServiceHistories from application.service_sync.service_histories_change import ServiceHistoriesChange -from ..service_histories import ServiceHistories - FILE_PATH = "application.service_sync.service_histories" SERVICE_ID = 1 @@ -28,7 +27,7 @@ def test_service_histories(mock_time: MagicMock): assert "new_change" == service_histories.NEW_CHANGE_KEY assert {} == service_histories.service_history assert {} == service_histories.existing_service_history - assert SERVICE_ID == service_histories.service_id + assert service_histories.service_id == SERVICE_ID assert time == service_histories.current_epoch_time mock_time.assert_called_once() @@ -47,7 +46,8 @@ def test_service_histories_get_service_history_from_db_rows_returned(): assert change == service_history.existing_service_history mock_connection.cursor.assert_called_once_with(row_factory=dict_row) mock_connection.cursor.return_value.execute.assert_called_once_with( - query="Select history from servicehistories where serviceid = %(SERVICE_ID)s", params={"SERVICE_ID": SERVICE_ID} + query="Select history from servicehistories where serviceid = %(SERVICE_ID)s", + params={"SERVICE_ID": SERVICE_ID}, ) mock_connection.cursor.return_value.fetchall.assert_called_once() @@ -64,7 +64,8 @@ def test_service_histories_get_service_history_from_db_no_rows_returned(): assert {} == service_history.existing_service_history mock_connection.cursor.assert_called_once_with(row_factory=dict_row) mock_connection.cursor.return_value.execute.assert_called_once_with( - query="Select history from servicehistories where serviceid = %(SERVICE_ID)s", params={"SERVICE_ID": SERVICE_ID} + query="Select history from servicehistories where serviceid = %(SERVICE_ID)s", + params={"SERVICE_ID": SERVICE_ID}, ) mock_connection.cursor.return_value.fetchall.assert_called_once() @@ -81,7 +82,7 @@ def test_service_histories_create_service_histories_entry_no_history_already_exi "new": {}, "initiator": {"userid": "DOS_INTEGRATION", "timestamp": "TBD"}, "approver": {"userid": "DOS_INTEGRATION", "timestamp": "TBD"}, - } + }, } == service_history.service_history @@ -100,7 +101,7 @@ def test_service_histories_add_change(): "new": {change_key: change}, "initiator": {"userid": "DOS_INTEGRATION", "timestamp": "TBD"}, "approver": {"userid": "DOS_INTEGRATION", "timestamp": "TBD"}, - } + }, } == service_history.service_history mock_service_history_change.get_change.assert_called_once_with() @@ -112,7 +113,7 @@ def test_service_histories_add_standard_opening_times_change(mock_service_histor service_history.add_change = mock_add_change = MagicMock() current_opening_times = MagicMock() current_opening_times.export_opening_times_in_seconds_for_day.return_value = current_opening_times_in_seconds = [ - "456-789" + "456-789", ] current_opening_times.export_opening_times_for_day.return_value = ( current_opening_times_for_day @@ -130,7 +131,8 @@ def test_service_histories_add_standard_opening_times_change(mock_service_histor ) # Assert mock_add_change.assert_called_once_with( - dos_change_key=DOS_STANDARD_OPENING_TIMES_MONDAY_CHANGE_KEY, change=mock_service_histories_change_variable + dos_change_key=DOS_STANDARD_OPENING_TIMES_MONDAY_CHANGE_KEY, + change=mock_service_histories_change_variable, ) current_opening_times.export_opening_times_in_seconds_for_day.assert_called_once_with(weekday) new_opening_times.export_opening_times_in_seconds_for_day.assert_called_once_with(weekday) @@ -165,7 +167,8 @@ def test_service_histories_add_standard_opening_times_change_no_change(mock_serv ) # Assert mock_add_change.assert_called_once_with( - dos_change_key=DOS_STANDARD_OPENING_TIMES_MONDAY_CHANGE_KEY, change=mock_service_histories_change_variable + dos_change_key=DOS_STANDARD_OPENING_TIMES_MONDAY_CHANGE_KEY, + change=mock_service_histories_change_variable, ) current_opening_times.export_opening_times_in_seconds_for_day.assert_called_once_with(weekday) new_opening_times.export_opening_times_in_seconds_for_day.assert_called_once_with(weekday) @@ -180,7 +183,8 @@ def test_service_histories_add_standard_opening_times_change_no_change(mock_serv @patch(f"{FILE_PATH}.ServiceHistories.get_formatted_specified_opening_times") @patch(f"{FILE_PATH}.ServiceHistoriesChange") def test_service_histories_add_specified_opening_times_change_modify( - mock_service_histories_change: MagicMock, mock_get_formatted_specified_opening_times: MagicMock + mock_service_histories_change: MagicMock, + mock_get_formatted_specified_opening_times: MagicMock, ): # Arrange service_history = ServiceHistories(service_id=SERVICE_ID) @@ -206,7 +210,8 @@ def test_service_histories_add_specified_opening_times_change_modify( service_history.add_specified_opening_times_change(current_opening_times, new_opening_times) # Assert mock_add_change.assert_called_once_with( - dos_change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, change=mock_service_histories_change_variable + dos_change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, + change=mock_service_histories_change_variable, ) mock_service_histories_change.assert_called_once_with( change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, @@ -218,7 +223,8 @@ def test_service_histories_add_specified_opening_times_change_modify( @patch(f"{FILE_PATH}.ServiceHistories.get_formatted_specified_opening_times") @patch(f"{FILE_PATH}.ServiceHistoriesChange") def test_service_histories_add_specified_opening_times_change_add( - mock_service_histories_change: MagicMock, mock_get_formatted_specified_opening_times: MagicMock + mock_service_histories_change: MagicMock, + mock_get_formatted_specified_opening_times: MagicMock, ): # Arrange service_history = ServiceHistories(service_id=SERVICE_ID) @@ -239,7 +245,8 @@ def test_service_histories_add_specified_opening_times_change_add( service_history.add_specified_opening_times_change(current_opening_times, new_opening_times) # Assert mock_add_change.assert_called_once_with( - dos_change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, change=mock_service_histories_change_variable + dos_change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, + change=mock_service_histories_change_variable, ) mock_service_histories_change.assert_called_once_with( change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, @@ -251,7 +258,8 @@ def test_service_histories_add_specified_opening_times_change_add( @patch(f"{FILE_PATH}.ServiceHistories.get_formatted_specified_opening_times") @patch(f"{FILE_PATH}.ServiceHistoriesChange") def test_service_histories_add_specified_opening_times_change_remove( - mock_service_histories_change: MagicMock, mock_get_formatted_specified_opening_times: MagicMock + mock_service_histories_change: MagicMock, + mock_get_formatted_specified_opening_times: MagicMock, ): # Arrange service_history = ServiceHistories(service_id=SERVICE_ID) @@ -272,7 +280,8 @@ def test_service_histories_add_specified_opening_times_change_remove( service_history.add_specified_opening_times_change(current_opening_times, new_opening_times) # Assert mock_add_change.assert_called_once_with( - dos_change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, change=mock_service_histories_change_variable + dos_change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, + change=mock_service_histories_change_variable, ) mock_service_histories_change.assert_called_once_with( change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, @@ -286,7 +295,8 @@ def test_service_histories_add_specified_opening_times_change_remove( @patch(f"{FILE_PATH}.ServiceHistories.get_formatted_specified_opening_times") @patch(f"{FILE_PATH}.ServiceHistoriesChange") def test_service_histories_add_specified_opening_times_change_no_change( - mock_service_histories_change: MagicMock, mock_get_formatted_specified_opening_times: MagicMock + mock_service_histories_change: MagicMock, + mock_get_formatted_specified_opening_times: MagicMock, ): # Arrange service_history = ServiceHistories(service_id=SERVICE_ID) @@ -299,7 +309,8 @@ def test_service_histories_add_specified_opening_times_change_no_change( service_history.add_specified_opening_times_change(current_opening_times, new_opening_times) # Assert mock_add_change.assert_called_once_with( - dos_change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, change=mock_service_histories_change_variable + dos_change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, + change=mock_service_histories_change_variable, ) mock_service_histories_change.assert_called_once_with( change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, @@ -338,7 +349,7 @@ def test_service_histories_get_formatted_specified_opening_times(): specified_opening_times = [SpecifiedOpeningTime(open_periods, date(2022, 12, 26), True)] # Act formatted_specified_opening_times = service_history.get_formatted_specified_opening_times( - opening_times=specified_opening_times + opening_times=specified_opening_times, ) # Assert assert [ @@ -365,11 +376,11 @@ def test_service_histories_save_service_histories_insert(mock_query_dos_db: Magi "data": "52 Green Lane$Southgate", "area": "demographic", "previous": "51 Green Lane$Southgate", - } + }, }, "initiator": {"userid": "DOS_INTEGRATION", "timestamp": "TBD"}, "approver": {"userid": "DOS_INTEGRATION", "timestamp": "TBD"}, - } + }, } # Act service_history.save_service_histories(mock_connection) @@ -394,11 +405,11 @@ def test_service_histories_save_service_histories_update(mock_query_dos_db: Magi "data": "52 Green Lane$Southgate", "area": "demographic", "previous": "51 Green Lane$Southgate", - } + }, }, "initiator": {"userid": "DOS_INTEGRATION", "timestamp": "TBD"}, "approver": {"userid": "DOS_INTEGRATION", "timestamp": "TBD"}, - } + }, } # Act service_history.save_service_histories(mock_connection) diff --git a/application/service_sync/tests/test_service_histories_change.py b/application/service_sync/tests/test_service_histories_change.py index 312dbf617..b1c7709b9 100644 --- a/application/service_sync/tests/test_service_histories_change.py +++ b/application/service_sync/tests/test_service_histories_change.py @@ -1,8 +1,8 @@ from unittest.mock import patch -from pytest import mark, raises +import pytest -from ..service_histories_change import ServiceHistoriesChange +from application.service_sync.service_histories_change import ServiceHistoriesChange from common.constants import ( DOS_DEMOGRAPHICS_AREA_TYPE, DOS_SERVICES_TABLE_CHANGE_TYPE_LIST, @@ -22,21 +22,23 @@ PREVIOUS_VALUE = "Old value to be removed from db" -@mark.parametrize("demographics_change_key", (DOS_SERVICES_TABLE_CHANGE_TYPE_LIST)) +@pytest.mark.parametrize("demographics_change_key", (DOS_SERVICES_TABLE_CHANGE_TYPE_LIST)) @patch(f"{FILE_PATH}.ServiceHistoriesChange.get_demographics_change_action") def test_service_histories_change_demographics_change(mock_get_demographics_change_action, demographics_change_key): # Act service_histories_change = ServiceHistoriesChange( - data=DATA, previous_value=PREVIOUS_VALUE, change_key=demographics_change_key + data=DATA, + previous_value=PREVIOUS_VALUE, + change_key=demographics_change_key, ) # Assert - assert DATA == service_histories_change.data - assert PREVIOUS_VALUE == service_histories_change.previous_value - assert DOS_DEMOGRAPHICS_AREA_TYPE == service_histories_change.area + assert service_histories_change.data == DATA + assert service_histories_change.previous_value == PREVIOUS_VALUE + assert service_histories_change.area == DOS_DEMOGRAPHICS_AREA_TYPE mock_get_demographics_change_action.assert_called_once_with() -@mark.parametrize( +@pytest.mark.parametrize( "opening_times_change_key", [ DOS_STANDARD_OPENING_TIMES_MONDAY_CHANGE_KEY, @@ -53,12 +55,14 @@ def test_service_histories_change_demographics_change(mock_get_demographics_chan def test_service_histories_change_opening_times_change(mock_get_opening_times_change_action, opening_times_change_key): # Act service_histories_change = ServiceHistoriesChange( - data=DATA, previous_value=PREVIOUS_VALUE, change_key=opening_times_change_key + data=DATA, + previous_value=PREVIOUS_VALUE, + change_key=opening_times_change_key, ) # Assert - assert DATA == service_histories_change.data - assert PREVIOUS_VALUE == service_histories_change.previous_value - assert DOS_DEMOGRAPHICS_AREA_TYPE == service_histories_change.area + assert service_histories_change.data == DATA + assert service_histories_change.previous_value == PREVIOUS_VALUE + assert service_histories_change.area == DOS_DEMOGRAPHICS_AREA_TYPE mock_get_opening_times_change_action.assert_called_once_with() @@ -66,28 +70,30 @@ def test_service_histories_change_opening_times_change(mock_get_opening_times_ch @patch(f"{FILE_PATH}.ServiceHistoriesChange.get_demographics_change_action") def test_service_histories_change_no_change(demographics_change_key, mock_get_opening_times_change_action): # Act - with raises(ValueError, match="Unknown change key"): + with pytest.raises(ValueError, match="Unknown change key"): ServiceHistoriesChange(data=DATA, previous_value=PREVIOUS_VALUE, change_key="ANY") # Assert demographics_change_key.assert_not_called() mock_get_opening_times_change_action.assert_not_called() -@mark.parametrize( - "data, previous_value, expected_action", +@pytest.mark.parametrize( + ("data", "previous_value", "expected_action"), [(DATA, PREVIOUS_VALUE, "modify"), (None, PREVIOUS_VALUE, "delete"), (DATA, None, "add")], ) def test_service_histories_change_get_demographics_change_action(data, previous_value, expected_action): # Act service_histories_change = ServiceHistoriesChange( - data=data, previous_value=previous_value, change_key=DOS_WEBSITE_CHANGE_KEY + data=data, + previous_value=previous_value, + change_key=DOS_WEBSITE_CHANGE_KEY, ) # get_demographics_change_action should be called by __init__ function # Assert assert expected_action == service_histories_change.change_action -@mark.parametrize( - "data, expected_action", +@pytest.mark.parametrize( + ("data", "expected_action"), [ ({"remove": "TO_REMOVE", "add": "TO_ADD"}, "modify"), ({"remove": "TO_REMOVE"}, "delete"), @@ -97,7 +103,9 @@ def test_service_histories_change_get_demographics_change_action(data, previous_ def test_service_histories_change_get_opening_times_change_action(data, expected_action): # Act service_histories_change = ServiceHistoriesChange( - data=data, previous_value=None, change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY + data=data, + previous_value=None, + change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY, ) # get_opening_times_change_action should be called by __init__ function # Assert assert expected_action == service_histories_change.change_action @@ -105,7 +113,7 @@ def test_service_histories_change_get_opening_times_change_action(data, expected def test_service_histories_change_get_opening_times_change_action_error(): # Act & Assert - with raises(ValueError, match="Unknown change action"): + with pytest.raises(ValueError, match="Unknown change action"): ServiceHistoriesChange(data={}, previous_value=None, change_key=DOS_SPECIFIED_OPENING_TIMES_CHANGE_KEY) @@ -114,7 +122,9 @@ def test_service_histories_change_get_change(mock_get_demographics_change_action # Arrange mock_get_demographics_change_action.return_value = change_action = "Change Action" service_histories_change = ServiceHistoriesChange( - data=DATA, previous_value=PREVIOUS_VALUE, change_key=DOS_WEBSITE_CHANGE_KEY + data=DATA, + previous_value=PREVIOUS_VALUE, + change_key=DOS_WEBSITE_CHANGE_KEY, ) # Act response = service_histories_change.get_change() @@ -132,7 +142,9 @@ def test_service_histories_change_get_change_add(mock_get_demographics_change_ac # Arrange mock_get_demographics_change_action.return_value = change_action = "add" service_histories_change = ServiceHistoriesChange( - data=DATA, previous_value=PREVIOUS_VALUE, change_key=DOS_WEBSITE_CHANGE_KEY + data=DATA, + previous_value=PREVIOUS_VALUE, + change_key=DOS_WEBSITE_CHANGE_KEY, ) # Act response = service_histories_change.get_change() diff --git a/application/service_sync/tests/test_service_sync.py b/application/service_sync/tests/test_service_sync.py index f53cd6282..a434943cb 100644 --- a/application/service_sync/tests/test_service_sync.py +++ b/application/service_sync/tests/test_service_sync.py @@ -1,12 +1,17 @@ from dataclasses import dataclass from os import environ -from unittest.mock import call, MagicMock, patch +from unittest.mock import MagicMock, call, patch +import pytest from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.utilities.typing import LambdaContext -from pytest import fixture -from ..service_sync import add_success_metric, lambda_handler, remove_sqs_message_from_queue, set_up_logging +from application.service_sync.service_sync import ( + add_success_metric, + lambda_handler, + remove_sqs_message_from_queue, + set_up_logging, +) from common.types import UpdateRequest, UpdateRequestMetadata, UpdateRequestQueueItem FILE_PATH = "application.service_sync.service_sync" @@ -29,11 +34,11 @@ ) -@fixture +@pytest.fixture() def lambda_context(): @dataclass class LambdaContext: - """Mock LambdaContext - All dummy values""" + """Mock LambdaContext - All dummy values.""" function_name: str = "service-sync" memory_limit_in_mb: int = 128 @@ -116,12 +121,14 @@ def test_lambda_handler_no_healthcheck( # Assert mock_set_up_logging.assert_called_once_with(UPDATE_REQUEST_QUEUE_ITEM) mock_check_and_remove_pending_dos_changes.assert_called_once_with( - UPDATE_REQUEST_QUEUE_ITEM["update_request"]["service_id"] + UPDATE_REQUEST_QUEUE_ITEM["update_request"]["service_id"], ) mock_nhs_entity.assert_called_once_with(CHANGE_EVENT) mock_get_dos_service_and_history.assert_called_once_with(service_id=SERVICE_ID) mock_compare_nhs_uk_and_dos_data.assert_called_once_with( - dos_service=dos_service, nhs_entity=nhs_entity, service_histories=service_histories + dos_service=dos_service, + nhs_entity=nhs_entity, + service_histories=service_histories, ) mock_update_dos_data.assert_called_once_with( changes_to_dos=mock_compare_nhs_uk_and_dos_data(), @@ -174,7 +181,7 @@ def test_lambda_handler_no_healthcheck_exception( # Assert mock_set_up_logging.assert_called_once_with(UPDATE_REQUEST_QUEUE_ITEM) mock_check_and_remove_pending_dos_changes.assert_called_once_with( - UPDATE_REQUEST_QUEUE_ITEM["update_request"]["service_id"] + UPDATE_REQUEST_QUEUE_ITEM["update_request"]["service_id"], ) mock_nhs_entity.assert_called_once_with(CHANGE_EVENT) mock_get_dos_service_and_history.assert_called_once_with(service_id=SERVICE_ID) @@ -209,7 +216,8 @@ def test_remove_sqs_message_from_queue(mock_client: MagicMock, mock_logger_info: # Assert mock_client.assert_called_once_with("sqs") mock_client.return_value.delete_message.assert_called_once_with( - QueueUrl=update_request_queue_url, ReceiptHandle=RECIPIENT_ID + QueueUrl=update_request_queue_url, + ReceiptHandle=RECIPIENT_ID, ) mock_logger_info.assert_called_once_with("Removed SQS message from queue", extra={"receipt_handle": RECIPIENT_ID}) # Cleanup diff --git a/application/service_sync/tests/test_service_update_logging.py b/application/service_sync/tests/test_service_update_logging.py index 13d4726d3..c0c67956f 100644 --- a/application/service_sync/tests/test_service_update_logging.py +++ b/application/service_sync/tests/test_service_update_logging.py @@ -3,7 +3,7 @@ from os import environ from unittest.mock import MagicMock, patch -from pytest import fixture +import pytest from application.common.constants import ( DOS_INTEGRATION_USER_NAME, @@ -13,7 +13,7 @@ DOS_STANDARD_OPENING_TIMES_FRIDAY_CHANGE_KEY, ) from application.common.opening_times import OpenPeriod, SpecifiedOpeningTime -from application.service_sync.service_update_logging import log_service_updates, ServiceUpdateLogger +from application.service_sync.service_update_logging import ServiceUpdateLogger, log_service_updates SERVICE_UID = "12345" SERVICE_NAME = "Test Service" @@ -27,9 +27,9 @@ FILE_PATH = "application.service_sync.service_update_logging" -@fixture +@pytest.fixture() def service_update_logger(): - yield ServiceUpdateLogger(service_uid=SERVICE_UID, service_name=SERVICE_NAME, type_id=TYPE_ID, odscode=ODSCODE) + return ServiceUpdateLogger(service_uid=SERVICE_UID, service_name=SERVICE_NAME, type_id=TYPE_ID, odscode=ODSCODE) def test_dos_logger(service_update_logger: ServiceUpdateLogger): @@ -53,7 +53,9 @@ def test_service_update_logger_get_opening_times_change_modify(service_update_lo data_field_modified = "test_field" # Act response = service_update_logger.get_opening_times_change( - data_field_modified=data_field_modified, previous_value=previous_value, new_value=new_value + data_field_modified=data_field_modified, + previous_value=previous_value, + new_value=new_value, ) assert ( f"{data_field_modified}_existing={previous_value}", @@ -68,7 +70,9 @@ def test_service_update_logger_get_opening_times_change_remove(service_update_lo data_field_modified = "test_field" # Act response = service_update_logger.get_opening_times_change( - data_field_modified=data_field_modified, previous_value=previous_value, new_value=new_value + data_field_modified=data_field_modified, + previous_value=previous_value, + new_value=new_value, ) assert ( f"{data_field_modified}_existing={previous_value}", @@ -83,14 +87,17 @@ def test_service_update_logger_get_opening_times_change_add(service_update_logge data_field_modified = "test_field" # Act response = service_update_logger.get_opening_times_change( - data_field_modified=data_field_modified, previous_value=previous_value, new_value=new_value + data_field_modified=data_field_modified, + previous_value=previous_value, + new_value=new_value, ) assert ("", f"{data_field_modified}_update=add={new_value}") == response @patch(f"{FILE_PATH}.log_service_updated") def test_service_update_logger_log_service_update( - mock_log_service_update: MagicMock, service_update_logger: ServiceUpdateLogger + mock_log_service_update: MagicMock, + service_update_logger: ServiceUpdateLogger, ): # Arrange environ["ENV"] = "UNKNOWN" @@ -147,7 +154,9 @@ def test_service_update_logger_log_standard_opening_times_service_update_for_wee # Assert mock_opening_period_times_from_list.assert_not_called() mock_get_opening_times_change.assert_called_once_with( - EXAMPLE_DATA_FIELD_MODIFIED, EXAMPLE_PREVIOUS_VALUE, EXAMPLE_NEW_VALUE + EXAMPLE_DATA_FIELD_MODIFIED, + EXAMPLE_PREVIOUS_VALUE, + EXAMPLE_NEW_VALUE, ) mock_log_service_update.assert_called_once_with( data_field_modified=EXAMPLE_DATA_FIELD_MODIFIED, @@ -165,7 +174,10 @@ def test_service_update_logger_log_specified_opening_times_service_update( ): # Arrange service_update_logger = ServiceUpdateLogger( - service_uid=SERVICE_UID, service_name=SERVICE_NAME, type_id=TYPE_ID, odscode=ODSCODE + service_uid=SERVICE_UID, + service_name=SERVICE_NAME, + type_id=TYPE_ID, + odscode=ODSCODE, ) open_periods = [ OpenPeriod(time(1, 0, 0), time(2, 0, 0)), @@ -209,8 +221,8 @@ def test_log_service_updates_demographics_change(mock_service_update_logger: Mag "area": "demographics", "previous": EXAMPLE_PREVIOUS_VALUE, }, - } - } + }, + }, } service_histories.service_history.keys.return_value = [time_stamp] service_histories.service_history.__getitem__.return_value.__getitem__.return_value = service_history[time_stamp][ @@ -249,8 +261,8 @@ def test_log_service_updates_standard_opening_times_change(mock_service_update_l "area": "demographics", "previous": EXAMPLE_PREVIOUS_VALUE, }, - } - } + }, + }, } service_histories.service_history.keys.return_value = [time_stamp] service_histories.service_history.__getitem__.return_value.__getitem__.return_value = service_history[time_stamp][ @@ -290,8 +302,8 @@ def test_log_service_updates_specified_opening_times_change(mock_service_update_ "area": "demographics", "previous": EXAMPLE_PREVIOUS_VALUE, }, - } - } + }, + }, } service_histories.service_history.keys.return_value = [time_stamp] service_histories.service_history.__getitem__.return_value.__getitem__.return_value = service_history[time_stamp][ @@ -329,8 +341,8 @@ def test_log_service_updates_sgsdid_change(mock_service_update_logger: MagicMock "area": "clinical", "previous": "", }, - } - } + }, + }, } service_histories.service_history.keys.return_value = [time_stamp] service_histories.service_history.__getitem__.return_value.__getitem__.return_value = service_history[time_stamp][ @@ -347,5 +359,6 @@ def test_log_service_updates_sgsdid_change(mock_service_update_logger: MagicMock ) mock_service_update_logger.return_value.log_sgsdid_service_update.assert_called_once_with( - action="cmssgsdid", new_value=DOS_PALLIATIVE_CARE_SGSDID + action="cmssgsdid", + new_value=DOS_PALLIATIVE_CARE_SGSDID, ) diff --git a/application/service_sync/tests/test_validation.py b/application/service_sync/tests/test_validation.py index 1c577faad..6fc914b96 100644 --- a/application/service_sync/tests/test_validation.py +++ b/application/service_sync/tests/test_validation.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch +import pytest from aws_lambda_powertools.logging import Logger -from pytest import mark from application.common.nhs import NHSEntity from application.service_sync.validation import validate_opening_times, validate_website @@ -12,7 +12,8 @@ @patch(f"{FILE_PATH}.log_service_with_generic_bank_holiday") @patch.object(Logger, "warning") def test_validate_opening_times_sucessful( - mock_warning_logger: MagicMock, mock_log_service_with_generic_bank_holiday: MagicMock + mock_warning_logger: MagicMock, + mock_log_service_with_generic_bank_holiday: MagicMock, ): # Arrange nhs_entity = MagicMock() @@ -31,7 +32,8 @@ def test_validate_opening_times_sucessful( @patch(f"{FILE_PATH}.log_service_with_generic_bank_holiday") @patch.object(Logger, "warning") def test_validate_opening_times_failure( - mock_warning_logger: MagicMock, mock_log_service_with_generic_bank_holiday: MagicMock + mock_warning_logger: MagicMock, + mock_log_service_with_generic_bank_holiday: MagicMock, ): # Arrange nhs_entity = MagicMock() @@ -44,15 +46,13 @@ def test_validate_opening_times_failure( # Assert assert result is False mock_warning_logger.assert_called_once_with( - ( - f"Opening Times for NHS Entity '{nhs_entity.odscode}' were previously found " - "to be invalid or illogical. Skipping change." - ) + f"Opening Times for NHS Entity '{nhs_entity.odscode}' were previously found " + "to be invalid or illogical. Skipping change.", ) mock_log_service_with_generic_bank_holiday.assert_called_once_with(nhs_entity, dos_service) -@mark.parametrize( +@pytest.mark.parametrize( "website", [ "www.test.com", @@ -68,7 +68,7 @@ def test_validate_website_sucess(mock_log_website_is_invalid: MagicMock, website mock_log_website_is_invalid.assert_not_called() -@mark.parametrize( +@pytest.mark.parametrize( "website", [ "https://testpharmacy@gmail.com", diff --git a/application/service_sync/validation.py b/application/service_sync/validation.py index 709b01b2f..13b429f8b 100644 --- a/application/service_sync/validation.py +++ b/application/service_sync/validation.py @@ -23,10 +23,8 @@ def validate_opening_times(dos_service: DoSService, nhs_entity: NHSEntity) -> bo log_service_with_generic_bank_holiday(nhs_entity, dos_service) if not nhs_entity.all_times_valid(): logger.warning( - ( - f"Opening Times for NHS Entity '{nhs_entity.odscode}' " - "were previously found to be invalid or illogical. Skipping change." - ) + f"Opening Times for NHS Entity '{nhs_entity.odscode}' " + "were previously found to be invalid or illogical. Skipping change.", ) return False return True diff --git a/application/slack_messenger/slack_messenger.py b/application/slack_messenger/slack_messenger.py index 9e2ad4c6f..6063ccf56 100644 --- a/application/slack_messenger/slack_messenger.py +++ b/application/slack_messenger/slack_messenger.py @@ -1,12 +1,12 @@ from datetime import datetime from json import loads from os import environ -from typing import Any, Dict, List +from typing import Any from urllib.parse import quote from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.tracing import Tracer -from aws_lambda_powertools.utilities.data_classes import event_source, SNSEvent +from aws_lambda_powertools.utilities.data_classes import SNSEvent, event_source from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext from requests import post @@ -16,7 +16,16 @@ tracer = Tracer() -def get_message_for_cloudwatch_event(event: SNSEvent) -> Dict[str, Any]: +def get_message_for_cloudwatch_event(event: SNSEvent) -> dict[str, Any]: + """Get message for cloudwatch event. + + Args: + event (SNSEvent): SNS event + + Returns: + dict[str, Any]: Message for slack + """ + def is_expression_alarm() -> bool: logger.debug( "Checking if alarm is an expression alarm", @@ -28,7 +37,7 @@ def is_expression_alarm() -> bool: ) return "Expression" in str(trigger) - def get_attachments_fields() -> List[Dict[str, Any]]: + def get_attachments_fields() -> list[dict[str, Any]]: fields = [ { "title": "Alarm Name", @@ -52,10 +61,10 @@ def get_attachments_fields() -> List[Dict[str, Any]]: { "title": "Trigger", "value": f"{trigger['Statistic']} {metric_name} {trigger['ComparisonOperator']} " - f"{str(trigger['Threshold'])} for {str(trigger['EvaluationPeriods'])} period(s) " - f" of {str(trigger['Period'])} seconds.", + f"{trigger['Threshold']!s} for {trigger['EvaluationPeriods']!s} period(s) " + f" of {trigger['Period']!s} seconds.", "short": False, - } + }, ) return fields @@ -90,22 +99,27 @@ def get_attachments_fields() -> List[Dict[str, Any]]: "type": "mrkdwn", "text": f":rotating_light: *<{link}|{alarm_name}>*", }, - } + }, ], "attachments": [ { "color": colour, "fields": get_attachments_fields(), "ts": timestamp, - } + }, ], } -def send_msg_slack(message: Dict[str, Any]) -> None: +def send_msg_slack(message: dict[str, Any]) -> None: + """Send message to slack. + + Args: + message (dict[str, Any]): Message to send to slack + """ url = environ["SLACK_WEBHOOK_URL"] channel = environ["SLACK_ALERT_CHANNEL"] - headers: Dict[str, str] = {"Content-Type": "application/json", "Accept": "application/json"} + headers: dict[str, str] = {"Content-Type": "application/json", "Accept": "application/json"} message["channel"] = channel message["icon_emoji"] = "" @@ -126,18 +140,17 @@ def send_msg_slack(message: Dict[str, Any]) -> None: @tracer.capture_lambda_handler() @event_source(data_class=SNSEvent) @logger.inject_lambda_context(clear_state=True) -def lambda_handler(event: SNSEvent, context: LambdaContext) -> None: - """Entrypoint handler for the slack_messenger lambda +def lambda_handler(event: SNSEvent, _context: LambdaContext) -> None: + """Entrypoint handler for the slack_messenger lambda. Args: - event (SNSEvent): + event (SNSEvent): SNS event context (LambdaContext): Lambda function context object Event: The event payload Some code may need to be changed if the exact input format is changed. """ - message = get_message_for_cloudwatch_event(event) logger.info("Sending alert to slack.", extra={"slack_message": message}) send_msg_slack(message) diff --git a/application/slack_messenger/tests/__init__.py b/application/slack_messenger/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/application/slack_messenger/tests/test_slack_messenger.py b/application/slack_messenger/tests/test_slack_messenger.py index ed64ee6c7..a8e846e98 100644 --- a/application/slack_messenger/tests/test_slack_messenger.py +++ b/application/slack_messenger/tests/test_slack_messenger.py @@ -3,19 +3,19 @@ from os import environ from unittest.mock import patch +import pytest from aws_lambda_powertools.utilities.data_classes import SNSEvent -from pytest import fixture, mark, raises from application.slack_messenger.slack_messenger import get_message_for_cloudwatch_event, lambda_handler, send_msg_slack FILE_PATH = "application.slack_messenger.slack_messenger" -@fixture +@pytest.fixture() def lambda_context(): @dataclass class LambdaContext: - """Mock LambdaContext - All dummy values""" + """Mock LambdaContext - All dummy values.""" function_name: str = "slack-messenger" memory_limit_in_mb: int = 128 @@ -76,8 +76,8 @@ class LambdaContext: "UnsubscribeUrl": "whocares", "MessageAttributes": {}, }, - } - ] + }, + ], } WEBHOOK_URL = "https://hooks.slack.com/services/1/2/3" @@ -102,7 +102,7 @@ def test_lambda_handler_slack_messenger(mock_send, mock_get, lambda_context): def test_send_message_missing_url(lambda_context): message = {} # Act - with raises(KeyError): + with pytest.raises(KeyError): send_msg_slack(message) @@ -111,7 +111,7 @@ def test_send_message_url_no_channel(lambda_context): message = {} environ["SLACK_WEBHOOK_URL"] = WEBHOOK_URL # Act & Assert - with raises(KeyError): + with pytest.raises(KeyError): send_msg_slack(message) # Clean Up del environ["SLACK_WEBHOOK_URL"] @@ -136,7 +136,10 @@ def test_send_message(mock_post, lambda_context): del environ["SLACK_WEBHOOK_URL"] -@mark.parametrize("new_state_value, colour", (("ALARM", "#e01e5a"), ("OK", "good"), ("INSUFFICIENT_DATA", "warning"))) +@pytest.mark.parametrize( + ("new_state_value", "colour"), + [("ALARM", "#e01e5a"), ("OK", "good"), ("INSUFFICIENT_DATA", "warning")], +) def test_get_message_from_event(new_state_value, colour): # Arrange sns_event_dict = SNS_EVENT.copy() diff --git a/build/automation/etc/githooks/scripts/python-code-pre-commit.sh b/build/automation/etc/githooks/scripts/python-code-pre-commit.sh index aa1ac20b0..bb4acf7e5 100755 --- a/build/automation/etc/githooks/scripts/python-code-pre-commit.sh +++ b/build/automation/etc/githooks/scripts/python-code-pre-commit.sh @@ -3,14 +3,6 @@ set -e [ $(make project-check-if-tech-is-included-in-stack NAME=python) == false ] && exit 0 -if [ $(make git-check-if-commit-changed-directory DIR=application PRECOMMIT=true) == true ]; then - make -s python-code-format python-code-check \ - FILES=application -fi - -if [ $(make git-check-if-commit-changed-directory DIR=test PRECOMMIT=true) == true ]; then - make -s python-code-format python-code-check \ - FILES=test -fi +make python-linting exit 0 diff --git a/pyproject.toml b/pyproject.toml index f646e80ba..63bb6c41a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,12 @@ [tool.ruff] -select = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] +select = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TID", "TRY", "UP", "YTT"] ignore = [ "D100", # Missing docstring in public module "D104", # Missing docstring in public package "PTH123", # Use `pathlib.Path` instead of `os.path` "G004", # Logging statement uses f-string - "ANN101", # Missing type annotation for self in method + "ANN101", # Missing type annotation for self in method, + "DTZ007", # Allow Datetime used as timezone in tests. ] # Allow autofix for all enabled rules (when `--fix`) is provided. @@ -34,8 +35,11 @@ exclude = [ "dist", "node_modules", "venv", - "application", # Ignore application folder temporarily. - "scripts" # Ignore scripts folder temporarily. + "scripts", # Ignore scripts folder + "application/orchestrator", # Ignore orchestrator folder as it will be removed soon. + "application/service_sync/changes_to_dos.py", + "application/service_sync/compare_data.py", + "application/service_sync/dos_data.py", ] # Same as Black. @@ -48,11 +52,25 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" target-version = "py310" [tool.ruff.per-file-ignores] -"**test_*.py" = [ +"application/**test_*.py" = [ "S101", # Allow `assert` in tests. "S105",# Allow fake secrets in tests. "SLF001", # Allow `_function` in tests. "PLR0913", # Allow many arrguments in tests. + "ANN001", # Allow missing type annotation in tests. + "ARG001", # Allow missing type annotation in tests. + "ANN201", # Allow missing type annotation in tests. + "D101", # Allow Missing docstring in public class + "D103", # Allow missing docstring in tests. + "S311", # Allow Standard pseudo-random generators are not suitable for cryptographic purposes in tests. + "PLR2004" , # Allow Magic value used in comparison, consider replacing 20 with a constant variable in tests. + "S608", # Allow Possible SQL injection vector through string concatenation in tests. + "DTZ001", # The use of `datetime.datetime.now()` without `tz` argument is not allowed in tests. + "DTZ005", # The use of `datetime.datetime.now()` without `tz` argument is not allowed in tests. + + ] +"application/**conftest.py" = [ + "S311", # Allow Standard pseudo-random generators are not suitable for cryptographic purposes in tests. ] [tool.ruff.pydocstyle] diff --git a/test/integration/steps/test_steps.py b/test/integration/steps/test_steps.py index 398663446..c556ee3c1 100644 --- a/test/integration/steps/test_steps.py +++ b/test/integration/steps/test_steps.py @@ -126,15 +126,15 @@ def a_service_table_entry_is_created(context: Context, ods_code: int = 0, servic ods_code = str(randint(10000, 99999)) query_values = { "id": str(randint(100000, 999999)), - "uid": f"test{str(randint(10000,99999))}", + "uid": f"test{randint(10000, 99999)!s}", "service_type": service_type, "service_status": 1, - "name": f"Test Pharmacy {str(randint(100,999))}", + "name": f"Test Pharmacy {randint(100, 999)!s}", "odscode": ods_code, - "address": f"{str(randint(100,999))} Test Address", + "address": f"{randint(100, 999)!s} Test Address", "town": "Nottingham", "postcode": "NG11GS", - "publicphone": f"{str(randint(10000000000, 99999999999))}", + "publicphone": f"{randint(10000000000, 99999999999)!s}", "web": "www.google.com", } context.generator_data = query_values diff --git a/test/integration/steps/utilities/generator.py b/test/integration/steps/utilities/generator.py index b0bd6df51..3c4f77710 100644 --- a/test/integration/steps/utilities/generator.py +++ b/test/integration/steps/utilities/generator.py @@ -223,7 +223,7 @@ def add_specified_openings_to_dos(context: Context) -> Any: date = datetime.strptime(day["date"], "%b %d %Y").strftime("%Y-%m-%d") query = ( 'INSERT INTO pathwaysdos.servicespecifiedopeningdates("date", serviceid) ' - f"VALUES('{str(date)}', {int(context.service_id)}) RETURNING id" + f"VALUES('{date!s}', {int(context.service_id)}) RETURNING id" ) lambda_payload = {"type": "read", "query": query, "query_vars": None} response = invoke_dos_db_handler_lambda(lambda_payload)