Build a datasets and dataloaders for training
from fastcore.all import Path
from nbdev.export import Config
from torchvision.datasets import CIFAR10
from gale.classification.core import show_image_batch
from gale.classification.augment import aug_transforms
from gale.collections.download import download_and_extract_archive
URL = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
data_path = Path(Config().path("nbs_path")) / "data"
dset = CIFAR10(data_path, download=True)
# create you transforms
aug_tfms = aug_transforms(70, 64, hflip=0.5, vflip=0.5, mult=2)
# now register the dataset to gale DatasetCataLog
register_torchvision_dataset("cifar_10", dset, augmentations=aug_tfms)
ds = DatasetCatalog.get("cifar_10")
# Let's create a dataloader & view the Images
loader = DataLoader(ds, batch_size=8, shuffle=True)
show_image_batch(next(iter(loader)))
# download a toy dataset
download_and_extract_archive(url=URL, download_root=data_path, extract_root=data_path)
# train data is present in :
hymenoptera_data = data_path / "hymenoptera_data/train"
register_dataset_from_folders(name="hymenoptera_train_ds", image_root=hymenoptera_data, augmentations=aug_tfms)
ds = DatasetCatalog.get("hymenoptera_train_ds")
# Let's create a dataloader & view the Images
loader = DataLoader(ds, batch_size=8, shuffle=True)
show_image_batch(next(iter(loader)))
from gale.collections.pandas import dataframe_labels_2_int, folder2df
df = folder2df(hymenoptera_data)
df = dataframe_labels_2_int(df, label_column="target")
df.head()
register_dataset_from_df(
name="hymenoptera_train_v0",
df=df,
path_column="image_id",
label_column="target",
augmentations=aug_tfms,
)
ds = DatasetCatalog.get("hymenoptera_train_v0")
# Let's create a dataloader & view the Images
loader = DataLoader(ds, batch_size=8, shuffle=True)
show_image_batch(next(iter(loader)))
from gale.config import get_config
cfg = get_config("classification")
cfg.datasets.train = "hymenoptera_train_v0"
dls = build_classification_loader_from_config(cfg.datasets.train, cfg.dataloader.train)
assert isinstance(dls, DataLoader)
show_image_batch(next(iter(dls)))