MT5 Fine Tuning - KeyError: 'source_ids'

Hi,

I am trying to fine tune MT5 for multitask question answer and question generation similar to @valhalla model. I prepared the dataset by using datasets library as follows:

    train_dataset = Dataset.from_pandas(pd.DataFrame(generate_data(mode="train")))
    valid_dataset = Dataset.from_pandas(pd.DataFrame(generate_data(mode="valid")))

    processor = DataProcessor(
        tokenizer,
        max_source_length=data_args.max_source_length,
        max_target_length=data_args.max_target_length
    )

    
    train_dataset = processor.process(train_dataset)
    valid_dataset = processor.process(valid_dataset)

    columns = ["source_ids", "target_ids", "attention_mask"]

However, when I try to train my model as below:

!python3 run_multi.py \
    --model_name_or_path google/mt5-small \
    --model_type mt5 \
    --tokenizer_name_or_path mt5_qg_tokenizer \
    --output_dir mt5-small-multi \
    --train_file_path data/train_data_qa_qg_mt5.pt \
    --valid_file_path data/valid_data_qa_qg_mt5.pt \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 16 \
    --gradient_accumulation_steps 2 \
    --learning_rate 1e-4 \
    --num_train_epochs 2 \
    --seed 42 \
    --do_train \
    --do_eval \
    --logging_steps 100 \
    --prediction_loss_only True

it says KeyError: 'source_ids'

I am sure the dataset has “source_ids” field.

train_dataset=torch.load(r"train_data_qa_qg_mt5.pt")
train_dataset
> Dataset({
    features: ['attention_mask', 'source_ids', 'source_text', 'target_ids', 'target_text', 'task'],
    num_rows: 3449
})

What might cause this?

The versions of the libraries are:
transformers == 4.4.2
datasets == 1.5.0

Thank you for the reply in advance.

1 Like

I found out that my data collator takes only “attention mask” as inputs. I do not know where the other fields disappear :pensive:

class T2TDataCollator():
    def __init__(self, tokenizer,  mode='training'):
        self.tokenizer = tokenizer

        self.mode = mode


    def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
        """
        Take a list of samples from a Dataset and collate them into a batch.
        Returns:
            A dictionary of tensors
        """

        input_ids = torch.stack([example['source_ids'] for example in batch])
        target_ids = torch.stack([example['target_ids'] for example in batch])
        attention_mask = torch.stack([example['attention_mask'] for example in batch])
        .....

The error is thrown in this part.

Is your script run_multi using Trainer? By default the Trainer removes any column that is not in your model signature (like “source_ids”), so you should pass --remove_unused_columns False in your command.

yes I was using Trainer. But I solved the problem. I was loading dataset via Datasets library, when I replaced it with nlp.load_dataset, it worked seamlessly. But thank you for the response, I did not pass --remove_unused_columns False