1
1
from functools import singledispatch
2
2
3
- import numpy as np
3
+ import numpy . random
4
4
import torch
5
- from numpy .random import Generator
6
5
7
6
import pytensor .tensor .random .basic as ptr
8
7
from pytensor .link .pytorch .dispatch .basic import pytorch_funcify , pytorch_typify
9
8
10
9
11
- @pytorch_typify .register (Generator )
10
+ @pytorch_typify .register (numpy . random . Generator )
12
11
def pytorch_typify_Generator (rng , ** kwargs ):
13
12
# XXX: Check if there is a better way.
14
13
# Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp)
15
- state = rng .__getstate__ ()
16
- rng_copy = np .random .default_rng ()
17
- rng_copy .bit_generator .state = rng .bit_generator .state
18
- seed = torch .from_numpy (rng_copy .integers ([2 ** 32 ]))
19
- state ["pytorch_gen" ] = torch .manual_seed (seed )
20
- return state
14
+ seed = torch .from_numpy (rng .integers ([2 ** 32 ]))
15
+ return torch .manual_seed (seed )
16
+
17
+
18
+ @pytorch_typify .register (torch ._C .Generator )
19
+ def pytorch_typify_pass_generator (rng , ** kwargs ):
20
+ return rng
21
21
22
22
23
23
@pytorch_funcify .register (ptr .RandomVariable )
@@ -28,7 +28,9 @@ def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
28
28
rv_sample = pytorch_sample_fn (op , node = node )
29
29
30
30
def sample_fn (rng , size , * args ):
31
- return rv_sample (rng , shape , out_dtype , * args )
31
+ new_rng = torch .Generator (device = "cpu" )
32
+ new_rng .set_state (rng .get_state ().clone ())
33
+ return rv_sample (new_rng , shape , out_dtype , * args )
32
34
33
35
return sample_fn
34
36
@@ -43,25 +45,31 @@ def pytorch_sample_fn(op, node):
43
45
44
46
@pytorch_sample_fn .register (ptr .BernoulliRV )
45
47
def pytorch_sample_fn_bernoulli (op , node ):
46
- def sample_fn (rng , size , dtype , p ):
47
- gen = rng ["pytorch_gen" ]
48
+ def sample_fn (gen , size , dtype , p ):
48
49
sample = torch .bernoulli (torch .broadcast_to (p , size ), generator = gen )
49
- rng ["pytorch_gen" ] = gen
50
- return (rng , sample )
50
+ return (gen , sample )
51
51
52
52
return sample_fn
53
53
54
54
55
55
@pytorch_sample_fn .register (ptr .BinomialRV )
56
56
def pytorch_sample_fn_binomial (op , node ):
57
- def sample_fn (rng , size , dtype , n , p ):
58
- gen = rng ["pytorch_gen" ]
57
+ def sample_fn (gen , size , dtype , n , p ):
59
58
sample = torch .binomial (
60
- torch .broadcast_to (n , size ),
61
- torch .broadcast_to (p , size ),
59
+ torch .broadcast_to (n , size ). to ( torch . float32 ) ,
60
+ torch .broadcast_to (p , size ). to ( torch . float32 ) ,
62
61
generator = gen ,
63
62
)
64
- rng ["pytorch_gen" ] = gen
65
- return (rng , sample )
63
+ return (gen , sample )
64
+
65
+ return sample_fn
66
+
67
+
68
+ @pytorch_sample_fn .register (ptr .UniformRV )
69
+ def pytorch_sample_fn_uniform (op , node ):
70
+ def sample_fn (gen , size , dtype , low , high ):
71
+ sample = torch .FloatTensor (size )
72
+ sample .uniform_ (low .item (), high .item (), generator = gen )
73
+ return (gen , sample )
66
74
67
75
return sample_fn
0 commit comments