Longformer for Encoder Decoder with gradient checkpointing

I’m struggling to find the right transformers class for my task.
I want to solve a seq2seq problem with an encoder decoder longformer. I generated one with this german RoBERTa model using this script.
I know that I could use EncoderDecoderModel(), but the issue is that it doesn’t support gradient checkpointing, which I desperately need, because otherwise it wouldn’t run on the machine.
And if I understand it correctly, the class LEDModel() only takes already built encoder decoder models and not just a plain longformer to chain it together, so that is also not an option.
I thought about initializing two seperate Longformers for encoder and decoder with LongformerModel(), but then I don’t know how to glue them together. Can someone explain how it works?
Or does anyone have another suggestions on how I can solve this problem?
Thank you very much!

I found a solution which at least helps a little:
When using EncoderDecoderModel(), it is possible to set gradient checkpointing at least on the encoder part:
model.encoder.config.gradient_checkpointing = True

2 Likes