T5.generate() cannot get hidden states although output_hidden_states=True

Hi, I was trying to modify the T5 model for for generation and need to add additional layers so I need the hiddenstates output.

ids = torch.tensor([[  822,    10,   116,   405,   158,   449,   253,    91,    81, 13902,
             3,  4172,  5907,     1,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,    84,    13,     8,   826,    19,  1176,    81,     8,
         16629,    13,  1043,    45,  7450,  1043,   168,     7,     1,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   163, 18554,     7,    11, 10235,     8,  5796,    31,
             7,     8,  2006,     1,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   113,   293,     7,     3,     9,   443,   116,    34,
            19,     3, 29107,     1,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,  1300,   572,    19,  2772,    53,  1702,     3,     9,
          1934,  1573,    13,   331,  1232,    16, 19647,   585,     1,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   116,   405,     3,   189,   127,  2385,    16,   467,
            13,     3,  8514,  1496,     1,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   213,   405,     8,   540,    45,     8,  1481,   369,
            45,     1,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,    38,     3,     9,   879,  3356,    46,  6152,    54,
          1472,    46,  3490,    44,    56,    21,   125,  2081,     1,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   113,   751,     8,   166,  2045,  9365,    16,     8,
          2265,     3,    32,   120,    51,  6174,     7,   846,     1,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   113,    19,     8,  3763,    13,  2515,   663,  1157,
          3370,   372,     1,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   113,  2832,    25,     3,     9,    77,    31,    17,
            59,  2907,    68,     3,     9,     3, 26219,  1782,     1,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   113,  1944,   491,  2576, 21919,    31,     7,  3062,
            16,  8581, 22534,   220,     1,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   113,    47,  1381,    16,     8,   387,    40,    32,
            32, 25153, 28882,     1,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   113,   243,     3,    99,     3,    23,    43,   894,
           856,    34,    19,    57,  4125,    30,     8, 15424,    13,  6079,
             7,     1,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   125,    19,     8,  2530,    13,     8,   564,   108,
            52,    23,     1,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   113,   808,   610,    13,     8, 24168,   789,    16,
           957,  2884,     1,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   125,   410,     8, 19085,   871,   103,   116,    79,
           608,     8,  2068,     1,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,   125,    19,     8,   711,   800,    13,     8,  1517,
          2195,     1,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,    19,     8,  1974,     8,   256, 12549,    17,     3,
           390,    30,     3,     9,  1176,   733,     1,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  822,    10,     3,  2754,  3216,    22,     7,  4516,    13,   936,
          6620,    11, 25515,    13,  1857,   936,   523,     1,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0]])

output1 = T5ForConditionalGeneration.from_pretrained('t5-large').generate(
                ids,
                max_length=20,
                output_hidden_states=True,
                prefix_allowed_tokens_fn=restrict_decode_vocab,
                early_stopping=True)

and got the following output without hidden states:

tensor([[    0,     3,  4613,     1],
        [    0,     3,   632,     1],
        [    0,     3,   632,     1],
        [    0,     3,   632,     1],
        [    0,     3,   632,     1],
        [    0,     3,  1206,     1],
        [    0,     3,     1,     0],
        [    0,     3,   632,     1],
        [    0,     3,  3959,     1],
        [    0,     3,  3891,     1],
        [    0,     3,     1,     0],
        [    0,     3,  1206,     1],
        [    0,     3,  4056,     1],
        [    0,     3,  1206,     1],
        [    0,     3,   632,     1],
        [    0,     3, 19978,     1],
        [    0,     3,   632,     1],
        [    0,     3,   632,     1],
        [    0,     3,     1,     0],
        [    0,     3,  4225,     1]])

Anyone knows what’s going wrong?
Thanks!