I have overridden the compute_loss
function as follows:
class CustomTrainer(Seq2SeqTrainer):
def compute_loss(self, model, inputs, return_outputs=False):
logits, g_output = model(**inputs)
labels = inputs.get(âlabelsâ)
graph_labels = inputs.get(âsubgraphâ)
loss_fn = torch.nn.CrossEntropyLoss() # by default, index with value -100 will be ignored
extra_loss_fct = torch.nn.BCEWithLogitsLoss()
loss = loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) +
ALPHA * extra_loss_fct(g_output, graph_labels.float())
return (loss, logits) if return_outputs else loss
Also, Iâve overridden the forward
function as follows:
class Text2MultiTargetsV2(PreTrainedModel):
def init(self, config):
super().init(config)
self.transformer = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
self.mp_layer = MeanPooling()
self.subgraph = torch.nn.Linear(D_EMBEDDING, D_GRAPH)def forward(self, input_ids, attention_mask, labels=None, subgraph=None): encoder_output = self.transformer.get_encoder()( input_ids=input_ids, attention_mask=attention_mask ) # compute extra output for the subgraph of shape [B, N+E] x_mean = self.mp_layer(encoder_output[0], attention_mask) subg = self.subgraph(x_mean) # sigmoid to be included in the loss nlp_out = self.transformer( input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits return nlp_out, subg def generate(self, model_inputs, **kwargs): return self.transformer.generate( model_inputs, **kwargs)
Then I used the custom trainer to train the model above.
trainer = CustomTrainer(
model,
args,
train_dataset=tokenized_datasets[âtrainâ],
eval_dataset=tokenized_datasets[âvalidâ],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
The training went well and I was able to predict correct result.
The only problem is that whenever a training epoch finishes, the evaluation loss is a negative number, which is weird.
I put a debug in the compute_loss
function, and found that the function is not called during the evaluation.
Could you tell me why the overridden function not get called?
Many thanks