Rate this Page

Custom Python Operators#

Created On: Jun 18, 2024 | Last Updated: Mar 19, 2025 | Last Verified: Nov 05, 2024

What you will learn
  • How to integrate custom operators written in Python with PyTorch

  • How to test custom operators using torch.library.opcheck

Prerequisites
  • PyTorch 2.4 or later

PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc). However, you might wish to use a new customized operator with PyTorch, perhaps written by a third-party library. This tutorial shows how to wrap Python functions so that they behave like PyTorch native operators. Reasons why you may wish to create a custom operator in PyTorch include:

  • Treating an arbitrary Python function as an opaque callable with respect to torch.compile (that is, prevent torch.compile from tracing into the function).

  • Adding training support to an arbitrary Python function

Use torch.library.custom_op() to create Python custom operators. Use the C++ TORCH_LIBRARY APIs to create C++ custom operators (these work in Python-less environments). See the Custom Operators Landing Page for more details.

Please note that if your operation can be expressed as a composition of existing PyTorch operators, then there is usually no need to use the custom operator API – everything (for example torch.compile, training support) should just work.

Example: Wrapping PIL’s crop into a custom operator#

Let’s say that we are using PIL’s crop operation.

import torch
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
import PIL
import IPython
import matplotlib.pyplot as plt

def crop(pic, box):
    img = to_pil_image(pic.cpu())
    cropped_img = img.crop(box)
    return pil_to_tensor(cropped_img).to(pic.device) / 255.

def display(img):
    plt.imshow(img.numpy().transpose((1, 2, 0)))

img = torch.ones(3, 64, 64)
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
display(img)