Converting dataloader items to PyTorch model inputs

Short summary: When ENOT needs to pass items from your dataloader to your model, it requires a special mapping function. This function will be used as following:

# Your code:
your_model = ...
your_dataloader = ...
your_mapping_function = ...

# In the ENOT framework:
sample = next(your_dataloader)
model_args, model_kwargs = your_mapping_function(sample)
model_result = your_model(*model_args, **model_kwargs)

When working with ENOT framework, sometimes it is necessary to write functions, which can transform dataloader items to model input format. We can later use these functions to extract dataloader items with the __next__ method and convert them into user’s custom model’s input format.

Such function should take a single input of type Any - a single item from user dataloader. It should return a tuple with two elements. First element is a tuple with model positional arguments, which will be passed to model __call__ method by unpacking operator *args. Second element is a dictionary with string keys and any-type values, which defines model keyword arguments, and will be passed to it’s __call__ method by unpacking operator **kwargs.

Let’s see the most basic example:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor)

def my_conversion_function(x):
    return (x[0], ), {}  # Single positional argument for model, no keyword arguments.

The same functionality is provided by our function default_sample_to_model_inputs().

Example function which changes model’s default forward keyword argument value:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor, should_profile: bool = True)

def my_conversion_function(x):
    return (x[0], ), {'should_profile': False}  # We do not want to profile!

Example function which also performs data pre-processing and moves it to the GPU:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor)

def my_conversion_function(x):
    # Normalizing images, casting them to float32 format, and move to cuda.
    return ((x[0].float()) / 255).cuda(), ), {}

Advanced case with complex dataloader item structure and model positional and keyword arguments:

DataLoader returns: {'sequence': torch.Tensor, 'translation': torch.Tensor, 'masks': List[torch.Tensor]}
Model expects: (sequence: torch.Tensor, mask: Optional[torch.Tensor] = None, unroll_attention: bool = False)

def my_conversion_function(x):
    sequence = x['sequence']
    mask = x['masks'][0]
    return (sequence, ), {'mask': mask, 'unroll_attention': True}
default_sample_to_model_inputs(x)

Default function for sample to model input conversion.

This function covers the most simple case when dataloader returns pairs of images and labels in the form (images, labels). If the model receives a single positional argument, then you can use this function to convert dataloader output to model input.

Parameters

x (tuple) – Tuple with one or more elements from which only the first item will be passed to the model. So, user-defined model must have a single positional argument in it’s forward function definition.

Returns

Model input args and kwargs, see more in Converting dataloader items to PyTorch model inputs.

Return type

tuple with two items - tuple and dict with str keys