Skip to content

Commit 7f92c8c

Browse files
authored
AC: fix scipy import in GAN metrics (openvinotoolkit#1647)
1 parent ea45e67 commit 7f92c8c

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

tools/accuracy_checker/accuracy_checker/metrics/gan_metrics.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
"""
1313

1414
import numpy as np
15-
from scipy.linalg import sqrtm
1615

1716
from ..representation import (
1817
RawTensorAnnotation,
@@ -21,6 +20,12 @@
2120

2221
from .metric import FullDatasetEvaluationMetric
2322
from ..config import NumberField
23+
from ..utils import UnsupportedPackage
24+
25+
try:
26+
from scipy.linalg import sqrtm
27+
except ImportError as error:
28+
sqrtm = UnsupportedPackage('scipy.linalg.sqrtm', error)
2429

2530

2631
class BaseGanMetric(FullDatasetEvaluationMetric):
@@ -98,6 +103,8 @@ def score_calc(self, annotations, predictions):
98103
"""
99104
Calculate FID between feature vector of the real and generated images.
100105
"""
106+
if isinstance(sqrtm, UnsupportedPackage):
107+
sqrtm.raise_error(self.__provider__)
101108

102109
assert annotations.shape[1] == predictions.shape[1], "Expected equal length of feature vectors"
103110

@@ -112,5 +119,5 @@ def score_calc(self, annotations, predictions):
112119
if np.iscomplexobj(covmean):
113120
covmean = covmean.real
114121

115-
FID = mdiff.dot(mdiff) + np.trace(cov_real + cov_gen - 2 * covmean)
116-
return FID
122+
fid = mdiff.dot(mdiff) + np.trace(cov_real + cov_gen - 2 * covmean)
123+
return fid

0 commit comments

Comments
 (0)