39
39
40
40
@numeric .device .register (tf .Tensor )
41
41
def _ (a : tf .Tensor ) -> TensorDeviceType :
42
- return DEVICE_MAP_REV [a .device .split ("/" )[- 1 ].split (":" )[1 ]]
42
+ if "CPU" in a .device :
43
+ return DEVICE_MAP_REV ["CPU" ]
44
+ if "GPU" in a .device :
45
+ return DEVICE_MAP_REV ["GPU" ]
43
46
44
47
45
48
@numeric .backend .register (tf .Tensor )
@@ -136,7 +139,7 @@ def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Te
136
139
137
140
@numeric .isempty .register (tf .Tensor )
138
141
def _ (a : tf .Tensor ) -> bool :
139
- return bool (tf .equal (tf .size (a ), 0 ).numpy (). T )
142
+ return bool (tf .equal (tf .size (a ), 0 ).numpy ())
140
143
141
144
142
145
@numeric .isclose .register (tf .Tensor )
@@ -199,18 +202,8 @@ def _(x: tf.Tensor, axis: int = 0) -> List[tf.Tensor]:
199
202
200
203
@numeric .moveaxis .register (tf .Tensor )
201
204
def _ (a : tf .Tensor , source : Union [int , Tuple [int , ...]], destination : Union [int , Tuple [int , ...]]) -> tf .Tensor :
202
- perm = list (range (a ._rank ()))
203
- if isinstance (source , int ):
204
- axe_to_move = perm .pop (source )
205
- if destination < 0 :
206
- destination = len (perm ) + destination + 1
207
- perm .insert (destination , axe_to_move )
208
- else :
209
- old_perm = perm [:]
210
- for i in range (len (source )):
211
- perm [destination [i ]] = old_perm [source [i ]]
212
205
with tf .device (a .device ):
213
- return tf .transpose (a , perm )
206
+ return tf .experimental . numpy . moveaxis (a , source , destination )
214
207
215
208
216
209
@numeric .mean .register (tf .Tensor )
@@ -311,6 +304,7 @@ def _(a: tf.Tensor, data: Any) -> tf.Tensor:
311
304
312
305
@numeric .item .register (tf .Tensor )
313
306
def _ (a : tf .Tensor ) -> Union [int , float , bool ]:
307
+ a = tf .reshape (a , [])
314
308
np_item = a .numpy ()
315
309
if isinstance (np_item , np .floating ):
316
310
return float (np_item )
@@ -337,11 +331,10 @@ def _(
337
331
a : tf .Tensor , axis : Optional [Union [int , Tuple [int , ...]]] = None , keepdims : bool = False , ddof : int = 0
338
332
) -> tf .Tensor :
339
333
with tf .device (a .device ):
340
- assert ddof in {0 , 1 }
341
334
tf_var = tf .math .reduce_variance (a , axis = axis , keepdims = keepdims )
342
335
if ddof :
343
336
n = tf .shape (a )[axis ] if axis is not None else tf .size (a )
344
- tf_var *= float (n ) / float (n - 1 )
337
+ tf_var *= float (n ) / float (n - ddof )
345
338
return tf_var
346
339
347
340
@@ -480,8 +473,7 @@ def zeros(
480
473
if device is not None :
481
474
device = DEVICE_MAP [device ]
482
475
with tf .device (device ):
483
- zeros = tf .zeros (shape , dtype = dtype )
484
- return zeros
476
+ return tf .zeros (shape , dtype = dtype )
485
477
486
478
487
479
def eye (
@@ -513,8 +505,7 @@ def arange(
513
505
if device is not None :
514
506
device = DEVICE_MAP [device ]
515
507
with tf .device (device ):
516
- r = tf .range (start , end , step , dtype = dtype )
517
- return r
508
+ return tf .range (start , end , step , dtype = dtype )
518
509
519
510
520
511
def from_numpy (ndarray : np .ndarray ) -> tf .Tensor :
0 commit comments