Carrying Gradients Through Generate

Hi folks,

How would you best recommend that I pass gradients through generate? below is a rough code snippet explaining the objective.

I am thinking that I could take the hypo_ids directly from the model output (instead of from generate), but this is no longer natural because teacher-forcing is used to generate these.

Thoughts?

Context from Pytorch Lightning Implementation:


# self.model = BartForConditionalGeneration("facebook/bart-base")

def forward(self, batch, batch_id):
    return self.model(input_ids = batch["x"], decoder_inputs=["decoder_inputs"], decoder_labels = ["decoder_labels"] )

def training_step(self, batch, batch_id)
   """Want two losses, language modelling loss and semantic similarity loss"""
    
    # language modelling loss
    outputs = self(batch)[0]
    language_modelling_loss = outputs[0]
    
    # semantic similarity loss
    target_ids = batch["target_ids"]
    hypo_ids = self.model.generate(batch["x"]) # no gradients passed of course
    semsim_loss = 1 - nn.CosineSimilarity(dim=0)(target_ids, hypo_ids)

   return {"loss": language_modelling_loss + semsim_loss}
    
    
1 Like

EDIT: The only method seems to be to use RL to simulate the sampling that occurs.

see https://papers.nips.cc/paper/8682-training-language-gans-from-scratch.pdf

@yjernite is also interested in this line of work.
I would write a method similar to parlai’s decode_forced

that forces the model to decode the tgt sequence and estimates its probability, then backprob the sum of the GT sequence. I’m not sure if that will lead to super similar results to the current teacher-forcing training approach, but it would be interesting to test!

1 Like

I just tried a simple ffnn to replicate argmax, but found that the gradients are almost always zero which makes sense I guess - changing other vector values will almost never change the maximum value.

This should also be interesting: Big `generate()` refactor

Hello,
I’m trying to do something similar. Did you manage to implement something working?