Hi,
I am trying to fine tune MT5 for multitask question answer and question generation similar to @valhalla model. I prepared the dataset by using datasets library as follows:
train_dataset = Dataset.from_pandas(pd.DataFrame(generate_data(mode="train")))
valid_dataset = Dataset.from_pandas(pd.DataFrame(generate_data(mode="valid")))
processor = DataProcessor(
tokenizer,
max_source_length=data_args.max_source_length,
max_target_length=data_args.max_target_length
)
train_dataset = processor.process(train_dataset)
valid_dataset = processor.process(valid_dataset)
columns = ["source_ids", "target_ids", "attention_mask"]
However, when I try to train my model as below:
!python3 run_multi.py \
--model_name_or_path google/mt5-small \
--model_type mt5 \
--tokenizer_name_or_path mt5_qg_tokenizer \
--output_dir mt5-small-multi \
--train_file_path data/train_data_qa_qg_mt5.pt \
--valid_file_path data/valid_data_qa_qg_mt5.pt \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--gradient_accumulation_steps 2 \
--learning_rate 1e-4 \
--num_train_epochs 2 \
--seed 42 \
--do_train \
--do_eval \
--logging_steps 100 \
--prediction_loss_only True
it says KeyError: 'source_ids'
I am sure the dataset has “source_ids” field.
train_dataset=torch.load(r"train_data_qa_qg_mt5.pt")
train_dataset
> Dataset({
features: ['attention_mask', 'source_ids', 'source_text', 'target_ids', 'target_text', 'task'],
num_rows: 3449
})
What might cause this?
The versions of the libraries are:
transformers == 4.4.2
datasets == 1.5.0
Thank you for the reply in advance.