diff --git a/_version.py b/_version.py index f17f8b4..37ca42f 100644 --- a/_version.py +++ b/_version.py @@ -28,7 +28,7 @@ commit_id: COMMIT_ID __commit_id__: COMMIT_ID -__version__ = version = '0.1.dev140+ge3e4fd1c6.d20260301' -__version_tuple__ = version_tuple = (0, 1, 'dev140', 'ge3e4fd1c6.d20260301') +__version__ = version = '0.1.dev161+g47a99b31a.d20260304' +__version_tuple__ = version_tuple = (0, 1, 'dev161', 'g47a99b31a.d20260304') -__commit_id__ = commit_id = 'ge3e4fd1c6' +__commit_id__ = commit_id = 'g47a99b31a' diff --git a/negate/__init__.py b/negate/__init__.py index 767aca6..08c769a 100644 --- a/negate/__init__.py +++ b/negate/__init__.py @@ -1,29 +1,99 @@ # SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 # -# ruff: noqa - -from datasets import logging as ds_logging, disable_progress_bars as ds_disable_progress_bars -from huggingface_hub import logging as hf_logging -from huggingface_hub.utils.tqdm import disable_progress_bars as hf_disable_progress_bars -from transformers import logging as tf_logging -from diffusers.utils import logging as df_logging -from timm.utils.log import setup_default_logging + +from __future__ import annotations + +import importlib import logging import warnings +from typing import Any + +__all__ = [ + "Blurb", + "InferContext", + "Spec", + "build_train_call", + "chart_decompositions", + "compute_weighted_certainty", + "configure_runtime_logging", + "end_processing", + "infer_origin", + "load_metadata", + "load_spec", + "pretrain", + "root_folder", + "run_training_statistics", + "save_features", + "save_train_result", + "train_model", + "training_loop", +] + +_ATTR_SOURCES = { + "Blurb": ("negate.io.blurb", "Blurb"), + "InferContext": ("negate.inference", "InferContext"), + "Spec": ("negate.io.spec", "Spec"), + "build_train_call": ("negate.train", "build_train_call"), + "chart_decompositions": ("negate.metrics.track", "chart_decompositions"), + "compute_weighted_certainty": ("negate.metrics.heuristics", "compute_weighted_certainty"), + "end_processing": ("negate.io.save", "end_processing"), + "infer_origin": ("negate.inference", "infer_origin"), + "load_metadata": ("negate.io.spec", "load_metadata"), + "load_spec": ("negate.io.spec", "load_spec"), + "pretrain": ("negate.train", "pretrain"), + "root_folder": ("negate.io.config", "root_folder"), + "run_training_statistics": ("negate.metrics.track", "run_training_statistics"), + "save_features": ("negate.io.save", "save_features"), + "save_train_result": ("negate.io.save", "save_train_result"), + "train_model": ("negate.train", "train_model"), + "training_loop": ("negate.train", "training_loop"), +} + +_LOGGING_CONFIGURED = False + + +def configure_runtime_logging() -> None: + """Apply quiet logging defaults for third-party ML stacks.""" + + global _LOGGING_CONFIGURED + if _LOGGING_CONFIGURED: + return + + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings("ignore", category=DeprecationWarning) + + try: + from datasets import logging as ds_logging, disable_progress_bars as ds_disable_progress_bars + from diffusers.utils import logging as df_logging + from huggingface_hub import logging as hf_logging + from huggingface_hub.utils.tqdm import disable_progress_bars as hf_disable_progress_bars + from timm.utils.log import setup_default_logging + from transformers import logging as tf_logging + except Exception: + # Keep startup resilient when optional deps are absent. + _LOGGING_CONFIGURED = True + return + + setup_default_logging(logging.ERROR) + for logger in [df_logging, ds_logging, hf_logging, tf_logging]: + logger.set_verbosity_error() -warnings.filterwarnings("ignore", category=UserWarning) -warnings.filterwarnings("ignore", category=DeprecationWarning) -setup_default_logging(logging.ERROR) -for logger in [df_logging, ds_logging, hf_logging, tf_logging]: - logger.set_verbosity_error() ds_disable_progress_bars() hf_disable_progress_bars() + _LOGGING_CONFIGURED = True + + +def __getattr__(name: str) -> Any: + source = _ATTR_SOURCES.get(name) + if source is None: + raise AttributeError(name) + + module_name, attr_name = source + module = importlib.import_module(module_name) + value = getattr(module, attr_name) + globals()[name] = value + return value + -from negate.io.blurb import Blurb -from negate.io.config import root_folder -from negate.io.spec import Spec, load_spec, load_metadata -from negate.metrics.track import chart_decompositions, run_training_statistics -from negate.train import build_train_call, pretrain, train_model, training_loop -from negate.io.save import save_train_result, end_processing, save_features -from negate.metrics.heuristics import compute_weighted_certainty -from negate.inference import infer_origin, InferContext +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/negate/__main__.py b/negate/__main__.py index 9a54f2d..dfd49a7 100644 --- a/negate/__main__.py +++ b/negate/__main__.py @@ -1,87 +1,218 @@ # SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 # -# type: ignore - -"""Command-line interface entry point for Negate package.\n -Handles CLI parsing, dataset loading, preprocessing, and result saving. -Supports 'inference' and 'train' subcommand with automatic timestamping. - - → Dataset (images) - → io/ops (load) - → wavelet.py (decompose) - → feature_{vit,vae}.py + residuals.py (extract) - → train.py (XGBoost grade) - → inference.py (predict from data) - → track.py (plotting/metrics) -""" - from __future__ import annotations import argparse -import os +import logging import re import time as timer_module +import tomllib +from dataclasses import dataclass, field from pathlib import Path from sys import argv -from dataclasses import dataclass -from tqdm import tqdm - - -from negate import ( - Blurb, - InferContext, - Spec, - build_train_call, - chart_decompositions, - compute_weighted_certainty, - end_processing, - infer_origin, - load_metadata, - load_spec, - pretrain, - root_folder, - run_training_statistics, - save_features, - save_train_result, - train_model, - training_loop, -) - +from typing import Any + +ROOT_FOLDER = Path(__file__).resolve().parent.parent +CONFIG_PATH = ROOT_FOLDER / "config" +BLURB_PATH = CONFIG_PATH / "blurb.toml" +CONFIG_TOML_PATH = CONFIG_PATH / "config.toml" +TIMESTAMP_PATTERN = re.compile(r"\d{8}_\d{6}") +DEFAULT_INFERENCE_PAIR = ["20260225_185933", "20260225_221149"] start_ns = timer_module.perf_counter() +CLI_LOGGER = logging.getLogger("negate.cli") +if not CLI_LOGGER.handlers: + _handler = logging.StreamHandler() + _handler.setFormatter(logging.Formatter("%(message)s")) + CLI_LOGGER.addHandler(_handler) +CLI_LOGGER.setLevel(logging.INFO) +CLI_LOGGER.propagate = False + + +@dataclass +class BlurbText: + """CLI help text defaults loaded from config/blurb.toml.""" + + # Commands + pretrain: str = "Analyze and graph performance..." + train: str = "Train XGBoost model..." + infer: str = "Infer whether features..." + + # Flags + loop: str = "Toggle training across the range..." + features_load: str = "Train from an existing set of features" + verbose: str = "Verbose console output" + label_syn: str = "Mark image as synthetic (label = 1) for evaluation." + label_gne: str = "Mark image as genuine (label = 0) for evaluation." + + # Dataset paths + gne_path: str = "Genunie/Human-origin image dataset path" + syn_path: str = "Synthetic image dataset path" + unidentified_path: str = "Path to the image or directory containing images of unidentified origin" + + # Verbose output + verbose_status: str = "Checking path " + verbose_dated: str = " using models dated " + + # Errors + infer_path_error: str = "Infer requires an image path." + model_error: str = "Warning: No valid model directories found in " + model_error_hint: str = " Create or add a trained model before running inference." + model_pair: str = "Two models must be provided for inference..." + model_pattern: str = "Model format must match pattern YYYYMMDD_HHMMSS..." + + # Shared phrasing + model_desc: str = "model to use. Default : " + + +@dataclass +class ModelChoices: + """Model and VAE choices inferred from config/config.toml.""" + + default_vit: str = "" + default_vae: str = "" + model_choices: list[str] = field(default_factory=list) + ae_choices: list[str] = field(default_factory=list) @dataclass class CmdContext: - """Container for main() arguments passed to cmd().""" + """Container for parsed arguments and runtime dependencies.""" args: argparse.Namespace - blurb: Blurb - spec: Spec + blurb: Any + spec: Any results_path: Path models_path: Path list_model: list[str] | None -def cmd(ctx: CmdContext) -> None: # -> list[dict[str, str | float | int]] - """Process command arguments\n - :raises ValueError: Missing image path. - :raises ValueError: Invalid VAE choice. - :raises NotImplementedError: Unsupported command passed. - """ +def load_spec(model_version: str | Path = "config"): + """Backwards-compatible export used by tests and callers.""" + + from negate.io.spec import load_spec as _load_spec + + return _load_spec(str(model_version)) + + +def _list_timestamp_dirs(path: Path) -> list[str]: + if not path.exists(): + return [] + entries = [entry.name for entry in path.iterdir() if entry.is_dir() and TIMESTAMP_PATTERN.fullmatch(entry.name)] + entries.sort(reverse=True) + return entries + + +def _load_blurb_text() -> BlurbText: + blurb = BlurbText() + if not BLURB_PATH.exists(): + return blurb + + with open(BLURB_PATH, "rb") as blurb_file: + data = tomllib.load(blurb_file) + + for key, value in data.items(): + if hasattr(blurb, key): + setattr(blurb, key, value) + return blurb + + +def _load_model_choices() -> ModelChoices: + choices = ModelChoices() + if not CONFIG_TOML_PATH.exists(): + return choices + + with open(CONFIG_TOML_PATH, "rb") as config_file: + data = tomllib.load(config_file) + + model_library = data.get("model", {}).get("library", {}) + for configured_models in model_library.values(): + if isinstance(configured_models, list): + choices.model_choices.extend(str(model_name) for model_name in configured_models) + elif isinstance(configured_models, str): + choices.model_choices.append(configured_models) + if choices.model_choices: + choices.default_vit = choices.model_choices[0] + + vae_library = data.get("vae", {}).get("library", {}) + for configured_vae in vae_library.values(): + if isinstance(configured_vae, list) and configured_vae: + choices.ae_choices.append(str(configured_vae[0])) + + if "" not in choices.ae_choices: + choices.ae_choices.append("") + if choices.ae_choices: + choices.default_vae = choices.ae_choices[0] + return choices + + +def _build_parser(blurb: BlurbText, choices: ModelChoices, list_results: list[str], list_model: list[str], inference_pair: list[str]) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Negate CLI") + subparsers = parser.add_subparsers(dest="cmd", required=True) + + pretrain_parser = subparsers.add_parser("pretrain", help=blurb.pretrain) + train_parser = subparsers.add_parser("train", help=blurb.train) + train_parser.add_argument("-l", "--loop", action="store_true", help=blurb.loop) + train_parser.add_argument("-f", "--features", choices=list_results, default=None, help=blurb.features_load) + + vit_help = f"Vison {blurb.model_desc} {choices.default_vit}".strip() + ae_help = f"Autoencoder {blurb.model_desc} {choices.default_vae}".strip() + infer_model_help = f"Trained {blurb.model_desc} {inference_pair}".strip() + + for sub in [pretrain_parser, train_parser]: + sub.add_argument("path", help=blurb.gne_path, nargs="?", default=None) + sub.add_argument("-s", "--syn", help=blurb.syn_path, nargs="?", default=None) + if choices.model_choices: + sub.add_argument("-m", "--model", choices=choices.model_choices, default=choices.default_vit, help=vit_help) + else: + sub.add_argument("-m", "--model", default=choices.default_vit, help=vit_help) + if choices.ae_choices: + sub.add_argument("-a", "--ae", choices=choices.ae_choices, default=choices.default_vae, help=ae_help) + else: + sub.add_argument("-a", "--ae", default=choices.default_vae, help=ae_help) + + infer_parser = subparsers.add_parser("infer", help=blurb.infer) + infer_parser.add_argument("path", help=blurb.unidentified_path) + if list_model: + infer_parser.add_argument("-m", "--model", choices=list_model, default=inference_pair, nargs="+", help=infer_model_help) + else: + infer_parser.add_argument("-m", "--model", choices=None, default=None, nargs="+") + + label_grp = infer_parser.add_mutually_exclusive_group() + label_grp.add_argument("-g", "--genuine", action="store_const", const=0, dest="label", help=blurb.label_gne) + label_grp.add_argument("-s", "--synthetic", action="store_const", const=1, dest="label", help=blurb.label_syn) + infer_parser.add_argument("-v", "--verbose", action="store_true", help=blurb.verbose) + + return parser + + +def cmd(ctx: CmdContext) -> None: args = ctx.args + + from negate import configure_runtime_logging + + configure_runtime_logging() + match args.cmd: case "pretrain": + from negate.io.save import end_processing, save_features + from negate.metrics.track import chart_decompositions + from negate.train import build_train_call, pretrain + origin_ds = build_train_call(args=args, path_result=ctx.results_path, spec=ctx.spec) features_ds = pretrain(origin_ds, ctx.spec) end_processing("Pretraining", start_ns) save_features(features_ds) chart_decompositions(features_dataset=features_ds, spec=ctx.spec) + case "train": + from negate.io.save import end_processing, save_train_result + from negate.metrics.track import run_training_statistics + from negate.train import build_train_call, train_model, training_loop + origin_ds = build_train_call(args=args, path_result=ctx.results_path, spec=ctx.spec) if args.loop is True: training_loop(image_ds=origin_ds, spec=ctx.spec) - else: train_result = train_model(features_ds=origin_ds, spec=ctx.spec) timecode = end_processing("Training", start_ns) @@ -89,111 +220,122 @@ def cmd(ctx: CmdContext) -> None: # -> list[dict[str, str | float | int]] run_training_statistics(train_result=train_result, timecode=timecode, spec=ctx.spec) case "infer": + from tqdm import tqdm + + from negate.inference import InferContext, infer_origin, preprocessing + from negate.io.datasets import generate_dataset + from negate.io.spec import load_metadata + from negate.metrics.heuristics import compute_weighted_certainty + if args.path is None: raise ValueError(ctx.blurb.infer_path_error) if ctx.list_model is None or not ctx.list_model: raise ValueError(f"{ctx.blurb.model_error} {ctx.models_path} {ctx.blurb.model_error_hint}") - img_file_or_folder: Path = Path(args.path) - assert isinstance(args.model, list) or isinstance(args.model, tuple), ValueError(ctx.blurb.model_pair) - negate_models = {} + + img_file_or_folder = Path(args.path) + if not isinstance(args.model, list) and not isinstance(args.model, tuple): + raise ValueError(ctx.blurb.model_pair) + + negate_models: dict[str, Path] = {} + model_specs: dict[str, Any] = {} + model_metadata: dict[str, Any] = {} for saved_model in args.model: negate_models[saved_model] = ctx.models_path / saved_model - assert negate_models[saved_model].exists(), ValueError(ctx.blurb.model_pattern) + if not negate_models[saved_model].exists(): + raise ValueError(ctx.blurb.model_pattern) + model_specs[saved_model] = load_spec(saved_model) + model_metadata[saved_model] = load_metadata(saved_model) + if args.verbose: import warnings warnings.filterwarnings("default", category=UserWarning) warnings.filterwarnings("default", category=DeprecationWarning) - print(f"{ctx.blurb.verbose_status} {img_file_or_folder}' {ctx.blurb.verbose_dated} {args.model}") + CLI_LOGGER.info(f"{ctx.blurb.verbose_status} {img_file_or_folder}' {ctx.blurb.verbose_dated} {args.model}") + + CLI_LOGGER.info("Preparing feature dataset and loading selected models...") + origin_ds = generate_dataset(img_file_or_folder, verbose=args.verbose) + feature_cache: dict[str, Any] = {} + feature_key_by_model: dict[str, str] = {} + for saved_model, model_spec in model_specs.items(): + feature_key = "|".join( + [ + str(model_spec.model), + str(model_spec.vae), + str(model_spec.dtype), + str(model_spec.device), + str(model_spec.opt.dim_factor), + str(model_spec.opt.dim_patch), + str(model_spec.opt.top_k), + str(model_spec.opt.condense_factor), + str(model_spec.opt.alpha), + str(model_spec.opt.magnitude_sampling), + ] + ) + feature_key_by_model[saved_model] = feature_key + if feature_key not in feature_cache: + feature_cache[feature_key] = preprocessing(origin_ds, model_spec, verbose=args.verbose) inference_result = {} - for saved_model, model_data in tqdm(negate_models.items(), disable=args.verbose): - if isinstance(model_data, str): - model_data = Path(model_data) + for saved_model, model_data in tqdm( + negate_models.items(), + total=len(negate_models), + desc="Running inference with each selected model", + disable=False, + ): context = InferContext( - spec=load_spec(saved_model), + spec=model_specs[saved_model], model_version=model_data, - train_metadata=load_metadata(saved_model), + train_metadata=model_metadata[saved_model], label=args.label, file_or_folder_path=img_file_or_folder, - dataset_feat=None, + dataset_feat=feature_cache[feature_key_by_model[saved_model]], run_heuristics=False, model=True, verbose=args.verbose, ) inference_result[saved_model] = infer_origin(context) - inference_results = (v for _, v in inference_result.items()) - compute_weighted_certainty( - *inference_results, - label=args.label, - ) - # return inferences + inference_results = (result for _, result in inference_result.items()) + compute_weighted_certainty(*inference_results, label=args.label) case _: raise NotImplementedError -def main(): - """CLI argument parser and command dispatcher.\n - :raises ValueError: Missing image path. - :raises ValueError: Invalid VAE choice. - :raises NotImplementedError: Unsupported command passed. - """ - - spec = Spec() - blurb = Blurb(spec) - models_path = root_folder / "models" - results_path = root_folder / "results" - - inference_pair = ["20260225_185933", "20260225_221149"] # [FLUX-AE, DC-AE] - - list_results: list[str] = [] - if len(os.listdir(results_path)) > 0: - list_results = [str(folder.stem) for folder in Path(results_path).iterdir() if folder.is_dir() and re.fullmatch(r"\d{8}_\d{6}", folder.stem)] - list_results.sort(reverse=True) +def main() -> None: + blurb_text = _load_blurb_text() + model_choices = _load_model_choices() - inference_pair = ["20260225_185933", "20260225_221149"] # [FLUX-AE, DC-AE] + models_path = ROOT_FOLDER / "models" + results_path = ROOT_FOLDER / "results" + list_results = _list_timestamp_dirs(results_path) + list_model = _list_timestamp_dirs(models_path) + inference_pair = list_model[:2] if len(list_model) >= 2 else DEFAULT_INFERENCE_PAIR - list_results = [] - if len(os.listdir(results_path)) > 0: - list_results = [str(folder.stem) for folder in Path(results_path).iterdir() if folder.is_dir() and re.fullmatch(r"\d{8}_\d{6}", folder.stem)] - list_results.sort(reverse=True) - - parser = argparse.ArgumentParser(description="Negate CLI") - subparsers = parser.add_subparsers(dest="cmd", required=True) - - pretrain_parser = subparsers.add_parser("pretrain", help=blurb.pretrain) - train_parser = subparsers.add_parser("train", help=blurb.train) - train_parser.add_argument("-l", "--loop", action="store_true", help=blurb.loop) - train_parser.add_argument("-f", "--features", choices=list_results, default=None, help=blurb.features_load) - - for sub in [pretrain_parser, train_parser]: - sub.add_argument("gne_path", help=blurb.gne_path, nargs="?", default=None) - sub.add_argument("-s", "--syn", help=blurb.syn_path, nargs="?", default=None) - sub.add_argument("-m", "--model", choices=blurb.model_choices, default=blurb.default_vit, help=blurb.vit_model_blurb()) - sub.add_argument("-a", "--ae", choices=blurb.ae_choices, default=blurb.default_vae, help=blurb.ae_model_blurb()) - - infer_parser = subparsers.add_parser("infer", help=blurb.infer) - infer_parser.add_argument("path", help=blurb.unidentified_path) - if len(os.listdir(models_path)) > 0: - list_model = [str(folder.stem) for folder in Path(models_path).iterdir() if folder.is_dir() and re.fullmatch(r"\d{8}_\d{6}", folder.stem)] - list_model.sort(reverse=True) - if list_model: - infer_parser.add_argument("-m", "--model", choices=list_model, default=inference_pair, help=blurb.infer_model_blurb(inference_pair)) - else: - list_model = None - infer_parser.add_argument("-m", "--model", choices=None, default=None) - - label_grp = infer_parser.add_mutually_exclusive_group() - label_grp.add_argument("-g", "--genuine", action="store_const", const=0, dest="label", help=blurb.label_gne) - label_grp.add_argument("-s", "--synthetic", action="store_const", const=1, dest="label", help=blurb.label_syn) - - infer_parser.add_argument("-v", "--verbose", action="store_true", help=blurb.verbose) + parser = _build_parser( + blurb=blurb_text, + choices=model_choices, + list_results=list_results, + list_model=list_model, + inference_pair=inference_pair, + ) args = parser.parse_args(argv[1:]) - cmd_context = CmdContext(args=args, blurb=blurb, spec=spec, results_path=results_path, models_path=models_path, list_model=list_model) + from negate.io.blurb import Blurb + from negate.io.spec import Spec + + spec = Spec() + blurb = Blurb(spec) + cmd_context = CmdContext( + args=args, + blurb=blurb, + spec=spec, + results_path=results_path, + models_path=models_path, + list_model=list_model if list_model else None, + ) cmd(cmd_context) diff --git a/negate/inference.py b/negate/inference.py index 07d3fa6..509ec0a 100644 --- a/negate/inference.py +++ b/negate/inference.py @@ -2,13 +2,14 @@ # import pickle +import logging from dataclasses import asdict, dataclass from pathlib import Path from typing import Any import numpy as np import onnxruntime as ort -from datasets import Dataset, enable_progress_bar, logging +from datasets import Dataset from onnxruntime.capi.onnxruntime_pybind11_state import Fail as ONNXRuntimeError from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument @@ -16,9 +17,12 @@ from negate.extract.feature_vae import VAEExtract from negate.io.config import random_state from negate.io.datasets import generate_dataset, prepare_dataset +from negate.types import ModelOutput, OriginLabel from negate.io.spec import Spec from negate.metrics.heuristics import weight_dc_gne, weight_ae_gne +LOGGER = logging.getLogger(__name__) + @dataclass class InferContext: @@ -28,7 +32,7 @@ class InferContext: model_version: Path train_metadata: dict file_or_folder_path: Path - label: int | None = None + label: OriginLabel | int | None = None dataset_feat: Dataset | None = None run_heuristics: bool = False model: bool = True @@ -69,7 +73,7 @@ def run_native(features_dataset: np.ndarray, model_version: Path, parameters: di model = xgb.Booster(params=parameters, model_file=model_file_path_named) model.load_model(model_file_path_named) if verbose: - print(f"Model '{model_file_path_named}' loaded.") + LOGGER.info("Model '%s' loaded.", model_file_path_named) result = model.predict(xgb.DMatrix(features_pca)) return result @@ -97,16 +101,16 @@ def run_onnx(features_dataset: np.ndarray, model_version: Path, parameters: dict session = ort.InferenceSession(model_file_path_named) if verbose: - print(f"Model '{model_file_path_named}' loaded.") + LOGGER.info("Model '%s' loaded.", model_file_path_named) input_name = session.get_inputs()[0].name inputs = {input_name: features_dataset.astype(np.float32)} # noqa try: result: ort.SparseTensor = session.run(None, {input_name: features_model})[0] # type: ignore - print(result) + LOGGER.info("%s", result) except (InvalidArgument, ONNXRuntimeError) as error_log: import sys - print(error_log) + LOGGER.error("%s", error_log) sys.exit() @@ -114,6 +118,7 @@ def build_map_call(spec: Spec, verbose: bool) -> dict[str, str | int | bool]: kwargs = {} kwargs["disable_nullable"] = spec.opt.disable_nullable kwargs["remove_columns"] = ["image"] + kwargs["desc"] = "Extracting wavelet and latent features from images" if spec.opt.batch_size > 0: kwargs["batched"] = True kwargs["batch_size"] = spec.opt.batch_size @@ -138,8 +143,7 @@ def build_map_call(spec: Spec, verbose: bool) -> dict[str, str | int | bool]: setup_default_logging(logging.INFO) for logger in [df_logging, ds_logging, hf_logging, tf_logging]: logger.set_verbosity_info() - print("Beginning preprocessing.") - kwargs["desc"] = ("Computing wavelets...",) + LOGGER.info("Beginning preprocessing.") return kwargs @@ -149,8 +153,7 @@ def batch_preprocessing(dataset: Dataset, spec: Spec, verbose: bool = False) -> :param spec: Specification container with analysis configuration. :return: Transformed dataset with 'features' column.""" kwargs = build_map_call(spec, verbose) - - with VAEExtract(spec) as extractor: # type: ignore + with VAEExtract(spec, verbose=verbose) as extractor: # type: ignore dataset = dataset.map( extractor.forward, **kwargs, # type: ignore @@ -164,7 +167,6 @@ def preprocessing(dataset: Dataset, spec: Spec, verbose: bool = False) -> Datase :param spec: Specification container with analysis configuration. :return: Transformed dataset with 'features' column.""" kwargs = build_map_call(spec, verbose) - context = WaveletContext(spec=spec, verbose=verbose) with WaveletAnalyze(context) as analyzer: # type: ignore dataset = dataset.map( @@ -174,7 +176,7 @@ def preprocessing(dataset: Dataset, spec: Spec, verbose: bool = False) -> Datase return dataset -def predict_gne_or_syn(context: InferContext) -> list[float]: +def predict_gne_or_syn(context: InferContext) -> list[ModelOutput]: """Returns probability results determined by decision tree model trained on dataset:\n :param data_path: Path to json file with saved parameter data""" spec = context.spec @@ -186,13 +188,13 @@ def predict_gne_or_syn(context: InferContext) -> list[float]: result = run_onnx(features_matrix, model_version, parameters=parameters, verbose=context.verbose) else: result = run_native(features_matrix, model_version, parameters=parameters, verbose=context.verbose) - prob = [] - for x in result: - prob.append(float(x)) - return prob + outputs: list[ModelOutput] = [] + for value in result: + outputs.append(ModelOutput.from_probability(float(value))) + return outputs -def infer_origin(context: InferContext) -> dict[str, list[float]]: +def infer_origin(context: InferContext) -> dict[str, list[ModelOutput] | list[float]]: """Predict synthetic or original for given image.\n :param context: Inference context containing spec, model path, and metadata. :param file_or_folder_path: Path to the image or folder to be checked. diff --git a/negate/io/datasets.py b/negate/io/datasets.py index 12a9073..8f95227 100644 --- a/negate/io/datasets.py +++ b/negate/io/datasets.py @@ -53,7 +53,12 @@ def generate_dataset(file_or_folder_path: Path | list[dict[str, PillowImage.Imag assert isinstance(file_or_folder_path, Path) if file_or_folder_path.is_dir(): - for img_path in tqdm(file_or_folder_path.iterdir(), total=len(os.listdir(str(file_or_folder_path))), desc="Creating dataset...", disable=not verbose): + for img_path in tqdm( + file_or_folder_path.iterdir(), + total=len(os.listdir(str(file_or_folder_path))), + desc="Scanning and validating image files for dataset", + disable=False, + ): if not (img_path.is_file() and img_path.suffix.lower() in valid_extensions): continue try: diff --git a/negate/metrics/heuristics.py b/negate/metrics/heuristics.py index 974e195..561722f 100644 --- a/negate/metrics/heuristics.py +++ b/negate/metrics/heuristics.py @@ -7,6 +7,8 @@ from typing import Any import numpy as np +from negate.types import InferenceModel, ModelOutput, OriginLabel + # Heuristics are not in use, but left for reference example # These ended up not working as effectively as the decision tree @@ -106,13 +108,13 @@ def heuristic_accuracy(result, dc=True): return result -def model_accuracy(result: np.ndarray, label: int | None = None, thresh: float = 0.5) -> list[tuple[str, int]]: +def model_accuracy(result: np.ndarray, label: OriginLabel | int | None = None, thresh: float = 0.5) -> list[tuple[OriginLabel, int]]: """Convert probability array to tuple format (label, confidence).""" thresh = 0.5 model_pred = (result > thresh).astype(int) if label is not None: - ground_truth = np.full(model_pred.shape, label, dtype=int) + ground_truth = np.full(model_pred.shape, int(OriginLabel.coerce(label)), dtype=int) acc = float(np.mean(model_pred == ground_truth)) print(f"Model Accuracy: {acc:.2%}") @@ -120,7 +122,7 @@ def model_accuracy(result: np.ndarray, label: int | None = None, thresh: float = for x in result: prob = float(x) conf = round((1 - prob) * 100) if prob < thresh else round(prob * 100) - label_out = "GNE" if prob < thresh else "SYN" + label_out = OriginLabel.from_probability(prob, threshold=thresh) type_conf.append((label_out, conf)) return type_conf @@ -140,10 +142,21 @@ def normalize_to_range( return out_min + (data - in_min) * (out_max - out_min) / (in_max - in_min) +def _extract_probabilities(predictions: list[ModelOutput] | list[float]) -> list[float]: + """Normalize legacy float predictions and typed ModelOutput predictions.""" + raw_predictions = [] + for entry in predictions: + if isinstance(entry, ModelOutput): + raw_predictions.append(float(entry.probability)) + else: + raw_predictions.append(float(entry)) + return raw_predictions + + def compute_weighted_certainty( - ae_inference: dict[str, list[float]], - dc_inference: dict[str, list[float]], - label: int | None = None, + ae_inference: dict[str, list[ModelOutput] | list[float]], + dc_inference: dict[str, list[ModelOutput] | list[float]], + label: OriginLabel | int | None = None, ae_low_thresh: float = 0.4, # lowering adjust certainty ae_high_thresh: float = 0.48, dc_low_thresh: float = 0.39, @@ -153,19 +166,32 @@ def compute_weighted_certainty( Compute certainty scores by combining all available inference methods.\n Each method contributes a vote (unk: 0=GNE, 1=SYN)). Certainty is the sum normalized 0-5 scale. """ - gne = "GNE" - syn = "SYN" - header = "" - if label is not None: - if label == 1: - header = f"{syn} [1]" - else: - header = f"{gne} (0)" + expected_origin = OriginLabel.coerce(label) if label is not None else None + + def predictor(pct: float, low_thresh: float, high_thresh: float) -> OriginLabel: + if pct < low_thresh: + return OriginLabel.GNE + if pct > high_thresh: + return OriginLabel.SYN + return OriginLabel.GNE if pct < 0.4 else OriginLabel.SYN - predictor = lambda pct, low_thresh, high_thresh,: "GNE" if pct < low_thresh else "SYN" if pct > high_thresh else "GNE" if pct < 0.4 else "SYN" predictions = [ - {"raw_pred": ae_inference["pred"], "thresh": (ae_low_thresh, ae_high_thresh), "norm": (0.02, 0.90), "norm_pred": None, "result": []}, - {"raw_pred": dc_inference["pred"], "thresh": (dc_low_thresh, dc_high_thresh), "norm": (0.15, 0.80), "norm_pred": None, "result": []}, + { + "index": InferenceModel.AE, + "raw_pred": _extract_probabilities(ae_inference["pred"]), + "thresh": (ae_low_thresh, ae_high_thresh), + "norm": (0.02, 0.90), + "norm_pred": None, + "result": [], + }, + { + "index": InferenceModel.DC, + "raw_pred": _extract_probabilities(dc_inference["pred"]), + "thresh": (dc_low_thresh, dc_high_thresh), + "norm": (0.15, 0.80), + "norm_pred": None, + "result": [], + }, ] for index in range(len(predictions)): @@ -173,9 +199,9 @@ def compute_weighted_certainty( predictions[index]["norm_pred"] = predictions[index]["norm_pred"].tolist() for image, num in enumerate(predictions[index]["norm_pred"]): origin = predictor(num, *predictions[index]["thresh"]) - predictions[index]["result"].append({"index": "dc" if index == 1 else "ae", "img": image, "num": num, "origin": origin}) + predictions[index]["result"].append({"index": predictions[index]["index"], "img": image, "num": num, "origin": origin}) - result_format = lambda x: f"{x['index']} :{x['origin']} img:{x['img']} " + f"{x['num']:.2%}" + result_format = lambda x: f"{x['index'].value} :{x['origin'].name} img:{x['img']} " + f"{x['num']:.2%}" final_result = [] final_numeric = [] @@ -185,27 +211,27 @@ def compute_weighted_certainty( else: low_amount_ae = (abs(predictions[0]["thresh"][0] - result["num"])), result high_amount_ae = (abs(result["num"] - predictions[0]["thresh"][1])), result - low_amount_dc = (abs(predictions[1]["thresh"][0] - predictions[1]["result"][index]["num"])), predictions[1]["result"][1] - high_amount_dc = (abs(predictions[1]["result"][index]["num"] - predictions[1]["thresh"][1])), predictions[1]["result"][1] + low_amount_dc = (abs(predictions[1]["thresh"][0] - predictions[1]["result"][index]["num"])), predictions[1]["result"][index] + high_amount_dc = (abs(predictions[1]["result"][index]["num"] - predictions[1]["thresh"][1])), predictions[1]["result"][index] most_certain = max( max(low_amount_ae, high_amount_ae, key=lambda x: x[0]), max(low_amount_dc, high_amount_dc, key=lambda x: x[0]), key=lambda x: x[0], )[1] most_certain["diffs"] = {"ae": (low_amount_ae[0], high_amount_ae[0]), "dc": (low_amount_dc[0], high_amount_dc[0])} - if label is not None: - most_certain["match"] = int(most_certain["origin"] == header[:-4]) + if expected_origin is not None: + most_certain["match"] = int(most_certain["origin"] == expected_origin) final_numeric.append(most_certain) output = result_format(most_certain) spacer = " " * (16 - len(output)) final_result.append(output + spacer) - model_pred = np.array([x["match"] for x in final_numeric]).astype(int) - if label is not None: + if expected_origin is not None: + model_pred = np.array([x["match"] for x in final_numeric], dtype=int) ground_truth = np.full(model_pred.shape, 1, dtype=int) acc = float(np.mean(model_pred == ground_truth)) pprint([x for x in final_numeric if x["match"] == 0]) - print(f"For : {header} ") + print(f"For : {expected_origin.name} [{int(expected_origin)}] ") print(f"Model Accuracy: {acc:.2%}") pprint(final_result) diff --git a/negate/types.py b/negate/types.py new file mode 100644 index 0000000..a48d8d8 --- /dev/null +++ b/negate/types.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 +# + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum, IntEnum + + +class OriginLabel(IntEnum): + """Discrete model output labels for image origin.""" + + GNE = 0 + SYN = 1 + + @classmethod + def coerce(cls, value: OriginLabel | int) -> OriginLabel: + """Normalize ints and enum instances to OriginLabel.""" + if isinstance(value, cls): + return value + return cls(int(value)) + + @classmethod + def from_probability(cls, probability: float, threshold: float = 0.5) -> OriginLabel: + """Map model probability to a discrete origin label.""" + return cls.SYN if probability > threshold else cls.GNE + + +class InferenceModel(str, Enum): + """Inference model role used in weighted certainty output.""" + + AE = "ae" + DC = "dc" + + +@dataclass(frozen=True, slots=True) +class ModelOutput: + """Typed model output carrying probability and enum label.""" + + probability: float + origin: OriginLabel + + @classmethod + def from_probability(cls, probability: float, threshold: float = 0.5) -> ModelOutput: + prob = float(probability) + return cls(probability=prob, origin=OriginLabel.from_probability(prob, threshold=threshold))