Skip to content

Commit f2d48ff

Browse files
add probability_sample method for backend
1 parent d597f90 commit f2d48ff

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
- Add `searchsorted` method for backend
2222

23+
- Add `probability_sample` method for backend as an alternative for `random_choice` since it supports `status` as external randomness format
24+
2325
### Changed
2426

2527
- The inner mechanism for `sample_expectation_ps` is changed to sample representation from count representation for a fast speed

tensorcircuit/backends/abstract_backend.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,41 @@ def stateful_randc(
10671067
"Backend '{}' has not implemented `stateful_randc`.".format(self.name)
10681068
)
10691069

1070+
def probability_sample(
1071+
self: Any, shots: int, p: Tensor, status: Optional[Tensor] = None, g: Any = None
1072+
) -> Tensor:
1073+
"""
1074+
Drawn ``shots`` samples from probability distribution p, given the external randomness
1075+
determined by uniform distributed ``status`` tensor or backend random generator ``g``.
1076+
This method is similar with ``stateful_randc``, but it supports ``status`` beyond ``g``,
1077+
which is convenient when jit or vmap
1078+
1079+
:param shots: Number of samples to draw with replacement
1080+
:type shots: int
1081+
:param p: prbability vector
1082+
:type p: Tensor
1083+
:param status: external randomness as a tensor with each element drawn uniformly from [0, 1],
1084+
defaults to None
1085+
:type status: Optional[Tensor], optional
1086+
:param g: backend random genrator, defaults to None
1087+
:type g: Any, optional
1088+
:return: The drawn sample as an int tensor
1089+
:rtype: Tensor
1090+
"""
1091+
if status is not None:
1092+
status = self.convert_to_tensor(status)
1093+
elif g is not None:
1094+
status = self.stateful_randu(g, shape=[shots])
1095+
else:
1096+
status = self.implicit_randu(shape=[shots])
1097+
p = p / self.sum(p)
1098+
p_cuml = self.cumsum(p)
1099+
r = p_cuml[-1] * (1 - self.cast(status, p.dtype))
1100+
ind = self.searchsorted(p_cuml, r)
1101+
a = self.arange(shots)
1102+
res = self.gather1d(a, ind)
1103+
return res
1104+
10701105
def gather1d(self: Any, operand: Tensor, indices: Tensor) -> Tensor:
10711106
"""
10721107
Return ``operand[indices]``, both ``operand`` and ``indices`` are rank-1 tensor.

tests/test_backends.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,16 @@ def test_backend_methods_2(backend):
283283
values = tc.backend.convert_to_tensor(np.array([0.0, 4.1, 12.0], dtype=np.float32))
284284
r = tc.backend.numpy(tc.backend.searchsorted(edges, values))
285285
np.testing.assert_allclose(r, np.array([1, 2, 4]))
286+
p = tc.backend.convert_to_tensor(
287+
np.array(
288+
[0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.2, 0.4], dtype=np.float32
289+
)
290+
)
291+
r = tc.backend.probability_sample(10000, p, status=np.random.uniform(size=[10000]))
292+
_, r = np.unique(r, return_counts=True)
293+
np.testing.assert_allclose(
294+
r - tc.backend.numpy(p) * 10000.0, np.zeros([10]), atol=100, rtol=1
295+
)
286296

287297

288298
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])

0 commit comments

Comments
 (0)