Self-made Longformer doesn't take more than 512 token

I converted a german RoBERTa model into a Longformer with a capacity of 4096 token using this script. I use it for LEDModel(). But when I tokenize the data with max_length=4096, I get the following error when printing the loss or executing loss.backward():

RuntimeError: CUDA error: device-side assert triggered

I understand that this has something to do with invalid values outside of the expected index range of possible tokens (read that here). When I run the same code with max_length=512 it somehow works. Could it be that the convertion to a Longformer didn’t pick up the possible length of input sequences? But if so, why should that be a problem regarding the size of the vocabulary? There is a snippet of my code below.

Edit:
I think I know what’s wrong: LEDModel() wants an encoder decoder model and not just a transformer/longformer. So this class doesn’t chain two transformers together to create an encoder decoder.
So the question is: How can I use my Longformer for the Seq2Seq task so that I can also utilize gradient checkpointing?

MAX_LEN = 4096
batch_size = 1
epochs = 1
max_grad_norm = 0.06
weight_decay = 0.3
FULL_FINETUNING = True
learning_rate = 1e-5
adam_eps = 1e-10

torch.cuda.empty_cache()
model_path = sys.argv[2]
tokenizer = RobertaTokenizer.from_pretrained(model_path)
tokenizer.pad_token = 0

def train(train_dataloader, valid_dataloader, test_dataloader, voc_id_to_tok):
    model = LEDModel.from_pretrained(model_path)

    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.vocab_size = len(voc_id_to_tok)
    print("vocab size", model.config.vocab_size)
    model.config.use_cache = False
    model.gradient_checkpointing_enable()
    model.resize_token_embeddings(len(tokenizer))

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    model.to(device)

    if FULL_FINETUNING:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay_rate': weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.0}
        ]
    else:
        param_optimizer = list(model.classifier.named_parameters())
        optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=learning_rate,
        eps=adam_eps
    )

    scaler = GradScaler()

    # Total number of training steps is number of batches * number of epochs.
    total_steps = len(train_dataloader) * epochs

    # Create the learning rate scheduler.
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    for epoch in trange(epochs, desc="Epoch"):
        # ========================================
        #               Training
        # ========================================
        # Perform one full pass over the training set.

        model.train()
        total_loss = 0

        # Training loop
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.to(device) for t in batch)

            b_input_ids, b_output_ids = batch  # b_input_mask, b_output_ids = batch

            model.zero_grad()
            with autocast():
                outputs = model(input_ids=b_input_ids, decoder_input_ids=b_output_ids)

                loss = outputs[0].mean()

            scaler.scale(loss).backward()
            total_loss += loss.item()
            clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)

            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

        avg_train_loss = total_loss / len(train_dataloader)
        print("Average train loss: {}".format(avg_train_loss))

        # Store the loss value for plotting the learning curve.
        loss_values.append(avg_train_loss)