Skip to content

Commit

Permalink
Merge pull request #8524 from diffblue/smt2-onehot
Browse files Browse the repository at this point in the history
SMT2: support `onehot` and `onehot0`
  • Loading branch information
kroening authored Dec 3, 2024
2 parents c902db3 + 7e39d04 commit 7eef276
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/solvers/flattening/boolbv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,10 @@ literalt boolbvt::convert_rest(const exprt &expr)
expr.id()==ID_reduction_nor || expr.id()==ID_reduction_nand ||
expr.id()==ID_reduction_xor || expr.id()==ID_reduction_xnor)
return convert_reduction(to_unary_expr(expr));
else if(expr.id()==ID_onehot || expr.id()==ID_onehot0)
return convert_onehot(to_unary_expr(expr));
else if(expr.id() == ID_onehot)
return convert_onehot(to_onehot_expr(expr));
else if(expr.id() == ID_onehot0)
return convert_onehot(to_onehot0_expr(expr));
else if(
const auto binary_overflow =
expr_try_dynamic_cast<binary_overflow_exprt>(expr))
Expand Down
8 changes: 8 additions & 0 deletions src/solvers/smt2/smt2_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,14 @@ void smt2_convt::convert_expr(const exprt &expr)
out << ")) #b1)"; // bvlshr, extract, =
}
}
else if(expr.id() == ID_onehot)
{
convert_expr(to_onehot_expr(expr).lower());
}
else if(expr.id() == ID_onehot0)
{
convert_expr(to_onehot0_expr(expr).lower());
}
else if(expr.id()==ID_extractbits)
{
const extractbits_exprt &extractbits_expr = to_extractbits_expr(expr);
Expand Down
34 changes: 34 additions & 0 deletions src/util/bitvector_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,37 @@ exprt zero_extend_exprt::lower() const
return extractbits_exprt{op(), 0, type()};
}
}

static exprt onehot_lowering(const exprt &expr)
{
exprt one_seen = false_exprt{};
const auto width = to_bitvector_type(expr.type()).get_width();
exprt::operandst more_than_one_seen_disjuncts;
more_than_one_seen_disjuncts.reserve(width);

for(std::size_t i = 0; i < width; i++)
{
auto bit = extractbit_exprt{expr, i};
more_than_one_seen_disjuncts.push_back(and_exprt{bit, one_seen});
one_seen = or_exprt{one_seen, bit};
}

auto more_than_one_seen = disjunction(more_than_one_seen_disjuncts);

return and_exprt{one_seen, not_exprt{more_than_one_seen}};
}

exprt onehot_exprt::lower() const
{
auto symbol = symbol_exprt{"onehot-op", op().type()};

return let_exprt{symbol, op(), onehot_lowering(symbol)};
}

exprt onehot0_exprt::lower() const
{
auto symbol = symbol_exprt{"onehot-op", op().type()};

// same as onehot, but on flipped operand bits
return let_exprt{symbol, bitnot_exprt{op()}, onehot_lowering(symbol)};
}
70 changes: 70 additions & 0 deletions src/util/bitvector_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1742,4 +1742,74 @@ inline zero_extend_exprt &to_zero_extend_expr(exprt &expr)
return static_cast<zero_extend_exprt &>(expr);
}

/// \brief A Boolean expression returning true iff the given
/// operand consists of exactly one '1' and '0' otherwise.
class onehot_exprt : public unary_predicate_exprt
{
public:
explicit onehot_exprt(exprt _op)
: unary_predicate_exprt(ID_onehot, std::move(_op))
{
}

/// lowering to extractbit
exprt lower() const;
};

/// \brief Cast an exprt to a \ref onehot_exprt
///
/// \a expr must be known to be \ref onehot_exprt.
///
/// \param expr: Source expression
/// \return Object of type \ref onehot_exprt
inline const onehot_exprt &to_onehot_expr(const exprt &expr)
{
PRECONDITION(expr.id() == ID_onehot);
onehot_exprt::check(expr);
return static_cast<const onehot_exprt &>(expr);
}

/// \copydoc to_onehot_expr(const exprt &)
inline onehot_exprt &to_onehot_expr(exprt &expr)
{
PRECONDITION(expr.id() == ID_onehot);
onehot_exprt::check(expr);
return static_cast<onehot_exprt &>(expr);
}

/// \brief A Boolean expression returning true iff the given
/// operand consists of exactly one '0' and '1' otherwise.
class onehot0_exprt : public unary_predicate_exprt
{
public:
explicit onehot0_exprt(exprt _op)
: unary_predicate_exprt(ID_onehot0, std::move(_op))
{
}

/// lowering to extractbit
exprt lower() const;
};

/// \brief Cast an exprt to a \ref onehot0_exprt
///
/// \a expr must be known to be \ref onehot0_exprt.
///
/// \param expr: Source expression
/// \return Object of type \ref onehot0_exprt
inline const onehot0_exprt &to_onehot0_expr(const exprt &expr)
{
PRECONDITION(expr.id() == ID_onehot0);
onehot0_exprt::check(expr);
return static_cast<const onehot0_exprt &>(expr);
}

/// \copydoc to_onehot0_expr(const exprt &)
inline onehot0_exprt &to_onehot0_expr(exprt &expr)
{
PRECONDITION(expr.id() == ID_onehot0);
onehot0_exprt::check(expr);
return static_cast<onehot0_exprt &>(expr);
}

#endif // CPROVER_UTIL_BITVECTOR_EXPR_H
78 changes: 78 additions & 0 deletions unit/util/bitvector_expr.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
// Author: Diffblue Ltd.

#include <util/arith_tools.h>
#include <util/bitvector_expr.h>
#include <util/bitvector_types.h>
#include <util/cout_message.h>
#include <util/namespace.h>
#include <util/std_expr.h>
#include <util/symbol_table.h>

#include <solvers/flattening/boolbv.h>
#include <solvers/sat/satcheck.h>
#include <testing-utils/use_catch.h>

TEST_CASE(
Expand Down Expand Up @@ -64,3 +71,74 @@ TEMPLATE_TEST_CASE(
}
}
}

TEST_CASE("onehot expression lowering", "[core][util][expr]")
{
console_message_handlert message_handler;
message_handler.set_verbosity(0);
satcheckt satcheck{message_handler};
symbol_tablet symbol_table;
namespacet ns{symbol_table};
boolbvt boolbv{ns, satcheck, message_handler};
unsignedbv_typet u8{8};

GIVEN("A bit-vector that is one-hot")
{
boolbv << onehot_exprt{from_integer(64, u8)}.lower();

THEN("the lowering of onehot is true")
{
REQUIRE(boolbv() == decision_proceduret::resultt::D_SATISFIABLE);
}
}

GIVEN("A bit-vector that is not one-hot")
{
boolbv << onehot_exprt{from_integer(5, u8)}.lower();

THEN("the lowering of onehot is false")
{
REQUIRE(boolbv() == decision_proceduret::resultt::D_UNSATISFIABLE);
}
}

GIVEN("A bit-vector that is not one-hot")
{
boolbv << onehot_exprt{from_integer(0, u8)}.lower();

THEN("the lowering of onehot is false")
{
REQUIRE(boolbv() == decision_proceduret::resultt::D_UNSATISFIABLE);
}
}

GIVEN("A bit-vector that is one-hot 0")
{
boolbv << onehot0_exprt{from_integer(0xfe, u8)}.lower();

THEN("the lowering of onehot0 is true")
{
REQUIRE(boolbv() == decision_proceduret::resultt::D_SATISFIABLE);
}
}

GIVEN("A bit-vector that is not one-hot 0")
{
boolbv << onehot0_exprt{from_integer(0x7e, u8)}.lower();

THEN("the lowering of onehot0 is false")
{
REQUIRE(boolbv() == decision_proceduret::resultt::D_UNSATISFIABLE);
}
}

GIVEN("A bit-vector that is not one-hot 0")
{
boolbv << onehot0_exprt{from_integer(0xff, u8)}.lower();

THEN("the lowering of onehot0 is false")
{
REQUIRE(boolbv() == decision_proceduret::resultt::D_UNSATISFIABLE);
}
}
}
2 changes: 2 additions & 0 deletions unit/util/module_dependencies.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
testing-utils
util
solvers/flattening
solvers/sat

0 comments on commit 7eef276

Please sign in to comment.