Converting dataloader items to PyTorch model inputs

Short summary: When ENOT needs to pass items from your dataloader to your model, it requires a special mapping function. This function will be used as following:

# Your code:
your_model = ...
your_dataloader = ...
your_mapping_function = ...

# In the ENOT framework:
sample = next(your_dataloader)
model_args, model_kwargs = your_mapping_function(sample)
model_result = your_model(*model_args, **model_kwargs)

When working with ENOT framework, sometimes it is necessary to write functions, which can transform dataloader items to model input format. We can later use these functions to extract dataloader items with the __next__ method and convert them into user’s custom model’s input format.

Such function should take a single input of type Any - a single item from user dataloader. It should return a tuple with two elements. First element is a tuple with model positional arguments, which will be passed to model’s __call__ method by unpacking operator *args. Second element is a dictionary with string keys and any-type values, which defines model keyword arguments, and will be passed to it’s __call__ method by unpacking operator **kwargs.

Let’s see the most basic example:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor)

def my_conversion_function(x):
    return (x[0], ), {}  # Single positional argument for model, no keyword arguments.

The same functionality is provided by our function default_sample_to_model_inputs().

Example function which changes model’s default forward keyword argument value:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor, should_profile: bool = True)

def my_conversion_function(x):
    return (x[0], ), {'should_profile': False}  # We do not want to profile!

Example function which also performs data pre-processing and moves it to the GPU:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor)

def my_conversion_function(x):
    # Normalizing images, casting them to float32 format, and move to cuda.
    return ((x[0].float()) / 255).cuda(), ), {}

Advanced case with complex dataloader item structure and model positional and keyword arguments:

DataLoader returns: {'sequence': torch.Tensor, 'translation': torch.Tensor, 'masks': List[torch.Tensor]}
Model expects: (sequence: torch.Tensor, mask: Optional[torch.Tensor] = None, unroll_attention: bool = False)

def my_conversion_function(x):
    sequence = x['sequence']
    mask = x['masks'][0]
    return (sequence, ), {'mask': mask, 'unroll_attention': True}
default_sample_to_model_inputs(x)

Default function for sample to model input conversion.

This function covers the most simple case when dataloader returns pairs of images and labels in the form (images, labels). If the model receives a single positional argument, then you can use this function to convert dataloader output to model input.

Parameters

x (tuple) – Tuple with one or more elements from which only the first item will be passed to the model. So, user-defined model must have a single positional argument in it’s forward function definition.

Returns

Model input args and kwargs, see more here.

Return type

tuple with two items - tuple and dict with str keys

Extracting the number of samples from dataloader item

Sometimes ENOT needs to know the number of samples of interest in dataloader item. Sample of interest could be an image, one second of audio, number of bounding boxes, e.t.c. What single sample means fully depends on you. To provide this information to ENOT, you should define a special function. This function will be used as following:

# Your code:
your_model = ...
your_dataloader = ...
your_function = ...

# In the ENOT framework:
sample = next(your_dataloader)
n_base_samples: int = your_function(sample)
# process n_base_samples

Such function should take a single input of type Any - a single item from user dataloader. It should return single integer value - number of samples of interest in dataloader sample.

Let’s see the most basic example:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)

def my_function(x):
    # Extract the number of images from the batch dimension.
    return x[0].shape[0]

The same functionality is provided by our function default_sample_to_n_samples().

Advanced case with complex dataloader item structure:

DataLoader returns: {'sequence': torch.Tensor, 'translation': torch.Tensor, 'mask': torch.Tensor}

def my_conversion_function(x):
    # sequence is a tensor of shape (N, T1) with int64 dtype.
    # translation is a tensor of shape (N, T2) with int64 dtype.
    # mask is a tensor of shape (N, T1) with int64 dtype.
    # mask is equal to 1 where sequence values are correct (not padded) and 0 elsewhere.
    total_time_steps = x['mask'].sum()
    total_seconds = total_time_steps // 100  # Let's suppose that one time step is 10ms.
    return total_seconds  # Total audio length in seconds in all samples.
default_sample_to_n_samples(x)

Default function to extract the number of samples of interest from dataloader item.

This function covers the most simple case when dataloader returns pairs of images and labels in the form (images, labels).

Parameters

x (tuple) – Tuple with one or more elements where the first element is a tensor with the batch dimension axis equal to 0.

Returns

Number of samples of interest in dataloader item, see more here.

Return type

int