""" `Learn the Basics `_ || `Quickstart `_ || `Tensors `_ || **Datasets & DataLoaders** || `Transforms `_ || `Build Model `_ || `Autograd `_ || `Optimization `_ || `Save & Load Model `_ Datasets & DataLoaders ====================== """ ################################################################# # Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code # to be decoupled from our model training code for better readability and modularity. # PyTorch provides two data primitives: ``torch.utils.data.DataLoader`` and ``torch.utils.data.Dataset`` # that allow you to use pre-loaded datasets as well as your own data. # ``Dataset`` stores the samples and their corresponding labels, and ``DataLoader`` wraps an iterable around # the ``Dataset`` to enable easy access to the samples. # # PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) that # subclass ``torch.utils.data.Dataset`` and implement functions specific to the particular data. # They can be used to prototype and benchmark your model. You can find them # here: `Image Datasets `_, # `Text Datasets `_, and # `Audio Datasets `_ # ############################################################ # Loading a Dataset # ------------------- # # Here is an example of how to load the `Fashion-MNIST `_ dataset from TorchVision. # Fashion-MNIST is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples. # Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes. # # We load the `FashionMNIST Dataset `_ with the following parameters: # - ``root`` is the path where the train/test data is stored, # - ``train`` specifies training or test dataset, # - ``download=True`` downloads the data from the internet if it's not available at ``root``. # - ``transform`` and ``target_transform`` specify the feature and label transformations import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() ) ################################################################# # Iterating and Visualizing the Dataset # ------------------------------------- # # We can index ``Datasets`` manually like a list: ``training_data[index]``. # We use ``matplotlib`` to visualize some samples in our training data. labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } figure = plt.figure(figsize=(8, 8)) cols, rows = 3, 3 for i in range(1, cols * rows + 1): sample_idx = torch.randint(len(training_data), size=(1,)).item() img, label = training_data[sample_idx] figure.add_subplot(rows, cols, i) plt.title(labels_map[label]) plt.axis("off") plt.imshow(img.squeeze(), cmap="gray") plt.show() ################################################################# # .. # .. figure:: /_static/img/basics/fashion_mnist.png # :alt: fashion_mnist ###################################################################### # -------------- # ################################################################# # Creating a Custom Dataset for your files # --------------------------------------------------- # # A custom Dataset class must implement three functions: `__init__`, `__len__`, and `__getitem__`. # Take a look at this implementation; the FashionMNIST images are stored # in a directory ``img_dir``, and their labels are stored separately in a CSV file ``annotations_file``. # # In the next sections, we'll break down what's happening in each of these functions. import os import pandas as pd from torchvision.io import decode_image class CustomImageDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = decode_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label ################################################################# # ``__init__`` # ^^^^^^^^^^^^^^^^^^^^ # # The __init__ function is run once when instantiating the Dataset object. We initialize # the directory containing the images, the annotations file, and both transforms (covered # in more detail in the next section). # # The labels.csv file looks like: :: # # tshirt1.jpg, 0 # tshirt2.jpg, 0 # ...... # ankleboot999.jpg, 9 def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform ################################################################# # ``__len__`` # ^^^^^^^^^^^^^^^^^^^^ # # The __len__ function returns the number of samples in our dataset. # # Example: def __len__(self): return len(self.img_labels) ################################################################# # ``__getitem__`` # ^^^^^^^^^^^^^^^^^^^^ # # The __getitem__ function loads and returns a sample from the dataset at the given index ``idx``. # Based on the index, it identifies the image's location on disk, converts that to a tensor using ``decode_image``, retrieves the # corresponding label from the csv data in ``self.img_labels``, calls the transform functions on them (if applicable), and returns the # tensor image and corresponding label in a tuple. def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = decode_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label ###################################################################### # -------------- # ################################################################# # Preparing your data for training with DataLoaders # ------------------------------------------------- # The ``Dataset`` retrieves our dataset's features and labels one sample at a time. While training a model, we typically want to # pass samples in "minibatches", reshuffle the data at every epoch to reduce model overfitting, and use Python's ``multiprocessing`` to # speed up data retrieval. # # ``DataLoader`` is an iterable that abstracts this complexity for us in an easy API. from torch.utils.data import DataLoader train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True) ########################### # Iterate through the DataLoader # ------------------------------- # # We have loaded that dataset into the ``DataLoader`` and can iterate through the dataset as needed. # Each iteration below returns a batch of ``train_features`` and ``train_labels`` (containing ``batch_size=64`` features and labels respectively). # Because we specified ``shuffle=True``, after we iterate over all batches the data is shuffled (for finer-grained control over # the data loading order, take a look at `Samplers `_). # Display image and label. train_features, train_labels = next(iter(train_dataloader)) print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") img = train_features[0].squeeze() label = train_labels[0] plt.imshow(img, cmap="gray") plt.show() print(f"Label: {label}") ###################################################################### # -------------- # ################################################################# # Further Reading # ---------------- # - `torch.utils.data API `_