Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f46e88c
feat: add base and runner classes for generic post-processing framework
mohamedelabbas1996 Sep 18, 2025
d86ea4d
feat: add post-processing framework base post-processing task class
mohamedelabbas1996 Sep 30, 2025
2c0f78f
feat: add small size filter post-processing task class
mohamedelabbas1996 Sep 30, 2025
ffba709
feat: add post processing job type
mohamedelabbas1996 Sep 30, 2025
63cd84b
feat: trigger small size filter post processing task from admin page
mohamedelabbas1996 Sep 30, 2025
cab62bf
feat: add a new algorithm task type for post-processing
mohamedelabbas1996 Sep 30, 2025
6d0e284
chore: deleted runner.py
mohamedelabbas1996 Sep 30, 2025
4cfe2d8
feat: add migration for creating a new job type
mohamedelabbas1996 Sep 30, 2025
b42e069
fix: fix an import error with the AlgorithmTaskType
mohamedelabbas1996 Sep 30, 2025
cb7c83a
feat: update identification history of occurrences in SmallSizeFilter
mohamedelabbas1996 Oct 2, 2025
10103db
feat: add rank rollup
mohamedelabbas1996 Oct 6, 2025
2e81d90
feat: add class masking post processing task
mohamedelabbas1996 Oct 7, 2025
0baf8ce
feat: trigger class masking from admin page
mohamedelabbas1996 Oct 7, 2025
f3caa18
fix: modified log messages
mohamedelabbas1996 Oct 8, 2025
65d4fef
fix: set the classification algorithm to the rank rollup Algorithm w…
mohamedelabbas1996 Oct 8, 2025
e13afc1
feat: trigger rank rollup from admin page
mohamedelabbas1996 Oct 8, 2025
7ecc18c
Remove class_masking.py from framework branch
mohamedelabbas1996 Oct 14, 2025
f214025
fix: initialize post-processing tasks with job context and simplify r…
mohamedelabbas1996 Oct 14, 2025
20ff4b6
feat: add permission to run post-processing jobs
mohamedelabbas1996 Oct 14, 2025
5b66ae3
chore: remove class_masking import
mohamedelabbas1996 Oct 14, 2025
0419eff
refactor: redesign BasePostProcessingTask with job-aware logging, pro…
mohamedelabbas1996 Oct 14, 2025
1ad1e76
refactor: adapt RankRollupTask to new BasePostProcessingTask with sel…
mohamedelabbas1996 Oct 14, 2025
d97e8e0
refactor: update SmallSizeFilter to use BasePostProcessingTask loggin…
mohamedelabbas1996 Oct 14, 2025
2922c86
migrations: update Project options to include post-processing job per…
mohamedelabbas1996 Oct 14, 2025
9012d7f
migrations: update Algorithm.task_type choices to include post-proces…
mohamedelabbas1996 Oct 14, 2025
319bb3d
Merge branch 'main' into feat/postprocessing-framework
mohamedelabbas1996 Oct 14, 2025
787ac0b
migrations: merged migrations
mohamedelabbas1996 Oct 14, 2025
5e85b75
refactor: refactor job runner to initialize post-processing tasks wit…
mohamedelabbas1996 Oct 10, 2025
88ffba8
chore: rebase feat/postprocessing-class-masking onto feat/postprocess…
mohamedelabbas1996 Oct 14, 2025
9519600
chore: remove class masking trigger (moved to feat/postprocessing-cla…
mohamedelabbas1996 Oct 14, 2025
21e6648
feat: improved progress tracking
mohamedelabbas1996 Oct 14, 2025
7135e15
Merge branch 'feat/postprocessing-framework' into feat/postprocessing…
mohamedelabbas1996 Oct 14, 2025
6632c31
feat: add applied_to field to Classification to track source classifi…
mohamedelabbas1996 Oct 15, 2025
23f80fb
tests: added tests for small size filter and rank roll up post-proces…
mohamedelabbas1996 Oct 15, 2025
336636a
fix: create only terminal classifications and remove identification c…
mohamedelabbas1996 Oct 15, 2025
0d90cde
refactor: remove inner transaction.atomic for cleaner transaction man…
mohamedelabbas1996 Oct 15, 2025
23469e2
tests: fixed small size filter test
mohamedelabbas1996 Oct 15, 2025
001464e
Merge branch 'feat/postprocessing-framework' into feat/postprocessing…
mohamedelabbas1996 Oct 15, 2025
916d652
Merge branch 'main' of github.com:RolnickLab/antenna into feat/postpr…
mihow Oct 16, 2025
1b8700e
draft: work towards class masking in new framework
mihow Oct 16, 2025
e4639f6
Merge remote-tracking branch 'origin/main' into feat/postprocessing-c…
mihow Feb 18, 2026
a466a52
feat: add class masking tests, management command, and fix registry
mihow Feb 18, 2026
a107597
fix: address review feedback on class masking and rank rollup
mihow Feb 18, 2026
da9b081
feat: replace hardcoded admin action with dynamic class masking form
mihow Feb 18, 2026
fc3f9e1
docs: add class masking screenshots for PR review
mihow Feb 18, 2026
c96a865
fix: address review feedback — N+1 query, distinct, HTML, test ordering
mihow Feb 18, 2026
6be1239
feat: expose applied_to field in Classification API serializers
mihow Feb 18, 2026
c4311aa
feat: make applied_to a nested object with algorithm details
mihow Feb 18, 2026
daed538
fix: add prefetch for applied_to on occurrence detail endpoint
mihow Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 121 additions & 1 deletion ami/main/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@

from django.contrib import admin
from django.db import models
from django.db.models import Count
from django.db.models.query import QuerySet
from django.http.request import HttpRequest
from django.template.defaultfilters import filesizeformat
from django.template.response import TemplateResponse
from django.urls import reverse
from django.utils.formats import number_format
from django.utils.html import format_html
from guardian.admin import GuardedModelAdmin

import ami.utils
from ami import tasks
from ami.jobs.models import Job
from ami.ml.models.algorithm import Algorithm
from ami.ml.models.project_pipeline_config import ProjectPipelineConfig
from ami.ml.post_processing.class_masking import update_single_occurrence
from ami.ml.tasks import remove_duplicate_classifications

from .models import (
Expand Down Expand Up @@ -288,20 +294,29 @@ class ClassificationInline(admin.TabularInline):
model = Classification
extra = 0
fields = (
"classification_link",
"taxon",
"algorithm",
"timestamp",
"terminal",
"created_at",
)
readonly_fields = (
"classification_link",
"taxon",
"algorithm",
"timestamp",
"terminal",
"created_at",
)

@admin.display(description="Classification")
def classification_link(self, obj: Classification) -> str:
if obj.pk:
url = reverse("admin:main_classification_change", args=[obj.pk])
return format_html('<a href="{}">{}</a>', url, f"Classification #{obj.pk}")
return "-"

def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
return qs.select_related("taxon", "algorithm", "detection")
Expand All @@ -311,20 +326,29 @@ class DetectionInline(admin.TabularInline):
model = Detection
extra = 0
fields = (
"detection_link",
"detection_algorithm",
"source_image",
"timestamp",
"created_at",
"occurrence",
)
readonly_fields = (
"detection_link",
"detection_algorithm",
"source_image",
"timestamp",
"created_at",
"occurrence",
)

@admin.display(description="ID")
def detection_link(self, obj):
if obj.pk:
url = reverse("admin:main_detection_change", args=[obj.pk])
return format_html('<a href="{}">{}</a>', url, obj.pk)
return "-"


@admin.register(Detection)
class DetectionAdmin(admin.ModelAdmin[Detection]):
Expand Down Expand Up @@ -382,7 +406,7 @@ class OccurrenceAdmin(admin.ModelAdmin[Occurrence]):
"determination__rank",
"created_at",
)
search_fields = ("determination__name", "determination__search_names")
search_fields = ("id", "determination__name", "determination__search_names")

def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
Expand All @@ -404,11 +428,83 @@ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
def detections_count(self, obj) -> int:
return obj.detections_count

@admin.action(description="Run class masking (select taxa list & algorithm)")
def run_class_masking(self, request: HttpRequest, queryset: QuerySet[Occurrence]) -> TemplateResponse | None:
"""
Run class masking on selected occurrences.
Shows an intermediate page to select a TaxaList and Algorithm.
"""
if request.POST.get("confirm"):
taxa_list_id = request.POST.get("taxa_list")
algorithm_id = request.POST.get("algorithm")
if not taxa_list_id or not algorithm_id:
self.message_user(request, "Please select both a taxa list and an algorithm.", level="error")
return None

try:
taxa_list = TaxaList.objects.get(pk=taxa_list_id)
algorithm = Algorithm.objects.get(pk=algorithm_id)
except (TaxaList.DoesNotExist, Algorithm.DoesNotExist) as e:
self.message_user(request, f"Error: {e}", level="error")
return None

if not algorithm.category_map:
self.message_user(
request, f"Algorithm '{algorithm.name}' does not have a category map.", level="error"
)
return None

count = 0
for occurrence in queryset:
try:
update_single_occurrence(
occurrence=occurrence,
algorithm=algorithm,
taxa_list=taxa_list,
)
count += 1
except Exception as e:
self.message_user(
request,
f"Error processing occurrence {occurrence.pk}: {e}",
level="error",
)

self.message_user(request, f"Successfully ran class masking on {count} occurrence(s).")
return None

# Show intermediate confirmation page
taxa_lists = TaxaList.objects.annotate(taxa_count=Count("taxa")).filter(taxa_count__gt=0).order_by("name")
algorithms = Algorithm.objects.filter(category_map__isnull=False).order_by("name")

# Annotate algorithms with label count
alg_list = []
for alg in algorithms:
alg.labels_count = len(alg.category_map.labels) if alg.category_map else 0
alg_list.append(alg)

return TemplateResponse(
request,
"admin/main/class_masking_confirmation.html",
{
**self.admin_site.each_context(request),
"title": "Run class masking",
"queryset": queryset,
"occurrence_count": queryset.count(),
"taxa_lists": taxa_lists,
"algorithms": alg_list,
"opts": self.model._meta,
"action_checkbox_name": admin.helpers.ACTION_CHECKBOX_NAME,
},
)

ordering = ("-created_at",)

# Add classifications as inline
inlines = [DetectionInline]

actions = [run_class_masking]


@admin.register(Classification)
class ClassificationAdmin(admin.ModelAdmin[Classification]):
Expand All @@ -432,6 +528,8 @@ class ClassificationAdmin(admin.ModelAdmin[Classification]):
"taxon__rank",
)

autocomplete_fields = ("taxon",)

def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
return qs.select_related(
Expand Down Expand Up @@ -662,10 +760,32 @@ def run_small_size_filter(self, request: HttpRequest, queryset: QuerySet[SourceI

self.message_user(request, f"Queued Small Size Filter for {queryset.count()} collection(s). Jobs: {jobs}")

@admin.action(description="Run Rank Rollup post-processing task (async)")
def run_rank_rollup(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
"""Trigger the Rank Rollup post-processing job asynchronously."""
jobs = []
DEFAULT_THRESHOLDS = {"SPECIES": 0.8, "GENUS": 0.6, "FAMILY": 0.4}

for collection in queryset:
job = Job.objects.create(
name=f"Post-processing: RankRollup on Collection {collection.pk}",
project=collection.project,
job_type_key="post_processing",
params={
"task": "rank_rollup",
"config": {"source_image_collection_id": collection.pk, "thresholds": DEFAULT_THRESHOLDS},
},
)
job.enqueue()
jobs.append(job.pk)

self.message_user(request, f"Queued Rank Rollup for {queryset.count()} collection(s). Jobs: {jobs}")

actions = [
populate_collection,
populate_collection_async,
run_small_size_filter,
run_rank_rollup,
]

# Hide images many-to-many field from form. This would list all source images in the database.
Expand Down
16 changes: 16 additions & 0 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,10 +874,21 @@ class ClassificationPredictionItemSerializer(serializers.Serializer):
logit = serializers.FloatField(read_only=True)


class ClassificationAppliedToSerializer(serializers.ModelSerializer):
"""Lightweight nested representation of the parent classification this was derived from."""

algorithm = AlgorithmSerializer(read_only=True)

class Meta:
model = Classification
fields = ["id", "created_at", "algorithm"]
Comment on lines +877 to +884
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ruff RUF012 false positive — no change needed, but suppress if CI enforces it

The fields = [...] list on line 884 triggers RUF012 ("mutable default class attribute"), but the identical pattern is used in every Meta class throughout the file without being flagged. This is a diff-scoped linter artifact, not a real risk. If CI is strict about RUF012, the quickest suppression is a # noqa: RUF012 comment on the fields line — otherwise leave it.

🧰 Tools
🪛 Ruff (0.15.1)

[warning] 884-884: Mutable default value for class attribute

(RUF012)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ami/main/api/serializers.py` around lines 877 - 884, This is a RUF012 false
positive for the mutable default class attribute on the Meta.fields list in
ClassificationAppliedToSerializer; if CI enforces RUF012, suppress it by adding
a local noqa comment on the Meta.fields line (i.e., add "# noqa: RUF012" to the
fields = ["id", "created_at", "algorithm"] line in the
ClassificationAppliedToSerializer.Meta block) so the linter ignores this
diff-scoped artifact.



class ClassificationSerializer(DefaultSerializer):
taxon = TaxonNestedSerializer(read_only=True)
algorithm = AlgorithmSerializer(read_only=True)
top_n = ClassificationPredictionItemSerializer(many=True, read_only=True)
applied_to = ClassificationAppliedToSerializer(read_only=True)

class Meta:
model = Classification
Expand All @@ -890,6 +901,7 @@ class Meta:
"scores",
"logits",
"top_n",
"applied_to",
"created_at",
"updated_at",
]
Expand All @@ -912,6 +924,8 @@ class Meta(ClassificationSerializer.Meta):


class ClassificationListSerializer(DefaultSerializer):
applied_to = ClassificationAppliedToSerializer(read_only=True)

class Meta:
model = Classification
fields = [
Expand All @@ -920,6 +934,7 @@ class Meta:
"taxon",
"score",
"algorithm",
"applied_to",
"created_at",
"updated_at",
]
Expand All @@ -939,6 +954,7 @@ class Meta:
"score",
"terminal",
"algorithm",
"applied_to",
"created_at",
]

Expand Down
14 changes: 12 additions & 2 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,17 @@ def get_queryset(self) -> QuerySet["Occurrence"]:
if self.action != "list":
qs = qs.prefetch_related(
Prefetch(
"detections", queryset=Detection.objects.order_by("-timestamp").select_related("source_image")
"detections",
queryset=Detection.objects.order_by("-timestamp")
.select_related("source_image")
.prefetch_related(
Prefetch(
"classifications",
queryset=Classification.objects.select_related(
"taxon", "algorithm", "applied_to__algorithm"
),
)
),
Comment on lines +1228 to +1238
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find the best_prediction definition on the Occurrence model/manager
rg -n "best_prediction" --type py -A 5 -B 2

Repository: RolnickLab/antenna

Length of output: 4132


🏁 Script executed:

#!/bin/bash
# Get the full best_prediction property implementation
sed -n '2912,2930p' ami/main/models.py

Repository: RolnickLab/antenna

Length of output: 945


🏁 Script executed:

#!/bin/bash
# Find ClassificationNestedSerializer definition and fields
rg -n "class ClassificationNestedSerializer" --type py -A 15

Repository: RolnickLab/antenna

Length of output: 1020


🏁 Script executed:

#!/bin/bash
# Check the exact serializer method mentioned in the review
sed -n '1331,1358p' ami/main/api/serializers.py

Repository: RolnickLab/antenna

Length of output: 1288


🏁 Script executed:

#!/bin/bash
# Verify the prefetch setup mentioned in lines 1228-1238 of views.py
sed -n '1228,1238p' ami/main/api/views.py

Repository: RolnickLab/antenna

Length of output: 571


🏁 Script executed:

#!/bin/bash
# Find the predictions() method on Occurrence model
rg -n "def predictions" ami/main/models.py -A 10

Repository: RolnickLab/antenna

Length of output: 1823


🏁 Script executed:

#!/bin/bash
# Check if there's any prefetch for best_prediction in the viewset
rg -n "best_prediction\|Prefetch.*prediction" ami/main/api/views.py -B 3 -A 3

Repository: RolnickLab/antenna

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Check the full context of the prefetch in views.py around line 1228
sed -n '1210,1250p' ami/main/api/views.py

Repository: RolnickLab/antenna

Length of output: 1576


🏁 Script executed:

#!/bin/bash
# Find OccurrenceListSerializer and check if it uses get_determination_details
rg -n "class OccurrenceListSerializer" ami/main/api/serializers.py -A 30

Repository: RolnickLab/antenna

Length of output: 1498


🏁 Script executed:

#!/bin/bash
# Check if determination_details field is in the serializer
rg -n "determination_details\|get_determination_details" ami/main/api/serializers.py -B 2 -A 2

Repository: RolnickLab/antenna

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Find the method that serves determination_details
rg -n "def.*determination_details" ami/main/api/serializers.py -A 5

Repository: RolnickLab/antenna

Length of output: 407


🏁 Script executed:

#!/bin/bash
# Since the field is called determination_details, look for get_determination_details method
rg -n "get_determination_details" ami/main/api/serializers.py -B 5 -A 20

Repository: RolnickLab/antenna

Length of output: 1334


🏁 Script executed:

#!/bin/bash
# Check what OccurrenceSerializer is used for list vs detail views
rg -n "get_serializer_class\|OccurrenceListSerializer\|OccurrenceSerializer" ami/main/api/views.py -B 2 -A 3 | head -40

Repository: RolnickLab/antenna

Length of output: 44


🏁 Script executed:

#!/bin/bash
# Verify the action check for prefetch application
sed -n '1225,1245p' ami/main/api/views.py

Repository: RolnickLab/antenna

Length of output: 762


🏁 Script executed:

#!/bin/bash
# Check if best_identification also has the same issue
rg -n "best_identification" ami/main/api/serializers.py -B 2 -A 2

Repository: RolnickLab/antenna

Length of output: 397


🏁 Script executed:

#!/bin/bash
# Verify which serializer is used for list action
rg -n "def get_serializer_class" ami/main/api/views.py -A 10 | grep -A 10 "Occurrence"

Repository: RolnickLab/antenna

Length of output: 577


🏁 Script executed:

#!/bin/bash
# Verify the prefetch condition one more time - check if list action gets prefetch
sed -n '1220,1250p' ami/main/api/views.py

Repository: RolnickLab/antenna

Length of output: 1226


🏁 Script executed:

#!/bin/bash
# Double-check the list serializer uses the problematic method
sed -n '1282,1330p' ami/main/api/serializers.py | grep -A 50 "class OccurrenceListSerializer"

Repository: RolnickLab/antenna

Length of output: 2011


🏁 Script executed:

#!/bin/bash
# Let's verify the predictions() query more carefully
sed -n '2941,2960p' ami/main/models.py

Repository: RolnickLab/antenna

Length of output: 882


🏁 Script executed:

#!/bin/bash
# Check what with_identifications does - might it prefetch best_identification?
rg -n "def with_identifications" ami/main/models.py -A 15

Repository: RolnickLab/antenna

Length of output: 731


best_prediction access in list view causes N+1 queries for masked classifications

The N+1 issue is valid. The list view queryset explicitly excludes the detections prefetch (if self.action != "list"), so OccurrenceListSerializer.get_determination_details() accesses obj.best_prediction without any eager loading. Since best_prediction is a @functools.cached_property that executes self.predictions().order_by("-terminal", "-score").first(), it triggers a fresh query. When the returned classification is serialized with ClassificationNestedSerializer (which includes the applied_to field), accessing applied_to and applied_to.algorithm on masked classifications causes two additional lazy-load queries per occurrence.

For a list of 10 occurrences where top classifications are masked results, this adds 30+ extra queries beyond the base query cost.

The long-term fix is to annotate and prefetch best_prediction at the queryset level (as the existing TODO suggests). Short-term: ensure the queryset used in predictions() eagerly loads applied_to__algorithm for classifications where it applies, or refactor get_determination_details() to selectively avoid serializing applied_to for list views.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ami/main/api/views.py` around lines 1228 - 1238,
OccurrenceListSerializer.get_determination_details() is triggering N+1 queries
by accessing Detection.best_prediction (a cached_property that calls
Detection.predictions().order_by(...).first()) without eager-loading applied_to
and applied_to.algorithm; fix by ensuring the queryset used when building
detections for list view (or inside Detection.predictions()) prefetches
applied_to__algorithm for Classification objects, or alternatively adjust
OccurrenceListSerializer to avoid serializing applied_to in list action (e.g.,
use a lightweight nested serializer for list views); target symbols:
OccurrenceListSerializer.get_determination_details, Detection.best_prediction,
Detection.predictions, and ClassificationNestedSerializer (and the
applied_to__algorithm prefetch) so the top prediction’s applied_to and its
algorithm are loaded eagerly.

)
)

Expand Down Expand Up @@ -1640,7 +1650,7 @@ class ClassificationViewSet(DefaultViewSet, ProjectMixin):
API endpoint for viewing and adding classification results from a model.
"""

queryset = Classification.objects.all().select_related("taxon", "algorithm") # , "detection")
queryset = Classification.objects.all().select_related("taxon", "algorithm", "applied_to__algorithm")
serializer_class = ClassificationSerializer
filterset_fields = [
# Docs about slow loading API browser because of large choice fields
Expand Down
83 changes: 83 additions & 0 deletions ami/ml/management/commands/run_class_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from django.core.management.base import BaseCommand, CommandError

from ami.main.models import SourceImageCollection, TaxaList
from ami.ml.models.algorithm import Algorithm
from ami.ml.post_processing.class_masking import ClassMaskingTask


class Command(BaseCommand):
help = (
"Run class masking post-processing on a source image collection. "
"Masks classifier logits for species not in the given taxa list and recalculates softmax scores."
)

def add_arguments(self, parser):
parser.add_argument("--collection-id", type=int, required=True, help="SourceImageCollection ID to process")
parser.add_argument("--taxa-list-id", type=int, required=True, help="TaxaList ID to use as the species mask")
parser.add_argument(
"--algorithm-id", type=int, required=True, help="Algorithm ID whose classifications to mask"
)
parser.add_argument("--dry-run", action="store_true", help="Show what would be done without making changes")

def handle(self, *args, **options):
collection_id = options["collection_id"]
taxa_list_id = options["taxa_list_id"]
algorithm_id = options["algorithm_id"]
dry_run = options["dry_run"]

# Validate inputs
try:
collection = SourceImageCollection.objects.get(pk=collection_id)
except SourceImageCollection.DoesNotExist:
raise CommandError(f"SourceImageCollection {collection_id} does not exist.")

try:
taxa_list = TaxaList.objects.get(pk=taxa_list_id)
except TaxaList.DoesNotExist:
raise CommandError(f"TaxaList {taxa_list_id} does not exist.")

try:
algorithm = Algorithm.objects.get(pk=algorithm_id)
except Algorithm.DoesNotExist:
raise CommandError(f"Algorithm {algorithm_id} does not exist.")

if not algorithm.category_map:
raise CommandError(f"Algorithm '{algorithm.name}' does not have a category map.")

from ami.main.models import Classification

classification_count = (
Classification.objects.filter(
detection__source_image__collections=collection,
terminal=True,
algorithm=algorithm,
scores__isnull=False,
)
.distinct()
.count()
)

taxa_count = taxa_list.taxa.count()

self.stdout.write(
f"Collection: {collection.name} (id={collection.pk})\n"
f"Taxa list: {taxa_list.name} (id={taxa_list.pk}, {taxa_count} taxa)\n"
f"Algorithm: {algorithm.name} (id={algorithm.pk})\n"
f"Classifications to process: {classification_count}"
)

if classification_count == 0:
raise CommandError("No terminal classifications with scores found for this collection/algorithm.")

if dry_run:
self.stdout.write(self.style.WARNING("Dry run — no changes made."))
return

self.stdout.write("Running class masking...")
task = ClassMaskingTask(
collection_id=collection_id,
taxa_list_id=taxa_list_id,
algorithm_id=algorithm_id,
)
task.run()
self.stdout.write(self.style.SUCCESS("Class masking completed."))
1 change: 0 additions & 1 deletion ami/ml/post_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from . import small_size_filter # noqa: F401
Loading