Skip to content

Commit

Permalink
Relax type restrictions on QuantumModel (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
neiljdo authored Feb 20, 2025
1 parent 16c90ae commit 68ea9bb
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 27 deletions.
8 changes: 4 additions & 4 deletions lambeq/training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Collection
from collections.abc import MutableSequence
from typing import Any

from sympy import Symbol as SymPySymbol
Expand All @@ -33,14 +33,14 @@


class Model(ABC):
"""Model base class.
"""Model abstract base class.
Attributes
----------
symbols : list of symbols
A sorted list of all :py:class:`Symbols <.Symbol>` occuring in
the data.
weights : Collection
weights : MutableSequence
A data structure containing the numeric values of
the model's parameters.
Expand All @@ -49,7 +49,7 @@ class Model(ABC):
def __init__(self) -> None:
"""Initialise an instance of :py:class:`Model` base class."""
self.symbols: list[Symbol] | list[SymPySymbol] = []
self.weights: Collection = []
self.weights: MutableSequence = []

def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.forward(*args, **kwds)
Expand Down
13 changes: 8 additions & 5 deletions lambeq/training/numpy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from collections.abc import Callable, Iterable
from typing import Any, TYPE_CHECKING

import numpy
import numpy as np
from numpy.typing import ArrayLike

from lambeq.backend import numerical_backend
Expand All @@ -46,6 +46,8 @@ class NumpyModel(QuantumModel):
"""A lambeq model for an exact classical simulation of a
quantum pipeline."""

weights: np.ndarray

def __init__(self, use_jit: bool = False) -> None:
"""Initialise an NumpyModel.
Expand Down Expand Up @@ -86,15 +88,16 @@ def diagram_output(x: Iterable[ArrayLike]) -> ArrayLike:
assert isinstance(sub_circuit, Circuit)
if not sub_circuit.is_mixed:
result = backend.abs(result) ** 2
return self._normalise_vector(result)
normalised_result: ArrayLike = self._normalise_vector(result)
return normalised_result

self.lambdas[diagram] = jit(diagram_output)
return self.lambdas[diagram]

def get_diagram_output(
self,
diagrams: list[Diagram]
) -> jnp.ndarray | numpy.ndarray:
) -> jnp.ndarray | np.ndarray:
"""Return the exact prediction for each diagram.
Parameters
Expand Down Expand Up @@ -141,9 +144,9 @@ def get_diagram_output(
result = tn.contractors.auto(*d.to_tn()).tensor
# square amplitudes to get probabilties for pure circuits
if not d.is_mixed:
result = numpy.abs(result) ** 2
result = np.abs(result) ** 2
results.append(self._normalise_vector(result))
return numpy.array(results)
return np.array(results)

def forward(self, x: list[Diagram]) -> Any:
"""Perform default forward pass of a lambeq model.
Expand Down
31 changes: 15 additions & 16 deletions lambeq/training/quantum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from abc import abstractmethod
from collections.abc import Iterable
import pickle
from typing import Any, TYPE_CHECKING
from typing import Any

import numpy as np

Expand All @@ -32,10 +32,7 @@
from lambeq.backend.tensor import Diagram
from lambeq.training.checkpoint import Checkpoint
from lambeq.training.model import Model


if TYPE_CHECKING:
from jax import numpy as jnp
from lambeq.typing import AnyTensor


class QuantumModel(Model):
Expand All @@ -46,29 +43,31 @@ class QuantumModel(Model):
symbols : list of symbols
A sorted list of all :py:class:`Symbols <.Symbol>` occurring in
the data.
weights : array
weights : AnyTensor
A data structure containing the numeric values of the model
parameters
parameters. This could be a `torch.Tensor`, `np.ndarray`, or
one from a different backend.
"""
weights: np.ndarray

weights: AnyTensor

def __init__(self) -> None:
"""Initialise a :py:class:`QuantumModel`."""
super().__init__()

self._training = False
self._train_predictions: list[Any] = []
self._train_predictions: list[AnyTensor] = []

def _log_prediction(self, y: Any) -> None:
def _log_prediction(self, y: AnyTensor) -> None:
"""Log a prediction of the model."""
self._train_predictions.append(y)

def _clear_predictions(self) -> None:
"""Clear the logged predictions of the model."""
self._train_predictions = []

def _normalise_vector(self, predictions: np.ndarray) -> np.ndarray:
def _normalise_vector(self, predictions: AnyTensor) -> AnyTensor:
"""Normalise the vector input.
Special cases:
Expand All @@ -77,11 +76,11 @@ def _normalise_vector(self, predictions: np.ndarray) -> np.ndarray:
"""

backend = numerical_backend.get_backend()
ret: np.ndarray = backend.abs(predictions)
ret: AnyTensor = backend.abs(predictions)

if predictions.shape:
# Prevent division by 0
l1_norm = backend.maximum(1e-9, ret.sum())
l1_norm = backend.maximum(backend.array(1e-9), ret.sum())
ret = ret / l1_norm

return ret
Expand Down Expand Up @@ -158,7 +157,7 @@ def _fast_subs(self,
def get_diagram_output(
self,
diagrams: list[Diagram]
) -> jnp.ndarray | np.ndarray:
) -> AnyTensor:
"""Return the diagram prediction.
Parameters
Expand All @@ -169,14 +168,14 @@ def get_diagram_output(
"""

def __call__(self, *args: Any, **kwargs: Any) -> Any:
def __call__(self, *args: Any, **kwargs: Any) -> AnyTensor:
out = self.forward(*args, **kwargs)
if self._training:
self._log_prediction(out)
return out

@abstractmethod
def forward(self, x: list[Diagram]) -> Any:
def forward(self, x: list[Diagram]) -> AnyTensor:
"""Compute the forward pass of the model using
`get_model_output`
Expand Down
4 changes: 3 additions & 1 deletion lambeq/training/tket_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class TketModel(QuantumModel):
"""

weights: np.ndarray

def __init__(self, backend_config: dict[str, Any]) -> None:
"""Initialise TketModel based on the `t|ket>` backend.
Expand Down Expand Up @@ -105,7 +107,7 @@ def get_diagram_output(self, diagrams: list[Diagram]) -> np.ndarray:
# lambeq evals a single diagram into a single result
# and not a list of results
if len(diagrams) == 1:
result = self._normalise_vector(tensors)
result: np.ndarray = self._normalise_vector(tensors)
return result.reshape(1, *result.shape)
return np.array([self._normalise_vector(t) for t in tensors])

Expand Down
3 changes: 2 additions & 1 deletion lambeq/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
import os
from typing import Union
from typing import Any, Union

AnyTensor = Any
StrPathT = Union[str, 'os.PathLike[str]']

0 comments on commit 68ea9bb

Please sign in to comment.