I am trying to do a specific experiment with an input sequence and an expected target sequence using the T5ForConditionalTextGeneration model. Here is an example of my input and output formats:
Input Text: "Michael Jordan is a professor at Berkeley." Target Text: "2 Entities [Entity1] Michael Jordan is a Person [Entity2] Berkeley is a Place"
I use this target format to achieve 2 things:
1. the model has to recognize how many entities the given input text has
2. generate the respective number of entities it predicted and their type.
Although the main overall objective here would be to generate the exact target sequence, I want to ensure that my model has to do the above two tasks perfectly. That is,
Objective 1: It has to predict the correct number of entities in a given text
Objective 2: It has to generate the number of entities that it predicted the sentence might have
Objective 3: It has to generate a sequence as close as possible to the expected target sequence
Objective 3 will be taken care of by the CrossEntropyLoss that the T5 model computes. But how to add Objectives 1 and 2 to that loss? I tried the following for Objective 1:
lm_labels = batch["target_ids"] lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100 t5_outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=labels, ) loss = t5_outputs lm_logits = t5_outputs pred_target_ids = torch.argmax(lm_logits, dim=2) pred_ent_count_ids = pred_target_ids[:, :3] # to take the first 3 tokens corresponding to the string "2 Entities" from the logits gold_ent_count_ids = lm_labels[:, :3] # to take the first 3 tokens corresponding to the string "2 Entities" from the gold target sentence ent_count_ids_loss_fct = CrossEntropyLoss(ignore_index=-100) ent_count_ids_loss = ent_count_ids_loss_fct(pred_ent_count_ids.float(), gold_ent_count_ids.float()) loss += ent_count_ids_loss
1. is that the correct way to compute loss for the number of entities?
2. How can we compute a loss if the number of predicted entities is mismatched with the actual generated [Entity X]? The model has to be penalized if it says that there are “3 Entities” and generated only up to “[Entity2]” or generated more than “[Entity 3]” such as “[Entity 4]”.