Skip to content

Commit 298891b

Browse files
committed
Added numerics methods for tensorflow tensor
1 parent 6031ccc commit 298891b

File tree

9 files changed

+779
-13
lines changed

9 files changed

+779
-13
lines changed

docs/api/source/conf.py

+2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def collect_api_entities() -> APIInfo:
143143
"nncf.tensor.functions.numpy_linalg",
144144
"nncf.tensor.functions.torch_numeric",
145145
"nncf.tensor.functions.torch_linalg",
146+
"nncf.tensor.functions.tf_numeric",
147+
"nncf.tensor.functions.tf_linalg",
146148
]
147149

148150
with mock(mock_modules):

nncf/tensor/definitions.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class TensorBackend(Enum):
2020
"""
2121

2222
numpy = auto()
23+
tf = auto()
2324
torch = auto()
2425

2526

nncf/tensor/functions/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def _initialize_backends():
7171
import nncf.tensor.functions.numpy_linalg
7272
import nncf.tensor.functions.numpy_numeric
7373

74+
with contextlib.suppress(ImportError):
75+
import nncf.tensor.functions.tf_linalg
76+
import nncf.tensor.functions.tf_numeric
77+
7478
with contextlib.suppress(ImportError):
7579
import nncf.tensor.functions.torch_linalg
7680
import nncf.tensor.functions.torch_numeric # noqa: F401

nncf/tensor/functions/dispatcher.py

+5
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,8 @@ def get_numeric_backend_fn(fn_name: str, backend: TensorBackend) -> Callable:
7575
from nncf.tensor.functions import torch_numeric
7676

7777
return getattr(torch_numeric, fn_name)
78+
79+
if backend == TensorBackend.tf:
80+
from nncf.tensor.functions import tf_numeric
81+
82+
return getattr(tf_numeric, fn_name)

nncf/tensor/functions/tf_linalg.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import warnings
13+
from typing import Optional, Tuple, Union
14+
15+
import tensorflow as tf
16+
17+
from nncf.tensor.functions import linalg
18+
19+
20+
@linalg.norm.register(tf.Tensor)
21+
def _(
22+
a: tf.Tensor,
23+
ord: Optional[Union[str, float, int]] = None,
24+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
25+
keepdims: bool = False,
26+
) -> tf.Tensor:
27+
if axis is None:
28+
axis = 0 if a._rank() == 1 else (0, 1)
29+
30+
if ord is None or (a._rank() == 1 and ord == "fro"):
31+
ord = "euclidean"
32+
33+
with tf.device(a.device):
34+
if ord == "nuc":
35+
s, _, _ = tf.linalg.svd(a)
36+
return tf.reduce_sum(s)
37+
38+
return tf.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims)
39+
40+
41+
@linalg.cholesky.register(tf.Tensor)
42+
def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor:
43+
with tf.device(a.device):
44+
cholesky = tf.linalg.cholesky(a)
45+
if upper:
46+
perm = list(range(tf.rank(a)))
47+
perm[-1], perm[-2] = perm[-2], perm[-1]
48+
cholesky = tf.transpose(cholesky, perm=perm)
49+
return cholesky
50+
51+
52+
@linalg.cholesky_inverse.register(tf.Tensor)
53+
def _(a: tf.Tensor, upper: bool = False) -> tf.Tensor:
54+
with tf.device(a.device):
55+
if upper:
56+
perm = list(range(tf.rank(a)))
57+
perm[-1], perm[-2] = perm[-2], perm[-1]
58+
a = tf.transpose(a, perm=perm)
59+
60+
eye = tf.eye(a.shape[0], dtype=a.dtype)
61+
return tf.linalg.cholesky_solve(a, eye)
62+
63+
64+
@linalg.inv.register(tf.Tensor)
65+
def _(a: tf.Tensor) -> tf.Tensor:
66+
with tf.device(a.device):
67+
return tf.linalg.inv(a)
68+
69+
70+
@linalg.pinv.register(tf.Tensor)
71+
def _(a: tf.Tensor) -> tf.Tensor:
72+
with tf.device(a.device):
73+
return tf.linalg.pinv(a)
74+
75+
76+
@linalg.lstsq.register(tf.Tensor)
77+
def _(a: tf.Tensor, b: tf.Tensor, driver: Optional[str] = None) -> tf.Tensor:
78+
with tf.device(a.device):
79+
if driver is not None:
80+
warnings.warn("Driver specifying is not supported in TensorFlow lstsq method")
81+
if tf.rank(b) == 1:
82+
b = tf.expand_dims(b, axis=0)
83+
perm = list(range(tf.rank(b)))
84+
perm[-1], perm[-2] = perm[-2], perm[-1]
85+
b = tf.transpose(b, perm=perm)
86+
87+
return tf.linalg.lstsq(a, b)
88+
89+
90+
@linalg.svd.register(tf.Tensor)
91+
def _(a: tf.Tensor, full_matrices: Optional[bool] = True) -> tf.Tensor:
92+
with tf.device(a.device):
93+
s, u, v = tf.linalg.svd(a, full_matrices=full_matrices)
94+
95+
return u, s, tf.transpose(v)

0 commit comments

Comments
 (0)