Backbones/feature extractors for use in Image Classification Tasks

Utils function

has_pool_type[source]

has_pool_type(m:Module)

Return True if m is a pooling layer or has one in its children From: https://github.com/fastai/fastai/blob/master/fastai/vision/learner.py#L76

prepare_backbone[source]

prepare_backbone(model:Module, cut=None)

Cut off the body of a typically pretrained model as determined by cut

tst = nn.Sequential(nn.Conv2d(3, 5, 3), nn.BatchNorm2d(5), nn.AvgPool2d(1), nn.Linear(3, 4))

m = prepare_backbone(tst)
test_eq(len(m), 2)

m = prepare_backbone(tst, cut=3)
test_eq(len(m), 3)

m = prepare_backbone(tst, cut=-1)
test_eq(len(m), 3)

filter_weight_decay[source]

filter_weight_decay(model:Module, lr:float, weight_decay:float=1e-05, skip_list=())

Filter out bias, bn and other 1d params from weight decay. Modified from: https://github.com/rwightman/pytorch-image-models/timm/optim/optim_factory.py

class ImageClassificationBackbone[source]

ImageClassificationBackbone() :: BasicModule

Abstract class for ImageClassification BackBones

Some meta_arch's in gale require backbones and for image classsification all the backbones should inherit from ImageClassificationBackbone

class TstModule(ImageClassificationBackbone):
    def __init__(self):
        super(TstModule, self).__init__()
        layers = [nn.Linear(3, 4), nn.Linear(4, 5)]
        self.layers = nn.Sequential(*layers)

    def forward(self, o):
        return self.layers(o)

    def output_shape(self):
        return ShapeSpec(4, None, None)

    def build_param_dicts(self):
        p0 = {"params": self.layers[0].parameters(), "lr": 1e-06, "weight_decay": 0.001}
        p1 = {"params": self.layers[1].parameters(), "lr": 1e-03, "weight_decay": 0.1}
        return [p0, p1]


tst = TstModule()

ImageClassificationBackbone.hypers[source]

Returns list of parameters like lr and wd for each param group

test_eq(tst.hypers.lr, [1e-06, 1e-03])
test_eq(tst.hypers.wd, [0.001, 0.1])

ImageClassificationBackbone.output_shape[source]

ImageClassificationBackbone.output_shape()

Returns the output shape. For most backbones this means it will contain the channels in the output layer.

assert tst.output_shape() == ShapeSpec(4, None, None)

ImageClassificationBackbone.filter_params[source]

ImageClassificationBackbone.filter_params(parameters:List[Dict])

Filters any empty paramter groups in p

class TimmBackboneBase[source]

TimmBackboneBase(model_name:str, input_shape:ShapeSpec, act:str=None, lr:float=0.001, wd:float=0, freeze_bn:bool=False, freeze_at:int=False, filter_wd:bool=False, pretrained=True, drop_block_rate=None, drop_path_rate=None, bn_tf=False, **kwargs) :: ImageClassificationBackbone

Create a model from timm and converts it into a Image Classification Backbone

This class provides a simple way to load a model from timm using all it's arguments. It then cuts the model at the pooling layer before the classifier of the model .ie., we keep the feature extractor the feature extractor is converted to the backbone. You can optionally choose to partially or fully freeze the parameters groups of the backbone using freeze_at. freeze_bn sets the BatchNorm layers of the model to eval & if filter_wd then the weight decay is not applied to bias and other 1d paramters of the backbone.

TimmBackboneBase.build_param_dics()) is responsible to building the parameters of the model. Currently it returns the trainable_params of the model with lr and wd. The paramters are filterd with wd if filter_wd. For more advanced options you should probably override this method.

Arguments to TimmBackboneBase:

  • input_shape (ShapeSpec): Shape of the Inputs
  • model_name (str): name of model to instantiate.
  • act (str): name of the activation function to use. If None uses the default activations else the name must be in ACTIVATION_REGISTRY.
  • lr (float): learning rate for the modules.
  • wd (float): weight decay for the modules.
  • freeze_bn (bool): freeze the batch normalization layers of the model.
  • freeze_at (int): freeze the layers of the backbone upto freeze_at, false means train all.
  • filter_wd (bool): Filter out bias, bn from weight_decay.
  • pretrained (bool): load pretrained ImageNet-1k weights if true.
  • drop_block_rate (float): Drop block rate
  • drop_path_rate (float): Drop_path_rate
  • bn_tf (bool): Use Tensorflow BatchNorm defaults for models that support it.
  • kwargs (optional): Optional kwargs passed onto timm.create_model()
input_shape = ShapeSpec(channels=3, height=255, width=255)
bk = TimmBackboneBase(model_name="resnet18", pretrained=True, input_shape=input_shape)
m = timm.create_model("resnet18")

i = torch.randn(2, 3, 224, 224)
o1 = bk(i)
test_eq(o1.shape, torch.Size([2, 512, 7, 7]))
test_eq(bk.output_shape().channels, m.num_features)
Loading pretrained weights from url (https://download.pytorch.org/models/resnet18-5c106cde.pth)

Dataclass

class TimmBackboneDataClass[source]

TimmBackboneDataClass(model_name:str='???', act:Optional[str]=None, lr:Any=0.001, wd:Any=0.0, freeze_bn:bool=False, freeze_at:Any=False, filter_wd:bool=False, pretrained:bool=True, drop_block_rate:Optional[float]=None, drop_path_rate:Optional[float]=None, bn_tf:bool=False)

Base config file for TimmBackboneBase. You need to pass in a model_name the opter parameters are optional.

The config for TimmBackboneBaseConfig is going to look like this. We need to convert the dataclass to the Omegaconf config file and then we can use from_config_dict method to instantiate our class ...

conf = TimmBackboneDataClass(model_name="resnet18", pretrained=True)
conf = OmegaConf.structured(conf)

# we need to explicitely pass in the input_shape argument
m = TimmBackboneBase.from_config_dict(conf, input_shape=input_shape)

o2 = m(i)
test_eq(o2.shape, torch.Size([2, 512, 7, 7]))

test_eq(o1.data, o2.data)
Loading pretrained weights from url (https://download.pytorch.org/models/resnet18-5c106cde.pth)

class ResNetBackbone[source]

ResNetBackbone(model_name:str, input_shape:ShapeSpec, act:str=None, lr:float=0.001, wd:float=0.01, lr_div:float=100, freeze_at:int=0, freeze_bn:bool=False, pretrained=True, drop_block_rate=0.0, drop_path_rate=0.0, **kwargs) :: ImageClassificationBackbone

A Backbone for ResNet based models from timm. Note: this class does supports all the models listed here

Arguments to ResNetBackbone:

  • input_shape (ShapeSpec): Shape of the Inputs
  • model_name (str): name of model to instantiate.
  • act (str): name of the activation function to use. If None uses the default activations else the name must be in ACTIVATION_REGISTRY.
  • lr (float): learning rate for the modules.
  • lr_div (int, float): factor for discriminative lrs.
  • wd (float): weight decay for the modules.
  • freeze_at (int): Freeze the first several stages of the ResNet. Commonly used in fine-tuning. 1 means freezing the stem. 2 means freezing the stem and one residual stage, etc.
  • pretrained (bool): load pretrained ImageNet-1k weights if true.
  • drop_block_rate (float): Drop block rate.
  • drop_path_rate (float): Drop path rate.
  • bn_tf (bool): Use Tensorflow BatchNorm defaults for models that support it.
  • kwargs (optional): Optional kwargs passed onto timm.create_model()

ResNetBackbone is a ImageClassificationBackbone class that is resposible to converting ResNet based models into a appropriate backbone for Image Classification tasks.

Note that each resnet block at 1 stem and 4 convolutional blocks in the model. You can freeze some or all of these blocks by setting freeze_at. If 0 then the whole model is traininable. 1 freezes only the stem, 2 freezes the stem and a block and so on. We also train the ResNet model using discriminative Lr's for finetuning. So the 3 and 4 blocks are trained with a learning rate of lr and the stem, 1 block, 2 block are trained with learning rates lr/lr_div. Weight decay wd is applied to the whole model.

ResNetBackbone.prepare_model[source]

ResNetBackbone.prepare_model(m:Module)

Freeze the first several stages of the ResNet. Commonly used in fine-tuning.

ResNetBackbone.freeze_block[source]

ResNetBackbone.freeze_block(m:Module)

Make this block m not trainable.

Dataclass

This class can be instantiated from a config as follows -

class ResNetBackboneDataClass[source]

ResNetBackboneDataClass(model_name:str='???', act:Optional[str]=None, lr:Any=0.001, lr_div:Any=10, wd:Any=0.0, freeze_at:int=0, pretrained:bool=True, drop_block_rate:Optional[float]=None, drop_path_rate:Optional[float]=None, bn_tf:bool=False)

Base config file for ResNetBackbone

conf = OmegaConf.structured(ResNetBackboneDataClass(model_name="resnet34"))
# instantiate cls from config
m = ResNetBackbone.from_config_dict(conf, input_shape=input_shape)
Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth)