Skip to content

Commit 908e5e7

Browse files
committed
support pytorch backend (kind of)
1 parent 0565d5e commit 908e5e7

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

tensorcircuit/translation.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,16 @@ def _translate_qiskit_params(
375375
if isinstance(expr, symengine.Expr): # qiskit uses symengine if available
376376
expr = expr._sympy_() # sympy.Expr
377377

378+
if expr.is_algebraic_expr(): # numpy ufuncs are not used
379+
lambdify_module_name = "numpy"
380+
else:
381+
if backend.name == "pytorch":
382+
raise ValueError(
383+
"pytorch backend does not support sympy lambdify with non-algebraic expressions"
384+
)
385+
else:
386+
lambdify_module_name = backend.name
387+
378388
for free_symbol in expr.free_symbols:
379389
# replace names: theta[0] -> theta_0
380390
# ParameterVector creates symbols with brackets like theta[0]
@@ -388,7 +398,7 @@ def _translate_qiskit_params(
388398
sympy.Symbol(str(symbol).replace("[", "_").replace("]", ""))
389399
for symbol in sympy_symbols
390400
]
391-
lam_f = sympy.lambdify(sympy_symbols, expr, modules=backend.name)
401+
lam_f = sympy.lambdify(sympy_symbols, expr, modules=lambdify_module_name)
392402
parameters.append(
393403
lam_f(*[binding_params[param.index] for param in parameter_list])
394404
)

tests/test_circuit.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,7 @@ def test_qiskit2tc():
11491149
np.testing.assert_allclose(qis_unitary2, qis_unitary, atol=1e-5)
11501150

11511151

1152-
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
1152+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
11531153
def test_qiskit2tc_parameterized(backend):
11541154
try:
11551155
from qiskit.circuit import Parameter, ParameterVector, QuantumCircuit
@@ -1171,11 +1171,12 @@ def test_qiskit2tc_parameterized(backend):
11711171
ansatz4_param = ParameterVector("φ", 3)
11721172
ansatz4.rx(2.0 * ansatz4_param[0] + 5.0, 0)
11731173
ansatz4.ry(ansatz4_param[0] * ansatz4_param[1] + ansatz4_param[2], 0)
1174-
ansatz4.rz(
1175-
np.exp(np.sin(ansatz4_param[0]))
1176-
+ np.abs(ansatz4_param[1]) / np.arctan(ansatz4_param[2]),
1177-
0,
1178-
)
1174+
if tc.backend.name != "pytorch": # pytorch backend with ufuncs is not supported
1175+
ansatz4.rz(
1176+
np.exp(np.sin(ansatz4_param[0]))
1177+
+ np.abs(ansatz4_param[1]) / np.arctan(ansatz4_param[2]),
1178+
0,
1179+
)
11791180
ansatz_list = [ansatz1, ansatz2, ansatz3, ansatz4]
11801181
for ansatz in ansatz_list:
11811182
n = ansatz.num_qubits
@@ -1196,7 +1197,7 @@ def get_unitary(params):
11961197

11971198
tc_unitary = get_unitary(params)
11981199
tc_unitary = np.reshape(tc_unitary, [2**n, 2**n])
1199-
p_mat = perm_matrix(n)
1200+
p_mat = tc.array_to_tensor(perm_matrix(n))
12001201
np.testing.assert_allclose(
12011202
p_mat @ tc_unitary @ p_mat, qiskit_unitary, atol=1e-5
12021203
)
@@ -1207,7 +1208,7 @@ def cost_fn(params):
12071208

12081209
if ansatz in [ansatz1, ansatz2, ansatz4]:
12091210
grad = tc.backend.grad(cost_fn)(tc.backend.convert_to_tensor(params))
1210-
assert np.sum(np.isnan(grad)) == 0
1211+
assert tc.backend.sum(tc.num_to_tensor(np.isnan(grad))) == 0
12111212
else:
12121213
# tf only supports tf tensor as input
12131214
grad = tc.backend.grad(cost_fn)(

0 commit comments

Comments
 (0)