I am trying to extend a pretrained model for a school project. I want to use reinforcement learning to finetune this model. I want my model to generate funny responses or interesting responses according to a classification head, then I want to backprop using rewards on the final funny_head / interesting_head scores. Towards this end I need to override the loss. Right now, I am just having trouble understanding the best way to implement the classification head. This is the current code I have but I would ideally like the multiple choice head to use the CLS token of the sentence for classification?
What should I pass in as the mc_token_ids to accomplish this? I would ideally like to take the CLS hiddenstates of the input sentence and output a score vector ie. [10,20] then take the softmax I would just use a linear layer but if my input sentence is “Hi Dialo, how are you doing today?” then the hidden state size is [1,11, 768] I don’t really want to just sum component-wise as I might loose something. I also don’t know if there is a CLS token pretrained or I would just use that?
Would appreciate any advice or pointers.
Thanks,
TSM
class CoolNet(GPT2LMHeadModel,GPT2DoubleHeadsModel):
def __init__(self, config):
super().__init__(config)
self.tokenizer = GPT2TokenizerFast.from_pretrained("microsoft/DialoGPT-small")
self.funny_head = GPT2LMHead(self.transformer.wte.weight, config)
self.interesting_head = GPT2LMHead(self.transformer.wte.weight, config)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
transformer_outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs.last_hidden_state
pkvs = transformer_outputs.past_key_values
ihl = self.interesting_head(hidden_states)
fhl = self.funny_head(hidden_states)
mhl = self.multiple_choice_head(hidden_states,torch.tensor([0]))
lm_logits = fhl
loss = None
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=pkvs,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
model = CoolNet.from_pretrained("microsoft/DialoGPT-small")
tokenizer = GPT2TokenizerFast.from_pretrained("microsoft/DialoGPT-small"
inputs = tokenizer("Hi Dialo, how are you doing today?" + tokenizer.eos_token, return_tensors="pt").input_ids
output = model.generate(input_ids=inputs,output_attentions=True,skip_special_tokens=True, max_length=30, return_dict=True, pad_token_id=tokenizer.cls_token_id)