From 88e5537e0e290eac761adbb3fd9b392414c60720 Mon Sep 17 00:00:00 2001 From: VedantDalimkar Date: Sun, 14 Sep 2025 18:54:12 +0530 Subject: [PATCH 1/4] Focal loss vectorised --- .gitignore | 4 +++- segmentation_models_pytorch/losses/focal.py | 21 +++++++++++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 33db579f7..21f4d6f90 100644 --- a/.gitignore +++ b/.gitignore @@ -109,4 +109,6 @@ venv.bak/ .mypy_cache/ # ruff -.ruff_cache/ \ No newline at end of file +.ruff_cache/ + +*.ipynb \ No newline at end of file diff --git a/segmentation_models_pytorch/losses/focal.py b/segmentation_models_pytorch/losses/focal.py index 3beb9f34e..6a9150c88 100644 --- a/segmentation_models_pytorch/losses/focal.py +++ b/segmentation_models_pytorch/losses/focal.py @@ -70,20 +70,21 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: elif self.mode == MULTICLASS_MODE: num_classes = y_pred.size(1) - loss = 0 - # Filter anchors with -1 label from loss computation + # If ignore_index parameter is passed, treat it as an extra class during one-hot encoding and remove it later + # One-hot encoding the labels allows us to vectorise the focal loss computation if self.ignore_index is not None: - not_ignored = y_true != self.ignore_index + y_true[y_true == self.ignore_index] = num_classes + y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes + 1) + y_true_one_hot = y_true_one_hot[ ... , : -1] - for cls in range(num_classes): - cls_y_true = (y_true == cls).long() - cls_y_pred = y_pred[:, cls, ...] + else: + y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes) - if self.ignore_index is not None: - cls_y_true = cls_y_true[not_ignored] - cls_y_pred = cls_y_pred[not_ignored] + y_true_one_hot = torch.permute(y_true_one_hot,(0,3,1,2)) - loss += self.focal_loss_fn(cls_y_pred, cls_y_true) + # Multiplying the loss by num_classes in order to stay consistent with the earlier loss computation which did not + # take a classwise average of the loss + loss = num_classes * self.focal_loss_fn(y_pred, y_true_one_hot) return loss From a6ffe174f64e1d7886c65aaf1f14ef60c8a71a93 Mon Sep 17 00:00:00 2001 From: VedantDalimkar Date: Sun, 14 Sep 2025 19:09:16 +0530 Subject: [PATCH 2/4] Added notebook that demonstrates speedups over benchmark --- .gitignore | 4 +- ...focal_loss_optimisation_benchmarking.ipynb | 189 ++++++++++++++++++ 2 files changed, 190 insertions(+), 3 deletions(-) create mode 100644 segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb diff --git a/.gitignore b/.gitignore index 21f4d6f90..33db579f7 100644 --- a/.gitignore +++ b/.gitignore @@ -109,6 +109,4 @@ venv.bak/ .mypy_cache/ # ruff -.ruff_cache/ - -*.ipynb \ No newline at end of file +.ruff_cache/ \ No newline at end of file diff --git a/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb b/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb new file mode 100644 index 000000000..bf7900d63 --- /dev/null +++ b/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "245a88c9", + "metadata": {}, + "outputs": [], + "source": [ + "from operator import not_\n", + "import numpy as np\n", + "from typing import Optional\n", + "np.random.seed(42)\n", + "from time import time\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import timm\n", + "from torchvision.transforms.functional import resize\n", + "import segmentation_models_pytorch.losses\n", + "from functools import partial\n", + "from segmentation_models_pytorch.losses._functional import focal_loss_with_logits\n", + "from segmentation_models_pytorch.losses import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE\n", + "\n", + "\n", + "class FocalLossVectorised(segmentation_models_pytorch.losses.FocalLoss):\n", + " def __init__(\n", + " self,\n", + " mode: str,\n", + " alpha: Optional[float] = None,\n", + " gamma: Optional[float] = 2.0,\n", + " ignore_index: Optional[int] = None,\n", + " reduction: Optional[str] = \"mean\",\n", + " normalized: bool = False,\n", + " reduced_threshold: Optional[float] = None,\n", + " ):\n", + " \n", + " super().__init__(mode = mode,alpha = alpha,gamma = gamma, ignore_index = ignore_index,reduction = reduction,\n", + " normalized = normalized,reduced_threshold = reduced_threshold)\n", + " \n", + " def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n", + " if self.mode in {BINARY_MODE, MULTILABEL_MODE}:\n", + " y_true = y_true.view(-1)\n", + " y_pred = y_pred.view(-1)\n", + "\n", + " if self.ignore_index is not None:\n", + " # Filter predictions with ignore label from loss computation\n", + " not_ignored = y_true != self.ignore_index\n", + " y_pred = y_pred[not_ignored]\n", + " y_true = y_true[not_ignored]\n", + "\n", + " loss = self.focal_loss_fn(y_pred, y_true)\n", + "\n", + " elif self.mode == MULTICLASS_MODE:\n", + " num_classes = y_pred.size(1)\n", + "\n", + " if self.ignore_index is not None:\n", + " y_true[y_true == self.ignore_index] = num_classes\n", + " y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes + 1)\n", + " y_true_one_hot = y_true_one_hot[ : , : , : , : -1]\n", + "\n", + " else: \n", + " y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes)\n", + "\n", + " y_true_one_hot = torch.permute(y_true_one_hot,(0,3,1,2))\n", + " loss = num_classes * self.focal_loss_fn(y_pred, y_true_one_hot)\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c64c3ea", + "metadata": {}, + "outputs": [], + "source": [ + "num_classes = 20\n", + "batch_size = 16\n", + "resolution = 512\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "d4d3a5f5", + "metadata": {}, + "outputs": [], + "source": [ + "vectorised_loss_fn = FocalLossVectorised(mode = 'multiclass')\n", + "loss_fn = segmentation_models_pytorch.losses.FocalLoss(mode = 'multiclass')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e49a5b8", + "metadata": {}, + "outputs": [], + "source": [ + "predictions = torch.randn((batch_size,num_classes,resolution,resolution)).to(device = device)\n", + "labels = torch.randint(low = 0,high = num_classes,size = (batch_size,resolution,resolution)).to(device = device)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "36a0f89b", + "metadata": {}, + "outputs": [], + "source": [ + "def benchmark(function,predictions,labels,benchmark_iterations = 1000):\n", + " start_time = time()\n", + "\n", + " for _ in range(benchmark_iterations):\n", + " loss = function(predictions,labels)\n", + "\n", + " end_time = time()\n", + "\n", + " average_time_taken = (end_time - start_time) / (benchmark_iterations)\n", + "\n", + " print(f\"Average time taken by function {function} is {average_time_taken} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8de20667", + "metadata": {}, + "outputs": [], + "source": [ + "benchmark(loss_fn,predictions,labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6da16fc2", + "metadata": {}, + "outputs": [], + "source": [ + "benchmark(vectorised_loss_fn,predictions,labels)" + ] + }, + { + "cell_type": "markdown", + "id": "a1083cd4", + "metadata": {}, + "source": [ + "##### CHECKING THAT OUTPUT OF NEW CLASS IS CONSISTENT WITH THE OLD ONE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8fba3182", + "metadata": {}, + "outputs": [], + "source": [ + "output_from_vectorised_fn = vectorised_loss_fn(predictions,labels)\n", + "output_from_old_fn = loss_fn(predictions,labels)\n", + "\n", + "assert torch.allclose(output_from_vectorised_fn,output_from_old_fn)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "smp_dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From f517f8a13173fae670dd17abd24784f26b5041f3 Mon Sep 17 00:00:00 2001 From: ved Date: Sun, 14 Sep 2025 19:32:05 +0530 Subject: [PATCH 3/4] Change some params in the notebook --- ...focal_loss_optimisation_benchmarking.ipynb | 52 ++++++++++++------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb b/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb index bf7900d63..77e1cf07b 100644 --- a/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb +++ b/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "245a88c9", "metadata": {}, "outputs": [], @@ -70,47 +70,47 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "9c64c3ea", "metadata": {}, "outputs": [], "source": [ "num_classes = 20\n", - "batch_size = 16\n", + "batch_size = 128\n", "resolution = 512\n", - "device = 'cuda' if torch.cuda.is_available() else 'cpu'" + "device = 'cuda:1' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 3, "id": "d4d3a5f5", "metadata": {}, "outputs": [], "source": [ - "vectorised_loss_fn = FocalLossVectorised(mode = 'multiclass')\n", - "loss_fn = segmentation_models_pytorch.losses.FocalLoss(mode = 'multiclass')" + "vectorised_loss_fn = FocalLossVectorised(mode = 'multiclass',ignore_index = num_classes)\n", + "loss_fn = segmentation_models_pytorch.losses.FocalLoss(mode = 'multiclass',ignore_index = num_classes)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "5e49a5b8", "metadata": {}, "outputs": [], "source": [ "predictions = torch.randn((batch_size,num_classes,resolution,resolution)).to(device = device)\n", - "labels = torch.randint(low = 0,high = num_classes,size = (batch_size,resolution,resolution)).to(device = device)" + "labels = torch.randint(low = 0,high = num_classes+1,size = (batch_size,resolution,resolution)).to(device = device)" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 5, "id": "36a0f89b", "metadata": {}, "outputs": [], "source": [ - "def benchmark(function,predictions,labels,benchmark_iterations = 1000):\n", + "def benchmark(function,predictions,labels,benchmark_iterations = 100):\n", " start_time = time()\n", "\n", " for _ in range(benchmark_iterations):\n", @@ -125,20 +125,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "8de20667", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average time taken by function FocalLoss() is 0.3390256547927856 seconds\n" + ] + } + ], "source": [ "benchmark(loss_fn,predictions,labels)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "6da16fc2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average time taken by function FocalLossVectorised() is 0.11771584510803222 seconds\n" + ] + } + ], "source": [ "benchmark(vectorised_loss_fn,predictions,labels)" ] @@ -153,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "8fba3182", "metadata": {}, "outputs": [], @@ -167,7 +183,7 @@ ], "metadata": { "kernelspec": { - "display_name": "smp_dev", + "display_name": "torch", "language": "python", "name": "python3" }, @@ -181,7 +197,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.21" + "version": "3.9.19" } }, "nbformat": 4, From 91ad2620ba46892502e69b739d507015844e4bad Mon Sep 17 00:00:00 2001 From: VedantDalimkar Date: Sun, 14 Sep 2025 19:38:26 +0530 Subject: [PATCH 4/4] Code linting changes --- .../focal_loss_optimisation_benchmarking.ipynb | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb b/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb index 77e1cf07b..e47ef9e07 100644 --- a/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb +++ b/segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb @@ -2,26 +2,16 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "245a88c9", "metadata": {}, "outputs": [], "source": [ - "from operator import not_\n", - "import numpy as np\n", - "from typing import Optional\n", - "np.random.seed(42)\n", + "from segmentation_models_pytorch.losses import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE\n", "from time import time\n", + "from typing import Optional\n", "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import timm\n", - "from torchvision.transforms.functional import resize\n", - "import segmentation_models_pytorch.losses\n", - "from functools import partial\n", - "from segmentation_models_pytorch.losses._functional import focal_loss_with_logits\n", - "from segmentation_models_pytorch.losses import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE\n", - "\n", + "import segmentation_models_pytorch\n", "\n", "class FocalLossVectorised(segmentation_models_pytorch.losses.FocalLoss):\n", " def __init__(\n",