Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 230 additions & 45 deletions .claude/skills/create-prognostic-wrapper/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -981,73 +981,258 @@ debug and fix the wrapper or test, then re-run.

---

## Step 12 — Provide Side-by-Side Comparison Scripts
## Step 12 — Provide E2E Validation Scripts and Comparison Plots

Present two scripts to the user:
Create three runnable scripts that validate the Earth2Studio
wrapper produces output equivalent to the model's native
inference pipeline. These scripts are intended to be
**attached to the PR** for reviewer verification and serve
as end-to-end integration tests.

### Reference script (without Earth2Studio)
### 12a. Earth2Studio reference script (`e2s_reference.py`)

Reconstruct a minimal inference script based on the original reference code:
A self-contained script that runs inference through
`run.deterministic` and writes output to a single file.

```python
# Reference inference (no Earth2Studio)
import torch
# ... original model imports ...
"""<ModelName> inference using the Earth2Studio run.deterministic API.

# Load model
model = OriginalModel.from_pretrained("path/to/checkpoint")
model.eval().cuda()
Loads the model, fetches initial conditions from the
appropriate data source, and runs a multi-step rollout.
Output is written to a single NetCDF file for comparison
with the native-framework reference output.
"""

# Prepare input
input_data = ... # Load/prepare input data per original repo instructions
import earth2studio.run as run
from earth2studio.data import ModelDataSource # e.g., SamudrACEData
from earth2studio.io import NetCDF4Backend
from earth2studio.models.px import ModelName

# Run inference
with torch.no_grad():
output = model(input_data)
# Configuration — match exactly what the native reference uses
SCENARIO = "..." # If applicable
N_STEPS = 40 # Total forward steps
TIME = ["2000-01-01T00:00:00"]
OUTPUT_FILE = "outputs/e2s_reference_output.nc"

print("Loading model via Earth2Studio...")
package = ModelName.load_default_package()
model = ModelName.load_model(package, ...)
data = ModelDataSource(...)

print(f"Running deterministic forecast ({N_STEPS} steps)...")
io = run.deterministic(
time=TIME,
nsteps=N_STEPS,
prognostic=model,
data=data,
io=NetCDF4Backend(
file_name=OUTPUT_FILE,
backend_kwargs={"mode": "w"},
),
)

print(f"Output shape: {output.shape}")
print(f"Results saved to {OUTPUT_FILE}")
```

### Earth2Studio equivalent
**Key requirements:**

- Use the **same initial conditions, forcing data, and
scenario** as the native reference script
- Use `run.deterministic` (not manual iteration) so the
full E2S pipeline (data fetch, coord alignment, IO
write) is exercised
- Output to a **NetCDF file** with all model variables
- Pin random seeds if the model has any stochastic
components (`torch.manual_seed(0)`,
`np.random.seed(0)`) and enable torch deterministic
mode when possible (`torch.use_deterministic_algorithms(True)`,
`torch.backends.cudnn.deterministic = True`,
`torch.backends.cudnn.benchmark = False`)

### 12b. Native-framework reference script (`native_reference.py`)

A self-contained script that runs inference using the
model's **original framework** (not Earth2Studio).

```python
# Earth2Studio inference
import torch
import numpy as np
from earth2studio.models.px import ModelName
from earth2studio.data import Random, fetch_data # or GFS, ERA5, etc.
"""<ModelName> inference using the native framework.

# Load model
model = ModelName.from_pretrained()
model = model.to("cuda")
Downloads model artifacts, then runs inference using the
original API / config-driven pipeline.
"""

# Prepare input via Earth2Studio data pipeline
time = np.array([np.datetime64("2024-01-01T00:00")])
input_coords = model.input_coords()
input_coords["time"] = time
ds = Random(input_coords) # Replace with real data source
x, coords = fetch_data(ds, time, input_coords["variable"], device="cuda")
import os
# ... native framework imports ...

# Single step
with torch.no_grad():
output, out_coords = model(x, coords)
# Download / resolve model artifacts
# (e.g., snapshot_download from HuggingFace)
repo_dir = download_model(...)

print(f"Output shape: {output.shape}")
print(f"Lead time: {out_coords['lead_time']}")
# Run inference using the native API
native_inference(
config="inference_config.yaml",
n_steps=...,
output_dir="outputs/native_reference_output",
)

# Multi-step forecast
iterator = model.create_iterator(x, coords)
for i, (step_x, step_coords) in enumerate(iterator):
print(f"Step {i}: lead_time={step_coords['lead_time']}")
if i >= 10:
break
print("Done.")
```

**Key requirements:**

- Use the **same checkpoint, IC files, and config** as
the E2S script
- Write output to a **known directory structure** with
predictable file names
- Pin random seeds if the model has any stochastic
components (`torch.manual_seed(0)`,
`np.random.seed(0)`) and enable torch deterministic
mode when possible (`torch.use_deterministic_algorithms(True)`,
`torch.backends.cudnn.deterministic = True`,
`torch.backends.cudnn.benchmark = False`)

### 12c. Comparison script (`compare.py`)

A script that loads both outputs, aligns them by time /
lead-time, and produces:

1. **Global-mean timeseries** — per variable, 3-panel
(reference, E2S, difference)
2. **Spatial maps** — at selected lead times, 3-panel
(reference, E2S, difference with RMSE/MAE stats)
3. **Summary statistics** — per-variable RMSE, MAE,
max absolute difference, and `np.allclose` result

```python
"""Compare <ModelName> forecasts from native and E2S pipelines.

Plots selected variables as timeseries (global mean) and
spatial maps at selected lead times, plus the difference.

Usage:
python compare.py
"""

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from earth2studio.lexicon.<model_lexicon> import ModelLexicon

# ---- Variable definitions ----
# Map E2S variable names to native-framework names, grouped
# by component if the model has multiple components.
VARIABLES = {
"var_key": {
"e2s_name": "...",
"native_name": "...", # Native framework name
"component": "...", # If applicable (e.g., "atmosphere", "ocean")
"label": "Description [units]",
"cmap": "RdYlBu_r",
"vmin": ..., "vmax": ...,
},
# ... more variables ...
}

# ---- Load datasets ----
ds_ref = xr.open_dataset("outputs/native_reference_output/predictions.nc")
ds_e2s = xr.open_dataset("outputs/e2s_reference_output.nc")

# ---- Helper functions ----

def get_ref_series(da):
"""Extract (n_steps, lat, lon) from native output."""
return da.isel(sample=0).values

def get_e2s_series(da, component="atmosphere"):
"""Extract (n_steps, lat, lon) from E2S output.

Skips the IC at lead_time=0 since the native output
typically contains only predictions.
"""
vals = da.isel(time=0).values[1:] # skip IC
# If component has different step cadence, sub-sample:
# if component == "ocean":
# indices = np.arange(N_INNER - 1, len(vals), N_INNER)
# vals = vals[indices]
return vals

# ---- Figure 1: Global-mean timeseries ----
# ... 3-panel plot per variable ...

# ---- Figure 2: Spatial maps at selected steps ----
# ... 3-panel (ref, E2S, diff) per step ...

# ---- Summary statistics ----
for var_key, var_info in VARIABLES.items():
# Compute RMSE, MAE, max abs diff, allclose
...

ds_ref.close()
ds_e2s.close()
```

### Important considerations for the comparison script

**Time / lead-time alignment:**

The E2S output from `run.deterministic` includes the
initial condition at `lead_time=0` as step 0, followed
by predictions at `lead_time=dt, 2*dt, ...`. The native
framework may write predictions only (no IC) starting
from the first forecast step. The comparison script must
account for this offset — typically by skipping the E2S
IC entry (`values[1:]`).

**Component-specific step cadences:**

If the model has components that update at different
rates (e.g., atmosphere every 6h, ocean every 5 days),
the native framework may write each component's output
at its own cadence while E2S writes all variables at
every atmosphere step. The comparison script must
sub-sample the E2S output at the slower component's
update boundaries to align with the native output.

**Variable name mapping:**

Use the model's lexicon class to map between E2S and
native variable names. Verify the mapping is correct by
checking a few variables at the IC step (where both
outputs should be identical).

**Grid alignment:**

Check whether both outputs use the same latitude
ordering (north-to-south vs south-to-north). If they
differ, flip one before computing differences.

### Deliverables for the PR

Attach the following to the pull request:

| File | Purpose |
|------|---------|
| `e2s_reference.py` | E2S inference script |
| `native_reference.py` (or `<framework>_reference.py`) | Native-framework inference script |
| `compare.py` | Comparison and plotting script |
| `compare_timeseries.png` | Global-mean timeseries plot |
| `compare_<var>.png` | Per-variable spatial map plots |

These do **not** get committed to the repository — they
are PR attachments for reviewer validation only.

### **[CONFIRM — Comparison Scripts]**

Ask the user to compare the two scripts and verify the
Earth2Studio version is functionally equivalent to the
reference.
Present the three scripts to the user and ask:

1. Do the E2S and native reference scripts use identical
inputs (IC, forcing, scenario, number of steps)?
2. Is the time alignment logic in `compare.py` correct
for this model's output structure?
3. Are the comparison plots showing near-zero differences
at the IC step and acceptable divergence at later
steps?

---

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added GenCast 1 degree Mini model
- Added SamudrACE Coupled Climate Model

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/modules/models_px.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Thus are typically used to generate forecast predictions.
Pangu6
Pangu3
Persistence
SamudrACE
SFNO
StormCast
StormScopeGOES
Expand Down
1 change: 1 addition & 0 deletions earth2studio/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from .rand import Random, Random_FX, RandomDataFrame
from .rx import CosineSolarZenith, LandSeaMask, SurfaceGeoPotential
from .samudrace import SamudrACEData
from .time_window import TimeWindow
from .ufs import UFSObsConv, UFSObsSat
from .utils import datasource_to_file, fetch_data, fetch_dataframe, prep_data_array
Expand Down
5 changes: 3 additions & 2 deletions earth2studio/data/ace2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@
@check_optional_dependencies("ace2")
class ACE2ERA5Data:
"""ACE2-ERA5 data source providing forcing or initial-conditions data.
Files are downloaded on-demand and cached automatically. Data are served as-is; no transformations are applied,
with the exception of global mean CO2 concentration, which may be overridden by a user-supplied function.
Files are downloaded on-demand and cached automatically. Data are served as-is; no
transformations are applied, with the exception of global mean CO2 concentration,
which may be overridden by a user-supplied function.

Provides all input variables described in the ACE2-ERA5 paper.

Expand Down
Loading