Cache T5 encoder results within batch when training

My batches are a little unorthodox, they consist of the same input but different targets. Basically I moved the prefixes T5 uses from the input to the target and trained the model this way. Like this:

input: “I am Alex, 33 years old”, target: “name: Alex”
input: “I am Alex, 33 years old”, target: “age: 33”

It would make training much faster if I could calculate the encoder output just once per batch and reuse it within the batch. So 1 forward pass for the encoder but multiple backward passes. Is there a library supported way to do this or do I need to go completely custom and write the training loop from scratch?

1 Like