From 1646e23bb9aa2cd4110030d782ba31c132b764f6 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Thu, 20 Feb 2025 11:43:39 -0800 Subject: [PATCH] handle more opcodes (WIP: bitwise) --- src/halmos/bitvec.py | 163 ++++++++++++++++++++++++++----------------- src/halmos/sevm.py | 157 +++++++++++++---------------------------- src/halmos/utils.py | 10 ++- 3 files changed, 157 insertions(+), 173 deletions(-) diff --git a/src/halmos/bitvec.py b/src/halmos/bitvec.py index c217ef43..146c1e82 100644 --- a/src/halmos/bitvec.py +++ b/src/halmos/bitvec.py @@ -7,12 +7,15 @@ BoolRef, BoolVal, Extract, + FuncDeclRef, If, + LShR, Not, Or, UDiv, URem, ZeroExt, + eq, is_bv_value, simplify, ) @@ -26,6 +29,10 @@ AnyBool: TypeAlias = bool | BoolRef +def is_power_of_two(x: int) -> bool: + return x > 0 and not (x & (x - 1)) + + def as_int(value: AnyValue) -> tuple[BVValue, MaybeSize]: if isinstance(value, int): return value, None @@ -73,8 +80,9 @@ def __new__(cls, value): return super().__new__(cls) def __init__(self, value: AnyBool): - if isinstance(value, HalmosBitVec): - value = value.value != 0 + # avoid reinitializing HalmosBool because of __new__ shortcut + if isinstance(value, HalmosBool): + return self._value = value if isinstance(value, bool): @@ -110,9 +118,13 @@ def unwrap(self) -> AnyBool: return self._value @property - def symbolic(self) -> bool: + def is_symbolic(self) -> bool: return self._symbolic + @property + def is_concrete(self) -> bool: + return not self._symbolic + @property def value(self) -> AnyBool: return self._value @@ -210,6 +222,10 @@ def __init__( If do_simplify is True, the value will be simplified using z3's simplify function. """ + # avoid reinitializing HalmosBitVec because of __new__ shortcut + if isinstance(value, HalmosBitVec): + return + # unwrap HalmosBool if isinstance(value, HalmosBool): value = value.unwrap() @@ -249,9 +265,13 @@ def size(self) -> int: return self._size @property - def symbolic(self) -> bool: + def is_symbolic(self) -> bool: return self._symbolic + @property + def is_concrete(self) -> bool: + return not self._symbolic + @property def value(self) -> int | BitVecRef: return self._value @@ -348,23 +368,47 @@ def __rsub__(self, other: AnyValue) -> "HalmosBitVec": return HalmosBitVec(other_value - self._value, size=size) - def __mul__(self, other: AnyValue) -> "HalmosBitVec": + def __mul__(self, other: BV) -> "HalmosBitVec": + return self.mul(other) + + def mul( + self, other: BV, *, abstraction: FuncDeclRef | None = None + ) -> "HalmosBitVec": size = self._size + assert size == other.size - if isinstance(other, HalmosBitVec): - assert size == other._size - if not other._symbolic and other._value == 1: - return self - return HalmosBitVec(self._value * other._value, size=size) + lhs, rhs = self.value, other.value + print(f"{self=}, {other=}") - other_value, other_size = as_int(other) - assert other_size is None or other_size == size + match (self.is_concrete, other.is_concrete): + case (True, True): + return HalmosBitVec(lhs * rhs, size=size) - # If we multiply by 1, no new object needed - if isinstance(other_value, int) and other_value == 1: - return self + case (True, False): + if lhs == 0: + return self + + if lhs == 1: + return other - return HalmosBitVec(self._value * other_value, size=size) + if is_power_of_two(lhs): + return other << (lhs.bit_length() - 1) + + case (False, True): + if rhs == 0: + return other + + if rhs == 1: + return self + + if is_power_of_two(rhs): + return self << (rhs.bit_length() - 1) + + return ( + HalmosBitVec(lhs * rhs, size=size) + if abstraction is None + else HalmosBitVec(abstraction(lhs, rhs), size=size) + ) def __rmul__(self, other: AnyValue) -> "HalmosBitVec": # just reuse __mul__ @@ -486,67 +530,60 @@ def __lshift__(self, shift: AnyValue) -> "HalmosBitVec": return HalmosBitVec(self._value << shift_value, size=size) def __rshift__(self, shift: AnyValue) -> "HalmosBitVec": + raise NotImplementedError("ambiguous, use lshr or ashr") + + def lshr(self, shift: BV) -> "HalmosBitVec": """ - Logical right shift by shift bits. - Python's >> is an arithmetic shift for negative ints, - but if we're dealing with unsigned logic, mask out as needed. - For Z3, use LShR if you want a logical shift: LShR(a, b). + Logical right shift """ + size = self._size - if isinstance(shift, HalmosBitVec): - assert size == shift._size - if not shift._symbolic and shift._value == 0: + # check for no-op + if shift.is_concrete: + if shift.value == 0: return self - # for symbolic, might want z3.LShR - # if you stored it in self._value, do that: - from z3 import LShR - return HalmosBitVec(LShR(self._value, shift._value), size=size) + if shift.value >= size: + return HalmosBitVec(0, size=size) - shift_value, shift_size = as_int(shift) - assert shift_size is None or shift_size == size + return HalmosBitVec(LShR(self.wrapped(), shift.wrapped()), size=size) - if isinstance(shift_value, int) and shift_value == 0: + def ashr(self, shift: BV) -> "HalmosBitVec": + """ + Arithmetic right shift + """ + + # check for no-op + if shift.is_concrete and shift.value == 0: return self - # for concrete - if isinstance(shift_value, int): - # plain Python >> does an arithmetic shift if self._value < 0, but presumably we treat as unsigned - # so do standard python right shift for positives or mask out if needed - return HalmosBitVec(self._value >> shift_value, size=size) - else: - # symbolic shift - from z3 import LShR + return HalmosBitVec(self.wrapped() >> shift.value, size=self.size) - return HalmosBitVec(LShR(self._value, shift_value), size) + def __invert__(self) -> BV: + # TODO: handle concrete case + return HalmosBitVec(~self.wrapped(), size=self.size) - def __rlshift__(self, shift: AnyValue) -> "HalmosBitVec": - """ - shift << self - """ - # same pattern as other r-operations - if isinstance(shift, HalmosBitVec): - assert shift._size == self._size - return shift.__lshift__(self) - # fallback - shift_value, shift_size = as_int(shift) - # just do shift_value << self._value - # careful about symbolic - return HalmosBitVec(shift_value << self._value, self._size) + def __lt__(self, other: BV) -> HalmosBool: + return self.ult(other) - def __rrshift__(self, shift: AnyValue) -> "HalmosBitVec": - """ - shift >> self - """ - if isinstance(shift, HalmosBitVec): - assert shift._size == self._size - return shift.__rshift__(self) - shift_value, shift_size = as_int(shift) - # do shift_value >> self._value - from z3 import LShR + def __gt__(self, other: BV) -> HalmosBool: + return self.ugt(other) + + def __le__(self, other: BV) -> HalmosBool: + return self.ule(other) + + def __ge__(self, other: BV) -> HalmosBool: + return self.uge(other) + + def __eq__(self, other: BV) -> bool: + if self.is_symbolic and other.is_symbolic: + return self.size == other.size and eq(self.value, other.value) + + if self.is_concrete and other.is_concrete: + return self.size == other.size and self.value == other.value - return HalmosBitVec(LShR(shift_value, self._value), self._size) + return False def ult(self, other: BV) -> HalmosBool: assert self._size == other._size diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index e7105039..54beb682 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -38,15 +38,12 @@ Extract, Function, If, - LShR, Or, Select, SignExt, Solver, SRem, Store, - UDiv, - URem, Xor, ZeroExt, eq, @@ -509,7 +506,7 @@ def pop(self) -> Word: def popi(self) -> Word: """The stack can contain BitVecs or Bools -- this function converts Bools to BitVecs""" - return b2i(self.pop()) + return BV(self.pop()) def peek(self, n: int = 1) -> Word: return self.stack[-n] @@ -1760,7 +1757,7 @@ def bitwise(op, x: Word, y: Word) -> Word: return bitwise(op, b2i(x), b2i(y)) -def b2i(w: BitVecRef | BoolRef) -> BitVecRef: +def b2i(w: BitVecRef | BoolRef) -> BV: """ Convert a boolean or bitvector to a bitvector. """ @@ -1776,13 +1773,6 @@ def b2i(w: BitVecRef | BoolRef) -> BitVecRef: # return w -def is_power_of_two(x: int) -> bool: - if x > 0: - return not (x & (x - 1)) - else: - return False - - class HalmosLogs: bounded_loops: list[JumpID] @@ -1878,10 +1868,6 @@ def mk_mod(self, ex: Exec, x: Any, y: Any) -> Any: # ex.path.append(Or(y == ZERO, ULT(term, y))) # (x % y) < y if y != 0 return term - def mk_mul(self, ex: Exec, x: Any, y: Any) -> Any: - term = f_mul[x.size()](x, y) - return term - def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word: w1 = b2i(w1) w2 = b2i(w2) @@ -1893,79 +1879,50 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word: return w1 - w2 if op == EVM.MUL: - is_bv_value_w1 = is_bv_value(w1) - is_bv_value_w2 = is_bv_value(w2) - - if is_bv_value_w1 and is_bv_value_w2: - return w1 * w2 - - if is_bv_value_w1: - i1: int = w1.as_long() - if i1 == 0: - return w1 - - if i1 == 1: - return w2 - - if is_power_of_two(i1): - return w2 << (i1.bit_length() - 1) - - if is_bv_value_w2: - i2: int = w2.as_long() - if i2 == 0: - return w2 - - if i2 == 1: - return w1 - - if is_power_of_two(i2): - return w1 << (i2.bit_length() - 1) - - if is_bv_value_w1 or is_bv_value_w2: - return w1 * w2 - - return self.mk_mul(ex, w1, w2) + return w1.mul(w2, abstraction=f_mul[w1.size]) if op == EVM.DIV: - div_for_overflow_check = self.div_xy_y(w1, w2) - if div_for_overflow_check is not None: # xy/x or xy/y - return div_for_overflow_check + # TODO: move to bitvec.py + # div_for_overflow_check = self.div_xy_y(w1, w2) + # if div_for_overflow_check is not None: # xy/x or xy/y + # return div_for_overflow_check - if is_bv_value(w1) and is_bv_value(w2): - if w2.as_long() == 0: - return w2 - else: - return UDiv(w1, w2) # unsigned div (bvudiv) + # if is_bv_value(w1) and is_bv_value(w2): + # if w2.as_long() == 0: + # return w2 + # else: + # return UDiv(w1, w2) # unsigned div (bvudiv) - if is_bv_value(w2): - # concrete denominator case - i2: int = w2.as_long() - if i2 == 0: - return w2 + # if is_bv_value(w2): + # # concrete denominator case + # i2: int = w2.as_long() + # if i2 == 0: + # return w2 - if i2 == 1: - return w1 + # if i2 == 1: + # return w1 - if is_power_of_two(i2): - return LShR(w1, i2.bit_length() - 1) + # if is_power_of_two(i2): + # return LShR(w1, i2.bit_length() - 1) return self.mk_div(ex, w1, w2) if op == EVM.MOD: - if is_bv_value(w1) and is_bv_value(w2): - if w2.as_long() == 0: - return w2 - else: - return URem(w1, w2) # bvurem - - if is_bv_value(w2): - i2: int = int(str(w2)) - if i2 == 0 or i2 == 1: - return con(0, w2.size()) - - if is_power_of_two(i2): - bitsize = i2.bit_length() - 1 - return ZeroExt(w2.size() - bitsize, Extract(bitsize - 1, 0, w1)) + # TODO: move to bitvec.py + # if is_bv_value(w1) and is_bv_value(w2): + # if w2.as_long() == 0: + # return w2 + # else: + # return URem(w1, w2) # bvurem + + # if is_bv_value(w2): + # i2: int = int(str(w2)) + # if i2 == 0 or i2 == 1: + # return con(0, w2.size()) + + # if is_power_of_two(i2): + # bitsize = i2.bit_length() - 1 + # return ZeroExt(w2.size() - bitsize, Extract(bitsize - 1, 0, w1)) return self.mk_mod(ex, w1, w2) @@ -2624,22 +2581,12 @@ def jumpi( jid = ex.jumpi_id() target: int = ex.int_of(ex.st.pop(), "symbolic JUMPI target") - cond: Word = ex.st.pop() + cond = Bool(ex.st.pop()) visited = ex.jumpis.get(jid, {True: 0, False: 0}) - # print(f"{cond=} {type(cond)=}") - # print(f"{is_non_zero(cond)=} {type(is_non_zero(cond))=}") - # print(f"{z3_bv(is_non_zero(cond))=} {type(z3_bv(is_non_zero(cond)))=}") - - cond_true = simplify(cond.is_non_zero().wrapped()) - cond_false = simplify(cond.is_zero().wrapped()) - - # cond_true = simplify(is_non_zero(cond)) - # cond_false = simplify(is_zero(cond)) - - # print(f"{cond_true=} {type(cond_true)=}") - # print(f"{cond_false=} {type(cond_false)=}") + cond_true = simplify(cond.wrapped()) + cond_false = simplify(cond.neg().wrapped()) potential_true: bool = ex.check(cond_true) != unsat potential_false: bool = ex.check(cond_false) != unsat @@ -2920,7 +2867,7 @@ def finalize(ex: Exec): elif opcode == EVM.LT: w1 = ex.st.popi() w2 = ex.st.popi() - ex.st.push(ULT(w1, w2)) # bvult + ex.st.push(w1 < w2) # bvult elif opcode == EVM.GT: w1 = ex.st.popi() @@ -2941,19 +2888,13 @@ def finalize(ex: Exec): w1 = ex.st.pop() w2 = ex.st.pop() - if eq(w1.sort(), w2.sort()): - ex.st.push(w1 == w2) - else: - if is_bool(w1): - if not is_bv(w2): - raise ValueError(w2) - ex.st.push(If(w1, ONE, ZERO) == w2) - else: - if not is_bv(w1): - raise ValueError(w1) - if not is_bool(w2): - raise ValueError(w2) - ex.st.push(w1 == If(w2, ONE, ZERO)) + match (w1, w2): + case (Bool(), Bool()): + ex.st.push(w1 == w2) + case (BV(), BV()): + ex.st.push(w1 == w2) + case (_, _): + ex.st.push(BV(w1) == BV(w2)) elif opcode == EVM.ISZERO: ex.st.push(is_zero(ex.st.pop())) @@ -2972,12 +2913,12 @@ def finalize(ex: Exec): elif opcode == EVM.SAR: w1 = ex.st.popi() w2 = ex.st.popi() - ex.st.push(w2 >> w1) # bvashr + ex.st.push(w2.ashr(w1)) # bvashr elif opcode == EVM.SHR: w1 = ex.st.popi() w2 = ex.st.popi() - ex.st.push(LShR(w2, w1)) # bvlshr + ex.st.push(w2.lshr(w1)) # bvlshr elif opcode == EVM.SIGNEXTEND: w = ex.int_of(ex.st.popi(), "symbolic SIGNEXTEND size") diff --git a/src/halmos/utils.py b/src/halmos/utils.py index 3b4b4707..6c71ae0d 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -212,7 +212,10 @@ def con(n: int, size_bits=256) -> Word: def z3_bv(x: Any) -> BitVecRef: if isinstance(x, BV): - return x.value if x.symbolic else con(x.unwrap(), size_bits=x.size) + return x.wrapped() + + if isinstance(x, Bool): + return BV(x).wrapped() # must check before int because isinstance(True, int) is True if isinstance(x, bool): @@ -270,7 +273,10 @@ def is_zero(x: Word) -> Bool: def is_concrete(x: Any) -> bool: if isinstance(x, BV): - return not x.symbolic + return x.is_concrete + + if isinstance(x, Bool): + return x.is_concrete return isinstance(x, int | bytes) or is_bv_value(x)