Utilities for Trainer

Utilities for Trainer

This page lists all the utility functions used by Trainer.

Most of those are only useful if you are studying the code of the Trainer in the library.

Utilities

class transformers.EvalPrediction

<source>

( predictions: typing.Union[numpy.ndarray, typing.Tuple[numpy.ndarray]]label_ids: typing.Union[numpy.ndarray, typing.Tuple[numpy.ndarray]]inputs: typing.Union[numpy.ndarray, typing.Tuple[numpy.ndarray], NoneType] = None )

Parameters

  • predictions (np.ndarray) β€” Predictions of the model.

  • label_ids (np.ndarray) β€” Targets to be matched.

  • inputs (np.ndarray, optional) β€”

Evaluation output (always contains labels), to be used to compute metrics.

class transformers.IntervalStrategy

<source>

( valuenames = Nonemodule = Nonequalname = Nonetype = Nonestart = 1 )

An enumeration.

transformers.enable_full_determinism

<source>

( seed: intwarn_only: bool = False )

Helper function for reproducible behavior during distributed training. See

transformers.set_seed

<source>

( seed: int )

Parameters

  • seed (int) β€” The seed to set.

Helper function for reproducible behavior to set the seed in random, numpy, torch and/or tf (if installed).

transformers.torch_distributed_zero_first

<source>

( local_rank: int )

Parameters

  • local_rank (int) β€” The rank of the local process.

Decorator to make all processes in distributed training wait for each local_master to do something.

Callbacks internals

class transformers.trainer_callback.CallbackHandler

<source>

( callbacksmodeltokenizeroptimizerlr_scheduler )

Internal class that just calls the list of callbacks in order.

Distributed Evaluation

class transformers.trainer_pt_utils.DistributedTensorGatherer

<source>

( world_sizenum_samplesmake_multiple_of = Nonepadding_index = -100 )

Parameters

  • world_size (int) β€” The number of processes used in the distributed training.

  • num_samples (int) β€” The number of samples in our dataset.

  • make_multiple_of (int, optional) β€” If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument (by adding samples).

  • padding_index (int, optional, defaults to -100) β€” The padding index to use if the arrays don’t all have the same sequence length.

A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.

If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on CPU at every step, our sampler will generate the following indices:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]

to get something of size a multiple of 3 (so that each process gets the same dataset length). Then process 0, 1 and 2 will be responsible of making predictions for the following samples:

  • P0: [0, 1, 2, 3, 4, 5]

  • P1: [6, 7, 8, 9, 10, 11]

  • P2: [12, 13, 14, 15, 0, 1]

The first batch treated on each process will be

  • P0: [0, 1]

  • P1: [6, 7]

  • P2: [12, 13]

So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor) corresponding to the following indices:

[0, 1, 6, 7, 12, 13]

If we directly concatenate our results without taking any precautions, the user will then get the predictions for the indices in this order at the end of the prediction loop:

[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]

For some reason, that’s not going to roll their boat. This class is there to solve that problem.

add_arrays

<source>

( arrays )

Add arrays to the internal storage, Will initialize the storage to the full size at the first arrays passed so that if we’re bound to get an OOM, it happens at the beginning.

finalize

<source>

( )

Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras to get each process a dataset of the same length).

Distributed Evaluation

class transformers.HfArgumentParser

<source>

( dataclass_types: typing.Union[DataClassType, typing.Iterable[DataClassType]]**kwargs )

This subclass of argparse.ArgumentParser uses type hints on dataclasses to generate arguments.

The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) arguments to the parser after initialization and you’ll get the output back after parsing as an additional namespace. Optional: To create sub argument groups use the _argument_group_name attribute in the dataclass.

parse_args_into_dataclasses

<source>

( args = Nonereturn_remaining_strings = Falselook_for_args_file = Trueargs_filename = Noneargs_file_flag = None ) β†’ Tuple consisting of

Returns

Tuple consisting of

  • the dataclass instances in the same order as they were passed to the initializer.abspath

  • if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser after initialization.

  • The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)

Parse command-line args into instances of the specified dataclass types.

This relies on argparse’s ArgumentParser.parse_known_args. See the doc at: docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args

parse_dict

<source>

( args: typing.Dict[str, typing.Any]allow_extra_keys: bool = False ) β†’ Tuple consisting of

Parameters

  • args (dict) β€” dict containing config values

  • allow_extra_keys (bool, optional, defaults to False) β€” Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.

Returns

Tuple consisting of

  • the dataclass instances in the same order as they were passed to the initializer.

Alternative helper method that does not use argparse at all, instead uses a dict and populating the dataclass types.

parse_json_file

<source>

( json_file: strallow_extra_keys: bool = False ) β†’ Tuple consisting of

Parameters

  • json_file (str or os.PathLike) β€” File name of the json file to parse

  • allow_extra_keys (bool, optional, defaults to False) β€” Defaults to False. If False, will raise an exception if the json file contains keys that are not parsed.

Returns

Tuple consisting of

  • the dataclass instances in the same order as they were passed to the initializer.

Alternative helper method that does not use argparse at all, instead loading a json file and populating the dataclass types.

parse_yaml_file

<source>

( yaml_file: strallow_extra_keys: bool = False ) β†’ Tuple consisting of

Parameters

  • yaml_file (str or os.PathLike) β€” File name of the yaml file to parse

  • allow_extra_keys (bool, optional, defaults to False) β€” Defaults to False. If False, will raise an exception if the json file contains keys that are not parsed.

Returns

Tuple consisting of

  • the dataclass instances in the same order as they were passed to the initializer.

Alternative helper method that does not use argparse at all, instead loading a yaml file and populating the dataclass types.

Debug Utilities

class transformers.debug_utils.DebugUnderflowOverflow

<source>

( modelmax_frames_to_save = 21trace_batch_nums = []abort_after_batch_num = None )

Parameters

  • model (nn.Module) β€” The model to debug.

  • max_frames_to_save (int, optional, defaults to 21) β€” How many frames back to record

  • trace_batch_nums(List[int], optional, defaults to []) β€” Which batch numbers to trace (turns detection off)

  • abort_after_batch_num (`intβ€œ, optional) β€” Whether to abort after a certain batch number has finished

This debug class helps detect and understand where the model starts getting very large or very small, and more importantly nan or inf weight and activation elements.

There are 2 working modes:

  1. Underflow/overflow detection (default)

  2. Specific batch absolute min/max tracing without detection

Mode 1: Underflow/overflow detection

To activate the underflow/overflow detection, initialize the object with the model :

Copied

debug_overflow = DebugUnderflowOverflow(model)

then run the training as normal and if nan or inf gets detected in at least one of the weight, input or output elements this module will throw an exception and will print max_frames_to_save frames that lead to this event, each frame reporting

  1. the fully qualified module name plus the class name whose forward was run

  2. the absolute min and max value of all elements for each module weights, and the inputs and output

For example, here is the header and the last few frames in detection report for google/mt5-small run in fp16

mixed precision :

Copied

Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min  abs max  metadata
[...]
                  encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
                  encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
                  encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
                  encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
                  encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00      inf output

You can see here, that T5DenseGatedGeluDense.forward resulted in output activations, whose absolute max value was around 62.7K, which is very close to fp16’s top limit of 64K. In the next frame we have Dropout which renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than 64K, and we get an overlow.

As you can see it’s the previous frames that we need to look into when the numbers start going into very large for fp16 numbers.

The tracking is done in a forward hook, which gets invoked immediately after forward has completed.

By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :

Copied

debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)

To validate that you have set up this debugging feature correctly, and you intend to use it in a training that may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in the next section.

Mode 2. Specific batch absolute min/max tracing without detection

The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.

Let’s say you want to watch the absolute min and max values for all the ingredients of each forward call of a

given batch, and only do that for batches 1 and 3. Then you instantiate this class as :

Copied

debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])

And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.

This is helpful if you know that the program starts misbehaving after a certain batch number, so you can fast-forward right to that area.

Early stopping:

You can also specify the batch number after which to stop the training, with :

Copied

debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)

This feature is mainly useful in the tracing mode, but you can use it for any mode.

Performance:

As this module measures absolute min/`max of each weight of the model on every forward it’ll slow the training down. Therefore remember to turn it off once the debugging needs have been met.

Last updated