PyTorch Transformers: TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'

I want to modify this model so it is made of xlm-roberta as encoder and gpt-neo as decoder.

At the beginning of training I get the following error:

Traceback (most recent call last):
  File "run.py", line 277, in <module>
    task()
  File "/opt/conda/lib/python3.7/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.7/site-packages/click/core.py", line 1657, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.7/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.7/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "run.py", line 166, in train
    train_model(start_epoch, eval_loss, (train_dl, valid_dl), optimizer, kwargs['base_path']+kwargs["checkpoint_path"], kwargs['base_path']+kwargs["best_model"])
  File "run.py", line 233, in train_model
    labels = labels.to(device))[0]
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py", line 628, in forward
    **kwargs_decoder,
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'

I think the problem is somewhere here, but I actually have no idea since the last line of the traceback is referencing the forward() function - which isn’t explicitly part of my model.

loss = model(input_ids = src_tensors.to(device), 
                            decoder_input_ids = tgt_tensors.to(device),
                            attention_mask = src_attn_tensors.to(device),
                            decoder_attention_mask = tgt_attn_tensors.to(device),
                            labels = labels.to(device))[0]

How can I “access” the forward() function? (Where is it called?) I know the model works because I tried it with the original configuration (BERT & GPT-2).

T his is the model setup:

model = EncoderDecoderModel.from_encoder_decoder_pretrained('xlm-roberta-base', 'EleutherAI/gpt-neo-1.3B', tie_encoder_decoder=True)
model.decoder.config.use_cache = False
tokenizer = Tokenizer(max_token_len)
model.config.decoder_start_token_id = tokenizer.autotokenizer.bos_token_id
model.config.eos_token_id = tokenizer.autotokenizer.eos_token_id
model.config.max_length = max_token_len
model.config.no_repeat_ngram_size = 3

Training:

def train(**kwargs):
    print("Loading datasets...")
    train_dataset = WikiDataset(kwargs['base_path']+kwargs['src_train'], kwargs['base_path']+kwargs['tgt_train'])
    valid_dataset = WikiDataset(kwargs['base_path']+kwargs['src_valid'], kwargs['base_path']+kwargs['tgt_valid'], kwargs['base_path']+kwargs['ref_valid'], ref=True)
    print("Dataset loaded successfully")

    train_dl = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
    valid_dl = DataLoader(valid_dataset, batch_size=TRAIN_BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate': 0.0}
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=3e-5)
    
    if os.path.exists(kwargs['base_path']+kwargs["checkpoint_path"]):
        optimizer, eval_loss, start_epoch = load_checkpt(kwargs['base_path']+kwargs["checkpoint_path"], optimizer)
        print(f"Loading model from checkpoint with start epoch: {start_epoch} and loss: {eval_loss}")
        logging.info(f"Model loaded from saved checkpoint with start epoch: {start_epoch} and loss: {eval_loss}")
    
    train_model(start_epoch, eval_loss, (train_dl, valid_dl), optimizer, kwargs['base_path']+kwargs["checkpoint_path"], kwargs['base_path']+kwargs["best_model"])



def train_model(start_epoch, eval_loss, loaders, optimizer, check_pt_path, best_model_path):
    best_eval_loss = eval_loss
    print("Model training started...")
    for epoch in range(start_epoch, N_EPOCH):
        print(f"Epoch {epoch} running...")
        epoch_start_time = time.time()
        epoch_train_loss = 0
        epoch_eval_loss = 0
        model.train()
        for step, batch in enumerate(loaders[0]):
            src_tensors, src_attn_tensors, tgt_tensors, tgt_attn_tensors, labels = tokenizer.encode_batch(batch)
            optimizer.zero_grad()
            model.zero_grad()
            loss = model(input_ids = src_tensors.to(device), 
                            decoder_input_ids = tgt_tensors.to(device),
                            attention_mask = src_attn_tensors.to(device),
                            decoder_attention_mask = tgt_attn_tensors.to(device),
                            labels = labels.to(device))[0]
            if step == 0:
                epoch_train_loss = loss.item()
            else:
                epoch_train_loss = (1/2.0)*(epoch_train_loss + loss.item())
            
            loss.backward()
            optimizer.step()

            if (step+1) % LOG_EVERY == 0:
                print(f'Epoch: {epoch} | iter: {step+1} | avg. train loss: {epoch_train_loss} | time elapsed: {time.time() - epoch_start_time}')
                logging.info(f'Epoch: {epoch} | iter: {step+1} | avg. train loss: {epoch_train_loss} | time elapsed: {time.time() - epoch_start_time}')
        
        eval_start_time = time.time()
        epoch_eval_loss, bleu_score, sari_score = evaluate(loaders[1], epoch_eval_loss)
        epoch_eval_loss = epoch_eval_loss/TRAIN_BATCH_SIZE
        print(f'Completed Epoch: {epoch} | avg. eval loss: {epoch_eval_loss:.5f} | blue score: {bleu_score} | Sari score: {sari_score} | time elapsed: {time.time() - eval_start_time}')
        logging.info(f'Completed Epoch: {epoch} | avg. eval loss: {epoch_eval_loss:.5f} | blue score: {bleu_score}| Sari score: {sari_score} | time elapsed: {time.time() - eval_start_time}')

        check_pt = {
            'epoch': epoch+1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'eval_loss': epoch_eval_loss,
            'sari_score': sari_score,
            'bleu_score': bleu_score
        }
        check_pt_time = time.time()
        print("Saving Checkpoint.......")
        if epoch_eval_loss < best_eval_loss:
            print("New best model found")
            logging.info(f"New best model found")
            best_eval_loss = epoch_eval_loss
            save_model_checkpt(check_pt, True, check_pt_path, best_model_path)
        else:
            save_model_checkpt(check_pt, False, check_pt_path, best_model_path)  
        print(f"Checkpoint saved successfully with time: {time.time() - check_pt_time}")
        logging.info(f"Checkpoint saved successfully with time: {time.time() - check_pt_time}")

        gc.collect()
        torch.cuda.empty_cache()