|
| 1 | +from functools import singledispatch |
| 2 | + |
| 3 | +import torch |
| 4 | +from numpy.random import Generator |
| 5 | + |
| 6 | +import pytensor.tensor.random.basic as ptr |
| 7 | +from pytensor.graph import Constant |
| 8 | +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify |
| 9 | +from pytensor.tensor.type_other import NoneTypeT |
| 10 | + |
| 11 | + |
| 12 | +@pytorch_typify.register(Generator) |
| 13 | +def pytorch_typify_Generator(rng, **kwargs): |
| 14 | + state = rng.__getstate__() |
| 15 | + state["pytorch_state"] = torch.manual_seed(123).get_state() # XXX: replace |
| 16 | + return state |
| 17 | + |
| 18 | + |
| 19 | +@pytorch_funcify.register(ptr.RandomVariable) |
| 20 | +def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): |
| 21 | + rv = node.outputs[1] |
| 22 | + out_dtype = rv.type.dtype |
| 23 | + static_shape = rv.type.shape |
| 24 | + batch_ndim = op.batch_ndim(node) |
| 25 | + |
| 26 | + # Try to pass static size directly to JAX |
| 27 | + static_size = static_shape[:batch_ndim] |
| 28 | + if None in static_size: |
| 29 | + # Sometimes size can be constant folded during rewrites, |
| 30 | + # without the RandomVariable node being updated with new static types |
| 31 | + size_param = op.size_param(node) |
| 32 | + if isinstance(size_param, Constant) and not isinstance( |
| 33 | + size_param.type, NoneTypeT |
| 34 | + ): |
| 35 | + static_size = tuple(size_param.data) |
| 36 | + |
| 37 | + def sample_fn(rng, size, *parameters): |
| 38 | + return pytorch_sample_fn(op, node=node)( |
| 39 | + rng, static_size, out_dtype, *parameters |
| 40 | + ) |
| 41 | + |
| 42 | + return sample_fn |
| 43 | + |
| 44 | + |
| 45 | +@singledispatch |
| 46 | +def pytorch_sample_fn(op, node): |
| 47 | + name = op.name |
| 48 | + raise NotImplementedError( |
| 49 | + f"No PyTorch implementation for the given distribution: {name}" |
| 50 | + ) |
| 51 | + |
| 52 | + |
| 53 | +@pytorch_sample_fn.register(ptr.BernoulliRV) |
| 54 | +def pytorch_sample_fn_bernoulli(op, node): |
| 55 | + def sample_fn(rng, size, dtype, p): |
| 56 | + # XXX replace |
| 57 | + state_ = rng["pytorch_state"] |
| 58 | + gen = torch.Generator().set_state(state_) |
| 59 | + sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) |
| 60 | + return (rng, sample) |
| 61 | + |
| 62 | + return sample_fn |
0 commit comments