T5 Model Problems - Constant Loss (doesn't go down)

I’ve been trying to train T5 on a custom dataset similar to Squad v2 by modifying the T5 on TPU colab written by Suraj Patil. (I have Colab Pro so I train on a high-ram TPU instance). The link I have shared is training on Squad v2 directly, which seems to have the same problem.

link: Google Colaboratory

However, no matter how much I train the error seems to stay constant, i.e. the model does not seem to learn. This seems to be the case both for loss recorded during training, as well as loss during the valuation phase.

Could someone please tell me what I am doing wrong? I am going a little crazy trying to figure it out.

Thank you in advance all.

I’ve corrected the following issues thus far:

  1. Different XLA import at start
  2. Modification of code to allow for answer-less questions under Squadv2 (as opposed to Squadv1 for Suraj’s original code) under the eos/encoder section
  3. Edited data imports to use huggingface’s datasets.load_datasets instead of NLP
  4. Under T2TDataCollator, modify batching to ensure that the inputs are tensors instead of lists, e.g.: torch.FloatTensor(example[‘input_ids’]).to(torch.int64)
  5. Specifying transformers version 2.9.1 to allow for Suraj’s particular usage of T5DataCollator (although I created a version using the current version of transformers, this also has the same problem described above).

[5b. If I use the current version of transformers and not 2.9.1, I make various modifications to T5DataCollator and the labels generated in the training phase to be (‘labels’, ‘decoder_attention_mask’ instead of ‘target_ids’ and ‘target_attention_mask’]

Hey @pjahn89, I am facing a similar issue while trying to finetuning T5 on XSum using TPU/GPU. The training loss is not constant (it varies, but doesn’t converge). But, my validation loss is constant, like literally not even a change in 5th decimal place, I tried many things like creating my nn.Module compatible with the trainer. Subclassed the trainer to modify compute_loss(). But, I am not seeing any change. If you have solved this issue, can you please tell me how?

Here’s a link to my colab for reference.

Thank you!