13
13
"""
14
14
15
15
from simulation .model import Defaults , Trial
16
- import pandas as pd
17
16
from pathlib import Path
17
+ import polars as pl
18
+ from polars .testing import assert_frame_equal
18
19
19
20
20
21
def test_reproduction ():
@@ -39,21 +40,23 @@ def test_reproduction():
39
40
trial .run_trial ()
40
41
41
42
# Compare patient-level results
42
- exp_patient = pd .read_csv (
43
- Path (__file__ ).parent .joinpath ('exp_results/patient.csv' ))
44
- pd .testing .assert_frame_equal (trial .patient_results_df , exp_patient )
43
+ exp_patient = pl .read_csv (
44
+ Path (__file__ ).parent .joinpath ('exp_results/patient.csv' )).cast ({
45
+ 'run' : pl .Int32 })
46
+ assert_frame_equal (trial .patient_results_df , exp_patient )
45
47
46
48
# Compare trial-level results
47
- exp_trial = pd .read_csv (
49
+ exp_trial = pl .read_csv (
48
50
Path (__file__ ).parent .joinpath ('exp_results/trial.csv' ))
49
- pd . testing . assert_frame_equal (trial .trial_results_df , exp_trial )
51
+ assert_frame_equal (trial .trial_results_df , exp_trial )
50
52
51
53
# Compare interval audit results
52
- exp_interval = pd .read_csv (
53
- Path (__file__ ).parent .joinpath ('exp_results/interval.csv' ))
54
- pd .testing .assert_frame_equal (trial .interval_audit_df , exp_interval )
54
+ exp_interval = pl .read_csv (
55
+ Path (__file__ ).parent .joinpath ('exp_results/interval.csv' )).cast ({
56
+ 'run' : pl .Int32 })
57
+ assert_frame_equal (trial .interval_audit_df , exp_interval )
55
58
56
59
# Compare overall results
57
- exp_overall = pd .read_csv (
58
- Path (__file__ ).parent .joinpath ('exp_results/overall.csv' ), index_col = 0 )
59
- pd . testing . assert_frame_equal (trial .overall_results_df , exp_overall )
60
+ exp_overall = pl .read_csv (
61
+ Path (__file__ ).parent .joinpath ('exp_results/overall.csv' ))
62
+ assert_frame_equal (trial .overall_results_df , exp_overall )
0 commit comments