From 063761ffd4396a6c420f1ddfc1163dd27616dcc6 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Wed, 11 Feb 2026 13:21:55 -0800 Subject: [PATCH 1/2] Add claude script generated from my plotting notebook --- scripts/downscaling/plot_beaker_histograms.py | 221 ++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 scripts/downscaling/plot_beaker_histograms.py diff --git a/scripts/downscaling/plot_beaker_histograms.py b/scripts/downscaling/plot_beaker_histograms.py new file mode 100644 index 000000000..5c6135b3d --- /dev/null +++ b/scripts/downscaling/plot_beaker_histograms.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python +""" +Fetch netCDF event files from a beaker dataset and generate histogram plots +comparing ensemble predictions against targets for each variable. + +Usage: + python plot_beaker_histograms.py [--output-dir ] +""" + +import argparse +import re +import subprocess +import tempfile +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr + +_EVENT_FILE_RE = re.compile(r"(.+)_(\d{8})\.nc$") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Generate histogram plots from beaker dataset event files" + ) + parser.add_argument( + "beaker_dataset_id", + help="The beaker dataset ID to fetch", + ) + parser.add_argument( + "--output-dir", + default="./histogram_outputs", + help="Output directory for figures (default: ./histogram_outputs)", + ) + return parser.parse_args() + + +def fetch_beaker_dataset(dataset_id: str, target_dir: str) -> None: + """Fetch a beaker dataset to the specified directory.""" + subprocess.run( + ["beaker", "dataset", "fetch", dataset_id, "--output", target_dir], + check=True, + ) + + +def find_event_files(directory: str) -> list[Path]: + """Find netCDF files matching the event naming pattern: *_.nc""" + return sorted( + p for p in Path(directory).glob("*.nc") if _EVENT_FILE_RE.match(p.name) + ) + + +def extract_event_name(filepath: Path) -> str: + """Extract event name from filename (everything before the date portion).""" + match = _EVENT_FILE_RE.match(filepath.name) + return match.group(1) if match else filepath.stem + + +def detect_variable_pairs(ds: xr.Dataset) -> list[str]: + """Detect variables that have both _predicted and _target versions.""" + predicted = { + v[: -len("_predicted")] for v in ds.data_vars if v.endswith("_predicted") + } + target = {v[: -len("_target")] for v in ds.data_vars if v.endswith("_target")} + return sorted(predicted & target) + + +def plot_histogram_lines( + ds: xr.Dataset, + key_prefix: str, + title_prefix: str, + save_path: Path, +) -> Path: + """ + Plot histogram comparing ensemble predictions against target. + + Ported from notebook cell eb8c884c with modifications to save figure. + """ + fig, ax = plt.subplots(figsize=(10, 6)) + + # Gather all predicted and target data + target_data = ds[f"{key_prefix}_target"] + predicted_data = ds[f"{key_prefix}_predicted"] + all_data = np.concatenate([target_data.values[None], predicted_data.values]) + + # Compute bin edges from min/max of all data + bins = 50 + data_min, data_max = np.min(all_data), np.max(all_data) + bin_edges = np.linspace(data_min, data_max, bins + 1) + bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1]) + + sample_data = predicted_data.values + lower_bounds = np.percentile(sample_data, 0.01, axis=(1, 2)) + upper_bounds = np.percentile(sample_data, 99.99, axis=(1, 2)) + lower_bound_2p5 = np.percentile(lower_bounds, 2.5) + lower_bound_97p5 = np.percentile(lower_bounds, 97.5) + upper_bound_2p5 = np.percentile(upper_bounds, 2.5) + upper_bound_97p5 = np.percentile(upper_bounds, 97.5) + + # Calculate target percentiles + target_lower_0p01_percentile = np.percentile(target_data.values, 0.01) + target_upper_99p99_percentile = np.percentile(target_data.values, 99.99) + counts, _ = np.histogram(target_data.values, bins=bin_edges) + + # Calculate histogram for each predicted sample + all_counts = [] + num_samples = sample_data.shape[0] + for i in range(num_samples): + sample_flat = sample_data[i].flatten() + sample_counts, _ = np.histogram(sample_flat, bins=bin_edges) + ax.step( + bin_centers, + sample_counts, + where="mid", + alpha=0.1, + label="Samples" if i == 0 else None, + ) + all_counts.append(sample_counts) + all_counts = np.stack(all_counts) + + ax.step( + bin_centers, counts, where="mid", color="black", linewidth=2, label="Target" + ) + ax.axvline( + target_lower_0p01_percentile, + color="black", + linestyle="dashed", + linewidth=1, + label="Target 0.01%", + ) + ax.axvline( + target_upper_99p99_percentile, + color="black", + linestyle="dashed", + linewidth=1, + label="Target 99.99%", + ) + ax.fill_betweenx( + [1e-5, 1e10], + upper_bound_2p5, + upper_bound_97p5, + color="gray", + alpha=0.2, + label="Pred Percentile (95% CI)", + ) + ax.fill_betweenx( + [1e-5, 1e10], + lower_bound_2p5, + lower_bound_97p5, + color="gray", + alpha=0.2, + ) + + avg_counts = np.mean(all_counts, axis=0) + ax.step( + bin_centers, + avg_counts, + where="mid", + color="C0", + linewidth=2, + label="Average Predicted", + ) + var_label = key_prefix.replace("_", " ").title() + ax.set_xlabel(var_label) + ax.set_ylabel("Count") + ax.set_yscale("log") + ax.set_ylim(0.1, 10 ** (np.log10(np.max(counts)) + 1)) + ax.grid(which="major", linestyle="--", linewidth=0.5, alpha=0.5) + ax.set_title(f"{title_prefix} distribution: {var_label}") + ax.legend() + + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + return save_path + + +def main(): + args = parse_args() + beaker_id = args.beaker_dataset_id + output_dir = Path(args.output_dir) + + print(f"Fetching beaker dataset: {beaker_id}") + + with tempfile.TemporaryDirectory() as temp_dir: + fetch_beaker_dataset(beaker_id, temp_dir) + + event_files = find_event_files(temp_dir) + if not event_files: + print(f"No event files found in dataset {beaker_id}") + return + + print(f"Found {len(event_files)} event file(s)") + + for nc_file in event_files: + event_name = extract_event_name(nc_file) + output_event_dir = output_dir / beaker_id / event_name + + print(f"Processing: {nc_file.name} -> {output_event_dir}") + + ds = xr.open_dataset(nc_file) + variables = detect_variable_pairs(ds) + + if not variables: + print(f" No variable pairs found in {nc_file.name}") + continue + + for var_prefix in variables: + fig_path = output_event_dir / f"{var_prefix}.png" + plot_histogram_lines(ds, var_prefix, event_name, save_path=fig_path) + print(f" Saved: {fig_path}") + + ds.close() + + print("Done!") + + +if __name__ == "__main__": + main() From 67a9705932681b4412999b1357480f175ee20947 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Wed, 11 Feb 2026 14:31:01 -0800 Subject: [PATCH 2/2] Cleanup --- scripts/downscaling/plot_beaker_histograms.py | 59 ++++++++++++------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/scripts/downscaling/plot_beaker_histograms.py b/scripts/downscaling/plot_beaker_histograms.py index 5c6135b3d..ccf7da457 100644 --- a/scripts/downscaling/plot_beaker_histograms.py +++ b/scripts/downscaling/plot_beaker_histograms.py @@ -3,8 +3,16 @@ Fetch netCDF event files from a beaker dataset and generate histogram plots comparing ensemble predictions against targets for each variable. +This will work for saved event outputs from `fme.downscaling.evaluator` +from a beaker experiment. It downloads the experiment files to a temporary +directory and then parses the filenames for _YYYYMMDD.nc to look +for single-event outputs. + Usage: python plot_beaker_histograms.py [--output-dir ] + +Requires: + beaker CLI to be installed and authenticated (https://github.com/allenai/beaker). """ import argparse @@ -17,6 +25,7 @@ import numpy as np import xarray as xr +# Matching for _YYYYMMDD.nc _EVENT_FILE_RE = re.compile(r"(.+)_(\d{8})\.nc$") @@ -44,17 +53,15 @@ def fetch_beaker_dataset(dataset_id: str, target_dir: str) -> None: ) -def find_event_files(directory: str) -> list[Path]: - """Find netCDF files matching the event naming pattern: *_.nc""" - return sorted( - p for p in Path(directory).glob("*.nc") if _EVENT_FILE_RE.match(p.name) - ) - - -def extract_event_name(filepath: Path) -> str: - """Extract event name from filename (everything before the date portion).""" - match = _EVENT_FILE_RE.match(filepath.name) - return match.group(1) if match else filepath.stem +def find_event_files(directory: str) -> dict[str, Path]: + """Find netCDF files matching the event naming pattern, keyed by event name.""" + event_files = {} + for p in sorted(Path(directory).glob("*.nc")): + # extract event name + matched = _EVENT_FILE_RE.match(p.name) + if matched: + event_files[matched.group(1)] = p + return event_files def detect_variable_pairs(ds: xr.Dataset) -> list[str]: @@ -71,11 +78,9 @@ def plot_histogram_lines( key_prefix: str, title_prefix: str, save_path: Path, -) -> Path: +) -> None: """ Plot histogram comparing ensemble predictions against target. - - Ported from notebook cell eb8c884c with modifications to save figure. """ fig, ax = plt.subplots(figsize=(10, 6)) @@ -91,6 +96,14 @@ def plot_histogram_lines( bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1]) sample_data = predicted_data.values + if sample_data.ndim != 3: + raise ValueError( + f"Expected predicted data to be 3D (samples, lat, lon), " + f"got shape {sample_data.shape}" + ) + + # Calculate the tail percentile values for each generated sample + # and generate a 95% confidence interval lower_bounds = np.percentile(sample_data, 0.01, axis=(1, 2)) upper_bounds = np.percentile(sample_data, 99.99, axis=(1, 2)) lower_bound_2p5 = np.percentile(lower_bounds, 2.5) @@ -103,6 +116,10 @@ def plot_histogram_lines( target_upper_99p99_percentile = np.percentile(target_data.values, 99.99) counts, _ = np.histogram(target_data.values, bins=bin_edges) + # Pre-compute y-axis limits so fill_betweenx spans exactly the plot area + ylim_min = 0.1 + ylim_max = 10 ** (np.log10(np.max(counts)) + 1) + # Calculate histogram for each predicted sample all_counts = [] num_samples = sample_data.shape[0] @@ -137,19 +154,20 @@ def plot_histogram_lines( label="Target 99.99%", ) ax.fill_betweenx( - [1e-5, 1e10], + [ylim_min, ylim_max], upper_bound_2p5, upper_bound_97p5, color="gray", alpha=0.2, - label="Pred Percentile (95% CI)", + label="Pred upper tail 95% CI", ) ax.fill_betweenx( - [1e-5, 1e10], + [ylim_min, ylim_max], lower_bound_2p5, lower_bound_97p5, color="gray", alpha=0.2, + label="Pred lower tail 95% CI", ) avg_counts = np.mean(all_counts, axis=0) @@ -165,7 +183,7 @@ def plot_histogram_lines( ax.set_xlabel(var_label) ax.set_ylabel("Count") ax.set_yscale("log") - ax.set_ylim(0.1, 10 ** (np.log10(np.max(counts)) + 1)) + ax.set_ylim(ylim_min, ylim_max) ax.grid(which="major", linestyle="--", linewidth=0.5, alpha=0.5) ax.set_title(f"{title_prefix} distribution: {var_label}") ax.legend() @@ -174,8 +192,6 @@ def plot_histogram_lines( fig.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) - return save_path - def main(): args = parse_args() @@ -194,8 +210,7 @@ def main(): print(f"Found {len(event_files)} event file(s)") - for nc_file in event_files: - event_name = extract_event_name(nc_file) + for event_name, nc_file in event_files.items(): output_event_dir = output_dir / beaker_id / event_name print(f"Processing: {nc_file.name} -> {output_event_dir}")