Skip to content

Commit 68ea9bb

Browse files
authored
Relax type restrictions on QuantumModel (#212)
1 parent 16c90ae commit 68ea9bb

File tree

5 files changed

+32
-27
lines changed

5 files changed

+32
-27
lines changed

lambeq/training/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from __future__ import annotations
2222

2323
from abc import ABC, abstractmethod
24-
from collections.abc import Collection
24+
from collections.abc import MutableSequence
2525
from typing import Any
2626

2727
from sympy import Symbol as SymPySymbol
@@ -33,14 +33,14 @@
3333

3434

3535
class Model(ABC):
36-
"""Model base class.
36+
"""Model abstract base class.
3737
3838
Attributes
3939
----------
4040
symbols : list of symbols
4141
A sorted list of all :py:class:`Symbols <.Symbol>` occuring in
4242
the data.
43-
weights : Collection
43+
weights : MutableSequence
4444
A data structure containing the numeric values of
4545
the model's parameters.
4646
@@ -49,7 +49,7 @@ class Model(ABC):
4949
def __init__(self) -> None:
5050
"""Initialise an instance of :py:class:`Model` base class."""
5151
self.symbols: list[Symbol] | list[SymPySymbol] = []
52-
self.weights: Collection = []
52+
self.weights: MutableSequence = []
5353

5454
def __call__(self, *args: Any, **kwds: Any) -> Any:
5555
return self.forward(*args, **kwds)

lambeq/training/numpy_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from collections.abc import Callable, Iterable
3030
from typing import Any, TYPE_CHECKING
3131

32-
import numpy
32+
import numpy as np
3333
from numpy.typing import ArrayLike
3434

3535
from lambeq.backend import numerical_backend
@@ -46,6 +46,8 @@ class NumpyModel(QuantumModel):
4646
"""A lambeq model for an exact classical simulation of a
4747
quantum pipeline."""
4848

49+
weights: np.ndarray
50+
4951
def __init__(self, use_jit: bool = False) -> None:
5052
"""Initialise an NumpyModel.
5153
@@ -86,15 +88,16 @@ def diagram_output(x: Iterable[ArrayLike]) -> ArrayLike:
8688
assert isinstance(sub_circuit, Circuit)
8789
if not sub_circuit.is_mixed:
8890
result = backend.abs(result) ** 2
89-
return self._normalise_vector(result)
91+
normalised_result: ArrayLike = self._normalise_vector(result)
92+
return normalised_result
9093

9194
self.lambdas[diagram] = jit(diagram_output)
9295
return self.lambdas[diagram]
9396

9497
def get_diagram_output(
9598
self,
9699
diagrams: list[Diagram]
97-
) -> jnp.ndarray | numpy.ndarray:
100+
) -> jnp.ndarray | np.ndarray:
98101
"""Return the exact prediction for each diagram.
99102
100103
Parameters
@@ -141,9 +144,9 @@ def get_diagram_output(
141144
result = tn.contractors.auto(*d.to_tn()).tensor
142145
# square amplitudes to get probabilties for pure circuits
143146
if not d.is_mixed:
144-
result = numpy.abs(result) ** 2
147+
result = np.abs(result) ** 2
145148
results.append(self._normalise_vector(result))
146-
return numpy.array(results)
149+
return np.array(results)
147150

148151
def forward(self, x: list[Diagram]) -> Any:
149152
"""Perform default forward pass of a lambeq model.

lambeq/training/quantum_model.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from abc import abstractmethod
2424
from collections.abc import Iterable
2525
import pickle
26-
from typing import Any, TYPE_CHECKING
26+
from typing import Any
2727

2828
import numpy as np
2929

@@ -32,10 +32,7 @@
3232
from lambeq.backend.tensor import Diagram
3333
from lambeq.training.checkpoint import Checkpoint
3434
from lambeq.training.model import Model
35-
36-
37-
if TYPE_CHECKING:
38-
from jax import numpy as jnp
35+
from lambeq.typing import AnyTensor
3936

4037

4138
class QuantumModel(Model):
@@ -46,29 +43,31 @@ class QuantumModel(Model):
4643
symbols : list of symbols
4744
A sorted list of all :py:class:`Symbols <.Symbol>` occurring in
4845
the data.
49-
weights : array
46+
weights : AnyTensor
5047
A data structure containing the numeric values of the model
51-
parameters
48+
parameters. This could be a `torch.Tensor`, `np.ndarray`, or
49+
one from a different backend.
5250
5351
"""
54-
weights: np.ndarray
52+
53+
weights: AnyTensor
5554

5655
def __init__(self) -> None:
5756
"""Initialise a :py:class:`QuantumModel`."""
5857
super().__init__()
5958

6059
self._training = False
61-
self._train_predictions: list[Any] = []
60+
self._train_predictions: list[AnyTensor] = []
6261

63-
def _log_prediction(self, y: Any) -> None:
62+
def _log_prediction(self, y: AnyTensor) -> None:
6463
"""Log a prediction of the model."""
6564
self._train_predictions.append(y)
6665

6766
def _clear_predictions(self) -> None:
6867
"""Clear the logged predictions of the model."""
6968
self._train_predictions = []
7069

71-
def _normalise_vector(self, predictions: np.ndarray) -> np.ndarray:
70+
def _normalise_vector(self, predictions: AnyTensor) -> AnyTensor:
7271
"""Normalise the vector input.
7372
7473
Special cases:
@@ -77,11 +76,11 @@ def _normalise_vector(self, predictions: np.ndarray) -> np.ndarray:
7776
"""
7877

7978
backend = numerical_backend.get_backend()
80-
ret: np.ndarray = backend.abs(predictions)
79+
ret: AnyTensor = backend.abs(predictions)
8180

8281
if predictions.shape:
8382
# Prevent division by 0
84-
l1_norm = backend.maximum(1e-9, ret.sum())
83+
l1_norm = backend.maximum(backend.array(1e-9), ret.sum())
8584
ret = ret / l1_norm
8685

8786
return ret
@@ -158,7 +157,7 @@ def _fast_subs(self,
158157
def get_diagram_output(
159158
self,
160159
diagrams: list[Diagram]
161-
) -> jnp.ndarray | np.ndarray:
160+
) -> AnyTensor:
162161
"""Return the diagram prediction.
163162
164163
Parameters
@@ -169,14 +168,14 @@ def get_diagram_output(
169168
170169
"""
171170

172-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
171+
def __call__(self, *args: Any, **kwargs: Any) -> AnyTensor:
173172
out = self.forward(*args, **kwargs)
174173
if self._training:
175174
self._log_prediction(out)
176175
return out
177176

178177
@abstractmethod
179-
def forward(self, x: list[Diagram]) -> Any:
178+
def forward(self, x: list[Diagram]) -> AnyTensor:
180179
"""Compute the forward pass of the model using
181180
`get_model_output`
182181

lambeq/training/tket_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class TketModel(QuantumModel):
3737
3838
"""
3939

40+
weights: np.ndarray
41+
4042
def __init__(self, backend_config: dict[str, Any]) -> None:
4143
"""Initialise TketModel based on the `t|ket>` backend.
4244
@@ -105,7 +107,7 @@ def get_diagram_output(self, diagrams: list[Diagram]) -> np.ndarray:
105107
# lambeq evals a single diagram into a single result
106108
# and not a list of results
107109
if len(diagrams) == 1:
108-
result = self._normalise_vector(tensors)
110+
result: np.ndarray = self._normalise_vector(tensors)
109111
return result.reshape(1, *result.shape)
110112
return np.array([self._normalise_vector(t) for t in tensors])
111113

lambeq/typing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
2020
"""
2121
import os
22-
from typing import Union
22+
from typing import Any, Union
2323

24+
AnyTensor = Any
2425
StrPathT = Union[str, 'os.PathLike[str]']

0 commit comments

Comments
 (0)