diff --git a/ami/main/admin.py b/ami/main/admin.py index c6170b153..e6ddce103 100644 --- a/ami/main/admin.py +++ b/ami/main/admin.py @@ -2,16 +2,22 @@ from django.contrib import admin from django.db import models +from django.db.models import Count from django.db.models.query import QuerySet from django.http.request import HttpRequest from django.template.defaultfilters import filesizeformat +from django.template.response import TemplateResponse +from django.urls import reverse from django.utils.formats import number_format +from django.utils.html import format_html from guardian.admin import GuardedModelAdmin import ami.utils from ami import tasks from ami.jobs.models import Job +from ami.ml.models.algorithm import Algorithm from ami.ml.models.project_pipeline_config import ProjectPipelineConfig +from ami.ml.post_processing.class_masking import update_single_occurrence from ami.ml.tasks import remove_duplicate_classifications from .models import ( @@ -288,6 +294,7 @@ class ClassificationInline(admin.TabularInline): model = Classification extra = 0 fields = ( + "classification_link", "taxon", "algorithm", "timestamp", @@ -295,6 +302,7 @@ class ClassificationInline(admin.TabularInline): "created_at", ) readonly_fields = ( + "classification_link", "taxon", "algorithm", "timestamp", @@ -302,6 +310,13 @@ class ClassificationInline(admin.TabularInline): "created_at", ) + @admin.display(description="Classification") + def classification_link(self, obj: Classification) -> str: + if obj.pk: + url = reverse("admin:main_classification_change", args=[obj.pk]) + return format_html('{}', url, f"Classification #{obj.pk}") + return "-" + def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: qs = super().get_queryset(request) return qs.select_related("taxon", "algorithm", "detection") @@ -311,6 +326,7 @@ class DetectionInline(admin.TabularInline): model = Detection extra = 0 fields = ( + "detection_link", "detection_algorithm", "source_image", "timestamp", @@ -318,6 +334,7 @@ class DetectionInline(admin.TabularInline): "occurrence", ) readonly_fields = ( + "detection_link", "detection_algorithm", "source_image", "timestamp", @@ -325,6 +342,13 @@ class DetectionInline(admin.TabularInline): "occurrence", ) + @admin.display(description="ID") + def detection_link(self, obj): + if obj.pk: + url = reverse("admin:main_detection_change", args=[obj.pk]) + return format_html('{}', url, obj.pk) + return "-" + @admin.register(Detection) class DetectionAdmin(admin.ModelAdmin[Detection]): @@ -382,7 +406,7 @@ class OccurrenceAdmin(admin.ModelAdmin[Occurrence]): "determination__rank", "created_at", ) - search_fields = ("determination__name", "determination__search_names") + search_fields = ("id", "determination__name", "determination__search_names") def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: qs = super().get_queryset(request) @@ -404,11 +428,83 @@ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: def detections_count(self, obj) -> int: return obj.detections_count + @admin.action(description="Run class masking (select taxa list & algorithm)") + def run_class_masking(self, request: HttpRequest, queryset: QuerySet[Occurrence]) -> TemplateResponse | None: + """ + Run class masking on selected occurrences. + Shows an intermediate page to select a TaxaList and Algorithm. + """ + if request.POST.get("confirm"): + taxa_list_id = request.POST.get("taxa_list") + algorithm_id = request.POST.get("algorithm") + if not taxa_list_id or not algorithm_id: + self.message_user(request, "Please select both a taxa list and an algorithm.", level="error") + return None + + try: + taxa_list = TaxaList.objects.get(pk=taxa_list_id) + algorithm = Algorithm.objects.get(pk=algorithm_id) + except (TaxaList.DoesNotExist, Algorithm.DoesNotExist) as e: + self.message_user(request, f"Error: {e}", level="error") + return None + + if not algorithm.category_map: + self.message_user( + request, f"Algorithm '{algorithm.name}' does not have a category map.", level="error" + ) + return None + + count = 0 + for occurrence in queryset: + try: + update_single_occurrence( + occurrence=occurrence, + algorithm=algorithm, + taxa_list=taxa_list, + ) + count += 1 + except Exception as e: + self.message_user( + request, + f"Error processing occurrence {occurrence.pk}: {e}", + level="error", + ) + + self.message_user(request, f"Successfully ran class masking on {count} occurrence(s).") + return None + + # Show intermediate confirmation page + taxa_lists = TaxaList.objects.annotate(taxa_count=Count("taxa")).filter(taxa_count__gt=0).order_by("name") + algorithms = Algorithm.objects.filter(category_map__isnull=False).order_by("name") + + # Annotate algorithms with label count + alg_list = [] + for alg in algorithms: + alg.labels_count = len(alg.category_map.labels) if alg.category_map else 0 + alg_list.append(alg) + + return TemplateResponse( + request, + "admin/main/class_masking_confirmation.html", + { + **self.admin_site.each_context(request), + "title": "Run class masking", + "queryset": queryset, + "occurrence_count": queryset.count(), + "taxa_lists": taxa_lists, + "algorithms": alg_list, + "opts": self.model._meta, + "action_checkbox_name": admin.helpers.ACTION_CHECKBOX_NAME, + }, + ) + ordering = ("-created_at",) # Add classifications as inline inlines = [DetectionInline] + actions = [run_class_masking] + @admin.register(Classification) class ClassificationAdmin(admin.ModelAdmin[Classification]): @@ -432,6 +528,8 @@ class ClassificationAdmin(admin.ModelAdmin[Classification]): "taxon__rank", ) + autocomplete_fields = ("taxon",) + def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: qs = super().get_queryset(request) return qs.select_related( @@ -662,10 +760,32 @@ def run_small_size_filter(self, request: HttpRequest, queryset: QuerySet[SourceI self.message_user(request, f"Queued Small Size Filter for {queryset.count()} collection(s). Jobs: {jobs}") + @admin.action(description="Run Rank Rollup post-processing task (async)") + def run_rank_rollup(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None: + """Trigger the Rank Rollup post-processing job asynchronously.""" + jobs = [] + DEFAULT_THRESHOLDS = {"SPECIES": 0.8, "GENUS": 0.6, "FAMILY": 0.4} + + for collection in queryset: + job = Job.objects.create( + name=f"Post-processing: RankRollup on Collection {collection.pk}", + project=collection.project, + job_type_key="post_processing", + params={ + "task": "rank_rollup", + "config": {"source_image_collection_id": collection.pk, "thresholds": DEFAULT_THRESHOLDS}, + }, + ) + job.enqueue() + jobs.append(job.pk) + + self.message_user(request, f"Queued Rank Rollup for {queryset.count()} collection(s). Jobs: {jobs}") + actions = [ populate_collection, populate_collection_async, run_small_size_filter, + run_rank_rollup, ] # Hide images many-to-many field from form. This would list all source images in the database. diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index d49a414a5..12a7fe896 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -874,10 +874,21 @@ class ClassificationPredictionItemSerializer(serializers.Serializer): logit = serializers.FloatField(read_only=True) +class ClassificationAppliedToSerializer(serializers.ModelSerializer): + """Lightweight nested representation of the parent classification this was derived from.""" + + algorithm = AlgorithmSerializer(read_only=True) + + class Meta: + model = Classification + fields = ["id", "created_at", "algorithm"] + + class ClassificationSerializer(DefaultSerializer): taxon = TaxonNestedSerializer(read_only=True) algorithm = AlgorithmSerializer(read_only=True) top_n = ClassificationPredictionItemSerializer(many=True, read_only=True) + applied_to = ClassificationAppliedToSerializer(read_only=True) class Meta: model = Classification @@ -890,6 +901,7 @@ class Meta: "scores", "logits", "top_n", + "applied_to", "created_at", "updated_at", ] @@ -912,6 +924,8 @@ class Meta(ClassificationSerializer.Meta): class ClassificationListSerializer(DefaultSerializer): + applied_to = ClassificationAppliedToSerializer(read_only=True) + class Meta: model = Classification fields = [ @@ -920,6 +934,7 @@ class Meta: "taxon", "score", "algorithm", + "applied_to", "created_at", "updated_at", ] @@ -939,6 +954,7 @@ class Meta: "score", "terminal", "algorithm", + "applied_to", "created_at", ] diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 9a2770ac8..e5dc6149c 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -1225,7 +1225,17 @@ def get_queryset(self) -> QuerySet["Occurrence"]: if self.action != "list": qs = qs.prefetch_related( Prefetch( - "detections", queryset=Detection.objects.order_by("-timestamp").select_related("source_image") + "detections", + queryset=Detection.objects.order_by("-timestamp") + .select_related("source_image") + .prefetch_related( + Prefetch( + "classifications", + queryset=Classification.objects.select_related( + "taxon", "algorithm", "applied_to__algorithm" + ), + ) + ), ) ) @@ -1640,7 +1650,7 @@ class ClassificationViewSet(DefaultViewSet, ProjectMixin): API endpoint for viewing and adding classification results from a model. """ - queryset = Classification.objects.all().select_related("taxon", "algorithm") # , "detection") + queryset = Classification.objects.all().select_related("taxon", "algorithm", "applied_to__algorithm") serializer_class = ClassificationSerializer filterset_fields = [ # Docs about slow loading API browser because of large choice fields diff --git a/ami/ml/management/commands/run_class_masking.py b/ami/ml/management/commands/run_class_masking.py new file mode 100644 index 000000000..d87375d74 --- /dev/null +++ b/ami/ml/management/commands/run_class_masking.py @@ -0,0 +1,83 @@ +from django.core.management.base import BaseCommand, CommandError + +from ami.main.models import SourceImageCollection, TaxaList +from ami.ml.models.algorithm import Algorithm +from ami.ml.post_processing.class_masking import ClassMaskingTask + + +class Command(BaseCommand): + help = ( + "Run class masking post-processing on a source image collection. " + "Masks classifier logits for species not in the given taxa list and recalculates softmax scores." + ) + + def add_arguments(self, parser): + parser.add_argument("--collection-id", type=int, required=True, help="SourceImageCollection ID to process") + parser.add_argument("--taxa-list-id", type=int, required=True, help="TaxaList ID to use as the species mask") + parser.add_argument( + "--algorithm-id", type=int, required=True, help="Algorithm ID whose classifications to mask" + ) + parser.add_argument("--dry-run", action="store_true", help="Show what would be done without making changes") + + def handle(self, *args, **options): + collection_id = options["collection_id"] + taxa_list_id = options["taxa_list_id"] + algorithm_id = options["algorithm_id"] + dry_run = options["dry_run"] + + # Validate inputs + try: + collection = SourceImageCollection.objects.get(pk=collection_id) + except SourceImageCollection.DoesNotExist: + raise CommandError(f"SourceImageCollection {collection_id} does not exist.") + + try: + taxa_list = TaxaList.objects.get(pk=taxa_list_id) + except TaxaList.DoesNotExist: + raise CommandError(f"TaxaList {taxa_list_id} does not exist.") + + try: + algorithm = Algorithm.objects.get(pk=algorithm_id) + except Algorithm.DoesNotExist: + raise CommandError(f"Algorithm {algorithm_id} does not exist.") + + if not algorithm.category_map: + raise CommandError(f"Algorithm '{algorithm.name}' does not have a category map.") + + from ami.main.models import Classification + + classification_count = ( + Classification.objects.filter( + detection__source_image__collections=collection, + terminal=True, + algorithm=algorithm, + scores__isnull=False, + ) + .distinct() + .count() + ) + + taxa_count = taxa_list.taxa.count() + + self.stdout.write( + f"Collection: {collection.name} (id={collection.pk})\n" + f"Taxa list: {taxa_list.name} (id={taxa_list.pk}, {taxa_count} taxa)\n" + f"Algorithm: {algorithm.name} (id={algorithm.pk})\n" + f"Classifications to process: {classification_count}" + ) + + if classification_count == 0: + raise CommandError("No terminal classifications with scores found for this collection/algorithm.") + + if dry_run: + self.stdout.write(self.style.WARNING("Dry run — no changes made.")) + return + + self.stdout.write("Running class masking...") + task = ClassMaskingTask( + collection_id=collection_id, + taxa_list_id=taxa_list_id, + algorithm_id=algorithm_id, + ) + task.run() + self.stdout.write(self.style.SUCCESS("Class masking completed.")) diff --git a/ami/ml/post_processing/__init__.py b/ami/ml/post_processing/__init__.py index 3517ed47c..e69de29bb 100644 --- a/ami/ml/post_processing/__init__.py +++ b/ami/ml/post_processing/__init__.py @@ -1 +0,0 @@ -from . import small_size_filter # noqa: F401 diff --git a/ami/ml/post_processing/class_masking.py b/ami/ml/post_processing/class_masking.py new file mode 100644 index 000000000..89c799a6f --- /dev/null +++ b/ami/ml/post_processing/class_masking.py @@ -0,0 +1,253 @@ +import logging + +import numpy as np +from django.db import transaction +from django.db.models import QuerySet +from django.utils import timezone + +from ami.main.models import Classification, Occurrence, SourceImageCollection, TaxaList +from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap, AlgorithmTaskType +from ami.ml.post_processing.base import BasePostProcessingTask + +logger = logging.getLogger(__name__) + + +def update_single_occurrence( + occurrence: Occurrence, + algorithm: Algorithm, + taxa_list: TaxaList, + task_logger: logging.Logger = logger, +): + task_logger.info(f"Recalculating classifications for occurrence {occurrence.pk}.") + + # Get the classifications for the occurrence in the collection + classifications = Classification.objects.filter( + detection__occurrence=occurrence, + terminal=True, + algorithm=algorithm, + scores__isnull=False, + logits__isnull=False, + ).distinct() + + # Make a new Algorithm for the filtered classifications + new_algorithm, _ = Algorithm.objects.get_or_create( + name=f"{algorithm.name} (filtered by taxa list {taxa_list.name})", + key=f"{algorithm.key}_filtered_by_taxa_list_{taxa_list.pk}", + defaults={ + "description": f"Classification algorithm {algorithm.name} filtered by taxa list {taxa_list.name}", + "task_type": AlgorithmTaskType.CLASSIFICATION.value, + "category_map": algorithm.category_map, + }, + ) + + make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=taxa_list, + algorithm=algorithm, + new_algorithm=new_algorithm, + ) + + +def update_occurrences_in_collection( + collection: SourceImageCollection, + taxa_list: TaxaList, + algorithm: Algorithm, + params: dict, + new_algorithm: Algorithm, + task_logger: logging.Logger = logger, + job=None, +): + task_logger.info(f"Recalculating classifications based on a taxa list. Params: {params}") + + classifications = Classification.objects.filter( + detection__source_image__collections=collection, + terminal=True, + algorithm=algorithm, + scores__isnull=False, + logits__isnull=False, + ).distinct() + + make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=taxa_list, + algorithm=algorithm, + new_algorithm=new_algorithm, + ) + + +def make_classifications_filtered_by_taxa_list( + classifications: QuerySet[Classification], + taxa_list: TaxaList, + algorithm: Algorithm, + new_algorithm: Algorithm, +): + taxa_in_list = set(taxa_list.taxa.all()) + + occurrences_to_update: set[Occurrence] = set() + classification_count = classifications.count() + logger.info(f"Found {classification_count} terminal classifications with scores to update.") + + if classification_count == 0: + raise ValueError("No terminal classifications with scores found to update.") + + if not algorithm.category_map: + raise ValueError(f"Algorithm {algorithm} does not have a category map.") + category_map: AlgorithmCategoryMap = algorithm.category_map + + # @TODO find a more efficient way to get the category map with taxa. This is slow! + logger.info(f"Retrieving category map with Taxa instances for algorithm {algorithm}") + category_map_with_taxa = category_map.with_taxa() + excluded_category_map_with_taxa = [ + category for category in category_map_with_taxa if category["taxon"] not in taxa_in_list + ] + + excluded_category_indices = [ + int(category["index"]) for category in excluded_category_map_with_taxa # type: ignore + ] + + # Log number of categories in the category map, num included, and num excluded, num classifications to update + logger.info( + f"Category map has {len(category_map_with_taxa)} categories, " + f"{len(excluded_category_map_with_taxa)} categories excluded, " + f"{classification_count} classifications to check" + ) + + classifications_to_add = [] + classifications_to_update = [] + + timestamp = timezone.now() + for classification in classifications: + scores, logits = classification.scores, classification.logits + + # Assert that all scores & logits are lists of numbers + if not isinstance(scores, list) or not all(isinstance(score, (int, float)) for score in scores): + raise ValueError(f"Scores for classification {classification.pk} are not a list of numbers: {scores}") + if not isinstance(logits, list) or not all(isinstance(logit, (int, float)) for logit in logits): + raise ValueError(f"Logits for classification {classification.pk} are not a list of numbers: {logits}") + + logger.debug(f"Processing classification {classification.pk} with {len(scores)} scores") + logger.info(f"Previous totals: {sum(scores)} scores, {sum(logits)} logits") + + logits_np = np.array(logits) + + # Mask excluded logits with -100 (effectively zero probability after softmax) + # @TODO consider using -np.inf for mathematically exact masking + logits_np[excluded_category_indices] = -100 + + logits: list[float] = logits_np.tolist() + + # Recalculate the softmax scores based on the filtered logits + scores_np: np.ndarray = np.exp(logits_np - np.max(logits_np)) # Subtract max for numerical stability + scores_np /= np.sum(scores_np) # Normalize to get probabilities + + scores: list = scores_np.tolist() # Convert back to list + + logger.info(f"New totals: {sum(scores)} scores, {sum(logits)} logits") + + # Get the taxon with the highest score using the index of the max score + top_index = scores.index(max(scores)) + top_taxon = category_map_with_taxa[top_index]["taxon"] + logger.debug(f"Top taxon: {category_map_with_taxa[top_index]}, index: {top_index}") + + # check if needs updating + if classification.scores == scores and classification.logits == logits: + logger.debug(f"Classification {classification.pk} does not need updating") + continue + + # Consider the existing classification as an intermediate classification + classification.terminal = False + classification.updated_at = timestamp + + # Recalculate the top taxon and score + new_classification = Classification( + taxon=top_taxon, + algorithm=new_algorithm, + score=max(scores), + scores=scores, + logits=logits, + detection=classification.detection, + timestamp=classification.timestamp, + terminal=True, + category_map=new_algorithm.category_map, + created_at=timestamp, + updated_at=timestamp, + ) + if new_classification.taxon is None: + raise ValueError( + f"Unable to determine top taxon after class masking for classification {classification.pk}. " + "No allowed classes found in taxa list." + ) + + classifications_to_update.append(classification) + classifications_to_add.append(new_classification) + + assert new_classification.detection is not None + assert new_classification.detection.occurrence is not None + occurrences_to_update.add(new_classification.detection.occurrence) + + logger.info( + f"Adding new classification for Taxon {top_taxon} to occurrence {new_classification.detection.occurrence}" + ) + + # Bulk update/create in a single transaction for atomicity + with transaction.atomic(): + if classifications_to_update: + logger.info(f"Bulk updating {len(classifications_to_update)} existing classifications") + Classification.objects.bulk_update(classifications_to_update, ["terminal", "updated_at"]) + logger.info(f"Updated {len(classifications_to_update)} existing classifications") + + if classifications_to_add: + logger.info(f"Bulk creating {len(classifications_to_add)} new classifications") + Classification.objects.bulk_create(classifications_to_add) + logger.info(f"Added {len(classifications_to_add)} new classifications") + + # Update the occurrence determinations + logger.info(f"Updating the determinations for {len(occurrences_to_update)} occurrences") + for occurrence in occurrences_to_update: + occurrence.save(update_determination=True) + logger.info(f"Updated determinations for {len(occurrences_to_update)} occurrences") + + +class ClassMaskingTask(BasePostProcessingTask): + key = "class_masking" + name = "Class masking" + + def run(self) -> None: + """Apply class masking on a source image collection using a taxa list.""" + job = self.job + self.logger.info(f"=== Starting {self.name} ===") + + collection_id = self.config.get("collection_id") + taxa_list_id = self.config.get("taxa_list_id") + algorithm_id = self.config.get("algorithm_id") + + # Validate config parameters + if not all([collection_id, taxa_list_id, algorithm_id]): + self.logger.error("Missing required configuration: collection_id, taxa_list_id, algorithm_id") + return + + try: + collection = SourceImageCollection.objects.get(pk=collection_id) + taxa_list = TaxaList.objects.get(pk=taxa_list_id) + algorithm = Algorithm.objects.get(pk=algorithm_id) + except Exception as e: + self.logger.exception(f"Failed to load objects: {e}") + return + + self.logger.info(f"Applying class masking on collection {collection_id} using taxa list {taxa_list_id}") + + # @TODO temporary, do we need a new algorithm for each class mask? + self.algorithm.category_map = algorithm.category_map # Ensure the algorithm has its category map loaded + + update_occurrences_in_collection( + collection=collection, + taxa_list=taxa_list, + algorithm=algorithm, + params=self.config, + task_logger=self.logger, + job=job, + new_algorithm=self.algorithm, + ) + + self.logger.info("Class masking completed successfully.") + self.logger.info(f"=== Completed {self.name} ===") diff --git a/ami/ml/post_processing/rank_rollup.py b/ami/ml/post_processing/rank_rollup.py new file mode 100644 index 000000000..4f788a6c7 --- /dev/null +++ b/ami/ml/post_processing/rank_rollup.py @@ -0,0 +1,167 @@ +import logging +from collections import defaultdict + +from django.db import transaction +from django.utils import timezone + +from ami.main.models import Classification, Taxon +from ami.ml.post_processing.base import BasePostProcessingTask + +logger = logging.getLogger(__name__) + + +def find_ancestor_by_parent_chain(taxon, target_rank: str): + """Climb up parent relationships until a taxon with the target rank is found.""" + if not taxon: + return None + + target_rank = target_rank.upper() + + current = taxon + while current: + if current.rank.upper() == target_rank: + return current + current = current.parent + + return None + + +class RankRollupTask(BasePostProcessingTask): + """Post-processing task that rolls up low-confidence classifications + to higher ranks using aggregated scores. + """ + + key = "rank_rollup" + name = "Rank rollup" + + DEFAULT_THRESHOLDS = {"SPECIES": 0.8, "GENUS": 0.6, "FAMILY": 0.4} + ROLLUP_ORDER = ["SPECIES", "GENUS", "FAMILY"] + + def run(self) -> None: + job = self.job + self.logger.info(f"Starting {self.name} task for job {job.pk if job else 'N/A'}") + + # ---- Read config parameters ---- + config = self.config or {} + collection_id = config.get("source_image_collection_id") + raw_thresholds = config.get("thresholds", self.DEFAULT_THRESHOLDS) + thresholds = {k.upper(): v for k, v in raw_thresholds.items()} + rollup_order = config.get("rollup_order", self.ROLLUP_ORDER) + + if not collection_id: + self.logger.info("No 'source_image_collection_id' provided in config. Aborting task.") + return + + self.logger.info( + f"Config loaded: collection_id={collection_id}, thresholds={thresholds}, rollup_order={rollup_order}" + ) + + qs = Classification.objects.filter( + terminal=True, + taxon__isnull=False, + detection__source_image__collections__id=collection_id, + ).distinct() + + total = qs.count() + self.logger.info(f"Found {total} terminal classifications to process for collection {collection_id}") + + # Pre-load all labels from category maps to avoid N+1 queries + all_labels: set[str] = set() + for clf in qs.only("category_map"): + if clf.category_map and clf.category_map.labels: + all_labels.update(label for label in clf.category_map.labels if label) + + label_to_taxon = {} + if all_labels: + for taxon in Taxon.objects.filter(name__in=all_labels).select_related("parent"): + label_to_taxon[taxon.name] = taxon + self.logger.info(f"Pre-loaded {len(label_to_taxon)} taxa from {len(all_labels)} unique labels") + + updated_occurrences = [] + + with transaction.atomic(): + for i, clf in enumerate(qs.iterator(), start=1): + score_str = f"{clf.score:.3f}" if clf.score is not None else "N/A" + self.logger.info(f"Processing classification #{clf.pk} (taxon={clf.taxon}, score={score_str})") + + if not clf.scores: + self.logger.info(f"Skipping classification #{clf.pk}: no scores available") + continue + if not clf.category_map: + self.logger.info(f"Skipping classification #{clf.pk}: no category_map assigned") + continue + + taxon_scores = defaultdict(float) + + for idx, score in enumerate(clf.scores): + label = clf.category_map.labels[idx] + if not label: + continue + + taxon = label_to_taxon.get(label) + if not taxon: + self.logger.debug(f"Skipping label '{label}' (no matching Taxon found)") + continue + + for rank in rollup_order: + ancestor = find_ancestor_by_parent_chain(taxon, rank) + if ancestor: + taxon_scores[ancestor] += score + self.logger.debug(f" + Added {score:.3f} to ancestor {ancestor.name} ({rank})") + + new_taxon = None + new_score = None + scores_str = {t.name: s for t, s in taxon_scores.items()} + self.logger.info(f"Aggregated taxon scores: {scores_str}") + for rank in rollup_order: + threshold = thresholds.get(rank, 1.0) + candidates = {t: s for t, s in taxon_scores.items() if t.rank == rank} + + if not candidates: + self.logger.info(f"No candidates found at rank {rank}") + continue + + best_taxon, best_score = max(candidates.items(), key=lambda kv: kv[1]) + self.logger.info( + f"Best at rank {rank}: {best_taxon.name} ({best_score:.3f}) [threshold={threshold}]" + ) + + if best_score >= threshold: + new_taxon, new_score = best_taxon, best_score + self.logger.info(f"Rollup decision: {new_taxon.name} ({rank}) with score {new_score:.3f}") + break + + if new_taxon and new_taxon != clf.taxon: + self.logger.info(f"Rolling up {clf.taxon} => {new_taxon} ({new_taxon.rank})") + + # Mark all classifications for this detection as non-terminal + Classification.objects.filter(detection=clf.detection).update(terminal=False) + Classification.objects.create( + detection=clf.detection, + taxon=new_taxon, + score=new_score, + terminal=True, + algorithm=self.algorithm, + timestamp=timezone.now(), + applied_to=clf, + ) + + occurrence = clf.detection.occurrence + if occurrence: + updated_occurrences.append(occurrence) + self.logger.info( + f"Rolled up occurrence {occurrence.pk}: {clf.taxon} => {new_taxon} " + f"({new_taxon.rank}) with rolled-up score={new_score:.3f}" + ) + else: + self.logger.warning(f"Detection #{clf.detection.pk} has no occurrence; skipping.") + else: + self.logger.info(f"No rollup applied for classification #{clf.pk} (taxon={clf.taxon})") + + # Update progress every 10 iterations + if i % 10 == 0 or i == total: + progress = i / total if total > 0 else 1.0 + self.update_progress(progress) + + self.logger.info(f"Rank rollup completed. Updated {len(updated_occurrences)} occurrences.") + self.logger.info(f"{self.name} task finished for collection {collection_id}.") diff --git a/ami/ml/post_processing/registry.py b/ami/ml/post_processing/registry.py index c85f607f9..28fa7fb2f 100644 --- a/ami/ml/post_processing/registry.py +++ b/ami/ml/post_processing/registry.py @@ -1,8 +1,12 @@ # Registry of available post-processing tasks +from ami.ml.post_processing.class_masking import ClassMaskingTask +from ami.ml.post_processing.rank_rollup import RankRollupTask from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask POSTPROCESSING_TASKS = { SmallSizeFilterTask.key: SmallSizeFilterTask, + ClassMaskingTask.key: ClassMaskingTask, + RankRollupTask.key: RankRollupTask, } diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 6d029492b..ab24b693a 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -10,14 +10,20 @@ from ami.main.models import ( Classification, Detection, + Occurrence, Project, SourceImage, SourceImageCollection, + TaxaList, Taxon, + TaxonRank, group_images_into_events, ) -from ami.ml.models import Algorithm, Pipeline, ProcessingService +from ami.ml.models import Algorithm, AlgorithmCategoryMap, Pipeline, ProcessingService +from ami.ml.models.algorithm import AlgorithmTaskType from ami.ml.models.pipeline import collect_images, get_or_create_algorithm_and_category_map, save_results +from ami.ml.post_processing.class_masking import make_classifications_filtered_by_taxa_list +from ami.ml.post_processing.rank_rollup import RankRollupTask from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask from ami.ml.schemas import ( AlgorithmConfigResponse, @@ -827,6 +833,13 @@ def setUp(self): ) self.collection.populate_sample() + # Select example taxa + self.species_taxon = Taxon.objects.filter(rank=TaxonRank.SPECIES.name).first() + self.genus_taxon = self.species_taxon.parent if self.species_taxon else None + self.assertIsNotNone(self.species_taxon) + self.assertIsNotNone(self.genus_taxon) + self.algorithm = self._create_category_map_with_algorithm() + def _create_images_with_dimensions( self, deployment, @@ -892,7 +905,9 @@ def test_small_size_filter_assigns_not_identifiable(self): not_identifiable_taxon, f"Detection {det.pk} should be classified as 'Not identifiable'", ) + occurrence = det.occurrence + assert occurrence self.assertIsNotNone(occurrence, f"Detection {det.pk} should belong to an occurrence.") occurrence.refresh_from_db() self.assertEqual( @@ -901,6 +916,413 @@ def test_small_size_filter_assigns_not_identifiable(self): f"Occurrence {occurrence.pk} should have its determination set to 'Not identifiable'.", ) + def _create_occurrences_with_classifications(self, num=3): + """Helper to create occurrences and terminal classifications below species threshold.""" + occurrences = [] + now = datetime.datetime.now(datetime.timezone.utc) + for i in range(num): + det = Detection.objects.create( + source_image=self.collection.images.first(), + bbox=[0, 0, 200, 200], + ) + occ = Occurrence.objects.create(project=self.project, event=self.deployment.events.first()) + occ.detections.add(det) + classification = Classification.objects.create( + detection=det, + taxon=self.species_taxon, + score=0.5, + scores=[0.5, 0.2, 0.1], + terminal=True, + timestamp=now, + algorithm=self.algorithm, + ) + occurrences.append((occ, classification)) + return occurrences + + def _create_category_map_with_algorithm(self): + """Create a simple AlgorithmCategoryMap and Algorithm to attach to classifications.""" + species_taxa = list(self.project.taxa.filter(rank=TaxonRank.SPECIES.name).order_by("name")[:3]) + assert species_taxa, "No species taxa found in project; run create_taxa() first." + + data = [ + { + "index": i, + "label": taxon.name, + "taxon_rank": taxon.rank, + "gbif_key": getattr(taxon, "gbif_key", None), + } + for i, taxon in enumerate(species_taxa) + ] + labels = [item["label"] for item in data] + + category_map = AlgorithmCategoryMap.objects.create( + data=data, + labels=labels, + version="v1.0", + description="Species-level category map for testing RankRollupTask", + ) + + algorithm = Algorithm.objects.create( + name="Test Species Classifier", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=category_map, + ) + + return algorithm + + def test_rank_rollup_creates_new_terminal_classifications(self): + occurrences = self._create_occurrences_with_classifications(num=3) + + task = RankRollupTask( + source_image_collection_id=self.collection.pk, + thresholds={"SPECIES": 0.8, "GENUS": 0.6, "FAMILY": 0.4}, + ) + task.run() + + # Validate results + for occ, original_cls in occurrences: + detection = occ.detections.first() + original_cls.refresh_from_db(fields=["terminal"]) + rolled_up_cls = Classification.objects.filter(detection=detection, terminal=True).first() + + self.assertIsNotNone( + rolled_up_cls, + f"Expected a new rolled-up classification for original #{original_cls.pk}", + ) + self.assertTrue( + rolled_up_cls.terminal, + "New rolled-up classification should be marked as terminal.", + ) + self.assertFalse( + original_cls.terminal, + "Original classification should be marked as non-terminal after roll-up.", + ) + self.assertEqual( + rolled_up_cls.taxon, + self.genus_taxon, + "Rolled-up classification should have genus-level taxon.", + ) + self.assertEqual( + rolled_up_cls.applied_to, + original_cls, + "Rolled-up classification should reference the original classification.", + ) + + def _create_classification_with_logits(self, detection, taxon, score, scores, logits): + """Helper to create a classification with explicit scores and logits.""" + now = datetime.datetime.now(datetime.timezone.utc) + return Classification.objects.create( + detection=detection, + taxon=taxon, + score=score, + scores=scores, + logits=logits, + terminal=True, + timestamp=now, + algorithm=self.algorithm, + ) + + def test_class_masking_redistributes_scores(self): + """ + Test that class masking correctly recalculates softmax after masking excluded species. + + Setup: 3 species in category map (indices 0, 1, 2). + Taxa list contains only species at indices 0 and 1. + Original classification has species at index 2 as the top prediction. + After masking, the top prediction should shift to species 0 or 1. + """ + import math + + species_taxa = list(self.project.taxa.filter(rank=TaxonRank.SPECIES.name).order_by("name")[:3]) + self.assertEqual(len(species_taxa), 3) + + # Create a taxa list with only the first 2 species (exclude species_taxa[2]) + partial_taxa_list = TaxaList.objects.create(name="Partial Species List") + partial_taxa_list.taxa.set(species_taxa[:2]) + + # Logits where excluded species (index 2) has the highest value + logits = [2.0, 1.0, 5.0] # species[2] dominates + # Compute original softmax + max_logit = max(logits) + exp_logits = [math.exp(x - max_logit) for x in logits] + total = sum(exp_logits) + original_scores = [e / total for e in exp_logits] + + # Original top prediction is species[2] (the excluded one) + self.assertEqual(original_scores.index(max(original_scores)), 2) + + det = Detection.objects.create( + source_image=self.collection.images.first(), + bbox=[0, 0, 200, 200], + ) + occ = Occurrence.objects.create(project=self.project, event=self.deployment.events.first()) + occ.detections.add(det) + + original_clf = self._create_classification_with_logits( + detection=det, + taxon=species_taxa[2], # top prediction is the excluded species + score=max(original_scores), + scores=original_scores, + logits=logits, + ) + + # Create a new algorithm for masked output + new_algorithm, _ = Algorithm.objects.get_or_create( + name=f"{self.algorithm.name} (filtered by {partial_taxa_list.name})", + key=f"{self.algorithm.key}_filtered_{partial_taxa_list.pk}", + defaults={ + "task_type": AlgorithmTaskType.CLASSIFICATION.value, + "category_map": self.algorithm.category_map, + }, + ) + + classifications = Classification.objects.filter(pk=original_clf.pk) + make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=partial_taxa_list, + algorithm=self.algorithm, + new_algorithm=new_algorithm, + ) + + # Original classification should be non-terminal + original_clf.refresh_from_db() + self.assertFalse(original_clf.terminal, "Original classification should be non-terminal after masking.") + + # New terminal classification should exist + new_clf = Classification.objects.filter(detection=det, terminal=True).first() + self.assertIsNotNone(new_clf, "A new terminal classification should be created.") + self.assertEqual(new_clf.algorithm, new_algorithm) + + # New top prediction should be species[0] (highest logit among allowed species) + self.assertEqual( + new_clf.taxon, + species_taxa[0], + "Top prediction should be the highest-scoring species remaining in the taxa list.", + ) + + # Scores should sum to ~1.0 (valid probability distribution) + self.assertAlmostEqual(sum(new_clf.scores), 1.0, places=5, msg="Masked scores should sum to 1.0") + + # Excluded species score should be ~0.0 + self.assertAlmostEqual( + new_clf.scores[2], + 0.0, + places=10, + msg="Excluded species score should be effectively zero.", + ) + + # New top score should be higher than original (probability mass redistributed) + self.assertGreater( + new_clf.score, + original_scores[0], + "In-list species score should increase after masking out the dominant excluded species.", + ) + + def test_class_masking_improves_accuracy(self): + """ + Test the key use case: class masking improves accuracy when the true species is in + the taxa list but was originally outscored by an out-of-list species. + + Scenario: True species is "Vanessa cardui" (in list). The classifier's top prediction + is an out-of-list species. After masking, "Vanessa cardui" should become the top + prediction, and the occurrence determination should update. + """ + species_taxa = list(self.project.taxa.filter(rank=TaxonRank.SPECIES.name).order_by("name")[:3]) + self.assertEqual(len(species_taxa), 3) + # species_taxa sorted by name: [Vanessa atalanta, Vanessa cardui, Vanessa itea] + + true_species = species_taxa[1] # Vanessa cardui — the "ground truth" + excluded_species = species_taxa[2] # Vanessa itea — not in the regional list + + # Taxa list: contains atalanta and cardui, but NOT itea + regional_list = TaxaList.objects.create(name="Regional Species List") + regional_list.taxa.set([species_taxa[0], species_taxa[1]]) + + # Logits: itea (index 2) is top, cardui (index 1) is close second, atalanta (index 0) is low + logits = [0.5, 3.0, 3.5] + + import math + + max_logit = max(logits) + exp_logits = [math.exp(x - max_logit) for x in logits] + total = sum(exp_logits) + scores = [e / total for e in exp_logits] + + # Original top prediction is the excluded species + self.assertEqual(scores.index(max(scores)), 2) + + det = Detection.objects.create( + source_image=self.collection.images.first(), + bbox=[0, 0, 200, 200], + ) + occ = Occurrence.objects.create(project=self.project, event=self.deployment.events.first()) + occ.detections.add(det) + + self._create_classification_with_logits( + detection=det, + taxon=excluded_species, + score=max(scores), + scores=scores, + logits=logits, + ) + # Occurrence determination is currently the excluded species + occ.save(update_determination=True) + occ.refresh_from_db() + self.assertEqual(occ.determination, excluded_species) + + new_algorithm, _ = Algorithm.objects.get_or_create( + name=f"{self.algorithm.name} (filtered by {regional_list.name})", + key=f"{self.algorithm.key}_filtered_{regional_list.pk}", + defaults={ + "task_type": AlgorithmTaskType.CLASSIFICATION.value, + "category_map": self.algorithm.category_map, + }, + ) + + classifications = Classification.objects.filter( + detection__occurrence=occ, + terminal=True, + algorithm=self.algorithm, + scores__isnull=False, + ) + make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=regional_list, + algorithm=self.algorithm, + new_algorithm=new_algorithm, + ) + + # After masking, occurrence determination should be the true species + occ.refresh_from_db() + self.assertEqual( + occ.determination, + true_species, + "After class masking, occurrence determination should update to the correct in-list species.", + ) + + # Verify the new classification's taxon + new_clf = Classification.objects.filter(detection=det, terminal=True).first() + self.assertEqual(new_clf.taxon, true_species) + self.assertGreater(new_clf.score, 0.5, "Masked score for true species should be > 0.5") + + def test_class_masking_no_change_when_all_species_in_list(self): + """When all category map species are in the taxa list, no new classifications should be created.""" + species_taxa = list(self.project.taxa.filter(rank=TaxonRank.SPECIES.name).order_by("name")[:3]) + + # Taxa list contains ALL species + full_list = TaxaList.objects.create(name="Full Species List") + full_list.taxa.set(species_taxa) + + logits = [3.0, 1.0, 0.5] + import math + + max_logit = max(logits) + exp_logits = [math.exp(x - max_logit) for x in logits] + total = sum(exp_logits) + scores = [e / total for e in exp_logits] + + det = Detection.objects.create( + source_image=self.collection.images.first(), + bbox=[0, 0, 200, 200], + ) + occ = Occurrence.objects.create(project=self.project, event=self.deployment.events.first()) + occ.detections.add(det) + + original_clf = self._create_classification_with_logits( + detection=det, + taxon=species_taxa[0], + score=max(scores), + scores=scores, + logits=logits, + ) + + new_algorithm, _ = Algorithm.objects.get_or_create( + name=f"{self.algorithm.name} (filtered full)", + key=f"{self.algorithm.key}_filtered_full", + defaults={ + "task_type": AlgorithmTaskType.CLASSIFICATION.value, + "category_map": self.algorithm.category_map, + }, + ) + + classifications = Classification.objects.filter(pk=original_clf.pk) + make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=full_list, + algorithm=self.algorithm, + new_algorithm=new_algorithm, + ) + + # Original should still be terminal (no change needed) + original_clf.refresh_from_db() + self.assertTrue(original_clf.terminal, "Original should remain terminal when all species are in the list.") + + # No new classifications created + clf_count = Classification.objects.filter(detection=det).count() + self.assertEqual(clf_count, 1, "No new classification should be created when masking changes nothing.") + + def test_class_masking_softmax_correctness(self): + """Verify that masked softmax produces mathematically correct results.""" + import math + + species_taxa = list(self.project.taxa.filter(rank=TaxonRank.SPECIES.name).order_by("name")[:3]) + + # Only keep species at index 0 + single_species_list = TaxaList.objects.create(name="Single Species List") + single_species_list.taxa.set([species_taxa[0]]) + + logits = [2.0, 3.0, 4.0] + max_logit = max(logits) + exp_logits = [math.exp(x - max_logit) for x in logits] + total = sum(exp_logits) + scores = [e / total for e in exp_logits] + + det = Detection.objects.create( + source_image=self.collection.images.first(), + bbox=[0, 0, 200, 200], + ) + occ = Occurrence.objects.create(project=self.project, event=self.deployment.events.first()) + occ.detections.add(det) + + self._create_classification_with_logits( + detection=det, + taxon=species_taxa[2], # original top is index 2 + score=max(scores), + scores=scores, + logits=logits, + ) + + new_algorithm, _ = Algorithm.objects.get_or_create( + name=f"{self.algorithm.name} (single species)", + key=f"{self.algorithm.key}_single", + defaults={ + "task_type": AlgorithmTaskType.CLASSIFICATION.value, + "category_map": self.algorithm.category_map, + }, + ) + + classifications = Classification.objects.filter(detection=det, terminal=True) + make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=single_species_list, + algorithm=self.algorithm, + new_algorithm=new_algorithm, + ) + + new_clf = Classification.objects.filter(detection=det, terminal=True).first() + self.assertIsNotNone(new_clf) + + # With only 1 allowed species, its score should be ~1.0 + self.assertAlmostEqual( + new_clf.scores[0], + 1.0, + places=5, + msg="With only one allowed species, its softmax score should be ~1.0", + ) + self.assertAlmostEqual(new_clf.scores[1], 0.0, places=10) + self.assertAlmostEqual(new_clf.scores[2], 0.0, places=10) + self.assertAlmostEqual(sum(new_clf.scores), 1.0, places=5) + class TestTaskStateManager(TestCase): """Test TaskStateManager for job progress tracking.""" diff --git a/ami/templates/admin/main/class_masking_confirmation.html b/ami/templates/admin/main/class_masking_confirmation.html new file mode 100644 index 000000000..01e6652d4 --- /dev/null +++ b/ami/templates/admin/main/class_masking_confirmation.html @@ -0,0 +1,58 @@ +{% extends "admin/base_site.html" %} + +{% load i18n admin_urls %} + +{% block title %} + {% translate "Run class masking" %} | {{ site_title|default:_("Django site admin") }} +{% endblock title %} +{% block breadcrumbs %} +
+{% endblock breadcrumbs %} +{% block content %} + +{% endblock content %} diff --git a/docs/screenshots/admin-class-masking-form.png b/docs/screenshots/admin-class-masking-form.png new file mode 100644 index 000000000..ac67a5949 Binary files /dev/null and b/docs/screenshots/admin-class-masking-form.png differ diff --git a/docs/screenshots/admin-classifications-after-masking.png b/docs/screenshots/admin-classifications-after-masking.png new file mode 100644 index 000000000..fec9744d4 Binary files /dev/null and b/docs/screenshots/admin-classifications-after-masking.png differ diff --git a/docs/screenshots/ui-occurrence-43498-identification.png b/docs/screenshots/ui-occurrence-43498-identification.png new file mode 100644 index 000000000..3dd6e44bf Binary files /dev/null and b/docs/screenshots/ui-occurrence-43498-identification.png differ