The pytorch-focalloss package contains the python package torch_focalloss, which provides PyTorch implementations of binary and multi-class focal loss functions.
pytorch-focalloss is installable from PyPI.
pip install pytorch-focalloss
The python package is importable as torch_focalloss. The only components in the package are the BinaryFocalLoss and MultiClassFocalLoss classes, which have interfaces that allow them to be drop-in replacements for PyTorch's BCEWithLogitsLoss and CrossEntropyLoss classes, respectively. All of the same keyword arguments are supported, as well as the focusing parameter
Benchmarks for comparing run times and memory usage of the focal loss implementations compared to their standard counterparts can be run using python ./benchmarking/benchmark_X.py (where X is one of basic, advanced, or training) from the repository's root directory.
Focal loss was first described in Lin et al.'s "Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002).
This implementation of binary focal loss extends the original slightly, allowing for multi-label classification with no modification needed, including support for using a different value of
It is built on top of PyTorch's BCEWithLogitsLoss class, and supports all of the same arguments. The pos_weight argument is given as alpha (but can alternatively be given as pos_weight for drop-in compatability with BCEWithLogitsLoss), and the reduction and weight arguments are the same as for BCEWithLogitsLoss.
The multi-class focal loss is a logical extension of the original binary focal loss to the N-class case. It similarly accepts a tensor of weights, with one for each class, as
It is built on top of PyTorch's CrossEntropyLoss class, and supports all of the same arguments. The weight argument is given as alpha (but can alternatively be given as weight for drop-in compatability with CrossEntropyLoss), and the reduction, ignore_index, and label_smoothing arguments are the same as for CrossEntropyLoss.
Note that one difference from CrossEntropyLoss is that if all samples have target value ignore_index, then MultiClassFocalLoss returns 0 where CrossEntropyLoss would return nan.
See below or check out DEMO.ipynb above for a demonstration of how the binary and multi-class focal losses work and compare to the standard cross entropy versions.
There are also benchmarks available to run using python ./benchmarking/benchmark_X.py (where X is one of basic, advanced, or training) from the repository's root directory that can compare the run times and memory usage of the focal loss implementations compared to their standard counterparts.
from torch import cuda, float32, ones, randint, randn, tensor
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torch_focalloss import BinaryFocalLoss, MultiClassFocalLossdevice = "cuda" if cuda.is_available() else "cpu"We'll use the same inputs for the whole example to demonstrate how changes in parameters changes the loss value.
First we create our simulated batch of 5 binary labels and raw logits.
preds = randn(5, device=device)
target = randint(2, size=(5,), dtype=float32, device=device)
print("Logits: ", preds)
print("Target: ", target)Logits: tensor([ 0.5257, 0.9124, -0.9304, 1.0868, 1.2109], device='cuda:0')
Target: tensor([1., 1., 1., 0., 1.], device='cuda:0')
The normal binary cross entropy loss is the same as focal loss when
gamma = 0
bce = BCEWithLogitsLoss()
bfl = BinaryFocalLoss(gamma=gamma)
print(f"BCE Loss: {bce(preds, target).item():.5f}")
print(f"Focal Loss: {bfl(preds, target).item():.5f}")BCE Loss: 0.74063
Focal Loss: 0.74063
This is also true when the weight applied to the positive class (1) relative to the negative class (0) is not 1. This parameter is called pos_weight parameter of the BCEWithLogits class, which is used to help manage class imbalance.
gamma = 0
alpha = tensor(1.5, device=device)
bce = BCEWithLogitsLoss(pos_weight=alpha)
bfl = BinaryFocalLoss(gamma=gamma, alpha=alpha)
print(f"BCE Loss: {bce(preds, target).item():.5f}")
print(f"Focal Loss: {bfl(preds, target).item():.5f}")BCE Loss: 0.97320
Focal Loss: 0.97320
Note that our
The formula
Focal loss differs from binary cross entropy loss when
gamma = 2
bce = BCEWithLogitsLoss()
bfl = BinaryFocalLoss(gamma=gamma)
print(f"BCE Loss: {bce(preds, target).item():.5f}")
print(f"Focal Loss: {bfl(preds, target).item():.5f}")BCE Loss: 0.74063
Focal Loss: 0.30507
Just like binary cross entropy loss, we can use our binary focal loss for multi-label classification without modification.
We will simulate a batch of 5 samples, each with 3 binary labels.
preds = randn(5, 3, device=device)
target = randint(2, size=(5, 3), dtype=float32, device=device)
print("Logits: \n", preds)
print("Target: \n", target)Logits:
tensor([[-0.8072, 0.0658, 1.5409],
[-1.1151, 0.9102, 0.3073],
[ 2.3941, 2.0975, -0.3208],
[ 0.2687, 0.0528, 0.5680],
[-1.3618, -0.4430, -1.3281]], device='cuda:0')
Target:
tensor([[1., 0., 1.],
[1., 1., 0.],
[0., 1., 1.],
[1., 1., 1.],
[0., 0., 0.]], device='cuda:0')
gamma = 2
alpha = tensor(1.5, device=device)
bce = BCEWithLogitsLoss(pos_weight=alpha)
bfl = BinaryFocalLoss(gamma=gamma, alpha=alpha)
print(f"BCE Loss: {bce(preds, target).item():.5f}")
print(f"Focal Loss: {bfl(preds, target).item():.5f}")BCE Loss: 0.91235
Focal Loss: 0.37774
When doing multi-label classification, you can also specify a value of
gamma = 2
alpha = tensor([0.5, 1, 1.5], device=device)
bce = BCEWithLogitsLoss(pos_weight=alpha)
bfl = BinaryFocalLoss(gamma=gamma, alpha=alpha)
print(f"BCE Loss: {bce(preds, target).item():.5f}")
print(f"Focal Loss: {bfl(preds, target).item():.5f}")BCE Loss: 0.66547
Focal Loss: 0.27402
We also extended Lin et al.'s focal loss, which they only defined for the binary case, to the multiclass case.
Our example input will be for a 4-class classification problem, so we will create a sample of 5 labels and 5 sets of logits.
preds = randn(5, 4, device=device)
target = randint(4, size=(5,), device=device)
print("Logits: \n", preds)
print("Target: \n", target)Logits:
tensor([[ 0.6680, -0.9365, 0.1303, -0.6680],
[-0.0752, 1.0425, -0.1543, -0.7228],
[-1.1970, 0.5895, 0.3956, 1.9686],
[-0.0353, 1.0202, 0.6165, -1.0623],
[-1.9054, -0.4874, -1.2124, 0.5739]], device='cuda:0')
Target:
tensor([3, 1, 0, 0, 0], device='cuda:0')
Like binary focal loss and binary cross entropy loss, multi-class focal loss and cross entropy loss are the same when
gamma = 0
cel = CrossEntropyLoss()
mcfl = MultiClassFocalLoss(gamma=gamma)
print(f"Cross Entropy Loss: {cel(preds, target).item():.5f}")
print(f"Multi-Class Focal Loss: {mcfl(preds, target).item():.5f}")Cross Entropy Loss: 2.19540
Multi-Class Focal Loss: 2.19540
This is also true when we apply class balancing weights. We also call these CrossEntropyLoss class. Note that when using the reduction option "mean", the weighted mean is taken, which means that the sum is divided by the effective number of samples according to the class weights. This is the same behavior as for the standard CrossEntropyLoss class.
gamma = 0
alpha = (ones(4) + randn(4)).abs().to(device=device)
print(f"Alpha: {alpha}\n")
cel = CrossEntropyLoss(weight=alpha)
mcfl = MultiClassFocalLoss(gamma=gamma, alpha=alpha)
print(f"Cross Entropy Loss: {cel(preds, target).item():.5f}")
print(f"Multi-Class Focal Loss: {mcfl(preds, target).item():.5f}")Alpha: tensor([1.0490, 0.1758, 0.8946, 1.3543], device='cuda:0')
Cross Entropy Loss: 2.48619
Multi-Class Focal Loss: 2.48619
As in the binary case, multi-class focal loss differs from cross entropy loss when
gamma = 2
cel = CrossEntropyLoss(weight=alpha)
mcfl = MultiClassFocalLoss(gamma=gamma, alpha=alpha)
print(f"Cross Entropy Loss: {cel(preds, target).item():.5f}")
print(f"Multi-Class Focal Loss: {mcfl(preds, target).item():.5f}")Cross Entropy Loss: 2.48619
Multi-Class Focal Loss: 2.09198
Use the Issues section for questions, feedback, and concerns, or create a Pull Request if you want to contribute!