Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PytorchQuantumModel for quantum circuits training with autograd #208

Merged
merged 6 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lambeq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
'NumpyModel',
'PennyLaneModel',
'PytorchModel',
'PytorchQuantumModel',
'QuantumModel',
'TketModel',

Expand Down Expand Up @@ -131,7 +132,8 @@
from lambeq.training import (Checkpoint, Dataset, Optimizer,
NelderMeadOptimizer, RotosolveOptimizer,
SPSAOptimizer, Model, NumpyModel,
PennyLaneModel, PytorchModel, QuantumModel,
PennyLaneModel, PytorchModel,
PytorchQuantumModel, QuantumModel,
TketModel, Trainer, PytorchTrainer,
QuantumTrainer, BinaryCrossEntropyLoss,
CrossEntropyLoss, LossFunction, MSELoss)
Expand Down
15 changes: 12 additions & 3 deletions lambeq/backend/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,10 @@ def array(self):
sin = self.modules.sin(half_theta)
cos = self.modules.cos(half_theta)

return np.array([[cos, -1j * sin], [-1j * sin, cos]])
I_arr = np.eye(2)
X_arr = np.array([[0, 1], [1, 0]])

return cos * I_arr - 1j * sin * X_arr


class Ry(SelfConjugate, Rotation):
Expand All @@ -846,7 +849,10 @@ def array(self):
sin = self.modules.sin(half_theta)
cos = self.modules.cos(half_theta)

return np.array([[cos, sin], [-sin, cos]])
I_arr = np.eye(2)
Y_arr = np.array([[0, 1j], [-1j, 0]])

return cos * I_arr - 1j * sin * Y_arr


class Rz(AntiConjugate, Rotation):
Expand All @@ -859,7 +865,10 @@ def array(self):
exp1 = np.e ** (-1j * half_theta)
exp2 = np.e ** (1j * half_theta)

return np.array([[exp1, 0], [0, exp2]])
P_0 = np.array([[1, 0], [0, 0]])
P_1 = np.array([[0, 0], [0, 1]])

return exp1 * P_0 + exp2 * P_1


class Controlled(Parametrized):
Expand Down
2 changes: 2 additions & 0 deletions lambeq/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'NumpyModel',
'PennyLaneModel',
'PytorchModel',
'PytorchQuantumModel',
'QuantumModel',
'TketModel',

Expand All @@ -44,6 +45,7 @@
from lambeq.training.numpy_model import NumpyModel
from lambeq.training.pennylane_model import PennyLaneModel
from lambeq.training.pytorch_model import PytorchModel
from lambeq.training.pytorch_quantum_model import PytorchQuantumModel
from lambeq.training.quantum_model import QuantumModel
from lambeq.training.tket_model import TketModel

Expand Down
137 changes: 137 additions & 0 deletions lambeq/training/pytorch_quantum_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
PytorchQuantumModel
============
Module implementing a basic lambeq model based on a Pytorch backend
for training quantum circuits with Pytorch automatic gradients.

"""
from __future__ import annotations

import torch

from lambeq.ansatz.base import Symbol
from lambeq.backend.numerical_backend import backend
from lambeq.backend.quantum import Diagram as Circuit
from lambeq.backend.tensor import Diagram
from lambeq.training.checkpoint import Checkpoint
from lambeq.training.quantum_model import QuantumModel


class PytorchQuantumModel(QuantumModel, torch.nn.Module):
"""A lambeq model for the quantum pipeline using PyTorch
with automatic gradient tracking."""

weights: torch.nn.Parameter # type: ignore[assignment]
symbols: list[Symbol]

def __init__(self) -> None:
"""Initialise a PytorchQuantumModel."""
QuantumModel.__init__(self)
torch.nn.Module.__init__(self)

def _reinitialise_modules(self) -> None:
"""Reinitialise all modules in the model."""
for module in self.modules():
try:
module.reset_parameters() # type: ignore[operator]
except (AttributeError, TypeError):
pass

def initialise_weights(self) -> None:
self._reinitialise_modules()
if not self.symbols:
raise ValueError('Symbols not initialised. Instantiate through '
'`PytorchQuantumModel.from_diagrams()`.')

self.weights = torch.nn.Parameter(torch.rand(len(self.symbols)))

def _load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Load the model weights and symbols from a lambeq
:py:class:`.Checkpoint`.

Parameters
----------
checkpoint : :py:class:`.Checkpoint`
Checkpoint containing the model weights,
symbols and additional information.

"""

self.symbols = checkpoint['model_symbols']
self.weights = checkpoint['model_weights']
self.load_state_dict(checkpoint['model_state_dict'])

def _make_checkpoint(self) -> Checkpoint:
"""Create checkpoint that contains the model weights and symbols.

Returns
-------
:py:class:`.Checkpoint`
Checkpoint containing the model weights, symbols and
additional information.

"""
checkpoint = Checkpoint()
checkpoint.add_many({'model_symbols': self.symbols,
'model_weights': self.weights,
'model_state_dict': self.state_dict()})
return checkpoint

def get_diagram_output(self, diagrams: list[Diagram]) -> torch.Tensor:
import tensornetwork as tn

diagrams = self._fast_subs(diagrams, self.weights)
with backend('pytorch'), tn.DefaultBackend('pytorch'):
results = []
for d in diagrams:
assert isinstance(d, Circuit)
nodes, edges = d.to_tn()

# Ensure uniform tensor dtypes for contraction.
dominant_dtype = torch.bool
for node in nodes:
dominant_dtype = torch.promote_types(
dominant_dtype, node.tensor.dtype)
for node in nodes:
if node.tensor.dtype != dominant_dtype:
node.tensor = node.tensor.to(dominant_dtype)

result = tn.contractors.auto(nodes, edges).tensor
if not d.is_mixed:
result = torch.square(torch.abs(result))
results.append(self._normalise_vector(result))
return torch.stack(results)

def forward(self, x: list[Diagram]) -> torch.Tensor:
"""Perform default forward pass by contracting tensors.

In case of a different datapoint (e.g. list of tuple) or
additional computational steps, please override this method.

Parameters
----------
x : list of :py:class:`~lambeq.backend.tensor.Diagram`
The :py:class:`Diagrams <lambeq.backend.tensor.Diagram>` to be
evaluated.

Returns
-------
torch.Tensor
Tensor containing model's prediction.

"""
return self.get_diagram_output(x)
142 changes: 142 additions & 0 deletions tests/training/test_pytorch_quantum_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import pickle
import pytest
from copy import deepcopy
from unittest.mock import mock_open, patch

import numpy as np
import torch
from lambeq.backend.grammar import Cup, Id, Word
from lambeq.backend.quantum import CRz, CX, H, Ket, Measure, SWAP, Discard, qubit

from lambeq import AtomicType, IQPAnsatz, PytorchQuantumModel, Symbol

def test_init():
N = AtomicType.NOUN
S = AtomicType.SENTENCE

ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1)
diagrams = [ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S)))]
model = PytorchQuantumModel.from_diagrams(diagrams)
model.initialise_weights()
assert len(model.weights) == 4
assert isinstance(model.weights, torch.nn.Parameter)

def test_forward():
N = AtomicType.NOUN
S = AtomicType.SENTENCE

s_dim = 2
ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1)
diagrams = [ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S)))]
model = PytorchQuantumModel.from_diagrams(diagrams)
model.initialise_weights()
pred = model.forward(diagrams)
assert pred.shape == (len(diagrams), s_dim)

def test_forward_mixed():
N = AtomicType.NOUN
S = AtomicType.SENTENCE

density_matrix_dim = (2, 2, 2, 2)
ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1)
diagrams = [ansatz((Word("Alice", N) @ Word("runs", N >> S))) >> (Discard() @ qubit @ qubit)]
model = PytorchQuantumModel.from_diagrams(diagrams)
model.initialise_weights()
pred = model.forward(diagrams)
assert pred.shape == (len(diagrams), *density_matrix_dim)


def test_backward():
N = AtomicType.NOUN
S = AtomicType.SENTENCE

ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1)
diagrams = [ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S)))]
model = PytorchQuantumModel.from_diagrams(diagrams)
model.initialise_weights()
pred = model.forward(diagrams)
loss = torch.nn.MSELoss()(pred, torch.zeros_like(pred))
loss.backward()
assert model.weights.grad is not None
assert model.weights.grad.shape == model.weights.shape


def test_initialise_weights_error():
with pytest.raises(ValueError):
model = PytorchQuantumModel()
model.initialise_weights()

def test_get_diagram_output_error():
N = AtomicType.NOUN
S = AtomicType.SENTENCE
ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1)
diagram = ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S)))
with pytest.raises(KeyError):
model = PytorchQuantumModel()
model.get_diagram_output([diagram])

def test_checkpoint_loading():
N = AtomicType.NOUN
S = AtomicType.SENTENCE
ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1)
diagram = ansatz((Word("Alice", N) @ Word("runs", N >> S) >> Cup(N, N.r) @ Id(S)))
model = PytorchQuantumModel.from_diagrams([diagram])
model.initialise_weights()

checkpoint = {'model_weights': model.weights,
'model_symbols': model.symbols,
'model_state_dict': model.state_dict()}
with patch('lambeq.training.checkpoint.open', mock_open(read_data=pickle.dumps(checkpoint))) as m, \
patch('lambeq.training.checkpoint.os.path.exists', lambda x: True) as p:
model_new = PytorchQuantumModel.from_checkpoint('model.lt')
assert len(model_new.weights) == len(model.weights)
assert model_new.symbols == model.symbols
assert torch.allclose(model([diagram]), model_new([diagram]))
m.assert_called_with('model.lt', 'rb')

def test_checkpoint_loading_errors():
checkpoint = {'model_weights': np.array([1,2,3])}
with patch('lambeq.training.checkpoint.open', mock_open(read_data=pickle.dumps(checkpoint))) as m, \
patch('lambeq.training.checkpoint.os.path.exists', lambda x: True) as p:
with pytest.raises(KeyError):
_ = PytorchQuantumModel.from_checkpoint('model.lt')
m.assert_called_with('model.lt', 'rb')

def test_checkpoint_loading_file_not_found_errors():
with patch('lambeq.training.checkpoint.open', mock_open(read_data='Not a valid checkpoint.')) as m, \
patch('lambeq.training.checkpoint.os.path.exists', lambda x: False) as p:
with pytest.raises(FileNotFoundError):
_ = PytorchQuantumModel.from_checkpoint('model.lt')
m.assert_not_called()


def test_pickling():
phi = Symbol('phi', directed_dom=123)
diagram = Ket(0, 0) >> CRz(phi) >> H @ H >> CX >> SWAP >> Measure() @ Measure()

deepcopied_diagram = deepcopy(diagram)
pickled_diagram = pickle.loads(pickle.dumps(diagram))
assert pickled_diagram == diagram
pickled_diagram.data = 'new data'
for box in pickled_diagram.boxes:
box.name = 'Jim'
box.data = ['random', 'data']
assert diagram == deepcopied_diagram
assert diagram != pickled_diagram
assert deepcopied_diagram != pickled_diagram

def test_normalise():
model = PytorchQuantumModel()
input1 = np.linspace(-10, 10, 21)
input2 = np.array(-0.5)
normalised1 = model._normalise_vector(input1)
normalised2 = model._normalise_vector(input2)
assert abs(normalised1.sum() - 1.0) < 1e-8
assert abs(normalised2 - 0.5) < 1e-8
assert np.all(normalised1 >= 0)

def test_fast_subs_error():
with pytest.raises(KeyError):
diag = Ket(0, 0) >> CRz(Symbol('phi', directed_dom=123)) >> H @ H >> CX >> SWAP >> Measure() @ Measure()
model = PytorchQuantumModel()
model._fast_subs([diag], [])
Loading