I’m struggling to find the right transformers class for my task.
I want to solve a seq2seq problem with an encoder decoder longformer. I generated one with this german RoBERTa model using this script.
I know that I could use EncoderDecoderModel()
, but the issue is that it doesn’t support gradient checkpointing, which I desperately need, because otherwise it wouldn’t run on the machine.
And if I understand it correctly, the class LEDModel()
only takes already built encoder decoder models and not just a plain longformer to chain it together, so that is also not an option.
I thought about initializing two seperate Longformers for encoder and decoder with LongformerModel()
, but then I don’t know how to glue them together. Can someone explain how it works?
Or does anyone have another suggestions on how I can solve this problem?
Thank you very much!
I found a solution which at least helps a little:
When using EncoderDecoderModel()
, it is possible to set gradient checkpointing at least on the encoder part:
model.encoder.config.gradient_checkpointing = True
2 Likes