11import os
2+ import warnings
23import webbrowser
34
5+ import numpy as np
46from discopy import grammar
5- from lambeq import BobcatParser , Rewriter , AtomicType , IQPAnsatz
7+ from lambeq import BobcatParser , Rewriter , AtomicType , IQPAnsatz , remove_cups , TketModel , QuantumTrainer , SPSAOptimizer , \
8+ Dataset
69from matplotlib import pyplot
710from pytket .circuit .display import render_circuit_as_html
8- from pytket .extensions .qiskit import tk_to_qiskit
11+ from pytket .extensions .qiskit import tk_to_qiskit , AerBackend
912
10- from app .src .main .constants import sample_sentences
11- from settings import GEN_PATH
13+ from settings import GEN_PATH , PROJECT_ROOT_PATH
1214
15+ PATH_TO_TRAINING = os .path .join (PROJECT_ROOT_PATH , 'data' , 'rp_train_data.txt' )
16+ PATH_TO_TESTING = os .path .join (PROJECT_ROOT_PATH , 'data' , 'rp_test_data.txt' )
1317
14- def quantum_compute (sentence ):
18+
19+ def send_into_quantum_pipeline (sentence ):
1520 # Convert to string diagram
1621 parser = BobcatParser (verbose = 'text' )
1722 diagram = parser .sentence2diagram (sentence ) # syntax-based, not bag-of-words
@@ -20,7 +25,6 @@ def quantum_compute(sentence):
2025 # Rewrite string diagram, to reduce performance costs / training time
2126 rewriter = Rewriter (['prepositional_phrase' , 'determiner' ]) # lower tensor count on prepositions
2227 prep_reduced_diagram = rewriter (diagram ).normal_form ()
23- prep_reduced_diagram .draw (figsize = (9 , 4 ), fontsize = 13 )
2428
2529 curry_functor = Rewriter (['curry' ]) # reduce number of cups
2630 curried_diagram = curry_functor (prep_reduced_diagram ).normal_form ()
@@ -33,7 +37,7 @@ def quantum_compute(sentence):
3337 C = AtomicType .CONJUNCTION
3438 ansatz = IQPAnsatz ({N : 1 , S : 1 , P : 1 , C : 1 }, n_layers = 4 )
3539
36- discopy_circuit = ansatz (diagram ) # Quantum circuit, DisCoPy format
40+ discopy_circuit = ansatz (diagram ) # Quantum circuit, DisCoPy format
3741 discopy_circuit .draw (figsize = (15 , 10 ))
3842
3943 tket_circuit = discopy_circuit .to_tk () # Quantum circuit, pytket format
@@ -46,9 +50,105 @@ def quantum_compute(sentence):
4650 qiskit_circuit .draw (output = 'mpl' )
4751 pyplot .show ()
4852
49- # Todo: Training
5053
51- if __name__ == "__main__" :
52- test_sentence = sample_sentences .NON_ADR_WALKING
53- print (f"Input string: { test_sentence } " )
54- quantum_compute (test_sentence )
54+ def read_data (filename ):
55+ labels , sentences = [], []
56+ with open (PATH_TO_TRAINING ) as f :
57+ for line in f :
58+ t = int (line [0 ])
59+ labels .append ([t , 1 - t ])
60+ sentences .append (line [1 :].strip ())
61+ return labels , sentences
62+
63+
64+ def train_data ():
65+ warnings .filterwarnings ('ignore' )
66+ os .environ ['TOKENIZERS_PARALLELISM' ] = 'true'
67+
68+ BATCH_SIZE = 30
69+ EPOCHS = 200
70+ SEED = 2
71+
72+ train_labels , train_data = read_data (PATH_TO_TRAINING )
73+ val_labels , val_data = read_data (PATH_TO_TESTING )
74+
75+ parser = BobcatParser (root_cats = ('NP' , 'N' ), verbose = 'text' )
76+ raw_train_diagrams = parser .sentences2diagrams (train_data , suppress_exceptions = True )
77+ raw_val_diagrams = parser .sentences2diagrams (val_data , suppress_exceptions = True )
78+
79+ train_diagrams = [
80+ diagram .normal_form ()
81+ for diagram in raw_train_diagrams if diagram is not None
82+ ]
83+ val_diagrams = [
84+ diagram .normal_form ()
85+ for diagram in raw_val_diagrams if diagram is not None
86+ ]
87+
88+ train_labels = [
89+ label for (diagram , label )
90+ in zip (raw_train_diagrams , train_labels )
91+ if diagram is not None
92+ ]
93+ val_labels = [
94+ label for (diagram , label )
95+ in zip (raw_val_diagrams , val_labels )
96+ if diagram is not None
97+ ]
98+
99+ ansatz = IQPAnsatz ({AtomicType .NOUN : 1 , AtomicType .SENTENCE : 0 },
100+ n_layers = 1 , n_single_qubit_params = 3 )
101+
102+ train_circuits = [ansatz (remove_cups (diagram )) for diagram in train_diagrams ]
103+ test_circuits = [ansatz (remove_cups (diagram )) for diagram in val_diagrams ]
104+ all_circuits = train_circuits + test_circuits
105+
106+ backend = AerBackend ()
107+ backend_config = {
108+ 'backend' : backend ,
109+ 'compilation' : backend .default_compilation_pass (2 ),
110+ 'shots' : 8192
111+ }
112+
113+ model = TketModel .from_diagrams (all_circuits , backend_config = backend_config )
114+ loss = lambda y_hat , y : - np .sum (y * np .log (y_hat )) / len (y ) # binary cross-entropy loss
115+ acc = lambda y_hat , y : np .sum (np .round (y_hat ) == y ) / len (y ) / 2 # half due to double-counting
116+ eval_metrics = {"acc" : acc }
117+
118+ trainer = QuantumTrainer (
119+ model ,
120+ loss_function = loss ,
121+ epochs = EPOCHS ,
122+ optimizer = SPSAOptimizer ,
123+ optim_hyperparams = {'a' : 0.05 , 'c' : 0.06 , 'A' : 0.01 * EPOCHS },
124+ evaluate_functions = eval_metrics ,
125+ evaluate_on_train = True ,
126+ verbose = 'text' ,
127+ seed = 0
128+ )
129+
130+ train_dataset = Dataset (
131+ train_circuits ,
132+ train_labels ,
133+ batch_size = BATCH_SIZE )
134+
135+ test_dataset = Dataset (test_circuits , val_labels , shuffle = False )
136+
137+ # Plotting accuracy & loss for training/testing sets
138+ # fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = pyplot.subplots(2, 2, sharex=True, sharey='row', figsize=(10, 6))
139+ # ax_tl.set_title('Training set')
140+ # ax_tr.set_title('Development set')
141+ # ax_bl.set_xlabel('Iterations')
142+ # ax_br.set_xlabel('Iterations')
143+ # ax_bl.set_ylabel('Accuracy')
144+ # ax_tl.set_ylabel('Loss')
145+ #
146+ # colours = iter(pyplot.rcParams['axes.prop_cycle'].by_key()['color'])
147+ # ax_tl.plot(trainer.train_epoch_costs[::10], color=next(colours))
148+ # ax_bl.plot(trainer.train_results['acc'][::10], color=next(colours))
149+ # ax_tr.plot(trainer.val_costs[::10], color=next(colours))
150+ # ax_br.plot(trainer.val_results['acc'][::10], color=next(colours))
151+
152+ trainer .fit (train_dataset , test_dataset , evaluation_step = 1 , logging_step = 20 ) # Train
153+ test_acc = acc (model (test_circuits ), val_labels ) # Record accuracy
154+ print ('Test accuracy:' , test_acc .item ())
0 commit comments