By default, there will be int(math.sqrt(n))
rows and ceil(n/rows)
columns. double will double the number of columns and n. The default figsize is (cols*imsize, rows*imsize+add_vert)
. If a title is passed it is set to the figure. sharex, sharey, squeeze, subplot_kw and gridspec_kw are all passed down to plt.subplots. If return_fig is True, returns fig,axs, otherwise just axs. flatten will flatten the matplot axes such that they can be iterated over with a single loop.
Now let's look at how we can train an image classifier using ClassificationTask
.
from fastcore.all import Path
from nbdev.export import Config
from gale.classification.augment import *
from gale.classification.data import register_dataset_from_folders
from gale.collections.callbacks.notebook import NotebookProgressCallback
from gale.collections.download import download_and_extract_archive
from gale.config import get_config
Let's take this tutorial and see how we can achieve this in PyTorch Gale ...
The dataset can be downloaded and decompressed with this line of code:
URL = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
data_path = Path(Config().path("nbs_path")) / "data"
# download a toy dataset
download_and_extract_archive(url=URL, download_root=data_path, extract_root=data_path)
data_path = data_path / "hymenoptera_data"
data_path.ls()
We first need to setup up our datasets. Our Images are stored in train
and val
folders respectively. We register these datasets into the DatasetCatalog
-
train_data = data_path / "train"
val_data = data_path / "val"
train_data.ls()
val_data.ls()
We take the convenience function register_dataset_from_folders
which parses the images, creates the labels and registers it into DatasetCatalog
image_roots = L(train_data, val_data)
augments = L(imagenet_augment_transform(224), imagenet_no_augment_transform(224))
for i, a, d in zip(image_roots, augments, ["train", "val"]):
register_dataset_from_folders("hymenoptera_" + d, i, augmentations=a)
Now, let's fine-tune a imagenet-pretrained ResNet18
model and finetune it on the hymenoptera
dataset.
# we can also specify overrides using the api of HYDRA
overrides = [
"head@model.head=classification/fc_head", # this modifies the head of the model
"optimizer=adamw", # this modifies the optimizer
"scheduler=onecycle", # this modifies the scheduler in the optimizer
]
cfg = get_config("classification", overrides=overrides)
cfg.input.channels = 3
cfg.input.height = 224
cfg.input.width = 224
# model arguments :
cfg.model.num_classes = 2 # has two class (ants, bees).
# NOTE: this config means the number of classes
cfg.model.backbone.init_args.model_name = "resnet18"
cfg.model.backbone.init_args.pretrained = True
cfg.model.backbone.init_args.freeze_at = 4
cfg.model.backbone.init_args.lr = 1e-06 # pick a good LR for the backbone
cfg.model.backbone.init_args.wd = 0.0
cfg.model.head.init_args.num_classes = "${model.num_classes}"
cfg.model.head.init_args.lr = 1e-02 # pick a good LR for the classifier
cfg.model.head.init_args.filter_wd = False
cfg.model.head.init_args.drop_rate = 0.3
cfg.model.head.init_args.wd = 0.0
# datasets :
cfg.datasets.train = "hymenoptera_train"
cfg.datasets.valid = "hymenoptera_val"
cfg.datasets.test = "hymenoptera_val"
# print(OmegaConf.to_yaml(cfg))
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
# specify your model training callbacks
cbs = [
NotebookProgressCallback(),
LearningRateMonitor(cfg.optimization.scheduler.interval, log_momentum=True),
]
# Initialize the pytorch-lightning trainer with all its goodies
trainer = pl.Trainer(max_epochs=4, callbacks=cbs, log_every_n_steps=1)
task = ClassificationTask(cfg, trainer, metrics=torchmetrics.Accuracy())
# print(task)
To verify the data loading is correct, let's visualize the annotations of randomly selected samples in the training set:
task.show_batch("train", n=8)
# Train the model âš¡
trainer.fit(task)
Now, let's run inference with the trained model on the test which in our case is same as the validation dataset
# Evalute the Model:
trainer.test(ckpt_path="best")
Visualize the prediction results on the test dataset
_ = task.show_results(max_n=12)
In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/