Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions segmentation_models_pytorch/losses/focal.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,21 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

elif self.mode == MULTICLASS_MODE:
num_classes = y_pred.size(1)
loss = 0

# Filter anchors with -1 label from loss computation
# If ignore_index parameter is passed, treat it as an extra class during one-hot encoding and remove it later
# One-hot encoding the labels allows us to vectorise the focal loss computation
if self.ignore_index is not None:
not_ignored = y_true != self.ignore_index
y_true[y_true == self.ignore_index] = num_classes
y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes + 1)
y_true_one_hot = y_true_one_hot[ ... , : -1]

for cls in range(num_classes):
cls_y_true = (y_true == cls).long()
cls_y_pred = y_pred[:, cls, ...]
else:
y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes)

if self.ignore_index is not None:
cls_y_true = cls_y_true[not_ignored]
cls_y_pred = cls_y_pred[not_ignored]
y_true_one_hot = torch.permute(y_true_one_hot,(0,3,1,2))

loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
# Multiplying the loss by num_classes in order to stay consistent with the earlier loss computation which did not
# take a classwise average of the loss
loss = num_classes * self.focal_loss_fn(y_pred, y_true_one_hot)

return loss
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "245a88c9",
"metadata": {},
"outputs": [],
"source": [
"from segmentation_models_pytorch.losses import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE\n",
"from time import time\n",
"from typing import Optional\n",
"import torch\n",
"import segmentation_models_pytorch\n",
"\n",
"class FocalLossVectorised(segmentation_models_pytorch.losses.FocalLoss):\n",
" def __init__(\n",
" self,\n",
" mode: str,\n",
" alpha: Optional[float] = None,\n",
" gamma: Optional[float] = 2.0,\n",
" ignore_index: Optional[int] = None,\n",
" reduction: Optional[str] = \"mean\",\n",
" normalized: bool = False,\n",
" reduced_threshold: Optional[float] = None,\n",
" ):\n",
" \n",
" super().__init__(mode = mode,alpha = alpha,gamma = gamma, ignore_index = ignore_index,reduction = reduction,\n",
" normalized = normalized,reduced_threshold = reduced_threshold)\n",
" \n",
" def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n",
" if self.mode in {BINARY_MODE, MULTILABEL_MODE}:\n",
" y_true = y_true.view(-1)\n",
" y_pred = y_pred.view(-1)\n",
"\n",
" if self.ignore_index is not None:\n",
" # Filter predictions with ignore label from loss computation\n",
" not_ignored = y_true != self.ignore_index\n",
" y_pred = y_pred[not_ignored]\n",
" y_true = y_true[not_ignored]\n",
"\n",
" loss = self.focal_loss_fn(y_pred, y_true)\n",
"\n",
" elif self.mode == MULTICLASS_MODE:\n",
" num_classes = y_pred.size(1)\n",
"\n",
" if self.ignore_index is not None:\n",
" y_true[y_true == self.ignore_index] = num_classes\n",
" y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes + 1)\n",
" y_true_one_hot = y_true_one_hot[ : , : , : , : -1]\n",
"\n",
" else: \n",
" y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes)\n",
"\n",
" y_true_one_hot = torch.permute(y_true_one_hot,(0,3,1,2))\n",
" loss = num_classes * self.focal_loss_fn(y_pred, y_true_one_hot)\n",
"\n",
" return loss"
]
},
{
"cell_type": "code",

Check failure on line 62 in segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb

View workflow job for this annotation

GitHub Actions / style

Ruff (F841)

segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb:1:1: F841 Local variable `loss` is assigned to but never used
"execution_count": 2,
"id": "9c64c3ea",
"metadata": {},
"outputs": [],
"source": [
"num_classes = 20\n",
"batch_size = 128\n",
"resolution = 512\n",
"device = 'cuda:1' if torch.cuda.is_available() else 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d4d3a5f5",
"metadata": {},
"outputs": [],
"source": [
"vectorised_loss_fn = FocalLossVectorised(mode = 'multiclass',ignore_index = num_classes)\n",
"loss_fn = segmentation_models_pytorch.losses.FocalLoss(mode = 'multiclass',ignore_index = num_classes)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5e49a5b8",
"metadata": {},
"outputs": [],
"source": [
"predictions = torch.randn((batch_size,num_classes,resolution,resolution)).to(device = device)\n",
"labels = torch.randint(low = 0,high = num_classes+1,size = (batch_size,resolution,resolution)).to(device = device)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "36a0f89b",
"metadata": {},
"outputs": [],
"source": [
"def benchmark(function,predictions,labels,benchmark_iterations = 100):\n",
" start_time = time()\n",
"\n",
" for _ in range(benchmark_iterations):\n",
" loss = function(predictions,labels)\n",
"\n",
" end_time = time()\n",
"\n",
" average_time_taken = (end_time - start_time) / (benchmark_iterations)\n",
"\n",
" print(f\"Average time taken by function {function} is {average_time_taken} seconds\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8de20667",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average time taken by function FocalLoss() is 0.3390256547927856 seconds\n"
]
}
],
"source": [
"benchmark(loss_fn,predictions,labels)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6da16fc2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average time taken by function FocalLossVectorised() is 0.11771584510803222 seconds\n"
]
}
],
"source": [
"benchmark(vectorised_loss_fn,predictions,labels)"
]
},
{
"cell_type": "markdown",
"id": "a1083cd4",
"metadata": {},
"source": [
"##### CHECKING THAT OUTPUT OF NEW CLASS IS CONSISTENT WITH THE OLD ONE"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8fba3182",
"metadata": {},
"outputs": [],
"source": [
"output_from_vectorised_fn = vectorised_loss_fn(predictions,labels)\n",
"output_from_old_fn = loss_fn(predictions,labels)\n",
"\n",
"assert torch.allclose(output_from_vectorised_fn,output_from_old_fn)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading