Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Circuit] Support mutable IO #1351

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 44 additions & 23 deletions magma/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,9 @@ def make_interface(self):
def __add__(self, other):
raise NotImplementedError()

@abstractmethod
def __iadd__(self, other):
# __iadd__ is explicitly overriden to enforce that it is non-mutating.
return self + other
raise NotImplementedError()

def flip(self) -> "IOInterface":
raise NotImplementedError()
Expand All @@ -436,11 +436,9 @@ class IO(IOInterface):
# https://www.python.org/dev/peps/pep-0468/.
def __init__(self, **kwargs):
self._ports = {}
self._decl = []
self._bound = False
for name, typ in kwargs.items():
self.add(name, typ)
self._decl.extend((name, typ))

@property
def ports(self):
Expand All @@ -454,34 +452,51 @@ def bind(self, defn):
self._bound = True

def decl(self):
return self._decl
return _flatten(
(name, type(port).flip()) for name, port in self._ports.items()
)

def make_interface(self):
decl = self.decl()
name = _make_interface_name(decl)
dct = dict(_io=self, _decl=decl, _initialized=False)
return InterfaceKind(name, (_DeclareSingletonInterface,), dct)

def __add__(self, other):
"""
Attempts to combine this IO and @other. Returns a new IO object with the
combined ports, unless:
* @other is not of type IOInterface, in which case a TypeError is
raised
def __add__(self, other: 'IO') -> 'IO':
"""Attempts to combine this IO and @other. Returns a new IO object with
the combined ports, unless:
* @other is not of type IO, in which case a TypeError is raised
* this or @other has already been bound, in which case an Exception is
raised
* this and @other have common port names, in which case an Exception
is raised
"""
if not isinstance(other, IOInterface):
raise TypeError(f"unsupported operand type(s) for +: 'IO' and "
f"'{type(other).__name__}'")
if not isinstance(other, IO):
raise TypeError(
f"unsupported operand type(s) for +: 'IO' and "
f"'{type(other).__name__}'"
)
if self._bound or other._bound:
raise Exception("Adding bound IO not allowed")
if self._ports.keys() & other._ports.keys():
raise Exception("Adding IO with duplicate port names not allowed")
Comment on lines +479 to +482
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should move this to a helper function so it can share code with __iadd__.

return IO(**_dict_from_decl(self.decl() + other.decl()))

def __iadd__(self, other: 'IO') -> 'IO':
"""Attempts to combine this @IO and other in place, with the same
caveats as __add__.
"""
if not isinstance(other, IO):
raise TypeError(
f"unsupported operand type(s) for +: 'IO' and "
f"'{type(other).__name__}'"
)
if self._bound or other._bound:
raise Exception("Adding bound IO not allowed")
if self._ports.keys() & other._ports.keys():
raise Exception("Adding IO with duplicate port names not allowed")
decl = self._decl + other._decl
return IO(**_dict_from_decl(decl))
self._ports.update(other._ports)
return self

def add(self, name, typ):
if self._bound:
Expand All @@ -505,7 +520,7 @@ def __getattr__(self, key: str):
return super().__getattribute__(key)

def fields(self):
return _dict_from_decl(self._decl)
return _dict_from_decl(self.decl())

def flip(self):
return IO(**{name: T.flip() for name, T in self.fields().items()})
Expand Down Expand Up @@ -534,14 +549,19 @@ def inst_ports(self):
return self._inst_ports.copy()

def decl(self):
return _flatten((name, type(port))
for name, port in self._ports.items())
return _flatten(
(name, type(port)) for name, port in self._ports.items()
)

def make_interface(self):
decl = self.decl()
name = _make_interface_name(decl)
dct = dict(_io=self, _decl=decl, _initialized=False,
_initialized_inst=False)
dct = {
"_io": self,
"_decl": decl,
"_initialized": False,
"_initialized_inst": False,
}
return InterfaceKind(name, (_DeclareSingletonInstanceInterface,), dct)

def add(self, name, typ):
Expand All @@ -552,8 +572,9 @@ def add(self, name, typ):
self._inst_ports[name] = inst_port

def __add__(self, other):
raise NotImplementedError(f"Addition operator disallowed on "
f"{cls.__name__}")
raise NotImplementedError(
f"Addition operator disallowed on {cls.__name__}"
)

def flip(self):
raise NotImplementedError()
7 changes: 7 additions & 0 deletions tests/test_interface/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def test_add_intersecting_io(caplog):
)


def test_iadd():
io = m.IO(a=m.In(m.Bit))
a = io.a
io += m.IO(b=m.In(m.Bit))
assert io.a is a


def test_flip():
A = m.Product.from_fields("anon", dict(x=m.In(m.Bit), y=m.Out(m.Bit)))
B = m.In(m.Bits[8])
Expand Down
Loading