Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions application/common/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,10 @@ def db_rows_to_spec_open_times_map(db_rows: Iterable[dict]) -> dict[str, list[Sp
for db_row in db_rows:
serviceid_dbrows_map[db_row["serviceid"]].append(db_row)

serviceid_specopentimes_map = {}
for service_id, db_rows in serviceid_dbrows_map.items():
serviceid_specopentimes_map[service_id] = db_rows_to_spec_open_times(db_rows)

return serviceid_specopentimes_map
return {
service_id: db_rows_to_spec_open_times(db_rows)
for service_id, db_rows in serviceid_dbrows_map.items()
}


def db_rows_to_std_open_times(db_rows: Iterable[dict]) -> StandardOpeningTimes:
Expand Down Expand Up @@ -362,11 +361,10 @@ def db_rows_to_std_open_times_map(db_rows: Iterable[dict]) -> dict[str, Standard
for db_row in db_rows:
serviceid_dbrows_map[db_row["serviceid"]].append(db_row)

serviceid_stdopentimes_map = {}
for service_id, db_rows in serviceid_dbrows_map.items():
serviceid_stdopentimes_map[service_id] = db_rows_to_std_open_times(db_rows)

return serviceid_stdopentimes_map
return {
service_id: db_rows_to_std_open_times(db_rows)
for service_id, db_rows in serviceid_dbrows_map.items()
}


def has_palliative_care(service: DoSService, connection: Connection) -> bool:
Expand Down
5 changes: 4 additions & 1 deletion application/common/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def get_newest_event_per_odscode(threads: int = 2, limit: int = None) -> dict[st
def merge_newest_events(newest_events: dict, more_events: list[dict]): # noqa: ANN202
for event in more_events:
newest_event = newest_events.get(event["ODSCode"])
if not (newest_event is not None and newest_event["SequenceNumber"] > event["SequenceNumber"]):
if (
newest_event is None
or newest_event["SequenceNumber"] <= event["SequenceNumber"]
):
newest_events[event["ODSCode"]] = event

def scan_thread(segment: int, total_segments: int): # noqa: ANN202
Expand Down
4 changes: 2 additions & 2 deletions application/common/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@lambda_handler_decorator(trace_execution=True)
def redact_staff_key_from_event(handler, event, context: LambdaContext) -> Any: # noqa: ANN001, ANN401
def redact_staff_key_from_event(handler, event, context: LambdaContext) -> Any: # noqa: ANN001, ANN401
"""Lambda middleware to remove the 'Staff' key from the Change Event payload.

Args:
Expand All @@ -24,7 +24,7 @@ def redact_staff_key_from_event(handler, event, context: LambdaContext) -> Any:
Any: Lambda handler response
"""
logger.info("Checking if 'Staff' key needs removing from Change Event payload")
if "Records" in event and len(list(event["Records"])) > 0:
if "Records" in event and list(event["Records"]):
for record in event["Records"]:
change_event = extract_body(record["body"])
if change_event.pop("Staff", None) is not None:
Expand Down
40 changes: 23 additions & 17 deletions application/common/nhs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,20 @@ def normal_postcode(self) -> str:

def extract_contact(self, contact_type: str) -> str | None:
"""Returns the nested contact value within the input payload."""
for item in self.entity_data.get("Contacts", []):
if (
item.get("ContactMethodType", "").upper() == contact_type.upper()
and item.get("ContactType", "").upper() == "PRIMARY"
and item.get("ContactAvailabilityType", "").upper() == "OFFICE HOURS"
):
return item.get("ContactValue")
return None
return next(
(
item.get("ContactValue")
for item in self.entity_data.get("Contacts", [])
if (
item.get("ContactMethodType", "").upper()
== contact_type.upper()
and item.get("ContactType", "").upper() == "PRIMARY"
and item.get("ContactAvailabilityType", "").upper()
== "OFFICE HOURS"
)
),
None,
)

def extract_uec_service(self, service_code: str) -> bool | None:
"""Extracts the UEC service from the payload (e.g. Palliative Care).
Expand Down Expand Up @@ -190,7 +196,7 @@ def is_matching_dos_service(self, dos_service: DoSService) -> bool:
)

if dos_service.typeid in DENTIST_SERVICE_TYPE_IDS:
if not (len(dos_service.odscode) >= 6 and len(self.odscode) >= 7): # noqa: PLR2004
if len(dos_service.odscode) < 6 or len(self.odscode) < 7: # noqa: PLR2004
return False
odscode_extra_0 = f"{dos_service.odscode[0]}0{dos_service.odscode[1:]}"
return self.odscode[:7] in (dos_service.odscode[:7], odscode_extra_0[:7])
Expand Down Expand Up @@ -221,10 +227,10 @@ def is_std_opening_json(item: dict) -> bool:
return False

# If marked as closed, ensure open time values are not present
if not is_open and (any(value not in ["", None] for value in (open_time, close_time))):
return False

return True
return bool(
is_open
or all(value in ["", None] for value in (open_time, close_time)),
)


def is_spec_opening_json(item: dict) -> bool:
Expand All @@ -249,10 +255,10 @@ def is_spec_opening_json(item: dict) -> bool:
return False

# If marked as closed, ensure open time values are not present
if not is_open and (any(value not in ["", None] for value in (open_time, close_time))):
return False

return True
return bool(
is_open
or all(value in ["", None] for value in (open_time, close_time)),
)


def match_nhs_entities_to_services(
Expand Down
10 changes: 3 additions & 7 deletions application/common/opening_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,7 @@ def remove_past_dates(
"""Removes any SpecifiedOpeningTime objects from the list that are in the past."""
if date_now is None:
date_now = datetime.now().date() # noqa: DTZ005
future_dates = []
for item in times_list:
if item.date >= date_now:
future_dates.append(item)
return future_dates
return [item for item in times_list if item.date >= date_now]

def export_test_format(self) -> dict:
"""Exports Specified opening time into a test format that can be used in the tests."""
Expand All @@ -334,7 +330,7 @@ def export_test_format_list(spec_opening_dates: list["SpecifiedOpeningTime"]) ->
opening_dates_cr_format = {}
for spec_open_date in spec_opening_dates:
spec_open_date_payload = spec_open_date.export_test_format()
opening_dates_cr_format.update(spec_open_date_payload)
opening_dates_cr_format |= spec_open_date_payload
return opening_dates_cr_format


Expand Down Expand Up @@ -370,7 +366,7 @@ def __str__(self) -> str:

def __len__(self) -> int:
"""Returns the number of OpenPeriods in the StandardOpeningTimes object."""
return sum([len(getattr(self, day)) for day in WEEKDAYS])
return sum(len(getattr(self, day)) for day in WEEKDAYS)

def __eq__(self, other: "StandardOpeningTimes") -> bool:
"""Check equality of 2 StandardOpeningTimes (generic bankholiday values are ignored)."""
Expand Down
6 changes: 2 additions & 4 deletions application/common/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,9 @@ def dummy_dos_service(**kwargs: Any) -> DoSService: # noqa: ANN401
return dos_service


def blank_dos_service(**kwargs: Any) -> DoSService: # noqa: ANN401
def blank_dos_service(**kwargs: Any) -> DoSService: # noqa: ANN401
"""Creates a DoSService Object with blank str data for the unit testing."""
test_data = {}
for col in DoSService.field_names():
test_data[col] = ""
test_data = {col: "" for col in DoSService.field_names()}
dos_service = DoSService(test_data)

for name, value in kwargs.items():
Expand Down
2 changes: 1 addition & 1 deletion application/common/tests/test_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_add_change_event_to_dynamodb(dynamodb_table_create, change_event, dynam
expected = loads(dumps(change_event), parse_float=Decimal)

assert response_id == change_id
assert deserialized["EventReceived"] == int(event_received_time)
assert deserialized["EventReceived"] == event_received_time
assert deserialized["TTL"] == int(event_received_time + TTL)
assert deserialized["Id"] == change_id
assert deserialized["SequenceNumber"] == 1
Expand Down
2 changes: 0 additions & 2 deletions application/common/tests/test_opening_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def test_open_period_eq_hash():
assert hash(c) != hash(d)

assert d == d
assert hash(d) == hash(d)

b.end = time(17, 0, 0)
assert a == b
assert hash(a) == hash(b)
Expand Down
2 changes: 1 addition & 1 deletion application/common/tests/test_report_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_log_invalid_open_times(mock_logger):
nhs_entity.odscode = "SLC4X"
nhs_entity.org_name = "OrganisationName"

dos_services = [dummy_dos_service() for i in range(3)]
dos_services = [dummy_dos_service() for _ in range(3)]
# Act
log_invalid_open_times(nhs_entity, dos_services)
# Assert
Expand Down
4 changes: 1 addition & 3 deletions application/common/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def get_sqs_msg_attribute(msg_attributes: dict[str, Any], key: str) -> str | flo
data_type = attribute.get("dataType")
if data_type == "String":
return attribute.get("stringValue")
if data_type == "Number":
return float(attribute.get("stringValue"))
return None
return float(attribute.get("stringValue")) if data_type == "Number" else None


def handle_sqs_msg_attributes(msg_attributes: dict[str, Any]) -> dict[str, Any] | None:
Expand Down
2 changes: 1 addition & 1 deletion application/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ pytest-sugar
pytest-xdist
requests
responses
ruff == 0.0.269
ruff == 0.0.270
testfixtures
vulture
2 changes: 1 addition & 1 deletion application/service_matcher/service_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_matching_services(nhs_entity: NHSEntity) -> list[DoSService]:
matching_services.append(service)
else:
non_matching_services.append(service)
if len(non_matching_services) > 0:
if non_matching_services:
log_unmatched_service_types(nhs_entity, non_matching_services)

if nhs_entity.org_type_id == PHARMACY_ORG_TYPE_ID:
Expand Down
9 changes: 2 additions & 7 deletions application/service_sync/changes_to_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,10 @@

from aws_lambda_powertools.logging import Logger

from .format import format_address, format_website
from .service_histories import ServiceHistories
from .validation import validate_website
from common.dos import DoSService, get_valid_dos_location
from common.dos_location import DoSLocation
from common.dos import DoSService
from common.nhs import NHSEntity
from common.opening_times import SpecifiedOpeningTime, StandardOpeningTimes, opening_period_times_from_list
from common.report_logging import log_invalid_nhsuk_postcode
from common.utilities import is_val_none_or_empty
from common.opening_times import SpecifiedOpeningTime

logger = Logger(child=True)

Expand Down
Loading