Skip to content

Commit c687e6a

Browse files
committed
Address comments
1 parent 298891b commit c687e6a

File tree

2 files changed

+11
-20
lines changed

2 files changed

+11
-20
lines changed

Diff for: nncf/tensor/functions/tf_numeric.py

+10-19
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@
3939

4040
@numeric.device.register(tf.Tensor)
4141
def _(a: tf.Tensor) -> TensorDeviceType:
42-
return DEVICE_MAP_REV[a.device.split("/")[-1].split(":")[1]]
42+
if "CPU" in a.device:
43+
return DEVICE_MAP_REV["CPU"]
44+
if "GPU" in a.device:
45+
return DEVICE_MAP_REV["GPU"]
4346

4447

4548
@numeric.backend.register(tf.Tensor)
@@ -136,7 +139,7 @@ def _(a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> tf.Te
136139

137140
@numeric.isempty.register(tf.Tensor)
138141
def _(a: tf.Tensor) -> bool:
139-
return bool(tf.equal(tf.size(a), 0).numpy().T)
142+
return bool(tf.equal(tf.size(a), 0).numpy())
140143

141144

142145
@numeric.isclose.register(tf.Tensor)
@@ -199,18 +202,8 @@ def _(x: tf.Tensor, axis: int = 0) -> List[tf.Tensor]:
199202

200203
@numeric.moveaxis.register(tf.Tensor)
201204
def _(a: tf.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> tf.Tensor:
202-
perm = list(range(a._rank()))
203-
if isinstance(source, int):
204-
axe_to_move = perm.pop(source)
205-
if destination < 0:
206-
destination = len(perm) + destination + 1
207-
perm.insert(destination, axe_to_move)
208-
else:
209-
old_perm = perm[:]
210-
for i in range(len(source)):
211-
perm[destination[i]] = old_perm[source[i]]
212205
with tf.device(a.device):
213-
return tf.transpose(a, perm)
206+
return tf.experimental.numpy.moveaxis(a, source, destination)
214207

215208

216209
@numeric.mean.register(tf.Tensor)
@@ -311,6 +304,7 @@ def _(a: tf.Tensor, data: Any) -> tf.Tensor:
311304

312305
@numeric.item.register(tf.Tensor)
313306
def _(a: tf.Tensor) -> Union[int, float, bool]:
307+
a = tf.reshape(a, [])
314308
np_item = a.numpy()
315309
if isinstance(np_item, np.floating):
316310
return float(np_item)
@@ -337,11 +331,10 @@ def _(
337331
a: tf.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ddof: int = 0
338332
) -> tf.Tensor:
339333
with tf.device(a.device):
340-
assert ddof in {0, 1}
341334
tf_var = tf.math.reduce_variance(a, axis=axis, keepdims=keepdims)
342335
if ddof:
343336
n = tf.shape(a)[axis] if axis is not None else tf.size(a)
344-
tf_var *= float(n) / float(n - 1)
337+
tf_var *= float(n) / float(n - ddof)
345338
return tf_var
346339

347340

@@ -480,8 +473,7 @@ def zeros(
480473
if device is not None:
481474
device = DEVICE_MAP[device]
482475
with tf.device(device):
483-
zeros = tf.zeros(shape, dtype=dtype)
484-
return zeros
476+
return tf.zeros(shape, dtype=dtype)
485477

486478

487479
def eye(
@@ -513,8 +505,7 @@ def arange(
513505
if device is not None:
514506
device = DEVICE_MAP[device]
515507
with tf.device(device):
516-
r = tf.range(start, end, step, dtype=dtype)
517-
return r
508+
return tf.range(start, end, step, dtype=dtype)
518509

519510

520511
def from_numpy(ndarray: np.ndarray) -> tf.Tensor:

Diff for: tests/tensorflow/test_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def cast_to(x: tf.Tensor, dtype: TensorDataType) -> tf.Tensor:
2929
class TestTFNNCFTensorOperators(TemplateTestNNCFTensorOperators):
3030
@staticmethod
3131
def to_tensor(x):
32-
with tf.device("/CPU:0"):
32+
with tf.device("CPU"):
3333
return tf.constant(x)
3434

3535
@staticmethod

0 commit comments

Comments
 (0)