Custom loss functions in `Gale`
from fastcore.all import *

class LabelSmoothingCrossEntropy[source]

LabelSmoothingCrossEntropy(eps:float=0.1, reduction:str='mean', weight:Optional[Tensor]=None) :: Module

Cross Entropy Loss with Label Smoothing

criterion = LabelSmoothingCrossEntropy(reduction="mean")

output = torch.randn(32, 5, requires_grad=True)
target = torch.empty(32, dtype=torch.long).random_(5)

loss = criterion(output, target)

class BinarySigmoidFocalLoss[source]

BinarySigmoidFocalLoss(alpha:float=-1, gamma:float=2, reduction:str='mean') :: Module

Creates a criterion that computes the focal loss between binary input and target. Focal Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

Source: https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py

Focal Loss is the same as cross entropy except easy-to-classify observations are down-weighted in the loss calculation. The strength of down-weighting is proportional to the size of the gamma parameter. Put another way, the larger gamma the less the easy-to-classify observations contribute to the loss.

criterion = BinarySigmoidFocalLoss(reduction="mean")

target = torch.ones([10, 64], dtype=torch.float32)
output = torch.full([10, 64], 1.5)

loss = criterion(output, target)

class FocalLoss[source]

FocalLoss(alpha:float=1, gamma:float=2, reduction:str='mean', eps:float=1e-08) :: Module

Same as nn.CrossEntropyLoss but with focal paramter, gamma. Focal Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Focal loss is computed as follows : ${FL}(p_t)$ = $lpha(1 - p_t)^{\gamma}{log}(p_t)$

Source: https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/focal.html

Arguments to FocalLoss:

  • alpha (float): Weighting factor $\alpha$ in [0, 1].
  • gamma (float, optional): Focusing parameter $\gamma$ >= 0. Default 2.
  • reduction (str, optional): Specifies the reduction to apply to the
  • output: none | mean | sum.
    • none: no reduction will be applied,
    • mean: the sum of the output will be divided by the number of elements in the output
    • sum: the output will be summed.
    • Default: none.
  • eps (float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8.
criterion = FocalLoss(alpha=0.5, gamma=2.0, reduction="mean")

N = 5  # num_classes
input = torch.randn(32, N, requires_grad=True)
target = torch.empty(32, dtype=torch.long).random_(N)
loss = criterion(input, target)


# Compare focal loss with gamma = 0 ,cross entropy
fl = FocalLoss(alpha=1, gamma=0, reduction="mean")
ce = nn.CrossEntropyLoss(reduction="mean")
output = torch.randn(32, N, requires_grad=True)
target = torch.empty(32, dtype=torch.long).random_(N)
test_close(fl(output, target), ce(output, target))

# Test focal loss with gamma > 0 is different than cross entropy
fl = FocalLoss(gamma=2)
with torch.no_grad():
    test_ne(fl(output, target), ce(output, target))

Build

Losses are created by the Lightning-Tasks in Gale using the Config. To load a loss via gale config the loss must be present in either LOSS_REGISTRY or losses available in the torch.nn.modules.loss_module

build_loss[source]

build_loss(config:DictConfig)

Builds a loss from a config. This assumes a 'name' key in the config which is used to determine what model class to instantiate. For instance, a config {"name": "my_loss", "foo": "bar"} will find a class that was registered as "my_loss". A custom loss must first be registerd into LOSS_REGISTRY.

For Image Classification a loss is created like so ...

from gale.config import get_config

cfg = get_config(config_name="classification")

# grab the config for the Loss Function
loss_cfg = cfg.training.train_loss_fn

# print(OmegaConf.to_yaml(loss_cfg))
loss = build_loss(loss_cfg)