Yea, I did just that, but still got the error (transformers==4.9.2):
batch['attention_mask'] = inputs.attention_mask
batch['input_ids'] = inputs.input_ids
batch['token_type_ids'] = inputs.token_type_ids
batch["decoder_input_ids"] = outputs.input_ids.copy()
batch["labels"] = outputs.input_ids.copy()
Where outputs
are from decoding the translations. I guess the error I got was something else.