PIL library loads an Image from a path and returns the Image in PIL.Image.Image
format ...
im = pil_loader(_IMAGE)
show_image(im)
assert isinstance(im, Image.Image)
Contrary to PIL, cv2
loads and returns the Image in np.ndarray
...
im = cv2_loader(_IMAGE)
show_image(im)
assert isinstance(im, np.ndarray)
augs = [T.RandomRotation(90), T.RandomHorizontalFlip(1.0), T.ColorJitter(0.4, 0.4, 0.4)]
augs = T.Compose(augs)
mapper = ClassificationMapper(augmentations=augs)
mapper
For training datasets ClassificationMapper
takes in a DatasetDict
which contains the keys file_name
, targets
; these corresponds to the file path of the Image and the Interger target label for the Image repectively ...
datas = DatasetDict(file_name=_IMAGE, target=0)
im, targ = mapper.encodes(datas)
# de-normalize image
im = im.permute(1, 2, 0) * torch.tensor(mapper.std) + torch.tensor(mapper.mean)
show_images([im], titles=[targ], imsize=5)
ClassificationMapper
with torchvision default datasets ...
import torchvision
from gale.collections.download import download_and_extract_archive
from nbdev.export import Config
from fastcore.all import Path
URL = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
data_path = Path(Config().path("nbs_path")) / "data"
# download a toy dataset
download_and_extract_archive(url=URL, download_root=data_path, extract_root=data_path)
# create torchvision dataset instance
path = data_path / "hymenoptera_data"
ds = torchvision.datasets.ImageFolder(root=path / "train")
im, targ = mapper.encodes(ds[0])
# de-normalize image
im = im.permute(1, 2, 0) * torch.tensor(mapper.std) + torch.tensor(mapper.mean)
show_images([im], titles=[targ], imsize=5)
uint8 or PIL images
. Normalization and conversion to tensors are handled independently by the library. ClassificationMapper
is compatible both with albumentation augmentations and torchvision augmentations.Arguments to FolderParser
:
root
: Root directory path.class_map
: Path to a.txt
file which contains the class mapping
parser = FolderParser(root=path / "train")
img, targ = mapper.encodes(parser[0])
img = img.permute(1, 2, 0) * torch.tensor(mapper.std) + torch.tensor(mapper.mean)
show_images([img], titles=[targ], imsize=5)
from gale.collections.pandas import folder2df, dataframe_labels_2_int
df = folder2df(path / "train")
df, class_map = dataframe_labels_2_int(df, label_column="target", return_labelling=True)
df.head()
parser = PandasParser(df, path_column="image_id", label_column="target")
img, targ = mapper.encodes(parser[0])
img = img.permute(1, 2, 0) * torch.tensor(mapper.std) + torch.tensor(mapper.mean)
show_images([img], titles=[targ], imsize=5)
To create a Image classification dataset for Gale we need to do the following -
URL = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
data_path = Path(Config().path("nbs_path")) / "data"
# download a toy dataset
download_and_extract_archive(url=URL, download_root=data_path, extract_root=data_path)
# take a peek at the structure of the dataset
path = data_path / "hymenoptera_data"
path.ls(), (path / "train").ls()
Data is stored in folders named after the classes. Also the data is divided into 2 subsets, we will only work with the train
subset of the current dataset. We first initialize our parser :
parser = FolderParser(root=path / "train")
ims = []
targs = []
# Let's check some samples from our parser
for i in range(5):
img = parser[i].file_name
targ = parser[i].target
ims.append(Image.open(img))
targs.append(targ)
show_images(ims)
Now that the Parser is corectly loading the Images, to get our data in Dataset's and DataLoader's we need to first create a maper for ClassificationDataset
.
The mapper will contain all the image augmentations and preprocessing necessary to prepare our data. While applying augmentations we do not need to normalize or convert our data to Tensors
, this is automatically done by the Mapper
:
# you can add even more facncy augmentations here
augs = A.Compose([A.Resize(128,128,p=1.0),A.HueSaturationValue(p=1.0)])
# also supports torchvision augmentations
# augs = T.Compose([T.Resize((128,128)), T.ColorJitter(0.3, 0.3, 0.3)])
mapper = ClassificationMapper(augmentations=augs)
dset = ClassificationDataset(mapper=mapper, parser=parser)
# Let's creat the dataloader
loader = torch.utils.data.DataLoader(dset, batch_size=8, shuffle=True)
View images from the DataLoader
for sanity check -
samples = next(iter(loader))
show_image_batch(samples)