1
1
import os
2
+ import warnings
2
3
import webbrowser
3
4
5
+ import numpy as np
4
6
from discopy import grammar
5
- from lambeq import BobcatParser , Rewriter , AtomicType , IQPAnsatz
7
+ from lambeq import BobcatParser , Rewriter , AtomicType , IQPAnsatz , remove_cups , TketModel , QuantumTrainer , SPSAOptimizer , \
8
+ Dataset
6
9
from matplotlib import pyplot
7
10
from pytket .circuit .display import render_circuit_as_html
8
- from pytket .extensions .qiskit import tk_to_qiskit
11
+ from pytket .extensions .qiskit import tk_to_qiskit , AerBackend
9
12
10
- from app .src .main .constants import sample_sentences
11
- from settings import GEN_PATH
13
+ from settings import GEN_PATH , PROJECT_ROOT_PATH
12
14
15
+ PATH_TO_TRAINING = os .path .join (PROJECT_ROOT_PATH , 'data' , 'rp_train_data.txt' )
16
+ PATH_TO_TESTING = os .path .join (PROJECT_ROOT_PATH , 'data' , 'rp_test_data.txt' )
13
17
14
- def quantum_compute (sentence ):
18
+
19
+ def send_into_quantum_pipeline (sentence ):
15
20
# Convert to string diagram
16
21
parser = BobcatParser (verbose = 'text' )
17
22
diagram = parser .sentence2diagram (sentence ) # syntax-based, not bag-of-words
@@ -20,7 +25,6 @@ def quantum_compute(sentence):
20
25
# Rewrite string diagram, to reduce performance costs / training time
21
26
rewriter = Rewriter (['prepositional_phrase' , 'determiner' ]) # lower tensor count on prepositions
22
27
prep_reduced_diagram = rewriter (diagram ).normal_form ()
23
- prep_reduced_diagram .draw (figsize = (9 , 4 ), fontsize = 13 )
24
28
25
29
curry_functor = Rewriter (['curry' ]) # reduce number of cups
26
30
curried_diagram = curry_functor (prep_reduced_diagram ).normal_form ()
@@ -33,7 +37,7 @@ def quantum_compute(sentence):
33
37
C = AtomicType .CONJUNCTION
34
38
ansatz = IQPAnsatz ({N : 1 , S : 1 , P : 1 , C : 1 }, n_layers = 4 )
35
39
36
- discopy_circuit = ansatz (diagram ) # Quantum circuit, DisCoPy format
40
+ discopy_circuit = ansatz (diagram ) # Quantum circuit, DisCoPy format
37
41
discopy_circuit .draw (figsize = (15 , 10 ))
38
42
39
43
tket_circuit = discopy_circuit .to_tk () # Quantum circuit, pytket format
@@ -46,9 +50,105 @@ def quantum_compute(sentence):
46
50
qiskit_circuit .draw (output = 'mpl' )
47
51
pyplot .show ()
48
52
49
- # Todo: Training
50
53
51
- if __name__ == "__main__" :
52
- test_sentence = sample_sentences .NON_ADR_WALKING
53
- print (f"Input string: { test_sentence } " )
54
- quantum_compute (test_sentence )
54
+ def read_data (filename ):
55
+ labels , sentences = [], []
56
+ with open (PATH_TO_TRAINING ) as f :
57
+ for line in f :
58
+ t = int (line [0 ])
59
+ labels .append ([t , 1 - t ])
60
+ sentences .append (line [1 :].strip ())
61
+ return labels , sentences
62
+
63
+
64
+ def train_data ():
65
+ warnings .filterwarnings ('ignore' )
66
+ os .environ ['TOKENIZERS_PARALLELISM' ] = 'true'
67
+
68
+ BATCH_SIZE = 30
69
+ EPOCHS = 200
70
+ SEED = 2
71
+
72
+ train_labels , train_data = read_data (PATH_TO_TRAINING )
73
+ val_labels , val_data = read_data (PATH_TO_TESTING )
74
+
75
+ parser = BobcatParser (root_cats = ('NP' , 'N' ), verbose = 'text' )
76
+ raw_train_diagrams = parser .sentences2diagrams (train_data , suppress_exceptions = True )
77
+ raw_val_diagrams = parser .sentences2diagrams (val_data , suppress_exceptions = True )
78
+
79
+ train_diagrams = [
80
+ diagram .normal_form ()
81
+ for diagram in raw_train_diagrams if diagram is not None
82
+ ]
83
+ val_diagrams = [
84
+ diagram .normal_form ()
85
+ for diagram in raw_val_diagrams if diagram is not None
86
+ ]
87
+
88
+ train_labels = [
89
+ label for (diagram , label )
90
+ in zip (raw_train_diagrams , train_labels )
91
+ if diagram is not None
92
+ ]
93
+ val_labels = [
94
+ label for (diagram , label )
95
+ in zip (raw_val_diagrams , val_labels )
96
+ if diagram is not None
97
+ ]
98
+
99
+ ansatz = IQPAnsatz ({AtomicType .NOUN : 1 , AtomicType .SENTENCE : 0 },
100
+ n_layers = 1 , n_single_qubit_params = 3 )
101
+
102
+ train_circuits = [ansatz (remove_cups (diagram )) for diagram in train_diagrams ]
103
+ test_circuits = [ansatz (remove_cups (diagram )) for diagram in val_diagrams ]
104
+ all_circuits = train_circuits + test_circuits
105
+
106
+ backend = AerBackend ()
107
+ backend_config = {
108
+ 'backend' : backend ,
109
+ 'compilation' : backend .default_compilation_pass (2 ),
110
+ 'shots' : 8192
111
+ }
112
+
113
+ model = TketModel .from_diagrams (all_circuits , backend_config = backend_config )
114
+ loss = lambda y_hat , y : - np .sum (y * np .log (y_hat )) / len (y ) # binary cross-entropy loss
115
+ acc = lambda y_hat , y : np .sum (np .round (y_hat ) == y ) / len (y ) / 2 # half due to double-counting
116
+ eval_metrics = {"acc" : acc }
117
+
118
+ trainer = QuantumTrainer (
119
+ model ,
120
+ loss_function = loss ,
121
+ epochs = EPOCHS ,
122
+ optimizer = SPSAOptimizer ,
123
+ optim_hyperparams = {'a' : 0.05 , 'c' : 0.06 , 'A' : 0.01 * EPOCHS },
124
+ evaluate_functions = eval_metrics ,
125
+ evaluate_on_train = True ,
126
+ verbose = 'text' ,
127
+ seed = 0
128
+ )
129
+
130
+ train_dataset = Dataset (
131
+ train_circuits ,
132
+ train_labels ,
133
+ batch_size = BATCH_SIZE )
134
+
135
+ test_dataset = Dataset (test_circuits , val_labels , shuffle = False )
136
+
137
+ # Plotting accuracy & loss for training/testing sets
138
+ # fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = pyplot.subplots(2, 2, sharex=True, sharey='row', figsize=(10, 6))
139
+ # ax_tl.set_title('Training set')
140
+ # ax_tr.set_title('Development set')
141
+ # ax_bl.set_xlabel('Iterations')
142
+ # ax_br.set_xlabel('Iterations')
143
+ # ax_bl.set_ylabel('Accuracy')
144
+ # ax_tl.set_ylabel('Loss')
145
+ #
146
+ # colours = iter(pyplot.rcParams['axes.prop_cycle'].by_key()['color'])
147
+ # ax_tl.plot(trainer.train_epoch_costs[::10], color=next(colours))
148
+ # ax_bl.plot(trainer.train_results['acc'][::10], color=next(colours))
149
+ # ax_tr.plot(trainer.val_costs[::10], color=next(colours))
150
+ # ax_br.plot(trainer.val_results['acc'][::10], color=next(colours))
151
+
152
+ trainer .fit (train_dataset , test_dataset , evaluation_step = 1 , logging_step = 20 ) # Train
153
+ test_acc = acc (model (test_circuits ), val_labels ) # Record accuracy
154
+ print ('Test accuracy:' , test_acc .item ())
0 commit comments