12 minute read

Most PyTorch courses and tutorials show how to train a model using the pre-loaded datasets (such as MNIST) that subclass the torch.utils.data.Dataset. But in realistic scenarios, we have to train models on our own datasets and implement functions specific to them. Therefore, many of us don’t know the underlying operations and functions to implement to get PyTorch to work with our own dataset.

PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

In this post, we will learn how to use the Dataset and DataLoader class. make subclasses from them to use for our custom dataset where samples and labels both are images. for classification training, we have samples as images and their class annotations in a CSV file (0 or 1). Pytorch official documentation provides a tutorial to work with that sort of dataset

Note: In this particular example, i used DRIVE dataset. DRIVE is a fundus images datset where samples are the retinal images and the labels are the corresponding segmentation map of retinal blood vessels (samples and labels are both images). With slight changes, this example can be used to load any type of dataset for training in pytorch

Subclassing torch.utils.data.Dataset to generate samples and labels

The whole code for making a dataset generator using torch.utils.data.Dataset that will be explained line by line:

Dataset subclass:


from torch.utils.data import Dataset
import os
import natsort
from PIL import Image
import numpy as np
import cv2

class CustomDataset(Dataset):
    """the class for loading the dataset from the directories
    Arguments:

        img_dir: directory for the dataset images

        label_dir: labels of the images directory

        transform: transforms applied to inputs

        transform_label: transforms applied to the labels
    """

    def __init__(self, img_dir, label_dir, 
                transform_image,transform_label,
                image_scale = None):

        self.image_scale = image_scale
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transform_image = transform_image
        self.transform_label = transform_label
        all_images = os.listdir(self.img_dir)
        all_lables = os.listdir(self.label_dir)

        self.total_imgs = natsort.natsorted(all_images)
        self.total_labels = natsort.natsorted(all_lables)
    

    def __len__(self):
        return len(self.total_imgs)

    @classmethod
    def preprocessing(cls, image, label, scale=None):
        """class method for preprocessing of the image 
        and label
        usage: preprocessing the images before feeding
        to the network for training as well as before
        making predictions dataset class dishes out 
        pre-processed bathes of images and labels """
        return(image, label)


    def __getitem__(self, idx):
        """ Generator to yield a tuple of image and
        label
        idx: the index to iterate over the dataset in
        directories of both images and labels
        ---------------------

        :return:  image, label
        :rtype: torch tensor
        """

        img_loc = os.path.join(self.img_dir,
                         self.total_imgs[idx])
        label_loc = os.path.join(self.label_dir,
                         self.total_labels[idx])
        # opening image using cv2 function
        image = cv2.imread(img_loc)
        # opening image with PIL package
        label = Image.open(label_loc)  


        image, label = self.preprocessing(image, label,
                         mask, scale=self.image_scale)

        label = np.asarray(label).astype(np.uint8)


        '''=====applying transformations ======='''

        label = self.transform_label(label)
        image = self.transform_label(image)

        return image, label

Writing the CustomDataset class:

import torch.utils.data.Dataset and all the other necessary packages according to your data.

from torch.utils.data import Dataset
import os
import natsort
from PIL import Image
import numpy as np
import cv2

make a subclass from Dataset and initializing it

__init__ function is a class constructor that is run once to initializing the class instance


    def __init__(self, img_dir, label_dir, transform_image, transform_label, image_scale = None):

        self.image_scale = image_scale
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transform_image = transform_image
        self.transform_label = transform_label
        all_images = os.listdir(self.img_dir)
        all_lables = os.listdir(self.label_dir)

        self.total_imgs = natsort.natsorted(all_images)
        self.total_labels = natsort.natsorted(all_lables)

we are providing the image and labels directories along with separate transforms for images and labels as parameters to the class __ini__ function. we list samples and labels from their respective directories and then sort them according to their names with natsort:

        all_images = os.listdir(self.img_dir)
        all_lables = os.listdir(self.label_dir)

        self.total_imgs = natsort.natsorted(all_images)
        self.total_labels = natsort.natsorted(all_lables)

these sorted lists will be used to get each individual image and its corresponding label (labels are also images for segmentation tasks).

Breaking down __getitem__:

__getitem__ loads the images and labels and iterated through them using the idx index. in this function we can apply necessary transforms from torchvision.transforms or other necessary pre-processing steps with a withing this function (or using a separate class method)

    def __getitem__(self, idx):

        img_loc = os.path.join(self.img_dir, self.total_imgs[idx])
        label_loc = os.path.join(self.label_dir, self.total_labels[idx])
        # opening image using cv2 function
        image = cv2.imread(img_loc)
        # opening image with PIL package
        label = Image.open(label_loc)  

        image, label = self.preprocessing(image, label, scale=self.image_scale)

        '''===============applying transformations ===================='''

        label = self.transform_label(label)
        image = self.transform_label(image)

        return image, label

i used the cv2 as well as PIL.Image to load images from the label or image directory just to demonstrate the flexibility and how you can use either of the two you are most comfortable with.

        img_loc = os.path.join(self.img_dir, self.total_imgs[idx])
        label_loc = os.path.join(self.label_dir, self.total_labels[idx])

we join paths of the image/label directory with the image/label list we sorted earlier. for example: total_images is the list of all the sorted images in the img_dir and idx provides an index for each individual image so that we have a complete path to the image/label to be loaded.

we can invoke the preprocessing() class method to perform some additional processing steps and then finally apply the transforms to the image and label and then return an image with its corresponding label.

__len__ method:

returns the len of the sample size so that the Dataset class knows how the number of iterations to be performed to load the entire dataset.

preprocessing method:

a class method implemented which is exclusively not a part of the torch.utils.data.Dataset class but additionally added to the CustomDataset class as a class method. can be excluded if not needed. I intentionally left it empty however it can be used to perform some preprocessing using cv2 for example: resizing, grayscale conversion, applying CLAHE, gemma correction and other similar cv2 operations.

Perparing data for training with DataLoader:

DataLoader to load CustomDataset:

import CustomDataset
from torch.utils.data import DataLoader, random_split
from torchvision import transforms as transforms

transform_label = transforms.Compose([transforms.ToTensor() ])

transform_image = transforms.Compose([transforms.ToTensor()])

dataset = CustomDataset(img_dir, label_dir, 
                        transform_image= transform_image,
                        transform_label=transform_label,
                        image_scale=.5)

n_val = int(len(dataset) * 0.2)
n_train = int(len(dataset) - n_val)
train, val = random_split(dataset, [n_train, n_val])

train_loader = DataLoader(train, batch_size=15, shuffle=True,
num_workers=0, pin_memory=False)

val_loader = DataLoader(val, batch_size=10, shuffle=False,
num_workers=0, pin_memory=False)

importing all the necessary packages

import CustomDataset
from torch.utils.data import DataLoader, random_split
from torchvision import transforms as transforms

transforms like random rotation, resize, random crop and other PyTorch transforms can be applied with `torchvision. transform. Here I only applied the tensor conversion that converts the data to a PyTorch tensor.

transform_label = transforms.Compose([transforms.ToTensor() ])

transform_image = transforms.Compose([transforms.ToTensor()])

Now the next step is to make an instance of our CustomDataset class with all the necessary parameters.

dataset = CustomDataset(img_dir, label_dir, 
                        transform_image= transform_image,
                        transform_label=transform_label,
                        image_scale=.5)

we can split our data set into train and validation set using random split

n_val = int(len(dataset) * 0.2)
n_train = int(len(dataset) - n_val)
train, val = random_split(dataset, [n_train, n_val])

make an iteratable generator from dataset for train and validation for training.

train_loader = DataLoader(train, batch_size=15, shuffle=True,
num_workers=0, pin_memory=False)

val_loader = DataLoader(val, batch_size=10, shuffle=False,
num_workers=0, pin_memory=False)

Iterate through the dataloader:

dataloader can be iterated and returns an image, label pair with the specified batch size.

for example, we defined our batch size to be 15 for train_loader. Considering our images are all grey-scaled (1 channelled) with a size of 512x512, This will yield a PyTorch tensor of (15,1,512,512) with 15 samples each sample of 1 channel and 512 height and width

for img, label in (train_loader):
    
    n,c,h,w =(img.shape)  #shape of tensor = (15,1,512,512)
    for i in range(n):
        im = np.squeeze(img[i, :, :, :].numpy())
        la = np.squeeze(label[i, :, :, :].numpy())
        visualize(im, la)


we can convert the tensors to NumPy arrays and plot them or use them straight for training just like pre-loaded datasets.