No need to change the HF code, just the structure of your own code:
class SampleModule(pl.LightningModule):
def __init__(kwargs):
# initialize stuff
self.model = BartForConditionalGeneration.from_pretrained(kwargs.arch)
# more stuff
def forward(kwargs):
return self.model(kwargs.batch)
def training_step(kwargs):
# do all stuff with shifting decoder inputs etc here
# then call self() as your forward method