Extending a GPT2 Model (Dialo)

I am trying to extend a pretrained model for a school project. I want to use reinforcement learning to finetune this model. I want my model to generate funny responses or interesting responses according to a classification head, then I want to backprop using rewards on the final funny_head / interesting_head scores. Towards this end I need to override the loss. Right now, I am just having trouble understanding the best way to implement the classification head. This is the current code I have but I would ideally like the multiple choice head to use the CLS token of the sentence for classification?

What should I pass in as the mc_token_ids to accomplish this? I would ideally like to take the CLS hiddenstates of the input sentence and output a score vector ie. [10,20] then take the softmax I would just use a linear layer but if my input sentence is “Hi Dialo, how are you doing today?” then the hidden state size is [1,11, 768] I don’t really want to just sum component-wise as I might loose something. I also don’t know if there is a CLS token pretrained or I would just use that?

Would appreciate any advice or pointers.

Thanks,
TSM

class CoolNet(GPT2LMHeadModel,GPT2DoubleHeadsModel):
    def __init__(self, config):
        super().__init__(config)
        self.tokenizer = GPT2TokenizerFast.from_pretrained("microsoft/DialoGPT-small")
        self.funny_head = GPT2LMHead(self.transformer.wte.weight, config)
        self.interesting_head = GPT2LMHead(self.transformer.wte.weight, config)

    def forward(
            self,
            input_ids=None,
            past_key_values=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
            use_cache=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
        ):  
      
        transformer_outputs = self.transformer(
            input_ids=input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        hidden_states = transformer_outputs.last_hidden_state
        pkvs = transformer_outputs.past_key_values
        ihl = self.interesting_head(hidden_states)
        fhl = self.funny_head(hidden_states)
        mhl = self.multiple_choice_head(hidden_states,torch.tensor([0]))
        lm_logits = fhl
        loss = None
        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=pkvs,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
model = CoolNet.from_pretrained("microsoft/DialoGPT-small")
tokenizer = GPT2TokenizerFast.from_pretrained("microsoft/DialoGPT-small"
inputs = tokenizer("Hi Dialo, how are you doing today?" + tokenizer.eos_token, return_tensors="pt").input_ids
output = model.generate(input_ids=inputs,output_attentions=True,skip_special_tokens=True, max_length=30, return_dict=True, pad_token_id=tokenizer.cls_token_id)