DATA LOADING (Auto Encoder with dictionary sample in output) TUTORIALPrerequisitesTo run this tutorial, please make sure the following packages are installed: - PyTorch 0.4.1
- TorchVision 0.2.1
- PIL: For image io and transforms
- Matplotlib: To generate plots, histograms and etc
from torch.utils.data import Dataset from os import listdir from os.path import join import matplotlib.pyplot as plt from PIL import Image DATASET CLASStorch.utils.data.Dataset is an abstract class representing a dataset. Your custom dataset should inherit Dataset and override the following methods:
- __len__ so that len(dataset) returns the size of the dataset.
- __getitem__ to support the indexing such that dataset[i] can be used to get ith sample.
Let’s create a dataset class for our Auto Encoder dataset. We will read the 'Input' image directory and 'Ground Truth' image directory in __init__ but leave the reading of images to __getitem__. This is memory efficient because all the images are not stored in the memory at once but read as required. Sample of our dataset will be a dictionary sample as {'img_in': img_in, 'img_gt': img_gt}.
class AutoEncoderDataSet(Dataset): def __init__(self, dir_in, dir_gt): self.dir_in = self.load_dir_single(dir_in) self.dir_gt = self.load_dir_single(dir_gt) def is_image_file(self, filename): return any(filename.endswith(extension) for extension in [".png", ".PNG", ".jpg", ".JPG", ".jpeg", ".JPEG"]) def load_img(self, filename): img = Image.open(filename) return img def load_dir_single(self, directory): return [join(directory, x) for x in listdir(directory) if self.is_image_file(x)] def __len__(self): return len(self.dir_in) def __getitem__(self, index): img_in = self.load_img(self.dir_in[index]) img_gt = self.load_img(self.dir_gt[index]) sample = {'img_in': img_in, 'img_gt': img_gt} return sample Let’s instantiate this class and iterate through the data samples. We will show the first 4 samples for 'Input' image and 'Ground Truth' image.
plt.close('all') f, axarr = plt.subplots(4, 2) auto_encoder_dataset = AutoEncoderDataSet('img/tr/in/', 'img/tr/gt/') for i in range(len(auto_encoder_dataset)): img_in, img_gt = auto_encoder_dataset[i] axarr[i, 0].imshow(img_in) axarr[i, 0].set_title('Input image #{}'.format(i)) axarr[i, 0].axis('off') axarr[i, 1].imshow(img_gt) axarr[i, 1].set_title('Ground truth image #{}'.format(i)) axarr[i, 1].axis('off') if i == 3: f.subplots_adjust(hspace=0.5) plt.show() break Out:
The full example code:
from torch.utils.data import Dataset from os import listdir from os.path import join import matplotlib.pyplot as plt from PIL import Image class AutoEncoderDataSet(Dataset): def __init__(self, dir_in, dir_gt): self.dir_in = self.load_dir_single(dir_in) self.dir_gt = self.load_dir_single(dir_gt) def is_image_file(self, filename): return any(filename.endswith(extension) for extension in [".png", ".PNG", ".jpg", ".JPG", ".jpeg", ".JPEG"]) def load_img(self, filename): img = Image.open(filename) return img def load_dir_single(self, directory): return [join(directory, x) for x in listdir(directory) if self.is_image_file(x)] def __len__(self): return len(self.dir_in) def __getitem__(self, index): img_in = self.load_img(self.dir_in[index]) img_gt = self.load_img(self.dir_gt[index]) sample = {'img_in': img_in, 'img_gt': img_gt} return sample def main(ps): plt.close('all') f, axarr = plt.subplots(4, 2) auto_encoder_dataset = AutoEncoderDataSet(ps['DIR_IMG_IN'], ps['DIR_IMG_GT']) for i in range(len(auto_encoder_dataset)): img_in, img_gt = auto_encoder_dataset[i] axarr[i, 0].imshow(img_in) axarr[i, 0].set_title('Input image #{}'.format(i)) axarr[i, 0].axis('off') axarr[i, 1].imshow(img_gt) axarr[i, 1].set_title('Ground truth image #{}'.format(i)) axarr[i, 1].axis('off') if i == 3: f.subplots_adjust(hspace=0.5) plt.show() break if __name__ == "__main__": ps = { 'DIR_IMG_IN': 'img/tr/in/', 'DIR_IMG_GT': 'img/tr/gt/' } main(ps) REFERENCES: |