Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/eval/evaluate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"# Loading paths to results. We're comparing to summed q_prime as it's a good indicator of if routing is working\n",
"summed_q_prime_path = Path(\"./summed_q_prime.zarr\") # To obtain this, please run scripts/summed_q_prime.py\n",
"predictions_path = Path(\n",
" \"./model_test.zarr\"\n",
" \"/projects/mhpi/tbindas/ddr/runs/ddr-v0.1.3-eval/2025-08-28_08-24-32/model_test.zarr\"\n",
") # To obtain this, please run scripts/test.py to evaluate a trained model\n",
"\n",
"ds_qp = xr.open_zarr(summed_q_prime_path)\n",
Expand Down Expand Up @@ -221,7 +221,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.5"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
4 changes: 3 additions & 1 deletion examples/parameter_maps/plot_parameter_map.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@
"outputs": [],
"source": [
"# Load pretrained model states\n",
"model_states = Path(\"./ddr_v0.1.0a2_trained_model_weights.pt\")\n",
"model_states = Path(\n",
" \"/projects/mhpi/tbindas/ddr/runs/ddr-v0.1.3-train/2025-08-27_19-26-09/saved_models/_ddr-v0.1.3-train_epoch_5_mb_42.pt\"\n",
")\n",
"\n",
"log.info(f\"Loading spatial_nn from checkpoint: {model_states.stem}\")\n",
"state = torch.load(model_states, map_location=device)\n",
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
"pykan==0.2.8",
"scikit-learn==1.7.0",
"scipy==1.16.0",
"torch==2.7.1",
"tqdm==4.67.1",
"xarray==2025.7.1",
"zarr==3.0.9",
Expand All @@ -52,7 +53,6 @@ docs = [
"sympy==1.14.0"
]
cuda = [
"torch==2.7.1",
"cupy-cuda12x==13.4.1",
]

Expand All @@ -63,7 +63,6 @@ tests = [
"ruff==0.12.2",
"nbstripout==0.8.1",
"boto3==1.39.14",

]
engine = [
"adbc-driver-sqlite==1.6.0",
Expand Down
7 changes: 7 additions & 0 deletions shoutout.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
### Shoutout File

This is a file that shouts out work that was used as inspiration for parts of this code

#### Neural Hydrology
https://github.com/neuralhydrology/neuralhydrology
The dataset abstractions used for hooking up many datasets in `datasetzoo` were used in abstracting the datasets in this code.
128 changes: 128 additions & 0 deletions src/ddr/geodatazoo/Dates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import logging
from datetime import datetime
from typing import Any

import numpy as np
import pandas as pd
import torch
from pydantic import BaseModel, ConfigDict, model_validator

log = logging.getLogger(__name__)


class Dates(BaseModel):
"""Dates class for handling time operations for training dMC models"""

model_config = ConfigDict(arbitrary_types_allowed=True)
daily_format: str = "%Y/%m/%d"
hourly_format: str = "%Y/%m/%d %H:%M:%S"
origin_start_date: str = "1980/01/01"
start_time: str
end_time: str
rho: int | None = None
batch_daily_time_range: pd.DatetimeIndex | None = pd.DatetimeIndex([], dtype="datetime64[ns]")
batch_hourly_time_range: pd.DatetimeIndex | None = pd.DatetimeIndex([], dtype="datetime64[ns]")
daily_time_range: pd.DatetimeIndex | None = pd.DatetimeIndex([], dtype="datetime64[ns]")
daily_indices: np.ndarray = np.empty(0)
hourly_time_range: pd.DatetimeIndex | None = pd.DatetimeIndex([], dtype="datetime64[ns]")
hourly_indices: torch.Tensor | None = torch.empty(0)
numerical_time_range: np.ndarray = np.empty(0)

def __init__(self, **kwargs):
super().__init__(
start_time=kwargs["start_time"],
end_time=kwargs["end_time"],
rho=kwargs["rho"],
)

@model_validator(mode="after")
@classmethod
def validate_dates(cls, dates: Any) -> Any:
"""Validates that the number of days you select is within the range of the start and end times"""
rho = dates.rho
if isinstance(rho, int):
if rho > len(dates.daily_time_range):
log.exception(
ValueError("Rho needs to be smaller than the routed period between start and end times")
)
raise ValueError("Rho needs to be smaller than the routed period between start and end times")
return dates

def model_post_init(self, __context: Any) -> None:
"""Initializes the Dates object and time ranges"""
self.daily_time_range = pd.date_range(
datetime.strptime(self.start_time, self.daily_format),
datetime.strptime(self.end_time, self.daily_format),
freq="D",
inclusive="both",
)
self.hourly_time_range = pd.date_range(
start=self.daily_time_range[0],
end=self.daily_time_range[-1],
freq="h",
inclusive="left",
)
self.batch_daily_time_range = self.daily_time_range
self.set_batch_time(self.daily_time_range)

def set_batch_time(self, daily_time_range: pd.DatetimeIndex) -> None:
"""Sets the time range for the batch you're train/test/simulating

Parameters
----------
daily_time_range : pd.DatetimeIndex
The daily time range you want to select
"""
self.batch_hourly_time_range = pd.date_range(
start=daily_time_range[0],
end=daily_time_range[-1],
freq="h",
inclusive="left",
)
origin_start_date = datetime.strptime(self.origin_start_date, self.daily_format)
origin_base_start_time = int(
(daily_time_range[0].to_pydatetime() - origin_start_date).total_seconds() / 86400
)
origin_base_end_time = int(
(daily_time_range[-1].to_pydatetime() - origin_start_date).total_seconds() / 86400
)

# The indices for the dates in your selected routing time range
self.numerical_time_range = np.arange(origin_base_start_time, origin_base_end_time + 1, 1)

self._create_daily_indices()
self._create_hourly_indices()

def _create_hourly_indices(self) -> None:
common_elements = self.hourly_time_range.intersection(self.batch_hourly_time_range)
self.hourly_indices = torch.tensor([self.hourly_time_range.get_loc(time) for time in common_elements])

def _create_daily_indices(self):
common_elements = self.daily_time_range.intersection(self.batch_daily_time_range)
self.daily_indices = np.array([self.daily_time_range.get_loc(time) for time in common_elements])

def calculate_time_period(self) -> None:
"""Calculates the time period for the dataset using rho"""
if self.rho is not None:
sample_size = len(self.daily_time_range)
random_start = torch.randint(low=0, high=sample_size - self.rho, size=(1, 1))[0][0].item()
self.batch_daily_time_range = self.daily_time_range[random_start : (random_start + self.rho)]
self.set_batch_time(self.batch_daily_time_range)

def set_date_range(self, chunk: np.ndarray) -> None:
"""Sets the date range for the dataset

Parameters
----------
chunk : np.ndarray
The chunk of the date range you want to select
"""
self.batch_daily_time_range = self.daily_time_range[chunk]
self.set_batch_time(self.batch_daily_time_range)

def create_time_windows(self) -> np.ndarray:
"""Creates the time slices, or windows, for testing the model"""
num_pieces = self.daily_time_range.shape[0] // self.rho
last_time = num_pieces * self.rho
reshaped_arr = np.reshape(self.daily_time_range[:last_time], (num_pieces, self.rho))
return reshaped_arr
54 changes: 54 additions & 0 deletions src/ddr/geodatazoo/Gauges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import csv
from pathlib import Path
from typing import Annotated

from pydantic import AfterValidator, BaseModel, ConfigDict, PositiveFloat


def zfill_usgs_id(STAID: str) -> str:
"""Ensures all USGS gauge strings that are filled to 8 digits

Parameters
----------
STAID: str
The USGS Station ID

Returns
-------
str
The eight-digit USGS Gauge ID
"""
return STAID.zfill(8)


class Gauge(BaseModel):
"""A pydantic object for managing properties for a Gauge and validating incoming CSV files"""

model_config = ConfigDict(extra="ignore")
STAID: Annotated[str, AfterValidator(zfill_usgs_id)]
DRAIN_SQKM: PositiveFloat


class GaugeSet(BaseModel):
"""A pydantic object for storing a list of Gauges"""

gauges: list[Gauge]


def validate_gages(file_path: Path) -> GaugeSet:
"""A function to read the training gauges file and validate based on a pydantic schema

Parameters
----------
file_path: Path
The path to the gauges csv file

Returns
-------
GaugeSet
A set of pydantic-validated gauges
"""
with file_path.open() as f:
reader = csv.DictReader(f)
gauges = [Gauge.model_validate(row) for row in reader]
return GaugeSet(gauges=gauges)
Empty file added src/ddr/geodatazoo/__init__.py
Empty file.
Loading