Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rearrange async code to fix httpx pool timeouts #566

Merged
merged 8 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 86 additions & 95 deletions project/npda/general_functions/csv/csv_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,31 +115,6 @@ async def validate_visit_using_form(patient_form, row, async_client):

return form

async def validate_rows(rows, async_client):
first_row = rows.iloc[0]
patient_row_index = int(first_row["row_index"])

transfer_fields = validate_transfer(first_row)

patient_form = await validate_patient_using_form(first_row, async_client)
# Pull through cleaned_data so we can use it in the async visit validators

patient_form.is_valid()

visit_forms = []
for _, row in rows.iterrows():
visit_form = await validate_visit_using_form(
patient_form, row, async_client
)
visit_forms.append((visit_form, int(row["row_index"])))

return (
patient_form,
transfer_fields,
patient_row_index,
visit_forms,
)

def retain_errors_and_invalid_field_data(form):
# We want to retain fields even if they're invalid so that we can return them to the user
# Use the field value from cleaned_data, falling back to data if it's not there
Expand All @@ -155,34 +130,11 @@ def retain_errors_and_invalid_field_data(form):
None if form.is_valid() else form.errors.get_json_data(escape_html=True)
)

async def validate_rows_in_parallel(rows_by_patient, async_client):
tasks = []

async with asyncio.TaskGroup() as tg:
for _, rows in rows_by_patient:
task = tg.create_task(validate_rows(rows, async_client))
tasks.append(task)

return [task.result() for task in tasks]

def record_errors_from_form(errors_to_return, row_index, form):
for field, errors in form.errors.as_data().items():
for error in errors:
errors_to_return[row_index][field].extend(error.messages)

def do_not_save_patient_if_no_unique_identifier(patient_form):
if (
patient_form.cleaned_data.get("nhs_number") is None
and patient_form.cleaned_data.get("unique_reference_number") is None
):
patient = patient_form.save(commit=False)
else:
patient = patient_form.save(
commit=True
) # save the patient if there is a unique identifier

return patient

""""
Create the submission and save the csv file
"""
Expand Down Expand Up @@ -275,59 +227,98 @@ def do_not_save_patient_if_no_unique_identifier(patient_form):
# dict[number, dict[str, list[str]]]
errors_to_return = collections.defaultdict(lambda: collections.defaultdict(list))

async with httpx.AsyncClient() as async_client:
validation_results_by_patient = await validate_rows_in_parallel(
rows_by_patient=visits_by_patient, async_client=async_client
)
async def process_rows_for_patient(rows, async_client):
patient = None

for (
patient_form,
transfer_fields,
patient_row_index,
# first_row_field_errors,
parsed_visits,
) in validation_results_by_patient:
record_errors_from_form(errors_to_return, patient_row_index, patient_form)

patient = None

try:
retain_errors_and_invalid_field_data(patient_form)
patient = await sync_to_async(
do_not_save_patient_if_no_unique_identifier
)(patient_form)

if patient:
# add the patient to a new Transfer instance
transfer_fields["paediatric_diabetes_unit"] = pdu
transfer_fields["patient"] = patient
await Transfer.objects.acreate(**transfer_fields)

await new_submission.patients.aadd(patient)
except Exception as error:
logger.exception(
f"Error saving patient for {pdu_pz_code} from {csv_file_name}[{patient_row_index}]: {error}"
)
first_row = rows.iloc[0]
patient_row_index = int(first_row["row_index"])

transfer_fields = validate_transfer(first_row)

patient_form = await validate_patient_using_form(first_row, async_client)

# Pull through cleaned_data so we can use it in the async visit validators
patient_form.is_valid()

# We don't know what field caused the error so add to __all__
errors_to_return[patient_row_index]["__all__"].append(str(error))
record_errors_from_form(errors_to_return, patient_row_index, patient_form)

visit_forms = []
for _, row in rows.iterrows():
visit_form = await validate_visit_using_form(
patient_form, row, async_client
)
visit_forms.append((visit_form, int(row["row_index"])))

try:
retain_errors_and_invalid_field_data(patient_form)

patient = await sync_to_async(lambda: patient_form.save())()

if patient:
for visit_form, visit_row_index in parsed_visits:
record_errors_from_form(
errors_to_return, visit_row_index, visit_form
)
# add the patient to a new Transfer instance
transfer_fields["paediatric_diabetes_unit"] = pdu
transfer_fields["patient"] = patient
await Transfer.objects.acreate(**transfer_fields)

await new_submission.patients.aadd(patient)
except Exception as error:
logger.exception(
f"Error saving patient for {pdu_pz_code} from {csv_file_name}[{patient_row_index}]: {error}"
)

try:
retain_errors_and_invalid_field_data(visit_form)
visit_form.instance.patient = patient
# We don't know what field caused the error so add to __all__
errors_to_return[patient_row_index]["__all__"].append(str(error))

await sync_to_async(lambda: visit_form.save())()
except Exception as error:
logger.exception(
f"Error saving visit for {pdu_pz_code} from {csv_file_name}[{visit_row_index}]: {error}"
)
errors_to_return[visit_row_index]["__all__"].append(str(error))
if patient:
for visit_form, visit_row_index in visit_forms:
record_errors_from_form(
errors_to_return, visit_row_index, visit_form
)

try:
retain_errors_and_invalid_field_data(visit_form)
visit_form.instance.patient = patient

await sync_to_async(lambda: visit_form.save())()
except Exception as error:
logger.exception(
f"Error saving visit for {pdu_pz_code} from {csv_file_name}[{visit_row_index}]: {error}"
)
errors_to_return[visit_row_index]["__all__"].append(str(error))

async with httpx.AsyncClient() as async_client:
async with asyncio.TaskGroup() as tg:
# The maximum number of patients we will process in parallel
# NB: each patient has a variable number of visits
#
# I tried 20, 10, 5 and 3 with 200 patients (16 visits each)
# 20: 59s.
# 10: 44s.
# 5: 42s
# 3: 45s
#
# I also tried no task group at all, just doing each patient in sequence
# That took 1m 1s.
#
# So I went with 5. Seems a reasonable balance between an actual speed up and not hammering third party APIs.
throttle_semaphore = asyncio.Semaphore(5)

counter = 1

for _, rows in visits_by_patient:
async def task(ix, rows):
if(throttle_semaphore.locked()):
print(f"!! [PATIENT {ix} waiting to start")

async with throttle_semaphore:
print(f"!! [PATIENT] {ix} starting")
await process_rows_for_patient(rows, async_client)
print(f"!! [PATIENT] {ix} complete")

tg.create_task(task(counter, rows))
counter += 1

# TODO MRB: why is it saying 1 rows worth of errors? I must have broke the error reporting somehow

# Store the errors to report back to the user in the Data Quality Report
if errors_to_return:
Expand Down
8 changes: 4 additions & 4 deletions project/npda/models/patient.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ class Patient(models.Model):
"""
The Patient class.

The index of multiple deprivation is calculated in the save() method using the postcode supplied and the
RCPCH Census Platform
The index of multiple deprivation is calculated using the postcode supplied and the RCPCH Census Platform

Custom methods age and age_days, returns the age
"""
Expand Down Expand Up @@ -148,8 +147,7 @@ class Meta:
CAN_OPT_OUT_CHILD_FROM_INCLUSION_IN_AUDIT,
]

def clean(self):
super().clean()
def save(self, **kwargs):
if not self.nhs_number and not self.unique_reference_number:
raise ValidationError(
"Either NHS Number or Unique Reference Number must be provided."
Expand All @@ -159,6 +157,8 @@ def clean(self):
"Only one of NHS Number or Unique Reference Number should be provided."
)

super().save(**kwargs)

def __str__(self) -> str:
if self.unique_reference_number:
return f"ID: {self.pk}, {self.unique_reference_number}"
Expand Down
20 changes: 17 additions & 3 deletions project/npda/tests/factories/seed_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def _seed_users_fixture(django_db_setup, django_db_blocker):

GOSH_PZ_CODE = "PZ196"
ALDER_HEY_PZ_CODE = "PZ074"
JERSEY_PZ_CODE = "PZ248"

logger.info(f"Seeding test users at {GOSH_PZ_CODE=} and {ALDER_HEY_PZ_CODE=}.")
# Seed a user of each type at GOSH
logger.info(f"Seeding test users at {GOSH_PZ_CODE=}, {ALDER_HEY_PZ_CODE=} and {JERSEY_PZ_CODE=}.")
# Seed a user of each type
for user in users:
first_name = user.role_str

Expand Down Expand Up @@ -95,9 +96,22 @@ def _seed_users_fixture(django_db_setup, django_db_blocker):
organisation_employers=[ALDER_HEY_PZ_CODE],
)

# Jersey user
new_user_jersey = NPDAUserFactory(
first_name=first_name,
role=user.role,
# Assign flags based on user role
is_active=is_active,
is_staff=is_staff,
is_rcpch_audit_team_member=is_rcpch_audit_team_member,
is_rcpch_staff=is_rcpch_staff,
groups=[user.group_name],
organisation_employers=[JERSEY_PZ_CODE],
)

logger.info(f"Seeded users: \n{new_user_gosh=} and \n{new_user_alder_hey=}")

assert NPDAUser.objects.count() == len(users) * 2
assert NPDAUser.objects.count() == len(users) * 3

@pytest.fixture(scope="session")
def seed_users_fixture(django_db_setup, django_db_blocker):
Expand Down
70 changes: 61 additions & 9 deletions project/npda/tests/test_csv_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ def test_user(seed_groups_fixture, seed_users_fixture):
# The database is not rolled back if we used the built in async support for pytest
# https://github.com/pytest-dev/pytest-asyncio/issues/226
@async_to_sync
async def csv_upload_sync(user, dataframe):
async def csv_upload_sync(user, dataframe, pdu_pz_code=ALDER_HEY_PZ_CODE):
return await csv_upload(
user,
dataframe,
csv_file_name=None,
csv_file_bytes=None,
pdu_pz_code=ALDER_HEY_PZ_CODE,
pdu_pz_code=pdu_pz_code,
audit_year=2024,
)

Expand Down Expand Up @@ -219,7 +219,6 @@ def test_multiple_patients(
@pytest.mark.parametrize(
"column,model_field",
[
pytest.param("NHS Number", "nhs_number"),
pytest.param("Date of Birth", "date_of_birth"),
pytest.param("Diabetes Type", "diabetes_type"),
pytest.param("Date of Diabetes Diagnosis", "diagnosis_date"),
Expand All @@ -232,6 +231,7 @@ def test_missing_mandatory_field(
single_row_valid_df,
column,
model_field,
pdu_pz_code
):
# As these tests need full transaction support we can't use our session fixtures
test_user = NPDAUser.objects.filter(
Expand All @@ -251,7 +251,64 @@ def test_missing_mandatory_field(

assert model_field in errors[0]

# Catastrophic - we can't save this patient at all so we won't save any of the patients in the submission
# Catastrophic - we can't save this patient at all
assert Patient.objects.count() == 0


@pytest.mark.django_db
def test_missing_nhs_number(
seed_groups_per_function_fixture,
seed_users_per_function_fixture,
single_row_valid_df
):
# As these tests need full transaction support we can't use our session fixtures
test_user = NPDAUser.objects.filter(
organisation_employers__pz_code=ALDER_HEY_PZ_CODE
).first()

# Delete all patients to ensure we're starting from a clean slate
Patient.objects.all().delete()

single_row_valid_df.loc[0, "NHS Number"] = None

assert (
Patient.objects.count() == 0
), "There should be no patients in the database before the test"

errors = csv_upload_sync(test_user, single_row_valid_df)

assert "nhs_number" in errors[0]

# We shouldn't save this patient (invariant enforced in Patient.save not in the database)
assert Patient.objects.count() == 0


@pytest.mark.django_db
def test_missing_unique_reference_number(
seed_groups_per_function_fixture,
seed_users_per_function_fixture,
single_row_valid_df
):
# As these tests need full transaction support we can't use our session fixtures
test_user = NPDAUser.objects.filter(
organisation_employers__pz_code=ALDER_HEY_PZ_CODE
).first()

# Delete all patients to ensure we're starting from a clean slate
Patient.objects.all().delete()

df = single_row_valid_df.rename(columns={"NHS Number": "Unique Reference Number"})
df.loc[0, "Unique Reference Number"] = None

assert (
Patient.objects.count() == 0
), "There should be no patients in the database before the test"

errors = csv_upload_sync(test_user, df, "PZ248")

assert "unique_reference_number" in errors[0]

# We shouldn't save this patient (invariant enforced in Patient.save not in the database)
assert Patient.objects.count() == 0


Expand Down Expand Up @@ -366,11 +423,6 @@ def test_invalid_nhs_number(test_user, single_row_valid_df):
errors = csv_upload_sync(test_user, single_row_valid_df)
assert "nhs_number" in errors[0]

# Catastrophic - Patient not save
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine to save an invalid NHS number, it will be flagged in the data quality report and UI

assert Patient.objects.count() == 0

# TODO MRB: create a ValidationError model field (https://github.com/rcpch/national-paediatric-diabetes-audit/issues/332)


@pytest.mark.django_db
def test_future_date_of_birth(test_user, single_row_valid_df):
Expand Down
Loading