I wish to train a model that classifies each token in a sequence input. I have a model architecture in place, and I want to use torch.nn.binary_cross_entropy_with_logits() to calculate loss.
logits = model(inputs)
inputs shape: (batch_size, max_seq_length)
logits shape: (batch_size, max_seq_length)
The compute_loss override example in the docs uses .view(-1) while performing the loss calculation.
This is the code I came up with after reading the docs:
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
outputs = model(**inputs)
loss = F.binary_cross_entropy_with_logits(
input=outputs.get("logits").view(-1),
target=labels.view(-1),
)
return (loss, outputs) if return_outputs else loss
I have two questions:
- Can I override compute_loss in Trainer or should I use Seq2SeqTrainer? The output sequence does not need to be calculated using a beam search.
- Is .view(-1) still required in the method I intend to use?