Debugging the RAG question encoder

Hi- Thank you again for the awesome library & work.
I have been trying to repurpose the RAG code to train on the KILT dataset. As I understand, during the training phase, document encoder (and the index) is fixed, only the query encoder and the generator are fine-tuned.

As I train multiple epochs, something curios happens where the question encoder ‘collapses’ into emitting identical predictions regardless of the input. Specifically, out1 and out2 are identical, even though input embeddings are different.

emb2 = torch.randn([1, 512, 768])
emb3 = torch.zeros([1, 512, 768])

# encoder
out1 = model.rag.question_encoder.question_encoder.bert_model.encoder(emb2)
out2 = model.rag.question_encoder.question_encoder.bert_model.encoder(emb3)

The way this behavior manifests itself is that the question encoder starts pulling the same wiki entries regardless of the question.

In fact, the last hidden states are identical for each token in the sequence.

I am curious if this type of behavior rings any bells? One hunch I have is whether mixed-precision training might be the cause. Any direction / feedback will be greatly appreciated, before I take the plunge and dig any further.

Thank you!

1 Like

Hi ! There’s some discussion about that at Retrieval Collapse when fine-tuning RAG · Issue #9405 · huggingface/transformers · GitHub
Apparently it can happen in some setups

1 Like

This is it! Thank you.