diff --git a/fnet.py b/fnet.py index 04eae94..2473816 100644 --- a/fnet.py +++ b/fnet.py @@ -27,7 +27,7 @@ def __init__(self): super().__init__() def forward(self, x): - x = torch.fft.fft(torch.fft.fft(x, dim=-1), dim=-2).real + x = torch.fft.fft2(x).real return x class FNet(nn.Module):