Problems and solution on Trainer

I am using the trainer to train an ASR model, the dataset and the output dimension are huge. This will cause some problems during training. I struggle with it many days, so I post my solution here, hope it can help.

  1. compute_metrics out of memory issue
    during compute_metrics, it will save all the logits in an array, when the output dimension is large, it will easily cause out-of-memory on a large dataset. The solution is to use torch.argmax on logits first to avoid saving all the data.

  2. when using trainer on seq2seq model, if the model output contains past_key_value, it will cause length error when merging different output, so past_key_value needs to be dropped on model output.

  3. group_by_length will take a very long processing time on model.train, and it uses a lot of memory to calculate the length of the data.

Note that for 3, you can have it in your datasets.Dataset computed once and for all with the map method, you just have to store the resuls in a "lengths" column. It will then uses the Dataset features and not try to access every element.

  1. model.forward should have labels argument if you have loss return, otherwise during prediction_loop will cause length error.

  2. it needs to add input_ids in dataset(for wav2vecProcessor, input will be input_values, I have to add an extra input_ids field XD)

It should be length instead of lengths, it can be customize using length_column_name now! That’s nice!