-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcompute_fid.py
34 lines (26 loc) · 1.21 KB
/
compute_fid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
from scipy.linalg import sqrtm
def compute_fid(act1, act2):
""" Calculates the FID between two distributions of data.
Inputs:
- act1: numpy array containing the first dataset. One row per entry, one column per feature.
- act1: numpy array containing the second dataset. One row per entry, one column per feature.
Outputs:
- fid: float containing the FID between the two distributions.
"""
# Workaround to there are no nans/infs in the data
act1 = np.nan_to_num(act1, nan = 10.0, posinf = 10.0, neginf = 10.0)
act2 = np.nan_to_num(act2, nan = 10.0, posinf = 10.0, neginf = 10.0)
# Calculate mean and covariance statistics
mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
# Calculate sum squared difference between means
ssdiff = np.sum((mu1 - mu2)**2.0)
# Calculate sqrt of product between cov
covmean = sqrtm(sigma1.dot(sigma2))
# Check and correct imaginary numbers from sqrt
if np.iscomplexobj(covmean):
covmean = covmean.real
# Calculate score
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return fid