Hi!
I am using the Trainer
class to train a sequence classifier with num_labels=57
(= the number of classes in the dataset). Using a batch size of 4, I get:
ValueError: Expected input batch_size (4) to match target batch_size (228).
Since 228 = 4*57, I suppose the target batch needs to be reshaped to size (4, 57). But how can this be done with the Trainer
class? I have tried using the preprocess_logits_for_metrics
argument but it didn’t help.
I would like to use Trainer
but any approach is welcome.
Thanks a lot for your help!
Matthias
Here is the code to reproduce the error:
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import DataCollatorWithPadding, TrainingArguments, Trainer
from datasets import load_dataset
import torch
# load dataset, tokenize, adapt columns, and apply datacollator
checkpoint = "bert-base-cased"
transformers_tokenizer = BertTokenizer.from_pretrained(checkpoint)
def transformers_tokenize_function(item):
return transformers_tokenizer(item["text"], padding=True, truncation=True)
transformers_tokenized_datasets = (
load_dataset("mdroth/transformers_issues_labels")
.map(transformers_tokenize_function, batched=True)
.remove_columns(column_names=["url", "text", "num_labels", "labels"])
.rename_column("arr_labels", "labels")
)
transformers_data_collator = DataCollatorWithPadding(tokenizer=transformers_tokenizer)
# training arguments
training_args = TrainingArguments(
"5_try_transformers_dataset",
evaluation_strategy="epoch",
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4
)
# model
num_labels = 57
transformers_model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)
# trainer
trainer = Trainer(
transformers_model,
training_args,
train_dataset=transformers_tokenized_datasets["dev"],
eval_dataset=transformers_tokenized_datasets["dev"],
data_collator=transformers_data_collator,
tokenizer=transformers_tokenizer,
#preprocess_logits_for_metrics=lambda x: torch.reshape(x, (-1, num_labels))
)
# train
trainer.train()