Training BART, error when preparing decoder_input_ids. Shape of input_ids?

Hi!

I’ve been trying to train BART on a dialogue data set but got stuck on the following error.

I don’t have explicit decoder input ids, so the forward() function calls the _prepare_bart_decoder_inputs() function which in turn calls the shift_tokens_right() function to make the decoder input.

The dimensions of the tensor made in shift_tokens_right() don’t match the dimensions of the input_ids tensor for the torch.gather() function call to work.

Here is the error output:

~/.local/lib/python3.6/site-packages/transformers/modeling_bart.py in forward(self, input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, decoder_cached_states, use_cache, output_attentions, output_hidden_states)
    837                 decoder_input_ids=decoder_input_ids,
    838                 decoder_padding_mask=decoder_attention_mask,
--> 839                 causal_mask_dtype=self.shared.weight.dtype,
    840             )
    841         else:

~/.local/lib/python3.6/site-packages/transformers/modeling_bart.py in _prepare_bart_decoder_inputs(config, input_ids, decoder_input_ids, decoder_padding_mask, causal_mask_dtype)
    111     pad_token_id = config.pad_token_id
    112     if decoder_input_ids is None:
--> 113         decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
    114     bsz, tgt_len = decoder_input_ids.size()
    115     if decoder_padding_mask is None:

~/.local/lib/python3.6/site-packages/transformers/modeling_bart.py in shift_tokens_right(input_ids, pad_token_id)
    166.   def shift_tokens_right(input_ids, pad_token_id):
    167   """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
    168     prev_output_tokens = input_ids.clone()
    169     index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
--> 170     prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
    171     prev_output_tokens[:, 1:] = input_ids[:, :-1]
    172     return prev_output_tokens

RuntimeError: invalid argument 2: Input tensor must have same size as output tensor apart from the specified dimension at /pytorch/aten/src/THC/generic/THCTensorScatterGather.cu:28

My input_ids tensor has size [4, 2, 640]
while index_of_eos from shift_tokens_right() has the size [4, 640, 1]

which doesn’t fit the requirements for torch.gather().

Permuting the input_ids tensor before calling the function doesn’t work since then the dimensions of index_of_eos change as well and it doesn’t match.

The only thing I have found to work is to edit shift_tokens_right() such that the index_of_eos tensor is permutated. Which seems like a bad idea.

Is the shape of my input_ids tensor wrong?

Do the tensors need to have just 2 dimensions batch_size and sequence length?

Maybe how I create the input features for pytorch needs to be reviewed…

I’ve been here, with both dialgoue and BART. Solved all my problems by porting my code to pytorch lightning

Once ported, you can very easily use the training_step() function as follows, where self() calls model.forward():

    def training_step(self, batch, batch_id):
        """see lighting docs. Need to make sure that first token is not ignored"""

        decoder_inputs = batch["target_ids"][:, :-1].contiguous()
        decoder_labels = batch["target_ids"][:, 1:].clone()
        decoder_labels[batch["target_ids"][:, 1:] == self.tokenizer.pad_token_id] = -100

        loss = self(source_ids=batch["source_ids"],
                    padding_mask=batch["padding_mask"],
                    decoder_inputs=decoder_inputs,
                    decoder_labels=decoder_labels)[0]

        return {"loss": loss}
1 Like

Thanks for the answer.

Did you need to change the code in the BART module, or just your own code that uses it?

No need to change the HF code, just the structure of your own code:

class SampleModule(pl.LightningModule):
    def __init__(kwargs):
         # initialize stuff
         self.model = BartForConditionalGeneration.from_pretrained(kwargs.arch)
         # more stuff

     def forward(kwargs):
          return self.model(kwargs.batch)

     def training_step(kwargs):
           # do all stuff with shifting decoder inputs etc here
           # then call self() as your forward method