Skip to content

Commit 69d62c7

Browse files
twaclawCh0ronomato
authored andcommitted
Started implementation of random variables with PyTorch backend.
1 parent 5d4e9e0 commit 69d62c7

File tree

5 files changed

+97
-2
lines changed

5 files changed

+97
-2
lines changed

pytensor/link/pytorch/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
import pytensor.link.pytorch.dispatch.sort
1313
import pytensor.link.pytorch.dispatch.subtensor
1414
import pytensor.link.pytorch.dispatch.blockwise
15+
import pytensor.link.pytorch.dispatch.random
1516
# isort: on

pytensor/link/pytorch/dispatch/basic.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@
2626

2727

2828
@singledispatch
29-
def pytorch_typify(data, **kwargs):
30-
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
29+
def pytorch_typify(data, dtype=None, **kwargs):
30+
if dtype is None:
31+
return data
32+
else:
33+
return torch.tensor(data, dtype=dtype)
3134

3235

3336
@pytorch_typify.register(np.ndarray)
+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from functools import singledispatch
2+
3+
import torch
4+
from numpy.random import Generator
5+
6+
import pytensor.tensor.random.basic as ptr
7+
from pytensor.graph import Constant
8+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
9+
from pytensor.tensor.type_other import NoneTypeT
10+
11+
12+
@pytorch_typify.register(Generator)
13+
def pytorch_typify_Generator(rng, **kwargs):
14+
state = rng.__getstate__()
15+
state["pytorch_state"] = torch.manual_seed(123).get_state() # XXX: replace
16+
return state
17+
18+
19+
@pytorch_funcify.register(ptr.RandomVariable)
20+
def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
21+
rv = node.outputs[1]
22+
out_dtype = rv.type.dtype
23+
static_shape = rv.type.shape
24+
batch_ndim = op.batch_ndim(node)
25+
26+
# Try to pass static size directly to JAX
27+
static_size = static_shape[:batch_ndim]
28+
if None in static_size:
29+
# Sometimes size can be constant folded during rewrites,
30+
# without the RandomVariable node being updated with new static types
31+
size_param = op.size_param(node)
32+
if isinstance(size_param, Constant) and not isinstance(
33+
size_param.type, NoneTypeT
34+
):
35+
static_size = tuple(size_param.data)
36+
37+
def sample_fn(rng, size, *parameters):
38+
return pytorch_sample_fn(op, node=node)(
39+
rng, static_size, out_dtype, *parameters
40+
)
41+
42+
return sample_fn
43+
44+
45+
@singledispatch
46+
def pytorch_sample_fn(op, node):
47+
name = op.name
48+
raise NotImplementedError(
49+
f"No PyTorch implementation for the given distribution: {name}"
50+
)
51+
52+
53+
@pytorch_sample_fn.register(ptr.BernoulliRV)
54+
def pytorch_sample_fn_bernoulli(op, node):
55+
def sample_fn(rng, size, dtype, p):
56+
# XXX replace
57+
state_ = rng["pytorch_state"]
58+
gen = torch.Generator().set_state(state_)
59+
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen)
60+
return (rng, sample)
61+
62+
return sample_fn

pytensor/link/pytorch/linker.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from numpy.random import Generator, RandomState
2+
13
from pytensor.link.basic import JITLinker
24
from pytensor.link.utils import unique_name_generator
35

@@ -83,9 +85,16 @@ def __del__(self):
8385
return inner_fn
8486

8587
def create_thunk_inputs(self, storage_map):
88+
from pytensor.link.pytorch.dispatch import pytorch_typify
89+
8690
thunk_inputs = []
8791
for n in self.fgraph.inputs:
8892
sinput = storage_map[n]
93+
if isinstance(sinput[0], RandomState | Generator):
94+
new_value = pytorch_typify(
95+
sinput[0], dtype=getattr(sinput[0], "dtype", None)
96+
)
97+
sinput[0] = new_value
8998
thunk_inputs.append(sinput)
9099

91100
return thunk_inputs

tests/link/pytorch/test_random.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor.tensor as pt
5+
from pytensor.compile.function import function
6+
from pytensor.compile.sharedvalue import shared
7+
from tests.link.pytorch.test_basic import pytorch_mode
8+
9+
10+
torch = pytest.importorskip("torch")
11+
12+
13+
@pytest.mark.parametrize("size", [(), (4,)])
14+
def test_random_bernoulli(size):
15+
rng = shared(np.random.default_rng(123))
16+
17+
g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng)
18+
g_fn = function([], g, mode=pytorch_mode)
19+
samples = g_fn()
20+
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)

0 commit comments

Comments
 (0)