Skip to content

Batched dot doesn't support complex inputs #1849

@jessegrabowski

Description

@jessegrabowski

Description

This came up in #1840.

import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.ztensor3('x')
y = pt.ztensor3('y')
z = x @ y
fn = pytensor.function([x, y], z)

x_val, y_val = np.random.normal(size=(2, 6, 6, 2)).view(np.complex128)
fn(x_val, y_val)
Traceback
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
File ~/Documents/Python/pytensor/pytensor/compile/function/types.py:1038, in Function.__call__(self, output_subset, *args, **kwargs)
   1037 try:
-> 1038     outputs = vm() if output_subset is None else vm(output_subset=output_subset)
   1039 except Exception:

NotImplementedError: type(x) is not double or float

During handling of the above exception, another exception occurred:

NotImplementedError                       Traceback (most recent call last)
Cell In[10], line 11
      8 fn = pytensor.function([x, y], z)
     10 x_val, y_val = np.random.normal(size=(2, 6, 6, 2)).view(np.complex128)
---> 11 fn(x_val, y_val)

File ~/Documents/Python/pytensor/pytensor/compile/function/types.py:1048, in Function.__call__(self, output_subset, *args, **kwargs)
   1046     if hasattr(self.vm, "thunks"):
   1047         thunk = self.vm.thunks[self.vm.position_of_error]
-> 1048     raise_with_op(
   1049         self.maker.fgraph,
   1050         node=self.vm.nodes[self.vm.position_of_error],
   1051         thunk=thunk,
   1052         storage_map=getattr(self.vm, "storage_map", None),
   1053     )
   1054 else:
   1055     # old-style linkers raise their own exceptions
   1056     raise

File ~/Documents/Python/pytensor/pytensor/link/utils.py:526, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    521     warnings.warn(
    522         f"{exc_type} error does not allow us to add an extra error message"
    523     )
    524     # Some exception need extra parameter in inputs. So forget the
    525     # extra long error message in that case.
--> 526 raise exc_value.with_traceback(exc_trace)

File ~/Documents/Python/pytensor/pytensor/compile/function/types.py:1038, in Function.__call__(self, output_subset, *args, **kwargs)
   1036     t0_fn = time.perf_counter()
   1037 try:
-> 1038     outputs = vm() if output_subset is None else vm(output_subset=output_subset)
   1039 except Exception:
   1040     self._restore_defaults()

NotImplementedError: type(x) is not double or float
Apply node that caused the error: BatchedDot(x, y)
Toposort index: 0
Inputs types: [TensorType(complex128, shape=(None, None, None)), TensorType(complex128, shape=(None, None, None))]
Inputs shapes: [(6, 6, 1), (6, 6, 1)]
Inputs strides: [(96, 16, 16), (96, 16, 16)]
Inputs values: ['not shown', 'not shown']
Outputs clients: [[output[0](BatchedDot.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/jessegrabowski/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/Users/jessegrabowski/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3098, in run_cell
    result = self._run_cell(
  File "/Users/jessegrabowski/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3153, in _run_cell
    result = runner(coro)
  File "/Users/jessegrabowski/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    coro.send(None)
  File "/Users/jessegrabowski/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3362, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/jessegrabowski/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3607, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/jessegrabowski/mambaforge/envs/pytensor-dev/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3667, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_46141/4000051598.py", line 7, in <module>
    z = x @ y

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions