Skip to content

Commit 18ff63a

Browse files
add tensor consitency chapter in sharp bits
1 parent 94dd9ef commit 18ff63a

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

docs/source/sharpbits.rst

+47
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,53 @@ Similarly, conditional gate application must be takend carefully.
141141
# <tf.Tensor: shape=(2,), dtype=complex64, numpy=array([0.99999994+0.j, 0. +0.j], dtype=complex64)>
142142
143143
144+
Tensor variables consistency
145+
-------------------------------------------------------
146+
147+
148+
All tensor variables' backend (tf vs jax vs ..), dtype (float vs complex), shape and device (cpu vs gpu) must be compatible/consistent.
149+
150+
Inspect the backend, dtype, shape and device using the following codes.
151+
152+
.. code-block:: python
153+
154+
for backend in ["numpy", "tensorflow", "jax", "pytorch"]:
155+
with tc.runtime_backend(backend):
156+
a = tc.backend.ones([2, 3])
157+
print("tensor backend:", tc.interfaces.which_backend(a))
158+
print("tensor dtype:", tc.backend.dtype(a))
159+
print("tensor shape:", tc.backend.shape_tuple(a))
160+
print("tensor device:", tc.backend.device(a))
161+
162+
If the backend is inconsistent, one can convert the tensor backend via :py:meth:`tensorcircuit.interfaces.tensortrans.general_args_to_backend`.
163+
164+
.. code-block:: python
165+
166+
for backend in ["numpy", "tensorflow", "jax", "pytorch"]:
167+
with tc.runtime_backend(backend):
168+
a = tc.backend.ones([2, 3])
169+
print("tensor backend:", tc.interfaces.which_backend(a))
170+
b = tc.interfaces.general_args_to_backend(a, target_backend="jax", enable_dlpack=False)
171+
print("tensor backend:", tc.interfaces.which_backend(b))
172+
173+
If the dtype is inconsistent, one can convert the tensor dtype using ``tc.backend.cast``.
174+
175+
.. code-block:: python
176+
177+
for backend in ["numpy", "tensorflow", "jax", "pytorch"]:
178+
with tc.runtime_backend(backend):
179+
a = tc.backend.ones([2, 3])
180+
print("tensor dtype:", tc.backend.dtype(a))
181+
b = tc.backend.cast(a, dtype="float64")
182+
print("tensor dtype:", tc.backend.dtype(b))
183+
184+
Also note the jax issue on float64/complex128, see `jax gotcha <https://github.com/google/jax#current-gotchas>`_.
185+
186+
If the shape is not consistent, one can convert the shape by ``tc.backend.reshape``.
187+
188+
If the device is not consistent, one can move the tensor between devices by ``tc.backend.device_move``.
189+
190+
144191
AD Consistency
145192
---------------------
146193

0 commit comments

Comments
 (0)