Skip to content

Commit b0e5765

Browse files
neiljdodimkart
andauthored
Fix mypy issues with NotImplemented return and other assignments(#204)
Co-authored-by: Dimitri Kartsaklis <[email protected]>
1 parent 6667756 commit b0e5765

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

lambeq/backend/grammar.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,20 +224,22 @@ def tensor(self, other: Iterable[Self]) -> Self: ...
224224
@overload
225225
def tensor(self, other: Self, *rest: Self) -> Self: ...
226226

227-
def tensor(self, other: Self | Iterable[Self], *rest: Self) -> Self:
227+
def tensor(self,
228+
other: Self | Iterable[Self],
229+
*rest: Self) -> Self:
228230
try:
229231
tys = [*other, *rest]
230232
except TypeError:
231-
return NotImplemented
233+
return NotImplemented # type: ignore[no-any-return]
232234

233235
# Diagrams are iterable - the identity diagram has
234236
# an empty list for its layers but may still contain types
235237
if getattr(other, 'is_id', False):
236-
return NotImplemented
238+
return NotImplemented # type: ignore[no-any-return]
237239

238240
if any(not isinstance(ty, type(self))
239241
or self.category != ty.category for ty in tys):
240-
return NotImplemented
242+
return NotImplemented # type: ignore[no-any-return]
241243

242244
return self._fromiter(ob for ty in (self, *tys) for ob in ty)
243245

@@ -905,7 +907,7 @@ def tensor(self, *diagrams: Diagrammable | Ty) -> Self:
905907
try:
906908
diags = self.lift([self, *diagrams])
907909
except ValueError:
908-
return NotImplemented
910+
return NotImplemented # type: ignore[no-any-return]
909911

910912
right = dom = self.dom.tensor(*[
911913
diagram.to_diagram().dom for diagram in diagrams
@@ -964,7 +966,7 @@ def then(self, *diagrams: Diagrammable) -> Self:
964966
try:
965967
diags = self.lift(diagrams)
966968
except ValueError:
967-
return NotImplemented
969+
return NotImplemented # type: ignore[no-any-return]
968970

969971
layers = [*self.layers]
970972
cod = self.cod

lambeq/backend/pregroup_tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def is_same_word(self, other: object) -> bool:
123123
which doesn't check equality of the children - essentially,
124124
this just checks if `other` is the same token."""
125125
if not isinstance(other, PregroupTreeNode):
126-
return NotImplemented
126+
return NotImplemented # type: ignore[no-any-return]
127127
return (self.word == other.word
128128
and self.ind == other.ind)
129129

lambeq/training/pennylane_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class PennyLaneModel(Model, torch.nn.Module):
4343
4444
"""
4545

46-
weights: torch.nn.ParameterList
46+
weights: torch.nn.ParameterList # type: ignore[assignment]
4747
symbols: list[Symbol]
4848

4949
def __init__(self,
@@ -131,7 +131,7 @@ def _reinitialise_modules(self) -> None:
131131
"""Reinitialise all modules in the model."""
132132
for module in self.modules():
133133
try:
134-
module.reset_parameters()
134+
module.reset_parameters() # type: ignore[operator]
135135
except (AttributeError, TypeError):
136136
pass
137137

lambeq/training/pytorch_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
class PytorchModel(Model, torch.nn.Module):
3636
"""A lambeq model for the classical pipeline using PyTorch."""
3737

38-
weights: torch.nn.ParameterList
38+
weights: torch.nn.ParameterList # type: ignore[assignment]
3939
symbols: list[Symbol]
4040

4141
def __init__(self) -> None:
@@ -47,7 +47,7 @@ def _reinitialise_modules(self) -> None:
4747
"""Reinitialise all modules in the model."""
4848
for module in self.modules():
4949
try:
50-
module.reset_parameters()
50+
module.reset_parameters() # type: ignore[operator]
5151
except (AttributeError, TypeError):
5252
pass
5353

0 commit comments

Comments
 (0)