File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -37,7 +37,7 @@ def fid_score(
37
37
diff = mu1 - mu2
38
38
39
39
# Product might be almost singular
40
- covmean , _ = scipy .linalg .sqrtm (sigma1 .mm (sigma2 ), disp = False )
40
+ covmean , _ = scipy .linalg .sqrtm (sigma1 .mm (sigma2 ). numpy () , disp = False )
41
41
# Numerical error might give slight imaginary component
42
42
if np .iscomplexobj (covmean ):
43
43
if not np .allclose (np .diagonal (covmean ).imag , 0 , atol = 1e-3 ):
@@ -48,7 +48,7 @@ def fid_score(
48
48
tr_covmean = np .trace (covmean )
49
49
50
50
if not np .isfinite (covmean ).all ():
51
- tr_covmean = np .sum (np .sqrt (((np .diag (sigma1 ) * eps ) * (np .diag (sigma2 ) * eps )) / (eps * eps )))
51
+ tr_covmean = np .sum (np .sqrt (((np .diag (sigma1 . numpy ()) * eps ) * (np .diag (sigma2 . numpy () ) * eps )) / (eps * eps )))
52
52
53
53
return float (diff .dot (diff ).item () + torch .trace (sigma1 ) + torch .trace (sigma2 ) - 2 * tr_covmean )
54
54
You can’t perform that action at this time.
0 commit comments