20
20
Tests for the provenance stored in the output tree sequences.
21
21
"""
22
22
import json
23
+ import math
24
+ import time
23
25
24
26
import pytest
25
27
import tskit
@@ -68,6 +70,51 @@ def test_ancestors_file(self, small_sd_fixture):
68
70
self .validate_file (ancestor_data )
69
71
70
72
73
+ class TestResourceMetrics :
74
+ """
75
+ Tests for the ResourceMetrics dataclass.
76
+ """
77
+
78
+ def test_create_and_asdict (self ):
79
+ metrics = provenance .ResourceMetrics (
80
+ elapsed_time = 1.5 , user_time = 1.0 , sys_time = 0.5 , max_memory = 1000
81
+ )
82
+ d = metrics .asdict ()
83
+ assert d == {
84
+ "elapsed_time" : 1.5 ,
85
+ "user_time" : 1.0 ,
86
+ "sys_time" : 0.5 ,
87
+ "max_memory" : 1000 ,
88
+ }
89
+
90
+ def test_combine_metrics (self ):
91
+ m1 = provenance .ResourceMetrics (
92
+ elapsed_time = 1.0 , user_time = 0.5 , sys_time = 0.2 , max_memory = 1000
93
+ )
94
+ m2 = provenance .ResourceMetrics (
95
+ elapsed_time = 2.0 , user_time = 1.5 , sys_time = 0.3 , max_memory = 2000
96
+ )
97
+ combined = provenance .ResourceMetrics .combine ([m1 , m2 ])
98
+ assert combined .elapsed_time == 3.0
99
+ assert combined .user_time == 2.0
100
+ assert combined .sys_time == 0.5
101
+ assert combined .max_memory == 2000
102
+
103
+ def test_combine_empty_list (self ):
104
+ with pytest .raises (ValueError ):
105
+ provenance .ResourceMetrics .combine ([])
106
+
107
+ def test_combine_single_metric (self ):
108
+ m = provenance .ResourceMetrics (
109
+ elapsed_time = 1.0 , user_time = 0.5 , sys_time = 0.2 , max_memory = 1000
110
+ )
111
+ combined = provenance .ResourceMetrics .combine ([m ])
112
+ assert combined .elapsed_time == 1.0
113
+ assert combined .user_time == 0.5
114
+ assert combined .sys_time == 0.2
115
+ assert combined .max_memory == 1000
116
+
117
+
71
118
class TestIncludeProvenance :
72
119
"""
73
120
Test that we can include or exclude provenances
@@ -124,6 +171,7 @@ def test_provenance_infer(self, small_sd_fixture, mmr, pc, post, precision):
124
171
assert params ["mismatch_ratio" ] == mmr
125
172
assert params ["path_compression" ] == pc
126
173
assert "simplify" not in params
174
+ assert "resources" in record
127
175
128
176
def test_provenance_generate_ancestors (self , small_sd_fixture ):
129
177
ancestors = tsinfer .generate_ancestors (small_sd_fixture )
@@ -132,6 +180,7 @@ def test_provenance_generate_ancestors(self, small_sd_fixture):
132
180
timestamp , record = p
133
181
params = record ["parameters" ]
134
182
assert params ["command" ] == "generate_ancestors"
183
+ assert "resources" in record
135
184
136
185
@pytest .mark .parametrize ("mmr" , [None , 0.1 ])
137
186
@pytest .mark .parametrize ("pc" , [True , False ])
@@ -154,6 +203,9 @@ def test_provenance_match_ancestors(self, small_sd_fixture, mmr, pc, precision):
154
203
assert params ["mismatch_ratio" ] == mmr
155
204
assert params ["path_compression" ] == pc
156
205
assert params ["precision" ] == precision
206
+ for provenance_index in [- 2 , - 1 ]:
207
+ record = json .loads (anc_ts .provenance (provenance_index ).record )
208
+ assert "resources" in record
157
209
158
210
@pytest .mark .parametrize ("mmr" , [None , 0.1 ])
159
211
@pytest .mark .parametrize ("pc" , [True , False ])
@@ -183,6 +235,9 @@ def test_provenance_match_samples(self, small_sd_fixture, mmr, pc, precision, po
183
235
assert params ["precision" ] == precision
184
236
assert params ["post_process" ] == post
185
237
assert "simplify" not in params # deprecated
238
+ for provenance_index in [- 3 , - 2 , - 1 ]:
239
+ record = json .loads (ts .provenance (provenance_index ).record )
240
+ assert "resources" in record
186
241
187
242
@pytest .mark .parametrize ("simp" , [True , False ])
188
243
def test_deprecated_simplify (self , small_sd_fixture , simp ):
@@ -207,15 +262,46 @@ def test_no_command(self):
207
262
with pytest .raises (ValueError ):
208
263
provenance .get_provenance_dict ()
209
264
210
- def validate_encoding (self , params ):
211
- pdict = provenance .get_provenance_dict ("test" , ** params )
265
+ def validate_encoding (self , params , resources = None ):
266
+ pdict = provenance .get_provenance_dict ("test" , resources = resources , ** params )
212
267
encoded = pdict ["parameters" ]
213
268
assert encoded ["command" ] == "test"
214
269
del encoded ["command" ]
215
270
assert encoded == params
271
+ if resources is not None :
272
+ assert "resources" in pdict
273
+ assert pdict ["resources" ] == resources
274
+ else :
275
+ assert "resources" not in pdict
216
276
217
277
def test_empty_params (self ):
218
278
self .validate_encoding ({})
219
279
220
280
def test_non_empty_params (self ):
221
281
self .validate_encoding ({"a" : 1 , "b" : "b" , "c" : 12345 })
282
+
283
+ def test_with_resources (self ):
284
+ self .validate_encoding (
285
+ {}, resources = {"elapsed_time" : 1.23 , "max_memory" : 567.89 }
286
+ )
287
+
288
+
289
+ def test_timing_and_memory_context_manager ():
290
+ with provenance .TimingAndMemory () as timing :
291
+ # Do some work to ensure measurable changes
292
+ time .sleep (0.1 )
293
+ for i in range (1000000 ):
294
+ math .sqrt (i )
295
+ _ = [0 ] * 1000000
296
+
297
+ assert timing .metrics is not None
298
+ assert timing .metrics .elapsed_time > 0.1
299
+ # Check we have highres timing
300
+ assert timing .metrics .elapsed_time < 1
301
+ assert timing .metrics .user_time > 0
302
+ assert timing .metrics .sys_time >= 0
303
+ assert timing .metrics .max_memory > 100_000_000
304
+
305
+ # Test metrics are not available during context
306
+ with provenance .TimingAndMemory () as timing2 :
307
+ assert timing2 .metrics is None
0 commit comments