diff --git a/pysteps/mongo/README.md b/pysteps/mongo/README.md new file mode 100644 index 000000000..1c97b162a --- /dev/null +++ b/pysteps/mongo/README.md @@ -0,0 +1,42 @@ +# mongo + +## Executable scripts + +### create_mongo_user.py + +This script is run by the database administrator to register a new user for the STEPS database + +### delete_files.py + +House keeping utility to delete records from the database + +### init_steps_db.py + +This script creates the STEPS database with the expected colletions and indices. + +### load_config.py + +This script loads the JSON configuration file into the STEPS database. + +### write_nc_files.py + +Read the database and generate the netCDF files for exporting to users. + +### write_ensembles.py + +An example of a product that is supplied to an end-user. + +## modules + +### gridfs_io.py + +Functions to read and write the binary data to GridFS + +### mongo_access.py + +Functions to read and write the metadata and parameters + +### nc_utils.py + +Functions to read and write the rain fields as CF netCDF binaries. + diff --git a/pysteps/mongo/create_mongo_user.py b/pysteps/mongo/create_mongo_user.py new file mode 100644 index 000000000..8f5bce982 --- /dev/null +++ b/pysteps/mongo/create_mongo_user.py @@ -0,0 +1,44 @@ +import secrets +import string +import os +from pymongo import MongoClient +from urllib.parse import quote_plus + +# === CONFIGURATION === +MONGO_HOST = "localhost" +MONGO_PORT = 27017 +AUTH_DB = "admin" +MONGO_ADMIN_USER = os.getenv("MONGO_USER") +MONGO_ADMIN_PASS = os.getenv("MONGO_PWD") +TARGET_DB = "STEPS" +PWD_DEFAULT = "c-bandBox" + +# === FUNCTIONS === +def generate_password(length=16): + alphabet = string.ascii_letters + string.digits + "!@#$%^&*()-_=+" + return ''.join(secrets.choice(alphabet) for _ in range(length)) + +def create_user(username, role="readWrite"): + # password = generate_password() + password = PWD_DEFAULT + client = MongoClient(f"mongodb://{quote_plus(MONGO_ADMIN_USER)}:{quote_plus(MONGO_ADMIN_PASS)}@{MONGO_HOST}:{MONGO_PORT}/?authSource={AUTH_DB}") + db = client[TARGET_DB] + + try: + db.command("createUser", username, pwd=password, roles=[{"role": role, "db": TARGET_DB}]) + print(f"\n✅ User '{username}' created with role '{role}'.\n") + print("Connection string:") + print(f" mongodb://{quote_plus(username)}:{quote_plus(password)}@{MONGO_HOST}:{MONGO_PORT}/{TARGET_DB}?authSource={TARGET_DB}\n") + except Exception as e: + print(f"❌ Failed to create user '{username}': {e}") + +# === ENTRY POINT === +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Create a MongoDB user with a random password.") + parser.add_argument("username", help="Username to create") + parser.add_argument("--role", default="readWrite", help="MongoDB role (default: readWrite)") + args = parser.parse_args() + create_user(args.username, args.role) + diff --git a/pysteps/mongo/delete_files.py b/pysteps/mongo/delete_files.py new file mode 100644 index 000000000..71279662d --- /dev/null +++ b/pysteps/mongo/delete_files.py @@ -0,0 +1,125 @@ +from models import get_db +from pymongo import MongoClient +import logging +import argparse +import gridfs +import pymongo +import datetime + +def is_valid_iso8601(time_str: str) -> bool: + """Check if the given string is a valid ISO 8601 datetime.""" + try: + datetime.datetime.fromisoformat(time_str) + return True + except ValueError: + return False + + +def main(): + parser = argparse.ArgumentParser( + description="Delete rainfall and/or state GridFS files.") + + parser.add_argument('-s', '--start', type=str, required=True, + help='Start time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-e', '--end', type=str, required=True, + help='End time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-n', '--name', type=str, required=True, + help='Name of domain [AKL]') + parser.add_argument('-p', '--product', type=str, required=True, + help='Name of product to delete [QPE, auckprec, qpesim]') + parser.add_argument('-c', '--cascade', default=False, action='store_true', + help='Delete the cascade files') + parser.add_argument('-r', '--rain', default=False, action='store_true', + help='Delete the rainfall files') + parser.add_argument('--params', default=False, action='store_true', + help='Delete the parameter documents') + + parser.add_argument('--dry_run', default=False, action='store_true', + help='Only list files that would be deleted, don’t delete them.') + + args = parser.parse_args() + logging.basicConfig(level=logging.INFO) + + if not (args.rain or args.cascade or args.params): + logging.warning("Nothing to delete: specify --rain, --cascade, or --params") + return + + # Validate and parse times + def parse_time(time_str): + if not is_valid_iso8601(time_str): + logging.error(f"Invalid time format: {time_str}") + exit(1) + t = datetime.datetime.fromisoformat(time_str) + return t.replace(tzinfo=datetime.timezone.utc) if t.tzinfo is None else t + + start_time = parse_time(args.start) + end_time = parse_time(args.end) + + name = args.name + product = args.product + dry_run = args.dry_run + + if product not in ["QPE", "auckprec", "qpesim", "nwpblend"]: + logging.error(f"Invalid product: {product}") + return + + db = get_db() + + def delete_files(collection_name): + coll = db[f"{collection_name}.files"] + fs = gridfs.GridFS(db, collection=collection_name) + + if product == "QPE": + query = { + "metadata.product": product, + "metadata.valid_time": {"$gte": start_time, "$lte": end_time} + } + else: + query = { + "metadata.product": product, + "metadata.base_time": {"$gte": start_time, "$lte": end_time} + } + + ids = list(coll.find(query, {"_id": 1,"filename":1})) + count = len(ids) + + if dry_run: + logging.info(f"[Dry Run] {count} files matched in {collection_name}. Listing _id values:") + for doc in ids: + logging.info(f" Would delete: {doc['filename']}") + else: + for doc in ids: + fs.delete(doc["_id"]) + logging.info(f"Deleted {count} files from {collection_name}") + + if args.rain: + delete_files(f"{name}.rain") + + if args.cascade: + delete_files(f"{name}.state") + + if args.params: + collection_name = f"{name}.params" + coll = db[collection_name] + if product == "QPE": + query = { + "metadata.product": product, + "metadata.valid_time": {"$gte": start_time, "$lte": end_time} + } + else: + query = { + "metadata.product": product, + "metadata.base_time": {"$gte": start_time, "$lte": end_time} + } + + ids = list(coll.find(query, {"_id": 1})) + count = len(ids) + if dry_run: + logging.info(f"[Dry Run] {count} files matched in {collection_name}") + else: + coll.delete_many(query) + logging.info(f"Deleted {count} files from {collection_name}") + + +if __name__ == "__main__": + main() diff --git a/pysteps/mongo/gridfs_io.py b/pysteps/mongo/gridfs_io.py new file mode 100644 index 000000000..dc5cf68da --- /dev/null +++ b/pysteps/mongo/gridfs_io.py @@ -0,0 +1,266 @@ +# Contains: store_cascade_to_gridfs, load_cascade_from_gridfs, load_rain_field, get_rain_fields, get_states +from io import BytesIO +import gridfs +import numpy as np +import pymongo +import copy +import datetime +from typing import Dict, Any, Optional, Union, Tuple + +def store_cascade_to_gridfs(db, name, cascade_dict, oflow, file_name, field_metadata): + """ + Stores a pysteps cascade decomposition dictionary into MongoDB's GridFS. + + Parameters: + db (pymongo.database.Database): The MongoDB database object. + cascade_dict (dict): The pysteps cascade decomposition dictionary. + oflow (np.ndarray): The optical flow field. + file_name (str): The (unique) name of the file to be stored. + field_metadata (dict): Additional metadata related to the field. + + Returns: + bson.ObjectId: The GridFS file ID. + """ + assert cascade_dict["domain"] == "spatial", "Only 'spatial' domain is supported." + state_col_name = f"{name}.state" + fs = gridfs.GridFS(db, collection=state_col_name) + + # Delete existing file with same filename + for old_file in fs.find({"filename": file_name}): + fs.delete(old_file._id) + + # Convert cascade_levels and oflow to a compressed format + buffer = BytesIO() + np.savez_compressed( + buffer, cascade_levels=cascade_dict["cascade_levels"], oflow=oflow) + buffer.seek(0) + + # Prepare metadata + metadata = { + "filename": file_name, + "domain": cascade_dict["domain"], + "normalized": cascade_dict["normalized"], + "transform": cascade_dict.get("transform"), + "threshold": cascade_dict.get("threshold"), + "zerovalue": cascade_dict.get("zerovalue") + } + metadata.update(field_metadata) # Merge additional metadata + + # Add optional statistics if available + if "means" in cascade_dict: + metadata["means"] = cascade_dict["means"] + if "stds" in cascade_dict: + metadata["stds"] = cascade_dict["stds"] + + # Store binary data and metadata atomically in GridFS + file_id = fs.put(buffer.getvalue(), filename=file_name, metadata=metadata) + + return file_id + + +def load_cascade_from_gridfs(db, name, file_name): + """ + Loads a pysteps cascade decomposition dictionary and optical flow from MongoDB's GridFS. + + Parameters: + db (pymongo.database.Database): The MongoDB database object. + file_name (str): The name of the file to retrieve. + + Returns: + tuple: (cascade_dict, oflow, metadata) + """ + state_col_name = f"{name}.state" + fs = gridfs.GridFS(db, collection=state_col_name) + + # Retrieve the file from GridFS + grid_out = fs.find_one({"filename": file_name}) + if grid_out is None: + raise ValueError(f"No file found with filename: {file_name}") + + # Retrieve metadata + metadata = grid_out.metadata + + # Read and decompress stored arrays + buffer = BytesIO(grid_out.read()) + npzfile = np.load(buffer) + + # Reconstruct cascade dictionary including the initial field transformation + cascade_dict = { + "cascade_levels": npzfile["cascade_levels"], + "domain": metadata["domain"], + "normalized": metadata["normalized"], + "transform": metadata.get("transform"), + "threshold": metadata.get("threshold"), + "zerovalue": metadata.get("zerovalue") + } + + # Restore optional statistics if they exist + if "means" in metadata: + cascade_dict["means"] = metadata["means"] + if "stds" in metadata: + cascade_dict["stds"] = metadata["stds"] + + oflow = npzfile["oflow"] # Optical flow field + + return cascade_dict, oflow, metadata + + +def load_rain_field(db, name, filename, nc_buf, metadata): + + # Check if the file exists, if yes then delete it + rain_col_name = f"{name}.rain" + + fs = gridfs.GridFS(db, collection=rain_col_name) + + existing_file = fs.find_one( + {"filename": filename}) + if existing_file: + fs.delete(existing_file._id) + + # Upload to GridFS + fs.put(nc_buf.tobytes(), + filename=filename, metadata=metadata) + + +def get_rain_fields(db: pymongo.MongoClient, name: str, query: dict): + rain_col_name = f"{name}.rain" + meta_col_name = f"{name}.rain.files" + fs = gridfs.GridFS(db, collection=rain_col_name) + meta_coll = db[meta_col_name] + + # Fetch matching filenames and metadata in a single query + fields_projection = {"_id": 0, "filename": 1, "metadata": 1} + results = meta_coll.find(query, projection=fields_projection).sort( + "filename", pymongo.ASCENDING) + + fields = [] + + # Process each matching file + for doc in results: + filename = doc["filename"] + + # Fetch metadata from GridFS + grid_out = fs.find_one({"filename": filename}) + if grid_out is None: + logging.warning(f"File {filename} not found in GridFS, skipping.") + continue + + rain_fs_metadata = grid_out.metadata if hasattr( + grid_out, "metadata") else {} + + # Copy relevant metadata + field_metadata = { + "filename": filename, + "product": rain_fs_metadata.get("product", "unknown"), + "domain": rain_fs_metadata.get("domain", "AKL"), + "ensemble": rain_fs_metadata.get("ensemble", None), + "base_time": rain_fs_metadata.get("base_time", None), + "valid_time": rain_fs_metadata.get("valid_time", None), + "mean": rain_fs_metadata.get("mean", 0), + "std_dev": rain_fs_metadata.get("std_dev", 0), + "wetted_area_ratio": rain_fs_metadata.get("wetted_area_ratio", 0) + } + + # Stream and decompress data + buffer = BytesIO(grid_out.read()) + rain_geodata, _, rain_data = read_nc(buffer) # Fixed variable name + + # Add the georeferencing metadata dictionary + field_metadata["geo_data"] = rain_geodata + + # Store the final record + record = {"rain": rain_data.copy( + ), "metadata": copy.deepcopy(field_metadata)} + fields.append(record) # Append the record to the list + + return fields + + +def get_states(db: pymongo.MongoClient, name: str, query: dict, + get_cascade: Optional[bool] = True, + get_optical_flow: Optional[bool] = True + ) -> Dict[Tuple[Any, Any, Any], Dict[str, Optional[Union[dict, np.ndarray]]]]: + """ + Retrieve state fields (cascade and/or optical flow) from a GridFS collection, + indexed by (valid_time, base_time, ensemble). + + Args: + db (pymongo.MongoClient): Database with the state collections. + name (str): Name prefix of the state collections. + query (dict): Mongo query for filtering state files. + get_cascade (bool, optional): Whether to retrieve cascade state. Defaults to True. + get_optical_flow (bool, optional): Whether to retrieve optical flow. Defaults to True. + + Returns: + dict: {(valid_time, base_time, ensemble): {"cascade": dict or None, + "optical_flow": np.ndarray or None, + "metadata": dict}} + """ + state_col_name = f"{name}.state" + meta_col_name = f"{name}.state.files" + fs = gridfs.GridFS(db, collection=state_col_name) + meta_coll = db[meta_col_name] + + fields = {"_id": 0, "filename": 1, "metadata": 1} + results = meta_coll.find(query, projection=fields).sort("filename", pymongo.ASCENDING) + + states = {} + + for doc in results: + state_file = doc["filename"] + metadata_dict = doc.get("metadata", {}) + + valid_time = metadata_dict.get("valid_time") + if valid_time is None: + logging.warning(f"No valid_time in metadata for file {state_file}, skipping.") + continue + if valid_time.tzinfo is None: + valid_time = valid_time.replace(tzinfo=datetime.timezone.utc) + + base_time = metadata_dict.get("base_time", "NA") + if base_time is not None and base_time.tzinfo is None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + + ensemble = metadata_dict.get("ensemble", "NA") + + # Set missing base_time or ensemble to "NA" + if base_time is None: + base_time = "NA" + if ensemble is None: + ensemble = "NA" + + + grid_out = fs.find_one({"filename": state_file}) + if grid_out is None: + logging.warning(f"File {state_file} not found in GridFS, skipping.") + continue + + buffer = BytesIO(grid_out.read()) + npzfile = np.load(buffer) + + cascade_dict = None + if get_cascade: + cascade_dict = { + "cascade_levels": npzfile["cascade_levels"], + "domain": metadata_dict.get("domain"), + "normalized": metadata_dict.get("normalized"), + "transform": metadata_dict.get("transform"), + "threshold": metadata_dict.get("threshold"), + "zerovalue": metadata_dict.get("zerovalue"), + "means": metadata_dict.get("means"), + "stds": metadata_dict.get("stds"), + } + + oflow = None + if get_optical_flow: + oflow = npzfile["oflow"] + + key = (valid_time, base_time, ensemble) + states[key] = { + "cascade": copy.deepcopy(cascade_dict) if cascade_dict is not None else None, + "optical_flow": oflow.copy() if oflow is not None else None, + "metadata": copy.deepcopy(metadata_dict) + } + + return states + diff --git a/pysteps/mongo/init_steps_db.py b/pysteps/mongo/init_steps_db.py new file mode 100644 index 000000000..dba03afd9 --- /dev/null +++ b/pysteps/mongo/init_steps_db.py @@ -0,0 +1,71 @@ +from pymongo import MongoClient, ASCENDING +import argparse +import os +from pymongo import MongoClient +from urllib.parse import quote_plus + +# === Configuration === +AUTH_DB = "STEPS" +TARGET_DB = "STEPS" +MONGO_HOST = os.getenv("MONGO_HOST","localhost") +MONGO_PORT = os.getenv("MONGO_PORT",27017) +STEPS_USER = os.getenv("STEPS_USER","radar") +STEPS_PWD = os.getenv("STEPS_PWD","c-bandBox") + +# === Functions === +def setup_domain(db, domain_name): + print(f"⏳ Setting up domain: {domain_name}") + + for product in ["rain", "state"]: + files_coll = f"{domain_name}.{product}.files" + chunks_coll = f"{domain_name}.{product}.chunks" + + # Create empty collections (MongoDB creates on first insert, but we want indexes now) + db[files_coll].insert_one({"temp": True}) # insert dummy + db[chunks_coll].insert_one({"temp": True}) + + # Create compound index on files + db[files_coll].create_index([ + ("metadata.product", ASCENDING), + ("metadata.valid_time", ASCENDING), + ("metadata.base_time", ASCENDING), + ("metadata.ensemble", ASCENDING) + ], name="product_valid_base_ensemble_idx") + + # Index for GridFS pre-deletion lookups + db[files_coll].create_index([("filename", ASCENDING)], name="filename_idx") + + # Remove dummy record + db[files_coll].delete_many({"temp": True}) + db[chunks_coll].delete_many({"temp": True}) + + print(f"✅ {files_coll} and {chunks_coll} initialized with index") + + # Create a per-domain params collection + params_coll = f"{domain_name}.params" + db[params_coll].insert_one({"_test": True}) + db[params_coll].delete_many({"_test": True}) + print(f"✅ {params_coll} initialized") + +def setup_config(db): + config_coll = "config" + db[config_coll].insert_one({"_test": True}) + db[config_coll].delete_many({"_test": True}) + print(f"✅ {config_coll} initialized (shared)") + +# === Main === +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Initialize STEPS MongoDB structure") + parser.add_argument("domains", nargs="+", help="List of domain names to set up (e.g. AKL WLG CHC)") + args = parser.parse_args() + connect_string = f"mongodb://{quote_plus(STEPS_USER)}:{quote_plus(STEPS_PWD)}@{MONGO_HOST}:{MONGO_PORT}/STEPS?authSource={AUTH_DB}" + print(f"Connecting to {connect_string}") + client = MongoClient(connect_string) + db = client[TARGET_DB] + + for domain in args.domains: + setup_domain(db, domain) + + setup_config(db) + print("🎉 Setup complete.") + diff --git a/pysteps/mongo/load_config.py b/pysteps/mongo/load_config.py new file mode 100644 index 000000000..f20a92a47 --- /dev/null +++ b/pysteps/mongo/load_config.py @@ -0,0 +1,283 @@ +import argparse +import json +import logging +import os +from pathlib import Path +import datetime +from pymongo import MongoClient, errors +from urllib.parse import quote_plus +from models import get_db + +# Default pysteps configuration values +DEFAULT_PYSTEPS_CONFIG = { + "precip_threshold": None, + "extrapolation_method": "semilagrangian", + "decomposition_method": "fft", + "bandpass_filter_method": "gaussian", + "noise_method": "nonparametric", + "noise_stddev_adj": None, + "ar_order": 1, + "scale_break": None, + "velocity_perturbation_method": None, + "conditional": False, + "probmatching_method": "cdf", + "mask_method": "incremental", + "seed": None, + "num_workers": 1, + "fft_method": "numpy", + "domain": "spatial", + "extrapolation_kwargs": {}, + "filter_kwargs": {}, + "noise_kwargs": {}, + "velocity_perturbation_kwargs": {}, + "mask_kwargs": {}, + "measure_time": False, + "callback": None, + "return_output": True +} + +valid_product_list = ["qpesim", "auckprec", "nowcast", "nwpblend"] + +# Default output configuration +DEFAULT_OUTPUT_CONFIG = { + "qpesim":{ + "gridfs_out": True, + "nc_out": False, + "out_product": "qpesim", + "out_dir_name": None, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "auckprec":{ + "gridfs_out": True, + "nc_out": False, + "out_product": "auckprec", + "tmp_dir": "$HOME/tmp", + "out_dir_name": None, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "nowcast":{ + "gridfs_out": False, + "nc_out": False, + "out_product": "nowcast", + "out_dir_name": None, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "nwpblend":{ + "gridfs_out": True, + "nc_out": False, + "out_product": "nwpblend", + "out_dir_name": None, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + } +} + +# Default domain configuration +DEFAULT_DOMAIN_CONFIG = { + "n_rows": None, + "n_cols": None, + "p_size": None, + "start_x": None, + "start_y": None +} + +# Default projection configuration for NZ +DEFAULT_PROJECTION_CONFIG = { + "epsg": "EPSG:2193", + "name": "transverse_mercator", + "central_meridian": 173.0, + "latitude_of_origin": 0.0, + "scale_factor": 0.9996, + "false_easting": 1600000.0, + "false_northing": 10000000.0 +} + + +def file_exists(file_path: Path) -> bool: + """Check if the given file path exists.""" + return file_path.is_file() + + +def load_config(config_path: Path) -> dict: + """Load the full configuration from a JSON file, applying defaults for missing fields.""" + try: + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + except json.JSONDecodeError: + logging.error(f"Error decoding JSON file: {config_path}") + return {} + + name = config.get("name", None) + if name is None: + logging.error( + f"Domain name not found") + return {} + + # Extract and validate pysteps configuration + pysteps_config = config.get("pysteps", {}) + if not isinstance(pysteps_config, dict): + logging.error( + f"Malformed pysteps configuration in {config_path}, expected a dictionary.") + return {} + + # Apply default values + for key, default_value in DEFAULT_PYSTEPS_CONFIG.items(): + if key not in pysteps_config: + logging.warning( + f"Missing key '{key}' in pysteps configuration, using default value.") + pysteps_config[key] = default_value + + # Validate mandatory keys + required_pysteps_keys = [ + "n_cascade_levels", "timestep", "kmperpixel" + ] + for key in required_pysteps_keys: + if key not in pysteps_config: + logging.error( + f"Missing mandatory key '{key}' in pysteps configuration.") + return {} + + # Extract and validate output configurations + output_config = config.get("output", {}) + if not isinstance(output_config, dict): + logging.error( + f"Malformed output configuration in {config_path}, expected a dictionary." + ) + return {} + + # Ensure "products" key exists and is a list + valid_product_list = ["qpesim", "auckprec", "nowcast", "nwpblend"] + + products = output_config.get("products", []) + if not isinstance(products, list): + logging.error( + f"Malformed 'products' key in output configuration, expected a list." + ) + return {} + + # Dictionary to store parsed output configurations + parsed_output_config = {} + + # Iterate over each product and extract its configuration + for product in products: + + if product not in valid_product_list: + logging.error( + f"Unexpected product found, '{product}' not in {valid_product_list}." + ) + continue + + product_config = output_config.get(product, {}) + + if not isinstance(product_config, dict): + logging.error( + f"Malformed configuration for product '{product}', expected a dictionary." + ) + continue + + # Merge with defaults + complete_config = DEFAULT_OUTPUT_CONFIG[product].copy() + complete_config.update(product_config) + + parsed_output_config[product] = complete_config + + # Extract and validate the domain location configuration + domain_config = config.get("domain", {}) + if not isinstance(domain_config, dict): + logging.error( + f"Malformed domain configuration in {config_path}, expected a dictionary.") + return {} + + for key, default_value in DEFAULT_DOMAIN_CONFIG.items(): + if key not in domain_config: + logging.error(f"Missing key '{key}' in domain configuration.") + return {} + + # Extract and validate the projection configuration - assumes CF fields for Transverse Mercator + projection_config = config.get("projection", {}) + if not isinstance(projection_config, dict): + logging.error( + f"Malformed projection configuration in {config_path}, expected a dictionary.") + return {} + + for key, default_value in DEFAULT_PROJECTION_CONFIG.items(): + if key not in projection_config: + logging.warning( + f"Missing key '{key}' in projection configuration, using default value") + projection_config[key] = default_value + + # Get the dynamic scaling if present + dynamic_scaling_config = config.get("dynamic_scaling", {}) + + # Only check for required keys if the dictionary is not empty + if dynamic_scaling_config: + required_ds_keys = ["central_wave_lengths", + "space_time_exponent", "lag2_constants", "lag2_exponents"] + for key in required_ds_keys: + if key not in dynamic_scaling_config: + logging.error( + f"Missing mandatory key '{key}' in dynamic_scaling configuration.") + return {} + + return { + "name": name, + "pysteps": pysteps_config, + "output": parsed_output_config, + "domain": domain_config, + "projection": projection_config, + "dynamic_scaling": dynamic_scaling_config + } + + +def insert_config_into_mongodb(config: dict): + """Insert the configuration into the MongoDB config collection.""" + record = { + "time": datetime.datetime.now(datetime.timezone.utc), + "config": config + } + + try: + db = get_db() + collection = db["config"] + + # Insert the record + result = collection.insert_one(record) + logging.info( + f"Configuration inserted successfully. Document ID: {result.inserted_id}") + + except errors.ServerSelectionTimeoutError: + logging.error( + "Failed to connect to MongoDB. Check if MongoDB is running and the URI is correct.") + except errors.PyMongoError as e: + logging.error(f"MongoDB error: {e}") + + +def main(): + parser = argparse.ArgumentParser( + description="Insert pysteps configuration into MongoDB" + ) + parser.add_argument('-c', '--config', type=Path, + help='Path to configuration file') + parser.add_argument('-v', '--verbose', action='store_true', + help='Enable verbose logging') + + args = parser.parse_args() + + # Configure logging + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + # Validate config file path + if not args.config or not file_exists(args.config): + logging.error(f"Configuration file does not exist: {args.config}") + return + + # Load the full configuration + config = load_config(args.config) + + if config: + logging.info("Final loaded configuration:\n%s", + json.dumps(config, indent=2)) + insert_config_into_mongodb(config) + + +if __name__ == "__main__": + main() diff --git a/pysteps/mongo/mongo_access.py b/pysteps/mongo/mongo_access.py new file mode 100644 index 000000000..9382a156d --- /dev/null +++ b/pysteps/mongo/mongo_access.py @@ -0,0 +1,184 @@ +# Contains: get_db, get_config, get_parameters_df, get_parameters, to_utc_naive +from typing import Dict +import pandas as pd +import datetime +import os +import logging +import pymongo.collection +from pymongo import MongoClient +from urllib.parse import quote_plus +from models.steps_params import StochasticRainParameters +from models.cascade_utils import get_cascade_wavelengths + + +def get_parameters(query: Dict, param_coll) -> Dict: + """ + Get the parameters matching the query, indexed by valid_time. + + Args: + query (dict): MongoDB query dictionary. + param_coll (pymongo collection): Collection with the parameters. + + Returns: + dict: Dictionary {valid_time: StochasticRainParameters} + """ + result = {} + for doc in param_coll.find(query).sort("metadata.valid_time", pymongo.ASCENDING): + try: + param = StochasticRainParameters.from_dict(doc) + param.calc_corl() + result[param.valid_time] = param + except Exception as e: + print( + f"Warning: could not parse parameter for valid_time {doc.get('valid_time')}: {e}") + return result + + +def get_parameters_df(query: Dict, param_coll: pymongo.collection.Collection) -> pd.DataFrame: + """ + Retrieve STEPS parameters from the database and return a DataFrame + indexed by (valid_time, base_time, ensemble), using 'NA' as sentinel for missing values. + + Args: + query (dict): MongoDB query dictionary. + param_coll (pymongo.collection.Collection): MongoDB collection. + + Returns: + pd.DataFrame: Indexed by (valid_time, base_time, ensemble), with a 'param' column. + """ + records = [] + + for doc in param_coll.find(query).sort("metadata.valid_time", pymongo.ASCENDING): + try: + metadata = doc.get("metadata", {}) + if metadata is None: + continue + + if doc["cascade"]["lag1"] is None or doc["cascade"]["lag2"] is None: + continue + + valid_time = metadata.get("valid_time") + if valid_time is not None and valid_time.tzinfo is None: + valid_time = valid_time.replace(tzinfo=datetime.timezone.utc) + + base_time = metadata.get("base_time") + if base_time is None: + base_time = "NA" + elif base_time.tzinfo is None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + + ensemble = metadata.get("ensemble") if metadata.get( + "ensemble") is not None else "NA" + param = StochasticRainParameters.from_dict(doc) + + param.calc_corl() + records.append({ + "valid_time": valid_time, + "base_time": base_time, + "ensemble": ensemble, + "param": param + }) + except Exception as e: + print( + f"Warning: could not parse parameter for {metadata.get('valid_time')}: {e}") + + if not records: + return pd.DataFrame(columns=["valid_time", "base_time", "ensemble", "param"]) + + df = pd.DataFrame(records) + return df + + +def get_config(db: pymongo.MongoClient, name: str) -> Dict: + """_summary_ + Return the most recent configuration setting + Args: + db (pymongo.MongoClient): Project database + + Returns: + Dict: Project configuration dictionary + """ + + config_coll = db["config"] + record = config_coll.find_one({'config.name': name}, sort=[ + ('time', pymongo.DESCENDING)]) + if record is None: + logging.error(f"Could not find configuration for domain {name}") + return None + + config = record['config'] + return config + + +def get_db(mongo_port=None): + MONGO_HOST = os.getenv("MONGO_HOST", "localhost") + # Use the function argument if provided, otherwise fall back to the environment variable, then default + MONGO_PORT = mongo_port if mongo_port is not None else int( + os.getenv("MONGO_PORT", 27017)) + + if mongo_port is None: + logging.info(f"Using MONGO_PORT from env: {MONGO_PORT}") + else: + logging.info(f"Using MONGO_PORT from argument: {mongo_port}") + + STEPS_USER = os.getenv("STEPS_USER", "radar") + STEPS_PWD = os.getenv("STEPS_PWD", "c-bandBox") + AUTH_DB = "STEPS" + TARGET_DB = "STEPS" + + conect_string = ( + f"mongodb://{quote_plus(STEPS_USER)}:{quote_plus(STEPS_PWD)}" + f"@{MONGO_HOST}:{MONGO_PORT}/STEPS?authSource={AUTH_DB}" + ) + logging.info(f"Connecting to {conect_string}") + client = MongoClient(conect_string) + db = client[TARGET_DB] + return db + + +def to_utc_naive(dt): + if dt.tzinfo is not None: + return dt.astimezone(datetime.timezone.utc).replace(tzinfo=None) + return dt + + +def get_central_wavelengths(db, name): + config = get_config(db, name) + n_levels = config["pysteps"].get("n_cascade_levels") + domain = config["domain"] + n_rows = domain.get("n_rows") + n_cols = domain.get("n_cols") + p_size = domain.get("p_size") + p_size_km = p_size / 1000.0 + domain_size_km = max(n_rows, n_cols) * p_size_km + + # Get central wavelengths + wavelengths_km = get_cascade_wavelengths( + n_levels, domain_size_km, p_size_km) + return wavelengths_km + +def get_base_time(valid_time, product, name, db): + # Get the base_time for the nwp run nearest to the valid_time in UTC zone + # Assume spin-up of 3 hours + start_base_time = valid_time - datetime.timedelta(hours=27) + end_base_time = valid_time - datetime.timedelta(hours=3) + base_time_query = { + "metadata.product": product, + "metadata.base_time": {"$gte": start_base_time, "$lte": end_base_time} + } + col_name = f"{name}.rain.files" + nwp_base_times = db[col_name].distinct( + "metadata.base_time", base_time_query) + + if nwp_base_times is None: + logging.warning( + f"Failed to find {product} data for {valid_time}") + return None + + nwp_base_times.sort(reverse=True) + base_time = nwp_base_times[0] + + if base_time.tzinfo is None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + + return base_time diff --git a/pysteps/mongo/mongodb_port_forwarding.md b/pysteps/mongo/mongodb_port_forwarding.md new file mode 100644 index 000000000..75cb2ea3d --- /dev/null +++ b/pysteps/mongo/mongodb_port_forwarding.md @@ -0,0 +1,128 @@ +# Port Forwarding to Remote MongoDB on Localhost + +This guide explains how to open port **27018** on your local machine and forward it to a **remote MongoDB** instance running on port **27017**, either temporarily or permanently using `systemd` and `autossh`. + +--- + +## 🔁 SSH Port Forwarding (Temporary) + +To forward local port `27018` to the remote MongoDB server on port `27017`, run: + +```bash +ssh -L 27018:localhost:27017 your_user@remote_host +``` + +Once the tunnel is active, connect to MongoDB using: + +```bash +mongosh --port 27018 +``` + +Or with a URI: + +```bash +mongodb://localhost:27018 +``` + +--- + +## 🔄 Making Port Forwarding Persistent with systemd and autossh + +To create a self-healing SSH tunnel that auto-reconnects on failure, use `autossh` with a `systemd` user service. + +### 1. Install autossh + +On Fedora: + +```bash +sudo dnf install autossh +``` + +On Debian/Ubuntu: + +```bash +sudo apt install autossh +``` + +--- + +### 2. Set up SSH keys + +```bash +ssh-keygen +ssh-copy-id radar@remote_host +``` + +Make sure `ssh radar@remote_host` works without a password. + +--- + +### 3. Create the systemd user service + +Create the file: +`~/.config/systemd/user/mongodb-tunnel.service` + +```ini +[Unit] +Description=Persistent SSH tunnel to radar MongoDB +After=network.target + +[Service] +Environment=AUTOSSH_GATETIME=0 +ExecStart=/usr/bin/autossh -M 0 -N -L 27018:10.8.0.41:27017 radar +Restart=always +RestartSec=10 + +[Install] +WantedBy=default.target +``` + +> - Replace `10.8.0.41` with the IP of the remote MongoDB server (not necessarily `localhost` on the remote if it’s bound to a specific interface). +> - `radar` is your SSH alias or username. Make sure it’s configured in `~/.ssh/config` if using an alias. + +--- + +### 4. Enable and start the tunnel + +```bash +systemctl --user daemon-reexec +systemctl --user daemon-reload +systemctl --user enable mongodb-tunnel +systemctl --user start mongodb-tunnel +``` + +To check the status: + +```bash +systemctl --user status mongodb-tunnel +``` + +--- + +### 5. Optional: Ensure ssh-agent is running + +Add to your shell startup script: + +```bash +eval "$(ssh-agent -s)" +ssh-add ~/.ssh/id_rsa +``` + +Or use your desktop’s SSH key manager. + +--- + +## ✅ Verifying the Tunnel + +Once active, test it with: + +```bash +mongosh --port 27018 +``` + +--- + +## 🔐 Security Note + +- Keep your SSH key safe with a passphrase. +- Use `ufw`, `firewalld`, or similar to restrict access if needed. diff --git a/pysteps/mongo/nc_utils.py b/pysteps/mongo/nc_utils.py new file mode 100644 index 000000000..cc5690cf7 --- /dev/null +++ b/pysteps/mongo/nc_utils.py @@ -0,0 +1,390 @@ +""" +Refactored IO utilities for pysteps. +""" + +import numpy as np +from pyproj import CRS +import netCDF4 +import datetime +from typing import Optional +import io +from pathlib import Path + + +def replace_extension(filename: str, new_ext: str) -> str: + return f"{filename.rsplit('.', 1)[0]}{new_ext}" + + +def convert_timestamps_to_datetimes(timestamps): + """Convert POSIX timestamps to datetime objects.""" + return [ + datetime.datetime.fromtimestamp(ts, tz=datetime.timezone.utc) + for ts in timestamps + ] + + +def write_netcdf_file( + file_path: Path, + rain: np.ndarray, + geo_data: dict, + valid_times: list[datetime.datetime], + ensembles: list[int] | None, +) -> None: + """ + Write a set of rainfall grids to a CF-compliant NetCDF file using i2 data and scale_factor. + + Args: + file_path (Path): Full path to the output file. + rain (np.ndarray): Rainfall array. Shape is [ensemble, time, y, x] if ensembles is provided, + otherwise [time, y, x], with units in mm/h as float. + geo_data (dict): Geospatial metadata (must include 'x', 'y', and optionally 'projection'). + valid_times (list[datetime.datetime]): List of timezone-aware valid times. + ensembles (list[int] | None): Optional list of ensemble member IDs. + """ + # Convert datetime to seconds since epoch + time_stamps = [vt.timestamp() for vt in valid_times] + + x = geo_data["x"] + y = geo_data["y"] + projection = geo_data.get("projection", "EPSG:4326") + rain_fill_value = -1 + + with netCDF4.Dataset(file_path, mode="w", format="NETCDF4") as ds: + # Define dimensions + ds.createDimension("y", len(y)) + ds.createDimension("x", len(x)) + ds.createDimension("time", len(valid_times)) + if ensembles is not None: + ds.createDimension("ensemble", len(ensembles)) + + # Define coordinate variables + x_var = ds.createVariable("x", "f4", ("x",)) + x_var[:] = x + x_var.standard_name = "projection_x_coordinate" + x_var.units = "m" + + y_var = ds.createVariable("y", "f4", ("y",)) + y_var[:] = y + y_var.standard_name = "projection_y_coordinate" + y_var.units = "m" + + t_var = ds.createVariable("time", "f8", ("time",)) + t_var[:] = time_stamps + t_var.standard_name = "time" + t_var.units = "seconds since 1970-01-01T00:00:00Z" + t_var.calendar = "standard" + + if ensembles is not None: + e_var = ds.createVariable("ensemble", "i4", ("ensemble",)) + e_var[:] = ensembles + e_var.standard_name = "ensemble" + e_var.units = "1" + + # Define the rainfall variable with proper fill_value + rain_dims = ( + ("time", "y", "x") if ensembles is None else ("ensemble", "time", "y", "x") + ) + rain_var = ds.createVariable( + "rainfall", + "i2", + rain_dims, + zlib=True, + complevel=5, + fill_value=rain_fill_value, + ) + + # Scale and store rainfall + rain_scaled = np.where( + np.isnan(rain), rain_fill_value, np.round(rain * 10).astype(np.int16) + ) + rain_var[...] = rain_scaled + + # Metadata + rain_var.scale_factor = 0.1 + rain_var.add_offset = 0.0 + rain_var.units = "mm/h" + rain_var.long_name = "Rainfall rate" + rain_var.grid_mapping = "projection" + rain_var.coordinates = " ".join(rain_dims) + + # CRS + crs = CRS.from_user_input(projection) + cf_grid_mapping = crs.to_cf() + spatial_ref = ds.createVariable("projection", "i4") + for key, value in cf_grid_mapping.items(): + setattr(spatial_ref, key, value) + + # Global attributes + ds.Conventions = "CF-1.8" + ds.title = "" + ds.institution = "" + ds.references = "" + ds.comment = "" + + +import io +import tempfile +import netCDF4 +import os +import numpy as np +from pyproj import CRS + + +def make_netcdf_buffer(rain: np.ndarray, geo_data: dict, time: int) -> io.BytesIO: + """ + Make the BytesIO netcdf object that is needed for writing to GridFS database + Args: + rain (np.ndarray): array of rain rates in mm/h as float + geo_data (dict): spatial metadata + time (int): seconds since 1970-01-01T00:00:00Z + + Returns: + io.BytesIO: _description_ + """ + x = geo_data["x"] + y = geo_data["y"] + projection = geo_data.get("projection", "EPSG:4326") + + # Use NamedTemporaryFile to create a temp NetCDF file + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + tmp_path = tmp.name + + # Create NetCDF file on disk + ds = netCDF4.Dataset(tmp_path, mode="w", format="NETCDF4") + + # Define dimensions + ds.createDimension("y", len(y)) + ds.createDimension("x", len(x)) + ds.createDimension("time", 1) + + # Coordinate variables + y_var = ds.createVariable("y", "f4", ("y",)) + x_var = ds.createVariable("x", "f4", ("x",)) + t_var = ds.createVariable("time", "i8", ("time",)) + + # Rainfall variable, + # Expects a float input array and the packing to i2 is done by the netCDF4 library + rain_var = ds.createVariable( + "rainfall", "i2", ("time", "y", "x"), zlib=True, complevel=5, fill_value=-1 + ) + rain_var.scale_factor = 0.1 + rain_var.add_offset = 0.0 + rain_var.units = "mm/h" + rain_var.long_name = "Rainfall rate" + rain_var.grid_mapping = "projection" + + # Assign values + y_var[:] = y + x_var[:] = x + t_var[:] = [time] + rain_var[0, :, :] = np.nan_to_num(rain, nan=-1) + + y_var.standard_name = "projection_y_coordinate" + y_var.units = "m" + x_var.standard_name = "projection_x_coordinate" + x_var.units = "m" + t_var.standard_name = "time" + t_var.units = "seconds since 1970-01-01T00:00:00Z" + + # CRS + crs = CRS.from_user_input(projection) + cf_grid_mapping = crs.to_cf() + spatial_ref = ds.createVariable("projection", "i4") + for key, value in cf_grid_mapping.items(): + setattr(spatial_ref, key, value) + + # Global attributes + ds.Conventions = "CF-1.8" + ds.title = "Rainfall data" + ds.institution = "" + ds.references = "" + ds.comment = "" + + ds.close() + + # Now read into memory + with open(tmp_path, "rb") as f: + nc_bytes = f.read() + + os.remove(tmp_path) + return io.BytesIO(nc_bytes) + + +def generate_geo_dict(domain: dict) -> dict: + """ + Generate the pysteps geo-spatial metadata from a domain dictionary. + + Args: + domain (dict): pysteps_param domain dictionary + + Returns: + dict: pysteps geo-data dictionary, or {} if required keys are missing + """ + required_keys = {"n_cols", "n_rows", "p_size", "start_x", "start_y"} + missing = required_keys - domain.keys() + if missing: + # Missing keys, return empty dict + return {} + + ncols = domain.get("n_cols") + nrows = domain.get("n_rows") + psize = domain.get("p_size") + start_x = domain.get("start_x") + start_y = domain.get("start_y") + + x = [start_x + i * psize for i in range(ncols)] # type: ignore + y = [start_y + i * psize for i in range(nrows)] # type: ignore + + out_geo = {} + out_geo["x"] = x + out_geo["y"] = y + out_geo["xpixelsize"] = psize + out_geo["ypixelsize"] = psize + out_geo["x1"] = start_x + out_geo["y1"] = start_y + out_geo["x2"] = start_x + (ncols - 1) * psize # type: ignore + out_geo["y2"] = start_y + (nrows - 1) * psize # type: ignore + out_geo["projection"] = domain["projection"]["epsg"] + out_geo["cartesian_unit"] = ("m",) + out_geo["yorigin"] = ("lower",) + out_geo["unit"] = "mm/h" + out_geo["threshold"] = 0 + out_geo["transform"] = None + + return out_geo + + +def read_nc(buffer: bytes): + """ + Read netCDF file from a memory buffer and return geo-referencing data and rain rates. + + :param buffer: Byte data of the NetCDF file from GridFS. + :return: Tuple containing geo-referencing data, valid times, and rain rate array. + """ + # Convert the byte buffer to a BytesIO object + byte_stream = io.BytesIO(buffer) + + # Open the NetCDF dataset + with netCDF4.Dataset("inmemory", mode="r", memory=byte_stream.getvalue()) as ds: + + # Extract geo-referencing data + x = ds.variables["x"][:] + y = ds.variables["y"][:] + + domain = {} + domain["ncols"] = len(x) + domain["nrows"] = len(y) + domain["psize"] = abs(x[1] - x[0]) + domain["start_x"] = x[0] + domain["start_y"] = y[0] + geo_data = generate_geo_dict(domain) + + # Convert timestamps to datetime + valid_times = convert_timestamps_to_datetimes(ds.variables["time"][:]) + + # Extract rain rates + rain_rate = ds.variables["rainfall"][:] + + # Replace invalid data with NaN and squeeze dimensions of np.ndarray + rain_rate = np.squeeze(rain_rate) + rain_rate[rain_rate < 0] = np.nan + + return geo_data, valid_times, rain_rate + + +def validate_keys(keys, mandatory_keys): + """Validate the presence of mandatory keys.""" + missing_keys = [key for key in mandatory_keys if key not in keys] + if missing_keys: + raise KeyError(f"Missing mandatory keys: {', '.join(missing_keys)}") + + +def make_nc_name( + domain: str, + prod: str, + valid_time: datetime.datetime, + base_time: Optional[datetime.datetime] = None, + ens: Optional[int] = None, + name_template: Optional[str] = None, +) -> str: + """ + Generate a unique name for a single rain field using a formatting template. + + Default templates: + Forecast products: "$D_$P_$V{%Y%m%dT%H%M%S}_$B{%Y%m%dT%H%M%S}_$E.nc" + QPE products: "$D_$P_$V{%Y%m%dT%H%M%S}.nc" + + Where: + $D = Domain name + $P = Product name + $V = Valid time (with strftime format) + $B = Base time (with strftime format) + $E = Ensemble number (zero-padded 2-digit) + + Returns: + str: Unique NetCDF file name. + """ + + if not isinstance(valid_time, datetime.datetime): + raise TypeError(f"valid_time must be datetime, got {type(valid_time)}") + + if base_time is not None and not isinstance(base_time, datetime.datetime): + raise TypeError(f"base_time must be datetime or None, got {type(base_time)}") + + # Default template logic + if name_template is None: + name_template = "$D_$P_$V{%Y-%m-%dT%H:%M:%S}" + if base_time is not None: + name_template += "_$B{%Y-%m-%dT%H:%M:%S}" + if ens is not None: + name_template += "_$E" + name_template += ".nc" + + result = name_template + + # Ensure timezone-aware times + if valid_time.tzinfo is None: + valid_time = valid_time.replace(tzinfo=datetime.timezone.utc) + if base_time is not None and base_time.tzinfo is None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + + # Replace flags + while "$" in result: + flag_posn = result.find("$") + if flag_posn == -1: + break + f_type = result[flag_posn + 1] + + try: + if f_type in ["V", "B"]: + field_start = result.find("{", flag_posn + 1) + field_end = result.find("}", flag_posn + 1) + if field_start == -1 or field_end == -1: + raise ValueError( + f"Missing braces for format of '${f_type}' in template." + ) + + fmt = result[field_start + 1 : field_end] + if f_type == "V": + time_str = valid_time.strftime(fmt) + elif f_type == "B" and base_time is not None: + time_str = base_time.strftime(fmt) + else: + time_str = "" + + result = result[:flag_posn] + time_str + result[field_end + 1 :] + + elif f_type == "D": + result = result[:flag_posn] + domain + result[flag_posn + 2 :] + elif f_type == "P": + result = result[:flag_posn] + prod + result[flag_posn + 2 :] + elif f_type == "E" and ens is not None: + result = result[:flag_posn] + f"{ens:02d}" + result[flag_posn + 2 :] + else: + raise ValueError( + f"Unknown or unsupported flag '${f_type}' in template." + ) + except Exception as e: + raise ValueError(f"Error processing flag '${f_type}': {e}") + + return result diff --git a/pysteps/mongo/pysteps_config.json b/pysteps/mongo/pysteps_config.json new file mode 100644 index 000000000..422b398dc --- /dev/null +++ b/pysteps/mongo/pysteps_config.json @@ -0,0 +1,133 @@ +{ + "name": "AKL", + "pysteps": { + "n_cascade_levels": 5, + "timestep": 600, + "kmperpixel": 2.0, + "precip_threshold": 1.0, + "transform":"dB", + "threshold":-10, + "zerovalue":-11, + "scale_break": 20, + "extrapolation_method": "semilagrangian", + "decomposition_method": "fft", + "bandpass_filter_method": "gaussian", + "noise_method": "nonparametric", + "noise_stddev_adj": null, + "ar_order": 2, + "velocity_perturbation_method": null, + "conditional": false, + "probmatching_method": "cdf", + "mask_method": "incremental", + "seed": null, + "num_workers": 1, + "fft_method": "numpy", + "domain": "spatial", + "extrapolation_kwargs": {}, + "filter_kwargs": {}, + "noise_kwargs": {}, + "velocity_perturbation_kwargs": {}, + "mask_kwargs": {}, + "measure_time": false, + "callback": null, + "return_output": true + }, + "output": { + "products":["qpesim", "auckprec","nowcast", "nwpblend"], + "qpesim":{ + "n_ens_members": 50, + "n_forecasts": 144, + "fx_update":10800, + "gridfs_out": true, + "nc_out": false, + "nwp_product": "auckprec", + "rad_product": "QPE", + "out_product": "qpesim", + "out_dir_name": null, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "auckprec":{ + "gridfs_out": true, + "nc_out": false, + "out_product": "auckprec", + "tmp_dir": "$HOME/tmp", + "out_dir_name": null, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "nowcast":{ + "n_ens_members": 25, + "n_forecasts": 12, + "fx_update":1800, + "gridfs_out": true, + "nc_out": false, + "nwp_product": "auckprec", + "rad_product": "QPE", + "out_product": "nowcast", + "out_dir_name": null, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + }, + "nwpblend":{ + "n_ens_members": 25, + "n_forecasts": 72, + "blend_width":180, + "gridfs_out": true, + "nc_out": false, + "nwp_product": "auckprec", + "rad_product": "QPE", + "out_product": "nwpblend", + "out_dir_name": null, + "out_file_name": "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}_$E.nc" + } + + }, + "domain": { + "n_rows": 128, + "n_cols": 128, + "p_size": 2000, + "start_x": 1627000, + "start_y": 5854000 + }, + "projection": { + "epsg": "EPSG:2193", + "name": "transverse_mercator", + "central_meridian": 173.0, + "latitude_of_origin": 0.0, + "scale_factor": 0.9996, + "false_easting": 1600000.0, + "false_northing": 10000000.0 + }, + "dynamic_scaling": { + "central_wave_lengths": [ + 128.0, + 33.793468576034414, + 14.709461552298315, + 6.402665020067988, + 2.0 + ], + "space_time_exponent": 0.8471382397261171, + "lag2_constants": [ + 1.0019662531226008, + 0.9895839949795303, + 0.9544104783679567, + 0.8610790248307003, + 0.6447730842290677 + ], + "lag2_exponents": [ + 2.576213907004238, + 2.7557407261557945, + 2.6102093715829717, + 2.222301153284, + 1.6742864097867338 + ], + "cor_len_percentiles": [ + 95, + 50, + 5 + ], + "cor_len_pvals": [ + 894.5855025523205, + 253.723822128425, + 84.28166129094791 + ] + } +} diff --git a/pysteps/mongo/write_ensemble.py b/pysteps/mongo/write_ensemble.py new file mode 100644 index 000000000..7c06713e5 --- /dev/null +++ b/pysteps/mongo/write_ensemble.py @@ -0,0 +1,335 @@ +""" +Output an nc file with past and forcast ensemble +""" + +from models.mongo_access import get_db, get_config +from models.nc_utils import convert_timestamps_to_datetimes, make_nc_name +from pymongo import MongoClient +import logging +import argparse +import pymongo +from gridfs import GridFSBucket, NoFile +import numpy as np +import datetime +import netCDF4 +import pandas as pd +from pathlib import Path +import xarray as xr +from pyproj import CRS +import io + +import numpy as np +import netCDF4 +from pathlib import Path + + +def write_rainfall_netcdf(filename: Path, rainfall: np.ndarray, + x: np.ndarray, y: np.ndarray, + time: list, ensemble: np.ndarray): + """ + Write rainfall data to NetCDF using low-level netCDF4 interface. + - rainfall: 4D np.ndarray (ensemble, time, y, x), float32, mm/h with NaNs + - x, y: 1D arrays of projection coordinates in meters + - time: list of timezone-aware datetime.datetime objects + - ensemble: 1D array of ensemble member IDs (int) + """ + + n_ens, n_times, ny, nx = rainfall.shape + assert len(time) == n_times + assert len(ensemble) == n_ens + + with netCDF4.Dataset(filename, "w", format="NETCDF4") as ds: + # Create dimensions + ds.createDimension("ensemble", n_ens) + ds.createDimension("time", n_times) + ds.createDimension("y", ny) + ds.createDimension("x", nx) + + # Coordinate variables + x_var = ds.createVariable("x", "f4", ("x",)) + y_var = ds.createVariable("y", "f4", ("y",)) + t_var = ds.createVariable("time", "i4", ("time",)) + ens_var = ds.createVariable("ensemble", "i4", ("ensemble",)) + + x_var[:] = x + y_var[:] = y + ens_var[:] = ensemble + t_var[:] = netCDF4.date2num( + time, units="seconds since 1970-01-01T00:00:00", calendar="standard") + + x_var.units = "m" + x_var.standard_name = "projection_x_coordinate" + y_var.units = "m" + y_var.standard_name = "projection_y_coordinate" + t_var.units = "seconds since 1970-01-01 00:00:00" + t_var.standard_name = "time" + ens_var.long_name = "ensemble member" + + # CRS variable (dummy scalar) + crs_var = ds.createVariable("crs", "i4") + crs_var.grid_mapping_name = "transverse_mercator" + crs_var.scale_factor_at_central_meridian = 0.9996 + crs_var.longitude_of_central_meridian = 173.0 + crs_var.latitude_of_projection_origin = 0.0 + crs_var.false_easting = 1600000.0 + crs_var.false_northing = 10000000.0 + crs_var.semi_major_axis = 6378137.0 + crs_var.inverse_flattening = 298.257222101 + crs_var.spatial_ref = "EPSG:2193" + + # Rainfall variable (compressed int16 with scale) + rain_var = ds.createVariable( + "rainfall", "i2", ("ensemble", "time", "y", "x"), + zlib=True, complevel=5, fill_value=-1 + ) + rain_var.scale_factor = 0.1 + rain_var.add_offset = 0.0 + rain_var.units = "mm/h" + rain_var.long_name = "Rainfall rate" + rain_var.grid_mapping = "crs" + + rainfall[np.isnan(rainfall)] = -1 + rain_var[:, :, :, :] = rainfall + + +def is_valid_iso8601(time_str: str) -> bool: + """Check if the given string is a valid ISO 8601 datetime.""" + try: + datetime.datetime.fromisoformat(time_str) + return True + except ValueError: + return False + + +def get_filenames(db: MongoClient, name: str, query: dict): + meta_coll = db[f"{name}.rain.files"] + + # Fetch matching filenames and metadata in a single query + fields_projection = {"_id": 1, "filename": 1, "metadata": 1} + results = meta_coll.find(query, projection=fields_projection).sort( + "filename", pymongo.ASCENDING) + files = [] + for doc in results: + record = { + "valid_time": doc["metadata"]["valid_time"], + "base_time": doc["metadata"]["base_time"], + "ensemble": doc["metadata"]["ensemble"], + "_id": doc["_id"], + "filename": doc["filename"] + } + files.append(record) + + files_df = pd.DataFrame(files) + return files_df + + +def main(): + parser = argparse.ArgumentParser( + description="Write rainfall fields to a netCDF file") + parser.add_argument('-n', '--name', required=True, + help='Domain name (e.g., AKL)') + parser.add_argument('-b', '--base_time', type=str, required=True, + help='Base time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-d', '--directory', required=True, type=Path, + help='Path to output directory for the figures') + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + # Validate start and end time and read them in + if args.base_time and is_valid_iso8601(args.base_time): + base_time = datetime.datetime.fromisoformat(str(args.base_time)) + if base_time.tzinfo is None: + base_time = base_time.replace(tzinfo=datetime.timezone.utc) + else: + logging.error( + "Invalid base time format. Please provide a valid ISO 8601 time string.") + return + + file_dir = args.directory + if not file_dir.exists(): + logging.error(f"Invalid output diectory {file_dir}") + return + + name = args.name + db = get_db() + + # Get the domain geometry + config = get_config(db, name) + nwpblend_config = config["output"]["nwpblend"] + n_ens = nwpblend_config.get("n_ens_members") + n_fx = nwpblend_config.get("n_forecasts") + n_qpe = n_fx + ts_seconds = config["pysteps"]["timestep"] + ts = datetime.timedelta(seconds=ts_seconds) + + # Get the file names for the input data + start_qpe = base_time - n_qpe * ts + end_qpe = base_time + query = { + "metadata.product": "QPE", + "metadata.valid_time": {"$gte": start_qpe, "$lte": end_qpe} + } + qpe_df = get_filenames(db, name, query) + + start_blend = base_time + end_blend = base_time + n_fx*ts + + query = { + "metadata.product": "nwpblend", + "metadata.valid_time": {"$gt": base_time, "$lte": end_blend}, + "metadata.base_time": base_time + } + blend_df = get_filenames(db, name, query) + + qpe_fields = [] + qpe_times = [] + + bucket_name = f"{name}.rain" + bucket = GridFSBucket(db, bucket_name=bucket_name) + + for index, row in qpe_df.iterrows(): + filename = row["filename"] + with bucket.open_download_stream_by_name(filename) as stream: + buffer = stream.read() + byte_stream = io.BytesIO(buffer) + ds = netCDF4.Dataset('inmemory', mode='r', + memory=byte_stream.getvalue()) + + # Extract rain rate and handle 3D (time, y, x) or 2D (y, x) + rain_rate = ds.variables["rainfall"][:] + if rain_rate.ndim == 3: + rain_rate = rain_rate[0, :, :] # Take first time slice if present + + # Get valid time (assuming one timestamp per file) + time_var = ds.variables["time"][:] + valid_time = convert_timestamps_to_datetimes( + time_var)[0] # e.g., returns a list + + if index == 0: + y_ref = ds.variables["y"][:] + x_ref = ds.variables["x"][:] + else: + assert np.allclose(ds.variables["y"][:], y_ref) + assert np.allclose(ds.variables["x"][:], x_ref) + + # Accumulate + qpe_fields.append(rain_rate) + qpe_times.append(valid_time) + + # Convert to xarray.DataArray + qpe_array = xr.DataArray( + data=np.stack(qpe_fields), # shape: (time, y, x) + coords={"time": qpe_times, "y": y_ref, "x": x_ref}, + dims=["time", "y", "x"], + name="qpe" + ) + + # Ensure sorted and aligned valid_times across all ensemble members + ensembles = np.sort(blend_df["ensemble"].unique()) + + blend_times = np.sort(blend_df["valid_time"].unique()) + # ensures tz-aware datetime64[ns, UTC] + blend_times = pd.to_datetime(blend_times, utc=True) + # convert to native datetime.datetime + blend_times = [dt.to_pydatetime() for dt in blend_times] + + n_ens = len(ensembles) + n_time = len(blend_times) + ny, nx = y_ref.shape[0], x_ref.shape[0] + + # Initialize a 4D array (ensemble, time, y, x) + blend_data = np.full((n_ens, n_time, ny, nx), np.nan, dtype=np.float32) + + # Mapping from value to index + ensemble_to_idx = {ens: i for i, ens in enumerate(ensembles)} + time_to_idx = {vt: i for i, vt in enumerate(blend_times)} + + for index, row in blend_df.iterrows(): + filename = row["filename"] + ensemble = row["ensemble"] + with bucket.open_download_stream_by_name(filename) as stream: + buffer = stream.read() + byte_stream = io.BytesIO(buffer) + ds = netCDF4.Dataset('inmemory', mode='r', + memory=byte_stream.getvalue()) + + rain_rate = ds.variables["rainfall"][:] + if rain_rate.ndim == 3: + rain_rate = rain_rate[0, :, :] + + time_var = ds.variables["time"][:] + valid_time = convert_timestamps_to_datetimes(time_var)[0] + + assert np.allclose(ds.variables["y"][:], y_ref) + assert np.allclose(ds.variables["x"][:], x_ref) + + # Write into 4D array + ei = ensemble_to_idx[ensemble] + ti = time_to_idx[valid_time] + blend_data[ei, ti, :, :] = rain_rate + + # Build DataArray + blend_array = xr.DataArray( + data=blend_data, + coords={ + "ensemble": ensembles, + "time": blend_times, + "y": y_ref, + "x": x_ref + }, + dims=["ensemble", "time", "y", "x"], + name="blend" + ) + + qpe_times = list(qpe_array.coords["time"].values) + blend_times = list(blend_array.coords["time"].values) + combined_times = qpe_times + blend_times + + # Convert to tz-aware datetime.datetime + combined_times = pd.to_datetime(combined_times, utc=True) + combined_times = [t.to_pydatetime() for t in combined_times] + qpe_data = qpe_array.values + + # Tile across ensemble: + n_ens = blend_array.sizes["ensemble"] + qpe_broadcast = np.tile(qpe_data[None, :, :, :], (n_ens, 1, 1, 1)) + + # Stack QPE and forecasts: + combined_data = np.concatenate([qpe_broadcast, blend_array.values], axis=1) + + # Create combined xarray + combined_array = xr.DataArray( + data=combined_data, + coords={ + "ensemble": blend_array.coords["ensemble"], + "time": combined_times, + "y": y_ref, + "x": x_ref + }, + dims=["ensemble", "time", "y", "x"], + name="rainfall" + ) + template = "$N_$P_$V{%Y%m%d_%H%M%S}.nc" + tstamp = base_time.timestamp() + product = "qpe_nwpblend" + fname = make_nc_name(template, name, product, tstamp, None, None) + fdir = args.directory + file_name = fdir / fname + logging.info(f"Writing data to {file_name}") + + write_rainfall_netcdf( + filename=file_name, + rainfall=combined_array.values, + x=x_ref, + y=y_ref, + time=combined_times, + ensemble=combined_array.coords["ensemble"].values + ) + + return + + +if __name__ == "__main__": + main() diff --git a/pysteps/mongo/write_nc_files.py b/pysteps/mongo/write_nc_files.py new file mode 100644 index 000000000..53192a53a --- /dev/null +++ b/pysteps/mongo/write_nc_files.py @@ -0,0 +1,307 @@ +""" +Write rainfall grids to a netCDF file +""" + +from models import read_nc, make_nc_name_dt +from models import get_db, get_config +from pymongo import MongoClient +import logging +import argparse +import gridfs +import pymongo +import numpy as np +import datetime +import os +import netCDF4 +from pyproj import CRS +import pandas as pd + + +def is_valid_iso8601(time_str: str) -> bool: + """Check if the given string is a valid ISO 8601 datetime.""" + try: + datetime.datetime.fromisoformat(time_str) + return True + except ValueError: + return False + + +def get_base_times(db, base_time_query): + meta_coll = db["AKL.rain.files"] + base_times = list(meta_coll.distinct( + "metadata.base_time", base_time_query)) + return base_times + + +def get_valid_times(db, valid_time_query): + meta_coll = db["AKL.rain.files"] + valid_times = list(meta_coll.distinct( + "metadata.valid_time", valid_time_query)) + return valid_times + + +def get_rain_fields(db: pymongo.MongoClient, query: dict): + meta_coll = db["AKL.rain.files"] + + # Fetch matching filenames and metadata in a single query + fields_projection = {"_id": 1, "filename": 1, "metadata": 1} + results = meta_coll.find(query, projection=fields_projection).sort( + "filename", pymongo.ASCENDING) + files = [] + for doc in results: + record = {"_id": doc["_id"], + "valid_time": doc["metadata"]["valid_time"]} + files.append(record) + return files + + +def load_rain_field(db, file_id): + """Retrieve a specific rain field NetCDF file from GridFS and return as numpy array""" + fs = gridfs.GridFS(db, collection='AKL.rain') + file_obj = fs.get(file_id) + metadata = file_obj.metadata + data_bytes = file_obj.read() + geo_data, valid_time, rain_rate = read_nc(data_bytes) + if isinstance(valid_time, np.ndarray): + valid_time = valid_time.tolist() + return geo_data, metadata, valid_time, rain_rate + + +def write_netcdf(file_path: str, rain: np.ndarray, geo_data: dict, times: list[datetime.datetime], ensembles: list[int]) -> None: + """ + Write a set of rainfall grids to a CF netCDF file + Args: + file_path (str): Full path to the output file + rain (np.ndarray): Rainfall array. Shape is [ensemble, time, y, x] if ensembles is provided, + otherwise [time, y, x] + geo_data (dict): Geospatial information + times (list[datetime.datetime]): list of valid times + ensembles (list[int]): Optional list of valid ensemble numbers + """ + # Convert the times to seconds since 1970-01-01T00:00:00Z + time_stamps = [] + for time in times: + if time.tzinfo is None: + time = time.replace(tzinfo=datetime.timezone.utc) + time_stamp = time.timestamp() + time_stamps.append(time_stamp) + + x = geo_data['x'] + y = geo_data['y'] + projection = geo_data.get('projection', 'EPSG:4326') + + # Create NetCDF file on disk + with netCDF4.Dataset(file_path, mode='w', format='NETCDF4') as ds: + + # Define dimensions + ds.createDimension("y", len(y)) + ds.createDimension("x", len(x)) + ds.createDimension("time", len(times)) + + # Coordinate variables + y_var = ds.createVariable("y", "f4", ("y",)) + y_var[:] = y + y_var.standard_name = "projection_y_coordinate" + y_var.units = "m" + + x_var = ds.createVariable("x", "f4", ("x",)) + x_var[:] = x + x_var.standard_name = "projection_x_coordinate" + x_var.units = "m" + + t_var = ds.createVariable("time", "f8", ("time",)) + t_var[:] = time_stamps + t_var.standard_name = "time" + t_var.units = "seconds since 1970-01-01T00:00:00Z" + t_var.calendar = "standard" + + # Set up the ensemble if we have one + if ensembles is not None: + ds.createDimension("ensemble", len(ensembles)) + e_var = ds.createVariable("ensemble", "i4", ("ensemble",)) + e_var[:] = ensembles + e_var.standard_name = "ensemble" + e_var.units = "1" + + # Rainfall + if ensembles is None: + rain_var = ds.createVariable( + "rainfall", "i2", ("time", "y", "x"), + zlib=True, complevel=5, fill_value=-1 + ) + rain_var[:, :, :] = np.nan_to_num(rain, nan=-1) + + else: + rain_var = ds.createVariable( + "rainfall", "i2", ("ensemble", "time", "y", "x"), + zlib=True, complevel=5, fill_value=-1 + ) + rain_var[:, :, :, :] = np.nan_to_num(rain, nan=-1) + + rain_var.scale_factor = 0.1 + rain_var.add_offset = 0.0 + rain_var.units = "mm/h" + rain_var.long_name = "Rainfall rate" + rain_var.grid_mapping = "projection" + rain_var.coordinates = "time y x" if ensembles is None else "ensemble time y x" + + # CRS + crs = CRS.from_user_input(projection) + cf_grid_mapping = crs.to_cf() + spatial_ref = ds.createVariable("projection", "i4") + for key, value in cf_grid_mapping.items(): + setattr(spatial_ref, key, value) + + # Global attributes + ds.Conventions = "CF-1.10" + ds.title = "Rainfall data" + ds.institution = "Weather Radar New Zealand Ltd" + ds.references = "" + ds.comment = "" + return + + +def main(): + parser = argparse.ArgumentParser( + description="Write rainfall fields to a netCDF file") + + parser.add_argument('-s', '--start', type=str, required=True, + help='Start time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-e', '--end', type=str, required=True, + help='End time yyyy-mm-ddTHH:MM:SS') + parser.add_argument('-n', '--name', type=str, required=True, + help='Name of domain [AKL]') + parser.add_argument('-p', '--product', type=str, required=True, + help='Name of input product [QPE, auckprec]') + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + # Validate start and end time and read them in + if args.start and is_valid_iso8601(args.start): + start_time = datetime.datetime.fromisoformat(str(args.start)) + if start_time.tzinfo is None: + start_time = start_time.replace(tzinfo=datetime.timezone.utc) + else: + logging.error( + "Invalid start time format. Please provide a valid ISO 8601 time string.") + return + + if args.end and is_valid_iso8601(args.end): + end_time = datetime.datetime.fromisoformat(str(args.end)) + if end_time.tzinfo is None: + end_time = end_time.replace(tzinfo=datetime.timezone.utc) + else: + logging.error( + "Invalid start time format. Please provide a valid ISO 8601 time string.") + return + + name = str(args.name) + product = str(args.product) + valid_products = ["QPE", "auckprec", "qpesim"] + if product not in valid_products: + logging.error( + f"Invalid product. Please provide either {valid_products}.") + return + + db = get_db() + meta_coll = db["AKL.rain.files"] + + if product == "QPE": + file_id_query = {'metadata.product': product, + 'metadata.valid_time': {"$gte": start_time, "$lte": end_time}} + file_ids = get_rain_fields(db, file_id_query) + + out_grid = [] + valid_times = [] + geo_out = None + expected_shape = None + for file_id in file_ids: + geo_data, metadata, nc_times, rain_data = load_rain_field( + db, file_id["_id"]) + + if expected_shape is None: + expected_shape = rain_data.shape + elif rain_data.shape != expected_shape: + logging.error(f"Inconsistent rain_data shape: expected {expected_shape}, got {rain_data.shape}") + return + + out_grid.append(rain_data) + valid_times.append(nc_times) + if geo_out is None: + geo_out = geo_data + + # QPE files are named using the start and end valid times + name_template = "$N_$P_$V{%Y-%m-%dT%H:%M:%S}_$B{%Y-%m-%dT%H:%M:%S}.nc" + file_name = make_nc_name_dt( + name_template, name, product, start_time, end_time, None) + out_array = np.array(out_grid) + + logging.info(f"Writing {file_name}") + write_netcdf(file_name, out_array, geo_out, valid_times, None) + + else: + + # Get the list of base times in the time period + base_time_query = {'metadata.product': product, + 'metadata.base_time': {"$gte": start_time, "$lte": end_time}} + base_times = list(meta_coll.distinct( + "metadata.base_time", base_time_query)) + + # Loop over the base times that have been found + for base_time in base_times: + + # Get the sorted list of ensmble members and valid times for this base time + ensemble_query = {'metadata.product': product, + 'metadata.base_time': base_time} + ensembles = list(meta_coll.distinct( + "metadata.ensemble", ensemble_query)) + ensembles.sort() + ne = len(ensembles) + valid_times = list(meta_coll.distinct("metadata.valid_time", ensemble_query)) + valid_times.sort() + nt = len(valid_times) + # Loop over the ensembles and read in the grids + out_grid = [] + geo_out = None + expected_shape = None + for ensemble in ensembles: + + # Get all the valid times for this ensemble + file_id_query = {'metadata.product': product, + 'metadata.base_time': base_time, 'metadata.ensemble': ensemble} + # Check that the expected number of fields have been found + file_ids = get_rain_fields(db, file_id_query) + if len(valid_times) != len(file_ids): + logging.error(f"{base_time}:Expected {len(valid_times)} found {len(file_ids)} valid times") + + for file_id in file_ids: + geo_data, metadata, nc_times, rain_data = load_rain_field( + db, file_id["_id"]) + + if expected_shape is None: + expected_shape = rain_data.shape + elif rain_data.shape != expected_shape: + logging.error(f"Inconsistent rain_data shape: expected {expected_shape}, got {rain_data.shape}") + return + + out_grid.append(rain_data) + if geo_out is None: + geo_out = geo_data + + # Forecast files are named using their base time + name_template = "$N_$P_$V{%Y-%m-%dT%H:%M:%S}.nc" + ny,nx = expected_shape + file_name = make_nc_name_dt( + name_template, name, product, base_time, None, None) + out_array = np.array(out_grid).reshape(ne,nt,ny,nx) + + logging.info(f"Writing {file_name}") + write_netcdf(file_name, out_array, geo_out, valid_times, ensembles) + + return + + +if __name__ == "__main__": + main() diff --git a/pysteps/utils/__init__.py b/pysteps/utils/__init__.py index 9594a75ae..4348d4db4 100644 --- a/pysteps/utils/__init__.py +++ b/pysteps/utils/__init__.py @@ -12,3 +12,4 @@ from .tapering import * from .transformation import * from .reprojection import * +from .transformer import * diff --git a/pysteps/utils/transformer.py b/pysteps/utils/transformer.py new file mode 100644 index 000000000..42f7475ef --- /dev/null +++ b/pysteps/utils/transformer.py @@ -0,0 +1,180 @@ +import numpy as np +import scipy.stats as scipy_stats +from scipy.interpolate import interp1d +from typing import Optional + +class BaseTransformer: + def __init__(self, threshold: float = 0.5, zerovalue: Optional[float] = None): + self.threshold = threshold + self.zerovalue = zerovalue + self.metadata = {} + + def transform(self, R: np.ndarray) -> np.ndarray: + raise NotImplementedError + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + raise NotImplementedError + + def get_metadata(self) -> dict: + return self.metadata.copy() + +class DBTransformer(BaseTransformer): + """ + DBTransformer applies a thresholded dB transform to rain rate fields. + + Parameters: + threshold (float): Rain rate threshold (in mm/h). Values below this are set to `zerovalue` in dB. + zerovalue (Optional[float]): Value in dB space to assign below-threshold pixels. If None, defaults to log10(threshold) - 0.1 + """ + + def __init__(self, threshold: float = 0.5, zerovalue: Optional[float] = None): + super().__init__(threshold, zerovalue) + threshold_db = 10.0 * np.log10(self.threshold) + + if self.zerovalue is None: + self.zerovalue = threshold_db - 0.1 + + self.metadata = { + "transform": "dB", + "threshold": self.threshold, # stored in mm/h + "zerovalue": self.zerovalue # stored in dB + } + + def transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + mask = R < self.threshold + R[~mask] = 10.0 * np.log10(R[~mask]) + R[mask] = self.zerovalue + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + R = 10.0 ** (R / 10.0) + R[R < self.threshold] = 0 + return R + + +class BoxCoxTransformer(BaseTransformer): + def __init__(self, Lambda: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.Lambda = Lambda + + def transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + mask = R < self.threshold + + if self.Lambda == 0.0: + R[~mask] = np.log(R[~mask]) + tval = np.log(self.threshold) + else: + R[~mask] = (R[~mask] ** self.Lambda - 1) / self.Lambda + tval = (self.threshold ** self.Lambda - 1) / self.Lambda + + if self.zerovalue is None: + self.zerovalue = tval - 1 + + R[mask] = self.zerovalue + + self.metadata = { + "transform": "BoxCox", + "lambda": self.Lambda, + "threshold": tval, + "zerovalue": self.zerovalue, + } + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + if self.Lambda == 0.0: + R = np.exp(R) + else: + R = np.exp(np.log(self.Lambda * R + 1) / self.Lambda) + + threshold_inv = ( + np.exp(np.log(self.Lambda * self.metadata["threshold"] + 1) / self.Lambda) + if self.Lambda != 0.0 else + np.exp(self.metadata["threshold"]) + ) + + R[R < threshold_inv] = self.metadata["zerovalue"] + self.metadata["transform"] = None + return R + +class NQTransformer(BaseTransformer): + def __init__(self, a: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.a = a + self._inverse_interp = None + + def transform(self, R: np.ndarray) -> np.ndarray: + R = R.copy() + shape = R.shape + R = R.ravel() + mask = ~np.isnan(R) + R_ = R[mask] + + n = R_.size + Rpp = ((np.arange(n) + 1 - self.a) / (n + 1 - 2 * self.a)) + Rqn = scipy_stats.norm.ppf(Rpp) + R_sorted = R_[np.argsort(R_)] + R_trans = np.interp(R_, R_sorted, Rqn) + + self.zerovalue = np.min(R_) + R_trans[R_ == self.zerovalue] = 0 + + self._inverse_interp = interp1d( + Rqn, R_sorted, bounds_error=False, + fill_value=(float(R_sorted.min()), float(R_sorted.max())) # type: ignore + ) + + R[mask] = R_trans + R = R.reshape(shape) + + self.metadata = { + "transform": "NQT", + "threshold": R_trans[R_trans > 0].min(), + "zerovalue": 0, + } + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + if self._inverse_interp is None: + raise RuntimeError("Must call transform() before inverse_transform()") + + R = R.copy() + shape = R.shape + R = R.ravel() + mask = ~np.isnan(R) + R[mask] = self._inverse_interp(R[mask]) + R = R.reshape(shape) + + self.metadata["transform"] = None + return R + +class SqrtTransformer(BaseTransformer): + def transform(self, R: np.ndarray) -> np.ndarray: + R = np.sqrt(R) + self.metadata = { + "transform": "sqrt", + "threshold": np.sqrt(self.threshold), + "zerovalue": np.sqrt(self.zerovalue) if self.zerovalue else 0.0 + } + return R + + def inverse_transform(self, R: np.ndarray) -> np.ndarray: + R = R**2 + self.metadata["transform"] = None + return R + +def get_transformer(name: str, **kwargs) -> BaseTransformer: + name = name.lower() + if name == "boxcox": + return BoxCoxTransformer(**kwargs) + elif name == "db": + return DBTransformer(**kwargs) + elif name == "nqt": + return NQTransformer(**kwargs) + elif name == "sqrt": + return SqrtTransformer(**kwargs) + else: + raise ValueError(f"Unknown transformer type: {name}")