RuntimeError: The expanded size of the tensor (31) must match the existing size (7) at non-singleton dimension 0. Target sizes: [31]. Tensor sizes: [7]

I’ve encountered the same issue every time I try to run trainer.train(). The trainer fails upon the evaluation/compute_metrics step. I’ve tried numerous tweaks suggested in other places but none of them have worked.

If you have any ideas, I’d be very grateful to hear them.

Here’s the code, mostly copied from other places:

nltk.download('punkt')

f = 'dataset.json'
df = pd.read_json(f, lines=True)
df = df[['abstract', 'title']]
df = df.loc[:20]

dataset = Dataset.from_pandas(df)
dataset = dataset.train_test_split(train_size=0.9, seed=42)
dataset_clean = dataset['train'].train_test_split(train_size=0.88888, seed=42)
dataset_clean['validation'] = dataset_clean.pop('test')
dataset_clean['test'] = dataset['test']
dataset = dataset_clean

checkpoint = 'facebook/bart-base'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

max_input_length = 512
max_target_length = 30


def preprocess_function(examples):
    model_inputs = tokenizer(
        examples['abstract'],
        max_length=max_input_length,
        padding='max_length',
        truncation=True
    )
    labels = tokenizer(
        examples['title'],
        max_length=max_target_length,
        padding='max_length',
        truncation=True
    )
    model_inputs['labels'] = labels['input_ids']
    return model_inputs


tokenized_datasets = dataset.map(preprocess_function, batched=True)

rouge_score = evaluate.load('rouge')

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

# Try 8 and 8 on GPU as starting point
batch_size = 1
num_train_epochs = 2

logging_steps = len(tokenized_datasets['train']) // batch_size
model_name = checkpoint.split('/')[-1]

args = Seq2SeqTrainingArguments(
    output_dir='{}-finetuned-arxiv'.format(model_name),
    evaluation_strategy='epoch',
    learning_rate=5.6e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps,
    # push_to_hub=True
)


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(
        decoded_preds, decoded_labels)

    result = rouge_score.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result['gen_len'] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result


data_collator = DataCollatorForSeq2Seq(
    tokenizer, model=model, return_tensors='pt')

tokenized_datasets = tokenized_datasets.remove_columns(
    dataset['train'].column_names
)

# Tried some tweaks here, as an example
model.generation_config.max_new_tokens = max_target_length
model.config.max_length = max_target_length

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()