Skip to content

Commit 7f7bb00

Browse files
committed
Implemented training with quantum pipeline.
Added new dataset.
1 parent 83b97d8 commit 7f7bb00

9 files changed

+228
-29
lines changed

app/src/main/classical_user_input.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from app.src.main.pipelines.classical import classical_compute
2-
1+
from app.src.main.pipelines.classical import send_into_classical_pipeline
32

43
test_sentence = input("Enter a sentence: ")
5-
classical_compute(test_sentence)
4+
send_into_classical_pipeline(test_sentence)

app/src/main/constants/sample_sentences.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@
77
ADS_ISSUED = "Initial ADS-Amount is equal to 6000000 ."
88

99
NON_ADR_WALKING = "John walks in the park ."
10-

app/src/main/pipelines/classical.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from discopy import grammar, Dim
22
from lambeq import BobcatParser, Rewriter, AtomicType, MPSAnsatz
33

4-
from app.src.main.constants import sample_sentences
54

6-
7-
def classical_compute(sentence):
5+
def send_into_classical_pipeline(sentence):
86
# Convert to string diagram
97
parser = BobcatParser(verbose='text')
108
diagram = parser.sentence2diagram(sentence) # syntax-based, not bag-of-words
@@ -13,7 +11,6 @@ def classical_compute(sentence):
1311
# Rewrite string diagram, to reduce performance costs / training time
1412
rewriter = Rewriter(['prepositional_phrase', 'determiner']) # lower tensor count on prepositions
1513
prep_reduced_diagram = rewriter(diagram).normal_form()
16-
prep_reduced_diagram.draw(figsize=(9, 4), fontsize=13)
1714

1815
curry_functor = Rewriter(['curry']) # reduce number of cups
1916
curried_diagram = curry_functor(prep_reduced_diagram).normal_form()
@@ -32,8 +29,3 @@ def classical_compute(sentence):
3229
mps_diagram.draw(figsize=(13, 7), fontsize=13)
3330

3431
# Todo: Training
35-
36-
if __name__ == "__main__":
37-
test_sentence = sample_sentences.ADS_ISSUED
38-
print(f"Input string: {test_sentence}")
39-
classical_compute(test_sentence)

app/src/main/pipelines/quantum.py

+112-12
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
import os
2+
import warnings
23
import webbrowser
34

5+
import numpy as np
46
from discopy import grammar
5-
from lambeq import BobcatParser, Rewriter, AtomicType, IQPAnsatz
7+
from lambeq import BobcatParser, Rewriter, AtomicType, IQPAnsatz, remove_cups, TketModel, QuantumTrainer, SPSAOptimizer, \
8+
Dataset
69
from matplotlib import pyplot
710
from pytket.circuit.display import render_circuit_as_html
8-
from pytket.extensions.qiskit import tk_to_qiskit
11+
from pytket.extensions.qiskit import tk_to_qiskit, AerBackend
912

10-
from app.src.main.constants import sample_sentences
11-
from settings import GEN_PATH
13+
from settings import GEN_PATH, PROJECT_ROOT_PATH
1214

15+
PATH_TO_TRAINING = os.path.join(PROJECT_ROOT_PATH, 'data', 'rp_train_data.txt')
16+
PATH_TO_TESTING = os.path.join(PROJECT_ROOT_PATH, 'data', 'rp_test_data.txt')
1317

14-
def quantum_compute(sentence):
18+
19+
def send_into_quantum_pipeline(sentence):
1520
# Convert to string diagram
1621
parser = BobcatParser(verbose='text')
1722
diagram = parser.sentence2diagram(sentence) # syntax-based, not bag-of-words
@@ -20,7 +25,6 @@ def quantum_compute(sentence):
2025
# Rewrite string diagram, to reduce performance costs / training time
2126
rewriter = Rewriter(['prepositional_phrase', 'determiner']) # lower tensor count on prepositions
2227
prep_reduced_diagram = rewriter(diagram).normal_form()
23-
prep_reduced_diagram.draw(figsize=(9, 4), fontsize=13)
2428

2529
curry_functor = Rewriter(['curry']) # reduce number of cups
2630
curried_diagram = curry_functor(prep_reduced_diagram).normal_form()
@@ -33,7 +37,7 @@ def quantum_compute(sentence):
3337
C = AtomicType.CONJUNCTION
3438
ansatz = IQPAnsatz({N: 1, S: 1, P: 1, C: 1}, n_layers=4)
3539

36-
discopy_circuit = ansatz(diagram) # Quantum circuit, DisCoPy format
40+
discopy_circuit = ansatz(diagram) # Quantum circuit, DisCoPy format
3741
discopy_circuit.draw(figsize=(15, 10))
3842

3943
tket_circuit = discopy_circuit.to_tk() # Quantum circuit, pytket format
@@ -46,9 +50,105 @@ def quantum_compute(sentence):
4650
qiskit_circuit.draw(output='mpl')
4751
pyplot.show()
4852

49-
# Todo: Training
5053

51-
if __name__ == "__main__":
52-
test_sentence = sample_sentences.NON_ADR_WALKING
53-
print(f"Input string: {test_sentence}")
54-
quantum_compute(test_sentence)
54+
def read_data(filename):
55+
labels, sentences = [], []
56+
with open(PATH_TO_TRAINING) as f:
57+
for line in f:
58+
t = int(line[0])
59+
labels.append([t, 1 - t])
60+
sentences.append(line[1:].strip())
61+
return labels, sentences
62+
63+
64+
def train_data():
65+
warnings.filterwarnings('ignore')
66+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
67+
68+
BATCH_SIZE = 30
69+
EPOCHS = 200
70+
SEED = 2
71+
72+
train_labels, train_data = read_data(PATH_TO_TRAINING)
73+
val_labels, val_data = read_data(PATH_TO_TESTING)
74+
75+
parser = BobcatParser(root_cats=('NP', 'N'), verbose='text')
76+
raw_train_diagrams = parser.sentences2diagrams(train_data, suppress_exceptions=True)
77+
raw_val_diagrams = parser.sentences2diagrams(val_data, suppress_exceptions=True)
78+
79+
train_diagrams = [
80+
diagram.normal_form()
81+
for diagram in raw_train_diagrams if diagram is not None
82+
]
83+
val_diagrams = [
84+
diagram.normal_form()
85+
for diagram in raw_val_diagrams if diagram is not None
86+
]
87+
88+
train_labels = [
89+
label for (diagram, label)
90+
in zip(raw_train_diagrams, train_labels)
91+
if diagram is not None
92+
]
93+
val_labels = [
94+
label for (diagram, label)
95+
in zip(raw_val_diagrams, val_labels)
96+
if diagram is not None
97+
]
98+
99+
ansatz = IQPAnsatz({AtomicType.NOUN: 1, AtomicType.SENTENCE: 0},
100+
n_layers=1, n_single_qubit_params=3)
101+
102+
train_circuits = [ansatz(remove_cups(diagram)) for diagram in train_diagrams]
103+
test_circuits = [ansatz(remove_cups(diagram)) for diagram in val_diagrams]
104+
all_circuits = train_circuits + test_circuits
105+
106+
backend = AerBackend()
107+
backend_config = {
108+
'backend': backend,
109+
'compilation': backend.default_compilation_pass(2),
110+
'shots': 8192
111+
}
112+
113+
model = TketModel.from_diagrams(all_circuits, backend_config=backend_config)
114+
loss = lambda y_hat, y: -np.sum(y * np.log(y_hat)) / len(y) # binary cross-entropy loss
115+
acc = lambda y_hat, y: np.sum(np.round(y_hat) == y) / len(y) / 2 # half due to double-counting
116+
eval_metrics = {"acc": acc}
117+
118+
trainer = QuantumTrainer(
119+
model,
120+
loss_function=loss,
121+
epochs=EPOCHS,
122+
optimizer=SPSAOptimizer,
123+
optim_hyperparams={'a': 0.05, 'c': 0.06, 'A': 0.01 * EPOCHS},
124+
evaluate_functions=eval_metrics,
125+
evaluate_on_train=True,
126+
verbose='text',
127+
seed=0
128+
)
129+
130+
train_dataset = Dataset(
131+
train_circuits,
132+
train_labels,
133+
batch_size=BATCH_SIZE)
134+
135+
test_dataset = Dataset(test_circuits, val_labels, shuffle=False)
136+
137+
# Plotting accuracy & loss for training/testing sets
138+
# fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = pyplot.subplots(2, 2, sharex=True, sharey='row', figsize=(10, 6))
139+
# ax_tl.set_title('Training set')
140+
# ax_tr.set_title('Development set')
141+
# ax_bl.set_xlabel('Iterations')
142+
# ax_br.set_xlabel('Iterations')
143+
# ax_bl.set_ylabel('Accuracy')
144+
# ax_tl.set_ylabel('Loss')
145+
#
146+
# colours = iter(pyplot.rcParams['axes.prop_cycle'].by_key()['color'])
147+
# ax_tl.plot(trainer.train_epoch_costs[::10], color=next(colours))
148+
# ax_bl.plot(trainer.train_results['acc'][::10], color=next(colours))
149+
# ax_tr.plot(trainer.val_costs[::10], color=next(colours))
150+
# ax_br.plot(trainer.val_results['acc'][::10], color=next(colours))
151+
152+
trainer.fit(train_dataset, test_dataset, evaluation_step=1, logging_step=20) # Train
153+
test_acc = acc(model(test_circuits), val_labels) # Record accuracy
154+
print('Test accuracy:', test_acc.item())

app/src/main/quantum_training.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from app.src.main.pipelines.quantum import train_data
2+
3+
train_data()

app/src/main/quantum_user_input.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from app.src.main.pipelines.quantum import quantum_compute
2-
1+
from app.src.main.pipelines.quantum import send_into_quantum_pipeline
32

43
test_sentence = input("Enter a sentence: ")
5-
quantum_compute(test_sentence)
4+
send_into_quantum_pipeline(test_sentence)

data/classify_adr_train_data.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,6 @@
7272
1 Advance amount is equal to 4000 .
7373
1 Fixed amount is equal to 3000 .
7474
1 Upfront amount is equal to 1000 .
75-
1 Contract renews 90 days before termination .
75+
1 Contract renews 90 days before termination .
76+
1 Minimum-Balance is 95 percent .
77+
1 Minimum-Balance is equal to 300000 .

data/rp_test_data.txt

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
1 organization that fleet destroy .
2+
1 person that teacher teach .
3+
1 device that air enter .
4+
1 device that water enter .
5+
1 device that astronomer use .
6+
1 document that student submit .
7+
1 document that government sell .
8+
1 player that pitcher face .
9+
1 building that monk build .
10+
1 quality that artist achieve .
11+
1 quality that species share .
12+
1 quality that vehicle increase .
13+
1 room that ship have .
14+
1 room that train feature .
15+
1 activity that festival feature .
16+
1 mammal that police have .
17+
1 material that police use .
18+
1 material that excavation remove .
19+
1 material that water have .
20+
0 organization that have team .
21+
0 building that attract sailor .
22+
0 device that show time .
23+
0 player that hit run .
24+
0 quality that win election .
25+
0 vehicle that replace horse .
26+
0 scientist that discover species .
27+
0 phenomenon that hit island .
28+
0 scientist that discover star .
29+
0 vehicle that destroy vessel .
30+
0 vehicle that cross river .
31+
0 mammal that attack ship .

data/rp_train_data.txt

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
1 organization that church establish .
2+
1 organization that team join .
3+
1 organization that company sell .
4+
1 organization that soldier serve .
5+
1 organization that sailor join .
6+
1 organization that vessel serve .
7+
1 organization that church represent .
8+
1 person that school serve .
9+
1 building that astronomer build .
10+
1 building that astronomer own .
11+
1 building that archaeologist discover .
12+
1 building that archaeologist study .
13+
1 player that batsman face .
14+
1 building that audience fill .
15+
1 device that shepherd play .
16+
1 document that company publish .
17+
1 device that people wear .
18+
1 document that election use .
19+
1 document that government offer .
20+
1 document that person submit .
21+
1 player that batter face .
22+
1 player that pitcher strike .
23+
1 person that train carry .
24+
1 quality that election reflect .
25+
1 organization that player join .
26+
1 quality that church teach .
27+
1 quality that vehicle offer .
28+
1 room that vessel contain .
29+
1 room that church have .
30+
1 woman that child love .
31+
1 material that fuel contain .
32+
1 woman that soldier use .
33+
1 woman that husband have .
34+
1 woman that husband love .
35+
1 vehicle that family own .
36+
1 material that ship strike .
37+
1 mammal that shepherd use .
38+
1 material that officer carry .
39+
1 phenomenon that engine lose .
40+
1 room that archaeologist discover .
41+
1 room that school include .
42+
1 room that student enter .
43+
1 vehicle that train have .
44+
1 vehicle that horse pull .
45+
1 vehicle that island have .
46+
1 material that engine require .
47+
0 organization that establish church .
48+
0 organization that support child .
49+
0 organization that use train .
50+
0 person that join movement .
51+
0 person that lose family .
52+
0 building that hold festival .
53+
0 device that carry water .
54+
0 device that keep time .
55+
0 player that strike batter .
56+
0 player that allow run .
57+
0 building that house monk .
58+
0 person that take ship .
59+
0 room that control movement .
60+
0 room that hold engine .
61+
0 woman that have child .
62+
0 woman that raise child .
63+
0 woman that carry pitcher .
64+
0 woman that have husband .
65+
0 woman that leave husband .
66+
0 activity that fill air .
67+
0 phenomenon that raise river .
68+
0 scientist that visit island .
69+
0 material that attract farmer .
70+
0 phenomenon that require fuel .
71+
0 vehicle that enter port .
72+
0 vehicle that transport horse .
73+
0 vehicle that haul material .
74+
0 activity that build school .

0 commit comments

Comments
 (0)