I’m using my own loss function with the Trainer
. I need to pass a custom criterion
I wrote that will be used in the loss function to compute the loss. I have the following setup:
from transformers import Trainer, TrainingArguments
class MyTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
# I compute the loss here and I need my `criterion`
return loss
training_args = TrainingArguments(# the arguments...
)
# model = my model...
trainer = MyTrainer(model=model,
args=training_args,
# rest of the arguments...
)
I wonder if there is any way I can pass my custom criterion
object to the Trainer
either through the Trainer
or TrainingArguments
? Or, what is the best way to use my criterion without changing the Trainer
?
If I understand your scenario correctly you are creating your own child class, using the Trainer class as parent class, is that right? If so, you should be able to add any arguments you want to the child class, since it is yours to amend, no?
Just like in this example: How do I add arguments to a subclass in Python 3 - Stack Overflow
Let me know if this is helpful at all or if I complete misunderstood your question 
Cheers
Heiko
I might be late, I was also facing same issue, So Here is the solution :
Trainer takes extra parameter custom_class_weight
for weighted CrossEntropyLoss
class MCC(object):
def __init__(self, problem_type, batch_size, dataset, model):
self.problem_type = problem_type
self.batch_size = batch_size
self.dataset = dataset
self.model = model
self.cls_weights = weights_calculation()
self.custom_trainer = WeightedTrainer
self.trainer = self.WeightedTrainer(self.model, custom_class_weight = self.cls_weights)
def weights_calculation(self):
class_weights = (1 - (self.dataset['labels'].value_counts().sort_index()/len(self.dataset))).values
return class_weights
class WeightedTrainer(Trainer):
def __init__(self, custom_class_weight, **kwargs,):
super().__init__(**kwargs)
self.custom_class_weight = custom_class_weight
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
logits = outputs.get('logits')
labels = inputs.get('labels')
loss_func = nn.CrossEntropyLoss(weight = self.custom_class_weight)
loss = loss_func(logits, labels)
return (loss, outputs) if return_outputs else loss