forked from pymc-devs/pytensor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscalar.py
51 lines (37 loc) · 1.62 KB
/
scalar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
ScalarOp,
)
@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
even though it's dispatched on the Scalar Op.
"""
nfunc_spec = getattr(op, "nfunc_spec", None)
if nfunc_spec is None:
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
func_name = nfunc_spec[0]
pytorch_func = getattr(torch, func_name)
if len(node.inputs) > op.nfunc_spec[1]:
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
# even though the base Op from `func_name` is specified as a binary Op.
# This happens with `Add`, which can work as a `Sum` for multiple scalars.
pytorch_variadic_func = getattr(torch, op.nfunc_variadic, None)
if not pytorch_variadic_func:
raise NotImplementedError(
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs"
)
def pytorch_func(*args):
return pytorch_variadic_func(
torch.stack(torch.broadcast_tensors(*args), axis=0), axis=0
)
return pytorch_func
@pytorch_funcify.register(Cast)
def pytorch_funcify_Cast(op: Cast, node, **kwargs):
dtype = getattr(torch, op.o_type.dtype)
def cast(x):
return x.to(dtype=dtype)
return cast