Weird output from model.generate()

I’m trying to create a multi-task model that does both text generation and text classification, and to this end I have modified a BART model, adding several classification heads on top of the encoder hidden state. This all seems to work fine judging by the logits output, but model.generate() seems to be giving me the same generated string for the whole batch. For example:

INPUT IDS
tensor([[    0, 26039,  9271,  ...,     1,     1,     1],
        [    0,   133, 17842,  ...,     1,     1,     1],
        [    0, 40118,     5,  ...,     1,     1,     1],
        ...,
        [    0,  1106, 49279,  ...,     1,     1,     1],
        [    0,   863,  6537,  ...,     1,     1,     1],
        [    0, 38195,     5,  ...,     1,     1,     1]], device='cuda:0')

GENERATED TOKENS
tensor([[ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2],
        [ 2,  0, 11, 19,  4,  4, 14, 12, 13,  2]], device='cuda:0')

So I looked into how model.generate() is implemented, and it seems like since BART is an encoder-decoder model it first calls the encoder to get encoder_outputs and then uses that to produce the logits, which makes sense. However, the encoder outputs I get from this seem to be just the first item in the batch, repeated:

ENCODER OUTPUTS
tensor([[[-0.2394, -0.0438,  0.9226,  ..., -0.1061, -1.6471,  0.3405],
         [-0.2394, -0.0438,  0.9226,  ..., -0.1061, -1.6471,  0.3405],
         [-0.2394, -0.0438,  0.9226,  ..., -0.1061, -1.6471,  0.3405],
         ...,
         [-0.2394, -0.0438,  0.9226,  ..., -0.1061, -1.6471,  0.3405],
         [-0.2394, -0.0438,  0.9226,  ..., -0.1061, -1.6471,  0.3405],
         [-0.2394, -0.0438,  0.9226,  ..., -0.1061, -1.6471,  0.3405]],

LM LOGITS
tensor([[[12.0693, -0.3513, -2.2889,  ..., -0.4746,  1.1982, -0.8619],
         [ 3.3621, -0.7901, -1.9532,  ..., -0.8510,  1.1555, -0.0721],
         [-0.7564, -1.3166, -1.3422,  ..., -1.3504,  1.5587,  6.3961],
         ...,
         [ 0.2305, -0.8267,  7.1652,  ..., -0.8145, -0.4890, -2.4669],
         [ 0.7836, -0.3687, -5.4600,  ..., -1.0403,  0.8052, -1.3461],
         [-0.9993, -0.9102,  3.6957,  ..., -0.9280,  0.7593,  4.8513]],

        [[12.0693, -0.3513, -2.2889,  ..., -0.4746,  1.1982, -0.8619],
         [ 3.3621, -0.7901, -1.9532,  ..., -0.8510,  1.1555, -0.0721],
         [-0.7564, -1.3166, -1.3422,  ..., -1.3504,  1.5587,  6.3961],
         ...,
         [ 0.2305, -0.8267,  7.1652,  ..., -0.8145, -0.4890, -2.4669],
         [ 0.7836, -0.3687, -5.4600,  ..., -1.0403,  0.8052, -1.3461],
         [-0.9993, -0.9102,  3.6957,  ..., -0.9280,  0.7593,  4.8513]],

Has anyone seen this issue before? How can I ensure that model.generate() works for my custom model?

I now have:

  1. config.is_encoder_decoder = True
  2. model has implemented self.get_encoder
  3. model is able to produce encoder_outputs correctly given a batch of input_ids

Everything I try doesn’t seem to alleviate the repetition issue. My tracing seems to suggest the duplication is coming from these lines:

What does this do? Appreciate any insights here.