Hi- Thank you again for the awesome library & work.
I have been trying to repurpose the RAG code to train on the KILT dataset. As I understand, during the training phase, document encoder (and the index) is fixed, only the query encoder and the generator are fine-tuned.
As I train multiple epochs, something curios happens where the question encoder ‘collapses’ into emitting identical predictions regardless of the input. Specifically,
out2 are identical, even though input embeddings are different.
emb2 = torch.randn([1, 512, 768]) emb3 = torch.zeros([1, 512, 768]) # encoder out1 = model.rag.question_encoder.question_encoder.bert_model.encoder(emb2) out2 = model.rag.question_encoder.question_encoder.bert_model.encoder(emb3)
The way this behavior manifests itself is that the question encoder starts pulling the same wiki entries regardless of the question.
In fact, the last hidden states are identical for each token in the sequence.
I am curious if this type of behavior rings any bells? One hunch I have is whether mixed-precision training might be the cause. Any direction / feedback will be greatly appreciated, before I take the plunge and dig any further.