diff --git a/lambeq/__init__.py b/lambeq/__init__.py index 35e3a768..ae1cbef3 100644 --- a/lambeq/__init__.py +++ b/lambeq/__init__.py @@ -92,6 +92,7 @@ 'NumpyModel', 'PennyLaneModel', 'PytorchModel', + 'PytorchQuantumModel', 'QuantumModel', 'TketModel', @@ -131,7 +132,8 @@ from lambeq.training import (Checkpoint, Dataset, Optimizer, NelderMeadOptimizer, RotosolveOptimizer, SPSAOptimizer, Model, NumpyModel, - PennyLaneModel, PytorchModel, QuantumModel, + PennyLaneModel, PytorchModel, + PytorchQuantumModel, QuantumModel, TketModel, Trainer, PytorchTrainer, QuantumTrainer, BinaryCrossEntropyLoss, CrossEntropyLoss, LossFunction, MSELoss) diff --git a/lambeq/backend/quantum.py b/lambeq/backend/quantum.py index 47b1426f..f4e740be 100644 --- a/lambeq/backend/quantum.py +++ b/lambeq/backend/quantum.py @@ -833,7 +833,10 @@ def array(self): sin = self.modules.sin(half_theta) cos = self.modules.cos(half_theta) - return np.array([[cos, -1j * sin], [-1j * sin, cos]]) + I_arr = np.eye(2) + X_arr = np.array([[0, 1], [1, 0]]) + + return cos * I_arr - 1j * sin * X_arr class Ry(SelfConjugate, Rotation): @@ -846,7 +849,10 @@ def array(self): sin = self.modules.sin(half_theta) cos = self.modules.cos(half_theta) - return np.array([[cos, sin], [-sin, cos]]) + I_arr = np.eye(2) + Y_arr = np.array([[0, 1j], [-1j, 0]]) + + return cos * I_arr - 1j * sin * Y_arr class Rz(AntiConjugate, Rotation): @@ -859,7 +865,10 @@ def array(self): exp1 = np.e ** (-1j * half_theta) exp2 = np.e ** (1j * half_theta) - return np.array([[exp1, 0], [0, exp2]]) + P_0 = np.array([[1, 0], [0, 0]]) + P_1 = np.array([[0, 0], [0, 1]]) + + return exp1 * P_0 + exp2 * P_1 class Controlled(Parametrized): diff --git a/lambeq/training/__init__.py b/lambeq/training/__init__.py index fd0a5d57..21ea7c8e 100644 --- a/lambeq/training/__init__.py +++ b/lambeq/training/__init__.py @@ -19,6 +19,7 @@ 'NumpyModel', 'PennyLaneModel', 'PytorchModel', + 'PytorchQuantumModel', 'QuantumModel', 'TketModel', @@ -44,6 +45,7 @@ from lambeq.training.numpy_model import NumpyModel from lambeq.training.pennylane_model import PennyLaneModel from lambeq.training.pytorch_model import PytorchModel +from lambeq.training.pytorch_quantum_model import PytorchQuantumModel from lambeq.training.quantum_model import QuantumModel from lambeq.training.tket_model import TketModel diff --git a/lambeq/training/pytorch_quantum_model.py b/lambeq/training/pytorch_quantum_model.py new file mode 100644 index 00000000..3cf067bf --- /dev/null +++ b/lambeq/training/pytorch_quantum_model.py @@ -0,0 +1,137 @@ +# Copyright 2021-2024 Cambridge Quantum Computing Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PytorchQuantumModel +============ +Module implementing a basic lambeq model based on a Pytorch backend +for training quantum circuits with Pytorch automatic gradients. + +""" +from __future__ import annotations + +import torch + +from lambeq.ansatz.base import Symbol +from lambeq.backend.numerical_backend import backend +from lambeq.backend.quantum import Diagram as Circuit +from lambeq.backend.tensor import Diagram +from lambeq.training.checkpoint import Checkpoint +from lambeq.training.quantum_model import QuantumModel + + +class PytorchQuantumModel(QuantumModel, torch.nn.Module): + """A lambeq model for the quantum pipeline using PyTorch + with automatic gradient tracking.""" + + weights: torch.nn.Parameter # type: ignore[assignment] + symbols: list[Symbol] + + def __init__(self) -> None: + """Initialise a PytorchQuantumModel.""" + QuantumModel.__init__(self) + torch.nn.Module.__init__(self) + + def _reinitialise_modules(self) -> None: + """Reinitialise all modules in the model.""" + for module in self.modules(): + try: + module.reset_parameters() # type: ignore[operator] + except (AttributeError, TypeError): + pass + + def initialise_weights(self) -> None: + self._reinitialise_modules() + if not self.symbols: + raise ValueError('Symbols not initialised. Instantiate through ' + '`PytorchQuantumModel.from_diagrams()`.') + + self.weights = torch.nn.Parameter(torch.rand(len(self.symbols))) + + def _load_checkpoint(self, checkpoint: Checkpoint) -> None: + """Load the model weights and symbols from a lambeq + :py:class:`.Checkpoint`. + + Parameters + ---------- + checkpoint : :py:class:`.Checkpoint` + Checkpoint containing the model weights, + symbols and additional information. + + """ + + self.symbols = checkpoint['model_symbols'] + self.weights = checkpoint['model_weights'] + self.load_state_dict(checkpoint['model_state_dict']) + + def _make_checkpoint(self) -> Checkpoint: + """Create checkpoint that contains the model weights and symbols. + + Returns + ------- + :py:class:`.Checkpoint` + Checkpoint containing the model weights, symbols and + additional information. + + """ + checkpoint = Checkpoint() + checkpoint.add_many({'model_symbols': self.symbols, + 'model_weights': self.weights, + 'model_state_dict': self.state_dict()}) + return checkpoint + + def get_diagram_output(self, diagrams: list[Diagram]) -> torch.Tensor: + import tensornetwork as tn + + diagrams = self._fast_subs(diagrams, self.weights) + with backend('pytorch'), tn.DefaultBackend('pytorch'): + results = [] + for d in diagrams: + assert isinstance(d, Circuit) + nodes, edges = d.to_tn() + + # Ensure uniform tensor dtypes for contraction. + dominant_dtype = torch.bool + for node in nodes: + dominant_dtype = torch.promote_types( + dominant_dtype, node.tensor.dtype) + for node in nodes: + if node.tensor.dtype != dominant_dtype: + node.tensor = node.tensor.to(dominant_dtype) + + result = tn.contractors.auto(nodes, edges).tensor + if not d.is_mixed: + result = torch.square(torch.abs(result)) + results.append(self._normalise_vector(result)) + return torch.stack(results) + + def forward(self, x: list[Diagram]) -> torch.Tensor: + """Perform default forward pass by contracting tensors. + + In case of a different datapoint (e.g. list of tuple) or + additional computational steps, please override this method. + + Parameters + ---------- + x : list of :py:class:`~lambeq.backend.tensor.Diagram` + The :py:class:`Diagrams ` to be + evaluated. + + Returns + ------- + torch.Tensor + Tensor containing model's prediction. + + """ + return self.get_diagram_output(x) diff --git a/tests/training/test_pytorch_quantum_model.py b/tests/training/test_pytorch_quantum_model.py new file mode 100644 index 00000000..5fe1704e --- /dev/null +++ b/tests/training/test_pytorch_quantum_model.py @@ -0,0 +1,142 @@ +import pickle +import pytest +from copy import deepcopy +from unittest.mock import mock_open, patch + +import numpy as np +import torch +from lambeq.backend.grammar import Cup, Id, Word +from lambeq.backend.quantum import CRz, CX, H, Ket, Measure, SWAP, Discard, qubit + +from lambeq import AtomicType, IQPAnsatz, PytorchQuantumModel, Symbol + +def test_init(): + N = AtomicType.NOUN + S = AtomicType.SENTENCE + + ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1) + diagrams = [ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S)))] + model = PytorchQuantumModel.from_diagrams(diagrams) + model.initialise_weights() + assert len(model.weights) == 4 + assert isinstance(model.weights, torch.nn.Parameter) + +def test_forward(): + N = AtomicType.NOUN + S = AtomicType.SENTENCE + + s_dim = 2 + ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1) + diagrams = [ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S)))] + model = PytorchQuantumModel.from_diagrams(diagrams) + model.initialise_weights() + pred = model.forward(diagrams) + assert pred.shape == (len(diagrams), s_dim) + +def test_forward_mixed(): + N = AtomicType.NOUN + S = AtomicType.SENTENCE + + density_matrix_dim = (2, 2, 2, 2) + ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1) + diagrams = [ansatz((Word("Alice", N) @ Word("runs", N >> S))) >> (Discard() @ qubit @ qubit)] + model = PytorchQuantumModel.from_diagrams(diagrams) + model.initialise_weights() + pred = model.forward(diagrams) + assert pred.shape == (len(diagrams), *density_matrix_dim) + + +def test_backward(): + N = AtomicType.NOUN + S = AtomicType.SENTENCE + + ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1) + diagrams = [ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S)))] + model = PytorchQuantumModel.from_diagrams(diagrams) + model.initialise_weights() + pred = model.forward(diagrams) + loss = torch.nn.MSELoss()(pred, torch.zeros_like(pred)) + loss.backward() + assert model.weights.grad is not None + assert model.weights.grad.shape == model.weights.shape + + +def test_initialise_weights_error(): + with pytest.raises(ValueError): + model = PytorchQuantumModel() + model.initialise_weights() + +def test_get_diagram_output_error(): + N = AtomicType.NOUN + S = AtomicType.SENTENCE + ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1) + diagram = ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S))) + with pytest.raises(KeyError): + model = PytorchQuantumModel() + model.get_diagram_output([diagram]) + +def test_checkpoint_loading(): + N = AtomicType.NOUN + S = AtomicType.SENTENCE + ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1) + diagram = ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S))) + model = PytorchQuantumModel.from_diagrams([diagram]) + model.initialise_weights() + + checkpoint = {'model_weights': model.weights, + 'model_symbols': model.symbols, + 'model_state_dict': model.state_dict()} + with patch('lambeq.training.checkpoint.open', mock_open(read_data=pickle.dumps(checkpoint))) as m, \ + patch('lambeq.training.checkpoint.os.path.exists', lambda x: True) as p: + model_new = PytorchQuantumModel.from_checkpoint('model.lt') + assert len(model_new.weights) == len(model.weights) + assert model_new.symbols == model.symbols + assert torch.allclose(model([diagram]), model_new([diagram])) + m.assert_called_with('model.lt', 'rb') + +def test_checkpoint_loading_errors(): + checkpoint = {'model_weights': np.array([1,2,3])} + with patch('lambeq.training.checkpoint.open', mock_open(read_data=pickle.dumps(checkpoint))) as m, \ + patch('lambeq.training.checkpoint.os.path.exists', lambda x: True) as p: + with pytest.raises(KeyError): + _ = PytorchQuantumModel.from_checkpoint('model.lt') + m.assert_called_with('model.lt', 'rb') + +def test_checkpoint_loading_file_not_found_errors(): + with patch('lambeq.training.checkpoint.open', mock_open(read_data='Not a valid checkpoint.')) as m, \ + patch('lambeq.training.checkpoint.os.path.exists', lambda x: False) as p: + with pytest.raises(FileNotFoundError): + _ = PytorchQuantumModel.from_checkpoint('model.lt') + m.assert_not_called() + + +def test_pickling(): + phi = Symbol('phi', directed_dom=123) + diagram = Ket(0, 0) >> CRz(phi) >> H @ H >> CX >> SWAP >> Measure() @ Measure() + + deepcopied_diagram = deepcopy(diagram) + pickled_diagram = pickle.loads(pickle.dumps(diagram)) + assert pickled_diagram == diagram + pickled_diagram.data = 'new data' + for box in pickled_diagram.boxes: + box.name = 'Jim' + box.data = ['random', 'data'] + assert diagram == deepcopied_diagram + assert diagram != pickled_diagram + assert deepcopied_diagram != pickled_diagram + +def test_normalise(): + model = PytorchQuantumModel() + input1 = np.linspace(-10, 10, 21) + input2 = np.array(-0.5) + normalised1 = model._normalise_vector(input1) + normalised2 = model._normalise_vector(input2) + assert abs(normalised1.sum() - 1.0) < 1e-8 + assert abs(normalised2 - 0.5) < 1e-8 + assert np.all(normalised1 >= 0) + +def test_fast_subs_error(): + with pytest.raises(KeyError): + diag = Ket(0, 0) >> CRz(Symbol('phi', directed_dom=123)) >> H @ H >> CX >> SWAP >> Measure() @ Measure() + model = PytorchQuantumModel() + model._fast_subs([diag], [])