-
Notifications
You must be signed in to change notification settings - Fork 11
Implement class masking using the post-processing framework #999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f46e88c
d86ea4d
2c0f78f
ffba709
63cd84b
cab62bf
6d0e284
4cfe2d8
b42e069
cb7c83a
10103db
2e81d90
0baf8ce
f3caa18
65d4fef
e13afc1
7ecc18c
f214025
20ff4b6
5b66ae3
0419eff
1ad1e76
d97e8e0
2922c86
9012d7f
319bb3d
787ac0b
5e85b75
88ffba8
9519600
21e6648
7135e15
6632c31
23f80fb
336636a
0d90cde
23469e2
001464e
916d652
1b8700e
e4639f6
a466a52
a107597
da9b081
fc3f9e1
c96a865
6be1239
c4311aa
daed538
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -874,10 +874,21 @@ class ClassificationPredictionItemSerializer(serializers.Serializer): | |
| logit = serializers.FloatField(read_only=True) | ||
|
|
||
|
|
||
| class ClassificationAppliedToSerializer(serializers.ModelSerializer): | ||
| """Lightweight nested representation of the parent classification this was derived from.""" | ||
|
|
||
| algorithm = AlgorithmSerializer(read_only=True) | ||
|
|
||
| class Meta: | ||
| model = Classification | ||
| fields = ["id", "created_at", "algorithm"] | ||
|
Comment on lines
+877
to
+884
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ruff RUF012 false positive — no change needed, but suppress if CI enforces it The 🧰 Tools🪛 Ruff (0.15.1)[warning] 884-884: Mutable default value for class attribute (RUF012) 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| class ClassificationSerializer(DefaultSerializer): | ||
| taxon = TaxonNestedSerializer(read_only=True) | ||
| algorithm = AlgorithmSerializer(read_only=True) | ||
| top_n = ClassificationPredictionItemSerializer(many=True, read_only=True) | ||
| applied_to = ClassificationAppliedToSerializer(read_only=True) | ||
|
|
||
| class Meta: | ||
| model = Classification | ||
|
|
@@ -890,6 +901,7 @@ class Meta: | |
| "scores", | ||
| "logits", | ||
| "top_n", | ||
| "applied_to", | ||
| "created_at", | ||
| "updated_at", | ||
| ] | ||
|
|
@@ -912,6 +924,8 @@ class Meta(ClassificationSerializer.Meta): | |
|
|
||
|
|
||
| class ClassificationListSerializer(DefaultSerializer): | ||
| applied_to = ClassificationAppliedToSerializer(read_only=True) | ||
|
|
||
| class Meta: | ||
| model = Classification | ||
| fields = [ | ||
|
|
@@ -920,6 +934,7 @@ class Meta: | |
| "taxon", | ||
| "score", | ||
| "algorithm", | ||
| "applied_to", | ||
| "created_at", | ||
| "updated_at", | ||
| ] | ||
|
|
@@ -939,6 +954,7 @@ class Meta: | |
| "score", | ||
| "terminal", | ||
| "algorithm", | ||
| "applied_to", | ||
| "created_at", | ||
| ] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1225,7 +1225,17 @@ def get_queryset(self) -> QuerySet["Occurrence"]: | |
| if self.action != "list": | ||
| qs = qs.prefetch_related( | ||
| Prefetch( | ||
| "detections", queryset=Detection.objects.order_by("-timestamp").select_related("source_image") | ||
| "detections", | ||
| queryset=Detection.objects.order_by("-timestamp") | ||
| .select_related("source_image") | ||
| .prefetch_related( | ||
| Prefetch( | ||
| "classifications", | ||
| queryset=Classification.objects.select_related( | ||
| "taxon", "algorithm", "applied_to__algorithm" | ||
| ), | ||
| ) | ||
| ), | ||
|
Comment on lines
+1228
to
+1238
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Find the best_prediction definition on the Occurrence model/manager
rg -n "best_prediction" --type py -A 5 -B 2Repository: RolnickLab/antenna Length of output: 4132 🏁 Script executed: #!/bin/bash
# Get the full best_prediction property implementation
sed -n '2912,2930p' ami/main/models.pyRepository: RolnickLab/antenna Length of output: 945 🏁 Script executed: #!/bin/bash
# Find ClassificationNestedSerializer definition and fields
rg -n "class ClassificationNestedSerializer" --type py -A 15Repository: RolnickLab/antenna Length of output: 1020 🏁 Script executed: #!/bin/bash
# Check the exact serializer method mentioned in the review
sed -n '1331,1358p' ami/main/api/serializers.pyRepository: RolnickLab/antenna Length of output: 1288 🏁 Script executed: #!/bin/bash
# Verify the prefetch setup mentioned in lines 1228-1238 of views.py
sed -n '1228,1238p' ami/main/api/views.pyRepository: RolnickLab/antenna Length of output: 571 🏁 Script executed: #!/bin/bash
# Find the predictions() method on Occurrence model
rg -n "def predictions" ami/main/models.py -A 10Repository: RolnickLab/antenna Length of output: 1823 🏁 Script executed: #!/bin/bash
# Check if there's any prefetch for best_prediction in the viewset
rg -n "best_prediction\|Prefetch.*prediction" ami/main/api/views.py -B 3 -A 3Repository: RolnickLab/antenna Length of output: 44 🏁 Script executed: #!/bin/bash
# Check the full context of the prefetch in views.py around line 1228
sed -n '1210,1250p' ami/main/api/views.pyRepository: RolnickLab/antenna Length of output: 1576 🏁 Script executed: #!/bin/bash
# Find OccurrenceListSerializer and check if it uses get_determination_details
rg -n "class OccurrenceListSerializer" ami/main/api/serializers.py -A 30Repository: RolnickLab/antenna Length of output: 1498 🏁 Script executed: #!/bin/bash
# Check if determination_details field is in the serializer
rg -n "determination_details\|get_determination_details" ami/main/api/serializers.py -B 2 -A 2Repository: RolnickLab/antenna Length of output: 44 🏁 Script executed: #!/bin/bash
# Find the method that serves determination_details
rg -n "def.*determination_details" ami/main/api/serializers.py -A 5Repository: RolnickLab/antenna Length of output: 407 🏁 Script executed: #!/bin/bash
# Since the field is called determination_details, look for get_determination_details method
rg -n "get_determination_details" ami/main/api/serializers.py -B 5 -A 20Repository: RolnickLab/antenna Length of output: 1334 🏁 Script executed: #!/bin/bash
# Check what OccurrenceSerializer is used for list vs detail views
rg -n "get_serializer_class\|OccurrenceListSerializer\|OccurrenceSerializer" ami/main/api/views.py -B 2 -A 3 | head -40Repository: RolnickLab/antenna Length of output: 44 🏁 Script executed: #!/bin/bash
# Verify the action check for prefetch application
sed -n '1225,1245p' ami/main/api/views.pyRepository: RolnickLab/antenna Length of output: 762 🏁 Script executed: #!/bin/bash
# Check if best_identification also has the same issue
rg -n "best_identification" ami/main/api/serializers.py -B 2 -A 2Repository: RolnickLab/antenna Length of output: 397 🏁 Script executed: #!/bin/bash
# Verify which serializer is used for list action
rg -n "def get_serializer_class" ami/main/api/views.py -A 10 | grep -A 10 "Occurrence"Repository: RolnickLab/antenna Length of output: 577 🏁 Script executed: #!/bin/bash
# Verify the prefetch condition one more time - check if list action gets prefetch
sed -n '1220,1250p' ami/main/api/views.pyRepository: RolnickLab/antenna Length of output: 1226 🏁 Script executed: #!/bin/bash
# Double-check the list serializer uses the problematic method
sed -n '1282,1330p' ami/main/api/serializers.py | grep -A 50 "class OccurrenceListSerializer"Repository: RolnickLab/antenna Length of output: 2011 🏁 Script executed: #!/bin/bash
# Let's verify the predictions() query more carefully
sed -n '2941,2960p' ami/main/models.pyRepository: RolnickLab/antenna Length of output: 882 🏁 Script executed: #!/bin/bash
# Check what with_identifications does - might it prefetch best_identification?
rg -n "def with_identifications" ami/main/models.py -A 15Repository: RolnickLab/antenna Length of output: 731
The N+1 issue is valid. The list view queryset explicitly excludes the detections prefetch ( For a list of 10 occurrences where top classifications are masked results, this adds 30+ extra queries beyond the base query cost. The long-term fix is to annotate and prefetch 🤖 Prompt for AI Agents |
||
| ) | ||
| ) | ||
|
|
||
|
|
@@ -1640,7 +1650,7 @@ class ClassificationViewSet(DefaultViewSet, ProjectMixin): | |
| API endpoint for viewing and adding classification results from a model. | ||
| """ | ||
|
|
||
| queryset = Classification.objects.all().select_related("taxon", "algorithm") # , "detection") | ||
| queryset = Classification.objects.all().select_related("taxon", "algorithm", "applied_to__algorithm") | ||
| serializer_class = ClassificationSerializer | ||
| filterset_fields = [ | ||
| # Docs about slow loading API browser because of large choice fields | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| from django.core.management.base import BaseCommand, CommandError | ||
|
|
||
| from ami.main.models import SourceImageCollection, TaxaList | ||
| from ami.ml.models.algorithm import Algorithm | ||
| from ami.ml.post_processing.class_masking import ClassMaskingTask | ||
|
|
||
|
|
||
| class Command(BaseCommand): | ||
| help = ( | ||
| "Run class masking post-processing on a source image collection. " | ||
| "Masks classifier logits for species not in the given taxa list and recalculates softmax scores." | ||
| ) | ||
|
|
||
| def add_arguments(self, parser): | ||
| parser.add_argument("--collection-id", type=int, required=True, help="SourceImageCollection ID to process") | ||
| parser.add_argument("--taxa-list-id", type=int, required=True, help="TaxaList ID to use as the species mask") | ||
| parser.add_argument( | ||
| "--algorithm-id", type=int, required=True, help="Algorithm ID whose classifications to mask" | ||
| ) | ||
| parser.add_argument("--dry-run", action="store_true", help="Show what would be done without making changes") | ||
|
|
||
| def handle(self, *args, **options): | ||
| collection_id = options["collection_id"] | ||
| taxa_list_id = options["taxa_list_id"] | ||
| algorithm_id = options["algorithm_id"] | ||
| dry_run = options["dry_run"] | ||
|
|
||
| # Validate inputs | ||
| try: | ||
| collection = SourceImageCollection.objects.get(pk=collection_id) | ||
| except SourceImageCollection.DoesNotExist: | ||
| raise CommandError(f"SourceImageCollection {collection_id} does not exist.") | ||
|
|
||
| try: | ||
| taxa_list = TaxaList.objects.get(pk=taxa_list_id) | ||
| except TaxaList.DoesNotExist: | ||
| raise CommandError(f"TaxaList {taxa_list_id} does not exist.") | ||
|
|
||
| try: | ||
| algorithm = Algorithm.objects.get(pk=algorithm_id) | ||
| except Algorithm.DoesNotExist: | ||
| raise CommandError(f"Algorithm {algorithm_id} does not exist.") | ||
|
|
||
| if not algorithm.category_map: | ||
| raise CommandError(f"Algorithm '{algorithm.name}' does not have a category map.") | ||
|
|
||
| from ami.main.models import Classification | ||
|
|
||
| classification_count = ( | ||
| Classification.objects.filter( | ||
| detection__source_image__collections=collection, | ||
| terminal=True, | ||
| algorithm=algorithm, | ||
| scores__isnull=False, | ||
| ) | ||
| .distinct() | ||
| .count() | ||
| ) | ||
|
|
||
| taxa_count = taxa_list.taxa.count() | ||
|
|
||
| self.stdout.write( | ||
| f"Collection: {collection.name} (id={collection.pk})\n" | ||
| f"Taxa list: {taxa_list.name} (id={taxa_list.pk}, {taxa_count} taxa)\n" | ||
| f"Algorithm: {algorithm.name} (id={algorithm.pk})\n" | ||
| f"Classifications to process: {classification_count}" | ||
| ) | ||
|
|
||
| if classification_count == 0: | ||
| raise CommandError("No terminal classifications with scores found for this collection/algorithm.") | ||
|
|
||
| if dry_run: | ||
| self.stdout.write(self.style.WARNING("Dry run — no changes made.")) | ||
| return | ||
|
|
||
| self.stdout.write("Running class masking...") | ||
| task = ClassMaskingTask( | ||
| collection_id=collection_id, | ||
| taxa_list_id=taxa_list_id, | ||
| algorithm_id=algorithm_id, | ||
| ) | ||
| task.run() | ||
| self.stdout.write(self.style.SUCCESS("Class masking completed.")) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +0,0 @@ | ||
| from . import small_size_filter # noqa: F401 | ||
Uh oh!
There was an error while loading. Please reload this page.