diff --git a/src/solvers/flattening/boolbv_struct.cpp b/src/solvers/flattening/boolbv_struct.cpp index 9b49a69eb09..b418da6438f 100644 --- a/src/solvers/flattening/boolbv_struct.cpp +++ b/src/solvers/flattening/boolbv_struct.cpp @@ -12,7 +12,10 @@ Author: Daniel Kroening, kroening@kroening.com bvt boolbvt::convert_struct(const struct_exprt &expr) { - const struct_typet &struct_type=to_struct_type(ns.follow(expr.type())); + const struct_typet &struct_type = + expr.type().id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(expr.type())) + : to_struct_type(expr.type()); std::size_t width=boolbv_width(struct_type); diff --git a/src/solvers/flattening/boolbv_typecast.cpp b/src/solvers/flattening/boolbv_typecast.cpp index 6f14f1c6d1a..3be48023153 100644 --- a/src/solvers/flattening/boolbv_typecast.cpp +++ b/src/solvers/flattening/boolbv_typecast.cpp @@ -538,17 +538,23 @@ bool boolbvt::type_conversion( return false; } } - else if(ns.follow(dest_type).id() == ID_struct) + else if(dest_type.id() == ID_struct || dest_type.id() == ID_struct_tag) { - const struct_typet &dest_struct = to_struct_type(ns.follow(dest_type)); + const struct_typet &dest_struct = + dest_type.id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(dest_type)) + : to_struct_type(dest_type); - if(ns.follow(src_type).id() == ID_struct) + if(src_type.id() == ID_struct || src_type.id() == ID_struct_tag) { // we do subsets dest.resize(dest_width, const_literal(false)); - const struct_typet &op_struct = to_struct_type(ns.follow(src_type)); + const struct_typet &op_struct = + src_type.id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(src_type)) + : to_struct_type(src_type); const struct_typet::componentst &dest_comp = dest_struct.components(); diff --git a/src/solvers/flattening/boolbv_update.cpp b/src/solvers/flattening/boolbv_update.cpp index 2a1ce9a5e23..5af9c09874a 100644 --- a/src/solvers/flattening/boolbv_update.cpp +++ b/src/solvers/flattening/boolbv_update.cpp @@ -107,9 +107,11 @@ void boolbvt::convert_update_rec( { const irep_idt &component_name=designator.get(ID_component_name); - if(ns.follow(type).id() == ID_struct) + if(type.id() == ID_struct || type.id() == ID_struct_tag) { - const struct_typet &struct_type = to_struct_type(ns.follow(type)); + const struct_typet &struct_type = + type.id() == ID_struct_tag ? ns.follow_tag(to_struct_tag_type(type)) + : to_struct_type(type); std::size_t struct_offset=0; @@ -144,9 +146,11 @@ void boolbvt::convert_update_rec( convert_update_rec( designators, d+1, new_type, new_offset, new_value, bv); } - else if(ns.follow(type).id() == ID_union) + else if(type.id() == ID_union || type.id() == ID_union_tag) { - const union_typet &union_type = to_union_type(ns.follow(type)); + const union_typet &union_type = type.id() == ID_union_tag + ? ns.follow_tag(to_union_tag_type(type)) + : to_union_type(type); const union_typet::componentt &component= union_type.get_component(component_name); diff --git a/src/solvers/flattening/bv_pointers.cpp b/src/solvers/flattening/bv_pointers.cpp index e40162aa1f2..a022801df4b 100644 --- a/src/solvers/flattening/bv_pointers.cpp +++ b/src/solvers/flattening/bv_pointers.cpp @@ -312,7 +312,6 @@ std::optional bv_pointerst::convert_address_of_rec(const exprt &expr) { const member_exprt &member_expr=to_member_expr(expr); const exprt &struct_op = member_expr.compound(); - const typet &struct_op_type=ns.follow(struct_op.type()); // recursive call auto bv_opt = convert_address_of_rec(struct_op); @@ -320,10 +319,16 @@ std::optional bv_pointerst::convert_address_of_rec(const exprt &expr) return {}; bvt bv = std::move(*bv_opt); - if(struct_op_type.id()==ID_struct) + if( + struct_op.type().id() == ID_struct || + struct_op.type().id() == ID_struct_tag) { - auto offset = member_offset( - to_struct_type(struct_op_type), member_expr.get_component_name(), ns); + const struct_typet &struct_op_type = + struct_op.type().id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(struct_op.type())) + : to_struct_type(struct_op.type()); + auto offset = + member_offset(struct_op_type, member_expr.get_component_name(), ns); CHECK_RETURN(offset.has_value()); // add offset @@ -333,7 +338,8 @@ std::optional bv_pointerst::convert_address_of_rec(const exprt &expr) else { INVARIANT( - struct_op_type.id() == ID_union, + struct_op.type().id() == ID_union || + struct_op.type().id() == ID_union_tag, "member expression should operate on struct or union"); // nothing to do, all members have offset 0 } @@ -551,21 +557,26 @@ bvt bv_pointerst::convert_pointer_type(const exprt &expr) else if(expr.id() == ID_field_address) { const auto &field_address_expr = to_field_address_expr(expr); - const typet &compound_type = ns.follow(field_address_expr.compound_type()); + const typet &compound_type = field_address_expr.compound_type(); // recursive call auto bv = convert_bitvector(field_address_expr.base()); - if(compound_type.id() == ID_struct) + if(compound_type.id() == ID_struct || compound_type.id() == ID_struct_tag) { - auto offset = member_offset( - to_struct_type(compound_type), field_address_expr.component_name(), ns); + const struct_typet &struct_type = + compound_type.id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(compound_type)) + : to_struct_type(compound_type); + auto offset = + member_offset(struct_type, field_address_expr.component_name(), ns); CHECK_RETURN(offset.has_value()); // add offset bv = offset_arithmetic(field_address_expr.type(), bv, *offset); } - else if(compound_type.id() == ID_union) + else if( + compound_type.id() == ID_union || compound_type.id() == ID_union_tag) { // nothing to do, all fields have offset 0 } diff --git a/src/solvers/smt2/smt2_conv.cpp b/src/solvers/smt2/smt2_conv.cpp index 0d63d2863c5..d132114cb22 100644 --- a/src/solvers/smt2/smt2_conv.cpp +++ b/src/solvers/smt2/smt2_conv.cpp @@ -846,7 +846,10 @@ void smt2_convt::convert_address_of_rec( struct_op_type.id() == ID_struct || struct_op_type.id() == ID_struct_tag, "member expression operand shall have struct type"); - const struct_typet &struct_type = to_struct_type(ns.follow(struct_op_type)); + const struct_typet &struct_type = + struct_op_type.id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(struct_op_type)) + : to_struct_type(struct_op_type); const irep_idt &component_name = member_expr.get_component_name(); @@ -3159,7 +3162,10 @@ void smt2_convt::convert_floatbv_typecast(const floatbv_typecast_exprt &expr) void smt2_convt::convert_struct(const struct_exprt &expr) { - const struct_typet &struct_type = to_struct_type(ns.follow(expr.type())); + const struct_typet &struct_type = + expr.type().id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(expr.type())) + : to_struct_type(expr.type()); const struct_typet::componentst &components= struct_type.components(); @@ -3262,10 +3268,9 @@ void smt2_convt::flatten_array(const exprt &expr) void smt2_convt::convert_union(const union_exprt &expr) { - const union_typet &union_type = to_union_type(ns.follow(expr.type())); const exprt &op=expr.op(); - std::size_t total_width=boolbv_width(union_type); + std::size_t total_width = boolbv_width(expr.type()); std::size_t member_width=boolbv_width(op.type()); @@ -4182,7 +4187,10 @@ void smt2_convt::convert_with(const with_exprt &expr) } else if(expr_type.id() == ID_struct || expr_type.id() == ID_struct_tag) { - const struct_typet &struct_type = to_struct_type(ns.follow(expr_type)); + const struct_typet &struct_type = + expr_type.id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(expr_type)) + : to_struct_type(expr_type); const exprt &index = expr.where(); const exprt &value = expr.new_value(); @@ -4253,11 +4261,9 @@ void smt2_convt::convert_with(const with_exprt &expr) } else if(expr_type.id() == ID_union || expr_type.id() == ID_union_tag) { - const union_typet &union_type = to_union_type(ns.follow(expr_type)); - const exprt &value = expr.new_value(); - std::size_t total_width=boolbv_width(union_type); + std::size_t total_width = boolbv_width(expr_type); std::size_t member_width=boolbv_width(value.type()); @@ -4399,7 +4405,10 @@ void smt2_convt::convert_member(const member_exprt &expr) if(struct_op_type.id() == ID_struct || struct_op_type.id() == ID_struct_tag) { - const struct_typet &struct_type = to_struct_type(ns.follow(struct_op_type)); + const struct_typet &struct_type = + struct_op_type.id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(struct_op_type)) + : to_struct_type(struct_op_type); INVARIANT( struct_type.has_component(name), "struct should have accessed component"); @@ -4496,7 +4505,9 @@ void smt2_convt::flatten2bv(const exprt &expr) if(use_datatypes) { // concatenate elements - const struct_typet &struct_type = to_struct_type(ns.follow(type)); + const struct_typet &struct_type = + type.id() == ID_struct_tag ? ns.follow_tag(to_struct_tag_type(type)) + : to_struct_type(type); const struct_typet::componentst &components= struct_type.components(); @@ -4622,7 +4633,9 @@ void smt2_convt::unflatten( out << "(mk-" << smt_typename; - const struct_typet &struct_type = to_struct_type(ns.follow(type)); + const struct_typet &struct_type = + type.id() == ID_struct_tag ? ns.follow_tag(to_struct_tag_type(type)) + : to_struct_type(type); const struct_typet::componentst &components= struct_type.components(); @@ -5501,8 +5514,11 @@ void smt2_convt::convert_type(const typet &type) else if(type.id() == ID_union || type.id() == ID_union_tag) { std::size_t width=boolbv_width(type); + const union_typet &union_type = type.id() == ID_union_tag + ? ns.follow_tag(to_union_tag_type(type)) + : to_union_type(type); CHECK_RETURN_WITH_DIAGNOSTICS( - to_union_type(ns.follow(type)).components().empty() || width != 0, + union_type.components().empty() || width != 0, "failed to get width of union"); out << "(_ BitVec " << width << ")"; diff --git a/src/solvers/smt2_incremental/encoding/struct_encoding.cpp b/src/solvers/smt2_incremental/encoding/struct_encoding.cpp index f57b943ff81..7e8e946d86b 100644 --- a/src/solvers/smt2_incremental/encoding/struct_encoding.cpp +++ b/src/solvers/smt2_incremental/encoding/struct_encoding.cpp @@ -105,8 +105,7 @@ extricate_updates(const with_exprt &struct_expr) static exprt encode(const with_exprt &with, const namespacet &ns) { const auto tag_type = type_checked_cast(with.type()); - const auto struct_type = - type_checked_cast(ns.follow(with.type())); + const auto struct_type = ns.follow_tag(tag_type); const auto updates = extricate_updates(with); const auto components = make_range(struct_type.components()) @@ -194,11 +193,19 @@ static std::size_t count_trailing_bit_width( /// the combined width of the fields which follow the field being selected. exprt struct_encodingt::encode_member(const member_exprt &member_expr) const { - const auto &type = ns.get().follow(member_expr.compound().type()); + const auto &compound_type = member_expr.compound().type(); const auto offset_bits = [&]() -> std::size_t { - if(can_cast_type(type)) + if( + can_cast_type(compound_type) || + can_cast_type(compound_type)) + { return 0; - const auto &struct_type = type_checked_cast(type); + } + const auto &struct_type = + compound_type.id() == ID_struct_tag + ? ns.get().follow_tag( + type_checked_cast(compound_type)) + : type_checked_cast(compound_type); return count_trailing_bit_width( struct_type, member_expr.get_component_name(), *boolbv_width); }(); diff --git a/src/solvers/strings/string_refinement.cpp b/src/solvers/strings/string_refinement.cpp index d9e682ae64c..e1980bed5f4 100644 --- a/src/solvers/strings/string_refinement.cpp +++ b/src/solvers/strings/string_refinement.cpp @@ -341,7 +341,10 @@ static void add_equations_for_symbol_resolution( { if(rhs.type().id() == ID_struct || rhs.type().id() == ID_struct_tag) { - const struct_typet &struct_type = to_struct_type(ns.follow(rhs.type())); + const struct_typet &struct_type = + rhs.type().id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(rhs.type())) + : to_struct_type(rhs.type()); for(const auto &comp : struct_type.components()) { if(is_char_pointer_type(comp.type())) @@ -377,7 +380,10 @@ extract_strings_from_lhs(const exprt &lhs, const namespacet &ns) result.push_back(lhs); else if(lhs.type().id() == ID_struct || lhs.type().id() == ID_struct_tag) { - const struct_typet &struct_type = to_struct_type(ns.follow(lhs.type())); + const struct_typet &struct_type = + lhs.type().id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(lhs.type())) + : to_struct_type(lhs.type()); for(const auto &comp : struct_type.components()) { const std::vector strings_in_comp = extract_strings_from_lhs( @@ -439,7 +445,9 @@ static void add_string_equation_to_symbol_resolution( eq.rhs().type().id() == ID_struct_tag) { const struct_typet &struct_type = - to_struct_type(ns.follow(eq.rhs().type())); + eq.rhs().type().id() == ID_struct_tag + ? ns.follow_tag(to_struct_tag_type(eq.rhs().type())) + : to_struct_type(eq.rhs().type()); for(const auto &comp : struct_type.components()) { const member_exprt lhs_data(eq.lhs(), comp.get_name(), comp.type());