Skip to content

Commit cc016ca

Browse files
committed
test(backtest): convert to polars and resolve dtype differences from save/import
1 parent 8907d48 commit cc016ca

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

tests/test_backtest.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
"""
1414

1515
from simulation.model import Defaults, Trial
16-
import pandas as pd
1716
from pathlib import Path
17+
import polars as pl
18+
from polars.testing import assert_frame_equal
1819

1920

2021
def test_reproduction():
@@ -39,21 +40,23 @@ def test_reproduction():
3940
trial.run_trial()
4041

4142
# Compare patient-level results
42-
exp_patient = pd.read_csv(
43-
Path(__file__).parent.joinpath('exp_results/patient.csv'))
44-
pd.testing.assert_frame_equal(trial.patient_results_df, exp_patient)
43+
exp_patient = pl.read_csv(
44+
Path(__file__).parent.joinpath('exp_results/patient.csv')).cast({
45+
'run': pl.Int32})
46+
assert_frame_equal(trial.patient_results_df, exp_patient)
4547

4648
# Compare trial-level results
47-
exp_trial = pd.read_csv(
49+
exp_trial = pl.read_csv(
4850
Path(__file__).parent.joinpath('exp_results/trial.csv'))
49-
pd.testing.assert_frame_equal(trial.trial_results_df, exp_trial)
51+
assert_frame_equal(trial.trial_results_df, exp_trial)
5052

5153
# Compare interval audit results
52-
exp_interval = pd.read_csv(
53-
Path(__file__).parent.joinpath('exp_results/interval.csv'))
54-
pd.testing.assert_frame_equal(trial.interval_audit_df, exp_interval)
54+
exp_interval = pl.read_csv(
55+
Path(__file__).parent.joinpath('exp_results/interval.csv')).cast({
56+
'run': pl.Int32})
57+
assert_frame_equal(trial.interval_audit_df, exp_interval)
5558

5659
# Compare overall results
57-
exp_overall = pd.read_csv(
58-
Path(__file__).parent.joinpath('exp_results/overall.csv'), index_col=0)
59-
pd.testing.assert_frame_equal(trial.overall_results_df, exp_overall)
60+
exp_overall = pl.read_csv(
61+
Path(__file__).parent.joinpath('exp_results/overall.csv'))
62+
assert_frame_equal(trial.overall_results_df, exp_overall)

0 commit comments

Comments
 (0)