@@ -106,7 +106,7 @@ impl<H: Hashable, C: Clone + Ord + core::fmt::Debug, const DEPTH: u8> CompleteTr
106
106
self . mark ( ) ;
107
107
}
108
108
Retention :: Checkpoint { id, marking } => {
109
- let latest_checkpoint = self . checkpoints . keys ( ) . rev ( ) . next ( ) ;
109
+ let latest_checkpoint = self . checkpoints . keys ( ) . next_back ( ) ;
110
110
if Some ( & id) > latest_checkpoint {
111
111
append ( & mut self . leaves , value, DEPTH ) ?;
112
112
if marking == Marking :: Marked {
@@ -145,7 +145,7 @@ impl<H: Hashable, C: Clone + Ord + core::fmt::Debug, const DEPTH: u8> CompleteTr
145
145
if !self . marks . contains ( & pos) {
146
146
self . marks . insert ( pos) ;
147
147
148
- if let Some ( checkpoint) = self . checkpoints . values_mut ( ) . rev ( ) . next ( ) {
148
+ if let Some ( checkpoint) = self . checkpoints . values_mut ( ) . next_back ( ) {
149
149
checkpoint. marked . insert ( pos) ;
150
150
}
151
151
}
@@ -156,11 +156,14 @@ impl<H: Hashable, C: Clone + Ord + core::fmt::Debug, const DEPTH: u8> CompleteTr
156
156
}
157
157
}
158
158
159
+ // Creates a new checkpoint with the specified identifier and the given tree position; if `pos`
160
+ // is not provided, the position of the most recently appended leaf is used, or a new
161
+ // checkpoint of the empty tree is added if appropriate.
159
162
fn checkpoint ( & mut self , id : C , pos : Option < Position > ) {
160
163
self . checkpoints . insert (
161
164
id,
162
165
Checkpoint :: at_length ( pos. map_or_else (
163
- || 0 ,
166
+ || self . leaves . len ( ) ,
164
167
|p| usize:: try_from ( p) . expect ( MAX_COMPLETE_SIZE_ERROR ) + 1 ,
165
168
) ) ,
166
169
) ;
@@ -170,16 +173,12 @@ impl<H: Hashable, C: Clone + Ord + core::fmt::Debug, const DEPTH: u8> CompleteTr
170
173
}
171
174
172
175
fn leaves_at_checkpoint_depth ( & self , checkpoint_depth : usize ) -> Option < usize > {
173
- if checkpoint_depth == 0 {
174
- Some ( self . leaves . len ( ) )
175
- } else {
176
- self . checkpoints
177
- . iter ( )
178
- . rev ( )
179
- . skip ( checkpoint_depth - 1 )
180
- . map ( |( _, c) | c. leaves_len )
181
- . next ( )
182
- }
176
+ self . checkpoints
177
+ . iter ( )
178
+ . rev ( )
179
+ . skip ( checkpoint_depth)
180
+ . map ( |( _, c) | c. leaves_len )
181
+ . next ( )
183
182
}
184
183
185
184
/// Removes the oldest checkpoint. Returns true if successful and false if
@@ -237,21 +236,20 @@ impl<H: Hashable + PartialEq + Clone, C: Ord + Clone + core::fmt::Debug, const D
237
236
}
238
237
}
239
238
240
- fn root ( & self , checkpoint_depth : usize ) -> Option < H > {
241
- self . leaves_at_checkpoint_depth ( checkpoint_depth)
242
- . and_then ( |len| root ( & self . leaves [ 0 ..len] , DEPTH ) )
239
+ fn root ( & self , checkpoint_depth : Option < usize > ) -> Option < H > {
240
+ checkpoint_depth. map_or_else (
241
+ || root ( & self . leaves [ ..] , DEPTH ) ,
242
+ |depth| {
243
+ self . leaves_at_checkpoint_depth ( depth)
244
+ . and_then ( |len| root ( & self . leaves [ 0 ..len] , DEPTH ) )
245
+ } ,
246
+ )
243
247
}
244
248
245
249
fn witness ( & self , position : Position , checkpoint_depth : usize ) -> Option < Vec < H > > {
246
- if self . marks . contains ( & position) && checkpoint_depth <= self . checkpoints . len ( ) {
250
+ if self . marks . contains ( & position) {
247
251
let leaves_len = self . leaves_at_checkpoint_depth ( checkpoint_depth) ?;
248
- let c_idx = self . checkpoints . len ( ) - checkpoint_depth;
249
- if self
250
- . checkpoints
251
- . iter ( )
252
- . skip ( c_idx)
253
- . any ( |( _, c) | c. marked . contains ( & position) )
254
- {
252
+ if u64:: from ( position) >= u64:: try_from ( leaves_len) . unwrap ( ) {
255
253
// The requested position was marked after the checkpoint was created, so we
256
254
// cannot create a witness.
257
255
None
@@ -279,7 +277,7 @@ impl<H: Hashable + PartialEq + Clone, C: Ord + Clone + core::fmt::Debug, const D
279
277
280
278
fn remove_mark ( & mut self , position : Position ) -> bool {
281
279
if self . marks . contains ( & position) {
282
- if let Some ( c) = self . checkpoints . values_mut ( ) . rev ( ) . next ( ) {
280
+ if let Some ( c) = self . checkpoints . values_mut ( ) . next_back ( ) {
283
281
c. forgotten . insert ( position) ;
284
282
} else {
285
283
self . marks . remove ( & position) ;
@@ -291,22 +289,43 @@ impl<H: Hashable + PartialEq + Clone, C: Ord + Clone + core::fmt::Debug, const D
291
289
}
292
290
293
291
fn checkpoint ( & mut self , id : C ) -> bool {
294
- if Some ( & id) > self . checkpoints . keys ( ) . rev ( ) . next ( ) {
292
+ if Some ( & id) > self . checkpoints . keys ( ) . next_back ( ) {
295
293
Self :: checkpoint ( self , id, self . current_position ( ) ) ;
296
294
true
297
295
} else {
298
296
false
299
297
}
300
298
}
301
299
302
- fn rewind ( & mut self ) -> bool {
303
- if let Some ( ( id, c) ) = self . checkpoints . iter ( ) . rev ( ) . next ( ) {
304
- self . leaves . truncate ( c. leaves_len ) ;
305
- for pos in c. marked . iter ( ) {
306
- self . marks . remove ( pos) ;
300
+ fn checkpoint_count ( & self ) -> usize {
301
+ self . checkpoints . len ( )
302
+ }
303
+
304
+ fn rewind ( & mut self , depth : usize ) -> bool {
305
+ if self . checkpoints . len ( ) > depth {
306
+ let mut to_delete = vec ! [ ] ;
307
+ for ( idx, ( id, c) ) in self
308
+ . checkpoints
309
+ . iter_mut ( )
310
+ . rev ( )
311
+ . enumerate ( )
312
+ . take ( depth + 1 )
313
+ {
314
+ for pos in c. marked . iter ( ) {
315
+ self . marks . remove ( pos) ;
316
+ }
317
+ if idx < depth {
318
+ to_delete. push ( id. clone ( ) ) ;
319
+ } else {
320
+ self . leaves . truncate ( c. leaves_len ) ;
321
+ c. marked . clear ( ) ;
322
+ c. forgotten . clear ( ) ;
323
+ }
307
324
}
308
- let id = id. clone ( ) ; // needed to avoid mutable/immutable borrow conflict
309
- self . checkpoints . remove ( & id) ;
325
+ for cid in to_delete. iter ( ) {
326
+ self . checkpoints . remove ( cid) ;
327
+ }
328
+
310
329
true
311
330
} else {
312
331
false
@@ -316,8 +335,6 @@ impl<H: Hashable + PartialEq + Clone, C: Ord + Clone + core::fmt::Debug, const D
316
335
317
336
#[ cfg( test) ]
318
337
mod tests {
319
- use std:: convert:: TryFrom ;
320
-
321
338
use super :: CompleteTree ;
322
339
use crate :: {
323
340
check_append, check_checkpoint_rewind, check_rewind_remove_mark, check_root_hashes,
@@ -334,7 +351,7 @@ mod tests {
334
351
}
335
352
336
353
let tree = CompleteTree :: < SipHashable , ( ) , DEPTH > :: new ( 100 ) ;
337
- assert_eq ! ( tree. root( 0 ) . unwrap ( ) , expected) ;
354
+ assert_eq ! ( tree. root( None ) , Some ( expected) ) ;
338
355
}
339
356
340
357
#[ test]
@@ -362,7 +379,7 @@ mod tests {
362
379
) ,
363
380
) ;
364
381
365
- assert_eq ! ( tree. root( 0 ) . unwrap ( ) , expected) ;
382
+ assert_eq ! ( tree. root( None ) , Some ( expected) ) ;
366
383
}
367
384
368
385
#[ test]
@@ -408,10 +425,12 @@ mod tests {
408
425
) ,
409
426
) ;
410
427
411
- assert_eq ! ( tree. root( 0 ) . unwrap( ) , expected) ;
428
+ assert_eq ! ( tree. root( None ) , Some ( expected. clone( ) ) ) ;
429
+ tree. checkpoint ( ( ) , None ) ;
430
+ assert_eq ! ( tree. root( Some ( 0 ) ) , Some ( expected. clone( ) ) ) ;
412
431
413
432
for i in 0u64 ..( 1 << DEPTH ) {
414
- let position = Position :: try_from ( i ) . unwrap ( ) ;
433
+ let position = Position :: from ( i ) ;
415
434
let path = tree. witness ( position, 0 ) . unwrap ( ) ;
416
435
assert_eq ! (
417
436
compute_root_from_witness( SipHashable ( i) , position, & path) ,
0 commit comments