Skip to content

[Bug]: Parameter count in FNO model is too low when following specification #67

@t-muser

Description

@t-muser

Describe the issue:

I think torchinfo miscounts the parameters due to some ModuleList shenanigans. If one uses a simple for-loop + torch, the number of parameters is much lower, with the difference being exactly the 7.7mio assigned to the total ModuleList. A simple fix would be to increase hidden_channels to 180, which gives 19,056,783 mio total parameters.

===================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
===================================================================================================================
FNO                                      [4, 3, 64, 64]            [4, 3, 64, 64]            --
├─ChannelMLP: 1-1                        [4, 3, 64, 64]            [4, 128, 64, 64]          --
│    └─ModuleList: 2-1                   --                        --                        --
│    │    └─Conv1d: 3-1                  [4, 3, 4096]              [4, 256, 4096]            1,024
│    │    └─Conv1d: 3-2                  [4, 256, 4096]            [4, 128, 4096]            32,896
├─FNOBlocks: 1-2                         [4, 128, 64, 64]          [4, 128, 64, 64]          7,177,536
│    └─ModuleList: 2-14                  --                        --                        (recursive)
│    │    └─Flattened1dConv: 3-3         [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─Conv1d: 4-1             [4, 128, 4096]            [4, 128, 4096]            16,384
│    └─ModuleList: 2-15                  --                        --                        (recursive)
│    │    └─SoftGating: 3-4              [4, 128, 64, 64]          [4, 128, 64, 64]          128
│    └─ModuleList: 2-16                  --                        --                        (recursive)
│    │    └─SpectralConv: 3-5            [4, 128, 64, 64]          [4, 128, 64, 64]          2,359,424
│    └─ModuleList: 2-17                  --                        --                        (recursive)
│    │    └─ChannelMLP: 3-6              [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─ModuleList: 4-2         --                        --                        16,576
├─FNOBlocks: 1-3                         [4, 128, 64, 64]          [4, 128, 64, 64]          (recursive)
│    └─ModuleList: 2-14                  --                        --                        (recursive)
│    │    └─Flattened1dConv: 3-7         [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─Conv1d: 4-3             [4, 128, 4096]            [4, 128, 4096]            16,384
│    └─ModuleList: 2-15                  --                        --                        (recursive)
│    │    └─SoftGating: 3-8              [4, 128, 64, 64]          [4, 128, 64, 64]          128
│    └─ModuleList: 2-16                  --                        --                        (recursive)
│    │    └─SpectralConv: 3-9            [4, 128, 64, 64]          [4, 128, 64, 64]          2,359,424
│    └─ModuleList: 2-17                  --                        --                        (recursive)
│    │    └─ChannelMLP: 3-10             [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─ModuleList: 4-4         --                        --                        16,576
├─FNOBlocks: 1-4                         [4, 128, 64, 64]          [4, 128, 64, 64]          (recursive)
│    └─ModuleList: 2-14                  --                        --                        (recursive)
│    │    └─Flattened1dConv: 3-11        [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─Conv1d: 4-5             [4, 128, 4096]            [4, 128, 4096]            16,384
│    └─ModuleList: 2-15                  --                        --                        (recursive)
│    │    └─SoftGating: 3-12             [4, 128, 64, 64]          [4, 128, 64, 64]          128
│    └─ModuleList: 2-16                  --                        --                        (recursive)
│    │    └─SpectralConv: 3-13           [4, 128, 64, 64]          [4, 128, 64, 64]          2,359,424
│    └─ModuleList: 2-17                  --                        --                        (recursive)
│    │    └─ChannelMLP: 3-14             [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─ModuleList: 4-6         --                        --                        16,576
├─FNOBlocks: 1-5                         [4, 128, 64, 64]          [4, 128, 64, 64]          (recursive)
│    └─ModuleList: 2-14                  --                        --                        (recursive)
│    │    └─Flattened1dConv: 3-15        [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─Conv1d: 4-7             [4, 128, 4096]            [4, 128, 4096]            16,384
│    └─ModuleList: 2-15                  --                        --                        (recursive)
│    │    └─SoftGating: 3-16             [4, 128, 64, 64]          [4, 128, 64, 64]          128
│    └─ModuleList: 2-16                  --                        --                        (recursive)
│    │    └─SpectralConv: 3-17           [4, 128, 64, 64]          [4, 128, 64, 64]          2,359,424
│    └─ModuleList: 2-17                  --                        --                        (recursive)
│    │    └─ChannelMLP: 3-18             [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─ModuleList: 4-8         --                        --                        16,576
├─ChannelMLP: 1-6                        [4, 128, 64, 64]          [4, 3, 64, 64]            --
│    └─ModuleList: 2-18                  --                        --                        --
│    │    └─Conv1d: 3-19                 [4, 128, 4096]            [4, 256, 4096]            33,024
│    │    └─Conv1d: 3-20                 [4, 256, 4096]            [4, 3, 4096]              771
===================================================================================================================
Total params: 16,815,299
Trainable params: 16,815,299
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 3.27
===================================================================================================================
Input size (MB): 0.20
Forward/backward pass size (MB): 319.16
Params size (MB): 0.80
Estimated Total Size (MB): 320.16
===================================================================================================================
Torch counting:
        9,637,763

Code to reproduce the issue:

import torch
from torchinfo import summary

from neuralop.models import FNO

in_channels = 3
out_channels = 3
modes = 16
hidden_channels = 128
n_layers = 4
n_spatial_dims = 2

batch_size = 4
height, width = 64, 64

device = torch.device('cpu')

dummy_input = torch.randn(batch_size, in_channels, height, width, device=device)

model = FNO(
    n_modes=(modes, modes),
    in_channels=in_channels,
    out_channels=out_channels,
    hidden_channels=hidden_channels,
    n_layers=n_layers,
    positional_embedding=None,
).to(device)

summary(
    model,
    input_size=(batch_size, in_channels, height, width),
    depth=4,
    col_names=["input_size", "output_size", "num_params"],
    device=device,
)

num_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
print(f'{num_params:,}')

Version

1.1

Environment

dependencies = [
    "torch >= 2.0.1",
    "torchvision >= 0.15.2",
    "numpy < 2.0.0",
    "transformers == 4.55.0",
    "matplotlib",
    "accelerate>=0.32.0",
    "wandb==0.22.2",
    "h5py",
    "pandas",
    "pyyaml",
    "netcdf4>=1.7.2",
    "einops>=0.8.1",
    "scipy>=1.16.1",
    "pytorch-lightning>=2.3.3",
    "ninja>=1.13.0",
    "ipykernel>=6.30.1",
    "seaborn>=0.13.2",
    "huggingface-hub[cli]>=0.34.4",
    "xarray>=2025.8.0",
    "torchinfo>=1.8.0",
    "the-well>=1.1.0",
    "hydra-core>=1.3.2",
    "denoising-diffusion-pytorch>=2.2.5",
    "timm>=1.0.20",
    "neuraloperator>=2.0.0",
    "triton>=3.4.0",
    "ruff>=0.14.5",
]

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions