Skip to content

Commit 9cd16e3

Browse files
committed
test(unittest model): convert to polars (though does now have warnings about joblib backend)
tests/test_unittest_model.py: 21 warnings /home/amy/mambaforge/envs/template-des/lib/python3.13/site-packages/joblib/externals/loky/backend/fork_exec.py:38: DeprecationWarning: This process (pid=49667) is multi-threaded, use of fork() may lead to deadlocks in the child. pid = os.fork()
1 parent cc016ca commit 9cd16e3

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

tests/test_unittest_model.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from simulation.model import Defaults, Exponential, Model, Trial
1616
import numpy as np
17-
import pandas as pd
17+
from polars.testing import assert_frame_equal, assert_series_equal
18+
import polars as pl
1819
import pytest
1920

2021

@@ -129,11 +130,11 @@ def test_high_demand():
129130

130131
# Check that the final patient in the patient-level results is not seen
131132
# by a nurse.
132-
last_patient = results['patient'].iloc[-1]
133-
assert np.isnan(last_patient['q_time_nurse']), (
133+
last_patient = results['patient'].tail(1)
134+
assert np.isnan(last_patient.select('q_time_nurse')), (
134135
'Expect last patient in high demand scenario to have queue time NaN.'
135136
)
136-
assert np.isnan(last_patient['time_with_nurse']), (
137+
assert np.isnan(last_patient.select('time_with_nurse')), (
137138
'Expect last patient in high demand scenario to have NaN for time' +
138139
'with nurse.'
139140
)
@@ -194,26 +195,26 @@ def helper_warmup(warm_up_period):
194195
results_none = helper_warmup(warm_up_period=0)
195196

196197
# Extract result of first patient
197-
first_warmup = results_warmup['patient'].iloc[0]
198-
first_none = results_none['patient'].iloc[0]
198+
first_warmup = results_warmup['patient'].head(1)
199+
first_none = results_none['patient'].head(1)
199200

200201
# Check that model with warm-up has arrival time later than warm-up length
201202
# and queue time greater than 0
202-
assert first_warmup['arrival_time'] > 500, (
203+
assert first_warmup.select('arrival_time').item() > 500, (
203204
'Expect first patient to arrive after time 500 when model is run ' +
204205
f'with warm-up, but got {first_warmup["arrival_time"]}.'
205206
)
206-
assert first_warmup['q_time_nurse'] > 0, (
207+
assert first_warmup.select('q_time_nurse').item() > 0, (
207208
'Expect first patient to need to queue in model with warm-up and ' +
208209
f'high arrival rate, but got {first_warmup["q_time_nurse"]}.'
209210
)
210211

211212
# Check that model without warm-up has arrival and queue time of 0
212-
assert first_none['arrival_time'] == 0, (
213+
assert first_none.select('arrival_time').item() == 0, (
213214
'Expect first patient to arrive at time 0 when model is run ' +
214215
f'without warm-up, but got {first_none["arrival_time"]}.'
215216
)
216-
assert first_none['q_time_nurse'] == 0, (
217+
assert first_none.select('q_time_nurse').item() == 0, (
217218
'Expect first patient to have no wait time in model without warm-up ' +
218219
f'but got {first_none["q_time_nurse"]}.'
219220
)
@@ -228,14 +229,13 @@ def test_arrivals():
228229
trial.run_trial()
229230

230231
# Get count of patients from patient-level and trial-level results
231-
patient_n = trial.patient_results_df.groupby('run')['patient_id'].count()
232+
patient_n = trial.patient_results_df.group_by('run').agg(
233+
pl.col('patient_id').count()).get_column('patient_id')
232234
trial_n = trial.trial_results_df['arrivals']
233235

234236
# Compare the counts from each run
235-
assert all(patient_n == trial_n), (
236-
'The number of arrivals in the trial-level results should be ' +
237-
'consistent with the number of patients in the patient-level results.'
238-
)
237+
assert_series_equal(
238+
patient_n, trial_n, check_dtypes=False, check_names=False)
239239

240240

241241
@pytest.mark.parametrize('param_name, initial_value, adjusted_value', [
@@ -349,7 +349,7 @@ def test_seed_stability():
349349
result2 = trial2.run_single(run=33)
350350

351351
# Check that dataframes with patient-level results are equal
352-
pd.testing.assert_frame_equal(result1['patient'], result2['patient'])
352+
assert_frame_equal(result1['patient'], result2['patient'])
353353

354354

355355
def test_interval_audit_time():
@@ -421,8 +421,8 @@ def test_parallel():
421421
results[mode] = trial.run_single(run=0)
422422

423423
# Verify results are identical
424-
pd.testing.assert_frame_equal(
424+
assert_frame_equal(
425425
results['seq']['patient'], results['par']['patient'])
426-
pd.testing.assert_frame_equal(
426+
assert_frame_equal(
427427
results['seq']['interval_audit'], results['par']['interval_audit'])
428428
assert results['seq']['trial'] == results['par']['trial']

0 commit comments

Comments
 (0)