I wish to train a model that classifies each token in a sequence input. I have a model architecture in place, and I want to use torch.nn.binary_cross_entropy_with_logits() to calculate loss.
logits = model(inputs)
inputs shape: (batch_size, max_seq_length)
logits shape: (batch_size, max_seq_length)
The compute_loss override example in the docs uses .view(-1) while performing the loss calculation.
This is the code I came up with after reading the docs:
class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.get("labels") outputs = model(**inputs) loss = F.binary_cross_entropy_with_logits( input=outputs.get("logits").view(-1), target=labels.view(-1), ) return (loss, outputs) if return_outputs else loss
I have two questions:
- Can I override compute_loss in Trainer or should I use Seq2SeqTrainer? The output sequence does not need to be calculated using a beam search.
- Is .view(-1) still required in the method I intend to use?