DataLoader is the heart of the PyTorch data-loading utility. It represents a Python iterable over a dataset. The most important argument of DataLoader is a dataset, which indicates a dataset object to load data from.

DataLoader supports automatically collating individual fetched data samples into batches via arguments batch_size. This is the most common cause and corresponds to fetching a minibatch of data and collating them into batched samples.

Internally, PyTorch uses a Collate Function to combine the data in your batches together. By default, a function called default_collate checks what type of data your Dataset returns and tries to combine it into a batch like (x_batch, y_batch).

What if we have custom types or multiple different types of data that we wanted to handle that default_collate couldn’t handle? We could edit our Dataset so that they are mergeable and that solves some of the types of issues BUT what if how we merged them depended on ‘batch-level information like the largest value in the batch.

You can use customized collate_fn to achieve custom batching, e.g., collating along a dimension other than the first, padding sequences of various lengths, or adding support for custom data types.

The custom collate_fn() function is often used for padding variable-length batches. So let’s create our dataset so that each item is a sequence, and they’re all different sizes.

reviews=['No man is an island','Entire of itself',
'Every man is a piece of the continent','part of the main',
'If a clod be washed away by the sea','Europe is the less',
'As well as if a promontory were','As well as if a manor of thy friend',
'Or of thine own were','Any man’s death diminishes me',
'Because I am involved in mankind',
'And therefore never send to know for whom the bell tolls',
'It tolls for thee']
labels=[random.randint(0, 1) for i in range(13)]
dataset=list(zip(reviews,labels))
tokenizer = get_tokenizer('basic_english')  
def yield_tokens(data_iter):
  for text,label in data_iter:
    yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(iter(dataset)), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
text_pipeline = lambda x: vocab(tokenizer(x))

A custom collate_fn can be used to customize collation, e.g., padding sequential data to a max length of a batch.collate_fn is called with a list of data samples at each time. It is expected to collate the input samples into a batch for yielding from the data loader iterator. 

def collate_batch(batch):
  
  label_list, text_list, = [], []
  
  for (_text,_label) in batch:
    label_list.append(_label)
    processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
    text_list.append(processed_text)
  
  label_list = torch.tensor(label_list, dtype=torch.int64)
  
  text_list = pad_sequence(text_list, batch_first=True, padding_value=0)
  
  return text_list.to(device),label_list.to(device),

With variable-sized sequences and a custom collate function, we could pad them to match the longest in the batch using torch.nn.utils.rnn.pad_sequence.

Before sending to the model, collate_fn function works on a batch of samples generated from DataLoader. The input to collate_fn is a batch of data with the batch size in DataLoader, and collate_fn processes them according to the data processing pipelines declared previously and make sure that collate_fn is declared as a top-level def. This ensures that the function is available to each worker.

Now, we can pass this collate_batch to collate_fn in DataLoader and it will pad each batch.

dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_batch,shuffle=True)
for x,y in dataloader:
  print(x,"Targets",y,"\n")
Custom collate function

It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values or lists if the values can not be converted into Tensors. Same for list s, tuple s, namedtuple s, etc.It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.

Related Post