Note
Go to the end to download the full example code.
What is torch.nn really?#
Created On: Dec 26, 2018 | Last Updated: Jan 24, 2025 | Last Verified: Nov 05, 2024
Authors: Jeremy Howard, fast.ai. Thanks to Rachel Thomas and Francisco Ingham.
We recommend running this tutorial as a notebook, not a script. To download the notebook (.ipynb) file,
click the link at the top of the page.
PyTorch provides the elegantly designed modules and classes torch.nn ,
torch.optim ,
Dataset ,
and DataLoader
to help you create and train neural networks.
In order to fully utilize their power and customize
them for your problem, you need to really understand exactly what they’re
doing. To develop this understanding, we will first train basic neural net
on the MNIST data set without using any features from these models; we will
initially only use the most basic PyTorch tensor functionality. Then, we will
incrementally add one feature from torch.nn, torch.optim, Dataset, or
DataLoader at a time, showing exactly what each piece does, and how it
works to make the code either more concise, or more flexible.
This tutorial assumes you already have PyTorch installed, and are familiar with the basics of tensor operations. (If you’re familiar with Numpy array operations, you’ll find the PyTorch tensor operations used here nearly identical).
MNIST data setup#
We will use the classic MNIST dataset, which consists of black-and-white images of hand-drawn digits (between 0 and 9).
We will use pathlib for dealing with paths (part of the Python 3 standard library), and will download the dataset using requests. We will only import modules when we use them, so you can see exactly what’s being used at each point.
from pathlib import Path
import requests
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)
URL = "https://github.com/pytorch/tutorials/raw/main/_static/"
FILENAME = "mnist.pkl.gz"
if not (PATH / FILENAME).exists():
content = requests.get(URL + FILENAME).content
(PATH / FILENAME).open("wb").write(content)
This dataset is in numpy array format, and has been stored using pickle, a python-specific format for serializing data.
Each image is 28 x 28, and is being stored as a flattened row of length 784 (=28x28). Let’s take a look at one; we need to reshape it to 2d first.
from matplotlib import pyplot
import numpy as np
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
# ``pyplot.show()`` only if not on Colab
try:
import google.colab
except ImportError:
pyplot.show()
print(x_train.shape)