PyTorch tensors are allocated in contiguous chunks of memory managed by torch.Storage instances. Storage is a one-dimensional array of numerical data. It is a contiguous block of memory containing numbers of a given type, such as float32 or int64.

A PyTorch Tensor instance is a view of such a Storage instance that is capable of indexing into that storage using offset and per-dimension strides.

Creating a tensor on the fly is good, but if the data inside is valuable and you want to save it to a file and load it back at some point. We don’t want to have to retrain a model from scratch every time we start running our program.

Save Tensors to File

PyTorch uses pickle under the hood to serialize the tensor object, plus a dedicated serialization code for the storage. Here’s we can save our inputs tensor to an inputs.t file:

inputs = torch.ones(3, 4)

torch.save(inputs, '../content/sample_data/inputs.t')

As an alternative, we can pass a file descriptor in lieu of the filename:

with open('../content/sample_data/alternative.t','wb') as f:
  torch.save(inputs, f)

Load Tensors from File

Loading our inputs back is similarly a one-line code.

inputs = torch.load('../content/sample_data/inputs.t')
print(inputs)

#or, equivalently,

with open('../content/sample_data/alternative.t','rb') as f:
  inputs = torch.load(f)

#tensor([[1., 1., 1., 1.],
              [1., 1., 1., 1.],
              [1., 1., 1., 1.]])

We can save tensors quickly this way but if we want to load them with the file format itself is not interoperable. We can’t read the tensor with software other than PyTorch. Depending on the use case, this may or may not be a limitation, but we should learn how to save tensors interoperably.

Serializing to HDF5 with h5py

Every use case is unique, but we need to save tensors interoperably will be more common when introducing PyTorch into existing systems that already rely on different libraries. New projects probably won’t need to do this as often.

For those cases, you can use the HDF5 format. HDF5 is a portable, widely supported format for representing serialized multidimensional arrays, organized in a nested key-value dictionary. Python supports HDF5 through the h5py library which accepts and returns data in the form of NumPy arrays.

We can save our inputs tensor by converting it to a NumPy array at no cost and passing it to the create_dataset function:

import h5py

fs = h5py.File('../content/sample_data/h5inputs.hdf5', 'w')
dset = fs.create_dataset('tKey', data=inputs.numpy())
fs.close()

Here ‘tKey’ is a key to the HDF5 file. We can have other keys or even nested ones. One of the interesting things in HDF5 is that we can index the dataset while on disk and access only the elements we’re interested in. Let’s suppose we want to load just the last two points in our dataset:

fs = h5py.File('../content/sample_data/h5inputs.hdf5', 'r')
dset = fs['tKey']
last_points = dset[-2:]

The data is not loaded when the file is opened or the dataset is required. Rather, the data stays on disk until we request the second and last rows in the dataset. At that point, h5py accesses those two columns and returns a NumPy array-like object encapsulating that region in that dataset that behaves like a NumPy array and has the same API.

We can pass the returned object to the torch.from_numpy function to obtain a tensor directly. Note that in this case, the data is copied over to the tensor’s storage:

last_points = torch.from_numpy(dset[-2:])
print(last_points)
fs.close()

Once we’re finished loading data, we close the file. Closing the HDF5 file invalidates the datasets, and trying to access set afterward will give an exception.

Related Post