This post will talk about GPU-accelerated PyTorch training using the MPS backend on Mac platforms. MPS enables high-performance training on GPU for MacOS devices with a Metal programming framework. I’ll also provide an overview of the software stack. So let’s talk briefly about the MPS backend and software components it relies on.

Earlier in 2022 Apple introduced metal acceleration in PyTorch. Metal is the GPU programming API on Apple platforms. Metal Performance Shaders(MPS) is a collection of high-performance GPU primitives for various fields like image processing, linear algebra, ray tracing, and machine learning. These metal kernels are optimized to provide the best performance on all of Apple’s platforms.

PyTorch MPS default device

Apple’s Metal Performance Shaders(MPS) Graph is a general-purpose compute graph for GPUs. It extends support to multi-dimensional tensors. It has compiler technology to allow it to do various optimizations such as operative fusion, constant folding, dead code elimination, and many more.

PyTorch MPS backend implements both, the operation kernels and the runtime framework which calls into the MPS graph. This enables PyTorch to highly efficient Kernels from MPS along with models command cues command buffers and Synchronization Primitives.

Operations in PyTorch will create unique graphs which will be cached, reducing the CPU overhead, and these operations are then encoded into the stream.

Requirements

  • Mac computers with Apple silicon or AMD GPUs
  • macOS 12.3 or later
  • Python 3.7 or later
  • PyTorch 2.0.0 or later.

To get started, just install the latest PyTorch build on your Apple silicon Mac running macOS 12.3 or later with a native version (arm64) of Python.

 Verify PyTorch Installation

You can verify your PyTorch MPS support installation using a simple Python script:

if torch.backends.mps.is_built():
    print("PyTorch is built with MPS support")
else:
    print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")

This function returns whether PyTorch is built with MPS support. Note that this doesn’t necessarily mean MPS is available. Just check this PyTorch binary were run a machine with working MPS drivers and devices, we would be able to use it.

To build PyTorch, follow the instructions provided on the Install PyTorch 2.0 GPU/MPS for Mac M1/M2 with Conda.

Check MPS Support

You can verify MPS support using a simple Python script:

if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
else:
    print ("MPS device not found.")

There is only ever one device though, so no equivalent to device_count in the Python API.

Change Default Device to MPS

It is common practice to write PyTorch code in a device-agnostic way, and then switch between CPU and MPS/CUDA depending on what hardware is available.

x = torch.ones(1, device=mps_device)
print (x.device) #mps:0

PyTorch now also has a context manager which can take care of the device transfer automatically.

with torch.device(mps_device):
    x = torch.ones(1)
    print(x.device) #mps:0

You can also set it globally like this:

torch.set_default_device('mps')

mod = torch.nn.Linear(20, 30)
print(mod.weight.device) #mps:0
print(mod(torch.randn(128, 20)).device) #mps:0

This function imposes a slight performance cost on every Python call to the torch API.