Issues with model.generate()

Hi, I am trying to train distilgpt-2 with a custom head . After training for a while when I try to generate text using model.generate(), I get an error “InvalidArgumentError: {{function_node _wrapped__Reshape_device/job:localhost/replica:0/task:0/device:GPU:0}} Input to reshape is a tensor with 30720 values, but the requested shape requires a multiple of 76800 [Op:Reshape]”
My LM head class is as follows :slight_smile:

class policy_rl(TFGPT2LMHeadModel):
  def __init__(self,config):
    self.transformer = TFGPT2Model(config)
    self.lm_head = tf.keras.layers.Dense(config.vocab_size, input_shape=(config.n_embd,), use_bias=False)

  def call(self, input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None,
             head_mask=None, inputs_embeds=None,output_attentions=None,output_hidden_states=None,
           return_dict= None, labels=None, use_cache=True, get_hidden=False):
    transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask,
                                               token_type_ids=token_type_ids, position_ids=position_ids,
                                               head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache)
    hidden_states = transformer_outputs[0]

    if get_hidden:
      return hidden_states
    logits = self.lm_head(hidden_states)
    loss = None
    #labels = input_ids
    if labels is not None:
      # shift labels to the left and cut last logit token
      print('inside labels')
      logits = logits[:, :-1]
      labels = labels[:, 1:]
      loss = self.compute_loss(labels, logits)  
    #if not inputs["return_dict"]:
    output = (logits,) + transformer_outputs[1:]
    #  return ((loss,) + output) if loss is not None else output 
    return TFCausalLMOutputWithPast(
  def serving_output(self, output):
    pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
    hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
    attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
    loss = tf.convert_to_tensor(output.loss)

    return TFCausalLMOutputWithPast(loss=loss,logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns)

I am trying to execute this :slight_smile:

text = 'The fox jumped'
input_ids = tokenizer.encode(text, return_tensors='tf')
beam_output = model.generate(
  past_key_values = None,
  max_length = 50,
  num_beams = 5,
  temperature = 0.7,

I tried using the below as well :

output_model = model.generate(input_ids[0],max_length=10)  # do greedy decoding
print(f"Generated: {tokenizer.decode(output_model, skip_special_tokens=True)}")

This gives the error :
InvalidArgumentError: {{function_node _wrapped__StridedSlice_device/job:localhost/replica:0/task:0/device:GPU:0}} Index out of range using input dim 1; input has only 1 dims [Op:StridedSlice] name: strided_slice/

Please can anyone help me with this issue.

@Rocketknight1 I saw your comment on Fine-Tuning TFGPT2LMHeadModel / What to pass to fit · Issue #11507 · huggingface/transformers · GitHub. if you could please have a look.