Skip to content

Commit 0565d5e

Browse files
committed
allow various parameterexpressions
1 parent b042a74 commit 0565d5e

File tree

4 files changed

+77
-66
lines changed

4 files changed

+77
-66
lines changed

requirements/requirements-extra.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# extra dependencies for ci
22
qiskit<1.0
3-
qiskit_aer<1.0
4-
# qiskit-nature
3+
qiskit-aer<1.0
4+
qiskit-nature
55
mitiq
66
cirq
77
torch

tensorcircuit/abstractcircuit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@
44

55
# pylint: disable=invalid-name
66

7-
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple
7+
import json
8+
import logging
89
from copy import deepcopy
910
from functools import reduce
1011
from operator import add
11-
import json
12-
import logging
12+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
1313

1414
import numpy as np
1515
import tensornetwork as tn
1616

1717
from . import gates
1818
from .cons import backend, dtypestr
19-
from .vis import qir2tex
2019
from .quantum import QuOperator
2120
from .utils import is_sequence
21+
from .vis import qir2tex
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -756,7 +756,7 @@ def to_qiskit(
756756
:type enable_inputs: bool, defaults to False
757757
:return: A qiskit object of this circuit.
758758
"""
759-
from .translation import qir2qiskit, perm_matrix
759+
from .translation import perm_matrix, qir2qiskit
760760

761761
qir = self.to_qir()
762762
if enable_instruction:
@@ -887,7 +887,7 @@ def from_qiskit(
887887
n = qc.num_qubits
888888

889889
return qiskit2tc( # type: ignore
890-
qc.data,
890+
qc,
891891
n,
892892
inputs,
893893
circuit_constructor=cls,

tensorcircuit/translation.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,26 @@
22
Circuit object translation in different packages
33
"""
44

5-
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
6-
from copy import deepcopy
75
import logging
6+
from copy import deepcopy
7+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
8+
89
import numpy as np
910

1011
logger = logging.getLogger(__name__)
1112

1213

1314
try:
15+
import qiskit.quantum_info as qi
16+
import symengine
17+
import sympy
1418
from qiskit import QuantumCircuit
19+
from qiskit.circuit import Parameter, ParameterExpression
1520
from qiskit.circuit.library import XXPlusYYGate
21+
from qiskit.circuit.parametervector import ParameterVectorElement
22+
from qiskit.circuit.quantumcircuitdata import CircuitInstruction
1623
from qiskit.extensions import UnitaryGate
17-
import qiskit.quantum_info as qi
1824
from qiskit.extensions.exceptions import ExtensionError
19-
from qiskit.circuit.quantumcircuitdata import CircuitInstruction
20-
from qiskit.circuit.parametervector import ParameterVectorElement
21-
from qiskit.circuit import Parameter, ParameterExpression
2225
except ImportError:
2326
logger.warning(
2427
"Please first ``pip install -U qiskit`` to enable related functionality in translation module"
@@ -34,11 +37,10 @@
3437

3538
from . import gates
3639
from .circuit import Circuit
37-
from .densitymatrix import DMCircuit2
3840
from .cons import backend
41+
from .densitymatrix import DMCircuit2
3942
from .interfaces.tensortrans import tensor_to_numpy
4043

41-
4244
Tensor = Any
4345

4446

@@ -358,7 +360,7 @@ def _translate_qiskit_params(
358360
gate_info: CircuitInstruction, binding_params: Any
359361
) -> List[float]:
360362
parameters = []
361-
for p in gate_info[0].params:
363+
for p in gate_info.operation.params:
362364
if isinstance(p, ParameterVectorElement):
363365
parameters.append(binding_params[p.index])
364366
elif isinstance(p, Parameter):
@@ -367,30 +369,29 @@ def _translate_qiskit_params(
367369
if len(p.parameters) == 0:
368370
parameters.append(float(p))
369371
continue
370-
if len(p.parameters) != 1:
371-
raise ValueError(
372-
f"Can't translate parameter expression with more than 1 parameters: {p}"
373-
)
374-
p_real = list(p.parameters)[0]
375-
if not isinstance(p_real, ParameterVectorElement):
376-
raise TypeError(
377-
"Parameters in parameter expression should be ParameterVectorElement"
378-
)
372+
379373
# note "sym" != "sim"
380374
expr = p.sympify().simplify()
381-
# only allow simple expressions like 1.0 * theta
382-
if not expr.is_Mul:
383-
raise ValueError(f"Unsupported parameter expression: {p}")
384-
arg1, arg2 = expr.args
385-
if arg1.is_number and arg2.is_symbol:
386-
coeff = arg1
387-
elif arg1.is_symbol and arg2.is_number:
388-
coeff = arg2
389-
else:
390-
raise ValueError(f"Unsupported parameter expression: {p}")
391-
# taking real part here because using complex type will result in a type error
392-
# for tf backend when the binding parameter is real
393-
parameters.append(float(coeff) * binding_params[p_real.index])
375+
if isinstance(expr, symengine.Expr): # qiskit uses symengine if available
376+
expr = expr._sympy_() # sympy.Expr
377+
378+
for free_symbol in expr.free_symbols:
379+
# replace names: theta[0] -> theta_0
380+
# ParameterVector creates symbols with brackets like theta[0]
381+
# but sympy.lambdify does not allow brackets in symbol names
382+
free_symbol.name = free_symbol.name.replace("[", "_").replace("]", "")
383+
384+
parameter_list = list(p.parameters)
385+
sympy_symbols = [param._symbol_expr for param in parameter_list]
386+
# replace names again: theta[0] -> theta_0
387+
sympy_symbols = [
388+
sympy.Symbol(str(symbol).replace("[", "_").replace("]", ""))
389+
for symbol in sympy_symbols
390+
]
391+
lam_f = sympy.lambdify(sympy_symbols, expr, modules=backend.name)
392+
parameters.append(
393+
lam_f(*[binding_params[param.index] for param in parameter_list])
394+
)
394395
else:
395396
# numbers, arrays, etc.
396397
parameters.append(p)
@@ -403,7 +404,7 @@ def ctrl_str2ctrl_state(ctrl_str: str, nctrl: int) -> List[int]:
403404

404405

405406
def qiskit2tc(
406-
qcdata: List[CircuitInstruction],
407+
qc: QuantumCircuit,
407408
n: int,
408409
inputs: Optional[List[float]] = None,
409410
is_dm: bool = False,
@@ -412,19 +413,18 @@ def qiskit2tc(
412413
binding_params: Optional[Union[Sequence[float], Dict[Any, float]]] = None,
413414
) -> Any:
414415
r"""
415-
Generate a tensorcircuit circuit using the quantum circuit data in qiskit.
416+
Generate a tensorcircuit circuit from the qiskit circuit.
416417
417418
:Example:
418419
419420
>>> qisc = QuantumCircuit(2)
420421
>>> qisc.h(0)
421422
>>> qisc.x(1)
422-
>>> qc = tc.translation.qiskit2tc(qisc.data, 2)
423+
>>> qc = tc.translation.qiskit2tc(qisc, 2)
423424
>>> qc.to_qir()[0]['gatef']
424-
h
425425
426-
:param qcdata: Quantum circuit data from qiskit.
427-
:type qcdata: List[CircuitInstruction]
426+
:param qc: A quantum circuit in qiskit.
427+
:type qc: QuantumCircuit
428428
:param n: # of qubits
429429
:type n: int
430430
:param inputs: Input state of the circuit. Default is None.
@@ -435,7 +435,7 @@ def qiskit2tc(
435435
:type circuit_params: Optional[Dict[str, Any]]
436436
:param binding_params: (variational) parameters for the circuit.
437437
Could be either a sequence or dictionary depending on the type of parameters in the Qiskit circuit.
438-
For ``ParameterVectorElement`` use sequence. For ``Parameter`` use dictionary
438+
For ``ParameterVectorElement`` use sequence. For ``Parameter`` use dictionary.
439439
:type binding_params: Optional[Union[Sequence[float], Dict[Any, float]]]
440440
:return: A quantum circuit in tensorcircuit
441441
:rtype: Any
@@ -451,17 +451,17 @@ def qiskit2tc(
451451
if "nqubits" not in circuit_params:
452452
circuit_params["nqubits"] = n
453453
if (
454-
len(qcdata) > 0
455-
and qcdata[0][0].name == "initialize"
454+
len(qc.data) > 0
455+
and qc.data[0][0].name == "initialize"
456456
and "inputs" not in circuit_params
457457
):
458-
circuit_params["inputs"] = perm_matrix(n) @ np.array(qcdata[0][0].params)
458+
circuit_params["inputs"] = perm_matrix(n) @ np.array(qc.data[0][0].params)
459459
if inputs is not None:
460460
circuit_params["inputs"] = inputs
461461

462462
tc_circuit: Any = Circ(**circuit_params)
463-
for gate_info in qcdata:
464-
idx = [qb.index for qb in gate_info[1]]
463+
for gate_info in qc.data:
464+
idx = [qc.find_bit(qb).index for qb in gate_info.qubits]
465465
gate_name = gate_info[0].name
466466
parameters = _translate_qiskit_params(gate_info, binding_params)
467467
if gate_name in [

tests/test_circuit.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# pylint: disable=invalid-name
22

3-
import sys
43
import os
4+
import sys
55
from functools import partial
6+
67
import numpy as np
78
import opt_einsum as oem
89
import pytest
@@ -975,6 +976,7 @@ def test_qir2cirq(backend):
975976
def test_qir2qiskit(backend):
976977
try:
977978
import qiskit.quantum_info as qi
979+
978980
from tensorcircuit.translation import perm_matrix
979981
except ImportError:
980982
pytest.skip("qiskit is not installed")
@@ -1074,9 +1076,10 @@ def test_qir2qiskit(backend):
10741076

10751077
def test_qiskit2tc():
10761078
try:
1077-
from qiskit import QuantumCircuit
10781079
import qiskit.quantum_info as qi
1080+
from qiskit import QuantumCircuit
10791081
from qiskit.circuit.library.standard_gates import MCXGate, SwapGate
1082+
10801083
from tensorcircuit.translation import perm_matrix
10811084
except ImportError:
10821085
pytest.skip("qiskit is not installed")
@@ -1149,26 +1152,34 @@ def test_qiskit2tc():
11491152
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
11501153
def test_qiskit2tc_parameterized(backend):
11511154
try:
1152-
from qiskit.circuit import QuantumCircuit, Parameter
1153-
from qiskit.quantum_info import Operator
1155+
from qiskit.circuit import Parameter, ParameterVector, QuantumCircuit
11541156
from qiskit.circuit.library import TwoLocal
1157+
from qiskit.quantum_info import Operator
11551158
from qiskit_nature.second_q.circuit.library import UCCSD
1156-
from qiskit_nature.second_q.mappers import ParityMapper, QubitConverter
1159+
from qiskit_nature.second_q.mappers import ParityMapper
11571160
except ImportError:
11581161
pytest.skip("qiskit or qiskit-nature is not installed")
11591162
from tensorcircuit.translation import perm_matrix
11601163

11611164
mapper = ParityMapper()
1162-
converter = QubitConverter(mapper=mapper, two_qubit_reduction=True)
1163-
ansatz1 = UCCSD(2, [1, 1], converter)
1165+
ansatz1 = UCCSD(2, [1, 1], mapper)
11641166
ansatz2 = TwoLocal(2, rotation_blocks="ry", entanglement_blocks="cz")
11651167
ansatz3 = QuantumCircuit(1)
11661168
ansatz3_param = Parameter("θ")
11671169
ansatz3.rx(ansatz3_param, 0)
1168-
ansatz_list = [ansatz1, ansatz2, ansatz3]
1170+
ansatz4 = QuantumCircuit(1)
1171+
ansatz4_param = ParameterVector("φ", 3)
1172+
ansatz4.rx(2.0 * ansatz4_param[0] + 5.0, 0)
1173+
ansatz4.ry(ansatz4_param[0] * ansatz4_param[1] + ansatz4_param[2], 0)
1174+
ansatz4.rz(
1175+
np.exp(np.sin(ansatz4_param[0]))
1176+
+ np.abs(ansatz4_param[1]) / np.arctan(ansatz4_param[2]),
1177+
0,
1178+
)
1179+
ansatz_list = [ansatz1, ansatz2, ansatz3, ansatz4]
11691180
for ansatz in ansatz_list:
11701181
n = ansatz.num_qubits
1171-
if ansatz in [ansatz1, ansatz2]:
1182+
if ansatz in [ansatz1, ansatz2, ansatz4]:
11721183
params = np.random.rand(ansatz.num_parameters)
11731184
else:
11741185
params = {ansatz3_param: 0.618}
@@ -1178,9 +1189,9 @@ def test_qiskit2tc_parameterized(backend):
11781189

11791190
# test jit
11801191
@tc.backend.jit
1181-
def get_unitary(_params):
1192+
def get_unitary(params):
11821193
return tc.Circuit.from_qiskit(
1183-
ansatz, inputs=np.eye(2**n), binding_params=_params
1194+
ansatz, inputs=np.eye(2**n), binding_params=params
11841195
).state()
11851196

11861197
tc_unitary = get_unitary(params)
@@ -1191,10 +1202,10 @@ def get_unitary(_params):
11911202
)
11921203

11931204
# test grad
1194-
def cost_fn(_params):
1195-
return tc.backend.real(tc.backend.sum(get_unitary(_params)))
1205+
def cost_fn(params):
1206+
return tc.backend.real(tc.backend.sum(get_unitary(params)))
11961207

1197-
if ansatz in [ansatz1, ansatz2]:
1208+
if ansatz in [ansatz1, ansatz2, ansatz4]:
11981209
grad = tc.backend.grad(cost_fn)(tc.backend.convert_to_tensor(params))
11991210
assert np.sum(np.isnan(grad)) == 0
12001211
else:
@@ -1208,8 +1219,8 @@ def cost_fn(_params):
12081219
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
12091220
def test_qiskit_vs_tc_intialization(backend):
12101221
try:
1211-
from qiskit import QuantumCircuit
12121222
import qiskit.quantum_info as qi
1223+
from qiskit import QuantumCircuit
12131224
except ImportError:
12141225
pytest.skip("qiskit is not installed")
12151226

0 commit comments

Comments
 (0)