diff --git a/.github/workflows/test-external.yml b/.github/workflows/test-external.yml index bfdf7ff3..0b7c217e 100644 --- a/.github/workflows/test-external.yml +++ b/.github/workflows/test-external.yml @@ -47,7 +47,7 @@ jobs: profile: "" - repo: "zobront/halmos-solady" dir: "halmos-solady" - cmd: "--function testCheck --solver-command yices-smt2 --solver-threads 3" + cmd: "--function testCheck --solver-command 'bitwuzla --produce-models' --solver-threads 3" branch: "" profile: "" - repo: "pcaversaccio/snekmate" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a21a1780..a517f7ab 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,7 +47,7 @@ jobs: uses: astral-sh/setup-uv@v4 with: # Install a specific version of uv. - version: "0.5.6" + version: "0.5.21" - name: Set up python ${{ matrix.python-version }} run: uv python install ${{ matrix.python-version }} diff --git a/.gitignore b/.gitignore index 2f368a47..de66ab4d 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,6 @@ out/ # - adds friction to CI # (at the cost of reproducible builds) uv.lock + +# https://docs.astral.sh/uv/concepts/projects/layout/#the-build-directory +build/ diff --git a/pyproject.toml b/pyproject.toml index 796b5820..c2566bff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,8 @@ dependencies = [ "z3-solver==4.12.6.0", "eth_hash[pysha3]>=0.7.0", "rich>=13.9.4", - "xxhash>=3.5.0" + "xxhash>=3.5.0", + "psutil>=6.1.0", ] dynamic = ["version"] diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index b6daa409..16d33c10 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: AGPL-3.0 import gc -import io import json import logging import os @@ -11,9 +10,8 @@ import sys import time import traceback -import uuid from collections import Counter -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future from copy import deepcopy from dataclasses import asdict, dataclass from datetime import timedelta @@ -21,34 +19,34 @@ from importlib import metadata from z3 import ( - Z3_OP_CONCAT, BitVec, - BitVecNumRef, - BitVecRef, - Bool, - CheckSatResult, - Context, - ModelRef, - Solver, - is_app, - is_bv, - sat, set_option, - unknown, unsat, ) +from .build import ( + build_output_iterator, + import_libs, + parse_build_out, + parse_devdoc, + parse_natspec, +) from .bytevec import ByteVec from .calldata import FunctionInfo, get_abi, mk_calldata from .config import Config as HalmosConfig from .config import arg_parser, default_config, resolve_config_files, toml_parser +from .constants import ( + VERBOSITY_TRACE_CONSTRUCTOR, + VERBOSITY_TRACE_COUNTEREXAMPLE, + VERBOSITY_TRACE_PATHS, + VERBOSITY_TRACE_SETUP, +) from .exceptions import HalmosException from .logs import ( COUNTEREXAMPLE_INVALID, COUNTEREXAMPLE_UNKNOWN, INTERNAL_ERROR, LOOP_BOUND, - PARSING_ERROR, REVERT_ALL, debug, error, @@ -57,7 +55,8 @@ warn, warn_code, ) -from .mapper import BuildOut, DeployAddressMapper, Mapper +from .mapper import BuildOut, DeployAddressMapper +from .processes import ExecutorRegistry, ShutdownError from .sevm import ( EMPTY_BALANCE, EVM, @@ -72,7 +71,6 @@ CallContext, CallOutput, Contract, - EventLog, Exec, FailCheatcode, Message, @@ -83,26 +81,29 @@ jumpid_str, mnemonic, ) +from .solve import ( + ContractContext, + FunctionContext, + PathContext, + SolverOutput, + solve_end_to_end, + solve_low_level, +) +from .traces import render_trace, rendered_trace from .utils import ( NamedTimer, address, - byte_length, color_error, con, create_solver, - cyan, green, hexify, indent_text, red, - stringify, unbox_int, yellow, ) -StrModel = dict[str, str] -AnyModel = StrModel | str - # Python version >=3.8.14, >=3.9.14, >=3.10.7, or >=3.11 if hasattr(sys, "set_int_max_str_digits"): sys.set_int_max_str_digits(0) @@ -114,10 +115,25 @@ if sys.stdout.encoding != "utf-8": sys.stdout.reconfigure(encoding="utf-8") -VERBOSITY_TRACE_COUNTEREXAMPLE = 2 -VERBOSITY_TRACE_SETUP = 3 -VERBOSITY_TRACE_PATHS = 4 -VERBOSITY_TRACE_CONSTRUCTOR = 5 + +@dataclass(frozen=True) +class TestResult: + name: str # test function name + exitcode: int + num_models: int = None + models: list[SolverOutput] = None + num_paths: tuple[int, int, int] = None # number of paths: [total, success, blocked] + time: tuple[int, int, int] = None # time: [total, paths, models] + num_bounded_loops: int = None # number of incomplete loops + + +class Exitcode(Enum): + PASS = 0 + COUNTEREXAMPLE = 1 + TIMEOUT = 2 + STUCK = 3 + REVERT_ALL = 4 + EXCEPTION = 5 def with_devdoc(args: HalmosConfig, fn_sig: str, contract_json: dict) -> HalmosConfig: @@ -146,6 +162,11 @@ def with_natspec( def load_config(_args) -> HalmosConfig: config = default_config() + if not config.solver_command: + warn( + "could not find z3 on the PATH -- check your PATH/venv or pass --solver-command explicitly" + ) + # parse CLI args first, so that can get `--help` out of the way and resolve `--debug` # but don't apply the CLI overrides yet cli_overrides = arg_parser().parse_args(_args) @@ -191,162 +212,16 @@ def mk_this() -> Address: return con_addr(FOUNDRY_TEST) -def mk_solver(args: HalmosConfig, logic="QF_AUFBV", ctx=None, assertion=False): - timeout = ( - args.solver_timeout_assertion if assertion else args.solver_timeout_branching +def mk_solver(args: HalmosConfig, logic="QF_AUFBV", ctx=None): + return create_solver( + logic=logic, + ctx=ctx, + timeout=args.solver_timeout_branching, + max_memory=args.solver_max_memory, ) - return create_solver(logic, ctx, timeout, args.solver_max_memory) - - -def rendered_initcode(context: CallContext) -> str: - message = context.message - data = message.data - - initcode_str = "" - args_str = "" - - if ( - isinstance(data, BitVecRef) - and is_app(data) - and data.decl().kind() == Z3_OP_CONCAT - ): - children = [arg for arg in data.children()] - if isinstance(children[0], BitVecNumRef): - initcode_str = hex(children[0].as_long()) - args_str = ", ".join(map(str, children[1:])) - else: - initcode_str = hexify(data) - - return f"{initcode_str}({cyan(args_str)})" - - -def render_output(context: CallContext, file=sys.stdout) -> None: - output = context.output - returndata_str = "0x" - failed = output.error is not None - - if not failed and context.is_stuck(): - return - - data = output.data - if data is not None: - is_create = context.message.is_create() - if hasattr(data, "unwrap"): - data = data.unwrap() - - returndata_str = ( - f"<{byte_length(data)} bytes of code>" - if (is_create and not failed) - else hexify(data) - ) - - ret_scheme = context.output.return_scheme - ret_scheme_str = f"{cyan(mnemonic(ret_scheme))} " if ret_scheme is not None else "" - error_str = f" (error: {repr(output.error)})" if failed else "" - - color = red if failed else green - indent = context.depth * " " - print( - f"{indent}{color('↩ ')}{ret_scheme_str}{color(returndata_str)}{color(error_str)}", - file=file, - ) - - -def rendered_log(log: EventLog) -> str: - opcode_str = f"LOG{len(log.topics)}" - topics = [ - f"{cyan(f'topic{i}')}={hexify(topic)}" for i, topic in enumerate(log.topics) - ] - data_str = f"{cyan('data')}={hexify(log.data)}" - args_str = ", ".join(topics + [data_str]) - - return f"{opcode_str}({args_str})" - - -def rendered_trace(context: CallContext) -> str: - with io.StringIO() as output: - render_trace(context, file=output) - return output.getvalue() - - -def rendered_calldata(calldata: ByteVec, contract_name: str | None = None) -> str: - if not calldata: - return "0x" - - if len(calldata) < 4: - return hexify(calldata) - - if len(calldata) == 4: - return f"{hexify(calldata.unwrap(), contract_name)}()" - - selector = calldata[:4].unwrap() - args = calldata[4:].unwrap() - return f"{hexify(selector, contract_name)}({hexify(args)})" - - -def render_trace(context: CallContext, file=sys.stdout) -> None: - message = context.message - addr = unbox_int(message.target) - addr_str = str(addr) if is_bv(addr) else hex(addr) - # check if we have a contract name for this address in our deployment mapper - addr_str = DeployAddressMapper().get_deployed_contract(addr_str) - - value = unbox_int(message.value) - value_str = f" (value: {value})" if is_bv(value) or value > 0 else "" - - call_scheme_str = f"{cyan(mnemonic(message.call_scheme))} " - indent = context.depth * " " - - if message.is_create(): - # TODO: select verbosity level to render full initcode - # initcode_str = rendered_initcode(context) - - try: - if context.output.error is None: - target = hex(int(str(message.target))) - bytecode = context.output.data.unwrap().hex() - contract_name = Mapper().get_by_bytecode(bytecode).contract_name - - DeployAddressMapper().add_deployed_contract(target, contract_name) - addr_str = contract_name - except Exception: - # TODO: print in debug mode - ... - - initcode_str = f"<{byte_length(message.data)} bytes of initcode>" - print( - f"{indent}{call_scheme_str}{addr_str}::{initcode_str}{value_str}", file=file - ) - - else: - calldata = rendered_calldata(message.data, addr_str) - call_str = f"{addr_str}::{calldata}" - static_str = yellow(" [static]") if message.is_static else "" - print(f"{indent}{call_scheme_str}{call_str}{static_str}{value_str}", file=file) - - log_indent = (context.depth + 1) * " " - for trace_element in context.trace: - if isinstance(trace_element, CallContext): - render_trace(trace_element, file=file) - elif isinstance(trace_element, EventLog): - print(f"{log_indent}{rendered_log(trace_element)}", file=file) - else: - raise HalmosException(f"unexpected trace element: {trace_element}") - - render_output(context, file=file) - - if context.depth == 1: - print(file=file) -def deploy_test( - creation_hexcode: str, - deployed_hexcode: str, - sevm: SEVM, - args: HalmosConfig, - libs: dict, - solver: Solver, -) -> Exec: +def deploy_test(ctx: FunctionContext, sevm: SEVM) -> Exec: this = mk_this() message = Message( target=this, @@ -364,12 +239,13 @@ def deploy_test( block=mk_block(), context=CallContext(message=message), pgm=None, # to be added - path=Path(solver), + path=Path(ctx.solver), ) # deploy libraries and resolve library placeholders in hexcode - (creation_hexcode, deployed_hexcode) = ex.resolve_libs( - creation_hexcode, deployed_hexcode, libs + contract_ctx = ctx.contract_ctx + (creation_hexcode, _) = ex.resolve_libs( + contract_ctx.creation_hexcode, contract_ctx.deployed_hexcode, contract_ctx.libs ) # test contract creation bytecode @@ -383,9 +259,9 @@ def deploy_test( if len(exs) != 1: raise ValueError(f"constructor: # of paths: {len(exs)}") - ex = exs[0] + [ex] = exs - if args.verbose >= VERBOSITY_TRACE_CONSTRUCTOR: + if ctx.args.verbose >= VERBOSITY_TRACE_CONSTRUCTOR: print("Constructor trace:") render_trace(ex.context) @@ -409,100 +285,103 @@ def deploy_test( return ex -def setup( - creation_hexcode: str, - deployed_hexcode: str, - abi: dict, - setup_info: FunctionInfo, - args: HalmosConfig, - libs: dict, - solver: Solver, -) -> Exec: +def setup(ctx: FunctionContext) -> Exec: setup_timer = NamedTimer("setup") setup_timer.create_subtimer("decode") + args, setup_info = ctx.args, ctx.info sevm = SEVM(args, setup_info) - setup_ex = deploy_test(creation_hexcode, deployed_hexcode, sevm, args, libs, solver) + setup_ex = deploy_test(ctx, sevm) setup_timer.create_subtimer("run") setup_sig = setup_info.sig - if setup_sig: - # TODO: dyn_params may need to be passed to mk_calldata in run() - calldata, dyn_params = mk_calldata(abi, setup_info, args) - setup_ex.path.process_dyn_params(dyn_params) - - parent_message = setup_ex.message() - setup_ex.context = CallContext( - message=Message( - target=parent_message.target, - caller=parent_message.caller, - origin=parent_message.origin, - value=0, - data=calldata, - call_scheme=EVM.CALL, - ), - ) + if not setup_sig: + if args.statistics: + print(setup_timer.report()) + return setup_ex + + # TODO: dyn_params may need to be passed to mk_calldata in run() + calldata, dyn_params = mk_calldata(ctx.contract_ctx.abi, setup_info, args) + setup_ex.path.process_dyn_params(dyn_params) + + parent_message = setup_ex.message() + setup_ex.context = CallContext( + message=Message( + target=parent_message.target, + caller=parent_message.caller, + origin=parent_message.origin, + value=0, + data=calldata, + call_scheme=EVM.CALL, + ), + ) - setup_exs_all = sevm.run(setup_ex) - setup_exs_no_error = [] + setup_exs_all = sevm.run(setup_ex) + setup_exs_no_error: list[tuple[Exec, SMTQuery]] = [] - for idx, setup_ex in enumerate(setup_exs_all): - if args.verbose >= VERBOSITY_TRACE_SETUP: - print(f"{setup_sig} trace #{idx+1}:") - render_trace(setup_ex.context) + for path_id, setup_ex in enumerate(setup_exs_all): + if args.verbose >= VERBOSITY_TRACE_SETUP: + print(f"{setup_sig} trace #{path_id}:") + render_trace(setup_ex.context) - if not (err := setup_ex.context.output.error): - setup_exs_no_error.append((setup_ex, setup_ex.path.to_smt2(args))) + if err := setup_ex.context.output.error: + opcode = setup_ex.current_opcode() + if opcode not in [EVM.REVERT, EVM.INVALID]: + warn_code( + INTERNAL_ERROR, + f"in {setup_sig}, executing {mnemonic(opcode)} failed with: {err}", + ) - else: - opcode = setup_ex.current_opcode() - if opcode not in [EVM.REVERT, EVM.INVALID]: - warn_code( - INTERNAL_ERROR, - f"in {setup_sig}, executing {mnemonic(opcode)} failed with: {err}", - ) + # only render the trace if we didn't already do it + if VERBOSITY_TRACE_COUNTEREXAMPLE <= args.verbose < VERBOSITY_TRACE_SETUP: + print(f"{setup_sig} trace:") + render_trace(setup_ex.context) - # only render the trace if we didn't already do it - if ( - args.verbose < VERBOSITY_TRACE_SETUP - and args.verbose >= VERBOSITY_TRACE_COUNTEREXAMPLE - ): - print(f"{setup_sig} trace:") - render_trace(setup_ex.context) - - setup_exs = [] - - if len(setup_exs_no_error) > 1: - for setup_ex, query in setup_exs_no_error: - res, _, _ = solve(query, args) - if res != unsat: - setup_exs.append(setup_ex) + else: + # note: ex.path.to_smt2() needs to be called at this point. The solver object is shared across paths, + # and solver.to_smt2() will return a different query if it is called after a different path is explored. + setup_exs_no_error.append((setup_ex, setup_ex.path.to_smt2(args))) + + setup_exs: list[Exec] = [] + + match setup_exs_no_error: + case []: + pass + case [(ex, _)]: + setup_exs.append(ex) + case _: + for path_id, (ex, query) in enumerate(setup_exs_no_error): + path_ctx = PathContext( + args=args, + path_id=path_id, + query=query, + solving_ctx=ctx.solving_ctx, + ) + solver_output = solve_low_level(path_ctx) + if solver_output.result != unsat: + setup_exs.append(ex) if len(setup_exs) > 1: break - elif len(setup_exs_no_error) == 1: - setup_exs.append(setup_exs_no_error[0][0]) - - if len(setup_exs) == 0: + match len(setup_exs): + case 0: raise HalmosException(f"No successful path found in {setup_sig}") - - if len(setup_exs) > 1: + case n if n > 1: debug("\n".join(map(str, setup_exs))) - raise HalmosException(f"Multiple paths were found in {setup_sig}") - setup_ex = setup_exs[0] + [setup_ex] = setup_exs - if args.print_setup_states: - print(setup_ex) + if args.print_setup_states: + print(setup_ex) - if sevm.logs.bounded_loops: - warn_code( - LOOP_BOUND, - f"{setup_sig}: paths have not been fully explored due to the loop unrolling bound: {args.loop}", - ) - debug("\n".join(jumpid_str(x) for x in sevm.logs.bounded_loops)) + if sevm.logs.bounded_loops: + warn_code( + LOOP_BOUND, + f"{setup_sig}: paths have not been fully explored due to the loop unrolling bound: {args.loop}", + ) + debug("\n".join(jumpid_str(x) for x in sevm.logs.bounded_loops)) if args.statistics: print(setup_timer.report()) @@ -510,85 +389,28 @@ def setup( return setup_ex -@dataclass -class PotentialModel: - model: AnyModel - is_valid: bool - - def __init__(self, model: ModelRef | str, args: HalmosConfig) -> None: - # convert model into string to avoid pickling errors for z3 (ctypes) objects containing pointers - self.model = ( - to_str_model(model, args.print_full_model) - if isinstance(model, ModelRef) - else model - ) - self.is_valid = is_model_valid(model) - - def __str__(self) -> str: - # expected to be a filename - if isinstance(self.model, str): - return f"see {self.model}" - - formatted = [f"\n {decl} = {val}" for decl, val in self.model.items()] - return "".join(sorted(formatted)) if formatted else "∅" - - -@dataclass(frozen=True) -class ModelWithContext: - # can be a filename containing the model or a dict with variable assignments - model: PotentialModel | None - index: int - result: CheckSatResult - unsat_core: list | None - - -@dataclass(frozen=True) -class TestResult: - name: str # test function name - exitcode: int - num_models: int = None - models: list[ModelWithContext] = None - num_paths: tuple[int, int, int] = None # number of paths: [total, success, blocked] - time: tuple[int, int, int] = None # time: [total, paths, models] - num_bounded_loops: int = None # number of incomplete loops - - -class Exitcode(Enum): - PASS = 0 - COUNTEREXAMPLE = 1 - TIMEOUT = 2 - STUCK = 3 - REVERT_ALL = 4 - EXCEPTION = 5 - - def is_global_fail_set(context: CallContext) -> bool: hevm_fail = isinstance(context.output.error, FailCheatcode) return hevm_fail or any(is_global_fail_set(x) for x in context.subcalls()) -def run( - setup_ex: Exec, - abi: dict, - fun_info: FunctionInfo, - args: HalmosConfig, - solver: Solver, -) -> TestResult: +def run_test(ctx: FunctionContext) -> TestResult: + args = ctx.args + fun_info = ctx.info funname, funsig = fun_info.name, fun_info.sig if args.verbose >= 1: print(f"Executing {funname}") - dump_dirname = f"/tmp/{funname}-{uuid.uuid4().hex}" - # - # calldata + # prepare calldata # + setup_ex = ctx.setup_ex sevm = SEVM(args, fun_info) - path = Path(solver) + path = Path(ctx.solver) path.extend_path(setup_ex.path) - cd, dyn_params = mk_calldata(abi, fun_info, args) + cd, dyn_params = mk_calldata(ctx.contract_ctx.abi, fun_info, args) path.process_dyn_params(dyn_params) message = Message( @@ -634,70 +456,83 @@ def run( ) ) - (logs, steps) = (sevm.logs, sevm.steps) - - # check assertion violations normal = 0 - models: list[ModelWithContext] = [] + potential = 0 stuck = [] - thread_pool = ThreadPoolExecutor(max_workers=args.solver_threads) - future_models = [] - counterexamples = [] - unsat_cores = [] - traces: dict[int, str] = {} - exec_cache: dict[int, Exec] = {} + def solve_end_to_end_callback(future: Future): + # beware: this function may be called from threads other than the main thread, + # so we must be careful to avoid referencing any z3 objects / contexts + + if e := future.exception(): + error(f"encountered exception during assertion solving: {e!r}") - def future_callback(future_model): - m = future_model.result() - models.append(m) + # + # we are done solving, process and triage the result + # - model, index, result = m.model, m.index, m.result + solver_output = future.result() + result, model = solver_output.result, solver_output.model - # retrieve cached exec and clear the cache entry - exec = exec_cache.pop(index, None) + if ctx.solving_ctx.executor.is_shutdown(): + # if the thread pool is in the process of shutting down, + # we want to stop processing remaining models/timeouts/errors, etc. + return + + # keep track of the solver outputs, so that we can display PASS/FAIL/TIMEOUT/ERROR later + ctx.solver_outputs.append(solver_output) if result == unsat: - if m.unsat_core: - unsat_cores.append(m.unsat_core) + if solver_output.unsat_core: + ctx.append_unsat_core(solver_output.unsat_core) return - # model could be an empty dict here - if model is not None: - if model.is_valid: - print(red(f"Counterexample: {model}")) - counterexamples.append(model) - else: - warn_code( - COUNTEREXAMPLE_INVALID, - f"Counterexample (potentially invalid): {model}", - ) - counterexamples.append(model) - else: + # model could be an empty dict here, so compare to None explicitly + if model is None: warn_code(COUNTEREXAMPLE_UNKNOWN, f"Counterexample: {result}") + return + # print counterexample trace if args.verbose >= VERBOSITY_TRACE_COUNTEREXAMPLE: - print( - f"Trace #{index + 1}:" - if args.verbose == VERBOSITY_TRACE_PATHS - else "Trace:" - ) - print(traces[index], end="") + path_id = solver_output.path_id + id_str = f" #{path_id}" if args.verbose >= VERBOSITY_TRACE_PATHS else "" + print(f"Trace{id_str}:") + print(ctx.traces[path_id], end="") + + if model.is_valid: + print(red(f"Counterexample: {model}")) + ctx.valid_counterexamples.append(model) + + # we have a valid counterexample, so we are eligible for early exit + if args.early_exit: + debug(f"Shutting down {ctx.info.name}'s solver executor") + ctx.solving_ctx.executor.shutdown(wait=False) + else: + warn_str = f"Counterexample (potentially invalid): {model}" + warn_code(COUNTEREXAMPLE_INVALID, warn_str) - if args.print_failed_states: - print(f"# {index + 1}") - print(exec) + ctx.invalid_counterexamples.append(model) + + # + # consume the sevm.run() generator + # (actually triggers path exploration) + # - # initialize with default value in case we don't enter the loop body - idx = -1 + path_id = 0 # default value in case we don't enter the loop body + submitted_futures = [] + for path_id, ex in enumerate(exs): + # check if early exit is triggered + if ctx.solving_ctx.executor.is_shutdown(): + if args.debug: + print("aborting path exploration, executor has been shutdown") + break - for idx, ex in enumerate(exs): # cache exec in case we need to print it later if args.print_failed_states: - exec_cache[idx] = ex + ctx.exec_cache[path_id] = ex if args.verbose >= VERBOSITY_TRACE_PATHS: - print(f"Path #{idx+1}:") + print(f"Path #{path_id}:") print(indent_text(hexify(ex.path))) print("\nTrace:") @@ -706,76 +541,108 @@ def future_callback(future_model): output = ex.context.output error_output = output.error if ex.is_panic_of(args.panic_error_codes) or is_global_fail_set(ex.context): + potential += 1 + if args.verbose >= 1: - print(f"Found potential path (id: {idx+1})") + print(f"Found potential path (id: {path_id})") panic_code = unbox_int(output.data[4:36].unwrap()) print(f"Panic(0x{panic_code:02x}) {error_output}") + # we don't know yet if this will lead to a counterexample + # so we save the rendered trace here and potentially print it later + # if a valid counterexample is found if args.verbose >= VERBOSITY_TRACE_COUNTEREXAMPLE: - traces[idx] = rendered_trace(ex.context) + ctx.traces[path_id] = rendered_trace(ex.context) - query = ex.path.to_smt2(args) + query: SMTQuery = ex.path.to_smt2(args) - future_model = thread_pool.submit( - gen_model_from_sexpr, - GenModelArgs(args, idx, query, unsat_cores, dump_dirname), + # beware: because this object crosses thread boundaries, we must be careful to + # avoid any reference to z3 objects + path_ctx = PathContext( + args=args, + path_id=path_id, + query=query, + solving_ctx=ctx.solving_ctx, ) - future_model.add_done_callback(future_callback) - future_models.append(future_model) + + try: + solve_future = ctx.thread_pool.submit(solve_end_to_end, path_ctx) + solve_future.add_done_callback(solve_end_to_end_callback) + submitted_futures.append(solve_future) + except ShutdownError: + if args.debug: + print("aborting path exploration, executor has been shutdown") + break elif ex.context.is_stuck(): - debug(f"Potential error path (id: {idx+1})") - res, _, _ = solve(ex.path.to_smt2(args), args) - if res != unsat: - stuck.append((idx, ex, ex.context.get_stuck_reason())) + debug(f"Potential error path (id: {path_id})") + path_ctx = PathContext( + args=args, + path_id=path_id, + query=ex.path.to_smt2(args), + solving_ctx=ctx.solving_ctx, + ) + solver_output = solve_low_level(path_ctx) + if solver_output.result != unsat: + stuck.append((path_id, ex, ex.context.get_stuck_reason())) if args.print_blocked_states: - traces[idx] = f"{hexify(ex.path)}\n{rendered_trace(ex.context)}" + ctx.traces[path_id] = ( + f"{hexify(ex.path)}\n{rendered_trace(ex.context)}" + ) elif not error_output: if args.print_success_states: - print(f"# {idx+1}") + print(f"# {path_id}") print(ex) normal += 1 # print post-states if args.print_states: - print(f"# {idx+1}") + print(f"# {path_id}") print(ex) # 0 width is unlimited - if args.width and idx >= args.width: - warn( - f"{funsig}: incomplete execution due to the specified limit: --width {args.width}" - ) + if args.width and path_id >= args.width: + msg = "incomplete execution due to the specified limit" + warn(f"{funsig}: {msg}: --width {args.width}") break - num_execs = idx + 1 + num_execs = path_id + 1 + + # the name is a bit misleading: this timer only starts after the exploration phase is complete + # but it's possible that solvers have already been running for a while timer.create_subtimer("models") - if future_models and args.verbose >= 1: + if potential > 0 and args.verbose >= 1: print( - f"# of potential paths involving assertion violations: {len(future_models)} / {num_execs} (--solver-threads {args.solver_threads})" + f"# of potential paths involving assertion violations: {potential} / {num_execs}" + f" (--solver-threads {args.solver_threads})" ) + # # display assertion solving progress - if not args.no_status or args.early_exit: + # + + if not args.no_status: while True: - if args.early_exit and len(counterexamples) > 0: - break - done = sum(fm.done() for fm in future_models) - total = len(future_models) + done = sum(fm.done() for fm in submitted_futures) + total = potential if done == total: break elapsed = timedelta(seconds=int(timer.elapsed())) sevm.status.update(f"[{elapsed}] solving queries: {done} / {total}") time.sleep(0.1) - if args.early_exit: - thread_pool.shutdown(wait=False, cancel_futures=True) - else: - thread_pool.shutdown(wait=True) + ctx.thread_pool.shutdown(wait=True) + + timer.stop() + time_info = timer.report(include_subtimers=args.statistics) - counter = Counter(str(m.result) for m in models) + # + # print test result + # + + counter = Counter(str(m.result) for m in ctx.solver_outputs) if counter["sat"] > 0: passfail = red("[FAIL]") exitcode = Exitcode.COUNTEREXAMPLE.value @@ -800,16 +667,19 @@ def future_callback(future_model): timer.stop() time_info = timer.report(include_subtimers=args.statistics) - # print result + # print test result print( - f"{passfail} {funsig} (paths: {num_execs}, {time_info}, bounds: [{', '.join([str(x) for x in dyn_params])}])" + f"{passfail} {funsig} (paths: {num_execs}, {time_info}, " + f"bounds: [{', '.join([str(x) for x in dyn_params])}])" ) - for idx, _, err in stuck: + for path_id, _, err in stuck: warn_code(INTERNAL_ERROR, f"Encountered {err}") if args.print_blocked_states: - print(f"\nPath #{idx+1}") - print(traces[idx], end="") + print(f"\nPath #{path_id}") + print(ctx.traces[path_id], end="") + + (logs, steps) = (sevm.logs, sevm.steps) if logs.bounded_loops: warn_code( @@ -824,14 +694,15 @@ def future_callback(future_model): json.dump(steps, json_file) # return test result + num_cexes = len(ctx.valid_counterexamples) + len(ctx.invalid_counterexamples) if args.minimal_json_output: - return TestResult(funsig, exitcode, len(counterexamples)) + return TestResult(funsig, exitcode, num_cexes) else: return TestResult( funsig, exitcode, - len(counterexamples), - counterexamples, + num_cexes, + ctx.valid_counterexamples + ctx.invalid_counterexamples, (num_execs, normal, len(stuck)), (timer.elapsed(), timer["paths"].elapsed(), timer["models"].elapsed()), len(logs.bounded_loops), @@ -855,61 +726,51 @@ def extract_setup(methodIdentifiers: dict[str, str]) -> FunctionInfo: return FunctionInfo(setup_name, setup_sig, setup_selector) -@dataclass(frozen=True) -class RunArgs: - # signatures of test functions to run - funsigs: list[str] - - # code of the current contract - creation_hexcode: str - deployed_hexcode: str - - abi: dict - methodIdentifiers: dict[str, str] - - args: HalmosConfig - contract_json: dict - libs: dict - - build_out_map: dict +def run_contract(ctx: ContractContext) -> list[TestResult]: + BuildOut().set_build_out(ctx.build_out_map) - -def run_sequential(run_args: RunArgs) -> list[TestResult]: - BuildOut().set_build_out(run_args.build_out_map) - - args = run_args.args - setup_info = extract_setup(run_args.methodIdentifiers) + args = ctx.args + setup_info = extract_setup(ctx.method_identifiers) try: - setup_config = with_devdoc(args, setup_info.sig, run_args.contract_json) + setup_config = with_devdoc(args, setup_info.sig, ctx.contract_json) setup_solver = mk_solver(setup_config) - setup_ex = setup( - run_args.creation_hexcode, - run_args.deployed_hexcode, - run_args.abi, - setup_info, - setup_config, - run_args.libs, - setup_solver, + setup_ctx = FunctionContext( + args=setup_config, + info=setup_info, + solver=setup_solver, + contract_ctx=ctx, ) + + setup_ex = setup(setup_ctx) except Exception as err: error(f"{setup_info.sig} failed: {type(err).__name__}: {err}") if args.debug: traceback.print_exc() + # reset any remaining solver states from the default context setup_solver.reset() + return [] test_results = [] - for funsig in run_args.funsigs: - fun_info = FunctionInfo( - funsig.split("(")[0], funsig, run_args.methodIdentifiers[funsig] - ) + for funsig in ctx.funsigs: + selector = ctx.method_identifiers[funsig] + fun_info = FunctionInfo(funsig.split("(")[0], funsig, selector) try: - test_config = with_devdoc(args, funsig, run_args.contract_json) + test_config = with_devdoc(args, funsig, ctx.contract_json) solver = mk_solver(test_config) debug(f"{test_config.formatted_layers()}") - test_result = run(setup_ex, run_args.abi, fun_info, test_config, solver) + + test_ctx = FunctionContext( + args=test_config, + info=fun_info, + solver=solver, + contract_ctx=ctx, + setup_ex=setup_ex, + ) + + test_result = run_test(test_ctx) except Exception as err: print(f"{color_error('[ERROR]')} {funsig}") error(f"{type(err).__name__}: {err}") @@ -929,394 +790,6 @@ def run_sequential(run_args: RunArgs) -> list[TestResult]: return test_results -@dataclass(frozen=True) -class GenModelArgs: - args: HalmosConfig - idx: int - sexpr: SMTQuery - known_unsat_cores: list[list] - dump_dirname: str | None = None - - -def parse_unsat_core(output) -> list | None: - # parsing example: - # unsat - # (error "the context is unsatisfiable") # <-- this line is optional - # (<41702> <37030> <36248> <47880>) - # result: - # [41702, 37030, 36248, 47880] - pattern = r"unsat\s*(\(\s*error\s+[^)]*\)\s*)?\(\s*((<[0-9]+>\s*)*)\)" - match = re.search(pattern, output) - if match: - result = [re.sub(r"<([0-9]+)>", r"\1", name) for name in match.group(2).split()] - return result - else: - warn(f"error in parsing unsat core: {output}") - return None - - -def solve( - query: SMTQuery, args: HalmosConfig, dump_filename: str | None = None -) -> tuple[CheckSatResult, PotentialModel | None, list | None]: - if args.dump_smt_queries or args.solver_command: - if not dump_filename: - dump_filename = f"/tmp/{uuid.uuid4().hex}.smt2" - - # for each implication assertion, `(assert (=> |id| c))`, in query.smtlib, - # generate a corresponding named assertion, `(assert (! |id| :named ))`. - # see `svem.Path.to_smt2()` for more details. - if args.cache_solver: - named_assertions = "".join( - [ - f"(assert (! |{assert_id}| :named <{assert_id}>))\n" - for assert_id in query.assertions - ] - ) - - with open(dump_filename, "w") as f: - if args.verbose >= 1: - debug(f"Writing SMT query to {dump_filename}") - if args.cache_solver: - f.write("(set-option :produce-unsat-cores true)\n") - f.write("(set-logic QF_AUFBV)\n") - f.write(query.smtlib) - if args.cache_solver: - f.write(named_assertions) - f.write("(check-sat)\n") - f.write("(get-model)\n") - if args.cache_solver: - f.write("(get-unsat-core)\n") - - if args.solver_command: - if args.verbose >= 1: - debug(" Checking with external solver process") - debug(f" {args.solver_command} {dump_filename} >{dump_filename}.out") - - # solver_timeout_assertion == 0 means no timeout, - # which translates to timeout_seconds=None for subprocess.run - timeout_seconds = None - if timeout_millis := args.solver_timeout_assertion: - timeout_seconds = timeout_millis / 1000 - - cmd = args.solver_command.split() + [dump_filename] - try: - res_str = subprocess.run( - cmd, capture_output=True, text=True, timeout=timeout_seconds - ).stdout.strip() - res_str_head = res_str.split("\n", 1)[0] - - with open(f"{dump_filename}.out", "w") as f: - f.write(res_str) - - if args.verbose >= 1: - debug(f" {res_str_head}") - - if res_str_head == "unsat": - unsat_core = parse_unsat_core(res_str) if args.cache_solver else None - return unsat, None, unsat_core - elif res_str_head == "sat": - return sat, PotentialModel(f"{dump_filename}.out", args), None - else: - return unknown, None, None - except subprocess.TimeoutExpired: - return unknown, None, None - - else: - ctx = Context() - solver = mk_solver(args, ctx=ctx, assertion=True) - solver.from_string(query.smtlib) - if args.cache_solver: - solver.set(unsat_core=True) - ids = [Bool(f"{x}", ctx) for x in query.assertions] - result = solver.check(*ids) - else: - result = solver.check() - model = PotentialModel(solver.model(), args) if result == sat else None - unsat_core = ( - [str(core) for core in solver.unsat_core()] - if args.cache_solver and result == unsat - else None - ) - solver.reset() - return result, model, unsat_core - - -def check_unsat_cores(query, unsat_cores) -> bool: - # return true if the given query contains any given unsat core - for unsat_core in unsat_cores: - if all(core in query.assertions for core in unsat_core): - return True - return False - - -def gen_model_from_sexpr(fn_args: GenModelArgs) -> ModelWithContext: - args, idx, sexpr = fn_args.args, fn_args.idx, fn_args.sexpr - - dump_dirname = fn_args.dump_dirname - dump_filename = f"{dump_dirname}/{idx+1}.smt2" - should_dump = args.dump_smt_queries or args.solver_command - if should_dump and not os.path.isdir(dump_dirname): - os.makedirs(dump_dirname) - print(f"Generating SMT queries in {dump_dirname}") - - if args.verbose >= 1: - print(f"Checking path condition (path id: {idx+1})") - - if check_unsat_cores(sexpr, fn_args.known_unsat_cores): - # if the given query contains an unsat-core, it is unsat; no need to run the solver. - if args.verbose >= 1: - print(" Already proven unsat") - return package_result(None, idx, unsat, None, args) - - res, model, unsat_core = solve(sexpr, args, dump_filename) - - if res == sat and not model.is_valid: - if args.verbose >= 1: - print(" Checking again with refinement") - - refined_filename = dump_filename.replace(".smt2", ".refined.smt2") - res, model, unsat_core = solve(refine(sexpr), args, refined_filename) - - return package_result(model, idx, res, unsat_core, args) - - -def refine(query: SMTQuery) -> SMTQuery: - smtlib = query.smtlib - - # replace uninterpreted abstraction with actual symbols for assertion solving - smtlib = re.sub( - r"\(declare-fun f_evm_(bvmul)_([0-9]+) \(\(_ BitVec \2\) \(_ BitVec \2\)\) \(_ BitVec \2\)\)", - r"(define-fun f_evm_\1_\2 ((x (_ BitVec \2)) (y (_ BitVec \2))) (_ BitVec \2) (\1 x y))", - smtlib, - ) - - # replace `(f_evm_bvudiv_N x y)` with `(ite (= y (_ bv0 N)) (_ bv0 N) (bvudiv x y))` - # similarly for bvurem, bvsdiv, and bvsrem - # NOTE: (bvudiv x (_ bv0 N)) is *defined* to (bvneg (_ bv1 N)); while (div x 0) is undefined - smtlib = re.sub( - r"\(declare-fun f_evm_(bvudiv|bvurem|bvsdiv|bvsrem)_([0-9]+) \(\(_ BitVec \2\) \(_ BitVec \2\)\) \(_ BitVec \2\)\)", - r"(define-fun f_evm_\1_\2 ((x (_ BitVec \2)) (y (_ BitVec \2))) (_ BitVec \2) (ite (= y (_ bv0 \2)) (_ bv0 \2) (\1 x y)))", - smtlib, - ) - - return SMTQuery(smtlib, query.assertions) - - -def package_result( - model: PotentialModel | None, - idx: int, - result: CheckSatResult, - unsat_core: list | None, - args: HalmosConfig, -) -> ModelWithContext: - if result == unsat: - if args.verbose >= 1: - print(f" Invalid path; ignored (path id: {idx+1})") - return ModelWithContext(None, idx, result, unsat_core) - - if result == sat: - if args.verbose >= 1: - print(f" Valid path; counterexample generated (path id: {idx+1})") - return ModelWithContext(model, idx, result, None) - - else: - if args.verbose >= 1: - print(f" Timeout (path id: {idx+1})") - return ModelWithContext(None, idx, result, None) - - -def is_model_valid(model: ModelRef | str) -> bool: - # TODO: evaluate the path condition against the given model after excluding f_evm_* symbols, - # since the f_evm_* symbols may still appear in valid models. - - # model is a filename, containing solver output - if isinstance(model, str): - with open(model) as f: - for line in f: - if "f_evm_" in line: - return False - return True - - # z3 model object - else: - return all(not str(decl).startswith("f_evm_") for decl in model) - - -def to_str_model(model: ModelRef, print_full_model: bool) -> StrModel: - def select(var): - name = str(var) - return name.startswith("p_") or name.startswith("halmos_") - - select_model = filter(select, model) if not print_full_model else model - return {str(decl): stringify(str(decl), model[decl]) for decl in select_model} - - -def get_contract_type( - ast_nodes: list, contract_name: str -) -> tuple[str | None, str | None]: - for node in ast_nodes: - if node["nodeType"] == "ContractDefinition" and node["name"] == contract_name: - abstract = "abstract " if node.get("abstract") else "" - contract_type = abstract + node["contractKind"] - natspec = node.get("documentation") - return contract_type, natspec - - return None, None - - -def parse_build_out(args: HalmosConfig) -> dict: - result = {} # compiler version -> source filename -> contract name -> (json, type) - - out_path = os.path.join(args.root, args.forge_build_out) - if not os.path.exists(out_path): - raise FileNotFoundError( - f"The build output directory `{out_path}` does not exist" - ) - - for sol_dirname in os.listdir(out_path): # for each source filename - if not sol_dirname.endswith(".sol"): - continue - - sol_path = os.path.join(out_path, sol_dirname) - if not os.path.isdir(sol_path): - continue - - for json_filename in os.listdir(sol_path): # for each contract name - try: - if not json_filename.endswith(".json"): - continue - if json_filename.startswith("."): - continue - - json_path = os.path.join(sol_path, json_filename) - with open(json_path, encoding="utf8") as f: - json_out = json.load(f) - - # cut off compiler version number as well - contract_name = json_filename.split(".")[0] - ast_nodes = json_out["ast"]["nodes"] - contract_type, natspec = get_contract_type(ast_nodes, contract_name) - - # can happen to solidity files for multiple reasons: - # - import only (like console2.log) - # - defines only structs or enums - # - defines only free functions - # - ... - if contract_type is None: - debug(f"Skipped {json_filename}, no contract definition found") - continue - - compiler_version = json_out["metadata"]["compiler"]["version"] - result.setdefault(compiler_version, {}) - result[compiler_version].setdefault(sol_dirname, {}) - contract_map = result[compiler_version][sol_dirname] - - if contract_name in contract_map: - raise ValueError( - "duplicate contract names in the same file", - contract_name, - sol_dirname, - ) - - contract_map[contract_name] = (json_out, contract_type, natspec) - parse_symbols(args, contract_map, contract_name) - - except Exception as err: - warn_code( - PARSING_ERROR, - f"Skipped {json_filename} due to parsing failure: {type(err).__name__}: {err}", - ) - if args.debug: - traceback.print_exc() - continue - - return result - - -def parse_symbols(args: HalmosConfig, contract_map: dict, contract_name: str) -> None: - try: - json_out = contract_map[contract_name][0] - bytecode = json_out["bytecode"]["object"] - contract_mapping_info = Mapper().get_or_create(contract_name) - contract_mapping_info.bytecode = bytecode - - Mapper().parse_ast(json_out["ast"]) - - except Exception: - debug(f"error parsing symbols for contract {contract_name}") - debug(traceback.format_exc()) - - # we parse symbols as best effort, don't propagate exceptions - pass - - -def parse_devdoc(funsig: str, contract_json: dict) -> str | None: - try: - return contract_json["metadata"]["output"]["devdoc"]["methods"][funsig][ - "custom:halmos" - ] - except KeyError: - return None - - -def parse_natspec(natspec: dict) -> str: - # This parsing scheme is designed to handle: - # - # - multiline tags: - # /// @custom:halmos --x - # /// --y - # - # - multiple tags: - # /// @custom:halmos --x - # /// @custom:halmos --y - # - # - tags that start in the middle of line: - # /// blah blah @custom:halmos --x - # /// --y - # - # In all the above examples, this scheme returns "--x (whitespaces) --y" - isHalmosTag = False - result = "" - for item in re.split(r"(@\S+)", natspec.get("text", "")): - if item == "@custom:halmos": - isHalmosTag = True - elif re.match(r"^@\S", item): - isHalmosTag = False - elif isHalmosTag: - result += item - return result.strip() - - -def import_libs(build_out_map: dict, hexcode: str, linkReferences: dict) -> dict: - libs = {} - - for filepath in linkReferences: - file_name = filepath.split("/")[-1] - - for lib_name in linkReferences[filepath]: - (lib_json, _, _) = build_out_map[file_name][lib_name] - lib_hexcode = lib_json["deployedBytecode"]["object"] - - # in bytes, multiply indices by 2 and offset 0x - placeholder_index = linkReferences[filepath][lib_name][0]["start"] * 2 + 2 - placeholder = hexcode[placeholder_index : placeholder_index + 40] - - libs[f"{filepath}:{lib_name}"] = { - "placeholder": placeholder, - "hexcode": lib_hexcode, - } - - return libs - - -def build_output_iterator(build_out: dict): - for compiler_version in sorted(build_out): - build_out_map = build_out[compiler_version] - for filename in sorted(build_out_map): - for contract_name in sorted(build_out_map[filename]): - yield (build_out_map, filename, contract_name) - - def contract_regex(args): if args.contract: return f"^{args.contract}$" @@ -1417,6 +890,8 @@ def _main(_args=None) -> MainResult: # def on_exit(exitcode: int) -> MainResult: + ExecutorRegistry().shutdown_all() + result = MainResult(exitcode, test_results_map) if args.json_output: @@ -1470,26 +945,28 @@ def on_signal(signum, frame): # support for `/// @custom:halmos` annotations contract_args = with_natspec(args, contract_name, natspec) - run_args = RunArgs( - funsigs, - creation_hexcode, - deployed_hexcode, - abi, - methodIdentifiers, - contract_args, - contract_json, - libs, - build_out_map, + contract_ctx = ContractContext( + args=contract_args, + name=contract_name, + funsigs=funsigs, + creation_hexcode=creation_hexcode, + deployed_hexcode=deployed_hexcode, + abi=abi, + method_identifiers=methodIdentifiers, + contract_json=contract_json, + libs=libs, + build_out_map=build_out_map, ) - test_results = run_sequential(run_args) - + test_results = run_contract(contract_ctx) num_passed = sum(r.exitcode == 0 for r in test_results) num_failed = num_found - num_passed print( - f"Symbolic test result: {num_passed} passed; " - f"{num_failed} failed; {contract_timer.report()}" + "Symbolic test result: " + f"{num_passed} passed; " + f"{num_failed} failed; " + f"{contract_timer.report()}" ) total_found += num_found diff --git a/src/halmos/build.py b/src/halmos/build.py new file mode 100644 index 00000000..b432826d --- /dev/null +++ b/src/halmos/build.py @@ -0,0 +1,174 @@ +import json +import os +import re +import traceback + +from halmos.config import Config as HalmosConfig +from halmos.logs import PARSING_ERROR, debug, warn_code +from halmos.mapper import Mapper + + +def get_contract_type( + ast_nodes: list, contract_name: str +) -> tuple[str | None, str | None]: + for node in ast_nodes: + if node["nodeType"] == "ContractDefinition" and node["name"] == contract_name: + abstract = "abstract " if node.get("abstract") else "" + contract_type = abstract + node["contractKind"] + natspec = node.get("documentation") + return contract_type, natspec + + return None, None + + +def parse_build_out(args: HalmosConfig) -> dict: + result = {} # compiler version -> source filename -> contract name -> (json, type) + + out_path = os.path.join(args.root, args.forge_build_out) + if not os.path.exists(out_path): + raise FileNotFoundError( + f"The build output directory `{out_path}` does not exist" + ) + + for sol_dirname in os.listdir(out_path): # for each source filename + if not sol_dirname.endswith(".sol"): + continue + + sol_path = os.path.join(out_path, sol_dirname) + if not os.path.isdir(sol_path): + continue + + for json_filename in os.listdir(sol_path): # for each contract name + try: + if not json_filename.endswith(".json"): + continue + if json_filename.startswith("."): + continue + + json_path = os.path.join(sol_path, json_filename) + with open(json_path, encoding="utf8") as f: + json_out = json.load(f) + + # cut off compiler version number as well + contract_name = json_filename.split(".")[0] + ast_nodes = json_out["ast"]["nodes"] + contract_type, natspec = get_contract_type(ast_nodes, contract_name) + + # can happen to solidity files for multiple reasons: + # - import only (like console2.log) + # - defines only structs or enums + # - defines only free functions + # - ... + if contract_type is None: + debug(f"Skipped {json_filename}, no contract definition found") + continue + + compiler_version = json_out["metadata"]["compiler"]["version"] + result.setdefault(compiler_version, {}) + result[compiler_version].setdefault(sol_dirname, {}) + contract_map = result[compiler_version][sol_dirname] + + if contract_name in contract_map: + raise ValueError( + "duplicate contract names in the same file", + contract_name, + sol_dirname, + ) + + contract_map[contract_name] = (json_out, contract_type, natspec) + parse_symbols(args, contract_map, contract_name) + + except Exception as err: + warn_code( + PARSING_ERROR, + f"Skipped {json_filename} due to parsing failure: {type(err).__name__}: {err}", + ) + if args.debug: + traceback.print_exc() + continue + + return result + + +def parse_symbols(args: HalmosConfig, contract_map: dict, contract_name: str) -> None: + try: + json_out = contract_map[contract_name][0] + bytecode = json_out["bytecode"]["object"] + contract_mapping_info = Mapper().get_or_create(contract_name) + contract_mapping_info.bytecode = bytecode + + Mapper().parse_ast(json_out["ast"]) + + except Exception: + debug(f"error parsing symbols for contract {contract_name}") + debug(traceback.format_exc()) + + # we parse symbols as best effort, don't propagate exceptions + pass + + +def parse_devdoc(funsig: str, contract_json: dict) -> str | None: + try: + return contract_json["metadata"]["output"]["devdoc"]["methods"][funsig][ + "custom:halmos" + ] + except KeyError: + return None + + +def parse_natspec(natspec: dict) -> str: + # This parsing scheme is designed to handle: + # + # - multiline tags: + # /// @custom:halmos --x + # /// --y + # + # - multiple tags: + # /// @custom:halmos --x + # /// @custom:halmos --y + # + # - tags that start in the middle of line: + # /// blah blah @custom:halmos --x + # /// --y + # + # In all the above examples, this scheme returns "--x (whitespaces) --y" + isHalmosTag = False + result = "" + for item in re.split(r"(@\S+)", natspec.get("text", "")): + if item == "@custom:halmos": + isHalmosTag = True + elif re.match(r"^@\S", item): + isHalmosTag = False + elif isHalmosTag: + result += item + return result.strip() + + +def import_libs(build_out_map: dict, hexcode: str, linkReferences: dict) -> dict: + libs = {} + + for filepath in linkReferences: + file_name = filepath.split("/")[-1] + + for lib_name in linkReferences[filepath]: + (lib_json, _, _) = build_out_map[file_name][lib_name] + lib_hexcode = lib_json["deployedBytecode"]["object"] + + # in bytes, multiply indices by 2 and offset 0x + placeholder_index = linkReferences[filepath][lib_name][0]["start"] * 2 + 2 + placeholder = hexcode[placeholder_index : placeholder_index + 40] + + libs[f"{filepath}:{lib_name}"] = { + "placeholder": placeholder, + "hexcode": lib_hexcode, + } + + return libs + + +def build_output_iterator(build_out: dict): + for compiler_version in sorted(build_out): + build_out_map = build_out[compiler_version] + for filename in sorted(build_out_map): + for contract_name in sorted(build_out_map[filename]): + yield (build_out_map, filename, contract_name) diff --git a/src/halmos/config.py b/src/halmos/config.py index 2d28e70f..eaa78270 100644 --- a/src/halmos/config.py +++ b/src/halmos/config.py @@ -6,6 +6,7 @@ from collections.abc import Callable, Generator from dataclasses import MISSING, dataclass, fields from dataclasses import field as dataclass_field +from pathlib import Path from typing import Any import toml @@ -25,6 +26,29 @@ ) +def find_venv_root() -> Path | None: + # If the environment variable is set, use that + if "VIRTUAL_ENV" in os.environ: + return Path(os.environ["VIRTUAL_ENV"]) + + # Otherwise, if we're in a venv, sys.prefix != sys.base_prefix + if sys.prefix != sys.base_prefix: + return Path(sys.prefix) + + # Not in a virtual environment + return None + + +def find_z3_path() -> Path | None: + venv_path = find_venv_root() + if venv_path: + z3_bin = "z3.exe" if sys.platform == "win32" else "z3" + z3_path = venv_path / "bin" / z3_bin + if z3_path.exists(): + return z3_path + return None + + # helper to define config fields def arg( help: str, @@ -441,7 +465,7 @@ class Config: solver_command: str = arg( help="use the given command when invoking the solver", - global_default=None, + global_default=str(find_z3_path()), metavar="COMMAND", group=solver, ) diff --git a/src/halmos/constants.py b/src/halmos/constants.py new file mode 100644 index 00000000..f4583ae9 --- /dev/null +++ b/src/halmos/constants.py @@ -0,0 +1,4 @@ +VERBOSITY_TRACE_COUNTEREXAMPLE = 2 +VERBOSITY_TRACE_SETUP = 3 +VERBOSITY_TRACE_PATHS = 4 +VERBOSITY_TRACE_CONSTRUCTOR = 5 diff --git a/src/halmos/processes.py b/src/halmos/processes.py new file mode 100644 index 00000000..05109a75 --- /dev/null +++ b/src/halmos/processes.py @@ -0,0 +1,246 @@ +import concurrent.futures +import contextlib +import subprocess +import threading +import time +import weakref +from subprocess import PIPE, Popen, TimeoutExpired + +import psutil + + +class ExecutorRegistry: + _instance = None + + # Singleton pattern + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._executors = weakref.WeakSet() + return cls._instance + + def register(self, executor): + self._executors.add(executor) + + def shutdown_all(self): + """Shuts down all registered executors.""" + + for ex in list(self._executors): + ex.shutdown(wait=False) + + +class PopenFuture(concurrent.futures.Future): + cmd: list[str] + timeout: float | None # in seconds, None means no timeout + process: subprocess.Popen | None + stdout: str | None + stderr: str | None + returncode: int | None + start_time: float | None + end_time: float | None + _exception: Exception | None + + def __init__(self, cmd: list[str], timeout: float | None = None): + super().__init__() + self.cmd = cmd + self.timeout = timeout + self.process = None + self.stdout = None + self.stderr = None + self.returncode = None + self.start_time = None + self.end_time = None + self._exception = None + + def start(self): + """Starts the subprocess and immediately returns.""" + + def run(): + try: + self.start_time = time.time() + self.process = Popen(self.cmd, stdout=PIPE, stderr=PIPE, text=True) + + # blocks until the process terminates + self.stdout, self.stderr = self.process.communicate( + timeout=self.timeout + ) + self.end_time = time.time() + self.returncode = self.process.returncode + except TimeoutExpired as e: + self._exception = e + self.cancel() + except Exception as e: + self._exception = e + finally: + self.set_result((self.stdout, self.stderr, self.returncode)) + + # avoid daemon threads because they can cause issues during shutdown + # we don't expect them to actually prevent halmos from terminating, + # as long as the underlying processes are terminated (either by natural + # causes or by forceful termination) + threading.Thread(target=run, daemon=False).start() + + return self + + def cancel(self): + """Attempts to terminate and then kill the process and its children.""" + if not self.is_running(): + return + + # use psutil to kill the entire process tree (including children) + try: + parent_process = psutil.Process(self.process.pid) + processes = parent_process.children(recursive=True) + processes.append(parent_process) + + # ask politely to terminate first + for process in processes: + process.terminate() + + # termination grace period + with contextlib.suppress(TimeoutExpired): + parent_process.wait(timeout=0.5) + + # after grace period, force kill + for process in processes: + if process.is_running(): + process.kill() + + except psutil.NoSuchProcess: + # process already terminated, nothing to do + pass + + def exception(self) -> Exception | None: + """Returns any exception raised during the process.""" + + return self._exception + + def result(self, timeout=None) -> tuple[str | None, str | None, int]: + """Blocks until the process is finished and returns the result (stdout, stderr, returncode). + + Can raise TimeoutError or some Exception raised during execution""" + + return super().result(timeout=timeout) + + def done(self): + """Returns True if the process has finished.""" + + return super().done() + + def is_running(self): + """Returns True if the process is still running. + + Returns False before start() and after termination.""" + + return self.process and self.process.poll() is None + + +class ShutdownError(RuntimeError): + """Raised when submitting a future to an executor that has been shutdown.""" + + +class PopenExecutor(concurrent.futures.Executor): + """ + An executor that runs commands in subprocesses. + + Simple implementation that has no concept of max workers or pending futures. + + The explicit goal is to support killing running subprocesses. + """ + + def __init__(self, max_workers: int = 1): + self._futures: list[PopenFuture] = list() + self._shutdown = threading.Event() + self._lock = threading.Lock() + + # TODO: support max_workers + + @property + def futures(self): + return self._futures + + def submit(self, future: PopenFuture) -> PopenFuture: + """Accepts an unstarted PopenFuture and schedules it for execution. + + Raises ShutdownError if the executor has been shutdown.""" + + if self._shutdown.is_set(): + raise ShutdownError() + + with self._lock: + self._futures.append(future) + future.start() + return future + + def is_shutdown(self) -> bool: + return self._shutdown.is_set() + + def shutdown(self, wait=True, cancel_futures=False): + # TODO: support max_workers / pending futures + + self._shutdown.set() + + # we have no concept of pending futures, + # therefore no cancellation of pending futures + if wait: + self._join() + + # if asked for immediate shutdown we cancel everything + else: + with self._lock, concurrent.futures.ThreadPoolExecutor() as executor: + # kick off all cancellations in parallel + cancel_tasks = [executor.submit(f.cancel) for f in self._futures] + + # wait for them to finish + concurrent.futures.wait(cancel_tasks) + + def map(self, fn, *iterables, timeout=None, chunksize=1): + raise NotImplementedError() + + def _join(self): + """Wait until all futures are finished or cancelled.""" + + # submitting new futures after join() would be bad, + # so we make this internal and only call it from shutdown() + with contextlib.suppress(concurrent.futures.CancelledError): + for future in list(self._futures): + future.result() + + +def main(): + with PopenExecutor() as executor: + # example usage + def done_callback(future: PopenFuture): + stdout, stderr, exitcode = future.result() + cmd = " ".join(future.cmd) + elapsed = future.end_time - future.start_time + print( + f"{cmd}\n" + f" exitcode={exitcode}\n" + f" stdout={stdout.strip()}\n" + f" stderr={stderr.strip()}\n" + f" elapsed={elapsed:.2f}s" + ) + executor.shutdown(wait=False) + + # Submit multiple commands + commands = [ + "sleep 1", + "sleep 10", + "echo hello", + ] + + futures = [PopenFuture(command.split()) for command in commands] + + for future in futures: + future.add_done_callback(done_callback) + executor.submit(future) + + # exiting the context manager will shutdown the executor with wait=True + # so no new futures can be submitted + # the first call to done_callback will cause the remaining futures to be cancelled + # (and the underlying processes to be terminated) + + +if __name__ == "__main__": + main() diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index fbfe35a8..d319588c 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -484,6 +484,7 @@ def push(self, v: Word) -> None: else: if not (eq(v.sort(), BitVecSort256) or is_bool(v)): raise ValueError(v) + self.stack.append(simplify(v)) def pop(self) -> Word: @@ -2608,12 +2609,14 @@ def jumpi( follow_false = visited[False] < self.options.loop if not (follow_true and follow_false): self.logs.bounded_loops.append(jid) - debug( - f"\nloop id: {jid}\n" - f"loop condition: {cond}\n" - f"calldata: {ex.calldata()}\n" - f"path condition:\n{ex.path}\n" - ) + if self.options.debug: + # rendering ex.path to string can be expensive, so only do it if debug is enabled + debug( + f"\nloop id: {jid}\n" + f"loop condition: {cond}\n" + f"calldata: {ex.calldata()}\n" + f"path condition:\n{ex.path}\n" + ) else: # for constant-bounded loops follow_true = potential_true diff --git a/src/halmos/solve.py b/src/halmos/solve.py new file mode 100644 index 00000000..3f4eabeb --- /dev/null +++ b/src/halmos/solve.py @@ -0,0 +1,482 @@ +import re +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from pathlib import Path +from tempfile import TemporaryDirectory + +from z3 import CheckSatResult, Solver, sat, unknown, unsat + +from halmos.calldata import FunctionInfo +from halmos.config import Config as HalmosConfig +from halmos.logs import ( + debug, + error, + warn, +) +from halmos.processes import ( + ExecutorRegistry, + PopenExecutor, + PopenFuture, + TimeoutExpired, +) +from halmos.sevm import Exec, SMTQuery +from halmos.utils import hexify + + +@dataclass +class ModelVariable: + full_name: str + variable_name: str + solidity_type: str + smt_type: str + size_bits: int + value: int + + +ModelVariables = dict[str, ModelVariable] + +# Regular expression for capturing halmos variables +halmos_var_pattern = re.compile( + r""" + \(\s*define-fun\s+ # Match "(define-fun" + \|?((?:halmos_|p_)[^ |]+)\|?\s+ # Capture either halmos_* or p_*, optionally wrapped in "|" + \(\)\s+\(_\s+([^ ]+)\s+ # Capture the SMTLIB type (e.g., "BitVec 256") + (\d+)\)\s+ # Capture the bit-width or type argument + ( # Group for the value + \#b[01]+ # Binary value (e.g., "#b1010") + |\#x[0-9a-fA-F]+ # Hexadecimal value (e.g., "#xFF") + |\(_\s+bv\d+\s+\d+\) # Decimal value (e.g., "(_ bv42 256)") + ) + """, + re.VERBOSE, +) + + +@dataclass(frozen=True) +class PotentialModel: + model: ModelVariables + is_valid: bool + + def __str__(self) -> str: + formatted = [] + for v in self.model.values(): + # TODO: parse type and render accordingly + formatted.append(f"\n {v.full_name} = {hexify(v.value)}") + return "".join(sorted(formatted)) if formatted else "∅" + + +@dataclass(frozen=True) +class ContractContext: + # config with contract-specific overrides + args: HalmosConfig + + # name of this contract + name: str + + # signatures of test functions to run + funsigs: list[str] + + # data parsed from the build output for this contract + creation_hexcode: str + deployed_hexcode: str + abi: dict + method_identifiers: dict[str, str] + contract_json: dict + libs: dict + + # note: build_out_map is shared across all contracts compiled using the same compiler version + # so in principle, we could consider having another context, say CompileUnitContext, and put build_out_map there + build_out_map: dict + + +@dataclass(frozen=True) +class SolvingContext: + # directory for dumping solver files + dump_dir: TemporaryDirectory + + # shared solver executor for all paths in the same function + executor: PopenExecutor = field(default_factory=PopenExecutor) + + # list of unsat cores + unsat_cores: list[list] = field(default_factory=list) + + +@dataclass(frozen=True) +class FunctionContext: + # config with function-specific overrides + args: HalmosConfig + + # function name, signature, and selector + info: FunctionInfo + + # solver using the function-specific config + solver: Solver + + # backlink to the parent contract context + contract_ctx: ContractContext + + # optional starting state + setup_ex: Exec | None = None + + # function-level solving context + # the FunctionContext initializes and owns the SolvingContext + solving_ctx: SolvingContext = field(init=False) + + # function-level thread pool that drives assertion solving + thread_pool: ThreadPoolExecutor = field(init=False) + + # list of solver outputs for this function + solver_outputs: list["SolverOutput"] = field(default_factory=list) + + # list of valid counterexamples for this function + valid_counterexamples: list[PotentialModel] = field(default_factory=list) + + # list of potentially invalid counterexamples for this function + invalid_counterexamples: list[PotentialModel] = field(default_factory=list) + + # map from path id to trace + traces: dict[int, str] = field(default_factory=dict) + + # map from path id to execution + exec_cache: dict[int, Exec] = field(default_factory=dict) + + def __post_init__(self): + args = self.args + + # create a temporary directory for dumping solver files + prefix = ( + f"{self.info.name}-" + if self.info.name + else f"{self.contract_ctx.name}-constructor-" + ) + + # if the user explicitly enabled dumping, we don't want to delete the directory on exit + delete = not self.args.dump_smt_queries + + # ideally we would pass `delete=delete` to the constructor, but it's in >=3.12 + dump_dir = TemporaryDirectory(prefix=prefix, ignore_cleanup_errors=True) + + # If user wants to keep the files, prevent cleanup on exit + if not delete: + dump_dir._finalizer.detach() + + if args.verbose >= 1 or args.dump_smt_queries: + print(f"Generating SMT queries in {dump_dir.name}") + + solving_ctx = SolvingContext(dump_dir=dump_dir) + object.__setattr__(self, "solving_ctx", solving_ctx) + + thread_pool = ThreadPoolExecutor( + max_workers=self.args.solver_threads, + thread_name_prefix=f"{self.info.name}-", + ) + object.__setattr__(self, "thread_pool", thread_pool) + + # register the solver executor to be shutdown on exit + ExecutorRegistry().register(solving_ctx.executor) + + def append_unsat_core(self, unsat_core: list[str]) -> None: + self.solving_ctx.unsat_cores.append(unsat_core) + + +@dataclass(frozen=True) +class PathContext: + args: HalmosConfig + path_id: int + solving_ctx: SolvingContext + query: SMTQuery + is_refined: bool = False + + @property + def dump_file(self) -> Path: + refined_str = ".refined" if self.is_refined else "" + filename = f"{self.path_id}{refined_str}.smt2" + + return Path(self.solving_ctx.dump_dir.name) / filename + + def refine(self) -> "PathContext": + return PathContext( + args=self.args, + path_id=self.path_id, + solving_ctx=self.solving_ctx, + query=refine(self.query), + is_refined=True, + ) + + +@dataclass(frozen=True) +class SolverOutput: + # solver result + result: CheckSatResult + + # we don't backlink to the parent path context to avoid extra + # references to Exec objects past the lifetime of the path + path_id: int + + # solver model + model: PotentialModel | None = None + + # optional unsat core + unsat_core: list[str] | None = None + + @staticmethod + def from_result( + stdout: str, stderr: str, returncode: int, path_ctx: PathContext + ) -> "SolverOutput": + # extract the first line (we expect sat/unsat/unknown) + newline_idx = stdout.find("\n") + first_line = stdout[:newline_idx] if newline_idx != -1 else stdout + + args, path_id = path_ctx.args, path_ctx.path_id + if args.verbose >= 1: + debug(f" {first_line}") + + match first_line: + case "unsat": + unsat_core = parse_unsat_core(stdout) if args.cache_solver else None + return SolverOutput(unsat, path_id, unsat_core=unsat_core) + case "sat": + is_valid = is_model_valid(stdout) + model = PotentialModel(model=parse_model_str(stdout), is_valid=is_valid) + return SolverOutput(sat, path_id, model=model) + case _: + return SolverOutput(unknown, path_id) + + +def parse_const_value(value: str) -> int: + match value[:2]: + case "#b": + return int(value[2:], 2) + case "#x": + return int(value[2:], 16) + case "bv": + return int(value[2:]) + case _: + # we may have a group like (_ bv123 256) + tokens = value.split() + for token in tokens: + if token.startswith("bv"): + return int(token[2:]) + + raise ValueError(f"unknown value format: {value}") + + +def _parse_halmos_var_match(match: re.Match) -> ModelVariable: + full_name = match.group(1).strip() + smt_type = f"{match.group(2)} {match.group(3)}" + size_bits = int(match.group(3)) + value = parse_const_value(match.group(4)) + + # Extract name and typename from the variable name + parts = full_name.split("_") + variable_name = parts[1] + solidity_type = parts[2] + + return ModelVariable( + full_name=full_name, + variable_name=variable_name, + solidity_type=solidity_type, + smt_type=smt_type, + size_bits=size_bits, + value=value, + ) + + +def parse_model_str(smtlib_str: str) -> ModelVariables: + """Expects a whole smtlib model output file, as produced by a solver + in response to a `(check-sat)` + `(get-model)` command. + + Extracts halmos variables and returns them grouped by their full name""" + + model_variables: dict[str, ModelVariable] = {} + + # use a regex to find all the variables + # for now we explicitly don't try to properly parse the smtlib output + # because of idiosyncrasies of different solvers: + # - ignores the initial sat/unsat on the first line + # - ignores the occasional `(model)` command used by yices, stp, cvc4, etc. + + for match in halmos_var_pattern.finditer(smtlib_str): + try: + variable = _parse_halmos_var_match(match) + model_variables[variable.full_name] = variable + except Exception as e: + error(f"error parsing smtlib string '{match.string.strip()}': {e}") + raise e + + return model_variables + + +def parse_model_file(file_path: str) -> ModelVariables: + with open(file_path) as file: + return parse_model_str(file.read()) + + +def parse_unsat_core(output: str) -> list[str] | None: + # parsing example: + # unsat + # (error "the context is unsatisfiable") # <-- this line is optional + # (<41702> <37030> <36248> <47880>) + # result: + # [41702, 37030, 36248, 47880] + pattern = r"unsat\s*(\(\s*error\s+[^)]*\)\s*)?\(\s*((<[0-9]+>\s*)*)\)" + match = re.search(pattern, output) + if match: + result = [re.sub(r"<([0-9]+)>", r"\1", name) for name in match.group(2).split()] + return result + else: + warn(f"error in parsing unsat core: {output}") + return None + + +def dump( + path_ctx: PathContext, +) -> tuple[CheckSatResult, PotentialModel | None, list | None]: + args, query, dump_file = path_ctx.args, path_ctx.query, path_ctx.dump_file + + if args.verbose >= 1: + debug(f"Writing SMT query to {dump_file}") + + # for each implication assertion, `(assert (=> |id| c))`, in query.smtlib, + # generate a corresponding named assertion, `(assert (! |id| :named ))`. + # see `svem.Path.to_smt2()` for more details. + if args.cache_solver: + named_assertions = "".join( + [ + f"(assert (! |{assert_id}| :named <{assert_id}>))\n" + for assert_id in query.assertions + ] + ) + + dump_file.write_text( + "(set-option :produce-unsat-cores true)\n" + "(set-logic QF_AUFBV)\n" + f"{query.smtlib}\n" + f"{named_assertions}" + "(check-sat)\n" + "(get-model)\n" + "(get-unsat-core)\n" + ) + + else: + dump_file.write_text( + f"(set-logic QF_AUFBV)\n{query.smtlib}\n(check-sat)\n(get-model)\n" + ) + + +def is_model_valid(solver_stdout: str) -> bool: + # TODO: evaluate the path condition against the given model after excluding f_evm_* symbols, + # since the f_evm_* symbols may still appear in valid models. + + return "f_evm_" not in solver_stdout + + +def solve_low_level(path_ctx: PathContext) -> SolverOutput: + """Invokes an external solver process to solve the given query. + + Can raise TimeoutError or some Exception raised during execution""" + + args, smt2_filename = path_ctx.args, str(path_ctx.dump_file) + + # make sure the smt2 file has been written + dump(path_ctx) + + if args.verbose >= 1: + print(" Checking with external solver process") + print(f" {args.solver_command} {smt2_filename} > {smt2_filename}.out") + + # solver_timeout_assertion == 0 means no timeout, + # which translates to timeout_seconds=None for subprocess.run + timeout_seconds = t / 1000 if (t := args.solver_timeout_assertion) else None + + cmd = args.solver_command.split() + [smt2_filename] + future = PopenFuture(cmd, timeout=timeout_seconds) + + # starts the subprocess asynchronously + path_ctx.solving_ctx.executor.submit(future) + + # block until the external solver returns, times out, is interrupted, fails, etc. + try: + stdout, stderr, returncode = future.result() + except TimeoutExpired: + return SolverOutput(result=unknown, path_id=path_ctx.path_id) + + # save solver stdout to file + with open(f"{smt2_filename}.out", "w") as f: + f.write(stdout) + + # save solver stderr to file (only if there is an error) + if stderr: + with open(f"{smt2_filename}.err", "w") as f: + f.write(stderr) + + return SolverOutput.from_result(stdout, stderr, returncode, path_ctx) + + +def solve_end_to_end(ctx: PathContext) -> SolverOutput: + """Synchronously resolves a query in a given context, which may result in 0, 1 or multiple solver invocations. + + - may result in 0 invocations if the query contains a known unsat core (hence the need for the context) + - may result in exactly 1 invocation if the query is unsat, or sat with a valid model + - may result in multiple invocations if the query is sat and the model is invalid (needs refinement) + + If this produces a model, it _should_ be valid. + """ + path_id, query = ctx.path_id, ctx.query + + verbose = print if ctx.args.verbose >= 1 else lambda *args, **kwargs: None + verbose(f"Checking path condition {path_id=}") + + # if the query contains an unsat-core, it is unsat; no need to run the solver + if check_unsat_cores(query, ctx.solving_ctx.unsat_cores): + verbose(" Already proven unsat") + return SolverOutput(unsat, path_id) + + solver_output = solve_low_level(ctx) + result, model = solver_output.result, solver_output.model + + # if the ctx is already refined, we don't need to solve again + if result == sat and not model.is_valid and not ctx.is_refined: + verbose(" Checking again with refinement") + + refined_ctx = ctx.refine() + + if refined_ctx.query.smtlib != query.smtlib: + # note that check_unsat_cores cannot return true for the refined query, because it relies on only + # constraint ids, which don't change after refinement + # therefore we can skip the unsat core check in solve_end_to_end and go directly to solve_low_level + return solve_low_level(refined_ctx) + else: + verbose(" Refinement did not change the query, no need to solve again") + + return solver_output + + +def check_unsat_cores(query: SMTQuery, unsat_cores: list[list]) -> bool: + # return true if the given query contains any given unsat core + for unsat_core in unsat_cores: + if all(core in query.assertions for core in unsat_core): + return True + return False + + +def refine(query: SMTQuery) -> SMTQuery: + smtlib = query.smtlib + + # replace uninterpreted abstraction with actual symbols for assertion solving + smtlib = re.sub( + r"\(declare-fun f_evm_(bvmul)_([0-9]+) \(\(_ BitVec \2\) \(_ BitVec \2\)\) \(_ BitVec \2\)\)", + r"(define-fun f_evm_\1_\2 ((x (_ BitVec \2)) (y (_ BitVec \2))) (_ BitVec \2) (\1 x y))", + smtlib, + ) + + # replace `(f_evm_bvudiv_N x y)` with `(ite (= y (_ bv0 N)) (_ bv0 N) (bvudiv x y))` + # similarly for bvurem, bvsdiv, and bvsrem + # NOTE: (bvudiv x (_ bv0 N)) is *defined* to (bvneg (_ bv1 N)); while (div x 0) is undefined + smtlib = re.sub( + r"\(declare-fun f_evm_(bvudiv|bvurem|bvsdiv|bvsrem)_([0-9]+) \(\(_ BitVec \2\) \(_ BitVec \2\)\) \(_ BitVec \2\)\)", + r"(define-fun f_evm_\1_\2 ((x (_ BitVec \2)) (y (_ BitVec \2))) (_ BitVec \2) (ite (= y (_ bv0 \2)) (_ bv0 \2) (\1 x y)))", + smtlib, + ) + + return SMTQuery(smtlib, query.assertions) diff --git a/src/halmos/traces.py b/src/halmos/traces.py new file mode 100644 index 00000000..c5d58bc8 --- /dev/null +++ b/src/halmos/traces.py @@ -0,0 +1,160 @@ +import io +import sys + +from z3 import Z3_OP_CONCAT, BitVecNumRef, BitVecRef, is_app + +from halmos.bytevec import ByteVec +from halmos.exceptions import HalmosException +from halmos.mapper import DeployAddressMapper, Mapper +from halmos.sevm import CallContext, EventLog, mnemonic +from halmos.utils import ( + byte_length, + cyan, + green, + hexify, + is_bv, + red, + unbox_int, + yellow, +) + + +def rendered_initcode(context: CallContext) -> str: + message = context.message + data = message.data + + initcode_str = "" + args_str = "" + + if ( + isinstance(data, BitVecRef) + and is_app(data) + and data.decl().kind() == Z3_OP_CONCAT + ): + children = [arg for arg in data.children()] + if isinstance(children[0], BitVecNumRef): + initcode_str = hex(children[0].as_long()) + args_str = ", ".join(map(str, children[1:])) + else: + initcode_str = hexify(data) + + return f"{initcode_str}({cyan(args_str)})" + + +def render_output(context: CallContext, file=sys.stdout) -> None: + output = context.output + returndata_str = "0x" + failed = output.error is not None + + if not failed and context.is_stuck(): + return + + data = output.data + if data is not None: + is_create = context.message.is_create() + if hasattr(data, "unwrap"): + data = data.unwrap() + + returndata_str = ( + f"<{byte_length(data)} bytes of code>" + if (is_create and not failed) + else hexify(data) + ) + + ret_scheme = context.output.return_scheme + ret_scheme_str = f"{cyan(mnemonic(ret_scheme))} " if ret_scheme is not None else "" + error_str = f" (error: {repr(output.error)})" if failed else "" + + color = red if failed else green + indent = context.depth * " " + print( + f"{indent}{color('↩ ')}{ret_scheme_str}{color(returndata_str)}{color(error_str)}", + file=file, + ) + + +def rendered_log(log: EventLog) -> str: + opcode_str = f"LOG{len(log.topics)}" + topics = [ + f"{cyan(f'topic{i}')}={hexify(topic)}" for i, topic in enumerate(log.topics) + ] + data_str = f"{cyan('data')}={hexify(log.data)}" + args_str = ", ".join(topics + [data_str]) + + return f"{opcode_str}({args_str})" + + +def rendered_trace(context: CallContext) -> str: + with io.StringIO() as output: + render_trace(context, file=output) + return output.getvalue() + + +def rendered_calldata(calldata: ByteVec, contract_name: str | None = None) -> str: + if not calldata: + return "0x" + + if len(calldata) < 4: + return hexify(calldata) + + if len(calldata) == 4: + return f"{hexify(calldata.unwrap(), contract_name)}()" + + selector = calldata[:4].unwrap() + args = calldata[4:].unwrap() + return f"{hexify(selector, contract_name)}({hexify(args)})" + + +def render_trace(context: CallContext, file=sys.stdout) -> None: + message = context.message + addr = unbox_int(message.target) + addr_str = str(addr) if is_bv(addr) else hex(addr) + # check if we have a contract name for this address in our deployment mapper + addr_str = DeployAddressMapper().get_deployed_contract(addr_str) + + value = unbox_int(message.value) + value_str = f" (value: {value})" if is_bv(value) or value > 0 else "" + + call_scheme_str = f"{cyan(mnemonic(message.call_scheme))} " + indent = context.depth * " " + + if message.is_create(): + # TODO: select verbosity level to render full initcode + # initcode_str = rendered_initcode(context) + + try: + if context.output.error is None: + target = hex(int(str(message.target))) + bytecode = context.output.data.unwrap().hex() + contract_name = Mapper().get_by_bytecode(bytecode).contract_name + + DeployAddressMapper().add_deployed_contract(target, contract_name) + addr_str = contract_name + except Exception: + # TODO: print in debug mode + ... + + initcode_str = f"<{byte_length(message.data)} bytes of initcode>" + print( + f"{indent}{call_scheme_str}{addr_str}::{initcode_str}{value_str}", file=file + ) + + else: + calldata = rendered_calldata(message.data, addr_str) + call_str = f"{addr_str}::{calldata}" + static_str = yellow(" [static]") if message.is_static else "" + print(f"{indent}{call_scheme_str}{call_str}{static_str}{value_str}", file=file) + + log_indent = (context.depth + 1) * " " + for trace_element in context.trace: + if isinstance(trace_element, CallContext): + render_trace(trace_element, file=file) + elif isinstance(trace_element, EventLog): + print(f"{log_indent}{rendered_log(trace_element)}", file=file) + else: + raise HalmosException(f"unexpected trace element: {trace_element}") + + render_output(context, file=file) + + if context.depth == 1: + print(file=file) diff --git a/tests/test_halmos.py b/tests/test_halmos.py index 3cd40d41..a8aba028 100644 --- a/tests/test_halmos.py +++ b/tests/test_halmos.py @@ -3,9 +3,10 @@ import pytest -from halmos.__main__ import _main, rendered_calldata +from halmos.__main__ import _main from halmos.bytevec import ByteVec from halmos.sevm import con +from halmos.traces import rendered_calldata @pytest.mark.parametrize( diff --git a/tests/test_solve.py b/tests/test_solve.py new file mode 100644 index 00000000..7f9d76cb --- /dev/null +++ b/tests/test_solve.py @@ -0,0 +1,112 @@ +import pytest + +from halmos.solve import ModelVariable, parse_model_str + + +@pytest.mark.parametrize( + "full_name", + [ + "halmos_y_uint256_043cfd7_01", + "p_y_uint256_043cfd7_01", + ], +) +def test_smtlib_z3_bv_output(full_name): + smtlib_str = f""" + (define-fun {full_name} () (_ BitVec 256) + #x0000000000000000000000000000000000000000000000000000000000000000) + """ + model = parse_model_str(smtlib_str) + + assert model[full_name] == ModelVariable( + full_name=full_name, + variable_name="y", + solidity_type="uint256", + smt_type="BitVec 256", + size_bits=256, + value=0, + ) + + +# note that yices only produces output like this with --smt2-model-format +# otherwise we get something like (= x #b00000100) +@pytest.mark.parametrize( + "full_name", + [ + "halmos_z_uint256_cabf047_02", + "p_z_uint256_cabf047_02", + ], +) +def test_smtlib_yices_binary_output(full_name): + smtlib_str = f""" + (define-fun + {full_name} + () + (_ BitVec 256) + #b1000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000) + """ + model = parse_model_str(smtlib_str) + assert model[full_name] == ModelVariable( + full_name=full_name, + variable_name="z", + solidity_type="uint256", + smt_type="BitVec 256", + size_bits=256, + value=1 << 255, + ) + + +@pytest.mark.parametrize( + "full_name", + [ + "halmos_z_uint256_11ce021_08", + "p_z_uint256_11ce021_08", + ], +) +def test_smtlib_yices_decimal_output(full_name): + val = 57896044618658097711785492504343953926634992332820282019728792003956564819968 + smtlib_str = f""" + (define-fun {full_name} () (_ BitVec 256) (_ bv{val} 256)) + """ + model = parse_model_str(smtlib_str) + assert model[full_name] == ModelVariable( + full_name=full_name, + variable_name="z", + solidity_type="uint256", + smt_type="BitVec 256", + size_bits=256, + value=val, + ) + + +@pytest.mark.parametrize( + "full_name", + [ + "halmos_x_uint8_043cfd7_01", + "p_x_uint8_043cfd7_01", + ], +) +def test_smtlib_stp_output(full_name): + # we should tolerate: + # - the extra (model) command + # - duplicate variable names + # - the initial `sat` result + # - the `|` around the variable name + # - the space in `( define-fun ...)` + smtlib_str = f""" + sat + (model + ( define-fun |{full_name}| () (_ BitVec 8) #x04 ) + ) + (model + ( define-fun |{full_name}| () (_ BitVec 8) #x04 ) + ) + """ + model = parse_model_str(smtlib_str) + assert model[full_name] == ModelVariable( + full_name=full_name, + variable_name="x", + solidity_type="uint8", + smt_type="BitVec 8", + size_bits=8, + value=4, + )