23
23
from abc import abstractmethod
24
24
from collections .abc import Iterable
25
25
import pickle
26
- from typing import Any , TYPE_CHECKING
26
+ from typing import Any
27
27
28
28
import numpy as np
29
29
32
32
from lambeq .backend .tensor import Diagram
33
33
from lambeq .training .checkpoint import Checkpoint
34
34
from lambeq .training .model import Model
35
-
36
-
37
- if TYPE_CHECKING :
38
- from jax import numpy as jnp
35
+ from lambeq .typing import AnyTensor
39
36
40
37
41
38
class QuantumModel (Model ):
@@ -46,29 +43,31 @@ class QuantumModel(Model):
46
43
symbols : list of symbols
47
44
A sorted list of all :py:class:`Symbols <.Symbol>` occurring in
48
45
the data.
49
- weights : array
46
+ weights : AnyTensor
50
47
A data structure containing the numeric values of the model
51
- parameters
48
+ parameters. This could be a `torch.Tensor`, `np.ndarray`, or
49
+ one from a different backend.
52
50
53
51
"""
54
- weights : np .ndarray
52
+
53
+ weights : AnyTensor
55
54
56
55
def __init__ (self ) -> None :
57
56
"""Initialise a :py:class:`QuantumModel`."""
58
57
super ().__init__ ()
59
58
60
59
self ._training = False
61
- self ._train_predictions : list [Any ] = []
60
+ self ._train_predictions : list [AnyTensor ] = []
62
61
63
- def _log_prediction (self , y : Any ) -> None :
62
+ def _log_prediction (self , y : AnyTensor ) -> None :
64
63
"""Log a prediction of the model."""
65
64
self ._train_predictions .append (y )
66
65
67
66
def _clear_predictions (self ) -> None :
68
67
"""Clear the logged predictions of the model."""
69
68
self ._train_predictions = []
70
69
71
- def _normalise_vector (self , predictions : np . ndarray ) -> np . ndarray :
70
+ def _normalise_vector (self , predictions : AnyTensor ) -> AnyTensor :
72
71
"""Normalise the vector input.
73
72
74
73
Special cases:
@@ -77,11 +76,11 @@ def _normalise_vector(self, predictions: np.ndarray) -> np.ndarray:
77
76
"""
78
77
79
78
backend = numerical_backend .get_backend ()
80
- ret : np . ndarray = backend .abs (predictions )
79
+ ret : AnyTensor = backend .abs (predictions )
81
80
82
81
if predictions .shape :
83
82
# Prevent division by 0
84
- l1_norm = backend .maximum (1e-9 , ret .sum ())
83
+ l1_norm = backend .maximum (backend . array ( 1e-9 ) , ret .sum ())
85
84
ret = ret / l1_norm
86
85
87
86
return ret
@@ -158,7 +157,7 @@ def _fast_subs(self,
158
157
def get_diagram_output (
159
158
self ,
160
159
diagrams : list [Diagram ]
161
- ) -> jnp . ndarray | np . ndarray :
160
+ ) -> AnyTensor :
162
161
"""Return the diagram prediction.
163
162
164
163
Parameters
@@ -169,14 +168,14 @@ def get_diagram_output(
169
168
170
169
"""
171
170
172
- def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
171
+ def __call__ (self , * args : Any , ** kwargs : Any ) -> AnyTensor :
173
172
out = self .forward (* args , ** kwargs )
174
173
if self ._training :
175
174
self ._log_prediction (out )
176
175
return out
177
176
178
177
@abstractmethod
179
- def forward (self , x : list [Diagram ]) -> Any :
178
+ def forward (self , x : list [Diagram ]) -> AnyTensor :
180
179
"""Compute the forward pass of the model using
181
180
`get_model_output`
182
181
0 commit comments