Skip to content

Solvers: Replace uses of namespacet::follow #8235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/solvers/flattening/boolbv_struct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ Author: Daniel Kroening, [email protected]

bvt boolbvt::convert_struct(const struct_exprt &expr)
{
const struct_typet &struct_type=to_struct_type(ns.follow(expr.type()));
const struct_typet &struct_type =
expr.type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(expr.type()))
: to_struct_type(expr.type());

std::size_t width=boolbv_width(struct_type);

Expand Down
14 changes: 10 additions & 4 deletions src/solvers/flattening/boolbv_typecast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,17 +538,23 @@ bool boolbvt::type_conversion(
return false;
}
}
else if(ns.follow(dest_type).id() == ID_struct)
else if(dest_type.id() == ID_struct || dest_type.id() == ID_struct_tag)
{
const struct_typet &dest_struct = to_struct_type(ns.follow(dest_type));
const struct_typet &dest_struct =
dest_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(dest_type))
: to_struct_type(dest_type);

if(ns.follow(src_type).id() == ID_struct)
if(src_type.id() == ID_struct || src_type.id() == ID_struct_tag)
{
// we do subsets

dest.resize(dest_width, const_literal(false));

const struct_typet &op_struct = to_struct_type(ns.follow(src_type));
const struct_typet &op_struct =
src_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(src_type))
: to_struct_type(src_type);

const struct_typet::componentst &dest_comp = dest_struct.components();

Expand Down
12 changes: 8 additions & 4 deletions src/solvers/flattening/boolbv_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ void boolbvt::convert_update_rec(
{
const irep_idt &component_name=designator.get(ID_component_name);

if(ns.follow(type).id() == ID_struct)
if(type.id() == ID_struct || type.id() == ID_struct_tag)
{
const struct_typet &struct_type = to_struct_type(ns.follow(type));
const struct_typet &struct_type =
type.id() == ID_struct_tag ? ns.follow_tag(to_struct_tag_type(type))
: to_struct_type(type);

std::size_t struct_offset=0;

Expand Down Expand Up @@ -144,9 +146,11 @@ void boolbvt::convert_update_rec(
convert_update_rec(
designators, d+1, new_type, new_offset, new_value, bv);
}
else if(ns.follow(type).id() == ID_union)
else if(type.id() == ID_union || type.id() == ID_union_tag)
{
const union_typet &union_type = to_union_type(ns.follow(type));
const union_typet &union_type = type.id() == ID_union_tag
? ns.follow_tag(to_union_tag_type(type))
: to_union_type(type);

const union_typet::componentt &component=
union_type.get_component(component_name);
Expand Down
31 changes: 21 additions & 10 deletions src/solvers/flattening/bv_pointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,18 +312,23 @@ std::optional<bvt> bv_pointerst::convert_address_of_rec(const exprt &expr)
{
const member_exprt &member_expr=to_member_expr(expr);
const exprt &struct_op = member_expr.compound();
const typet &struct_op_type=ns.follow(struct_op.type());

// recursive call
auto bv_opt = convert_address_of_rec(struct_op);
if(!bv_opt.has_value())
return {};

bvt bv = std::move(*bv_opt);
if(struct_op_type.id()==ID_struct)
if(
struct_op.type().id() == ID_struct ||
struct_op.type().id() == ID_struct_tag)
{
auto offset = member_offset(
to_struct_type(struct_op_type), member_expr.get_component_name(), ns);
const struct_typet &struct_op_type =
struct_op.type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(struct_op.type()))
: to_struct_type(struct_op.type());
auto offset =
member_offset(struct_op_type, member_expr.get_component_name(), ns);
CHECK_RETURN(offset.has_value());

// add offset
Expand All @@ -333,7 +338,8 @@ std::optional<bvt> bv_pointerst::convert_address_of_rec(const exprt &expr)
else
{
INVARIANT(
struct_op_type.id() == ID_union,
struct_op.type().id() == ID_union ||
struct_op.type().id() == ID_union_tag,
"member expression should operate on struct or union");
// nothing to do, all members have offset 0
}
Expand Down Expand Up @@ -551,21 +557,26 @@ bvt bv_pointerst::convert_pointer_type(const exprt &expr)
else if(expr.id() == ID_field_address)
{
const auto &field_address_expr = to_field_address_expr(expr);
const typet &compound_type = ns.follow(field_address_expr.compound_type());
const typet &compound_type = field_address_expr.compound_type();

// recursive call
auto bv = convert_bitvector(field_address_expr.base());

if(compound_type.id() == ID_struct)
if(compound_type.id() == ID_struct || compound_type.id() == ID_struct_tag)
{
auto offset = member_offset(
to_struct_type(compound_type), field_address_expr.component_name(), ns);
const struct_typet &struct_type =
compound_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(compound_type))
: to_struct_type(compound_type);
auto offset =
member_offset(struct_type, field_address_expr.component_name(), ns);
CHECK_RETURN(offset.has_value());

// add offset
bv = offset_arithmetic(field_address_expr.type(), bv, *offset);
}
else if(compound_type.id() == ID_union)
else if(
compound_type.id() == ID_union || compound_type.id() == ID_union_tag)
{
// nothing to do, all fields have offset 0
}
Expand Down
40 changes: 28 additions & 12 deletions src/solvers/smt2/smt2_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,10 @@ void smt2_convt::convert_address_of_rec(
struct_op_type.id() == ID_struct || struct_op_type.id() == ID_struct_tag,
"member expression operand shall have struct type");

const struct_typet &struct_type = to_struct_type(ns.follow(struct_op_type));
const struct_typet &struct_type =
struct_op_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(struct_op_type))
: to_struct_type(struct_op_type);

const irep_idt &component_name = member_expr.get_component_name();

Expand Down Expand Up @@ -3159,7 +3162,10 @@ void smt2_convt::convert_floatbv_typecast(const floatbv_typecast_exprt &expr)

void smt2_convt::convert_struct(const struct_exprt &expr)
{
const struct_typet &struct_type = to_struct_type(ns.follow(expr.type()));
const struct_typet &struct_type =
expr.type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(expr.type()))
: to_struct_type(expr.type());

const struct_typet::componentst &components=
struct_type.components();
Expand Down Expand Up @@ -3262,10 +3268,9 @@ void smt2_convt::flatten_array(const exprt &expr)

void smt2_convt::convert_union(const union_exprt &expr)
{
const union_typet &union_type = to_union_type(ns.follow(expr.type()));
const exprt &op=expr.op();

std::size_t total_width=boolbv_width(union_type);
std::size_t total_width = boolbv_width(expr.type());

std::size_t member_width=boolbv_width(op.type());

Expand Down Expand Up @@ -4182,7 +4187,10 @@ void smt2_convt::convert_with(const with_exprt &expr)
}
else if(expr_type.id() == ID_struct || expr_type.id() == ID_struct_tag)
{
const struct_typet &struct_type = to_struct_type(ns.follow(expr_type));
const struct_typet &struct_type =
expr_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(expr_type))
: to_struct_type(expr_type);

const exprt &index = expr.where();
const exprt &value = expr.new_value();
Expand Down Expand Up @@ -4253,11 +4261,9 @@ void smt2_convt::convert_with(const with_exprt &expr)
}
else if(expr_type.id() == ID_union || expr_type.id() == ID_union_tag)
{
const union_typet &union_type = to_union_type(ns.follow(expr_type));

const exprt &value = expr.new_value();

std::size_t total_width=boolbv_width(union_type);
std::size_t total_width = boolbv_width(expr_type);

std::size_t member_width=boolbv_width(value.type());

Expand Down Expand Up @@ -4399,7 +4405,10 @@ void smt2_convt::convert_member(const member_exprt &expr)

if(struct_op_type.id() == ID_struct || struct_op_type.id() == ID_struct_tag)
{
const struct_typet &struct_type = to_struct_type(ns.follow(struct_op_type));
const struct_typet &struct_type =
struct_op_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(struct_op_type))
: to_struct_type(struct_op_type);

INVARIANT(
struct_type.has_component(name), "struct should have accessed component");
Expand Down Expand Up @@ -4496,7 +4505,9 @@ void smt2_convt::flatten2bv(const exprt &expr)
if(use_datatypes)
{
// concatenate elements
const struct_typet &struct_type = to_struct_type(ns.follow(type));
const struct_typet &struct_type =
type.id() == ID_struct_tag ? ns.follow_tag(to_struct_tag_type(type))
: to_struct_type(type);

const struct_typet::componentst &components=
struct_type.components();
Expand Down Expand Up @@ -4622,7 +4633,9 @@ void smt2_convt::unflatten(

out << "(mk-" << smt_typename;

const struct_typet &struct_type = to_struct_type(ns.follow(type));
const struct_typet &struct_type =
type.id() == ID_struct_tag ? ns.follow_tag(to_struct_tag_type(type))
: to_struct_type(type);

const struct_typet::componentst &components=
struct_type.components();
Expand Down Expand Up @@ -5501,8 +5514,11 @@ void smt2_convt::convert_type(const typet &type)
else if(type.id() == ID_union || type.id() == ID_union_tag)
{
std::size_t width=boolbv_width(type);
const union_typet &union_type = type.id() == ID_union_tag
? ns.follow_tag(to_union_tag_type(type))
: to_union_type(type);
CHECK_RETURN_WITH_DIAGNOSTICS(
to_union_type(ns.follow(type)).components().empty() || width != 0,
union_type.components().empty() || width != 0,
"failed to get width of union");

out << "(_ BitVec " << width << ")";
Expand Down
17 changes: 12 additions & 5 deletions src/solvers/smt2_incremental/encoding/struct_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ extricate_updates(const with_exprt &struct_expr)
static exprt encode(const with_exprt &with, const namespacet &ns)
{
const auto tag_type = type_checked_cast<struct_tag_typet>(with.type());
const auto struct_type =
type_checked_cast<struct_typet>(ns.follow(with.type()));
const auto struct_type = ns.follow_tag(tag_type);
const auto updates = extricate_updates(with);
const auto components =
make_range(struct_type.components())
Expand Down Expand Up @@ -194,11 +193,19 @@ static std::size_t count_trailing_bit_width(
/// the combined width of the fields which follow the field being selected.
exprt struct_encodingt::encode_member(const member_exprt &member_expr) const
{
const auto &type = ns.get().follow(member_expr.compound().type());
const auto &compound_type = member_expr.compound().type();
const auto offset_bits = [&]() -> std::size_t {
if(can_cast_type<union_typet>(type))
if(
can_cast_type<union_typet>(compound_type) ||
can_cast_type<union_tag_typet>(compound_type))
{
return 0;
const auto &struct_type = type_checked_cast<struct_typet>(type);
}
const auto &struct_type =
compound_type.id() == ID_struct_tag
? ns.get().follow_tag(
type_checked_cast<struct_tag_typet>(compound_type))
: type_checked_cast<struct_typet>(compound_type);
return count_trailing_bit_width(
struct_type, member_expr.get_component_name(), *boolbv_width);
}();
Expand Down
14 changes: 11 additions & 3 deletions src/solvers/strings/string_refinement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,10 @@ static void add_equations_for_symbol_resolution(
{
if(rhs.type().id() == ID_struct || rhs.type().id() == ID_struct_tag)
{
const struct_typet &struct_type = to_struct_type(ns.follow(rhs.type()));
const struct_typet &struct_type =
rhs.type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(rhs.type()))
: to_struct_type(rhs.type());
for(const auto &comp : struct_type.components())
{
if(is_char_pointer_type(comp.type()))
Expand Down Expand Up @@ -377,7 +380,10 @@ extract_strings_from_lhs(const exprt &lhs, const namespacet &ns)
result.push_back(lhs);
else if(lhs.type().id() == ID_struct || lhs.type().id() == ID_struct_tag)
{
const struct_typet &struct_type = to_struct_type(ns.follow(lhs.type()));
const struct_typet &struct_type =
lhs.type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(lhs.type()))
: to_struct_type(lhs.type());
for(const auto &comp : struct_type.components())
{
const std::vector<exprt> strings_in_comp = extract_strings_from_lhs(
Expand Down Expand Up @@ -439,7 +445,9 @@ static void add_string_equation_to_symbol_resolution(
eq.rhs().type().id() == ID_struct_tag)
{
const struct_typet &struct_type =
to_struct_type(ns.follow(eq.rhs().type()));
eq.rhs().type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(eq.rhs().type()))
: to_struct_type(eq.rhs().type());
for(const auto &comp : struct_type.components())
{
const member_exprt lhs_data(eq.lhs(), comp.get_name(), comp.type());
Expand Down