@@ -141,6 +141,53 @@ Similarly, conditional gate application must be takend carefully.
141
141
# <tf.Tensor: shape=(2,), dtype=complex64, numpy=array([0.99999994+0.j, 0. +0.j], dtype=complex64)>
142
142
143
143
144
+ Tensor variables consistency
145
+ -------------------------------------------------------
146
+
147
+
148
+ All tensor variables' backend (tf vs jax vs ..), dtype (float vs complex), shape and device (cpu vs gpu) must be compatible/consistent.
149
+
150
+ Inspect the backend, dtype, shape and device using the following codes.
151
+
152
+ .. code-block :: python
153
+
154
+ for backend in [" numpy" , " tensorflow" , " jax" , " pytorch" ]:
155
+ with tc.runtime_backend(backend):
156
+ a = tc.backend.ones([2 , 3 ])
157
+ print (" tensor backend:" , tc.interfaces.which_backend(a))
158
+ print (" tensor dtype:" , tc.backend.dtype(a))
159
+ print (" tensor shape:" , tc.backend.shape_tuple(a))
160
+ print (" tensor device:" , tc.backend.device(a))
161
+
162
+ If the backend is inconsistent, one can convert the tensor backend via :py:meth: `tensorcircuit.interfaces.tensortrans.general_args_to_backend `.
163
+
164
+ .. code-block :: python
165
+
166
+ for backend in [" numpy" , " tensorflow" , " jax" , " pytorch" ]:
167
+ with tc.runtime_backend(backend):
168
+ a = tc.backend.ones([2 , 3 ])
169
+ print (" tensor backend:" , tc.interfaces.which_backend(a))
170
+ b = tc.interfaces.general_args_to_backend(a, target_backend = " jax" , enable_dlpack = False )
171
+ print (" tensor backend:" , tc.interfaces.which_backend(b))
172
+
173
+ If the dtype is inconsistent, one can convert the tensor dtype using ``tc.backend.cast ``.
174
+
175
+ .. code-block :: python
176
+
177
+ for backend in [" numpy" , " tensorflow" , " jax" , " pytorch" ]:
178
+ with tc.runtime_backend(backend):
179
+ a = tc.backend.ones([2 , 3 ])
180
+ print (" tensor dtype:" , tc.backend.dtype(a))
181
+ b = tc.backend.cast(a, dtype = " float64" )
182
+ print (" tensor dtype:" , tc.backend.dtype(b))
183
+
184
+ Also note the jax issue on float64/complex128, see `jax gotcha <https://github.com/google/jax#current-gotchas >`_.
185
+
186
+ If the shape is not consistent, one can convert the shape by ``tc.backend.reshape ``.
187
+
188
+ If the device is not consistent, one can move the tensor between devices by ``tc.backend.device_move ``.
189
+
190
+
144
191
AD Consistency
145
192
---------------------
146
193
0 commit comments