I want to use LongT5 model for binary classification task. It is a seq2seq Encoder-Decoder generative model. Some suggest to make it output class-name as output. Its not an elegant solution.
I intend to modify final layer and attach binary classification head and have a loss function that’s meant for binary classifier. I am trying to replicate what is done for GPT2ForSequenceClassification
I tried to create custom forward function like below and replace the model’s existing forward function. I am not sure if it’s the right approach. I want to use LongT5 model for binary classification task, what’s the correct way to do it ?
model_name = "google/long-t5-local-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
model = LongT5ForConditionalGeneration.from_pretrained(model_name, config=config)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
seq2seq_output = LongT5Model.forward(self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
head_mask,
decoder_head_mask,
cross_attn_head_mask,
encoder_outputs,
past_key_values,
inputs_embeds,
decoder_inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict
)
# Get the encoder's last hidden state
hidden_states = encoder_outputs[0] # Assuming encoder_outputs is a tuple
classifier = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.ReLU(),
torch.nn.Dropout(0.1), # Optional dropout
torch.nn.Linear(config.hidden_size, num_labels, bias=False)
)
# Pass the hidden states through the classifier
logits = classifier(hidden_states[:, 0, :]) # Take the first token's representation
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))
return Seq2SeqLMOutput(
loss=loss,
logits=logits,
past_key_values=seq2seq_output.past_key_values,
decoder_hidden_states=seq2seq_output.hidden_states,
decoder_attentions=seq2seq_output.attentions,
cross_attentions=seq2seq_output.cross_attentions,
encoder_last_hidden_state=seq2seq_output.last_hidden_state,
encoder_hidden_states=seq2seq_output.encoder_hidden_states,
encoder_attentions=seq2seq_output.encoder_attentions
)
# Replace the model's forward method
model.forward = forward.__get__(model, type(model))
# Example usage:
inputs = tokenizer("Example sentence", return_tensors="pt")
decoder_input_ids = tokenizer("This trace", return_tensors="pt").input_ids
with torch.no_grad():
logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits