diff --git a/.github/workflows/runPytest.yml b/.github/workflows/runPytest.yml new file mode 100644 index 0000000..09c957e --- /dev/null +++ b/.github/workflows/runPytest.yml @@ -0,0 +1,43 @@ +name: Run Pytest + +on: + push: + branches: + - "**" + pull_request: + branches: + - "**" + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + os: ["ubuntu-latest", "macos-latest", "windows-latest"] + # python-version: ["3.11", "3.12", "3.13", "3.14"] + python-version: ["3.14"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip install -e .[test] + + - name: Run tests + run: | + pytest tests/ -v + +# pytest tests/ -v --cov-report xml --cov=NOCAT + # - name: Upload coverage to Codecov + # uses: codecov/codecov-action@v4 + # with: + # fail_ci_if_error: false + # flags: pytest + # files: ./coverage.xml + # token: ${{ secrets.CODECOV_TOKEN }} # required \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e712f6c..9332c16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,3 +70,9 @@ exclude = [ verbose = 1 quiet = false color = true + +[project.optional-dependencies] +test = [ + "pytest", + "pytest-cov", +] diff --git a/requirements.txt b/requirements.txt index 337e900..b39e4f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,4 +21,5 @@ geopandas pvlib tqdm glidertools -ipykernel \ No newline at end of file +ipykernel +pytest \ No newline at end of file diff --git a/src/toolbox/pipeline.py b/src/toolbox/pipeline.py index ffb2ced..9f87258 100644 --- a/src/toolbox/pipeline.py +++ b/src/toolbox/pipeline.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Pipeline Class""" +"""Pipeline class definition to handle configuration and step execution.""" import yaml import pandas as pd @@ -34,9 +34,24 @@ ) _PIPELINE_LOGGER_NAME = "toolbox.pipeline" +"""Global logger name for the pipeline. Used to create child loggers for steps.""" def _setup_logging(log_file=None, level=logging.INFO): - """Set up logging for the entire pipeline.""" + """ + Set up logging for the entire pipeline. + + Parameters + ---------- + log_file : str, optional + Path to the log file. If provided, logs will be written to this file. + level : int, optional + Logging level (e.g., logging.INFO, logging.DEBUG). + + Returns + ------- + logging.Logger + Configured logger instance. + """ logger = logging.getLogger(_PIPELINE_LOGGER_NAME) logger.setLevel(level) logger.propagate = False @@ -69,14 +84,29 @@ def _setup_logging(log_file=None, level=logging.INFO): class Pipeline(ConfigMirrorMixin): """ + Pipeline that manages a sequence of processing steps. + Config-aware pipeline that can: - Load config YAML into private self._parameters - Keep global_parameters mirrored to _parameters['pipeline'] - Build, run, and export steps as before + + Parameters + ---------- + ConfigMirrorMixin : Class + Class to handle configuration + """ def __init__(self, config_path=None): - """Initialize pipeline with optional config file""" + """ + Initialize pipeline with optional config file. + + Parameters + ---------- + config_path : str, optional + Path to the YAML configuration file. + """ self.steps = [] # hierarchical step configs self.graph = Digraph("Pipeline", format="png", graph_attr={"rankdir": "TB"}) self.global_parameters = {} # mirrors _parameters["pipeline"] @@ -95,7 +125,16 @@ def __init__(self, config_path=None): self.logger.info("Pipeline initialised") def build_steps(self, steps_config, parent_name=None): - """Recursively build steps from configuration""" + """ + Recursively build steps from configuration. + + Parameters + ---------- + steps_config : list of dict + List of step configurations. + parent_name : str, optional + Name of the parent step, if any. + """ for step in steps_config: REQUIRED_STEPS = STEP_DEPENDENCIES.get(step["name"], []) for required_step in REQUIRED_STEPS: @@ -121,7 +160,27 @@ def add_step( parent_name=None, run_immediately=False, ): - """Dynamically adds a step and optionally runs it immediately""" + """ + Dynamically adds a step and optionally runs it immediately. + + Parameters + ---------- + step_name : str + Name of the step to add. + parameters : dict, optional + Parameters for the step. + diagnostics : bool, optional + Whether to enable diagnostics for this step. + parent_name : str, optional + Name of the parent step, if any. + run_immediately : bool, optional + Whether to run the step immediately after adding it. + + Raises + ------ + ValueError + If the step name is not recognized or a specified parent step is not found. + """ if step_name not in STEP_CLASSES: raise ValueError( f"Step '{step_name}' is not recognized or missing @register_step." @@ -150,7 +209,16 @@ def add_step( self._context = self.execute_step(step_config, self._context) def _find_step(self, steps_list, step_name): - """Recursively find a step by name""" + """ + Recursively find a step by name. + + Parameters + ---------- + steps_list : list of dict + List of step configurations. + step_name : str + Name of the step to find. + """ for step in steps_list: if step["name"] == step_name: return step @@ -160,13 +228,24 @@ def _find_step(self, steps_list, step_name): return None def execute_step(self, step_config, _context): - """Executes a single step""" + """ + Executes a single step. + + Parameters + ---------- + step_config : dict + Configuration for the step to execute. + _context : dict + Current context to pass to the step. + """ step = create_step(step_config, _context) self.logger.info(f"Executing: {step.name}") return step.run() def run_last_step(self): - """Runs only the most recently added step""" + """ + Runs only the most recently added step based on the index in self.steps. + """ if not self.steps: self.logger.info("No steps to run.") return @@ -175,7 +254,12 @@ def run_last_step(self): self._context = self.execute_step(last_step, self._context) def run(self): - """Runs the entire pipeline""" + """ + Runs the entire pipeline. + + If visualisation is specified in the configuration parameters, a visualisation + of the pipeline execution will be generated. + """ for step in self.steps: self._context = self.execute_step(step, self._context) @@ -183,10 +267,13 @@ def run(self): self.visualise_pipeline() def visualise_pipeline(self): - """Generates a visualisation of the pipeline execution""" + """ + Generates a visualisation of the pipeline execution. + """ self.graph.clear() def add_to_graph(step_config, parent_name=None, step_order=None): + """Add a step to the graph, intended for recursive use.""" step_name = step_config["name"] diagnostics = step_config.get("diagnostics", False) color = "red" if diagnostics else "black" @@ -213,7 +300,14 @@ def add_to_graph(step_config, parent_name=None, step_order=None): self.graph.render("pipeline_visualisation", view=True) def generate_config(self): - """Generate a configuration dictionary from the current pipeline setup""" + """ + Generate a configuration dictionary from the current pipeline setup. + + returns + ------- + dict + Configuration dictionary of the current pipeline. + """ cfg = { "pipeline": self.global_parameters, "steps": self.steps, @@ -223,7 +317,19 @@ def generate_config(self): return cfg def export_config(self, output_path="generated_pipeline.yaml"): - """Write current config to file (respects private _parameters)""" + """ + Write current config to file (respects private _parameters) + + parameters + ---------- + output_path : str + Path to save the exported configuration YAML file. + + returns + ------- + dict + Configuration dictionary of the current pipeline. + """ cfg = self.generate_config() with open(output_path, "w") as f: yaml.safe_dump(cfg, f, sort_keys=False) @@ -231,7 +337,14 @@ def export_config(self, output_path="generated_pipeline.yaml"): return cfg def save_config(self, path="pipeline_config.yaml"): - """Save the canonical private config (same as manager.save_config).""" + """ + Save the canonical private config (same as manager.save_config). + + parameters + ---------- + path : str + Path to save the exported configuration YAML file. + """ # ensure _parameters is up to date self._parameters.update(self.generate_config()) super().save_config(path) \ No newline at end of file diff --git a/src/toolbox/steps/__init__.py b/src/toolbox/steps/__init__.py index 97747bb..a1feb92 100644 --- a/src/toolbox/steps/__init__.py +++ b/src/toolbox/steps/__init__.py @@ -30,14 +30,20 @@ # Global registries STEP_CLASSES = {} +"""Dictionary mapping step names to their implementing classes.""" QC_CLASSES = {} +"""Dictionary mapping QC test names to their implementing classes.""" STEP_DEPENDENCIES = { "QC: Salinity": ["Load OG1"], } +"""Dictionary of explicit dependencies between steps by name.""" def discover_steps(): - """Dynamically discover and import step modules from the custom directory.""" + """ + Dynamically discover and import step modules from the custom directory. + This populates the global STEP_CLASSES and QC_CLASSES registries for use elsewhere. + """ base_dir = pathlib.Path(__file__).parent.resolve() custom_dir = base_dir / "custom" print(f"[Discovery] Scanning for step modules in {custom_dir}") diff --git a/src/toolbox/steps/base_step.py b/src/toolbox/steps/base_step.py index b1f2a0b..8ecea2b 100644 --- a/src/toolbox/steps/base_step.py +++ b/src/toolbox/steps/base_step.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""This module defines the base class for pipeline steps and configurations.""" from toolbox.utils.config_mirror import ConfigMirrorMixin import warnings @@ -21,8 +22,8 @@ warnings.formatwarning = lambda msg, *args, **kwargs: f"{msg}\n" -# Registry of explicitly registered step classes REGISTERED_STEPS = {} +"""Registry of explicitly registered step classes.""" def register_step(cls): diff --git a/src/toolbox/steps/base_test.py b/src/toolbox/steps/base_test.py index e731cb2..08abf55 100644 --- a/src/toolbox/steps/base_test.py +++ b/src/toolbox/steps/base_test.py @@ -13,9 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""This module defines the base class for QC tests and a registry for QC test classes.""" -# Registry of explicitly registered step classes REGISTERED_QC = {} +"""Registry of explicitly registered QC test classes.""" + flag_cols = { 0: "gray", 1: "blue", @@ -28,6 +30,7 @@ 8: "cyan", 9: "black", } +"""Map of QC flag values to colors for diagnostics plotting.""" def register_qc(cls): @@ -43,10 +46,19 @@ def register_qc(cls): class BaseTest: """ + Initializes a base class for quality control, to be further tweaked when inherited. + + Follow the docstring format below when creating new QC tests. + + Target Variable: "Any" or a specific variable names (see impossible_location_test.py) + Flag Number: "Any" or a specific ARGO flag number + Variables Flagged: "Any" or specific variable names, possibly external to the target variable (see valid_profile_test.py) + Your description follows here. + Target Variable: Flag Number: Variables Flagged: - ? description ? + """ test_name = None @@ -72,9 +84,17 @@ def __init__(self, data, **kwargs): self.flags = None def return_qc(self): + """Representative of QC processing, to be overridden by subclasses. + + Returns + ------- + flags : array-like + Output QC flags for the data specific to the test. + """ self.flags = None # replace with processing of some kind return self.flags def plot_diagnostics(self): - # Any relevant diagnostic + """Representative of diagnostic plotting (optional).""" + # Any relevant diagnostic is generated or written out here pass diff --git a/src/toolbox/steps/custom/apply_qc.py b/src/toolbox/steps/custom/apply_qc.py index 7209077..101408b 100644 --- a/src/toolbox/steps/custom/apply_qc.py +++ b/src/toolbox/steps/custom/apply_qc.py @@ -57,6 +57,11 @@ def organise_flags(self, new_flags): As an example, if an existing flag is 2 (probably good data) and a new flag is 4 (bad data), the resulting flag will be 4. 2 (probably good data) + 4 (bad data) -> 4 (bad data) 3 (probably bad data) + 5 (value changed) -> 3 (probably bad data) + + parameters + ---------- + new_flags : xarray.Dataset + Dataset containing new QC flag variables to be merged into the existing flag store. """ # Define combinatrix for handling flag upgrade behaviour @@ -91,8 +96,16 @@ def organise_flags(self, new_flags): self.flag_store[column_name] = new_flags[column_name] def run(self): - """Run the Apply QC step.""" - + """ + Run the Apply QC step. + + raises + ------ + KeyError + If no QC operations are specified, if requested QC tests are invalid, or esssential variables are missing. + ValueError + If no data is found in context. + """ # Defining the order of operations if len(self.qc_settings.keys()) == 0: raise KeyError( @@ -158,9 +171,14 @@ def run(self): # Update QC history for flagged_var in returned_flags.data_vars: + # Track percent of flags no longer 0 (following ARGO convention) percent_flagged = ( returned_flags[flagged_var].to_numpy() != 0 - ).sum() / len(returned_flags) + ).sum() / len(returned_flags[flagged_var]) + if percent_flagged == 0: + self.log_warn(f"All flags for {flagged_var} remain 0 after {qc_test_name}") + # else: # TODO: Add 'verbose' log option if needed. Might not need to happen at this point. + # self.log(f"{percent_flagged*100:.2f}% of {flagged_var} points accounted for by {qc_test_name}") qc_history.setdefault(flagged_var, []).append( (qc_test_name, percent_flagged) ) diff --git a/src/toolbox/steps/custom/blank_step.py b/src/toolbox/steps/custom/blank_step.py index fef1fd0..31b14e4 100644 --- a/src/toolbox/steps/custom/blank_step.py +++ b/src/toolbox/steps/custom/blank_step.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Example step template. Copy and populate this example, which will inherit additional functionality from BaseStep.""" + #### Mandatory imports #### from toolbox.steps.base_step import BaseStep, register_step from toolbox.utils.qc_handling import QCHandlingMixin diff --git a/src/toolbox/steps/custom/export.py b/src/toolbox/steps/custom/export.py index dd22777..cb657e1 100644 --- a/src/toolbox/steps/custom/export.py +++ b/src/toolbox/steps/custom/export.py @@ -19,10 +19,13 @@ #### Mandatory imports #### from ..base_step import BaseStep, register_step import toolbox.utils.diagnostics as diag - +import json @register_step class ExportStep(BaseStep): + """ + Step to export data in various formats. + """ step_name = "Data Export" def run(self): @@ -36,6 +39,11 @@ def run(self): else: self.log(f"Data found in context.") data = self.context["data"] + # Add exiting notes on QC history if available TODO: Move earlier to individual QC steps on each data variable attribute + if "qc_history" in self.context: + self.log(f"QC history found in context.") + data.attrs["delayed_qc_history"] = json.dumps(self.context["qc_history"]) + export_format = self.parameters["export_format"] output_path = self.parameters["output_path"] @@ -65,7 +73,9 @@ def run(self): return self.context def generate_diagnostics(self): - """Generate diagnostics for the export step.""" + """ + Generate diagnostics for the export step. + """ self.log(f"Generating diagnostics for {self.step_name}") diag.generate_diagnostics(self.context, self.step_name) self.log(f"Diagnostics generated successfully.") diff --git a/src/toolbox/steps/custom/gen_data.py b/src/toolbox/steps/custom/gen_data.py index 24e4a27..ad80a56 100644 --- a/src/toolbox/steps/custom/gen_data.py +++ b/src/toolbox/steps/custom/gen_data.py @@ -26,6 +26,8 @@ @register_step class GenerateData(BaseStep): """ + Step for generating synthetic data for testing pipelines. + Example config setup: """ diff --git a/src/toolbox/steps/custom/load_data.py b/src/toolbox/steps/custom/load_data.py index 2f59624..bc5347d 100644 --- a/src/toolbox/steps/custom/load_data.py +++ b/src/toolbox/steps/custom/load_data.py @@ -48,6 +48,7 @@ def run(self): # load data from xarray self.data = xr.open_dataset(self.file_path) + self.log(f"Loaded data from {self.file_path}") # Check that the "TIME" variable is monotonic and nanless - then make it a coordinate if "TIME" in self.data.coords: # Temporary fix for BODC OG1 files where TIME is a coord diff --git a/src/toolbox/steps/custom/qc/archive.py b/src/toolbox/steps/custom/qc/archive.py index 8b32668..fadd91d 100644 --- a/src/toolbox/steps/custom/qc/archive.py +++ b/src/toolbox/steps/custom/qc/archive.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Archive file of unused QC tests for reference.""" + # import xarray as xr # from datetime import datetime # from geodatasets import get_path diff --git a/src/toolbox/steps/custom/qc/blank_test.py b/src/toolbox/steps/custom/qc/blank_test.py index ce7309d..6b5fd59 100644 --- a/src/toolbox/steps/custom/qc/blank_test.py +++ b/src/toolbox/steps/custom/qc/blank_test.py @@ -14,13 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Example QC test template, using parts of impossible_date_test as a skeleton.""" + #### Mandatory imports #### -from toolbox.steps.base_test import BaseTest, register_qc, flag_cols +# from toolbox.steps.base_test import BaseTest, register_qc, flag_cols # Uncomment when implementing +from toolbox.steps.base_test import BaseTest #### Custom imports #### +# any additional imports required for the test go here -@register_qc +# @register_qc # Uncomment when implementing class impossible_date_test(BaseTest): """ Target Variable: TIME @@ -40,4 +44,5 @@ def return_qc(self): return self.flags def plot_diagnostics(self): - plt.show(block=True) \ No newline at end of file + # plt.show(block=True) + pass diff --git a/src/toolbox/steps/custom/qc/flag_full_profile.py b/src/toolbox/steps/custom/qc/flag_full_profile.py index edf8af7..baa115e 100644 --- a/src/toolbox/steps/custom/qc/flag_full_profile.py +++ b/src/toolbox/steps/custom/qc/flag_full_profile.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""QC test to flag entire glider profiles based on number of bad flags.""" + #### Mandatory imports #### import numpy as np from toolbox.steps.base_test import BaseTest, register_qc, flag_cols @@ -45,24 +47,29 @@ class flag_full_profile(BaseTest): } diagnostics: true """ + test_name = "flag full profile" # Specify if test target variable is user-defined (if True, __init__ has to be redefined) dynamic = True def __init__(self, data, **kwargs): - # Check the necessary kwargs are available required_kwargs = {"check_vars"} if not required_kwargs.issubset(set(kwargs.keys())): - raise KeyError(f"{required_kwargs - set(kwargs.keys())} are missing from {self.test_name} settings") + raise KeyError( + f"{required_kwargs - set(kwargs.keys())} are missing from {self.test_name} settings" + ) # Specify the tests paramters from kwargs (config) - self.expected_parameters = {k: v for k, v in kwargs.items() if k in required_kwargs} + self.expected_parameters = { + k: v for k, v in kwargs.items() if k in required_kwargs + } self.required_variables = ( - list(self.expected_parameters["check_vars"].keys()) + - [f"{k}_QC" for k in self.expected_parameters["check_vars"].keys()] + - ["PROFILE_NUMBER"]) + list(self.expected_parameters["check_vars"].keys()) + + [f"{k}_QC" for k in self.expected_parameters["check_vars"].keys()] + + ["PROFILE_NUMBER"] + ) if data is not None: self.data = data.copy(deep=True) @@ -73,21 +80,27 @@ def __init__(self, data, **kwargs): self.flags = None def return_qc(self): - + # TODO: Add support for flagging if threshold is a mix of 3 (questionable) and 4 (definitely bad) flags # Subset the data self.data = self.data[self.required_variables] for var, threshold in self.check_vars.items(): - flag_counts = (self.data[f"{var}_QC"] == 4).groupby(self.data["PROFILE_NUMBER"]).sum() - bad_profiles = flag_counts.where(flag_counts >= threshold, drop=True)["PROFILE_NUMBER"] + flag_counts = ( + (self.data[f"{var}_QC"] == 4).groupby(self.data["PROFILE_NUMBER"]).sum() + ) # Default to flag 4 (definitely bad) + bad_profiles = flag_counts.where(flag_counts >= threshold, drop=True)[ + "PROFILE_NUMBER" + ] self.data[f"{var}_QC"] = xr.where( self.data[f"PROFILE_NUMBER"].isin(bad_profiles), 4, - self.data[f"{var}_QC"] + self.data[f"{var}_QC"], ) # Select just the flags - self.flags = self.data[[var_qc for var_qc in self.data.data_vars if "_QC" in var_qc]] + self.flags = self.data[ + [var_qc for var_qc in self.data.data_vars if "_QC" in var_qc] + ] return self.flags @@ -96,17 +109,16 @@ def plot_diagnostics(self): # Plot the QC output n_plots = len(self.check_vars.keys()) - fig, axs = plt.subplots(nrows=n_plots, figsize=(8, 4*n_plots), dpi=200) + fig, axs = plt.subplots(nrows=n_plots, figsize=(8, 4 * n_plots), dpi=200) if n_plots == 1: axs = [axs] for ax, var in zip(axs, self.check_vars.keys()): - for i in range(10): # Plot by flag number - plot_data = self.data[ - [var, "N_MEASUREMENTS"] - ].where(self.data[f"{var}_QC"] == i, drop=True) + plot_data = self.data[[var, "N_MEASUREMENTS"]].where( + self.data[f"{var}_QC"] == i, drop=True + ) if len(plot_data[var]) == 0: continue @@ -130,4 +142,4 @@ def plot_diagnostics(self): ax.legend(title="Flags", loc="upper right") fig.tight_layout() - plt.show(block=True) \ No newline at end of file + plt.show(block=True) diff --git a/src/toolbox/steps/custom/qc/impossible_date_test.py b/src/toolbox/steps/custom/qc/impossible_date_test.py index 7a79b71..c50ce9c 100644 --- a/src/toolbox/steps/custom/qc/impossible_date_test.py +++ b/src/toolbox/steps/custom/qc/impossible_date_test.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""QC test to identify impossible dates in TIME variable.""" + #### Mandatory imports #### from toolbox.steps.base_test import BaseTest, register_qc, flag_cols @@ -24,6 +26,7 @@ import matplotlib import matplotlib.pyplot as plt + @register_qc class impossible_date_test(BaseTest): """ @@ -39,22 +42,23 @@ class impossible_date_test(BaseTest): qc_outputs = ["TIME_QC"] def return_qc(self): - # Convert to polars self.df = pl.from_pandas( - self.data[self.required_variables].to_dataframe(), - nan_to_null=False + self.data[self.required_variables].to_dataframe(), nan_to_null=False ) # Check if any of the datetime stamps fall outside 1985 and the current datetime + # TODO: Add optional bounds via parameters (such as known deployment dates, for example) self.df = self.df.with_columns( - pl.when( - pl.col("TIME").is_null() - ).then(9) + pl.when(pl.col("TIME").is_null()) + .then(9) .when( - ((pl.col("TIME") > datetime(1985, 1, 1)) - & (pl.col("TIME") < datetime.now())) - ).then(1) + ( + (pl.col("TIME") > datetime(1985, 1, 1)) + & (pl.col("TIME") < datetime.now()) + ) + ) + .then(1) .otherwise(4) .alias("TIME_QC") ) @@ -63,10 +67,9 @@ def return_qc(self): flags = self.df.select(pl.col("^.*_QC$")) self.flags = xr.Dataset( data_vars={ - col: ("N_MEASUREMENTS", flags[col].to_numpy()) - for col in flags.columns + col: ("N_MEASUREMENTS", flags[col].to_numpy()) for col in flags.columns }, - coords={"N_MEASUREMENTS": self.data["N_MEASUREMENTS"]} + coords={"N_MEASUREMENTS": self.data["N_MEASUREMENTS"]}, ) return self.flags @@ -76,9 +79,7 @@ def plot_diagnostics(self): fig, ax = plt.subplots(figsize=(6, 4), dpi=200) for i in range(10): # Plot by flag number - plot_data = self.df.with_row_index().filter( - pl.col("TIME_QC") == i - ) + plot_data = self.df.with_row_index().filter(pl.col("TIME_QC") == i) if len(plot_data) == 0: continue @@ -98,4 +99,4 @@ def plot_diagnostics(self): ) ax.legend(title="Flags", loc="upper right") fig.tight_layout() - plt.show(block=True) \ No newline at end of file + plt.show(block=True) diff --git a/src/toolbox/steps/custom/qc/impossible_location_test.py b/src/toolbox/steps/custom/qc/impossible_location_test.py index 2ae95ea..7bb42af 100644 --- a/src/toolbox/steps/custom/qc/impossible_location_test.py +++ b/src/toolbox/steps/custom/qc/impossible_location_test.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""QC test to identify impossible locations in LATITUDE and LONGITUDE variables.""" + #### Mandatory imports #### from toolbox.steps.base_test import BaseTest, register_qc, flag_cols @@ -23,6 +25,7 @@ import matplotlib import matplotlib.pyplot as plt + @register_qc class impossible_location_test(BaseTest): """ @@ -38,33 +41,30 @@ class impossible_location_test(BaseTest): qc_outputs = ["LATITUDE_QC", "LONGITUDE_QC"] def return_qc(self): - # Convert to polars self.df = pl.from_pandas( - self.data[self.required_variables].to_dataframe(), - nan_to_null=False + self.data[self.required_variables].to_dataframe(), nan_to_null=False ) # Check LAT/LONG exist within expected bounds + # TODO: Add optional bounds via parameters (such as Southern Hemisphere, for example) for label, bounds in zip(["LATITUDE", "LONGITUDE"], [(-90, 90), (-180, 180)]): self.df = self.df.with_columns( - pl.when( - pl.col(label).is_nan() - ).then(9) - .when( - (pl.col(label) > bounds[0]) & (pl.col(label) < bounds[1]) - ).then(1) - .otherwise(4).alias(f"{label}_QC") + pl.when(pl.col(label).is_nan()) + .then(9) + .when((pl.col(label) > bounds[0]) & (pl.col(label) < bounds[1])) + .then(1) + .otherwise(4) + .alias(f"{label}_QC") ) # Convert back to xarray flags = self.df.select(pl.col("^.*_QC$")) self.flags = xr.Dataset( data_vars={ - col: ("N_MEASUREMENTS", flags[col].to_numpy()) - for col in flags.columns + col: ("N_MEASUREMENTS", flags[col].to_numpy()) for col in flags.columns }, - coords={"N_MEASUREMENTS": self.data["N_MEASUREMENTS"]} + coords={"N_MEASUREMENTS": self.data["N_MEASUREMENTS"]}, ) return self.flags @@ -78,9 +78,7 @@ def plot_diagnostics(self): ): for i in range(10): # Plot by flag number - plot_data = self.df.with_row_index().filter( - pl.col(f"{var}_QC") == i - ) + plot_data = self.df.with_row_index().filter(pl.col(f"{var}_QC") == i) if len(plot_data) == 0: continue @@ -103,4 +101,4 @@ def plot_diagnostics(self): fig.suptitle("Impossible Location Test") fig.tight_layout() - plt.show(block=True) \ No newline at end of file + plt.show(block=True) diff --git a/src/toolbox/steps/custom/qc/impossible_speed_test.py b/src/toolbox/steps/custom/qc/impossible_speed_test.py index 8196e49..35d0296 100644 --- a/src/toolbox/steps/custom/qc/impossible_speed_test.py +++ b/src/toolbox/steps/custom/qc/impossible_speed_test.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""QC test to identify impossible speeds in glider data.""" + #### Mandatory imports #### from toolbox.steps.base_test import BaseTest, register_qc, flag_cols @@ -24,6 +26,7 @@ import numpy as np import matplotlib + @register_qc class impossible_speed_test(BaseTest): """ @@ -39,14 +42,14 @@ class impossible_speed_test(BaseTest): qc_outputs = ["TIME_QC", "LATITUDE_QC", "LONGITUDE_QC"] def return_qc(self): - # Convert to polars self.df = pl.from_pandas( - self.data[self.required_variables].to_dataframe(), - nan_to_null=False + self.data[self.required_variables].to_dataframe(), nan_to_null=False ) - self.df = self.df.with_columns((pl.col("TIME").diff().cast(pl.Float64) * 1e-9).alias("dt")) + self.df = self.df.with_columns( + (pl.col("TIME").diff().cast(pl.Float64) * 1e-9).alias("dt") + ) for label in ["LATITUDE", "LONGITUDE"]: self.df = self.df.with_columns( pl.col(label) @@ -58,16 +61,17 @@ def return_qc(self): self.df = self.df.with_columns( (pl.col(f"delta_{label}") / pl.col("dt")).alias(f"{label}_speed") ) + # Define absolute speed self.df = self.df.with_columns( - ((pl.col("LATITUDE_speed") ** 2 + pl.col("LONGITUDE_speed") ** 2) ** 0.5).alias( - "absolute_speed" - ) + ( + (pl.col("LATITUDE_speed") ** 2 + pl.col("LONGITUDE_speed") ** 2) ** 0.5 + ).alias("absolute_speed") ) # TODO: Does this need a flag for potentially bad data for cases where speed is inf? self.df = self.df.with_columns( ( - (pl.col("absolute_speed") < 3) + (pl.col("absolute_speed") < 3) # Speed threshold & pl.col("absolute_speed").is_not_null() & pl.col("absolute_speed").is_finite() ).alias("speed_is_valid") @@ -75,9 +79,8 @@ def return_qc(self): for label in ["LATITUDE", "LONGITUDE", "TIME"]: self.df = self.df.with_columns( - pl.when( - pl.col("speed_is_valid") - ).then(1) + pl.when(pl.col("speed_is_valid")) + .then(1) .otherwise(4) .alias(f"{label}_QC") ) @@ -86,10 +89,9 @@ def return_qc(self): flags = self.df.select(pl.col("^.*_QC$")) self.flags = xr.Dataset( data_vars={ - col: ("N_MEASUREMENTS", flags[col].to_numpy()) - for col in flags.columns + col: ("N_MEASUREMENTS", flags[col].to_numpy()) for col in flags.columns }, - coords={"N_MEASUREMENTS": self.data["N_MEASUREMENTS"]} + coords={"N_MEASUREMENTS": self.data["N_MEASUREMENTS"]}, ) return self.flags @@ -100,9 +102,7 @@ def plot_diagnostics(self): for i in range(10): # Plot by flag number - plot_data = self.df.filter( - pl.col("LATITUDE_QC") == i - ) + plot_data = self.df.filter(pl.col("LATITUDE_QC") == i) if len(plot_data) == 0: continue @@ -120,10 +120,10 @@ def plot_diagnostics(self): title="Impossible Speed Test", xlabel="Time (s)", ylabel="Absolute Horizontal Speed (m/s)", - ylim=(0, 4) + ylim=(0, 4), ) ax.axhline(3, ls="--", c="k") ax.legend(title="Flags", loc="upper right") fig.tight_layout() - plt.show(block=True) \ No newline at end of file + plt.show(block=True) diff --git a/src/toolbox/steps/custom/qc/par_irregularity_test.py b/src/toolbox/steps/custom/qc/par_irregularity_test.py index bbf5a62..0ec83a0 100644 --- a/src/toolbox/steps/custom/qc/par_irregularity_test.py +++ b/src/toolbox/steps/custom/qc/par_irregularity_test.py @@ -14,6 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +QC tests to identify irregularities in PAR profiles based on La Forgia & Organelli (2025). +* Shapiro–Wilk test +* Day and night sequences +""" + #### Mandatory imports #### from IPython.core.pylabtools import figsize from toolbox.steps.base_test import BaseTest, register_qc, flag_cols @@ -61,10 +67,13 @@ def calculate_solar_elevation(latitude, longitude, datetime): time_utc = pd.to_datetime(datetime).tz_localize("UTC") # Compute solar position - solar_position = pvlib.solarposition.get_solarposition(time_utc, latitude, longitude) + solar_position = pvlib.solarposition.get_solarposition( + time_utc, latitude, longitude + ) return solar_position["elevation"].values + def qc_par_flagging(pres, par, sun_elev, nei_par=3e-2): """ Real-time quality control (RT-QC) for PAR profiles @@ -249,10 +258,13 @@ def qc_par_flagging(pres, par, sun_elev, nei_par=3e-2): n_bad = np.sum(np.isin(flags, [4, 9])) n_prob_bad = np.sum(flags == 3) profile_flag = ( - 4 if n_bad >= n_good + n_prob_bad else - 1 if n_good / N >= 0.25 else - 2 if np.sum(flags == 2) >= np.sum(flags == 3) else - 3 + 4 + if n_bad >= n_good + n_prob_bad + else 1 + if n_good / N >= 0.25 + else 2 + if np.sum(flags == 2) >= np.sum(flags == 3) + else 3 ) return flags.astype(int), profile_flag, pa @@ -261,18 +273,22 @@ def qc_par_flagging(pres, par, sun_elev, nei_par=3e-2): @register_qc class par_irregularity_test(BaseTest): """ + Wrapper for qc_par_flagging, defining solar_elevation if it is not provided. """ test_name = "PAR irregularity test" - expected_parameters = { - "noise_equivalent_estimate": 3e-2, - "plot_profiles": [] - } - required_variables = ["LATITUDE", "LONGITUDE", "TIME", "PRES", "DOWNWELLING_PAR", "PROFILE_NUMBER"] + expected_parameters = {"noise_equivalent_estimate": 3e-2, "plot_profiles": []} + required_variables = [ + "LATITUDE", + "LONGITUDE", + "TIME", + "PRES", + "DOWNWELLING_PAR", + "PROFILE_NUMBER", + ] qc_outputs = ["DOWNWELLING_PAR_QC"] def return_qc(self): - # Subset the data self.data = self.data[self.required_variables] @@ -280,17 +296,25 @@ def return_qc(self): par_qc = np.full(len(self.data["DOWNWELLING_PAR"]), 0) # Apply the checks across individual profiles - profile_numbers = np.unique(self.data["PROFILE_NUMBER"].dropna(dim="N_MEASUREMENTS")) - for profile_number in tqdm(profile_numbers, colour="green", desc='\033[97mProgress\033[0m', unit="profile"): - + profile_numbers = np.unique( + self.data["PROFILE_NUMBER"].dropna(dim="N_MEASUREMENTS") + ) + for profile_number in tqdm( + profile_numbers, + colour="green", + desc="\033[97mProgress\033[0m", + unit="profile", + ): # Subset the data - profile = self.data.where(self.data["PROFILE_NUMBER"] == profile_number, drop=True) + profile = self.data.where( + self.data["PROFILE_NUMBER"] == profile_number, drop=True + ) # Find the solar elevation solar_elevation = calculate_solar_elevation( profile["LATITUDE"][0].values, profile["LONGITUDE"][0].values, - profile["TIME"][0].values + profile["TIME"][0].values, ) # Apply the QC opperation @@ -298,11 +322,13 @@ def return_qc(self): profile["PRES"], profile["DOWNWELLING_PAR"], solar_elevation, - self.noise_equivalent_estimate + self.noise_equivalent_estimate, ) # Stitch the QC results back into the QC container - profile_element_indices = np.where(self.data["PROFILE_NUMBER"] == profile_number) + profile_element_indices = np.where( + self.data["PROFILE_NUMBER"] == profile_number + ) par_qc[profile_element_indices] = profile_element_qc # any remaining flags that are 0 (unchecked) are updated to 1 (good) @@ -327,17 +353,20 @@ def plot_diagnostics(self): for profile_number, ax in zip(self.plot_profiles, axs.flatten()): # Select the profile data profile = self.data.where( - self.data["PROFILE_NUMBER"] == profile_number, - drop=True + self.data["PROFILE_NUMBER"] == profile_number, drop=True ).dropna(dim="N_MEASUREMENTS", subset=["DOWNWELLING_PAR", "PRES"]) if len(profile["DOWNWELLING_PAR"]) == 0: - ax.legend(title=f"Prof. {profile_number} (data missing)", loc="upper right") + ax.legend( + title=f"Prof. {profile_number} (data missing)", loc="upper right" + ) continue for flag in range(10): # Get the data for each flag and check it isn't empty - plot_data = profile.where(profile["DOWNWELLING_PAR_QC"] == flag, drop=True) + plot_data = profile.where( + profile["DOWNWELLING_PAR_QC"] == flag, drop=True + ) if len(plot_data["DOWNWELLING_PAR"]) == 0: continue @@ -360,4 +389,4 @@ def plot_diagnostics(self): fig.suptitle("PAR irregularity test") fig.tight_layout() - plt.show(block=True) \ No newline at end of file + plt.show(block=True) diff --git a/src/toolbox/steps/custom/qc/position_on_land_test.py b/src/toolbox/steps/custom/qc/position_on_land_test.py index ab16f00..e6bf8ac 100644 --- a/src/toolbox/steps/custom/qc/position_on_land_test.py +++ b/src/toolbox/steps/custom/qc/position_on_land_test.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""QC test that identifies glider positions not located on land and flags accordingly.""" + #### Mandatory imports #### from toolbox.steps.base_test import BaseTest, register_qc, flag_cols @@ -26,6 +28,7 @@ import matplotlib import geopandas + @register_qc class position_on_land_test(BaseTest): """ @@ -41,20 +44,14 @@ class position_on_land_test(BaseTest): qc_outputs = ["LATITUDE_QC", "LONGITUDE_QC"] def return_qc(self): - # Convert to polars self.df = pl.from_pandas( - self.data[self.required_variables].to_dataframe(), - nan_to_null=False + self.data[self.required_variables].to_dataframe(), nan_to_null=False ) # Concat the polygons into a MultiPolygon object - self.world = geopandas.read_file( - get_path("naturalearth.land") - ) - land_polygons = sh.ops.unary_union( - self.world.geometry - ) + self.world = geopandas.read_file(get_path("naturalearth.land")) + land_polygons = sh.ops.unary_union(self.world.geometry) # Check if lat, long coords fall within the area of the land polygons self.df = self.df.with_columns( @@ -63,10 +60,11 @@ def return_qc(self): lambda x: sh.contains_xy( land_polygons, x.struct.field("LONGITUDE").to_numpy(), - x.struct.field("LATITUDE").to_numpy() + x.struct.field("LATITUDE").to_numpy(), ) * 4 - ).replace({0: 1}) + ) + .replace({0: 1}) .alias("LONGITUDE_QC") ) # Add the flags to LATITUDE as well. @@ -76,10 +74,9 @@ def return_qc(self): flags = self.df.select(pl.col("^.*_QC$")) self.flags = xr.Dataset( data_vars={ - col: ("N_MEASUREMENTS", flags[col].to_numpy()) - for col in flags.columns + col: ("N_MEASUREMENTS", flags[col].to_numpy()) for col in flags.columns }, - coords={"N_MEASUREMENTS": self.data["N_MEASUREMENTS"]} + coords={"N_MEASUREMENTS": self.data["N_MEASUREMENTS"]}, ) return self.flags @@ -93,9 +90,7 @@ def plot_diagnostics(self): for i in range(10): # Plot by flag number - plot_data = self.df.filter( - pl.col("LATITUDE_QC") == i - ) + plot_data = self.df.filter(pl.col("LATITUDE_QC") == i) if len(plot_data) == 0: continue @@ -116,4 +111,4 @@ def plot_diagnostics(self): ) ax.legend(title="Flags", loc="upper right") fig.tight_layout() - plt.show(block=True) \ No newline at end of file + plt.show(block=True) diff --git a/src/toolbox/steps/custom/qc/range_test.py b/src/toolbox/steps/custom/qc/range_test.py index fa87ced..1f07cbd 100644 --- a/src/toolbox/steps/custom/qc/range_test.py +++ b/src/toolbox/steps/custom/qc/range_test.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""QC test(s) for flagging based on value ranges.""" + #### Mandatory imports #### import numpy as np from toolbox.steps.base_test import BaseTest, register_qc, flag_cols @@ -30,7 +32,7 @@ class range_test(BaseTest): Target Variable: Any Flag Number: Any Variables Flagged: Any - Checks that a meausrement is within a reasonable range. + Checks that a meausurement is within a reasonable range. EXAMPLE ------- @@ -46,29 +48,38 @@ class range_test(BaseTest): } diagnostics: true """ + test_name = "range test" # Specify if test target variable is user-defined (if True, __init__ has to be redefined) dynamic = True def __init__(self, data, **kwargs): - # Check the necessary kwargs are available required_kwargs = {"variable_ranges", "also_flag", "plot"} if not required_kwargs.issubset(set(kwargs.keys())): - raise KeyError(f"{required_kwargs - set(kwargs.keys())} are missing from {self.test_name} settings") + raise KeyError( + f"{required_kwargs - set(kwargs.keys())} are missing from {self.test_name} settings" + ) # Specify the tests paramters from kwargs (config) - self.expected_parameters = {k: v for k, v in kwargs.items() if k in required_kwargs} - self.required_variables = list(set(self.expected_parameters["variable_ranges"].keys())) + self.expected_parameters = { + k: v for k, v in kwargs.items() if k in required_kwargs + } + self.required_variables = list( + set(self.expected_parameters["variable_ranges"].keys()) + ) self.tested_variables = self.required_variables.copy() if "test_depth_range" in kwargs.keys(): self.required_variables.append("DEPTH") self.test_depth_range = kwargs["test_depth_range"] self.qc_outputs = list( - set(f"{var}_QC" for var in self.tested_variables) | - set(f"{var}_QC" for var in sum(self.expected_parameters["also_flag"].values(), [])) + set(f"{var}_QC" for var in self.tested_variables) + | set( + f"{var}_QC" + for var in sum(self.expected_parameters["also_flag"].values(), []) + ) ) if data is not None: @@ -80,44 +91,44 @@ def __init__(self, data, **kwargs): self.flags = None def return_qc(self): - # Subset the data self.data = self.data[self.required_variables] # If the user specified a depth range, limit the checks to that range if hasattr(self, "test_depth_range"): # TODO: -DEPTH - depth_range_mask = ( - (self.data["DEPTH"] >= self.test_depth_range[0]) & - (self.data["DEPTH"] <= self.test_depth_range[1]) + depth_range_mask = (self.data["DEPTH"] >= self.test_depth_range[0]) & ( + self.data["DEPTH"] <= self.test_depth_range[1] ) else: depth_range_mask = True # Make the empty QC columns for var in self.tested_variables: - self.data[f"{var}_QC"] = (["N_MEASUREMENTS"], np.full(len(self.data[var]), 0)) + self.data[f"{var}_QC"] = ( + ["N_MEASUREMENTS"], + np.full(len(self.data[var]), 0), + ) # Generate the variable-specific flags for var, meta in self.variable_ranges.items(): for flag, bounds in meta.items(): self.data[f"{var}_QC"] = xr.where( ( - depth_range_mask & - (self.data[var] > bounds[0]) & - (self.data[var] < bounds[1]) & - (self.data[f"{var}_QC"] == 0) + depth_range_mask + & (self.data[var] > bounds[0]) + & (self.data[var] < bounds[1]) + & (self.data[f"{var}_QC"] == 0) ), flag, - 0 + 0, ) # Replace all remaining 0s (unchecked) with 1s (good) self.data[f"{var}_QC"] = xr.where( - depth_range_mask & - (self.data[f"{var}_QC"] == 0), + depth_range_mask & (self.data[f"{var}_QC"] == 0), 1, - self.data[f"{var}_QC"] + self.data[f"{var}_QC"], ) # Broadcast the QC found for var into variables specified by "also_flag" @@ -126,7 +137,9 @@ def return_qc(self): self.data[f"{extra_var}_QC"] = self.data[f"{var}_QC"] # Select just the flags - self.flags = self.data[[var_qc for var_qc in self.data.data_vars if "_QC" in var_qc]] + self.flags = self.data[ + [var_qc for var_qc in self.data.data_vars if "_QC" in var_qc] + ] return self.flags @@ -135,7 +148,9 @@ def plot_diagnostics(self): # If not plots were specified if len(self.plot) == 0: - print("WARNING: In 'range test' diagnostics were called but no plots were specified.") + print( + "WARNING: In 'range test' diagnostics were called but no plots were specified." + ) return # Plot the QC output @@ -144,17 +159,18 @@ def plot_diagnostics(self): axs = [axs] for ax, var in zip(axs, self.plot): - # Check that the user specified var exists in the test set if f"{var}_QC" not in self.qc_outputs: - print(f"WARNING: Cannot plot {var}_QC as it was not included in this test.") + print( + f"WARNING: Cannot plot {var}_QC as it was not included in this test." + ) continue for i in range(10): # Plot by flag number - plot_data = self.data[ - [var, "N_MEASUREMENTS"] - ].where(self.data[f"{var}_QC"] == i, drop=True) + plot_data = self.data[[var, "N_MEASUREMENTS"]].where( + self.data[f"{var}_QC"] == i, drop=True + ) if len(plot_data[var]) == 0: continue @@ -182,4 +198,4 @@ def plot_diagnostics(self): ax.legend(title="Flags", loc="upper right") fig.tight_layout() - plt.show(block=True) \ No newline at end of file + plt.show(block=True) diff --git a/src/toolbox/steps/custom/qc/range_test_gross.py b/src/toolbox/steps/custom/qc/range_test_gross.py new file mode 100644 index 0000000..6e1abd1 --- /dev/null +++ b/src/toolbox/steps/custom/qc/range_test_gross.py @@ -0,0 +1,150 @@ +"""Gross Range Test QC Step.""" + +#### Mandatory imports #### +import numpy as np +from toolbox.steps.base_test import BaseTest, register_qc, flag_cols + +#### Custom imports #### +import matplotlib.pyplot as plt +import xarray as xr +import matplotlib + + +# TODO: Could be registered within range_test.py +@register_qc +class gross_range_test(BaseTest): + """ + Outside range test similar to IOOS QC gross range test. Not to be confused with `range test`, which flags within a range. + + Given two values it checks for data points outside of this range and assigns a corresponding flag as defined in the configuration. + + Target Variable: Any + Flag Number: Any + Variables Flagged: Any + + EXAMPLE + ------- + gross range test: + variable_ranges: + TEMP: + 3: [0, 30] # Flags temperature data outside of this range as probably bad (3) + 4: [-2.5, 40] # Flags temperature data outside of this range as bad (4) + CNDC: + 3: [5, 42] + 4: [2, 45] + also_flag: + TEMP: [DOXY] # Flag DOXY based on TEMP flags + """ + + test_name = "gross range test" + dynamic = True + + def __init__(self, data, **kwargs): + required_kwargs = {"variable_ranges", "also_flag"} + if not required_kwargs.issubset(kwargs): + raise KeyError( + f"{required_kwargs - set(kwargs)} missing from gross range test" + ) + + self.variable_ranges = kwargs["variable_ranges"] + self.also_flag = kwargs["also_flag"] + self.plot = kwargs.get("plot", []) # Make plotting optional + + self.required_variables = list(self.variable_ranges.keys()) + self.tested_variables = self.required_variables.copy() + + self.qc_outputs = list( + set(f"{v}_QC" for v in self.tested_variables) + | set(f"{v}_QC" for v in sum(self.also_flag.values(), [])) + ) + + if data is not None: + self.data = data.copy(deep=True) + + def return_qc(self): + """Select data outside of the ranges and flag accordingly.""" + # Subset the data + self.data = self.data[self.required_variables] + + for var in self.tested_variables: + qc = xr.zeros_like(self.data[var], dtype=int) + + # Apply flags from most severe to least + for flag in sorted(self.variable_ranges[var], reverse=True): + low, high = self.variable_ranges[var][flag] + + outside = (self.data[var] < low) | (self.data[var] > high) + + qc = xr.where((qc == 0) & outside, flag, qc) + + # Anything not flagged is good + qc = xr.where(qc == 0, 1, qc) + + self.data[f"{var}_QC"] = qc + + # Propagate flags + for extra_var in self.also_flag.get(var, []): + self.data[f"{extra_var}_QC"] = qc + + # Select just the flags + self.flags = self.data[[v for v in self.data.data_vars if v.endswith("_QC")]] + + return self.flags + + def plot_diagnostics(self): + """Visualise the QC results in a similar manner to range_test""" + matplotlib.use("tkagg") + + # If not plots were specified + if len(self.plot) == 0: + self.log_warn( + "WARNING: In 'range test gross' diagnostics were called but no plots were specified." + ) + return + + # Plot the QC output + fig, axs = plt.subplots(nrows=len(self.plot), figsize=(8, 6), dpi=200) + if len(self.plot) == 1: + axs = [axs] + + for ax, var in zip(axs, self.plot): + # Check that the user specified var exists in the test set + if f"{var}_QC" not in self.qc_outputs: + self.log_warn( + f"WARNING: Cannot plot {var}_QC as it was not included in this test." + ) + continue + + for i in range(10): + # Plot by flag number + plot_data = self.data[[var, "N_MEASUREMENTS"]].where( + self.data[f"{var}_QC"] == i, drop=True + ) + + if len(plot_data[var]) == 0: + continue + + # Plot the data + ax.plot( + plot_data["N_MEASUREMENTS"], + plot_data[var], + c=flag_cols[i], + ls="", + marker="o", + label=f"{i}", + ) + + for bounds in self.variable_ranges[var].values(): + for bound in bounds: + ax.axhline(bound, ls="--", c="k") + + ax.set( + xlabel="Index", + ylabel=var, + title=f"{var} Range Test", + ) + + ax.legend(title="Flags", loc="upper right") + + fig.tight_layout() + plt.show(block=True) diff --git a/src/toolbox/steps/custom/qc/spike_test.py b/src/toolbox/steps/custom/qc/spike_test.py index d82518e..f0d6a35 100644 --- a/src/toolbox/steps/custom/qc/spike_test.py +++ b/src/toolbox/steps/custom/qc/spike_test.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""QC test for flagging using spike/despike detection methods.""" + #### Mandatory imports #### import numpy as np from toolbox.steps.base_test import BaseTest, register_qc, flag_cols @@ -48,24 +50,33 @@ class spike_test(BaseTest): } diagnostics: true """ + test_name = "spike test" # Specify if test target variable is user-defined (if True, __init__ has to be redefined) dynamic = True def __init__(self, data, **kwargs): - # Check the necessary kwargs are available required_kwargs = {"variables", "also_flag", "plot"} if not required_kwargs.issubset(set(kwargs.keys())): - raise KeyError(f"{required_kwargs - set(kwargs.keys())} are missing from {self.test_name} settings") + raise KeyError( + f"{required_kwargs - set(kwargs.keys())} are missing from {self.test_name} settings" + ) # Specify the tests paramters from kwargs (config) - self.expected_parameters = {k: v for k, v in kwargs.items() if k in required_kwargs} - self.required_variables = list(set(self.expected_parameters["variables"].keys())) + ["PROFILE_NUMBER"] + self.expected_parameters = { + k: v for k, v in kwargs.items() if k in required_kwargs + } + self.required_variables = list( + set(self.expected_parameters["variables"].keys()) + ) + ["PROFILE_NUMBER"] self.qc_outputs = list( - set(f"{var}_QC" for var in self.required_variables) | - set(f"{var}_QC" for var in sum(self.expected_parameters["also_flag"].values(), [])) + set(f"{var}_QC" for var in self.required_variables) + | set( + f"{var}_QC" + for var in sum(self.expected_parameters["also_flag"].values(), []) + ) ) if data is not None: @@ -79,21 +90,27 @@ def __init__(self, data, **kwargs): self.flags = None def return_qc(self): - # Subset the data self.data = self.data[self.required_variables] # Generate the variable-specific flags for var, sensitivity in self.variables.items(): - spike_qc = np.full(len(self.data[var]), 0) # Apply the checks across individual profiles - profile_numbers = np.unique(self.data["PROFILE_NUMBER"].dropna(dim="N_MEASUREMENTS")) - for profile_number in tqdm(profile_numbers, colour="green", desc=f'\033[97mProgress [{var}]\033[0m', unit="prof"): - + profile_numbers = np.unique( + self.data["PROFILE_NUMBER"].dropna(dim="N_MEASUREMENTS") + ) + for profile_number in tqdm( + profile_numbers, + colour="green", + desc=f"\033[97mProgress [{var}]\033[0m", + unit="prof", + ): # Subset the data - profile = self.data.where(self.data["PROFILE_NUMBER"] == profile_number, drop=True) + profile = self.data.where( + self.data["PROFILE_NUMBER"] == profile_number, drop=True + ) # remove nans var_data = profile[var].dropna(dim="N_MEASUREMENTS") @@ -101,10 +118,12 @@ def return_qc(self): continue # Calculate the residules from the running median of the data - rolling_median = var_data.to_pandas().rolling( - window=self.window_size, - center=True - ).median().to_numpy() + rolling_median = ( + var_data.to_pandas() + .rolling(window=self.window_size, center=True) + .median() + .to_numpy() + ) residules = var_data - rolling_median # Define the residule threshold @@ -119,7 +138,9 @@ def return_qc(self): profile_flags[np.where(~nan_mask)] = spike_flags # Stitch the QC results back into the QC container - profile_indices = np.where(self.data["PROFILE_NUMBER"] == profile_number) + profile_indices = np.where( + self.data["PROFILE_NUMBER"] == profile_number + ) spike_qc[profile_indices] = profile_flags # Add the flags to the data @@ -131,7 +152,9 @@ def return_qc(self): self.data[f"{extra_var}_QC"] = self.data[f"{var}_QC"] # Select just the flags - self.flags = self.data[[var_qc for var_qc in self.data.data_vars if "_QC" in var_qc]] + self.flags = self.data[ + [var_qc for var_qc in self.data.data_vars if "_QC" in var_qc] + ] return self.flags @@ -140,26 +163,31 @@ def plot_diagnostics(self): # If not plots were specified if len(self.plot) == 0: - print(f"WARNING: In '{self.test_name}', diagnostics were called but no variables were specified for plotting.") + print( + f"WARNING: In '{self.test_name}', diagnostics were called but no variables were specified for plotting." + ) return # Plot the QC output - fig, axs = plt.subplots(nrows=len(self.plot), figsize=(8, 6), sharex=True, dpi=200) + fig, axs = plt.subplots( + nrows=len(self.plot), figsize=(8, 6), sharex=True, dpi=200 + ) if len(self.plot) == 1: axs = [axs] for ax, var in zip(axs, self.plot): - # Check that the user specified var exists in the test set if f"{var}_QC" not in self.qc_outputs: - print(f"WARNING: Cannot plot {var}_QC as it was not included in this test.") + print( + f"WARNING: Cannot plot {var}_QC as it was not included in this test." + ) continue for i in range(10): # Plot by flag number - plot_data = self.data[ - [var, "N_MEASUREMENTS"] - ].where(self.data[f"{var}_QC"] == i, drop=True) + plot_data = self.data[[var, "N_MEASUREMENTS"]].where( + self.data[f"{var}_QC"] == i, drop=True + ) if len(plot_data[var]) == 0: continue @@ -183,4 +211,4 @@ def plot_diagnostics(self): ax.legend(title="Flags", loc="upper right") fig.tight_layout() - plt.show(block=True) \ No newline at end of file + plt.show(block=True) diff --git a/src/toolbox/steps/custom/qc/stuck_value_test.py b/src/toolbox/steps/custom/qc/stuck_value_test.py index f5dab20..67c331a 100644 --- a/src/toolbox/steps/custom/qc/stuck_value_test.py +++ b/src/toolbox/steps/custom/qc/stuck_value_test.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""QC test(s) for flagging stuck, static, or otherwise unchanged data (which should be changing).""" + #### Mandatory imports #### import numpy as np from toolbox.steps.base_test import BaseTest, register_qc, flag_cols @@ -45,24 +47,33 @@ class stuck_value_test(BaseTest): } diagnostics: true """ + test_name = "stuck value test" # Specify if test target variable is user-defined (if True, __init__ has to be redefined) dynamic = True def __init__(self, data, **kwargs): - # Check the necessary kwargs are available required_kwargs = {"variables", "also_flag", "plot"} if not required_kwargs.issubset(set(kwargs.keys())): - raise KeyError(f"{required_kwargs - set(kwargs.keys())} are missing from {self.test_name} settings") + raise KeyError( + f"{required_kwargs - set(kwargs.keys())} are missing from {self.test_name} settings" + ) # Specify the tests paramters from kwargs (config) - self.expected_parameters = {k: v for k, v in kwargs.items() if k in required_kwargs} - self.required_variables = list(set(self.expected_parameters["variables"].keys())) + self.expected_parameters = { + k: v for k, v in kwargs.items() if k in required_kwargs + } + self.required_variables = list( + set(self.expected_parameters["variables"].keys()) + ) self.qc_outputs = list( - set(f"{var}_QC" for var in self.required_variables) | - set(f"{var}_QC" for var in sum(self.expected_parameters["also_flag"].values(), [])) + set(f"{var}_QC" for var in self.required_variables) + | set( + f"{var}_QC" + for var in sum(self.expected_parameters["also_flag"].values(), []) + ) ) if data is not None: @@ -74,13 +85,11 @@ def __init__(self, data, **kwargs): self.flags = None def return_qc(self): - # Subset the data self.data = self.data[self.required_variables] # Generate the variable-specific flags for var, n_stuck in self.variables.items(): - # remove nans var_data = self.data[var].dropna(dim="N_MEASUREMENTS") @@ -93,9 +102,7 @@ def return_qc(self): # Handle edge cases for index, step in zip([0, -1], [1, -1]): - stuck_value_mask[index] = ( - var_data[index] == var_data[index + step] - ) + stuck_value_mask[index] = var_data[index] == var_data[index + step] # The remaining processing has to be in int dtype stuck_value_mask = stuck_value_mask.astype(int) @@ -110,7 +117,7 @@ def return_qc(self): stuck_value_mask[start:end] = end - start # Convert the stuck values mask into flags - bad_values = (stuck_value_mask > n_stuck) + bad_values = stuck_value_mask > n_stuck stuck_value_mask[bad_values] = 4 stuck_value_mask[~bad_values] = 1 @@ -125,7 +132,9 @@ def return_qc(self): self.data[f"{extra_var}_QC"] = self.data[f"{var}_QC"] # Select just the flags - self.flags = self.data[[var_qc for var_qc in self.data.data_vars if "_QC" in var_qc]] + self.flags = self.data[ + [var_qc for var_qc in self.data.data_vars if "_QC" in var_qc] + ] return self.flags @@ -134,26 +143,31 @@ def plot_diagnostics(self): # If not plots were specified if len(self.plot) == 0: - print(f"WARNING: In '{self.test_name}', diagnostics were called but no variables were specified for plotting.") + print( + f"WARNING: In '{self.test_name}', diagnostics were called but no variables were specified for plotting." + ) return # Plot the QC output - fig, axs = plt.subplots(nrows=len(self.plot), figsize=(8, 6), sharex=True, dpi=200) + fig, axs = plt.subplots( + nrows=len(self.plot), figsize=(8, 6), sharex=True, dpi=200 + ) if len(self.plot) == 1: axs = [axs] for ax, var in zip(axs, self.plot): - # Check that the user specified var exists in the test set if f"{var}_QC" not in self.qc_outputs: - print(f"WARNING: Cannot plot {var}_QC as it was not included in this test.") + print( + f"WARNING: Cannot plot {var}_QC as it was not included in this test." + ) continue for i in range(10): # Plot by flag number - plot_data = self.data[ - [var, "N_MEASUREMENTS"] - ].where(self.data[f"{var}_QC"] == i, drop=True) + plot_data = self.data[[var, "N_MEASUREMENTS"]].where( + self.data[f"{var}_QC"] == i, drop=True + ) if len(plot_data[var]) == 0: continue @@ -177,4 +191,4 @@ def plot_diagnostics(self): ax.legend(title="Flags", loc="upper right") fig.tight_layout() - plt.show(block=True) \ No newline at end of file + plt.show(block=True) diff --git a/src/toolbox/steps/custom/qc/valid_profile_test.py b/src/toolbox/steps/custom/qc/valid_profile_test.py index 93fe402..fac5676 100644 --- a/src/toolbox/steps/custom/qc/valid_profile_test.py +++ b/src/toolbox/steps/custom/qc/valid_profile_test.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""QC tests for assessing validity of a glider profile, based on different definitions of successful data.""" + #### Mandatory imports #### from toolbox.steps.base_test import BaseTest, register_qc, flag_cols @@ -23,12 +25,13 @@ import xarray as xr import matplotlib + @register_qc class valid_profile_test(BaseTest): """ - Target Variable: PROFILE_NUMEBER + Target Variable: PROFILE_NUMBER Flag Number: 4 (bad data), 3 (potentially bad) - Variables Flagged: PROFILE_NUMEBER + Variables Flagged: PROFILE_NUMBER Checks that each profile is of a certain length (in number of points) and contains points within a specified depth range. """ @@ -42,11 +45,9 @@ class valid_profile_test(BaseTest): qc_outputs = ["PROFILE_NUMBER"] def return_qc(self): - # Convert to polars self.df = pl.from_pandas( - self.data[self.required_variables].to_dataframe(), - nan_to_null=False + self.data[self.required_variables].to_dataframe(), nan_to_null=False ) # Check profiles are of a given length @@ -55,20 +56,19 @@ def return_qc(self): # Find profiles that have no data between the sepcified depth ranges profile_ranges = self.df.group_by("PROFILE_NUMBER").agg( - (pl.col("DEPTH").is_between(*self.depth_range).any()).alias("in_depth_range") + (pl.col("DEPTH").is_between(*self.depth_range).any()).alias( + "in_depth_range" + ) ) self.df = self.df.join(profile_ranges, on="PROFILE_NUMBER", how="left") self.df = self.df.with_columns( - pl.when( - pl.col("PROFILE_NUMBER").is_nan() - ).then(9) - .when( - pl.col("count") < self.profile_length - ).then(4) - .when( - pl.col("in_depth_range").not_() - ).then(3) + pl.when(pl.col("PROFILE_NUMBER").is_nan()) + .then(9) + .when(pl.col("count") < self.profile_length) + .then(4) + .when(pl.col("in_depth_range").not_()) + .then(3) .otherwise(1) .alias("PROFILE_NUMBER_QC") ) @@ -77,10 +77,9 @@ def return_qc(self): flags = self.df.select(pl.col("^.*_QC$")) self.flags = xr.Dataset( data_vars={ - col: ("N_MEASUREMENTS", flags[col].to_numpy()) - for col in flags.columns + col: ("N_MEASUREMENTS", flags[col].to_numpy()) for col in flags.columns }, - coords={"N_MEASUREMENTS": self.data["N_MEASUREMENTS"]} + coords={"N_MEASUREMENTS": self.data["N_MEASUREMENTS"]}, ) return self.flags @@ -116,4 +115,3 @@ def plot_diagnostics(self): fig.tight_layout() plt.show(block=True) - diff --git a/src/toolbox/utils/qc_handling.py b/src/toolbox/utils/qc_handling.py index ad445c5..741e5a9 100644 --- a/src/toolbox/utils/qc_handling.py +++ b/src/toolbox/utils/qc_handling.py @@ -77,7 +77,12 @@ def filter_qc(self): def reconstruct_data(self): """ - Reconstruct data by replacing flagged values with original values + Reconstruct data by replacing flagged values with original values. + + raises + ------ + KeyError + If the specified behaviour is not specified in this method. """ if self.behaviour == "replace": pass @@ -119,7 +124,12 @@ def update_qc(self): def generate_qc(self, qc_constituents: dict): """ - Generate QC flags for child variables based on parent variables' QC flags + Generate QC flags for child variables based on parent variables' QC flags. + + parameters + ---------- + qc_constituents : dict + A dictionary mapping child QC variable names to lists of parent QC variable names. """ # Unpack the parent qc for qc_child, qc_parents in qc_constituents.items(): diff --git a/tests/test_impossible_date_test.py b/tests/test_impossible_date_test.py new file mode 100644 index 0000000..4f9bf91 --- /dev/null +++ b/tests/test_impossible_date_test.py @@ -0,0 +1,41 @@ +import pytest +import xarray as xr +import polars as pl +import numpy as np +import pandas as pd +from datetime import datetime +from toolbox.steps.custom.qc.impossible_date_test import impossible_date_test + +# Helper function to create a test xarray Dataset +def create_test_dataset(times): + return xr.Dataset( + { + "TIME": ("N_MEASUREMENTS", times), + }, + coords={"N_MEASUREMENTS": range(len(times))}, + ) + +def test_all_valid_dates(): + times_good = pd.date_range(start="2000-01-01", periods=10, freq="D") + data = create_test_dataset(times_good) + test = impossible_date_test(data) + flags = test.return_qc() + assert (flags["TIME_QC"] == 1).all() + + bad_times = pd.date_range(start="1900-01-01", periods=10, freq="D") + data = create_test_dataset(bad_times) + test = impossible_date_test(data) + flags = test.return_qc() + assert (flags["TIME_QC"] == 4).all() + + others = pd.to_datetime( + ["1677-12-07", "1989-11-09", "1978-09-21", "2020-01-01", "1975-11-10"] + ) + expected_flags = np.array([4, 1, 4, 1, 4]) + data = create_test_dataset(others) + test = impossible_date_test(data) + flags = test.return_qc() + assert (flags["TIME_QC"] == expected_flags).all() + +# test_plot_diagnostics +