Skip to content

Commit

Permalink
Fix mypy issues with NotImplemented return and other assignments(#204)
Browse files Browse the repository at this point in the history
Co-authored-by: Dimitri Kartsaklis <[email protected]>
  • Loading branch information
neiljdo and dimkart authored Feb 17, 2025
1 parent 6667756 commit b0e5765
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 11 deletions.
14 changes: 8 additions & 6 deletions lambeq/backend/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,20 +224,22 @@ def tensor(self, other: Iterable[Self]) -> Self: ...
@overload
def tensor(self, other: Self, *rest: Self) -> Self: ...

def tensor(self, other: Self | Iterable[Self], *rest: Self) -> Self:
def tensor(self,
other: Self | Iterable[Self],
*rest: Self) -> Self:
try:
tys = [*other, *rest]
except TypeError:
return NotImplemented
return NotImplemented # type: ignore[no-any-return]

# Diagrams are iterable - the identity diagram has
# an empty list for its layers but may still contain types
if getattr(other, 'is_id', False):
return NotImplemented
return NotImplemented # type: ignore[no-any-return]

if any(not isinstance(ty, type(self))
or self.category != ty.category for ty in tys):
return NotImplemented
return NotImplemented # type: ignore[no-any-return]

return self._fromiter(ob for ty in (self, *tys) for ob in ty)

Expand Down Expand Up @@ -905,7 +907,7 @@ def tensor(self, *diagrams: Diagrammable | Ty) -> Self:
try:
diags = self.lift([self, *diagrams])
except ValueError:
return NotImplemented
return NotImplemented # type: ignore[no-any-return]

right = dom = self.dom.tensor(*[
diagram.to_diagram().dom for diagram in diagrams
Expand Down Expand Up @@ -964,7 +966,7 @@ def then(self, *diagrams: Diagrammable) -> Self:
try:
diags = self.lift(diagrams)
except ValueError:
return NotImplemented
return NotImplemented # type: ignore[no-any-return]

layers = [*self.layers]
cod = self.cod
Expand Down
2 changes: 1 addition & 1 deletion lambeq/backend/pregroup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def is_same_word(self, other: object) -> bool:
which doesn't check equality of the children - essentially,
this just checks if `other` is the same token."""
if not isinstance(other, PregroupTreeNode):
return NotImplemented
return NotImplemented # type: ignore[no-any-return]
return (self.word == other.word
and self.ind == other.ind)

Expand Down
4 changes: 2 additions & 2 deletions lambeq/training/pennylane_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class PennyLaneModel(Model, torch.nn.Module):
"""

weights: torch.nn.ParameterList
weights: torch.nn.ParameterList # type: ignore[assignment]
symbols: list[Symbol]

def __init__(self,
Expand Down Expand Up @@ -131,7 +131,7 @@ def _reinitialise_modules(self) -> None:
"""Reinitialise all modules in the model."""
for module in self.modules():
try:
module.reset_parameters()
module.reset_parameters() # type: ignore[operator]
except (AttributeError, TypeError):
pass

Expand Down
4 changes: 2 additions & 2 deletions lambeq/training/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
class PytorchModel(Model, torch.nn.Module):
"""A lambeq model for the classical pipeline using PyTorch."""

weights: torch.nn.ParameterList
weights: torch.nn.ParameterList # type: ignore[assignment]
symbols: list[Symbol]

def __init__(self) -> None:
Expand All @@ -47,7 +47,7 @@ def _reinitialise_modules(self) -> None:
"""Reinitialise all modules in the model."""
for module in self.modules():
try:
module.reset_parameters()
module.reset_parameters() # type: ignore[operator]
except (AttributeError, TypeError):
pass

Expand Down

0 comments on commit b0e5765

Please sign in to comment.