Skip to content

Commit 9b1a70a

Browse files
committed
Check fix
1 parent 6b337b8 commit 9b1a70a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ignite/metrics/gan/fid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def fid_score(
3737
diff = mu1 - mu2
3838

3939
# Product might be almost singular
40-
covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2), disp=False)
40+
covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2).numpy(), disp=False)
4141
# Numerical error might give slight imaginary component
4242
if np.iscomplexobj(covmean):
4343
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
@@ -48,7 +48,7 @@ def fid_score(
4848
tr_covmean = np.trace(covmean)
4949

5050
if not np.isfinite(covmean).all():
51-
tr_covmean = np.sum(np.sqrt(((np.diag(sigma1) * eps) * (np.diag(sigma2) * eps)) / (eps * eps)))
51+
tr_covmean = np.sum(np.sqrt(((np.diag(sigma1.numpy()) * eps) * (np.diag(sigma2.numpy()) * eps)) / (eps * eps)))
5252

5353
return float(diff.dot(diff).item() + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean)
5454

0 commit comments

Comments
 (0)