Using model() instead of model.generate()

I am using pixtral-12b and I want to directly use the model() to get the output and their logits, how can it be used to product correct output ? It seems to be producing lots of gibberish when I try to decode the outputs from model().

        outputs = self.model(**inputs)
        # generated_ids = self.model.generate(**inputs)

        # generate_ids = self.model.generate(**inputs, max_new_tokens=50,do_sample=True,min_p=0.1, temperature=0.9)
        # output = self.processor.batch_decode(generate_ids, skip_special_tokens=True,clean_up_tokenization_spaces=False)
        # print(output[0]) 

        # Get logits for the last token
        logits = outputs.logits[:, -1, :]

        #
        # You might need to use the model's tokenizer to convert logits to text
        generated_ids = logits.argmax(dim=-1)
        generated_text = self.processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        
        print("Generated text:", generated_text)

        # # Initialize an empty list to store all generated token IDs
        # generated_ids = []

        # Get the logits for each token in the generated sequence
        for i in range(outputs.logits.size(1)):  # Assuming logits has shape (batch_size, seq_length, vocab_size)
            logits_test = outputs.logits[:, i, :]
            # Get the predicted token id for each step
            token_id = logits_test.argmax(dim=-1)
            generated_ids.append(token_id[0].item())  # Save token id as integer

Also using model() takes atleast 2x the memory than model.generate(), I am really confused about using model() directly, can someone refer me to correct docs or give any thoughts on this ?

1 Like

Can you explain a bit more about your answer ? isn’t it just calling the forward pass ?
And how about this ?

1 Like

Oh, sorry. I thought it was a normal class. It’s okay, though, because it inherits from the Torch module.