Note
Go to the end to download the full example code.
Custom Python Operators#
Created On: Jun 18, 2024 | Last Updated: Mar 19, 2025 | Last Verified: Nov 05, 2024
How to integrate custom operators written in Python with PyTorch
How to test custom operators using
torch.library.opcheck
PyTorch 2.4 or later
PyTorch offers a large library of operators that work on Tensors (e.g.
torch.add, torch.sum, etc). However, you might wish to use a new customized
operator with PyTorch, perhaps written by a third-party library. This tutorial
shows how to wrap Python functions so that they behave like PyTorch native
operators. Reasons why you may wish to create a custom operator in PyTorch include:
Treating an arbitrary Python function as an opaque callable with respect to
torch.compile(that is, preventtorch.compilefrom tracing into the function).Adding training support to an arbitrary Python function
Use torch.library.custom_op() to create Python custom operators.
Use the C++ TORCH_LIBRARY APIs to create C++ custom operators (these
work in Python-less environments).
See the Custom Operators Landing Page
for more details.
Please note that if your operation can be expressed as a composition of
existing PyTorch operators, then there is usually no need to use the custom operator
API – everything (for example torch.compile, training support) should
just work.
Example: Wrapping PIL’s crop into a custom operator#
Let’s say that we are using PIL’s crop operation.
import torch
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
import PIL
import IPython
import matplotlib.pyplot as plt
def crop(pic, box):
img = to_pil_image(pic.cpu())
cropped_img = img.crop(box)
return pil_to_tensor(cropped_img).to(pic.device) / 255.
def display(img):
plt.imshow(img.numpy().transpose((1, 2, 0)))
img = torch.ones(3, 64, 64)
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
display(img)