Custom loss functions in `Gale`
from fastcore.all import *
criterion = LabelSmoothingCrossEntropy(reduction="mean")
output = torch.randn(32, 5, requires_grad=True)
target = torch.empty(32, dtype=torch.long).random_(5)
loss = criterion(output, target)
Focal Loss is the same as cross entropy except easy-to-classify observations are down-weighted in the loss calculation. The strength of down-weighting is proportional to the size of the gamma parameter. Put another way, the larger gamma the less the easy-to-classify observations contribute to the loss.
criterion = BinarySigmoidFocalLoss(reduction="mean")
target = torch.ones([10, 64], dtype=torch.float32)
output = torch.full([10, 64], 1.5)
loss = criterion(output, target)
Arguments to FocalLoss
:
alpha
(float): Weighting factor $\alpha$ in[0, 1]
.gamma
(float, optional): Focusing parameter $\gamma$ >= 0. Default 2.reduction
(str, optional): Specifies the reduction to apply to theoutput
:none
|mean
|sum
.none
: no reduction will be applied,mean
: the sum of the output will be divided by the number of elements in the outputsum
: the output will be summed.- Default:
none
.
eps
(float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8.
criterion = FocalLoss(alpha=0.5, gamma=2.0, reduction="mean")
N = 5 # num_classes
input = torch.randn(32, N, requires_grad=True)
target = torch.empty(32, dtype=torch.long).random_(N)
loss = criterion(input, target)
# Compare focal loss with gamma = 0 ,cross entropy
fl = FocalLoss(alpha=1, gamma=0, reduction="mean")
ce = nn.CrossEntropyLoss(reduction="mean")
output = torch.randn(32, N, requires_grad=True)
target = torch.empty(32, dtype=torch.long).random_(N)
test_close(fl(output, target), ce(output, target))
# Test focal loss with gamma > 0 is different than cross entropy
fl = FocalLoss(gamma=2)
with torch.no_grad():
test_ne(fl(output, target), ce(output, target))
Losses are created by the Lightning-Tasks in Gale using the Config. To load a loss via gale config the loss must be present in either LOSS_REGISTRY
or losses available in the torch.nn.modules.loss
_module
For Image Classification a loss is created like so ...
from gale.config import get_config
cfg = get_config(config_name="classification")
# grab the config for the Loss Function
loss_cfg = cfg.training.train_loss_fn
# print(OmegaConf.to_yaml(loss_cfg))
loss = build_loss(loss_cfg)