4
4
from numpy .random import Generator
5
5
6
6
import pytensor .tensor .random .basic as ptr
7
- from pytensor .graph import Constant
8
7
from pytensor .link .pytorch .dispatch .basic import pytorch_funcify , pytorch_typify
9
- from pytensor .tensor .type_other import NoneTypeT
10
8
11
9
12
10
@pytorch_typify .register (Generator )
13
11
def pytorch_typify_Generator (rng , ** kwargs ):
12
+ # XXX: Check if there is a better way.
13
+ # Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp)
14
14
state = rng .__getstate__ ()
15
- state ["pytorch_state" ] = torch .manual_seed (123 ).get_state () # XXX: replace
15
+ seed = torch .from_numpy (rng .integers ([2 ** 32 ]))
16
+ state ["pytorch_gen" ] = torch .manual_seed (seed )
16
17
return state
17
18
18
19
19
20
@pytorch_funcify .register (ptr .RandomVariable )
20
21
def torch_funcify_RandomVariable (op : ptr .RandomVariable , node , ** kwargs ):
21
22
rv = node .outputs [1 ]
22
23
out_dtype = rv .type .dtype
23
- static_shape = rv .type .shape
24
- batch_ndim = op .batch_ndim (node )
25
-
26
- # Try to pass static size directly to JAX
27
- static_size = static_shape [:batch_ndim ]
28
- if None in static_size :
29
- # Sometimes size can be constant folded during rewrites,
30
- # without the RandomVariable node being updated with new static types
31
- size_param = op .size_param (node )
32
- if isinstance (size_param , Constant ) and not isinstance (
33
- size_param .type , NoneTypeT
34
- ):
35
- static_size = tuple (size_param .data )
24
+ shape = rv .type .shape
36
25
37
26
def sample_fn (rng , size , * parameters ):
38
- return pytorch_sample_fn (op , node = node )(
39
- rng , static_size , out_dtype , * parameters
40
- )
27
+ return pytorch_sample_fn (op , node = node )(rng , shape , out_dtype , * parameters )
41
28
42
29
return sample_fn
43
30
@@ -53,10 +40,22 @@ def pytorch_sample_fn(op, node):
53
40
@pytorch_sample_fn .register (ptr .BernoulliRV )
54
41
def pytorch_sample_fn_bernoulli (op , node ):
55
42
def sample_fn (rng , size , dtype , p ):
56
- # XXX replace
57
- state_ = rng ["pytorch_state" ]
58
- gen = torch .Generator ().set_state (state_ )
59
- sample = torch .bernoulli (torch .expand_copy (p , size ), generator = gen )
60
- return (rng , sample )
43
+ gen = rng ["pytorch_gen" ]
44
+ sample = torch .bernoulli (torch .broadcast_to (p , size ), generator = gen )
45
+ return (gen , sample )
46
+
47
+ return sample_fn
48
+
49
+
50
+ @pytorch_sample_fn .register (ptr .BinomialRV )
51
+ def pytorch_sample_fn_binomial (op , node ):
52
+ def sample_fn (rng , size , dtype , n , p ):
53
+ gen = rng ["pytorch_gen" ]
54
+ sample = torch .binomial (
55
+ torch .broadcast_to (n .to (p .dtype ), size ),
56
+ torch .broadcast_to (p , size ),
57
+ generator = gen ,
58
+ )
59
+ return (gen , sample )
61
60
62
61
return sample_fn
0 commit comments