Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
drop_rate (float, optional): Drop-out rate.
drop_path (float, optional): Drop-path rate.
enc_depth (int, optional): Number of Perceiver blocks in the encoder.
dec_depth (int, optioanl): Number of Perceiver blocks in the decoder.
dec_depth (int, optional): Number of Perceiver blocks in the decoder.
dec_mlp_ratio (float, optional): Hidden dim. to embedding dim. ratio for MLPs in the
decoder. The embedding dimensionality here is different, which is why this is a
separate parameter.
Expand Down Expand Up @@ -266,7 +266,7 @@ def forward(self, batch: Batch) -> Batch:
"""Forward pass.

Args:
batch (:class:`Batch`): Batch to run the model on.
batch (:class:`aurora.Batch`): Batch to run the model on.

Returns:
:class:`Batch`: Prediction for the batch.
Expand Down Expand Up @@ -472,7 +472,7 @@ def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor])

If a checkpoint was trained with a larger `max_history_size` than the current model,
this function will assert fail to prevent loading the checkpoint. This is to
prevent loading a checkpoint which will likely cause the checkpoint to degrade is
prevent loading a checkpoint which will likely cause the checkpoint to degrade its
performance.

This implementation copies weights from the checkpoint to the model and fills zeros
Expand Down
2 changes: 1 addition & 1 deletion aurora/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def forward(

Args:
x (torch.Tensor): Backbone output of shape `(B, L, D)`.
batch (:class:`aurora.batch.Batch`): Batch to make predictions for.
batch (:class:`aurora.Batch`): Batch to make predictions for.
patch_res (tuple[int, int, int]): Patch resolution
lead_time (timedelta): Lead time.

Expand Down
2 changes: 1 addition & 1 deletion aurora/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def forward(self, batch: Batch, lead_time: timedelta) -> torch.Tensor:
"""Peform encoding.

Args:
batch (:class:`.Batch`): Batch to encode.
batch (:class:`aurora.Batch`): Batch to encode.
lead_time (timedelta): Lead time.

Returns:
Expand Down
4 changes: 2 additions & 2 deletions aurora/model/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
r: int = 4,
alpha: int = 1,
dropout: float = 0.0,
):
) -> None:
"""Initialise.

Args:
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
dropout: float = 0.0,
max_steps: int = 40,
mode: LoRAMode = "single",
):
) -> None:
"""Initialise.

Args:
Expand Down
11 changes: 7 additions & 4 deletions aurora/model/swin3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def forward(
mask (torch.Tensor, optional): Attention mask of floating points in the range
`[-inf, 0)` with shape of `(nW, ws, ws)`, where `nW` is the number of windows,
and `ws` is the window size (i.e. total tokens inside the window).
rollout_step (int, optional): Roll-out step. Defaults to `0`.

Returns:
torch.Tensor: Output of shape `(nW*B, N, C)`.
Expand Down Expand Up @@ -198,8 +199,8 @@ def window_partition_3d(x: torch.Tensor, ws: tuple[int, int, int]) -> torch.Tens
"""Partition into windows.

Args:
x: (torch.Tensor): Input tensor of shape `(B, C, H, W, D)`.
ws: (tuple[int, int, int]): A 3D window size `(Wc, Wh, Ww)`.
x (torch.Tensor): Input tensor of shape `(B, C, H, W, D)`.
ws (tuple[int, int, int]): A 3D window size `(Wc, Wh, Ww)`.

Returns:
torch.Tensor: Partitioning of shape `(num_windows*B, Wc, Wh, Ww, D)`.
Expand Down Expand Up @@ -318,7 +319,8 @@ def compute_3d_shifted_window_mask(
H (int): Height of the image.
W (int): Width of the image.
ws (tuple[int, int, int]): Window sizes of the form `(Wc, Wh, Ww)`.
ss (tuple[int, int, int]): Shift sizes of the form `(Sc, Sh, Sw)`
ss (tuple[int, int, int]): Shift sizes of the form `(Sc, Sh, Sw)`.
device (torch.device): Device of the mask.
dtype (torch.dtype, optional): Data type of the mask. Defaults to `torch.bfloat16`.
warped (bool): If `True`,assume that the left and right sides of the image are connected.
Defaults to `True`.
Expand Down Expand Up @@ -768,7 +770,8 @@ def __init__(
lora_mode: LoRAMode = "single",
use_lora: bool = False,
) -> None:
"""
"""Initialise.

Args:
embed_dim (int): Patch embedding dimension. Default to `96`.
encoder_depths (tuple[int, ...]): Number of blocks in each encoder layer. Defaults to
Expand Down
1 change: 1 addition & 0 deletions aurora/model/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def unpatchify(x: torch.Tensor, V: int, H: int, W: int, P: int) -> torch.Tensor:
V (int): Number of variables.
H (int): Number of latitudes.
W (int): Number of longitudes.
P (int): Patch size.

Returns:
torch.Tensor: Unpatchified representation of shape `(B, V, C, H, W)`.
Expand Down
6 changes: 3 additions & 3 deletions aurora/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def rollout(model: Aurora, batch: Batch, steps: int) -> Generator[Batch, None, N
"""Perform a roll-out to make long-term predictions.

Args:
model (:class:`aurora.model.aurora.Aurora`): The model to roll out.
batch (:class:`aurora.batch.Batch`): The batch to start the roll-out from.
model (:class:`aurora.Aurora`): The model to roll out.
batch (:class:`aurora.Batch`): The batch to start the roll-out from.
steps (int): The number of roll-out steps.

Yields:
:class:`aurora.batch.Batch`: The prediction after every step.
:class:`aurora.Batch`: The prediction after every step.
"""
# We will need to concatenate data, so ensure that everything is already of the right form.
batch = model.batch_transform_hook(batch) # This might modify the available variables.
Expand Down
2 changes: 1 addition & 1 deletion aurora/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def step(self, batch: Batch) -> None:
"""Track the next step.

Args:
batch (:class:`aurora.batch.Batch`): Prediction.
batch (:class:`aurora.Batch`): Prediction.
"""
# Check that there is only one prediction. We don't support batched tracking.
if len(batch.metadata.time) != 1:
Expand Down
Loading