Skip to content

Conversation

vedantdalimkar
Copy link
Contributor

@vedantdalimkar vedantdalimkar commented Sep 14, 2025

This PR addresses #1235

The current focal loss implementation iterates over each class and calculates focal loss in a class-wise manner. This is slightly inefficient and can be optimised by vectorising the loss computation in multiclass mode. Also, the current implementation uses expensive masking operations for filtering out pixels belonging to ignore_index class

I have also attached a notebook that benchmarks the new approach against the old one. The time improvement is significant, often speeding up the code by more than 2x! The notebook also shows that the output of the new function is consistent with the new one.

@qubvel let me know if I need to add anymore tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant