Enot utility package

The enot.utils module contains utility functional.

batch_norm

This package contains batch normalization layer utilities. Some examples are batch norm tuning functionality, resetting all batch norm layers and checking if module is an instance of PyTorch’s BatchNorm class.

tune_bn_stats(model, dataloader, reset_bns=False, set_momentums_none=False, n_steps=None, epochs=1, sample_to_model_inputs=<function default_sample_to_model_inputs>, verbose=0)

Tunes batch norm running statistics of the model.

Parameters:
  • model (Module) – Model to update batch norms.

  • dataloader (torch.utils.data.DataLoader) – Dataloader which generates data that will be used to update model’s running statistics.

  • reset_bns (bool, optional) – Whether to reset ALL of the running statistics before tuning. Default value is False.

  • set_momentums_none (bool, optional) – Whether to set all of the momentums in batch norms to None. Default value is False.

  • n_steps (int or None, optional) – Number of steps in one epoch of batch norm tuning. Defaults to None, which completes each epoch.

  • epochs (int, optional) – Number of epochs to tune batch norms. Default value is 1.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • verbose (int) – Procedure verbosity level. 0 disables all messages, 1 enables tqdm progress bar logging.

Return type:

None

Notes

Typically, it is better to tune batch norm statistics on the same data you trained your model. When tuning batch norms on the holdout set, you may experience performance degradation due to the distribution shift.

is_bn(module)

Checks whether torch.nn.Module is a BatchNorm instance or not.

Parameters:

module (torch.nn.Module) – Module to check.

Returns:

Whether input module is a BatchNorm instance or not.

Return type:

bool

reset_bn(module)

Resets module running stats if module is a batch norm layer.

Parameters:

module (torch.nn.Module) – Module to reset running stats (if module is an instance of torch BatchNorm).

Return type:

None

model_reset_bn(model)

Resets running stats in all batch norm layers.

Parameters:

model (torch.nn.Module) – Model in which the running stats of all batch norm layers will be reset.

Return type:

None

model_set_bn_momentum(model, momentum=None)

Sets momentum value in all batch norm layers.

Parameters:
  • model (torch.nn.Module) – Model to update batch norms.

  • momentum (float or None, optional) – Momentum value to set in all batch norm layers. Default value is None.

Return type:

None