-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of random variables with PyTorch backend #1075
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1075 +/- ##
=======================================
Coverage 82.10% 82.11%
=======================================
Files 185 186 +1
Lines 48089 48184 +95
Branches 8659 8673 +14
=======================================
+ Hits 39485 39564 +79
- Misses 6439 6452 +13
- Partials 2165 2168 +3
|
static_shape = rv.type.shape | ||
batch_ndim = op.batch_ndim(node) | ||
|
||
# Try to pass static size directly to JAX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: pytorch
# XXX replace | ||
state_ = rng["pytorch_state"] | ||
gen = torch.Generator().set_state(state_) | ||
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't mind this approach! Torch has a lot of wrapping and abstraction on top of it's random generation, so if we just keep a little bit of state around it feels a bit simpler.
thunk_inputs = [] | ||
for n in self.fgraph.inputs: | ||
sinput = storage_map[n] | ||
if isinstance(sinput[0], RandomState | Generator): | ||
new_value = pytorch_typify( | ||
sinput[0], dtype=getattr(sinput[0], "dtype", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
static_shape = rv.type.shape | ||
batch_ndim = op.batch_ndim(node) | ||
|
||
# Try to pass static size directly to JAX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This static size is a JAX limitation that shouldn't exist in PyTorch
# XXX replace | ||
state_ = rng["pytorch_state"] | ||
gen = torch.Generator().set_state(state_) | ||
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it jut broadcast?, why copy?
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) | |
sample = torch.bernoulli(torch.expand_copy(p, size), generator=gen) |
85d6080
to
1c8dc80
Compare
def pytorch_typify(data, dtype=None, **kwargs): | ||
if dtype is None: | ||
return data | ||
else: | ||
return torch.tensor(data, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We change this approach. You need to dispatch on the RNG type and decide what to do with it. The base-cass is to raise
# XXX: Check if there is a better way. | ||
# Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp) | ||
state = rng.__getstate__() | ||
seed = torch.from_numpy(rng.integers([2**32])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have to copy the rng before calling rng.integers
we don't want to modify the original one
def sample_fn(rng, size, *parameters): | ||
return pytorch_sample_fn(op, node=node)(rng, shape, out_dtype, *parameters) | ||
|
||
return sample_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call pytorch_sample_fn
outside of sample_fn
.
sample = torch.binomial( | ||
torch.broadcast_to(n.to(p.dtype), size), | ||
torch.broadcast_to(p, size), | ||
generator=gen, | ||
) | ||
return (gen, sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size may be none, in which case you should do: n, p = torch.broacast_arrays(n, p)
or whatever it's called
@@ -84,9 +86,16 @@ def fn(*inputs, inner_fn=inner_fn): | |||
return fn | |||
|
|||
def create_thunk_inputs(self, storage_map): | |||
from pytensor.link.pytorch.dispatch import pytorch_typify |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to copy the logic with SharedVariables in JAX to emmit a warning and use different variables. You can refactor the logic so it's not duplicated
tests/link/pytorch/test_random.py
Outdated
4, | ||
), | ||
10, | ||
0.5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you take some of these trailing commas, pre-commit won't force it to be multi-line, which is very unreadable here
], | ||
) | ||
def test_binomial(n, p, size): | ||
rng = shared(np.random.default_rng(123)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need tests that confirm the original rng was not affected
rng = shared(np.random.default_rng(123)) | ||
g = pt.random.binomial(n, p, size=size, rng=rng) | ||
g_fn = function([], g, mode=pytorch_mode) | ||
samples = g_fn() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should call twice. In this case, because you did not set updates you should get the same draws back. See https://pytensor.readthedocs.io/en/latest/tutorial/prng.html for details
You should also test with updates separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated this to include a test without the update, but I'm not getting the same draws. I'll read through the article and see if I can see why
- Copied generator before sampling from it
6176479
to
1f863f5
Compare
rng_copy = np.random.default_rng() | ||
rng_copy.bit_generator.state = rng.bit_generator.state | ||
seed = torch.from_numpy(rng_copy.integers([2**32])) | ||
state["pytorch_gen"] = torch.manual_seed(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this monkeypatching on the original state anymore, just work directly with the torch.manual_seed
. It's not any better to pretend this is still a valid numpy generator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess at this point in the pipeline, it doesn't matter if it is still a torch generator, since we're calling typify
def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): | ||
rv = node.outputs[1] | ||
out_dtype = rv.type.dtype | ||
shape = rv.type.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape is not guaranteed to be static. Use the size
argument passed at runtime? Or add an if/else if this was an optimization
@pytorch_sample_fn.register(ptr.BernoulliRV) | ||
def pytorch_sample_fn_bernoulli(op, node): | ||
def sample_fn(rng, size, dtype, p): | ||
gen = rng["pytorch_gen"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah let's not do this indirectaion just work with rng directly, not with the rng stored inside the dictionary
rv_sample = pytorch_sample_fn(op, node=node) | ||
|
||
def sample_fn(rng, size, *args): | ||
_rng = deepcopy(rng) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you shouldn't always deepcopy, only when op.inplace=False
Looking nearly ready! |
9f1416a
to
4b4f8d0
Compare
Description
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1075.org.readthedocs.build/en/1075/