Skip to content

Commit

Permalink
Added a TestVerifier and used to detect and fix bugs in witness rec…
Browse files Browse the repository at this point in the history
…onstruction (#778)

* Added a TestVerifier and used to detect and fix bugs in witness reconstruction

* Update plan files

* Sort after extracting from HashMap to get deterministic order

* Only run TestVerifier tests on Linux

* Fixed bug in CI
  • Loading branch information
MatthewDaggitt authored Feb 9, 2024
1 parent 4565d0b commit 8fcf08c
Show file tree
Hide file tree
Showing 83 changed files with 359,252 additions and 72,016 deletions.
26 changes: 23 additions & 3 deletions .github/workflows/build-vehicle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
version: "3.10.2.1"
project-file: "cabal.project.ghc-9.4.8"
extra-args: ""
golden-test-args: ""
include:
# Build with GHC 9.8.1
- os:
Expand All @@ -37,6 +38,8 @@ jobs:
version: "3.10.2.1"
project-file: "cabal.project.ghc-9.8.1"
extra-args: ""
# The TestVerifier uses bash so only test on a single Linux version.
golden-test-args: "--test-option='--allowlist-externals=TestVerifier'"
# Build with GHC 9.6.4
- os:
name: Linux
Expand All @@ -48,6 +51,7 @@ jobs:
version: "3.10.2.1"
project-file: "cabal.project.ghc-9.6.4"
extra-args: ""
golden-test-args: ""
# Build with GHC 9.2.8
- os:
name: Linux
Expand All @@ -59,6 +63,7 @@ jobs:
version: "3.8.1.0"
project-file: "cabal.project.ghc-9.2.8"
extra-args: ""
golden-test-args: ""
# Build with GHC 9.0.2
- os:
name: Linux
Expand All @@ -70,6 +75,7 @@ jobs:
version: "3.8.1.0"
project-file: "cabal.project.ghc-9.0.2"
extra-args: ""
golden-test-args: ""
# Build with GHC 8.10.7
- os:
name: Linux
Expand All @@ -81,6 +87,7 @@ jobs:
version: "3.8.1.0"
project-file: "cabal.project.ghc-8.10.7"
extra-args: ""
golden-test-args: ""
# Build with -fnothunks:
- os:
name: Linux
Expand All @@ -92,6 +99,7 @@ jobs:
version: "3.10.2.1"
project-file: "cabal.project.nothunks.ghc-9.4.8"
extra-args: "-fnothunks"
golden-test-args: ""
# 20-12-2022:
# This test is disabled because -fghc-debug requires two threads, which triggers #342.
# Build with -fghc-debug:
Expand Down Expand Up @@ -134,22 +142,34 @@ jobs:
cabal-project-file: ${{ matrix.haskell.cabal.project-file }}
cabal-project-freeze-file: ${{ matrix.haskell.cabal.project-file }}.freeze

- name: Test Vehicle
- name: Run Vehicle unit tests
run: |
cabal test \
vehicle:test:unit-tests \
vehicle:test:golden-tests \
--test-show-details=always \
--test-option=--color=always \
--test-option=--num-threads=1 \
--project-file=${{ matrix.haskell.cabal.project-file }} \
${{ matrix.haskell.cabal.extra-args }}
shell: sh

- name: Run Vehicle golden tests
run: |
cabal test \
vehicle:test:golden-tests \
--test-show-details=always \
--test-option=--color=always \
--test-option=--num-threads=1 \
${{ matrix.haskell.cabal.golden-test-args }} \
--project-file=${{ matrix.haskell.cabal.project-file }} \
${{ matrix.haskell.cabal.extra-args }} \
shell: sh

- name: Build Vehicle
run: |
mkdir -p bin
cabal install vehicle:exe:vehicle \
cabal install \
vehicle:exe:vehicle \
--overwrite-policy=always \
--install-method=copy \
--installdir=bin \
Expand Down
2 changes: 2 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* Added new command-line option `--verifier-args` to `verify` mode that allows extra
arguments to be passed directly to the verifier.

* Fixed bug when reconstructing witnesses using Fourier-Motzkin elimination.

## Version 0.11.1

* Fixed bug properties involving the comparison of abstract `Index` values would throw
Expand Down
26 changes: 21 additions & 5 deletions vehicle/src/Vehicle/Backend/Queries/PostProcessing.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import Control.Monad (foldM, forM, unless, when)
import Control.Monad.Reader (MonadReader (..))
import Control.Monad.State (get)
import Data.Either (partitionEithers)
import Data.HashMap.Strict qualified as HashMap (toList)
import Data.LinkedHashMap qualified as LinkedHashMap
import Data.List (sortOn)
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NonEmpty
import Data.Map (Map)
Expand Down Expand Up @@ -36,7 +38,8 @@ convertPartitionsToQueries ::
Partitions ->
m (DisjunctAll (MetaNetwork, UserVariableReconstruction, QueryContents))
convertPartitionsToQueries partitions = do
allQueries <- forM (partitionsToDisjuncts partitions) $ \(userVarSol, assertionTree) -> do
allQueries <- forM (partitionsToDisjuncts partitions) $ \(reconstruction, assertionTree) -> do
fullReconstruction <- reconstructNetworkTensorVars reconstruction
networkVarAssertions <- convertToNetworkRatVarAssertions assertionTree
let dnfTree = exprToDNF networkVarAssertions
forM dnfTree $ \conjuncts -> do
Expand All @@ -45,7 +48,7 @@ convertPartitionsToQueries partitions = do
-- Compile queries to particular format
let contents = prettifyQueryContents (variables metaNetwork) newConjuncts
-- Return the result
return (metaNetwork, userVarSol, contents)
return (metaNetwork, fullReconstruction, contents)
return $ disjunctDisjuncts allQueries

-- This is separated from `convertPartitionsToQueries` above because for
Expand All @@ -64,6 +67,19 @@ compileQueryToFormat (metaNetwork, userVars, contents@QueryContents {..}) = do
queryText <- formatQuery queryFormat queryAddress contents
return (queryMetaData, queryText)

--------------------------------------------------------------------------------
-- Step 0: Add reconstruction steps for network tensor variables.

reconstructNetworkTensorVars ::
(MonadQueryStructure m) =>
UserVariableReconstruction ->
m UserVariableReconstruction
reconstructNetworkTensorVars solutions = do
GlobalCtx {..} <- get
let networkTensorVars = sortOn fst $ HashMap.toList $ networkVariableReductions
let mkStep (var, (ratVars, _)) = ReconstructTensor (NetworkTensorVar var) (fmap NetworkRationalVar ratVars)
return $ foldr (\v -> (mkStep v :)) solutions networkTensorVars

--------------------------------------------------------------------------------
-- Step 1: Reduce tensor equalities to a series of rational equalities and
-- checks that the expression only contains network variables.
Expand All @@ -83,9 +99,9 @@ convertToNetworkRatVarAssertions = go

convert :: Assertion -> m (BooleanExpr QueryAssertion)
convert = \case
TensorEq tensorEquality -> do
rationalEqualities <- reduceTensorEquality tensorEquality
let assertions = fmap (Query . RationalEq) rationalEqualities
TensorEq (TensorEquality tensorEquality) -> do
rationalEqualities <- reduceTensorExpr tensorEquality
let assertions = fmap (Query . RationalEq . RationalEquality) rationalEqualities
go $ Conjunct $ ConjunctAll (NonEmpty.fromList assertions)
RationalEq (RationalEquality expr) ->
Query <$> makeQueryAssertion Equal expr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ compileTensorAssertion spinePrefix x y = do
let maybeAssertion = liftA2 tensorEqToAssertion x' y'
case maybeAssertion of
Just assertion -> return $ mkTrivialPartition assertion
Nothing -> compileBoolExpr =<< appStdlibDef StdEqualsVector (spinePrefix <> (Arg mempty Explicit Relevant <$> [x, y]))
Nothing -> do
logDebug MaxDetail $ "Unable to solve tensor equality so reducing to rational equalities"
compileBoolExpr =<< appStdlibDef StdEqualsVector (spinePrefix <> (Arg mempty Explicit Relevant <$> [x, y]))

compileTensorLinearExpr ::
forall m.
Expand Down
101 changes: 46 additions & 55 deletions vehicle/src/Vehicle/Backend/Queries/UserVariableElimination/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,13 @@ mapAssertionExprs ft fr ass = checkTriviality $ case ass of
RationalIneq RationalInequality {..} -> RationalIneq $ RationalInequality strictness (fr rationalIneqExpr)

substituteTensorEq ::
(OriginalUserVariable, TensorEquality) ->
Map RationalVariable RationalEquality ->
(OriginalUserVariable, LinearExpr TensorVariable RationalTensor) ->
Map RationalVariable (LinearExpr RationalVariable Rational) ->
Assertion ->
MaybeTrivial Assertion
substituteTensorEq (var, solution) ratSolutions =
mapAssertionExprs
(eliminateVar (UserTensorVar var) (tensorEqExpr solution))
(eliminateVar (UserTensorVar var) solution)
eliminateRatVars
where
-- Usually the expression being substituted into is much smaller than the number of tensor
Expand All @@ -374,11 +374,11 @@ substituteTensorEq (var, solution) ratSolutions =
let vc = Sparse (Map.singleton v c) 0
case Map.lookup v ratSolutions of
Nothing -> vc
Just sol -> eliminateVar v (rationalEqExpr sol) vc
Just sol -> eliminateVar v sol vc

substituteRationalEq :: UserRationalVariable -> RationalEquality -> Assertion -> MaybeTrivial Assertion
substituteRationalEq :: UserRationalVariable -> LinearExpr RationalVariable Rational -> Assertion -> MaybeTrivial Assertion
substituteRationalEq var solution =
mapAssertionExprs id (eliminateVar (UserRationalVar var) (rationalEqExpr solution))
mapAssertionExprs id (eliminateVar (UserRationalVar var) solution)

--------------------------------------------------------------------------------
-- Partitions
Expand All @@ -388,10 +388,10 @@ type AssertionTree = BooleanExpr Assertion
-- | One step in the process for transforming unreduced user variables into
-- reduced network input and output variables.
data UserVariableReconstructionStep
= SolveTensorEquality OriginalUserVariable TensorEquality
| SolveRationalEquality UserRationalVariable RationalEquality
= SolveTensorEquality OriginalUserVariable (LinearExpr TensorVariable RationalTensor)
| SolveRationalEquality UserRationalVariable (LinearExpr RationalVariable Rational)
| SolveRationalInequalities UserRationalVariable FourierMotzkinVariableSolution
| ReconstructTensor OriginalUserVariable [UserRationalVariable]
| ReconstructTensor TensorVariable [RationalVariable]
deriving (Eq, Ord, Show, Generic)

instance ToJSON UserVariableReconstructionStep
Expand All @@ -400,10 +400,10 @@ instance FromJSON UserVariableReconstructionStep

instance Pretty UserVariableReconstructionStep where
pretty = \case
SolveTensorEquality v _s -> "SolveTensorEquality[" <+> pretty v <+> "]" -- "=" <+> pretty s <+> "]"
SolveRationalEquality v _s -> "SolveRationalEquality[" <+> pretty v <+> "]" -- "=" <+> pretty s <+> "]"
SolveRationalInequalities v _ -> "SolveRationalInequalities[" <+> pretty v <+> "]"
ReconstructTensor v _vs -> "ReconstructTensor[" <+> pretty v <+> "]"
SolveTensorEquality v s -> "Equation:" <+> pretty v <+> "=" <+> pretty s
SolveRationalEquality v s -> "Equation:" <+> pretty v <+> "=" <+> pretty s
SolveRationalInequalities v s -> "Inequalities:" <+> pretty v <+> "bounded" <+> pretty s
ReconstructTensor v vs -> "Reconstruct:" <+> pretty v <+> "from" <+> prettyList vs

-- | The steps for transforming unreduced user variables into reduced network
-- input and output varibles.
Expand All @@ -414,38 +414,9 @@ instance Pretty UserVariableReconstructionStep where
-- The steps are stored in the same order they occured during compilation.
type UserVariableReconstruction = [UserVariableReconstructionStep]

type UserVarSolutions = UserVariableReconstruction

addRationalEqualitySolution ::
UserRationalVariable ->
RationalEquality ->
UserVarSolutions ->
UserVarSolutions
addRationalEqualitySolution var eq solutions =
SolveRationalEquality var eq : solutions

addRationalInequalitySolution ::
UserRationalVariable ->
FourierMotzkinVariableSolution ->
UserVarSolutions ->
UserVarSolutions
addRationalInequalitySolution var eq solutions =
SolveRationalInequalities var eq : solutions

addTensorEqualitySolution ::
OriginalUserVariable ->
TensorEquality ->
UserVarSolutions ->
UserVarSolutions
addTensorEqualitySolution var eq solutions =
SolveTensorEquality var eq : solutions

unionSolutions :: UserVarSolutions -> UserVarSolutions -> UserVarSolutions
unionSolutions = (<>)

type Partition = (UserVarSolutions, AssertionTree)
type Partition = (UserVariableReconstruction, AssertionTree)

newtype Partitions = Partitions (Map UserVarSolutions AssertionTree)
newtype Partitions = Partitions (Map UserVariableReconstruction AssertionTree)

partitionsToDisjuncts :: Partitions -> DisjunctAll Partition
partitionsToDisjuncts (Partitions ps) = DisjunctAll $ NonEmpty.fromList $ Map.toList ps
Expand All @@ -464,7 +435,7 @@ orPartitions :: Partitions -> Partitions -> Partitions
orPartitions (Partitions p1) (Partitions p2) =
Partitions $ Map.unionWith orBoolExpr p1 p2

mkSinglePartition :: (UserVarSolutions, MaybeTrivial AssertionTree) -> MaybeTrivial Partitions
mkSinglePartition :: (UserVariableReconstruction, MaybeTrivial AssertionTree) -> MaybeTrivial Partitions
mkSinglePartition (solutions, maybeAssertion) =
fmap (Partitions . Map.singleton solutions) maybeAssertion

Expand All @@ -474,19 +445,41 @@ mkTrivialPartition assertion = mkSinglePartition (mempty, NonTrivial $ Query ass
--------------------------------------------------------------------------------
-- Variable reconstruction

data FMBound = FMBound
{ boundStrictness :: Strictness,
boundValue :: LinearExpr RationalVariable Rational
}
deriving (Show, Eq, Ord, Generic)

instance ToJSON FMBound

instance FromJSON FMBound

instance Pretty FMBound where
pretty FMBound {..} =
pretty boundValue <> (if boundStrictness == Strict then " (strictly)" else "")

-- | A FM solution for an normalised user variable is two lists of constraints.
-- The variable value must be greater than the first set of assertions, and less than
-- the second set of assertions.
data FourierMotzkinVariableSolution = FMSolution
{ lowerBounds :: [RationalInequality],
upperBounds :: [RationalInequality]
{ lowerBounds :: [FMBound],
upperBounds :: [FMBound]
}
deriving (Show, Eq, Ord, Generic)

instance ToJSON FourierMotzkinVariableSolution

instance FromJSON FourierMotzkinVariableSolution

instance Pretty FourierMotzkinVariableSolution where
pretty FMSolution {..} =
"below by max"
<+> pretty lowerBounds
<+> "and"
<+> "above by min"
<+> pretty upperBounds

--------------------------------------------------------------------------------
-- Monads

Expand Down Expand Up @@ -585,11 +578,11 @@ appStdlibDef fn spine = do
Just fnBody -> normaliseApp fnBody spine
Nothing -> compilerDeveloperError $ "Unexpected found" <+> quotePretty fn <+> "to have no body"

reduceTensorEquality ::
reduceTensorExpr ::
(MonadQueryStructure m) =>
TensorEquality ->
m [RationalEquality]
reduceTensorEquality (TensorEquality (Sparse coeff constant)) = do
LinearExpr TensorVariable RationalTensor ->
m [LinearExpr RationalVariable Rational]
reduceTensorExpr (Sparse coeff constant) = do
let constValues = Vector.toList $ tensorValues constant
let numRatEqs = product (tensorDims constant)
coeffList <- traverse (\(v, c) -> (,c) <$> getReducedVariablesFor v) (Map.toList coeff)
Expand All @@ -600,10 +593,8 @@ reduceTensorEquality (TensorEquality (Sparse coeff constant)) = do
[([RationalVariable], Coefficient)] ->
[Rational] ->
Int ->
RationalEquality
mkRatEquality coeffs consts i = do
let expr = Sparse (Map.fromList (fmap (first (!! i)) coeffs)) (consts !! i)
RationalEquality expr
LinearExpr RationalVariable Rational
mkRatEquality coeffs consts i = Sparse (Map.fromList (fmap (first (!! i)) coeffs)) (consts !! i)

--------------------------------------------------------------------------------
-- Context operations
Expand Down
Loading

0 comments on commit 8fcf08c

Please sign in to comment.