Problems and solution on Trainer

I am using the trainer to train an ASR model, the dataset and the output dimension are huge. This will cause some problems during training. I struggle with it many days, so I post my solution here, hope it can help.

  1. compute_metrics out of memory issue
    during compute_metrics, it will save all the logits in an array, when the output dimension is large, it will easily cause out-of-memory on a large dataset. The solution is to use torch.argmax on logits first to avoid saving all the data.

  2. when using trainer on seq2seq model, if the model output contains past_key_value, it will cause length error when merging different output, so past_key_value needs to be dropped on model output.

  3. group_by_length will take a very long processing time on model.train, and it uses a lot of memory to calculate the length of the data.

1 Like

Note that for 3, you can have it in your datasets.Dataset computed once and for all with the map method, you just have to store the resuls in a "lengths" column. It will then uses the Dataset features and not try to access every element.

1 Like
  1. model.forward should have labels argument if you have loss return, otherwise during prediction_loop will cause length error.

  2. it needs to add input_ids in dataset(for wav2vecProcessor, input will be input_values, I have to add an extra input_ids field XD)

It should be length instead of lengths, it can be customize using length_column_name now! That’s nice!