Skip to content

Commit 679ab3f

Browse files
authored
Merge pull request #6879 from thomasspriggs/tas/improved_get_value_response_validation
Add get value response validation for smt bv constant descriptors
2 parents 124c4b8 + 79169e0 commit 679ab3f

File tree

2 files changed

+149
-31
lines changed

2 files changed

+149
-31
lines changed

src/solvers/smt2_incremental/smt_response_validation.cpp

+63-14
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
/// `response_or_errort` in the case where the parse tree is of that type or
1616
/// an empty optional otherwise.
1717

18-
#include <solvers/smt2_incremental/smt_response_validation.h>
18+
#include "smt_response_validation.h"
1919

20+
#include <util/arith_tools.h>
2021
#include <util/mp_arith.h>
2122
#include <util/range.h>
2223

@@ -190,15 +191,33 @@ static bool all_subs_are_pairs(const irept &parse_tree)
190191
[](const irept &sub) { return sub.get_sub().size() == 2; });
191192
}
192193

193-
static response_or_errort<irep_idt>
194-
validate_smt_identifier(const irept &parse_tree)
194+
/// Checks for valid bit vector constants of the form `(_ bv(value) (width))`
195+
/// for example - `(_ bv4 64)`.
196+
static optionalt<smt_termt>
197+
valid_smt_indexed_bit_vector(const irept &parse_tree)
195198
{
196-
if(!parse_tree.get_sub().empty() || parse_tree.id().empty())
197-
{
198-
return response_or_errort<irep_idt>(
199-
"Expected identifier, found - \"" + print_parse_tree(parse_tree) + "\".");
200-
}
201-
return response_or_errort<irep_idt>(parse_tree.id());
199+
if(parse_tree.get_sub().size() != 3)
200+
return {};
201+
if(parse_tree.get_sub().at(0).id() != "_")
202+
return {};
203+
const auto value_string = id2string(parse_tree.get_sub().at(1).id());
204+
std::smatch match_results;
205+
static const std::regex bv_value_regex{R"(^bv(\d+)$)", std::regex::optimize};
206+
if(!std::regex_search(value_string, match_results, bv_value_regex))
207+
return {};
208+
INVARIANT(
209+
match_results.size() == 2,
210+
"Match results should include digits sub-expression if regex is matched.");
211+
const std::string value_digits = match_results[1];
212+
const auto value = string2integer(value_digits);
213+
const auto bit_width_string = id2string(parse_tree.get_sub().at(2).id());
214+
const auto bit_width =
215+
numeric_cast_v<std::size_t>(string2integer(bit_width_string));
216+
if(bit_width == 0)
217+
return {};
218+
if(value >= power(mp_integer{2}, bit_width))
219+
return {};
220+
return smt_bit_vector_constant_termt{value, bit_width};
202221
}
203222

204223
static optionalt<smt_termt> valid_smt_bool(const irept &parse_tree)
@@ -229,7 +248,7 @@ static optionalt<smt_termt> valid_smt_hex(const std::string &text)
229248
if(!std::regex_match(text, hex_format))
230249
return {};
231250
const std::string hex{text.begin() + 2, text.end()};
232-
// SMT-LIB 2 allows hex characters to be upper of lower case, but they should
251+
// SMT-LIB 2 allows hex characters to be upper or lower case, but they should
233252
// be upper case for mp_integer.
234253
const mp_integer value =
235254
string2integer(make_range(hex).map<std::function<int(int)>>(toupper), 16);
@@ -240,6 +259,8 @@ static optionalt<smt_termt> valid_smt_hex(const std::string &text)
240259
static optionalt<smt_termt>
241260
valid_smt_bit_vector_constant(const irept &parse_tree)
242261
{
262+
if(const auto indexed = valid_smt_indexed_bit_vector(parse_tree))
263+
return *indexed;
243264
if(!parse_tree.get_sub().empty() || parse_tree.id().empty())
244265
return {};
245266
const auto value_string = id2string(parse_tree.id());
@@ -250,24 +271,52 @@ valid_smt_bit_vector_constant(const irept &parse_tree)
250271
return {};
251272
}
252273

253-
static response_or_errort<smt_termt> validate_term(const irept &parse_tree)
274+
static optionalt<smt_termt> valid_term(const irept &parse_tree)
254275
{
255276
if(const auto smt_bool = valid_smt_bool(parse_tree))
256-
return response_or_errort<smt_termt>{*smt_bool};
277+
return {*smt_bool};
257278
if(const auto bit_vector_constant = valid_smt_bit_vector_constant(parse_tree))
258-
return response_or_errort<smt_termt>{*bit_vector_constant};
279+
return {*bit_vector_constant};
280+
return {};
281+
}
282+
283+
static response_or_errort<smt_termt> validate_term(const irept &parse_tree)
284+
{
285+
if(const auto term = valid_term(parse_tree))
286+
return response_or_errort<smt_termt>{*term};
259287
return response_or_errort<smt_termt>{"Unrecognised SMT term - \"" +
260288
print_parse_tree(parse_tree) + "\"."};
261289
}
262290

291+
static response_or_errort<smt_termt>
292+
validate_smt_descriptor(const irept &parse_tree, const smt_sortt &sort)
293+
{
294+
if(const auto term = valid_term(parse_tree))
295+
return response_or_errort<smt_termt>{*term};
296+
const auto id = parse_tree.id();
297+
if(!id.empty())
298+
return response_or_errort<smt_termt>{smt_identifier_termt{id, sort}};
299+
return response_or_errort<smt_termt>{
300+
"Expected descriptor SMT term, found - \"" + print_parse_tree(parse_tree) +
301+
"\"."};
302+
}
303+
263304
static response_or_errort<smt_get_value_responset::valuation_pairt>
264305
validate_valuation_pair(const irept &pair_parse_tree)
265306
{
266307
PRECONDITION(pair_parse_tree.get_sub().size() == 2);
267308
const auto &descriptor = pair_parse_tree.get_sub()[0];
268309
const auto &value = pair_parse_tree.get_sub()[1];
310+
const response_or_errort<smt_termt> value_validation = validate_term(value);
311+
if(const auto value_errors = value_validation.get_if_error())
312+
{
313+
return response_or_errort<smt_get_value_responset::valuation_pairt>{
314+
*value_errors};
315+
}
316+
const smt_termt value_term = *value_validation.get_if_valid();
269317
return validation_propagating<smt_get_value_responset::valuation_pairt>(
270-
validate_smt_identifier(descriptor), validate_term(value));
318+
validate_smt_descriptor(descriptor, value_term.get_sort()),
319+
validate_term(value));
271320
}
272321

273322
/// \returns: A response or error in the case where the parse tree appears to be

unit/solvers/smt2_incremental/smt_response_validation.cpp

+86-17
Original file line numberDiff line numberDiff line change
@@ -117,20 +117,90 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]")
117117
}
118118
SECTION("Bit vector sorted values.")
119119
{
120-
const response_or_errort<smt_responset> response_255 =
121-
validate_smt_response(*smt2irep("((a #xff))").parsed_output);
122-
CHECK(
123-
*response_255.get_if_valid() ==
124-
smt_get_value_responset{{smt_get_value_responset::valuation_pairt{
125-
smt_identifier_termt{"a", smt_bit_vector_sortt{8}},
126-
smt_bit_vector_constant_termt{255, 8}}}});
127-
const response_or_errort<smt_responset> response_42 =
128-
validate_smt_response(*smt2irep("((a #b00101010))").parsed_output);
129-
CHECK(
130-
*response_42.get_if_valid() ==
131-
smt_get_value_responset{{smt_get_value_responset::valuation_pairt{
132-
smt_identifier_termt{"a", smt_bit_vector_sortt{8}},
133-
smt_bit_vector_constant_termt{42, 8}}}});
120+
SECTION("Hex value")
121+
{
122+
const response_or_errort<smt_responset> response_255 =
123+
validate_smt_response(*smt2irep("((a #xff))").parsed_output);
124+
CHECK(
125+
*response_255.get_if_valid() ==
126+
smt_get_value_responset{{smt_get_value_responset::valuation_pairt{
127+
smt_identifier_termt{"a", smt_bit_vector_sortt{8}},
128+
smt_bit_vector_constant_termt{255, 8}}}});
129+
}
130+
SECTION("Binary value")
131+
{
132+
const response_or_errort<smt_responset> response_42 =
133+
validate_smt_response(*smt2irep("((a #b00101010))").parsed_output);
134+
CHECK(
135+
*response_42.get_if_valid() ==
136+
smt_get_value_responset{{smt_get_value_responset::valuation_pairt{
137+
smt_identifier_termt{"a", smt_bit_vector_sortt{8}},
138+
smt_bit_vector_constant_termt{42, 8}}}});
139+
}
140+
SECTION("Descriptors which are bit vector constants")
141+
{
142+
const response_or_errort<smt_responset> response_descriptor =
143+
validate_smt_response(*smt2irep("(((_ bv255 8) #x2A))").parsed_output);
144+
CHECK(
145+
*response_descriptor.get_if_valid() ==
146+
smt_get_value_responset{{smt_get_value_responset::valuation_pairt{
147+
smt_bit_vector_constant_termt{255, 8},
148+
smt_bit_vector_constant_termt{42, 8}}}});
149+
SECTION("Invalid bit vector constants")
150+
{
151+
SECTION("Value too large for width")
152+
{
153+
const response_or_errort<smt_responset> pair_value_response =
154+
validate_smt_response(
155+
*smt2irep("(((_ bv256 8) #xff))").parsed_output);
156+
CHECK(
157+
*pair_value_response.get_if_error() ==
158+
std::vector<std::string>{
159+
"Expected descriptor SMT term, found - \"\n"
160+
"0: _\n"
161+
"1: bv256\n"
162+
"2: 8\"."});
163+
}
164+
SECTION("Value missing bv prefix.")
165+
{
166+
const response_or_errort<smt_responset> pair_value_response =
167+
validate_smt_response(*smt2irep("(((_ 42 8) #xff))").parsed_output);
168+
CHECK(
169+
*pair_value_response.get_if_error() ==
170+
std::vector<std::string>{
171+
"Expected descriptor SMT term, found - \"\n"
172+
"0: _\n"
173+
"1: 42\n"
174+
"2: 8\"."});
175+
}
176+
SECTION("Hex value.")
177+
{
178+
const response_or_errort<smt_responset> pair_value_response =
179+
validate_smt_response(
180+
*smt2irep("(((_ bv2A 8) #xff))").parsed_output);
181+
CHECK(
182+
*pair_value_response.get_if_error() ==
183+
std::vector<std::string>{
184+
"Expected descriptor SMT term, found - \"\n"
185+
"0: _\n"
186+
"1: bv2A\n"
187+
"2: 8\"."});
188+
}
189+
SECTION("Zero width.")
190+
{
191+
const response_or_errort<smt_responset> pair_value_response =
192+
validate_smt_response(
193+
*smt2irep("(((_ bv0 0) #xff))").parsed_output);
194+
CHECK(
195+
*pair_value_response.get_if_error() ==
196+
std::vector<std::string>{
197+
"Expected descriptor SMT term, found - \"\n"
198+
"0: _\n"
199+
"1: bv0\n"
200+
"2: 0\"."});
201+
}
202+
}
203+
}
134204
}
135205
SECTION("Multiple valuation pairs.")
136206
{
@@ -174,12 +244,11 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]")
174244
validate_smt_response(*smt2irep("((() true))").parsed_output);
175245
CHECK(
176246
*empty_descriptor_response.get_if_error() ==
177-
std::vector<std::string>{"Expected identifier, found - \"\"."});
247+
std::vector<std::string>{"Expected descriptor SMT term, found - \"\"."});
178248
const response_or_errort<smt_responset> empty_pair =
179249
validate_smt_response(*smt2irep("((() ())))").parsed_output);
180250
CHECK(
181251
*empty_pair.get_if_error() ==
182-
std::vector<std::string>{"Expected identifier, found - \"\".",
183-
"Unrecognised SMT term - \"\"."});
252+
std::vector<std::string>{"Unrecognised SMT term - \"\"."});
184253
}
185254
}

0 commit comments

Comments
 (0)