I am using the trainer to train an ASR model, the dataset and the output dimension are huge. This will cause some problems during training. I struggle with it many days, so I post my solution here, hope it can help.
-
compute_metrics out of memory issue
during compute_metrics, it will save all the logits in an array, when the output dimension is large, it will easily cause out-of-memory on a large dataset. The solution is to use torch.argmax on logits first to avoid saving all the data. -
when using trainer on seq2seq model, if the model output contains past_key_value, it will cause length error when merging different output, so past_key_value needs to be dropped on model output.
-
group_by_length will take a very long processing time on model.train, and it uses a lot of memory to calculate the length of the data.