Skip to content

Commit

Permalink
Merge pull request #852 from phanrahan/struct-eq
Browse files Browse the repository at this point in the history
Use Kind.__eq__ for structural equality
  • Loading branch information
leonardt authored Sep 22, 2020
2 parents 2a78772 + 2582d93 commit b519419
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 56 deletions.
2 changes: 1 addition & 1 deletion magma/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __getitem__(cls, index: tuple) -> 'ArrayMeta':


if cls.is_concrete:
if index == (cls.N, cls.T):
if index[0] == cls.N and index[1] is cls.T:
return cls
else:
return cls.abstract_t[index]
Expand Down
16 changes: 10 additions & 6 deletions magma/clock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from .t import Direction, In
from .digital import DigitalMeta, Digital
from .wire import wire
Expand Down Expand Up @@ -160,16 +162,18 @@ def wireclock(define, circuit):
wireclocktype(define, circuit, Enable)


def get_reset_args(reset_type: AbstractReset = None):
if reset_type is not None and not issubclass(reset_type, AbstractReset):
def get_reset_args(reset_type: Optional[AbstractReset]):
if reset_type is None:
return tuple(False for _ in range(4))
if not issubclass(reset_type, AbstractReset):
raise TypeError(
f"Expected subclass of AbstractReset for argument reset_type, "
f"not {type(reset_type)}")

has_async_reset = reset_type == AsyncReset
has_async_resetn = reset_type == AsyncResetN
has_reset = reset_type == Reset
has_resetn = reset_type == ResetN
has_async_reset = issubclass(reset_type, AsyncReset)
has_async_resetn = issubclass(reset_type, AsyncResetN)
has_reset = issubclass(reset_type, Reset)
has_resetn = issubclass(reset_type, ResetN)
return (has_async_reset, has_async_resetn, has_reset, has_resetn)


Expand Down
2 changes: 1 addition & 1 deletion magma/digital.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def qualify(cls, direction):
return cls[direction]

def __eq__(cls, rhs):
return cls is rhs
return isinstance(rhs, DigitalMeta)

def is_wireable(cls, rhs):
rhs = magma_type(rhs)
Expand Down
7 changes: 6 additions & 1 deletion magma/syntax/sequential2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __setitem__(self, i, value):


class _SequentialRegisterWrapperMeta(MagmaProtocolMeta):
_cache = {}

def _to_magma_(cls):
return cls.T

Expand All @@ -55,7 +57,10 @@ def _is_oriented_magma_(cls, direction):
return cls.T.is_oriented(direction)

def __getitem__(cls, T):
return type(cls)(f"_SequentialRegisterWrapper{T}", (cls, ), {"T": T})
if T not in cls._cache:
cls._cache[T] = type(cls)(f"_SequentialRegisterWrapper{T}",
(cls, ), {"T": T})
return cls._cache[T]

def __eq__(cls, rhs):
if not isinstance(rhs, _SequentialRegisterWrapperMeta):
Expand Down
23 changes: 15 additions & 8 deletions magma/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _from_idx(cls, idx):
raise TypeError('Type is already bound')

undirected_idx = tuple(v.qualify(Direction.Undirected) for v in idx)
if undirected_idx != idx:
if any(x is not y for x, y in zip(undirected_idx, idx)):
bases = [cls[undirected_idx]]
else:
bases = [cls]
Expand Down Expand Up @@ -159,6 +159,17 @@ def is_bindable(cls, rhs):
return False
return True

def __eq__(cls, rhs):
if not isinstance(rhs, TupleKind):
return False

if not cls.is_bound:
return not rhs.is_bound

return cls.fields == rhs.fields

__hash__ = TupleMeta.__hash__


class Tuple(Type, Tuple_, metaclass=TupleKind):
def __init__(self, *largs, **kwargs):
Expand Down Expand Up @@ -463,8 +474,8 @@ def qualify(cls, direction):
for k, v in cls.field_dict.items():
if not issubclass(new_fields[k], v):
base = cls.unbound_t
if base.is_bound and all(v == base.field_dict[k] for k, v in
new_fields.items()):
if base.is_bound and all(v is base.field_dict[k]
for k, v in new_fields.items()):
return base

if cls.unbound_t is AnonProduct:
Expand Down Expand Up @@ -494,11 +505,7 @@ def __eq__(cls, rhs):
if not cls.is_bound:
return not rhs.is_bound

for k, v in cls.field_dict.items():
if getattr(rhs, k) is not v:
return False

return True
return cls.field_dict == rhs.field_dict

def is_wireable(cls, rhs):
rhs = magma_type(rhs)
Expand Down
23 changes: 17 additions & 6 deletions tests/test_type/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ def test_array():
assert B2 == B2
assert C2 == C2

assert A2 != B2
assert A2 != C2
assert B2 != C2
# Structural equality
assert A2 == B2
assert A2 == C2
assert B2 == C2

# Nominal equality
assert A2 is not B2
assert A2 is not C2
assert B2 is not C2

A4 = Array[4,Bit]
assert A4 == Array4
Expand Down Expand Up @@ -81,16 +87,21 @@ def test_val():

a3 = a1[0:2]


def test_flip():
AIn = In(Array2)
AOut = Out(Array2)

print(AIn)
print(AOut)

assert AIn != Array2
assert AOut != Array2
assert AIn != AOut
assert AIn == Array2
assert AOut == Array2
assert AIn == AOut

assert AIn is not Array2
assert AOut is not Array2
assert AIn is not AOut

A = In(AOut)
assert A == AIn
Expand Down
10 changes: 7 additions & 3 deletions tests/test_type/test_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ def test_bit():
assert m.BitIn == m.BitIn
assert m.BitOut == m.BitOut

assert m.Bit != m.BitIn
assert m.Bit != m.BitOut
assert m.BitIn != m.BitOut
assert m.Bit == m.BitIn
assert m.Bit == m.BitOut
assert m.BitIn == m.BitOut

assert m.Bit is not m.BitIn
assert m.Bit is not m.BitOut
assert m.BitIn is not m.BitOut

assert str(m.Bit) == 'Bit'
assert str(m.BitIn) == 'In(Bit)'
Expand Down
20 changes: 14 additions & 6 deletions tests/test_type/test_bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ def test_bits_basic():
assert bits_in_2 == m.In(bits_2)
assert bits_out_2 == m.Out(bits_2)

assert bits_2 != bits_in_2
assert bits_2 != bits_out_2
assert bits_in_2 != bits_out_2
assert bits_2 == bits_in_2
assert bits_2 == bits_out_2
assert bits_in_2 == bits_out_2

assert bits_2 is not bits_in_2
assert bits_2 is not bits_out_2
assert bits_in_2 is not bits_out_2

bits_4 = m.Bits[4]
assert bits_4 == ARRAY4
Expand Down Expand Up @@ -90,9 +94,13 @@ def test_flip():
print(a_in)
print(a_out)

assert a_in != ARRAY2
assert a_out != ARRAY2
assert a_in != a_out
assert a_in == ARRAY2
assert a_out == ARRAY2
assert a_in == a_out

assert a_in is not ARRAY2
assert a_out is not ARRAY2
assert a_in is not a_out

in_a_out = m.In(a_out)
assert in_a_out == a_in
Expand Down
32 changes: 32 additions & 0 deletions tests/test_type/test_equality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import magma as m


def test_struct_eq():
assert m.Array[3, m.In(m.Bit)] == m.Array[3, m.Out(m.Bit)], """\
Direction should not matter when checking structural equality\
"""

class T1(m.Product):
x = m.Bit
y = m.Out(m.Bits[1])

class T2(m.Product):
x = m.In(m.Bit)
y = m.Bits[1]
assert T1 is not T2, "Different names are not nominally equal"
assert m.Array[3, T1] == m.Array[3, T2], """\
Products should match structurally, direction does not matter
"""

class T3(m.Product):
x = m.In(m.Bit)
assert m.Array[3, T1] != m.Array[3, T3], """\
Missing field should not match
"""

class T4(m.Product):
x = m.In(m.Bit)
z = m.Bits[1]
assert m.Array[3, T1] != m.Array[3, T4], """\
Different fields should not match
"""
26 changes: 17 additions & 9 deletions tests/test_type/test_sint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ def test():
assert B2 == B2
assert C2 == C2

assert A2 != B2
assert A2 != C2
assert B2 != C2
assert A2 == B2
assert A2 == C2
assert B2 == C2

# A4 = SInt[4]
# assert A4 == Array4
# assert A2 != A4
assert A2 is not B2
assert A2 is not C2
assert B2 is not C2

A4 = SInt[4]
assert A4 == Array4
assert A2 != A4


def test_val():
Expand Down Expand Up @@ -56,9 +60,13 @@ def test_flip():
print(AIn)
print(AOut)

assert AIn != Array2
assert AOut != Array2
assert AIn != AOut
assert AIn == Array2
assert AOut == Array2
assert AIn == AOut

assert AIn is not Array2
assert AOut is not Array2
assert AIn is not AOut

A = In(AOut)
assert A == AIn
Expand Down
20 changes: 14 additions & 6 deletions tests/test_type/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ class C2(Product, cache=True):
#assert str(C2) == 'Tuple(x=Out(Bit),y=Out(Bit))'
assert C2 == C2

assert A2 != B2
assert A2 != C2
assert B2 != C2
assert A2 == B2
assert A2 == C2
assert B2 == C2

assert A2 is not B2
assert A2 is not C2
assert B2 is not C2


def test_flip():
Expand All @@ -120,9 +124,13 @@ class Product2(Product):
print(Tin)
print(Tout)

assert Tin != Product2
assert Tout != Product2
assert Tin != Tout
assert Tin == Product2
assert Tout == Product2
assert Tin == Tout

assert Tin is not Product2
assert Tout is not Product2
assert Tin is not Tout

T = In(Tout)
assert T == Tin
Expand Down
26 changes: 17 additions & 9 deletions tests/test_type/test_uint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ def test():
assert B2 == B2
assert C2 == C2

assert A2 != B2
assert A2 != C2
assert B2 != C2
assert A2 == B2
assert A2 == C2
assert B2 == C2

# A4 = UInt[4]
# assert A4 == Array4
# assert A2 != A4
assert A2 is not B2
assert A2 is not C2
assert B2 is not C2

A4 = UInt[4]
assert A4 == Array4
assert A2 != A4


def test_val():
Expand Down Expand Up @@ -56,9 +60,13 @@ def test_flip():
print(AIn)
print(AOut)

assert AIn != Array2
assert AOut != Array2
assert AIn != AOut
assert AIn == Array2
assert AOut == Array2
assert AIn == AOut

assert AIn is not Array2
assert AOut is not Array2
assert AIn is not AOut

A = In(AOut)
assert A == AIn
Expand Down

0 comments on commit b519419

Please sign in to comment.