Reshaping logits when using Trainer

Props to jwa018 for providing the solution here. In my case, it was also necessary to deal with the data types of the arguments for BCEWithLogitsLoss(). Making sure they are of type .float() worked for me.

Now, my entire code looks as follows:

from transformers import Trainer
import torch
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            token_type_ids=inputs['token_type_ids']
        )
        loss = torch.nn.BCEWithLogitsLoss()(outputs['logits'].float(), inputs['labels'].float())
        return (loss, outputs) if return_outputs else loss


from transformers import DataCollatorWithPadding, TrainingArguments, BertTokenizer, BertForSequenceClassification
from datasets import load_dataset
# load dataset, tokenize, adapt columns, and apply datacollator
checkpoint = "bert-base-cased"
transformers_tokenizer = BertTokenizer.from_pretrained(checkpoint)
def transformers_tokenize_function(item):
    return transformers_tokenizer(item["text"], padding=True, truncation=True)
transformers_tokenized_datasets = (
    load_dataset("mdroth/transformers_issues_labels")
    .map(transformers_tokenize_function, batched=True)
    .remove_columns(column_names=["url", "text", "num_labels", "labels"])
    .rename_column("arr_labels", "labels")
)
transformers_data_collator = DataCollatorWithPadding(tokenizer=transformers_tokenizer)
# training arguments
training_args = TrainingArguments(
    "5_try_transformers_dataset",
    evaluation_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4
)
# model
num_labels = 57
transformers_model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)
# trainer
trainer = CustomTrainer(
    transformers_model,
    training_args,
    train_dataset=transformers_tokenized_datasets["dev"],
    eval_dataset=transformers_tokenized_datasets["dev"],
    data_collator=transformers_data_collator,
    tokenizer=transformers_tokenizer
)
# train
trainer.train()
2 Likes