Getting the following error "valueError: You have to specify either decoder_input_ids or decoder_inputs_embeds"

@nielsr I am not sure about this as I am just a beginner, however, I did manage to get past that error and got a new one regarding index outof range.

Code:

import random
import numpy as np
import torch
import pandas as pd
import json
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from trl import RewardTrainer, SFTTrainer
from datasets import Dataset
from transformers import (
    Trainer,
    TrainingArguments,
    default_data_collator,
    DataCollatorForLanguageModeling
)

device="cuda"
df = pd.read_parquet("/raid/ganesh/vishak/pranav/ss2113/train_rlhf.parquet")
model = AutoModelForSeq2SeqLM.from_pretrained("summarization_policy_new/").to(device)
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base", truncation=True, max_length=256, padding="max_length")
text = df.iloc[2]["prompt"]
tokenized_text = tokenizer(text, return_tensors="pt", max_length=256).to(device)

tokenizer.decode(model.generate(**tokenized_text)[0])

df = pd.read_parquet("/raid/ganesh/vishak/pranav/ss2113/test_summ.parquet")
df = df[:10]
raw_dataset = Dataset.from_pandas(df)
raw_dataset

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base",truncation=True, max_length=256, padding="max_length")
model = AutoModelForSeq2SeqLM.from_pretrained("/raid/ganesh/vishak/pranav/ss2113/summarization_policy_new").to(device)
tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.end_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id

# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# def formatting_func(examples):
#     kwargs = {"padding": "max_length",
#               "truncation": True,
#               "max_length": 256,
#               "return_tensors": "pt"
#               }
#     prompt_plus_chosen_response = examples["prompt"] + "\n" + examples["chosen"]
#     prompt_plus_rejected_response = examples["prompt"] + "\n" + examples["rejected"]
#     tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)
#     tokens_rejected = tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)
#     # Generate summaries for chosen and rejected responses
#     generated_chosen = model.generate(input_ids=tokens_chosen["input_ids"].to(device), attention_mask=tokens_chosen["attention_mask"].to(device), decoder_input_ids=tokens_chosen["input_ids"].to(device))
#     generated_rejected = model.generate(input_ids=tokens_rejected["input_ids"].to(device), attention_mask=tokens_rejected["attention_mask"].to(device), decoder_input_ids=tokens_rejected["input_ids"].to(device))

#     # Decode the generated summaries
#     decoded_chosen = tokenizer.decode(generated_chosen[0], skip_special_tokens=True)
#     decoded_rejected = tokenizer.decode(generated_rejected[0], skip_special_tokens=True)
#     return {
#         "input_ids_chosen": tokens_chosen["input_ids"][0], "attention_mask_chosen": tokens_chosen["attention_mask"][0],
#         "input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0],
#         "generated_chosen": decoded_chosen, "generated_rejected": decoded_rejected
#     }

# formatted_dataset = raw_dataset.map(formatting_func)
# formatted_dataset = formatted_dataset.train_test_split()
def formatting_func(examples):
    kwargs = {
        "padding": "max_length",
        "truncation": True,
        "max_length": 256,
        "return_tensors": "pt"
    }
    prompt_plus_chosen_response = examples["prompt"] + "\n" + examples["chosen"]
    prompt_plus_rejected_response = examples["prompt"] + "\n" + examples["rejected"]
    tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)
    tokens_rejected = tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)

    return {
        "input_ids_chosen": tokens_chosen["input_ids"][0], "attention_mask_chosen": tokens_chosen["attention_mask"][0],
        "input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0]
    }

formatted_dataset = raw_dataset.map(formatting_func)
formatted_dataset = formatted_dataset.train_test_split()

# Define a custom collator to handle the additional input_ids for decoder_input_ids
def custom_collate_fn(batch):
    input_ids_chosen = torch.stack([example["input_ids_chosen"] for example in batch])
    attention_mask_chosen = torch.stack([example["attention_mask_chosen"] for example in batch])
    input_ids_rejected = torch.stack([example["input_ids_rejected"] for example in batch])
    attention_mask_rejected = torch.stack([example["attention_mask_rejected"] for example in batch])
    
    return {
        "input_ids_chosen": input_ids_chosen,
        "attention_mask_chosen": attention_mask_chosen,
        "input_ids_rejected": input_ids_rejected,
        "attention_mask_rejected": attention_mask_rejected
    }

### Loading the TRL reward trainer and training the trainer
training_args = TrainingArguments(
        output_dir="t5_rm_checkpoint/",
        num_train_epochs=1,
        logging_steps=10,
        gradient_accumulation_steps=1,
        save_strategy="steps",
        evaluation_strategy="steps",
        per_device_train_batch_size=2,
        per_device_eval_batch_size=1,
        eval_accumulation_steps=1,
        eval_steps=500,
        save_steps=500,
        warmup_steps=100,
        logging_dir="./logs",
        learning_rate=1e-5,
        save_total_limit=1,
        no_cuda=True
    )

# trainer = RewardTrainer(model=model,
#                         tokenizer=tokenizer,
#                         train_dataset=formatted_dataset['train'],
#                         eval_dataset=formatted_dataset['test'],
#                         args= training_args
#                         )
# trainer.train()
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=formatted_dataset['train'],
    eval_dataset=formatted_dataset['test'],
    args=training_args,
    data_collator=custom_collate_fn
)
trainer.train()
trainer.save_model("t5_rm_model/")
Errror:-

Traceback (most recent call last):
  File "/raid/ganesh/vishak/pranav/ss2113/rlhf_reward_part2.py", line 135, in <module>
    trainer.train()
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
    return inner_training_loop(
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/transformers/trainer.py", line 1870, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/accelerate/data_loader.py", line 384, in __iter__
    current_batch = next(dataloader_iter)
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2807, in __getitems__
    batch = self.__getitem__(keys)
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2803, in __getitem__
    return self._getitem(key)
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2787, in _getitem
    pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 588, in query_table
    pa_subtable = _query_table_with_indices_mapping(table, key, indices=indices)
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 75, in _query_table_with_indices_mapping
    return _query_table(table, [indices.fast_slice(i, 1).column(0)[0].as_py() for i in key])
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 100, in _query_table
    return table.fast_gather(key % table.num_rows)
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/table.py", line 134, in fast_gather
    [
  File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/table.py", line 135, in <listcomp>
    self._batches[batch_idx].slice(i - self._offsets[batch_idx], 1)
IndexError: list index out of range

(Please help with this if possible. If I’m not wrong with this error, in one way or the other it can be resolved despite using this seq2seq model)