Rate this Page

Learn the Basics || Quickstart || Tensors || Datasets & DataLoaders || Transforms || Build Model || Autograd || Optimization || Save & Load Model

Datasets & DataLoaders#

Created On: Feb 09, 2021 | Last Updated: Sep 24, 2025 | Last Verified: Nov 05, 2024

Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code to be decoupled from our model training code for better readability and modularity. PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) that subclass torch.utils.data.Dataset and implement functions specific to the particular data. They can be used to prototype and benchmark your model. You can find them here: Image Datasets, Text Datasets, and Audio Datasets

Loading a Dataset#

Here is an example of how to load the Fashion-MNIST dataset from TorchVision. Fashion-MNIST is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples. Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes.

We load the FashionMNIST Dataset with the following parameters:
  • root is the path where the train/test data is stored,

  • train specifies training or test dataset,

  • download=True downloads the data from the internet if it’s not available at root.

  • transform and target_transform specify the feature and label transformations

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:13, 360kB/s]
  1%|          | 164k/26.4M [00:00<00:56, 464kB/s]
  3%|▎         | 721k/26.4M [00:00<00:15, 1.64MB/s]
 11%|█         | 2.88M/26.4M [00:00<00:04, 5.66MB/s]
 33%|███▎      | 8.81M/26.4M [00:00<00:01, 15.3MB/s]
 56%|█████▌    | 14.7M/26.4M [00:01<00:00, 20.9MB/s]
 77%|███████▋  | 20.4M/26.4M [00:01<00:00, 24.3MB/s]
 99%|█████████▊| 26.1M/26.4M [00:01<00:00, 30.9MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.1MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 358kB/s]
  5%|▌         | 229k/4.42M [00:00<00:06, 674kB/s]
 20%|██        | 885k/4.42M [00:00<00:01, 2.00MB/s]
 81%|████████  | 3.57M/4.42M [00:00<00:00, 7.00MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.02MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 64.6MB/s]

Iterating and Visualizing the Dataset#

We can index Datasets manually like a list: training_data[index]. We use matplotlib to visualize some samples in our training data.

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()