Difference between model.generate() and model() outputs


I have been experimenting with transformers library and I wanted to fine-tune GPT-2 with a new custom_loss function that basically adds the cosine similarity between a certain topic and the generated results (so it is penalizing if the generated phrase has a low cosine similarity with the topic “sports” for example)

Here it is my compute_loss function

class GPT2Trainer(Trainer):

    def __init__(self, *args, **kwargs):

        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):

        # implement custom logic here

        custom_loss = super().compute_loss(model, inputs, return_outputs)

        outputs = model(**inputs)

        Here I suppose that I will have to decode each generated output in the batch and compute the 
        mean cosine similarity of all the words with the topic word and then computing the mean over the 
        batch size

        if type(custom_loss) is tuple:

          custom_loss = (custom_loss[0] + (Here I would add the cos similarity), custom_loss[1])


          custom_loss = custom_loss + (Here I would add the cos similarity).

        return custom_loss

Nevertheless, I can’t understand why the results between the function generate(**inputs) and doing the argmax for each one of the logits for the model(**inputs) function are different:

tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("gpt2")
t = tokenizer.encode("how are you?", return_tensors="pt")
res = model.generate(t)
print(tokenizer.decode(res[0], skip_special_tokens=True))

outputs = model(t)
logits = outputs.logits[0]

for i in logits:
  prob = F.softmax(i)
  word_id = prob.argmax()

the result of the first print is “I’m not sure. I’m not sure if I’m going to” but the result of the second one is “, you going”. So I wanted to know what I’m doing wrong or what are the differences between each method in order to know what to use in the compute_loss function and getting the correct prediction.