-
Notifications
You must be signed in to change notification settings - Fork 203
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the issue:
I think torchinfo miscounts the parameters due to some ModuleList shenanigans. If one uses a simple for-loop + torch, the number of parameters is much lower, with the difference being exactly the 7.7mio assigned to the total ModuleList. A simple fix would be to increase hidden_channels to 180, which gives 19,056,783 mio total parameters.
===================================================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param #
===================================================================================================================
FNO [4, 3, 64, 64] [4, 3, 64, 64] --
├─ChannelMLP: 1-1 [4, 3, 64, 64] [4, 128, 64, 64] --
│ └─ModuleList: 2-1 -- -- --
│ │ └─Conv1d: 3-1 [4, 3, 4096] [4, 256, 4096] 1,024
│ │ └─Conv1d: 3-2 [4, 256, 4096] [4, 128, 4096] 32,896
├─FNOBlocks: 1-2 [4, 128, 64, 64] [4, 128, 64, 64] 7,177,536
│ └─ModuleList: 2-14 -- -- (recursive)
│ │ └─Flattened1dConv: 3-3 [4, 128, 64, 64] [4, 128, 64, 64] --
│ │ │ └─Conv1d: 4-1 [4, 128, 4096] [4, 128, 4096] 16,384
│ └─ModuleList: 2-15 -- -- (recursive)
│ │ └─SoftGating: 3-4 [4, 128, 64, 64] [4, 128, 64, 64] 128
│ └─ModuleList: 2-16 -- -- (recursive)
│ │ └─SpectralConv: 3-5 [4, 128, 64, 64] [4, 128, 64, 64] 2,359,424
│ └─ModuleList: 2-17 -- -- (recursive)
│ │ └─ChannelMLP: 3-6 [4, 128, 64, 64] [4, 128, 64, 64] --
│ │ │ └─ModuleList: 4-2 -- -- 16,576
├─FNOBlocks: 1-3 [4, 128, 64, 64] [4, 128, 64, 64] (recursive)
│ └─ModuleList: 2-14 -- -- (recursive)
│ │ └─Flattened1dConv: 3-7 [4, 128, 64, 64] [4, 128, 64, 64] --
│ │ │ └─Conv1d: 4-3 [4, 128, 4096] [4, 128, 4096] 16,384
│ └─ModuleList: 2-15 -- -- (recursive)
│ │ └─SoftGating: 3-8 [4, 128, 64, 64] [4, 128, 64, 64] 128
│ └─ModuleList: 2-16 -- -- (recursive)
│ │ └─SpectralConv: 3-9 [4, 128, 64, 64] [4, 128, 64, 64] 2,359,424
│ └─ModuleList: 2-17 -- -- (recursive)
│ │ └─ChannelMLP: 3-10 [4, 128, 64, 64] [4, 128, 64, 64] --
│ │ │ └─ModuleList: 4-4 -- -- 16,576
├─FNOBlocks: 1-4 [4, 128, 64, 64] [4, 128, 64, 64] (recursive)
│ └─ModuleList: 2-14 -- -- (recursive)
│ │ └─Flattened1dConv: 3-11 [4, 128, 64, 64] [4, 128, 64, 64] --
│ │ │ └─Conv1d: 4-5 [4, 128, 4096] [4, 128, 4096] 16,384
│ └─ModuleList: 2-15 -- -- (recursive)
│ │ └─SoftGating: 3-12 [4, 128, 64, 64] [4, 128, 64, 64] 128
│ └─ModuleList: 2-16 -- -- (recursive)
│ │ └─SpectralConv: 3-13 [4, 128, 64, 64] [4, 128, 64, 64] 2,359,424
│ └─ModuleList: 2-17 -- -- (recursive)
│ │ └─ChannelMLP: 3-14 [4, 128, 64, 64] [4, 128, 64, 64] --
│ │ │ └─ModuleList: 4-6 -- -- 16,576
├─FNOBlocks: 1-5 [4, 128, 64, 64] [4, 128, 64, 64] (recursive)
│ └─ModuleList: 2-14 -- -- (recursive)
│ │ └─Flattened1dConv: 3-15 [4, 128, 64, 64] [4, 128, 64, 64] --
│ │ │ └─Conv1d: 4-7 [4, 128, 4096] [4, 128, 4096] 16,384
│ └─ModuleList: 2-15 -- -- (recursive)
│ │ └─SoftGating: 3-16 [4, 128, 64, 64] [4, 128, 64, 64] 128
│ └─ModuleList: 2-16 -- -- (recursive)
│ │ └─SpectralConv: 3-17 [4, 128, 64, 64] [4, 128, 64, 64] 2,359,424
│ └─ModuleList: 2-17 -- -- (recursive)
│ │ └─ChannelMLP: 3-18 [4, 128, 64, 64] [4, 128, 64, 64] --
│ │ │ └─ModuleList: 4-8 -- -- 16,576
├─ChannelMLP: 1-6 [4, 128, 64, 64] [4, 3, 64, 64] --
│ └─ModuleList: 2-18 -- -- --
│ │ └─Conv1d: 3-19 [4, 128, 4096] [4, 256, 4096] 33,024
│ │ └─Conv1d: 3-20 [4, 256, 4096] [4, 3, 4096] 771
===================================================================================================================
Total params: 16,815,299
Trainable params: 16,815,299
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 3.27
===================================================================================================================
Input size (MB): 0.20
Forward/backward pass size (MB): 319.16
Params size (MB): 0.80
Estimated Total Size (MB): 320.16
===================================================================================================================
Torch counting:
9,637,763
Code to reproduce the issue:
import torch
from torchinfo import summary
from neuralop.models import FNO
in_channels = 3
out_channels = 3
modes = 16
hidden_channels = 128
n_layers = 4
n_spatial_dims = 2
batch_size = 4
height, width = 64, 64
device = torch.device('cpu')
dummy_input = torch.randn(batch_size, in_channels, height, width, device=device)
model = FNO(
n_modes=(modes, modes),
in_channels=in_channels,
out_channels=out_channels,
hidden_channels=hidden_channels,
n_layers=n_layers,
positional_embedding=None,
).to(device)
summary(
model,
input_size=(batch_size, in_channels, height, width),
depth=4,
col_names=["input_size", "output_size", "num_params"],
device=device,
)
num_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
print(f'{num_params:,}')Version
1.1
Environment
dependencies = [
"torch >= 2.0.1",
"torchvision >= 0.15.2",
"numpy < 2.0.0",
"transformers == 4.55.0",
"matplotlib",
"accelerate>=0.32.0",
"wandb==0.22.2",
"h5py",
"pandas",
"pyyaml",
"netcdf4>=1.7.2",
"einops>=0.8.1",
"scipy>=1.16.1",
"pytorch-lightning>=2.3.3",
"ninja>=1.13.0",
"ipykernel>=6.30.1",
"seaborn>=0.13.2",
"huggingface-hub[cli]>=0.34.4",
"xarray>=2025.8.0",
"torchinfo>=1.8.0",
"the-well>=1.1.0",
"hydra-core>=1.3.2",
"denoising-diffusion-pytorch>=2.2.5",
"timm>=1.0.20",
"neuraloperator>=2.0.0",
"triton>=3.4.0",
"ruff>=0.14.5",
]
Context for the issue:
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working