From 4783c09278d11c5bf5f2629010fe488a2724f3e9 Mon Sep 17 00:00:00 2001 From: seebees Date: Tue, 22 Oct 2024 12:53:50 -0700 Subject: [PATCH 01/10] feat: Optimize sort by `Below` The `Below` function is in the hot path. The slice (x[1..]) operation is not optimized in Dafny. This optimizes this function by turning the recursive slice into a loop over an index into the seq. Further, a bounded integer version is also included. --- .../StructuredEncryption/src/SortCanon.dfy | 80 ++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy b/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy index 1033d899e..f9c977cbf 100644 --- a/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy +++ b/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy @@ -148,11 +148,89 @@ module SortCanon { } } - predicate method Below(x: seq, y: seq) { + predicate Below(x: seq, y: seq) { |x| != 0 ==> && |y| != 0 && x[0] <= y[0] && (x[0] == y[0] ==> Below(x[1..], y[1..])) + } by method { + + // The slice x[1..], y[1..] are un-optimized operations in Dafny. + // This means that their usage will result in a lot of data copying. + // Additional, it is very likely that these size of these sequences + // will be less than uint64. + // So writing an optimized version that only works on bounded types + // should further optimized this hot code. + + if HasUint64Len(x) && HasUint64Len(y) { + return BoundedBelow(x,y); + } + + if |x| == 0 { + assert Below(x, y); + return true; + } + + if |y| == 0 { + assert !Below(x, y); + return false; + } + + for i := 0 to |x| + invariant i <= |y| + // The function on the initial arguments + // is equal to function applied to the intermediate arguments. + invariant Below(x, y) == Below(x[i..], y[i..]) + { + if |y| <= i { + return false; + } else if y[i] < x[i] { + return false; + } else if x[i] < y[i] { + return true; + } else { + assert x[i] == y[i]; + } + } + + return true; + } + + predicate BoundedBelow(x: seq64, y: seq64) + { + Below(x,y) + } by method { + var xLength := |x| as uint64; + var yLength := |y| as uint64; + + if xLength == 0 { + assert BoundedBelow(x, y); + return true; + } + + if yLength == 0 { + assert !BoundedBelow(x, y); + return false; + } + + for i := 0 to xLength + invariant i <= yLength + // The function on the initial arguments + // is equal to function applied to the intermediate arguments. + invariant BoundedBelow(x, y) == BoundedBelow(x[i..], y[i..]) + { + if yLength <= i { + return false; + } else if y[i] < x[i] { + return false; + } else if x[i] < y[i] { + return true; + } else { + assert x[i] == y[i]; + } + } + + return true; } lemma BelowIsTotal() From b4a071aac043179d4beeef22e57c3096df0452c7 Mon Sep 17 00:00:00 2001 From: seebees Date: Sat, 26 Oct 2024 06:43:00 -0700 Subject: [PATCH 02/10] Compatability things --- .../src/DynamoDbEncryptionBranchKeyIdSupplier.dfy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DynamoDbEncryption/dafny/DynamoDbEncryption/src/DynamoDbEncryptionBranchKeyIdSupplier.dfy b/DynamoDbEncryption/dafny/DynamoDbEncryption/src/DynamoDbEncryptionBranchKeyIdSupplier.dfy index 4533b8def..c79ca3294 100644 --- a/DynamoDbEncryption/dafny/DynamoDbEncryption/src/DynamoDbEncryptionBranchKeyIdSupplier.dfy +++ b/DynamoDbEncryption/dafny/DynamoDbEncryption/src/DynamoDbEncryptionBranchKeyIdSupplier.dfy @@ -70,7 +70,7 @@ module DynamoDbEncryptionBranchKeyIdSupplier { // We expect this interface to be implemented in the native language, // so any errors thrown by the native implementation will appear as Opaque errors if err.Opaque? then - MPL.Opaque(obj:=err.obj) + MPL.Opaque(obj:=err.obj, alt_text:="") else MPL.AwsCryptographicMaterialProvidersException(message:="Unexpected error while getting Branch Key ID.") } From ea794cc2b84497c3d5781b66ebd988c6f319822f Mon Sep 17 00:00:00 2001 From: seebees Date: Thu, 31 Oct 2024 09:32:52 -0700 Subject: [PATCH 03/10] feat: Optimize sort by `MergeSort` The `MergeSort` function is in the hot path. The slice (x[1..]) operation is not optimized in Dafny. This optimizes this function by turning the recursive slice into a loop over an index into the seq. Further, a bounded integer version is also included. It also limits the total amount of data copied. --- .../src/OptimizedMergeSort.dfy | 515 ++++++++++++++++++ .../StructuredEncryption/src/SortCanon.dfy | 16 +- 2 files changed, 529 insertions(+), 2 deletions(-) create mode 100644 DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy diff --git a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy new file mode 100644 index 000000000..a9b3d12a5 --- /dev/null +++ b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy @@ -0,0 +1,515 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +include "../Model/AwsCryptographyDbEncryptionSdkStructuredEncryptionTypes.dfy" + +module {:options "-functionSyntax:4"} OptimizedMergeSort { + import Seq.MergeSort + import Relations = MergeSort.Relations + import opened StandardLibrary.UInt + + // The Seq.MergeSort function implemented as implemented + // does not compile to an optimal implementation + // in any of the Dafny target languages. + // This implementation aims to be significantly more optimal. + // First, it minimizes copies. + // It does this by making 2 arrays of the original sequence + // and then using these 2 as left and right alternatively. + // This can be audited by verifying + // that the arrays are only sliced into a seq in `FastMergeSort`. + // All other slicing is done in ghost state. + // Second, is has a bounded number implementation + // that avoids using `nat`. + + function {:isolate_assertions} FastMergeSort(s: seq, lessThanOrEq: (T, T) -> bool): (result :seq) + requires Relations.TotalOrdering(lessThanOrEq) + { + MergeSort.MergeSortBy(s, lessThanOrEq) + } + by method { + if |s| <= 1 { + return s; + } else { + + // The slice x[1..], y[1..] are un-optimized operations in Dafny. + // This means that their usage will result in a lot of data copying. + // Additional, it is very likely that these size of these sequences + // will be less than uint64. + // So writing an optimized version that only works on bounded types + // should further optimized this hot code. + + var left := new T[|s|](i requires 0 <= i < |s| => s[i]); + var right := new T[|s|](i requires 0 <= i < |s| => s[i]); + var lo, hi := 0, right.Length; + + label BEFORE_WORK: + + if HasUint64Len(s) { + var boundedLo: uint64, boundedHi: uint64 := 0, right.Length as uint64; + ghost var _ := BoundedMergeSortMethod(left, right, lessThanOrEq, boundedLo, boundedHi, Right); + + result := right[..]; + } else { + ghost var _ := MergeSortMethod(left, right, lessThanOrEq, lo, hi, Right); + + result := right[..]; + } + + ghost var other := MergeSort.MergeSortBy(s, lessThanOrEq); + + assert Relations.SortedBy(right[..], lessThanOrEq) by { + assert right[..] == right[lo..hi]; + } + assert multiset(right[..]) == multiset(other) by { + calc { + multiset(right[..]); + == {assert right[..] == right[lo..hi];} + multiset(right[lo..hi]); + == + multiset(old@BEFORE_WORK(left[lo..hi])); + == {assert old@BEFORE_WORK(left[lo..hi]) == s;} + multiset(s); + == + multiset(other); + } + } + + // Implementing a by method can be complicated. + // Because methods can have non-determinism, + // where functions can not. + // This means that Dafny normally wants to know + // that the method and function maintain equality at every step. + // But this is hard for this kind of optimized sorting. + // Because what is the functional state at every step + // and how does it correspond the state of 2 optimized arrays? + // This lemma works around this + // by proving that the outcomes are always deterministic and the same. + // It does this by proving that given a total ordering, + // there is one and only one way to sort a given sequence. + TotalOrderingImpliesSortingIsUnique(right[..], other, lessThanOrEq); + } + } + + datatype PlaceResults = Left | Right | Either + type ResultPlacement = r: PlaceResults | !r.Either? witness * + + method {:isolate_assertions} MergeSortMethod( + left: array, + right: array, + lessThanOrEq: (T, T) -> bool, + lo: nat, + hi: nat, + where: PlaceResults + ) + returns (resultPlacement: ResultPlacement) + requires Relations.TotalOrdering(lessThanOrEq) + requires lo < hi <= left.Length + requires hi <= right.Length && left != right + // reads left, right + modifies left, right + ensures !where.Either? ==> where == resultPlacement + + // We do not modify anything before lo + ensures left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]) + // We do not modify anything above hi + ensures left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]) + + ensures multiset(left[lo..hi]) == multiset(old(left[lo..hi])) + ensures resultPlacement.Left? ==> Relations.SortedBy(left[lo..hi], lessThanOrEq) + ensures resultPlacement.Right? ==> Relations.SortedBy(right[lo..hi], lessThanOrEq) + ensures resultPlacement.Right? ==> multiset(right[lo..hi]) == multiset(old(left[lo..hi])) + + decreases hi - lo + { + if hi - lo == 1 { + if where == Right { + right[lo] := left[lo]; + return Right; + } else { + return Left; + } + } + + ghost var beforeWork := multiset(left[lo..hi]); + var mid := (lo + hi) / 2; + var placement? := MergeSortMethod(left, right, lessThanOrEq, lo, mid, Either); + assert left[mid..hi] == old(left[mid..hi]); + ghost var placement2? := MergeSortMethod(left, right, lessThanOrEq, mid, hi, placement?); + assert placement2? == placement?; + + ghost var preMergeResult := if placement?.Left? then left else right; + calc { + multiset(preMergeResult[lo..hi]); + == { assert preMergeResult[lo..hi] == preMergeResult[lo..mid] + preMergeResult[mid..hi]; } + multiset(preMergeResult[lo..mid] + preMergeResult[mid..hi]); + == + multiset(old(left[lo..mid]) + old(left[mid..hi])); + == { assert old(left[lo..hi]) == old(left[lo..mid]) + old(left[mid..hi]); } + beforeWork; + } + + ghost var mergedResult; + if placement?.Left? { + MergeIntoRight(left := left, right := right, lessThanOrEq := lessThanOrEq, lo := lo, mid := mid, hi := hi); + resultPlacement := Right; + + mergedResult := right[lo..hi]; + assert left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]); + } else { + assert placement?.Right?; + MergeIntoRight(left := right, right := left, lessThanOrEq := lessThanOrEq, lo := lo, mid := mid, hi := hi); + resultPlacement := Left; + + mergedResult := left[lo..hi]; + assert left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]); + } + + label BEFORE_RETURN: + assert left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]); + if resultPlacement.Left? && where == Right { + forall i | lo <= i < hi { + right[i] := left[i]; + } + + assert right[lo..hi] == mergedResult; + assert left[..] == old@BEFORE_RETURN(left[..]); + assert right[..lo] == old(right[..lo]); + + resultPlacement := Right; + } + if resultPlacement.Right? && where == Left { + forall i | lo <= i < hi { + left[i] := right[i]; + } + + assert left[lo..hi] == mergedResult; + assert right[..] == old@BEFORE_RETURN(right[..]); + assert left[..lo] == old(left[..lo]); + + resultPlacement := Left; + } + } + + method {:isolate_assertions} MergeIntoRight( + nameonly left: array, + nameonly right: array, + nameonly lessThanOrEq: (T, T) -> bool, + nameonly lo: nat, + nameonly mid: nat, + nameonly hi: nat + ) + requires Relations.TotalOrdering(lessThanOrEq) + requires lo <= mid <= hi <= left.Length + requires hi <= right.Length && left != right + // We store "left" in [lo..mid] + requires Relations.SortedBy(left[lo..mid], lessThanOrEq) + // We store "right" in [mid..hi] + requires Relations.SortedBy(left[mid..hi], lessThanOrEq) + // reads left, right + modifies right + // We do not modify anything before lo + ensures right[..lo] == old(right[..lo]) && left[..lo] == old(left[..lo]) + // We do not modify anything above hi + ensures right[hi..] == old(right[hi..]) && left[..lo] == old(left[..lo]) + ensures Relations.SortedBy(right[lo..hi], lessThanOrEq) + ensures multiset(right[lo..hi]) == multiset(old(left[lo..hi])) + ensures multiset(left[lo..hi]) == multiset(old(left[lo..hi])) + { + var leftPosition, rightPosition, iter := lo, mid, lo; + while iter < hi + modifies right + + invariant lo <= leftPosition <= mid <= rightPosition <= hi + invariant leftPosition - lo + rightPosition - mid == iter - lo + invariant right[..lo] == old(right[..lo]) + invariant right[hi..] == old(right[hi..]) + + invariant Relations.SortedBy(left[leftPosition..mid], lessThanOrEq) + invariant Relations.SortedBy(left[rightPosition..hi], lessThanOrEq) + invariant Below(right[lo..iter], left[leftPosition..mid], lessThanOrEq) + invariant Below(right[lo..iter], left[rightPosition..hi], lessThanOrEq) + invariant Relations.SortedBy(right[lo..iter], lessThanOrEq) + invariant multiset(right[lo..iter]) == multiset(left[lo..leftPosition]) + multiset(left[mid..rightPosition]) + { + if leftPosition == mid || (rightPosition < hi && lessThanOrEq(left[rightPosition], left[leftPosition])) { + right[iter] := left[rightPosition]; + + PushStillSortedBy(right, lo, iter, lessThanOrEq); + rightPosition, iter := rightPosition + 1, iter + 1; + + BelowIsTransitive(right[lo..iter], left[leftPosition..mid], lessThanOrEq); + BelowIsTransitive(right[lo..iter], left[rightPosition..hi], lessThanOrEq); + } else { + right[iter] := left[leftPosition]; + + PushStillSortedBy(right, lo, iter, lessThanOrEq); + leftPosition, iter := leftPosition + 1, iter + 1; + + BelowIsTransitive(right[lo..iter], left[leftPosition..mid], lessThanOrEq); + BelowIsTransitive(right[lo..iter], left[rightPosition..hi], lessThanOrEq); + } + } + assert multiset(right[lo..hi]) == multiset(old(left[lo..hi])) by { + assert leftPosition == mid && rightPosition == hi; + assert old(left[lo..hi]) == left[lo..hi] == left[lo..mid] + left[mid..hi]; + } + } + + // Helpers to prove MergeSort + + ghost predicate Below(a: seq, b: seq, lessThanOrEq: (T, T) -> bool) + requires Relations.TotalOrdering(lessThanOrEq) + { + forall i, j :: 0 <= i < |a| && 0 <= j < |b| ==> lessThanOrEq(a[i], b[j]) + } + + lemma BelowIsTransitive(a: seq, b: seq, lessThanOrEq: (T, T) -> bool) + requires Relations.TotalOrdering(lessThanOrEq) + requires Relations.SortedBy(a, lessThanOrEq) + requires Relations.SortedBy(b, lessThanOrEq) + requires 0 < |a| && 0 < |b| ==> lessThanOrEq(a[|a| - 1], b[0]) + ensures Below(a, b, lessThanOrEq) + {} + + lemma PushStillSortedBy(a: array, lo:nat, i: nat, lessThanOrEq: (T, T) -> bool) + requires 0 <= lo <= i < a.Length + requires Relations.SortedBy(a[lo..i], lessThanOrEq) + requires |a[lo..i]| == 0 || lessThanOrEq(a[lo..i][|a[lo..i]| - 1], a[i]) + requires Relations.TotalOrdering(lessThanOrEq) + ensures Relations.SortedBy(a[lo..i + 1], lessThanOrEq) + ensures lo < i ==> lessThanOrEq(a[i - 1], a[i]) + {} + + lemma {:isolate_assertions} TotalOrderingImpliesSortingIsUnique(s1: seq, s2: seq, lessThanOrEq: (T, T) -> bool) + requires Relations.TotalOrdering(lessThanOrEq) + requires Relations.SortedBy(s1, lessThanOrEq) && Relations.SortedBy(s2, lessThanOrEq) + requires multiset(s1) == multiset(s2) + ensures s1 == s2 + { + if |s1| == 0 { + } else { + assert s1[0] in s2 by { + assert s1[0] in multiset(s2); + } + + var i :| 0 <= i < |s2| && s2[i] == s1[0]; + assert multiset{s1[0]} == multiset{s2[i]}; + assert multiset{s1[0]} + multiset(s1[1..]) == multiset(s1) by { + assert s1 == [s1[0]] + s1[1..]; + } + assert multiset{s2[i]} + multiset(s2[0..i] + s2[i+1..]) == multiset(s2) by { + assert s2 == s2[0..i] + [s2[i]] + s2[i+1..]; + } + + assert Relations.SortedBy(s1[1..], lessThanOrEq); + assert Relations.SortedBy(s2[0..i] + s2[i+1..], lessThanOrEq) by { + if i == 0 || i == |s2| - 1 { + } else { + assert lessThanOrEq(s2[i-1], s2[i]); + assert lessThanOrEq(s2[i], s2[i+1]); + } + } + MultisetProperty(multiset{s1[0]}, multiset(s1[1..]), multiset(s2[0..i] + s2[i+1..])); + TotalOrderingImpliesSortingIsUnique(s1[1..], s2[0..i] + s2[i+1..], lessThanOrEq); + + if i == 0 { + } else { + assert s1 == [s2[i]] + s1[1..]; + assert lessThanOrEq(s2[0], s2[i]); + assert s2[0] in s1; + } + } + } + + lemma MultisetProperty(m: multiset, a: multiset, b: multiset) + requires m + a == m + b + ensures a == b + { + var a' := (m + a) - m; + var b' := (m + b) - m; + assert a == a' == b' == b; + } + + // These are bounded implementations of the above. + // They do exactly the same thing, + // but they use `uint64`. + // This further speeds things up + // because math with bounded variables + // is significantly faster that math with big numbers. + + method {:isolate_assertions} BoundedMergeSortMethod( + left: array, + right: array, + lessThanOrEq: (T, T) -> bool, + lo: uint64, + hi: uint64, + where: PlaceResults + ) + returns (resultPlacement: ResultPlacement) + requires Relations.TotalOrdering(lessThanOrEq) + requires + && left.Length < UINT64_LIMIT + && right.Length < UINT64_LIMIT + requires lo < hi <= left.Length as uint64 + requires hi <= right.Length as uint64 && left != right + // reads left, right + modifies left, right + ensures !where.Either? ==> where == resultPlacement + + // We do not modify anything before lo + ensures left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]) + // We do not modify anything above hi + ensures left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]) + + ensures multiset(left[lo..hi]) == multiset(old(left[lo..hi])) + ensures resultPlacement.Left? ==> Relations.SortedBy(left[lo..hi], lessThanOrEq) + ensures resultPlacement.Right? ==> Relations.SortedBy(right[lo..hi], lessThanOrEq) + ensures resultPlacement.Right? ==> multiset(right[lo..hi]) == multiset(old(left[lo..hi])) + + decreases hi - lo + { + if hi - lo == 1 { + if where == Right { + right[lo] := left[lo]; + return Right; + } else { + return Left; + } + } + + ghost var beforeWork := multiset(left[lo..hi]); + var mid := ((hi - lo)/2) + lo; + var placement? := BoundedMergeSortMethod(left, right, lessThanOrEq, lo, mid, Either); + assert left[mid..hi] == old(left[mid..hi]); + ghost var placement2? := BoundedMergeSortMethod(left, right, lessThanOrEq, mid, hi, placement?); + assert placement2? == placement?; + + ghost var preMergeResult := if placement?.Left? then left else right; + calc { + multiset(preMergeResult[lo..hi]); + == { assert preMergeResult[lo..hi] == preMergeResult[lo..mid] + preMergeResult[mid..hi]; } + multiset(preMergeResult[lo..mid] + preMergeResult[mid..hi]); + == + multiset(old(left[lo..mid]) + old(left[mid..hi])); + == { assert old(left[lo..hi]) == old(left[lo..mid]) + old(left[mid..hi]); } + beforeWork; + } + + ghost var mergedResult; + if placement?.Left? { + BoundedMergeIntoRight(left := left, right := right, lessThanOrEq := lessThanOrEq, lo := lo, mid := mid, hi := hi); + resultPlacement, mergedResult := Right, right[lo..hi]; + + assert left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]); + } else { + assert placement?.Right?; + BoundedMergeIntoRight(left := right, right := left, lessThanOrEq := lessThanOrEq, lo := lo, mid := mid, hi := hi); + resultPlacement, mergedResult := Left, left[lo..hi]; + + assert left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]); + } + + label BEFORE_RETURN: + assert left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]); + if resultPlacement.Left? && where == Right { + forall i | lo <= i < hi { + right[i] := left[i]; + } + + assert right[lo..hi] == mergedResult; + assert left[..] == old@BEFORE_RETURN(left[..]); + assert right[..lo] == old(right[..lo]); + + resultPlacement := Right; + } + if resultPlacement.Right? && where == Left { + forall i | lo <= i < hi { + left[i] := right[i]; + } + assert left[lo..hi] == mergedResult; + assert right[..] == old@BEFORE_RETURN(right[..]); + assert left[..lo] == old(left[..lo]); + + resultPlacement := Left; + } + } + + method {:isolate_assertions} BoundedMergeIntoRight( + nameonly left: array, + nameonly right: array, + nameonly lessThanOrEq: (T, T) -> bool, + nameonly lo: uint64, + nameonly mid: uint64, + nameonly hi: uint64 + ) + requires Relations.TotalOrdering(lessThanOrEq) + requires + && left.Length < UINT64_LIMIT + && right.Length < UINT64_LIMIT + requires lo <= mid <= hi <= left.Length as uint64 + requires hi <= right.Length as uint64 && left != right + // We store "left" in [lo..mid] + requires Relations.SortedBy(left[lo..mid], lessThanOrEq) + // We store "right" in [mid..hi] + requires Relations.SortedBy(left[mid..hi], lessThanOrEq) + // reads left, right + modifies right + // We do not modify anything before lo + ensures right[..lo] == old(right[..lo]) && left[..lo] == old(left[..lo]) + // We do not modify anything above hi + ensures right[hi..] == old(right[hi..]) && left[..lo] == old(left[..lo]) + ensures Relations.SortedBy(right[lo..hi], lessThanOrEq) + ensures multiset(right[lo..hi]) == multiset(old(left[lo..hi])) + ensures multiset(left[lo..hi]) == multiset(old(left[lo..hi])) + { + var leftPosition, rightPosition, iter := lo, mid, lo; + while iter < hi + modifies right + + invariant lo <= leftPosition <= mid <= rightPosition <= hi + invariant leftPosition as nat - lo as nat + rightPosition as nat - mid as nat == iter as nat - lo as nat + invariant right[..lo] == old(right[..lo]) + invariant right[hi..] == old(right[hi..]) + + invariant Relations.SortedBy(left[leftPosition..mid], lessThanOrEq) + invariant Relations.SortedBy(left[rightPosition..hi], lessThanOrEq) + invariant Below(right[lo..iter], left[leftPosition..mid], lessThanOrEq) + invariant Below(right[lo..iter], left[rightPosition..hi], lessThanOrEq) + invariant Relations.SortedBy(right[lo..iter], lessThanOrEq) + invariant multiset(right[lo..iter]) == multiset(left[lo..leftPosition]) + multiset(left[mid..rightPosition]) + { + label BEFORE_WORK: + if leftPosition == mid || (rightPosition < hi && lessThanOrEq(left[rightPosition], left[leftPosition])) { + right[iter] := left[rightPosition]; + + PushStillSortedBy(right, lo as nat, iter as nat, lessThanOrEq); + rightPosition, iter := rightPosition + 1, iter + 1; + + BelowIsTransitive(right[lo..iter], left[leftPosition..mid], lessThanOrEq); + + assert 0 < |right[lo..iter]| && 0 < |left[rightPosition..hi]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[rightPosition..hi][0]); + BelowIsTransitive(right[lo..iter], left[rightPosition..hi], lessThanOrEq); + } else { + right[iter] := left[leftPosition]; + + PushStillSortedBy(right, lo as nat, iter as nat, lessThanOrEq); + leftPosition, iter := leftPosition + 1, iter + 1; + + assert 0 < |right[lo..iter]| && 0 < |left[leftPosition..mid]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[leftPosition..mid][0]) by { + if 0 < |right[lo..iter]| && 0 < |left[leftPosition..mid]| { + assert lessThanOrEq(left[leftPosition-1], left[leftPosition]) by { + assert lo <= leftPosition-1 < leftPosition < mid; + assert Relations.SortedBy(left[lo..mid], lessThanOrEq); + } + } + } + BelowIsTransitive(right[lo..iter], left[leftPosition..mid], lessThanOrEq); + BelowIsTransitive(right[lo..iter], left[rightPosition..hi], lessThanOrEq); + } + } + assert multiset(right[lo..hi]) == multiset(old(left[lo..hi])) by { + assert leftPosition == mid && rightPosition == hi; + assert old(left[lo..hi]) == left[lo..hi] == left[lo..mid] + left[mid..hi]; + } + } +} diff --git a/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy b/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy index f9c977cbf..ae86e268b 100644 --- a/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy +++ b/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy @@ -3,6 +3,7 @@ include "../Model/AwsCryptographyDbEncryptionSdkStructuredEncryptionTypes.dfy" include "Util.dfy" +include "OptimizedMergeSort.dfy" module SortCanon { export @@ -22,6 +23,7 @@ module SortCanon { import opened Relations import opened Seq.MergeSort import opened StructuredEncryptionUtil + import OptimizedMergeSort predicate method AuthBelow(x: CanonAuthItem, y: CanonAuthItem) { Below(x.key, y.key) @@ -295,7 +297,7 @@ module SortCanon { {} - function method AuthSort(x : CanonAuthList) : (result : CanonAuthList) + function AuthSort(x : CanonAuthList) : (result : CanonAuthList) requires CanonAuthListHasNoDuplicates(x) ensures multiset(x) == multiset(result) ensures SortedBy(result, AuthBelow) @@ -307,9 +309,14 @@ module SortCanon { CanonAuthListMultiNoDup(x, ret); assert CanonAuthListHasNoDuplicates(ret); ret + } by method { + AuthBelowIsTotal(); + result := OptimizedMergeSort.FastMergeSort(x, AuthBelow); + CanonAuthListMultiNoDup(x, result); + assert CanonAuthListHasNoDuplicates(result); } - function method CryptoSort(x : CanonCryptoList) : (result : CanonCryptoList) + function CryptoSort(x : CanonCryptoList) : (result : CanonCryptoList) requires CanonCryptoListHasNoDuplicates(x) ensures multiset(x) == multiset(result) ensures multiset(result) == multiset(x) @@ -322,6 +329,11 @@ module SortCanon { CanonCryptoListMultiNoDup(x, ret); assert CanonCryptoListHasNoDuplicates(ret); ret + } by method { + CryptoBelowIsTotal(); + result := OptimizedMergeSort.FastMergeSort(x, CryptoBelow); + CanonCryptoListMultiNoDup(x, result); + assert CanonCryptoListHasNoDuplicates(result); } lemma MultisetHasNoDuplicates(xs: CanonCryptoList) From a41a673416a91cf5c8b47395e24ca9e99f96d088 Mon Sep 17 00:00:00 2001 From: seebees Date: Fri, 1 Nov 2024 13:08:12 -0700 Subject: [PATCH 04/10] roll back --- .../src/DynamoDbEncryptionBranchKeyIdSupplier.dfy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DynamoDbEncryption/dafny/DynamoDbEncryption/src/DynamoDbEncryptionBranchKeyIdSupplier.dfy b/DynamoDbEncryption/dafny/DynamoDbEncryption/src/DynamoDbEncryptionBranchKeyIdSupplier.dfy index c79ca3294..4533b8def 100644 --- a/DynamoDbEncryption/dafny/DynamoDbEncryption/src/DynamoDbEncryptionBranchKeyIdSupplier.dfy +++ b/DynamoDbEncryption/dafny/DynamoDbEncryption/src/DynamoDbEncryptionBranchKeyIdSupplier.dfy @@ -70,7 +70,7 @@ module DynamoDbEncryptionBranchKeyIdSupplier { // We expect this interface to be implemented in the native language, // so any errors thrown by the native implementation will appear as Opaque errors if err.Opaque? then - MPL.Opaque(obj:=err.obj, alt_text:="") + MPL.Opaque(obj:=err.obj) else MPL.AwsCryptographicMaterialProvidersException(message:="Unexpected error while getting Branch Key ID.") } From 3aa38da1bd6683c86d4fbf54af31cb313fbdc912 Mon Sep 17 00:00:00 2001 From: seebees Date: Fri, 1 Nov 2024 20:53:15 -0700 Subject: [PATCH 05/10] do not use for comprehension --- .../src/OptimizedMergeSort.dfy | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy index a9b3d12a5..1bde8c058 100644 --- a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy +++ b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy @@ -85,7 +85,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { // This lemma works around this // by proving that the outcomes are always deterministic and the same. // It does this by proving that given a total ordering, - // there is one and only one way to sort a given sequence. + // there is one and only one way to sort a given sequence. TotalOrderingImpliesSortingIsUnique(right[..], other, lessThanOrEq); } } @@ -412,21 +412,36 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { label BEFORE_RETURN: assert left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]); if resultPlacement.Left? && where == Right { - forall i | lo <= i < hi { + for i := lo to hi + modifies right + invariant left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]) + invariant left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]) + invariant right[lo..i] == left[lo..i] + { right[i] := left[i]; } - assert right[lo..hi] == mergedResult; + assert right[lo..hi] == mergedResult by { + assert mergedResult == left[lo..hi]; + } assert left[..] == old@BEFORE_RETURN(left[..]); assert right[..lo] == old(right[..lo]); resultPlacement := Right; } if resultPlacement.Right? && where == Left { - forall i | lo <= i < hi { + for i := lo to hi + modifies left + invariant left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]) + invariant left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]) + invariant left[lo..i] == right[lo..i] + { left[i] := right[i]; } - assert left[lo..hi] == mergedResult; + + assert left[lo..hi] == mergedResult by { + assert mergedResult == right[lo..hi]; + } assert right[..] == old@BEFORE_RETURN(right[..]); assert left[..lo] == old(left[..lo]); From 63789b68bbaea65b1aa437741585b2dfbe6295fb Mon Sep 17 00:00:00 2001 From: seebees Date: Mon, 4 Nov 2024 09:59:05 -0800 Subject: [PATCH 06/10] update optimize --- .../src/OptimizedMergeSort.dfy | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy index 1bde8c058..5c5152a79 100644 --- a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy +++ b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy @@ -167,7 +167,12 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { label BEFORE_RETURN: assert left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]); if resultPlacement.Left? && where == Right { - forall i | lo <= i < hi { + for i := lo to hi + modifies right + invariant left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]) + invariant left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]) + invariant right[lo..i] == left[lo..i] + { right[i] := left[i]; } @@ -178,7 +183,12 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { resultPlacement := Right; } if resultPlacement.Right? && where == Left { - forall i | lo <= i < hi { + for i := lo to hi + modifies left + invariant left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]) + invariant left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]) + invariant left[lo..i] == right[lo..i] + { left[i] := right[i]; } @@ -412,6 +422,14 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { label BEFORE_RETURN: assert left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]); if resultPlacement.Left? && where == Right { + // A forall comprehension might seem like a nice fit here, + // however this does not good for two reasons. + // First, Dafny currently creates a range fur the full bounds of the bounded number + // see: https://github.com/dafny-lang/dafny/issues/5897 + // Second this would create two loops. + // First loop would create the `lo to hi` range of numbers. + // The second loop would then loop over these elements. + // A single loop with for i := lo to hi modifies right invariant left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]) From 005f287bf947e37f1dbdbc4e3bca1b7c3f15b4cf Mon Sep 17 00:00:00 2001 From: Andy Jewell Date: Mon, 10 Feb 2025 11:11:14 -0500 Subject: [PATCH 07/10] m --- .../dafny/StructuredEncryption/src/OptimizedMergeSort.dfy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy index 5c5152a79..7b3ba1b36 100644 --- a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy +++ b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy @@ -429,7 +429,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { // Second this would create two loops. // First loop would create the `lo to hi` range of numbers. // The second loop would then loop over these elements. - // A single loop with + // A single loop with for i := lo to hi modifies right invariant left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]) From bfa2d684c515195efc9d747f59c899982df6b154 Mon Sep 17 00:00:00 2001 From: seebees Date: Tue, 3 Dec 2024 13:35:17 -0800 Subject: [PATCH 08/10] Update some of the proof --- .../src/OptimizedMergeSort.dfy | 44 ++++++++++++++----- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy index 7b3ba1b36..e071bfa9e 100644 --- a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy +++ b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy @@ -105,7 +105,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { requires Relations.TotalOrdering(lessThanOrEq) requires lo < hi <= left.Length requires hi <= right.Length && left != right - // reads left, right + reads left, right modifies left, right ensures !where.Either? ==> where == resultPlacement @@ -217,6 +217,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { requires Relations.SortedBy(left[mid..hi], lessThanOrEq) // reads left, right modifies right + reads left, right // We do not modify anything before lo ensures right[..lo] == old(right[..lo]) && left[..lo] == old(left[..lo]) // We do not modify anything above hi @@ -239,8 +240,10 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { invariant Below(right[lo..iter], left[leftPosition..mid], lessThanOrEq) invariant Below(right[lo..iter], left[rightPosition..hi], lessThanOrEq) invariant Relations.SortedBy(right[lo..iter], lessThanOrEq) - invariant multiset(right[lo..iter]) == multiset(left[lo..leftPosition]) + multiset(left[mid..rightPosition]) + invariant multiset(right[lo..iter]) == multiset(left[lo..leftPosition] + left[mid..rightPosition]) { + + ghost var oldRightPosition, oldIter, oldLeftPosition := rightPosition, iter, leftPosition; if leftPosition == mid || (rightPosition < hi && lessThanOrEq(left[rightPosition], left[leftPosition])) { right[iter] := left[rightPosition]; @@ -255,7 +258,26 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { PushStillSortedBy(right, lo, iter, lessThanOrEq); leftPosition, iter := leftPosition + 1, iter + 1; + assert 0 < |right[lo..iter]| && 0 < |left[leftPosition..mid]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[leftPosition..mid][0]) by { + if 0 == |right[lo..iter]| || 0 == |left[leftPosition..mid]| { + } else { + assert rightPosition == oldRightPosition; + assert oldLeftPosition < mid; + // This is true, but uncommenting it causes the proof to fail + // leaving it here it make what is going on a little more clear + // assert right[lo..iter][|right[lo..iter]| - 1] == right[oldIter]; + assert left[leftPosition..mid][0] == left[leftPosition]; + } + } BelowIsTransitive(right[lo..iter], left[leftPosition..mid], lessThanOrEq); + + assert 0 < |right[lo..iter]| && 0 < |left[rightPosition..hi]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[rightPosition..hi][0]) by { + if 0 == |right[lo..iter]| || 0 == |left[rightPosition..hi]| { + } else { + assert right[lo..iter][|right[lo..iter]| - 1] == right[iter - 1]; + assert left[rightPosition..hi][0] == left[rightPosition]; + } + } BelowIsTransitive(right[lo..iter], left[rightPosition..hi], lessThanOrEq); } } @@ -362,7 +384,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { && right.Length < UINT64_LIMIT requires lo < hi <= left.Length as uint64 requires hi <= right.Length as uint64 && left != right - // reads left, right + reads left, right modifies left, right ensures !where.Either? ==> where == resultPlacement @@ -485,7 +507,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { requires Relations.SortedBy(left[lo..mid], lessThanOrEq) // We store "right" in [mid..hi] requires Relations.SortedBy(left[mid..hi], lessThanOrEq) - // reads left, right + reads left, right modifies right // We do not modify anything before lo ensures right[..lo] == old(right[..lo]) && left[..lo] == old(left[..lo]) @@ -509,9 +531,10 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { invariant Below(right[lo..iter], left[leftPosition..mid], lessThanOrEq) invariant Below(right[lo..iter], left[rightPosition..hi], lessThanOrEq) invariant Relations.SortedBy(right[lo..iter], lessThanOrEq) - invariant multiset(right[lo..iter]) == multiset(left[lo..leftPosition]) + multiset(left[mid..rightPosition]) + invariant multiset(right[lo..iter]) == multiset(left[lo..leftPosition] + left[mid..rightPosition]) { - label BEFORE_WORK: + + ghost var oldRightPosition, oldIter, oldLeftPosition := rightPosition, iter, leftPosition; if leftPosition == mid || (rightPosition < hi && lessThanOrEq(left[rightPosition], left[leftPosition])) { right[iter] := left[rightPosition]; @@ -529,11 +552,10 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { leftPosition, iter := leftPosition + 1, iter + 1; assert 0 < |right[lo..iter]| && 0 < |left[leftPosition..mid]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[leftPosition..mid][0]) by { - if 0 < |right[lo..iter]| && 0 < |left[leftPosition..mid]| { - assert lessThanOrEq(left[leftPosition-1], left[leftPosition]) by { - assert lo <= leftPosition-1 < leftPosition < mid; - assert Relations.SortedBy(left[lo..mid], lessThanOrEq); - } + if 0 == |right[lo..iter]| || 0 == |left[leftPosition..mid]| { + } else { + assert right[lo..iter][|right[lo..iter]| - 1] == right[iter - 1]; + assert left[leftPosition..mid][0] == left[leftPosition]; } } BelowIsTransitive(right[lo..iter], left[leftPosition..mid], lessThanOrEq); From 40153a582b3cd900844ef28e87519cd556a3d7c4 Mon Sep 17 00:00:00 2001 From: seebees Date: Wed, 12 Feb 2025 15:15:27 -0800 Subject: [PATCH 09/10] Updates to keep inline with Dafny standard library. --- .../src/OptimizedMergeSort.dfy | 320 +++++++++++++----- .../StructuredEncryption/src/SortCanon.dfy | 4 +- 2 files changed, 241 insertions(+), 83 deletions(-) diff --git a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy index e071bfa9e..f85b5cb82 100644 --- a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy +++ b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy @@ -4,11 +4,16 @@ include "../Model/AwsCryptographyDbEncryptionSdkStructuredEncryptionTypes.dfy" module {:options "-functionSyntax:4"} OptimizedMergeSort { - import Seq.MergeSort - import Relations = MergeSort.Relations - import opened StandardLibrary.UInt - // The Seq.MergeSort function implemented as implemented + import Relations + import BoundedInts + import InternalModule = Seq.MergeSort + + predicate HasUint64Len(s: seq) { + |s| < BoundedInts.TWO_TO_THE_64 + } + + // The MergeSortBy function implemented as implemented // does not compile to an optimal implementation // in any of the Dafny target languages. // This implementation aims to be significantly more optimal. @@ -16,15 +21,80 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { // It does this by making 2 arrays of the original sequence // and then using these 2 as left and right alternatively. // This can be audited by verifying - // that the arrays are only sliced into a seq in `FastMergeSort`. + // that the arrays are only sliced into a seq in `MergeSortNat`. // All other slicing is done in ghost state. // Second, is has a bounded number implementation // that avoids using `nat`. - function {:isolate_assertions} FastMergeSort(s: seq, lessThanOrEq: (T, T) -> bool): (result :seq) + function {:isolate_assertions} MergeSort(s: seq, lessThanOrEq: (T, T) -> bool): (result :seq) + requires Relations.TotalOrdering(lessThanOrEq) + requires HasUint64Len(s) + { + InternalModule.MergeSortBy(s, lessThanOrEq) + } + by method { + if |s| <= 1 { + return s; + } else { + + // The slice x[1..], y[1..] are un-optimized operations in Dafny. + // This means that their usage will result in a lot of data copying. + // Additional, it is very likely that these size of these sequences + // will be less than uint64. + // So writing an optimized version that only works on bounded types + // should further optimized this hot code. + + var left := new T[|s|](i requires 0 <= i < |s| => s[i]); + var right := new T[|s|](i requires 0 <= i < |s| => s[i]); + var lo, hi := 0, right.Length; + + label BEFORE_WORK: + + var boundedLo: BoundedInts.uint64, boundedHi: BoundedInts.uint64 := 0, right.Length as BoundedInts.uint64; + ghost var _ := MergeSortMethod(left, right, lessThanOrEq, boundedLo, boundedHi, Right); + + result := right[..]; + + ghost var other := InternalModule.MergeSortBy(s, lessThanOrEq); + + assert Relations.SortedBy(right[..], lessThanOrEq) by { + assert right[..] == right[lo..hi]; + } + assert multiset(right[..]) == multiset(other) by { + calc { + multiset(right[..]); + == {assert right[..] == right[lo..hi];} + multiset(right[lo..hi]); + == + multiset(old@BEFORE_WORK(left[lo..hi])); + == {assert old@BEFORE_WORK(left[lo..hi]) == s;} + multiset(s); + == + multiset(other); + } + } + + // Implementing a by method can be complicated. + // Because methods can have non-determinism, + // where functions can not. + // This means that Dafny normally wants to know + // that the method and function maintain equality at every step. + // But this is hard for this kind of optimized sorting. + // Because what is the functional state at every step + // and how does it correspond the state of 2 optimized arrays? + // This lemma works around this + // by proving that the outcomes are always deterministic and the same. + // It does this by proving that given a total ordering, + // there is one and only one way to sort a given sequence. + TotalOrderingImpliesSortingIsUnique(right[..], other, lessThanOrEq); + } + } + + // This is included as sugar in case you don't want to ensure your seq HasUint64Len. + function {:isolate_assertions} MergeSortNat(s: seq, lessThanOrEq: (T, T) -> bool): (result :seq) requires Relations.TotalOrdering(lessThanOrEq) { - MergeSort.MergeSortBy(s, lessThanOrEq) + InternalModule.MergeSortBy(s, lessThanOrEq) } by method { if |s| <= 1 { @@ -45,17 +115,22 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { label BEFORE_WORK: if HasUint64Len(s) { - var boundedLo: uint64, boundedHi: uint64 := 0, right.Length as uint64; - ghost var _ := BoundedMergeSortMethod(left, right, lessThanOrEq, boundedLo, boundedHi, Right); + var boundedLo: BoundedInts.uint64, boundedHi: BoundedInts.uint64 := 0, right.Length as BoundedInts.uint64; + ghost var _ := MergeSortMethod(left, right, lessThanOrEq, boundedLo, boundedHi, Right); result := right[..]; } else { - ghost var _ := MergeSortMethod(left, right, lessThanOrEq, lo, hi, Right); + // Fallback to `nat` or BigInt. + // This is a little silly, but this ensures + // that the behavior for very large seq will be the same. + // Though it is likely if any such seq existed in the real world, + // the performance improvement here would still not be enough to complete the sort... + ghost var _ := NatMergeSortMethod(left, right, lessThanOrEq, lo, hi, Right); result := right[..]; } - ghost var other := MergeSort.MergeSortBy(s, lessThanOrEq); + ghost var other := InternalModule.MergeSortBy(s, lessThanOrEq); assert Relations.SortedBy(right[..], lessThanOrEq) by { assert right[..] == right[lo..hi]; @@ -93,18 +168,26 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { datatype PlaceResults = Left | Right | Either type ResultPlacement = r: PlaceResults | !r.Either? witness * + // These are bounded implementations of merge sort. + // This further speeds things up + // because math with bounded variables + // is significantly faster that math with big numbers. + method {:isolate_assertions} MergeSortMethod( left: array, right: array, lessThanOrEq: (T, T) -> bool, - lo: nat, - hi: nat, + lo: BoundedInts.uint64, + hi: BoundedInts.uint64, where: PlaceResults ) returns (resultPlacement: ResultPlacement) requires Relations.TotalOrdering(lessThanOrEq) - requires lo < hi <= left.Length - requires hi <= right.Length && left != right + requires left.Length < BoundedInts.TWO_TO_THE_64 + requires right.Length < BoundedInts.TWO_TO_THE_64 + requires lo < hi <= left.Length as BoundedInts.uint64 + requires hi <= right.Length as BoundedInts.uint64 + requires left != right reads left, right modifies left, right ensures !where.Either? ==> where == resultPlacement @@ -131,7 +214,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { } ghost var beforeWork := multiset(left[lo..hi]); - var mid := (lo + hi) / 2; + var mid := ((hi - lo)/2) + lo; var placement? := MergeSortMethod(left, right, lessThanOrEq, lo, mid, Either); assert left[mid..hi] == old(left[mid..hi]); ghost var placement2? := MergeSortMethod(left, right, lessThanOrEq, mid, hi, placement?); @@ -140,33 +223,47 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { ghost var preMergeResult := if placement?.Left? then left else right; calc { multiset(preMergeResult[lo..hi]); - == { assert preMergeResult[lo..hi] == preMergeResult[lo..mid] + preMergeResult[mid..hi]; } + == { LemmaSplitAt(preMergeResult[..], lo as nat, mid as nat, hi as nat); } multiset(preMergeResult[lo..mid] + preMergeResult[mid..hi]); == multiset(old(left[lo..mid]) + old(left[mid..hi])); - == { assert old(left[lo..hi]) == old(left[lo..mid]) + old(left[mid..hi]); } + == { LemmaSplitAt(old(left[..]), lo as nat, mid as nat, hi as nat); } beforeWork; } ghost var mergedResult; if placement?.Left? { MergeIntoRight(left := left, right := right, lessThanOrEq := lessThanOrEq, lo := lo, mid := mid, hi := hi); - resultPlacement := Right; + resultPlacement, mergedResult := Right, right[lo..hi]; - mergedResult := right[lo..hi]; - assert left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]); + assert left[hi..] == old(left[hi..]); + assert right[hi..] == old(right[hi..]); + assert left[..lo] == old(left[..lo]); + assert right[..lo] == old(right[..lo]); } else { assert placement?.Right?; MergeIntoRight(left := right, right := left, lessThanOrEq := lessThanOrEq, lo := lo, mid := mid, hi := hi); - resultPlacement := Left; + resultPlacement, mergedResult := Left, left[lo..hi]; - mergedResult := left[lo..hi]; - assert left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]); + assert left[hi..] == old(left[hi..]); + assert right[hi..] == old(right[hi..]); + assert left[..lo] == old(left[..lo]); + assert right[..lo] == old(right[..lo]); } label BEFORE_RETURN: - assert left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]); + assert left[hi..] == old(left[hi..]); + assert right[hi..] == old(right[hi..]); + assert left[..lo] == old(left[..lo]); + assert right[..lo] == old(right[..lo]); if resultPlacement.Left? && where == Right { + // A forall comprehension might seem like a nice fit here, + // however this does not good for two reasons. + // First, Dafny currently creates a range for the full bounds of the bounded number + // see: https://github.com/dafny-lang/dafny/issues/5897 + // Second this would create two loops. + // First loop would create the `lo to hi` range of numbers. + // The second loop would then loop over these elements. for i := lo to hi modifies right invariant left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]) @@ -174,9 +271,12 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { invariant right[lo..i] == left[lo..i] { right[i] := left[i]; + assert right[lo..i] == left[lo..i]; } - assert right[lo..hi] == mergedResult; + assert right[lo..hi] == mergedResult by { + assert mergedResult == left[lo..hi]; + } assert left[..] == old@BEFORE_RETURN(left[..]); assert right[..lo] == old(right[..lo]); @@ -190,9 +290,12 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { invariant left[lo..i] == right[lo..i] { left[i] := right[i]; + assert right[lo..i] == left[lo..i]; } - assert left[lo..hi] == mergedResult; + assert left[lo..hi] == mergedResult by { + assert mergedResult == right[lo..hi]; + } assert right[..] == old@BEFORE_RETURN(right[..]); assert left[..lo] == old(left[..lo]); @@ -204,20 +307,22 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { nameonly left: array, nameonly right: array, nameonly lessThanOrEq: (T, T) -> bool, - nameonly lo: nat, - nameonly mid: nat, - nameonly hi: nat + nameonly lo: BoundedInts.uint64, + nameonly mid: BoundedInts.uint64, + nameonly hi: BoundedInts.uint64 ) requires Relations.TotalOrdering(lessThanOrEq) - requires lo <= mid <= hi <= left.Length - requires hi <= right.Length && left != right + requires + && left.Length < BoundedInts.TWO_TO_THE_64 + && right.Length < BoundedInts.TWO_TO_THE_64 + requires lo <= mid <= hi <= left.Length as BoundedInts.uint64 + requires hi <= right.Length as BoundedInts.uint64 && left != right // We store "left" in [lo..mid] requires Relations.SortedBy(left[lo..mid], lessThanOrEq) // We store "right" in [mid..hi] requires Relations.SortedBy(left[mid..hi], lessThanOrEq) - // reads left, right - modifies right reads left, right + modifies right // We do not modify anything before lo ensures right[..lo] == old(right[..lo]) && left[..lo] == old(left[..lo]) // We do not modify anything above hi @@ -231,7 +336,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { modifies right invariant lo <= leftPosition <= mid <= rightPosition <= hi - invariant leftPosition - lo + rightPosition - mid == iter - lo + invariant leftPosition as nat - lo as nat + rightPosition as nat - mid as nat == iter as nat - lo as nat invariant right[..lo] == old(right[..lo]) invariant right[hi..] == old(right[hi..]) @@ -247,30 +352,37 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { if leftPosition == mid || (rightPosition < hi && lessThanOrEq(left[rightPosition], left[leftPosition])) { right[iter] := left[rightPosition]; - PushStillSortedBy(right, lo, iter, lessThanOrEq); + PushStillSortedBy(right, lo as nat, iter as nat, lessThanOrEq); rightPosition, iter := rightPosition + 1, iter + 1; BelowIsTransitive(right[lo..iter], left[leftPosition..mid], lessThanOrEq); + + assert 0 < |right[lo..iter]| && 0 < |left[rightPosition..hi]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[rightPosition..hi][0]) by { + if 0 == |right[lo..iter]| || 0 == |left[rightPosition..hi]| { + } else { + assert Relations.SortedBy(left[oldRightPosition..hi], lessThanOrEq); + assert lessThanOrEq(left[oldRightPosition..hi][0], left[oldRightPosition..hi][1]); + } + } BelowIsTransitive(right[lo..iter], left[rightPosition..hi], lessThanOrEq); + + assert multiset(right[lo..iter]) == multiset(left[lo..leftPosition] + left[mid..rightPosition]) by { + // Dafny just wants to be reminded + } } else { right[iter] := left[leftPosition]; - PushStillSortedBy(right, lo, iter, lessThanOrEq); + PushStillSortedBy(right, lo as nat, iter as nat, lessThanOrEq); leftPosition, iter := leftPosition + 1, iter + 1; assert 0 < |right[lo..iter]| && 0 < |left[leftPosition..mid]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[leftPosition..mid][0]) by { if 0 == |right[lo..iter]| || 0 == |left[leftPosition..mid]| { } else { - assert rightPosition == oldRightPosition; - assert oldLeftPosition < mid; - // This is true, but uncommenting it causes the proof to fail - // leaving it here it make what is going on a little more clear - // assert right[lo..iter][|right[lo..iter]| - 1] == right[oldIter]; - assert left[leftPosition..mid][0] == left[leftPosition]; + assert Relations.SortedBy(left[oldLeftPosition..mid], lessThanOrEq); + assert lessThanOrEq(left[oldLeftPosition..mid][0], left[oldLeftPosition..mid][1]); } } BelowIsTransitive(right[lo..iter], left[leftPosition..mid], lessThanOrEq); - assert 0 < |right[lo..iter]| && 0 < |left[rightPosition..hi]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[rightPosition..hi][0]) by { if 0 == |right[lo..iter]| || 0 == |left[rightPosition..hi]| { } else { @@ -279,10 +391,15 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { } } BelowIsTransitive(right[lo..iter], left[rightPosition..hi], lessThanOrEq); + + assert multiset(right[lo..iter]) == multiset(left[lo..leftPosition] + left[mid..rightPosition]) by { + // Dafny just wants to be reminded + } } } assert multiset(right[lo..hi]) == multiset(old(left[lo..hi])) by { assert leftPosition == mid && rightPosition == hi; + LemmaSplitAt(left[..], lo as nat, mid as nat, hi as nat); assert old(left[lo..hi]) == left[lo..hi] == left[lo..mid] + left[mid..hi]; } } @@ -362,28 +479,37 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { assert a == a' == b' == b; } - // These are bounded implementations of the above. - // They do exactly the same thing, - // but they use `uint64`. - // This further speeds things up - // because math with bounded variables - // is significantly faster that math with big numbers. + lemma LemmaNewFirstElementStillSortedBy(x: T, s: seq, lessThan: (T, T) -> bool) + requires Relations.SortedBy(s, lessThan) + requires |s| == 0 || lessThan(x, s[0]) + requires Relations.TotalOrdering(lessThan) + ensures Relations.SortedBy([x] + s, lessThan) + {} - method {:isolate_assertions} BoundedMergeSortMethod( + lemma LemmaSplitAt(s: seq, lo: nat, split: nat, hi: nat) + requires 0 <= lo + requires lo <= split + requires split <= hi + requires hi <= |s| + ensures s[lo..hi] == s[lo..split] + s[split..hi] + {} + + // This is the nat version of merge sort. + // This is an exact copy of the bounded integer implementation above + // but with `nat` instead of BoundedInts.uint64. + + method {:isolate_assertions} NatMergeSortMethod( left: array, right: array, lessThanOrEq: (T, T) -> bool, - lo: uint64, - hi: uint64, + lo: nat, + hi: nat, where: PlaceResults ) returns (resultPlacement: ResultPlacement) requires Relations.TotalOrdering(lessThanOrEq) - requires - && left.Length < UINT64_LIMIT - && right.Length < UINT64_LIMIT - requires lo < hi <= left.Length as uint64 - requires hi <= right.Length as uint64 && left != right + requires lo < hi <= left.Length + requires hi <= right.Length && left != right reads left, right modifies left, right ensures !where.Either? ==> where == resultPlacement @@ -411,42 +537,51 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { ghost var beforeWork := multiset(left[lo..hi]); var mid := ((hi - lo)/2) + lo; - var placement? := BoundedMergeSortMethod(left, right, lessThanOrEq, lo, mid, Either); + var placement? := NatMergeSortMethod(left, right, lessThanOrEq, lo, mid, Either); assert left[mid..hi] == old(left[mid..hi]); - ghost var placement2? := BoundedMergeSortMethod(left, right, lessThanOrEq, mid, hi, placement?); + ghost var placement2? := NatMergeSortMethod(left, right, lessThanOrEq, mid, hi, placement?); assert placement2? == placement?; ghost var preMergeResult := if placement?.Left? then left else right; calc { multiset(preMergeResult[lo..hi]); - == { assert preMergeResult[lo..hi] == preMergeResult[lo..mid] + preMergeResult[mid..hi]; } + == { LemmaSplitAt(preMergeResult[..], lo as nat, mid as nat, hi as nat); } multiset(preMergeResult[lo..mid] + preMergeResult[mid..hi]); == multiset(old(left[lo..mid]) + old(left[mid..hi])); - == { assert old(left[lo..hi]) == old(left[lo..mid]) + old(left[mid..hi]); } + == { LemmaSplitAt(old(left[..]), lo as nat, mid as nat, hi as nat); } beforeWork; } ghost var mergedResult; if placement?.Left? { - BoundedMergeIntoRight(left := left, right := right, lessThanOrEq := lessThanOrEq, lo := lo, mid := mid, hi := hi); + NatMergeIntoRight(left := left, right := right, lessThanOrEq := lessThanOrEq, lo := lo, mid := mid, hi := hi); resultPlacement, mergedResult := Right, right[lo..hi]; - assert left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]); + assert left[hi..] == old(left[hi..]); + assert right[hi..] == old(right[hi..]); + assert left[..lo] == old(left[..lo]); + assert right[..lo] == old(right[..lo]); } else { assert placement?.Right?; - BoundedMergeIntoRight(left := right, right := left, lessThanOrEq := lessThanOrEq, lo := lo, mid := mid, hi := hi); + NatMergeIntoRight(left := right, right := left, lessThanOrEq := lessThanOrEq, lo := lo, mid := mid, hi := hi); resultPlacement, mergedResult := Left, left[lo..hi]; - assert left[hi..] == old(left[hi..]) && right[hi..] == old(right[hi..]); + assert left[hi..] == old(left[hi..]); + assert right[hi..] == old(right[hi..]); + assert left[..lo] == old(left[..lo]); + assert right[..lo] == old(right[..lo]); } label BEFORE_RETURN: - assert left[..lo] == old(left[..lo]) && right[..lo] == old(right[..lo]); + assert left[hi..] == old(left[hi..]); + assert right[hi..] == old(right[hi..]); + assert left[..lo] == old(left[..lo]); + assert right[..lo] == old(right[..lo]); if resultPlacement.Left? && where == Right { // A forall comprehension might seem like a nice fit here, // however this does not good for two reasons. - // First, Dafny currently creates a range fur the full bounds of the bounded number + // First, Dafny currently creates a range for the full bounds of the bounded number // see: https://github.com/dafny-lang/dafny/issues/5897 // Second this would create two loops. // First loop would create the `lo to hi` range of numbers. @@ -459,6 +594,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { invariant right[lo..i] == left[lo..i] { right[i] := left[i]; + assert right[lo..i] == left[lo..i]; } assert right[lo..hi] == mergedResult by { @@ -477,6 +613,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { invariant left[lo..i] == right[lo..i] { left[i] := right[i]; + assert right[lo..i] == left[lo..i]; } assert left[lo..hi] == mergedResult by { @@ -489,26 +626,24 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { } } - method {:isolate_assertions} BoundedMergeIntoRight( + method {:isolate_assertions} NatMergeIntoRight( nameonly left: array, nameonly right: array, nameonly lessThanOrEq: (T, T) -> bool, - nameonly lo: uint64, - nameonly mid: uint64, - nameonly hi: uint64 + nameonly lo: nat, + nameonly mid: nat, + nameonly hi: nat ) requires Relations.TotalOrdering(lessThanOrEq) - requires - && left.Length < UINT64_LIMIT - && right.Length < UINT64_LIMIT - requires lo <= mid <= hi <= left.Length as uint64 - requires hi <= right.Length as uint64 && left != right + requires lo <= mid <= hi <= left.Length + requires hi <= right.Length && left != right // We store "left" in [lo..mid] requires Relations.SortedBy(left[lo..mid], lessThanOrEq) // We store "right" in [mid..hi] requires Relations.SortedBy(left[mid..hi], lessThanOrEq) - reads left, right + // reads left, right modifies right + reads left, right // We do not modify anything before lo ensures right[..lo] == old(right[..lo]) && left[..lo] == old(left[..lo]) // We do not modify anything above hi @@ -543,8 +678,18 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { BelowIsTransitive(right[lo..iter], left[leftPosition..mid], lessThanOrEq); - assert 0 < |right[lo..iter]| && 0 < |left[rightPosition..hi]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[rightPosition..hi][0]); + assert 0 < |right[lo..iter]| && 0 < |left[rightPosition..hi]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[rightPosition..hi][0]) by { + if 0 == |right[lo..iter]| || 0 == |left[rightPosition..hi]| { + } else { + assert Relations.SortedBy(left[oldRightPosition..hi], lessThanOrEq); + assert lessThanOrEq(left[oldRightPosition..hi][0], left[oldRightPosition..hi][1]); + } + } BelowIsTransitive(right[lo..iter], left[rightPosition..hi], lessThanOrEq); + + assert multiset(right[lo..iter]) == multiset(left[lo..leftPosition] + left[mid..rightPosition]) by { + // Dafny just wants to be reminded + } } else { right[iter] := left[leftPosition]; @@ -554,17 +699,30 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { assert 0 < |right[lo..iter]| && 0 < |left[leftPosition..mid]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[leftPosition..mid][0]) by { if 0 == |right[lo..iter]| || 0 == |left[leftPosition..mid]| { } else { - assert right[lo..iter][|right[lo..iter]| - 1] == right[iter - 1]; - assert left[leftPosition..mid][0] == left[leftPosition]; + assert Relations.SortedBy(left[oldLeftPosition..mid], lessThanOrEq); + assert lessThanOrEq(left[oldLeftPosition..mid][0], left[oldLeftPosition..mid][1]); } } BelowIsTransitive(right[lo..iter], left[leftPosition..mid], lessThanOrEq); + assert 0 < |right[lo..iter]| && 0 < |left[rightPosition..hi]| ==> lessThanOrEq(right[lo..iter][|right[lo..iter]| - 1], left[rightPosition..hi][0]) by { + if 0 == |right[lo..iter]| || 0 == |left[rightPosition..hi]| { + } else { + assert right[lo..iter][|right[lo..iter]| - 1] == right[iter - 1]; + assert left[rightPosition..hi][0] == left[rightPosition]; + } + } BelowIsTransitive(right[lo..iter], left[rightPosition..hi], lessThanOrEq); + + assert multiset(right[lo..iter]) == multiset(left[lo..leftPosition] + left[mid..rightPosition]) by { + // Dafny just wants to be reminded + } } } assert multiset(right[lo..hi]) == multiset(old(left[lo..hi])) by { assert leftPosition == mid && rightPosition == hi; + LemmaSplitAt(left[..], lo as nat, mid as nat, hi as nat); assert old(left[lo..hi]) == left[lo..hi] == left[lo..mid] + left[mid..hi]; } } + } diff --git a/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy b/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy index ae86e268b..f4a5fbde2 100644 --- a/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy +++ b/DynamoDbEncryption/dafny/StructuredEncryption/src/SortCanon.dfy @@ -311,7 +311,7 @@ module SortCanon { ret } by method { AuthBelowIsTotal(); - result := OptimizedMergeSort.FastMergeSort(x, AuthBelow); + result := OptimizedMergeSort.MergeSortNat(x, AuthBelow); CanonAuthListMultiNoDup(x, result); assert CanonAuthListHasNoDuplicates(result); } @@ -331,7 +331,7 @@ module SortCanon { ret } by method { CryptoBelowIsTotal(); - result := OptimizedMergeSort.FastMergeSort(x, CryptoBelow); + result := OptimizedMergeSort.MergeSortNat(x, CryptoBelow); CanonCryptoListMultiNoDup(x, result); assert CanonCryptoListHasNoDuplicates(result); } From 286e7c7a46bf6f08013eab6655ecb43bb2535341 Mon Sep 17 00:00:00 2001 From: seebees Date: Wed, 12 Feb 2025 15:30:43 -0800 Subject: [PATCH 10/10] reads on methods not enabled --- .../StructuredEncryption/src/OptimizedMergeSort.dfy | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy index f85b5cb82..2ed337996 100644 --- a/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy +++ b/DynamoDbEncryption/dafny/StructuredEncryption/src/OptimizedMergeSort.dfy @@ -188,7 +188,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { requires lo < hi <= left.Length as BoundedInts.uint64 requires hi <= right.Length as BoundedInts.uint64 requires left != right - reads left, right + // reads left, right modifies left, right ensures !where.Either? ==> where == resultPlacement @@ -321,7 +321,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { requires Relations.SortedBy(left[lo..mid], lessThanOrEq) // We store "right" in [mid..hi] requires Relations.SortedBy(left[mid..hi], lessThanOrEq) - reads left, right + // reads left, right modifies right // We do not modify anything before lo ensures right[..lo] == old(right[..lo]) && left[..lo] == old(left[..lo]) @@ -510,7 +510,7 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { requires Relations.TotalOrdering(lessThanOrEq) requires lo < hi <= left.Length requires hi <= right.Length && left != right - reads left, right + // reads left, right modifies left, right ensures !where.Either? ==> where == resultPlacement @@ -641,9 +641,8 @@ module {:options "-functionSyntax:4"} OptimizedMergeSort { requires Relations.SortedBy(left[lo..mid], lessThanOrEq) // We store "right" in [mid..hi] requires Relations.SortedBy(left[mid..hi], lessThanOrEq) - // reads left, right modifies right - reads left, right + // reads left, right // We do not modify anything before lo ensures right[..lo] == old(right[..lo]) && left[..lo] == old(left[..lo]) // We do not modify anything above hi