Lightning Image Classification Task

predict_context[source]

predict_context(func:Callable)

This decorator is used as context manager to put model in eval mode before running predict and reset to train after.

class ClassificationTask[source]

ClassificationTask(cfg:DictConfig, trainer:Trainer=None, metrics:Union[Metric, Mapping, Sequence, NoneType]=None) :: DefaultTask

Interface for Pytorch-lightning based Gale modules

Internals

ClassificationTask.setup[source]

ClassificationTask.setup(stage:Optional[str]=None)

Sets up the model and all the modelitites of the pl.LightningModule. This is called during task initialization.

ClassificationTask.forward[source]

ClassificationTask.forward(x)

Forward method: we pass in the input through the meta_arch to get the predictions for the current image batch

ClassificationTask.shared_step[source]

ClassificationTask.shared_step(batch:Any, batch_idx:int, stage:str)

Common steps for training, validation and test stages. Shared step returns a dictionary containing the loss and logs, which are the metric values computed at stage. This also applies mixup/cutmix to the training data if specified in config.

ClassificationTask.generate_preds[source]

ClassificationTask.generate_preds(batch:Tuple)

Generate predictions for batch. Returns the Images, Targets & Predictions for the Batch.

Overrides

ClassificationTask.setup_training_data[source]

ClassificationTask.setup_training_data(name:str=None, dls_conf:DictConfig=None)

Builds the training dataset from name and the dataloader from dls_conf, if None then parsers the values from the passed config while creating the instance

ClassificationTask.setup_validation_data[source]

ClassificationTask.setup_validation_data(name:Union[List, str]=None, dls_conf:DictConfig=None)

Same as setup_training_data but sets up validation dataset and dataloaders

ClassificationTask.setup_test_data[source]

ClassificationTask.setup_test_data(name:Union[List, str]=None, dls_conf:DictConfig=None)

Same as setup_training_data but sets up test dataset and dataloaders

DefaultTask.setup_optimization[source]

DefaultTask.setup_optimization(conf:DictConfig=None)

Prepares an optimizer from a string name and its optional config parameters. You can also manually call this method with a valid optimization config to setup the optimizers and lr_schedulers.

Helpers

get_grid[source]

get_grid(n, nrows=None, ncols=None, add_vert=0, figsize=None, double=False, title=None, return_fig=False, flatten=True, imsize=3, suptitle=None, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None)

Return a grid of n axes, rows by cols

By default, there will be int(math.sqrt(n)) rows and ceil(n/rows) columns. double will double the number of columns and n. The default figsize is (cols*imsize, rows*imsize+add_vert). If a title is passed it is set to the figure. sharex, sharey, squeeze, subplot_kw and gridspec_kw are all passed down to plt.subplots. If return_fig is True, returns fig,axs, otherwise just axs. flatten will flatten the matplot axes such that they can be iterated over with a single loop.

ClassificationTask.show_results[source]

ClassificationTask.show_results(dataloader:DataLoader=None, ctxs=None, max_n:int=10, nrows:int=None, ncols:int=None, figsize:Tuple=None, **kwargs)

Displays the results for max_n items of a batch in Dataloader if given or else uses test or validationd dataloader

None[source]

Minimal Example

Now let's look at how we can train an image classifier using ClassificationTask.

from fastcore.all import Path
from nbdev.export import Config

from gale.classification.augment import *
from gale.classification.data import register_dataset_from_folders
from gale.collections.callbacks.notebook import NotebookProgressCallback
from gale.collections.download import download_and_extract_archive
from gale.config import get_config

Let's take this tutorial and see how we can achieve this in PyTorch Gale ...

Prepare the dataset

The dataset can be downloaded and decompressed with this line of code:

URL = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
data_path = Path(Config().path("nbs_path")) / "data"

# download a toy dataset
download_and_extract_archive(url=URL, download_root=data_path, extract_root=data_path)

data_path = data_path / "hymenoptera_data"
data_path.ls()
Using downloaded and verified file: /Users/ayushman/Desktop/gale/nbs/data/hymenoptera_data.zip
Extracting /Users/ayushman/Desktop/gale/nbs/data/hymenoptera_data.zip to /Users/ayushman/Desktop/gale/nbs/data
(#2) [Path('/Users/ayushman/Desktop/gale/nbs/data/hymenoptera_data/train'),Path('/Users/ayushman/Desktop/gale/nbs/data/hymenoptera_data/val')]

We first need to setup up our datasets. Our Images are stored in train and val folders respectively. We register these datasets into the DatasetCatalog -

train_data = data_path / "train"
val_data = data_path / "val"

train_data.ls()
val_data.ls()
(#2) [Path('/Users/ayushman/Desktop/gale/nbs/data/hymenoptera_data/val/bees'),Path('/Users/ayushman/Desktop/gale/nbs/data/hymenoptera_data/val/ants')]

We take the convenience function register_dataset_from_folders which parses the images, creates the labels and registers it into DatasetCatalog

image_roots = L(train_data, val_data)
augments = L(imagenet_augment_transform(224), imagenet_no_augment_transform(224))

for i, a, d in zip(image_roots, augments, ["train", "val"]):
    register_dataset_from_folders("hymenoptera_" + d, i, augmentations=a)
Dataset: hymenoptera_train registerd to DatasetCatalog
Dataset: hymenoptera_val registerd to DatasetCatalog

Train!

Now, let's fine-tune a imagenet-pretrained ResNet18 model and finetune it on the hymenoptera dataset.

# we can also specify overrides using the api of HYDRA
overrides = [
    "head@model.head=classification/fc_head",  # this modifies the head of the model
    "optimizer=adamw",  # this modifies the optimizer
    "scheduler=onecycle",  # this modifies the scheduler in the optimizer
]

cfg = get_config("classification", overrides=overrides)
cfg.input.channels = 3
cfg.input.height = 224
cfg.input.width = 224

# model arguments :
cfg.model.num_classes = 2  # has two class (ants, bees).
# NOTE: this config means the number of classes

cfg.model.backbone.init_args.model_name = "resnet18"
cfg.model.backbone.init_args.pretrained = True
cfg.model.backbone.init_args.freeze_at = 4
cfg.model.backbone.init_args.lr = 1e-06  # pick a good LR for the backbone
cfg.model.backbone.init_args.wd = 0.0

cfg.model.head.init_args.num_classes = "${model.num_classes}"
cfg.model.head.init_args.lr = 1e-02  # pick a good LR for the classifier
cfg.model.head.init_args.filter_wd = False
cfg.model.head.init_args.drop_rate = 0.3
cfg.model.head.init_args.wd = 0.0

# datasets :
cfg.datasets.train = "hymenoptera_train"
cfg.datasets.valid = "hymenoptera_val"
cfg.datasets.test = "hymenoptera_val"
# print(OmegaConf.to_yaml(cfg))
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

# specify your model training callbacks
cbs = [
    NotebookProgressCallback(),
    LearningRateMonitor(cfg.optimization.scheduler.interval, log_momentum=True),
]

# Initialize the pytorch-lightning trainer with all its goodies
trainer = pl.Trainer(max_epochs=4, callbacks=cbs, log_every_n_steps=1)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
task = ClassificationTask(cfg, trainer, metrics=torchmetrics.Accuracy())
Building GeneralizedImageClassifier from config ...
Loading pretrained weights from url (https://download.pytorch.org/models/resnet18-5c106cde.pth)
Model created, param count: 11.2 M.
# print(task)

To verify the data loading is correct, let's visualize the annotations of randomly selected samples in the training set:

task.show_batch("train", n=8)
# Train the model âš¡
trainer.fit(task)
  | Name       | Type                       | Params | In sizes         | Out sizes
-----------------------------------------------------------------------------------------
0 | _metrics   | ModuleDict                 | 0      | ?                | ?        
1 | _model     | GeneralizedImageClassifier | 11.2 M | [1, 3, 224, 224] | [1, 2]   
2 | train_loss | SoftTargetCrossEntropy     | 0      | ?                | ?        
3 | eval_loss  | CrossEntropyLoss           | 0      | ?                | ?        
-----------------------------------------------------------------------------------------
8.4 M     Trainable params
2.8 M     Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)
Training [32/32 02:04, Epoch 3 {'loss': '0.292', 'v_num': 0}]
epoch val/loss val/accuracy train/loss train/accuracy time samples/s
0 0.323344 0.875817 0.515968 0.800000 32.280200 0.402700
1 0.182327 0.941176 0.245177 0.900000 32.835700 0.395900
2 0.182145 0.947712 0.355949 0.850000 31.498000 0.412700
3 0.204220 0.921569 0.297929 0.850000 30.248500 0.429800

1

Inference & evaluation using the trained model

Now, let's run inference with the trained model on the test which in our case is same as the validation dataset

# Evalute the Model:
trainer.test(ckpt_path="best")
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/accuracy': 0.9215686321258545, 'test/loss': 0.20422008633613586}
--------------------------------------------------------------------------------
[{'test/loss': 0.20422008633613586, 'test/accuracy': 0.9215686321258545}]

Visualize the prediction results on the test dataset

_ = task.show_results(max_n=12)

In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!

# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/