Reshaping logits when using Trainer

Hi!

I am using the Trainer class to train a sequence classifier with num_labels=57 (= the number of classes in the dataset). Using a batch size of 4, I get:
ValueError: Expected input batch_size (4) to match target batch_size (228).

Since 228 = 4*57, I suppose the target batch needs to be reshaped to size (4, 57). But how can this be done with the Trainer class? I have tried using the preprocess_logits_for_metrics argument but it didn’t help.

I would like to use Trainer but any approach is welcome.

Thanks a lot for your help!
Matthias

Here is the code to reproduce the error:

from transformers import BertTokenizer, BertForSequenceClassification
from transformers import DataCollatorWithPadding, TrainingArguments, Trainer
from datasets import load_dataset
import torch
# load dataset, tokenize, adapt columns, and apply datacollator
checkpoint = "bert-base-cased"
transformers_tokenizer = BertTokenizer.from_pretrained(checkpoint)
def transformers_tokenize_function(item):
    return transformers_tokenizer(item["text"], padding=True, truncation=True)
transformers_tokenized_datasets = (
    load_dataset("mdroth/transformers_issues_labels")
    .map(transformers_tokenize_function, batched=True)
    .remove_columns(column_names=["url", "text", "num_labels", "labels"])
    .rename_column("arr_labels", "labels")
)
transformers_data_collator = DataCollatorWithPadding(tokenizer=transformers_tokenizer)
# training arguments
training_args = TrainingArguments(
    "5_try_transformers_dataset",
    evaluation_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4
)
# model
num_labels = 57
transformers_model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)
# trainer
trainer = Trainer(
    transformers_model,
    training_args,
    train_dataset=transformers_tokenized_datasets["dev"],
    eval_dataset=transformers_tokenized_datasets["dev"],
    data_collator=transformers_data_collator,
    tokenizer=transformers_tokenizer,
    #preprocess_logits_for_metrics=lambda x: torch.reshape(x, (-1, num_labels))
)
# train
trainer.train()

Props to jwa018 for providing the solution here. In my case, it was also necessary to deal with the data types of the arguments for BCEWithLogitsLoss(). Making sure they are of type .float() worked for me.

Now, my entire code looks as follows:

from transformers import Trainer
import torch
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            token_type_ids=inputs['token_type_ids']
        )
        loss = torch.nn.BCEWithLogitsLoss()(outputs['logits'].float(), inputs['labels'].float())
        return (loss, outputs) if return_outputs else loss


from transformers import DataCollatorWithPadding, TrainingArguments, BertTokenizer, BertForSequenceClassification
from datasets import load_dataset
# load dataset, tokenize, adapt columns, and apply datacollator
checkpoint = "bert-base-cased"
transformers_tokenizer = BertTokenizer.from_pretrained(checkpoint)
def transformers_tokenize_function(item):
    return transformers_tokenizer(item["text"], padding=True, truncation=True)
transformers_tokenized_datasets = (
    load_dataset("mdroth/transformers_issues_labels")
    .map(transformers_tokenize_function, batched=True)
    .remove_columns(column_names=["url", "text", "num_labels", "labels"])
    .rename_column("arr_labels", "labels")
)
transformers_data_collator = DataCollatorWithPadding(tokenizer=transformers_tokenizer)
# training arguments
training_args = TrainingArguments(
    "5_try_transformers_dataset",
    evaluation_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4
)
# model
num_labels = 57
transformers_model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)
# trainer
trainer = CustomTrainer(
    transformers_model,
    training_args,
    train_dataset=transformers_tokenized_datasets["dev"],
    eval_dataset=transformers_tokenized_datasets["dev"],
    data_collator=transformers_data_collator,
    tokenizer=transformers_tokenizer
)
# train
trainer.train()
2 Likes