Skip to content

Commit cb26531

Browse files
committed
Update shadows.py and test_shadows.py
1 parent dd554a6 commit cb26531

File tree

2 files changed

+49
-15
lines changed

2 files changed

+49
-15
lines changed

tensorcircuit/shadows.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
def shadow_bound(
1515
observables: Union[Tensor, Sequence[int]], epsilon: float, delta: float = 0.01
1616
) -> Tuple[int, int]:
17-
r"""Calculate the shadow bound of the Pauli observables, please refer to the Theorem S1 and Lemma S3 in Huang, H.-Y., R. Kueng, and J. Preskill, 2020, Nat. Phys. 16, 1050.
17+
r"""Calculate the shadow bound of the Pauli observables, please refer to the Theorem S1 and Lemma S3 in
18+
Huang, H.-Y., R. Kueng, and J. Preskill, 2020, Nat. Phys. 16, 1050.
1819
1920
:param observables: shape = (nq,) or (M, nq), where nq is the number of qubits, M is the number of observables
2021
:type: Union[Tensor, Sequence[int]]
@@ -25,10 +26,11 @@ def shadow_bound(
2526
2627
:return Nk: number of snapshots
2728
:rtype: int
28-
:return k: Number of equal parts to split the shadow snapshot states to compute the median of means. k=1 (default) corresponds to simply taking the mean over all shadow snapshot states.
29+
:return k: Number of equal parts to split the shadow snapshot states to compute the median of means.
30+
k=1 (default) corresponds to simply taking the mean over all shadow snapshot states.
2931
:rtype: int
3032
"""
31-
count = np.sign(backend.numpy(observables))
33+
count = np.sign(np.asarray(observables))
3234
if len(count.shape) == 1:
3335
count = count[None, :]
3436
M = count.shape[0]
@@ -65,7 +67,8 @@ def shadow_snapshots(
6567
ns, nq = pauli_strings.shape
6668
if 2**nq != len(psi):
6769
raise ValueError(
68-
f"The number of qubits of psi and pauli_strings should be the same, but got {nq} and {int(np.log2(len(psi)))}."
70+
f"The number of qubits of psi and pauli_strings should be the same, "
71+
f"but got {nq} and {int(np.log2(len(psi)))}."
6972
)
7073
if status is None:
7174
status = backend.convert_to_tensor(np.random.rand(ns, 1))
@@ -225,9 +228,11 @@ def expection_ps_shadow(
225228
:type: Optional[Sequence[int]]
226229
:param z: sites to apply Z gate, defaults to None
227230
:type: Optional[Sequence[int]]
228-
:param ps: or one can apply a ps structures instead of x, y, z, e.g. [1, 1, 0, 2, 3, 0] for X_0X_1Y_3Z_4 defaults to None, ps can overwrite x, y and z
231+
:param ps: or one can apply a ps structures instead of x, y, z, e.g. [1, 1, 0, 2, 3, 0] for X_0X_1Y_3Z_4
232+
defaults to None, ps can overwrite x, y and z
229233
:type: Optional[Sequence[int]]
230-
:param k: Number of equal parts to split the shadow snapshot states to compute the median of means. k=1 (default) corresponds to simply taking the mean over all shadow snapshot states.
234+
:param k: Number of equal parts to split the shadow snapshot states to compute the median of means.
235+
k=1 (default) corresponds to simply taking the mean over all shadow snapshot states.
231236
:type: int
232237
233238
:return expectation values: shape = (k,)
@@ -328,8 +333,8 @@ def entropy_shadow(
328333

329334

330335
def renyi_entropy_2(snapshots: Tensor, sub: Optional[Sequence[int]] = None) -> Tensor:
331-
r"""To calculate the second order Renyi entropy of a subsystem from snapshot, please refer to Brydges, T. et al. Science 364, 260–263 (2019).
332-
This function is not jitable.
336+
r"""To calculate the second order Renyi entropy of a subsystem from snapshot, please refer to
337+
Brydges, T. et al. Science 364, 260–263 (2019). This function is not jitable.
333338
334339
:param snapshots: shape = (ns, repeat, nq)
335340
:type: Tensor
@@ -438,7 +443,7 @@ def global_shadow_state2(
438443
lss_states = slice_sub(snapshots, sub)
439444
else:
440445
lss_states = snapshots # (ns, repeat, nq, 2, 2)
441-
ns, repeat, nq, _, _ = lss_states.shape
446+
nq = lss_states.shape[2]
442447

443448
old_indices = [f"{ABC[2 * i: 2 + 2 * i]}" for i in range(nq)]
444449
new_indices = f"{ABC[0:2 * nq:2]}{ABC[1:2 * nq:2]}"

tests/test_shadows.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,10 @@
55
from tensorcircuit.shadows import (
66
shadow_bound,
77
shadow_snapshots,
8-
local_snapshot_states,
98
global_shadow_state,
109
entropy_shadow,
1110
renyi_entropy_2,
1211
expection_ps_shadow,
13-
global_shadow_state1,
14-
global_shadow_state2,
15-
slice_sub,
1612
)
1713

1814

@@ -54,8 +50,8 @@ def classical_shadow(psi, pauli_strings, status):
5450
expc, ent = csjit(psi, pauli_strings, status)
5551
expc = np.median(expc)
5652

57-
assert np.abs(expc - exact_expc) < error
58-
assert np.abs(ent - exact_ent) < 5 * error
53+
assert np.isclose(expc, exact_expc, atol=error)
54+
assert np.isclose(ent, exact_ent, atol=5 * error)
5955

6056

6157
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
@@ -77,6 +73,39 @@ def test_state(backend):
7773
np.allclose(sdw_state, bell_state, atol=0.01)
7874

7975

76+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
77+
def test_ent(backend):
78+
nq, ns, repeat = 6, 2000, 1000
79+
80+
thetas = 2 * np.random.rand(2, nq) - 1
81+
82+
c = tc.Circuit(nq)
83+
for i in range(nq):
84+
c.H(i)
85+
for i in range(2):
86+
for j in range(nq):
87+
c.cnot(j, (j + 1) % nq)
88+
for j in range(nq):
89+
c.rz(j, theta=thetas[i, j] * np.pi)
90+
91+
sub = [1, 4]
92+
psi = c.state()
93+
94+
pauli_strings = tc.backend.convert_to_tensor(np.random.randint(1, 4, size=(ns, nq)))
95+
status = tc.backend.convert_to_tensor(np.random.rand(ns, repeat))
96+
snapshots = shadow_snapshots(psi, pauli_strings, status, measurement_only=True)
97+
98+
exact_rdm = tc.quantum.reduced_density_matrix(
99+
psi, cut=[i for i in range(nq) if i not in sub]
100+
)
101+
exact_ent = tc.quantum.renyi_entropy(exact_rdm, k=2)
102+
ent = entropy_shadow(snapshots, pauli_strings, sub, alpha=2)
103+
ent2 = renyi_entropy_2(snapshots, sub)
104+
105+
assert np.isclose(ent, exact_ent, atol=0.1)
106+
assert np.isclose(ent2, exact_ent, atol=0.1)
107+
108+
80109
# @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
81110
# def test_expc(backend):
82111
# import pennylane as qml

0 commit comments

Comments
 (0)