Use LongT5 model for binary classification

I want to use LongT5 model for binary classification task. It is a seq2seq Encoder-Decoder generative model. Some suggest to make it output class-name as output. Its not an elegant solution.

I intend to modify final layer and attach binary classification head and have a loss function that’s meant for binary classifier. I am trying to replicate what is done for GPT2ForSequenceClassification

I tried to create custom forward function like below and replace the model’s existing forward function. I am not sure if it’s the right approach. I want to use LongT5 model for binary classification task, what’s the correct way to do it ?


model_name = "google/long-t5-local-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
model = LongT5ForConditionalGeneration.from_pretrained(model_name, config=config)

def forward(
    self,
    input_ids=None,
    attention_mask=None,
    decoder_input_ids=None,
    decoder_attention_mask=None,
    head_mask=None,
    decoder_head_mask=None,
    cross_attn_head_mask=None,
    encoder_outputs=None,
    past_key_values=None,
    inputs_embeds=None,
    decoder_inputs_embeds=None,
    labels=None,
    use_cache=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):
    seq2seq_output = LongT5Model.forward(self,
                                         input_ids,
                                         attention_mask,
                                         decoder_input_ids,
                                         decoder_attention_mask,
                                         head_mask,
                                         decoder_head_mask,
                                         cross_attn_head_mask,
                                         encoder_outputs,
                                         past_key_values,
                                         inputs_embeds,
                                         decoder_inputs_embeds,
                                         labels,
                                         use_cache,
                                         output_attentions,
                                         output_hidden_states,
                                         return_dict
                                         )

    # Get the encoder's last hidden state
    hidden_states = encoder_outputs[0]  # Assuming encoder_outputs is a tuple
    classifier = torch.nn.Sequential(
        torch.nn.Linear(config.hidden_size, config.hidden_size),
        torch.nn.ReLU(),
        torch.nn.Dropout(0.1),  # Optional dropout
        torch.nn.Linear(config.hidden_size, num_labels, bias=False)
    )
    # Pass the hidden states through the classifier
    logits = classifier(hidden_states[:, 0, :])  # Take the first token's representation
    loss_fct = CrossEntropyLoss()
    loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))

    return Seq2SeqLMOutput(
        loss=loss,
        logits=logits,
        past_key_values=seq2seq_output.past_key_values,
        decoder_hidden_states=seq2seq_output.hidden_states,
        decoder_attentions=seq2seq_output.attentions,
        cross_attentions=seq2seq_output.cross_attentions,
        encoder_last_hidden_state=seq2seq_output.last_hidden_state,
        encoder_hidden_states=seq2seq_output.encoder_hidden_states,
        encoder_attentions=seq2seq_output.encoder_attentions
    )

# Replace the model's forward method
model.forward = forward.__get__(model, type(model))
# Example usage:
inputs = tokenizer("Example sentence", return_tensors="pt")

decoder_input_ids = tokenizer("This trace", return_tensors="pt").input_ids
with torch.no_grad():
    logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits
1 Like