@nielsr I am not sure about this as I am just a beginner, however, I did manage to get past that error and got a new one regarding index outof range.
Code:
import random
import numpy as np
import torch
import pandas as pd
import json
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from trl import RewardTrainer, SFTTrainer
from datasets import Dataset
from transformers import (
Trainer,
TrainingArguments,
default_data_collator,
DataCollatorForLanguageModeling
)
device="cuda"
df = pd.read_parquet("/raid/ganesh/vishak/pranav/ss2113/train_rlhf.parquet")
model = AutoModelForSeq2SeqLM.from_pretrained("summarization_policy_new/").to(device)
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base", truncation=True, max_length=256, padding="max_length")
text = df.iloc[2]["prompt"]
tokenized_text = tokenizer(text, return_tensors="pt", max_length=256).to(device)
tokenizer.decode(model.generate(**tokenized_text)[0])
df = pd.read_parquet("/raid/ganesh/vishak/pranav/ss2113/test_summ.parquet")
df = df[:10]
raw_dataset = Dataset.from_pandas(df)
raw_dataset
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base",truncation=True, max_length=256, padding="max_length")
model = AutoModelForSeq2SeqLM.from_pretrained("/raid/ganesh/vishak/pranav/ss2113/summarization_policy_new").to(device)
tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.end_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# def formatting_func(examples):
# kwargs = {"padding": "max_length",
# "truncation": True,
# "max_length": 256,
# "return_tensors": "pt"
# }
# prompt_plus_chosen_response = examples["prompt"] + "\n" + examples["chosen"]
# prompt_plus_rejected_response = examples["prompt"] + "\n" + examples["rejected"]
# tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)
# tokens_rejected = tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)
# # Generate summaries for chosen and rejected responses
# generated_chosen = model.generate(input_ids=tokens_chosen["input_ids"].to(device), attention_mask=tokens_chosen["attention_mask"].to(device), decoder_input_ids=tokens_chosen["input_ids"].to(device))
# generated_rejected = model.generate(input_ids=tokens_rejected["input_ids"].to(device), attention_mask=tokens_rejected["attention_mask"].to(device), decoder_input_ids=tokens_rejected["input_ids"].to(device))
# # Decode the generated summaries
# decoded_chosen = tokenizer.decode(generated_chosen[0], skip_special_tokens=True)
# decoded_rejected = tokenizer.decode(generated_rejected[0], skip_special_tokens=True)
# return {
# "input_ids_chosen": tokens_chosen["input_ids"][0], "attention_mask_chosen": tokens_chosen["attention_mask"][0],
# "input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0],
# "generated_chosen": decoded_chosen, "generated_rejected": decoded_rejected
# }
# formatted_dataset = raw_dataset.map(formatting_func)
# formatted_dataset = formatted_dataset.train_test_split()
def formatting_func(examples):
kwargs = {
"padding": "max_length",
"truncation": True,
"max_length": 256,
"return_tensors": "pt"
}
prompt_plus_chosen_response = examples["prompt"] + "\n" + examples["chosen"]
prompt_plus_rejected_response = examples["prompt"] + "\n" + examples["rejected"]
tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)
tokens_rejected = tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)
return {
"input_ids_chosen": tokens_chosen["input_ids"][0], "attention_mask_chosen": tokens_chosen["attention_mask"][0],
"input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0]
}
formatted_dataset = raw_dataset.map(formatting_func)
formatted_dataset = formatted_dataset.train_test_split()
# Define a custom collator to handle the additional input_ids for decoder_input_ids
def custom_collate_fn(batch):
input_ids_chosen = torch.stack([example["input_ids_chosen"] for example in batch])
attention_mask_chosen = torch.stack([example["attention_mask_chosen"] for example in batch])
input_ids_rejected = torch.stack([example["input_ids_rejected"] for example in batch])
attention_mask_rejected = torch.stack([example["attention_mask_rejected"] for example in batch])
return {
"input_ids_chosen": input_ids_chosen,
"attention_mask_chosen": attention_mask_chosen,
"input_ids_rejected": input_ids_rejected,
"attention_mask_rejected": attention_mask_rejected
}
### Loading the TRL reward trainer and training the trainer
training_args = TrainingArguments(
output_dir="t5_rm_checkpoint/",
num_train_epochs=1,
logging_steps=10,
gradient_accumulation_steps=1,
save_strategy="steps",
evaluation_strategy="steps",
per_device_train_batch_size=2,
per_device_eval_batch_size=1,
eval_accumulation_steps=1,
eval_steps=500,
save_steps=500,
warmup_steps=100,
logging_dir="./logs",
learning_rate=1e-5,
save_total_limit=1,
no_cuda=True
)
# trainer = RewardTrainer(model=model,
# tokenizer=tokenizer,
# train_dataset=formatted_dataset['train'],
# eval_dataset=formatted_dataset['test'],
# args= training_args
# )
# trainer.train()
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=formatted_dataset['train'],
eval_dataset=formatted_dataset['test'],
args=training_args,
data_collator=custom_collate_fn
)
trainer.train()
trainer.save_model("t5_rm_model/")
Errror:-
Traceback (most recent call last):
File "/raid/ganesh/vishak/pranav/ss2113/rlhf_reward_part2.py", line 135, in <module>
trainer.train()
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
return inner_training_loop(
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/transformers/trainer.py", line 1870, in _inner_training_loop
for step, inputs in enumerate(epoch_iterator):
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/accelerate/data_loader.py", line 384, in __iter__
current_batch = next(dataloader_iter)
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = self.dataset.__getitems__(possibly_batched_index)
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2807, in __getitems__
batch = self.__getitem__(keys)
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2803, in __getitem__
return self._getitem(key)
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2787, in _getitem
pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 588, in query_table
pa_subtable = _query_table_with_indices_mapping(table, key, indices=indices)
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 75, in _query_table_with_indices_mapping
return _query_table(table, [indices.fast_slice(i, 1).column(0)[0].as_py() for i in key])
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 100, in _query_table
return table.fast_gather(key % table.num_rows)
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/table.py", line 134, in fast_gather
[
File "/raid/ganesh/vishak/miniconda3/envs/perturbation-env/lib/python3.10/site-packages/datasets/table.py", line 135, in <listcomp>
self._batches[batch_idx].slice(i - self._offsets[batch_idx], 1)
IndexError: list index out of range
(Please help with this if possible. If I’m not wrong with this error, in one way or the other it can be resolved despite using this seq2seq model)