Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ htmlcov
.DS_Store
*.swp
.envrc
uv.lock

checkpoints/
mlflow_tmp/
Expand Down
4 changes: 4 additions & 0 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
positive_atmos_vars: tuple[str, ...] = (),
clamp_at_first_step: bool = False,
simulate_indexing_bug: bool = False,
use_chunked_checkpointing: bool = False,
) -> None:
"""Construct an instance of the model.

Expand Down Expand Up @@ -175,6 +176,7 @@ def __init__(
simulate_indexing_bug (bool, optional): Simulate an indexing bug that's present for the
air pollution version of Aurora. This is necessary to obtain numerical equivalence
to the original implementation. Defaults to `False`.
use_chunked_checkpointing (bool, optional): Enable chunked-checkpointing.
"""
super().__init__()
self.surf_vars = surf_vars
Expand Down Expand Up @@ -215,6 +217,7 @@ def __init__(
dynamic_vars=dynamic_vars,
atmos_static_vars=atmos_static_vars,
simulate_indexing_bug=simulate_indexing_bug,
use_chunked_checkpointing=use_chunked_checkpointing,
)

self.backbone = Swin3DTransformerBackbone(
Expand Down Expand Up @@ -249,6 +252,7 @@ def __init__(
level_condition=level_condition,
separate_perceiver=separate_perceiver,
modulation_heads=modulation_heads,
use_chunked_checkpointing=use_chunked_checkpointing,
)

if bf16_mode and not autocast:
Expand Down
4 changes: 4 additions & 0 deletions aurora/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
level_condition: Optional[tuple[int | float, ...]] = None,
separate_perceiver: tuple[str, ...] = (),
modulation_heads: tuple[str, ...] = (),
use_chunked_checkpointing: bool = False,
) -> None:
"""Initialise.

Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
modulation_heads (tuple[str, ...], optional): Names of every variable for which to
enable an additional head, the so-called modulation head, that can be used to
predict the difference.
use_chunked_checkpointing (bool, optional): Enable chunked-checkpointing.
"""
super().__init__()

Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(
drop=drop_rate,
residual_latent=True,
ln_eps=perceiver_ln_eps,
use_chunked_checkpointing=use_chunked_checkpointing,
)
if self.separate_perceiver:
self.level_decoder_alternate = PerceiverResampler(
Expand All @@ -112,6 +115,7 @@ def __init__(
drop=drop_rate,
residual_latent=True,
ln_eps=perceiver_ln_eps,
use_chunked_checkpointing=use_chunked_checkpointing,
)

self.surf_heads = nn.ParameterDict(
Expand Down
3 changes: 3 additions & 0 deletions aurora/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
dynamic_vars: bool = False,
atmos_static_vars: bool = False,
simulate_indexing_bug: bool = False,
use_chunked_checkpointing: bool = False,
) -> None:
"""Initialise.

Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
simulate_indexing_bug (bool, optional): Simulate an indexing bug that's present for the
air pollution version of Aurora. This is necessary to obtain numerical equivalence
to the original implementation. Defaults to `False`.
use_chunked_checkpointing (bool, optional): Enable chunked-checkpointing.
"""
super().__init__()

Expand Down Expand Up @@ -156,6 +158,7 @@ def __init__(
mlp_ratio=mlp_ratio,
ln_eps=perceiver_ln_eps,
ln_k_q=stabilise_level_agg,
use_chunked_checkpointing=use_chunked_checkpointing,
)

# Drop patches after encoding.
Expand Down
8 changes: 8 additions & 0 deletions aurora/model/perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from chunkcheck import chunk_and_checkpoint

__all__ = ["MLP", "PerceiverResampler"]

Expand Down Expand Up @@ -167,6 +168,7 @@ def __init__(
residual_latent: bool = True,
ln_eps: float = 1e-5,
ln_k_q: bool = False,
use_chunked_checkpointing = False,
) -> None:
"""Initialise.

Expand All @@ -190,6 +192,7 @@ def __init__(

self.residual_latent = residual_latent
self.layers = nn.ModuleList([])
self.use_chunked_checkpointing = use_chunked_checkpointing
mlp_hidden_dim = int(latent_dim * mlp_ratio)
for i in range(depth):
self.layers.append(
Expand Down Expand Up @@ -219,6 +222,11 @@ def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Latent features of shape `(B, L1, D1)`.
"""
if self.use_chunked_checkpointing:
return chunk_and_checkpoint(self._forward, latents, x, chunk_size=2025)
return self._forward(latents, x)

def _forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
for attn, ff, ln1, ln2 in self.layers:
# We use post-res-norm like in Swin v2 and most Transformer architectures these days.
# This empirically works better than the pre-norm used in the original Perceiver.
Expand Down
1 change: 1 addition & 0 deletions aurora/model/swin3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from chunkcheck import chunk_and_checkpoint
from einops import rearrange
from timm.layers import DropPath, to_3tuple

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"xarray",
"netcdf4",
"azure-storage-blob",
"chunkcheck>=0.1.1",
]

[project.optional-dependencies]
Expand Down
Loading