I am using FlanT5 for learning to generate labels for a classification task (Discourse Relation Classification). The expected input and output for the task are as follows:
Input: "what discourse relation holds between sent1 and sent2? sent1: The company is reporting the highest profits for Q3 2023. sent2: The recent increase in the company's market cap has impacted market sentiments positively. "
Output: "Cause"
The code for FlanT5 model is as such (I force the model to generate only the labels seen in training data):
MODEL_NAME = "google/flan-t5-small"
class T5Classifier(nn.Module):
def __init__(self):
super().__init__()
self.num_labels = num_labels
self.label_space = tokenizer(label_space, add_special_tokens=False).input_ids
config = T5Config.from_pretrained(MODEL_NAME,
force_words_ids=self.label_space)
self.t5 = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, num_labels=self.num_labels)
self.t5 = self.t5.to(device)
def forward(self, input_ids, attention_masks, label_input_ids, label_attention_masks):
outputs = self.t5(
input_ids=input_ids,
attention_mask=attention_masks,
labels=label_input_ids,
decoder_attention_mask=label_attention_masks)
loss = outputs.loss
logits = outputs.logits
return loss, logits
def generator(self, input_ids, attention_mask):
generated_ids = self.t5.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample = True
)
return generated_ids
My model is learning to generate only the majority label. There is no diversity in the output. Can someone comment on how to make my model learn the labels? Full code can be found in my github repo.