tst = nn.Sequential(nn.Conv2d(3, 5, 3), nn.BatchNorm2d(5), nn.AvgPool2d(1), nn.Linear(3, 4))
m = prepare_backbone(tst)
test_eq(len(m), 2)
m = prepare_backbone(tst, cut=3)
test_eq(len(m), 3)
m = prepare_backbone(tst, cut=-1)
test_eq(len(m), 3)
Some meta_arch
's in gale require backbones and for image classsification all the backbones should inherit from ImageClassificationBackbone
class TstModule(ImageClassificationBackbone):
def __init__(self):
super(TstModule, self).__init__()
layers = [nn.Linear(3, 4), nn.Linear(4, 5)]
self.layers = nn.Sequential(*layers)
def forward(self, o):
return self.layers(o)
def output_shape(self):
return ShapeSpec(4, None, None)
def build_param_dicts(self):
p0 = {"params": self.layers[0].parameters(), "lr": 1e-06, "weight_decay": 0.001}
p1 = {"params": self.layers[1].parameters(), "lr": 1e-03, "weight_decay": 0.1}
return [p0, p1]
tst = TstModule()
test_eq(tst.hypers.lr, [1e-06, 1e-03])
test_eq(tst.hypers.wd, [0.001, 0.1])
assert tst.output_shape() == ShapeSpec(4, None, None)
This class provides a simple way to load a model from timm using all it's arguments. It then cuts the model at the pooling layer before the classifier of the model .ie., we keep the feature extractor the feature extractor is converted to the backbone. You can optionally choose to partially or fully freeze the parameters groups of the backbone using freeze_at
. freeze_bn
sets the BatchNorm layers of the model to eval & if filter_wd
then the weight decay is not applied to bias and other 1d paramters of the backbone.
TimmBackboneBase.build_param_dics()
) is responsible to building the parameters of the model. Currently it returns the trainable_params
of the model with lr
and wd
. The paramters are filterd with wd
if filter_wd
. For more advanced options you should probably override this method.
Arguments to TimmBackboneBase
:
input_shape
(ShapeSpec): Shape of the Inputsmodel_name
(str): name of model to instantiate.act
(str): name of the activation function to use. If None uses the default activations else the name must be inACTIVATION_REGISTRY
.lr
(float): learning rate for the modules.wd
(float): weight decay for the modules.freeze_bn
(bool): freeze the batch normalization layers of the model.freeze_at
(int): freeze the layers of the backbone uptofreeze_at
, false means train all.filter_wd
(bool): Filter out bias, bn from weight_decay.pretrained
(bool): load pretrained ImageNet-1k weights if true.drop_block_rate
(float): Drop block ratedrop_path_rate
(float): Drop_path_ratebn_tf
(bool): Use Tensorflow BatchNorm defaults for models that support it.kwargs
(optional): Optional kwargs passed ontotimm.create_model()
input_shape = ShapeSpec(channels=3, height=255, width=255)
bk = TimmBackboneBase(model_name="resnet18", pretrained=True, input_shape=input_shape)
m = timm.create_model("resnet18")
i = torch.randn(2, 3, 224, 224)
o1 = bk(i)
test_eq(o1.shape, torch.Size([2, 512, 7, 7]))
test_eq(bk.output_shape().channels, m.num_features)
The config for TimmBackboneBaseConfig
is going to look like this. We need to convert the dataclass to the Omegaconf config file and then we can use from_config_dict
method to instantiate our class ...
conf = TimmBackboneDataClass(model_name="resnet18", pretrained=True)
conf = OmegaConf.structured(conf)
# we need to explicitely pass in the input_shape argument
m = TimmBackboneBase.from_config_dict(conf, input_shape=input_shape)
o2 = m(i)
test_eq(o2.shape, torch.Size([2, 512, 7, 7]))
test_eq(o1.data, o2.data)
Arguments to ResNetBackbone
:
input_shape
(ShapeSpec): Shape of the Inputsmodel_name
(str): name of model to instantiate.act
(str): name of the activation function to use. If None uses the default activations else the name must be inACTIVATION_REGISTRY
.lr
(float): learning rate for the modules.lr_div
(int, float): factor for discriminative lrs.wd
(float): weight decay for the modules.freeze_at
(int): Freeze the first several stages of the ResNet. Commonly used in fine-tuning.1
means freezing the stem.2
means freezing the stem and one residual stage, etc.pretrained
(bool): load pretrained ImageNet-1k weights if true.drop_block_rate
(float): Drop block rate.drop_path_rate
(float): Drop path rate.bn_tf
(bool): Use Tensorflow BatchNorm defaults for models that support it.kwargs
(optional): Optional kwargs passed ontotimm.create_model()
ResNetBackbone
is a ImageClassificationBackbone
class that is resposible to converting ResNet based models into a appropriate backbone for Image Classification tasks.
Note that each resnet block at 1 stem and 4 convolutional blocks in the model. You can freeze some or all of these blocks by setting freeze_at
. If 0
then the whole model is traininable. 1
freezes only the stem, 2
freezes the stem and a block and so on. We also train the ResNet model using discriminative Lr's for finetuning.
So the 3 and 4 blocks are trained with a learning rate of lr
and the stem, 1 block, 2 block are trained with learning rates lr
/lr_div
. Weight decay wd
is applied to the whole model.
This class can be instantiated from a config as follows -
conf = OmegaConf.structured(ResNetBackboneDataClass(model_name="resnet34"))
# instantiate cls from config
m = ResNetBackbone.from_config_dict(conf, input_shape=input_shape)