Skip to content

Extend verify_grad to complex gradient #1367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@
for var in grad_dict:
g = grad_dict[var]
if hasattr(g.type, "dtype"):
assert g.type.dtype in pytensor.tensor.type.float_dtypes
assert g.type.dtype in pytensor.tensor.type.continuous_dtypes

_rval: Sequence[Variable] = _populate_grad_dict(
var_to_app_to_idx, grad_dict, _wrt, cost_name
Expand Down Expand Up @@ -1411,7 +1411,7 @@
)

if not isinstance(term.type, NullType | DisconnectedType):
if term.type.dtype not in pytensor.tensor.type.float_dtypes:
if term.type.dtype not in pytensor.tensor.type.continuous_dtypes:
raise TypeError(
str(node.op) + ".grad illegally "
" returned an integer-valued variable."
Expand Down Expand Up @@ -1562,7 +1562,7 @@
"""

dtype = x.type.dtype
if dtype not in pytensor.tensor.type.float_dtypes:
if dtype not in pytensor.tensor.type.continuous_dtypes:
dtype = config.floatX

return x.ones_like(dtype=dtype)
Expand Down Expand Up @@ -1633,16 +1633,35 @@
rval *= i
return rval

def real_array(x):
dtype = np.array(x).dtype
if str(dtype).startswith("complex"):
return (

Check warning on line 1639 in pytensor/gradient.py

View check run for this annotation

Codecov / codecov/patch

pytensor/gradient.py#L1639

Added line #L1639 was not covered by tests
np.stack([np.real(x), np.imag(x)], axis=-1),
np.array([1.0, 1j]),
)
else:
return (np.expand_dims(np.array(x), axis=-1), np.array([1.0]))

def real_wrapper(f, c_stack):
def wrapped_f(*arr):
c_arr = [t @ s for (t, s) in zip(arr, c_stack, strict=True)]
return f(*c_arr)

return wrapped_f

packed_pt = False
if not isinstance(pt, list | tuple):
pt = [pt]
packed_pt = True

apt = [np.array(p) for p in pt]
apt, complex_stack = list(map(list, zip(*map(real_array, pt), strict=True)))

shapes = [p.shape for p in apt]
dtypes = [str(p.dtype) for p in apt]

real_f = real_wrapper(f, complex_stack)

# TODO: remove this eventually (why was this here in the first place ?)
# In the case of CSM, the arguments are a mixture of floats and
# integers...
Expand Down Expand Up @@ -1677,22 +1696,26 @@
apt[i][...] = p
cur_pos += p_size

f_x = f(*[p.copy() for p in apt])
real_f_x = real_f(*[p.copy() for p in apt])

# now iterate over the elements of x, and call f on apt.
x_copy = x.copy()
for i in range(total_size):
x[:] = x_copy

x[i] += eps
f_eps = f(*apt)
real_f_eps = real_f(*apt)

# TODO: remove this when it is clear that the next
# replacemement does not pose problems of its own. It was replaced
# for its inability to handle complex variables.
# gx[i] = numpy.asarray((f_eps - f_x) / eps)

gx[i] = (f_eps - f_x) / eps
gx[i] = (real_f_eps - real_f_x) / eps

self.gf = [
(p @ s).conj() for (p, s) in zip(self.gf, complex_stack, strict=True)
]

if packed_pt:
self.gf = self.gf[0]
Expand Down Expand Up @@ -1874,14 +1897,18 @@
pt = [np.array(p) for p in pt]

for i, p in enumerate(pt):
if p.dtype not in ("float16", "float32", "float64"):
if p.dtype not in pytensor.tensor.type.continuous_dtypes:
raise TypeError(
"verify_grad can work only with floating point "
f'inputs, but input {i} has dtype "{p.dtype}".'
)

_type_tol = dict( # relative error tolerances for different types
float16=5e-2, float32=1e-2, float64=1e-4
float16=5e-2,
float32=1e-2,
float64=1e-4,
complex64=1e-2,
complex128=1e-4,
)

if abs_tol is None:
Expand Down