Problem
I just had a really confusing bart bug where I was passing input_ids and decoder_input_ids, but not labels, to avoid computing the loss ( I wanted to compute a custom loss).
Since use_cache
defaults to True
if is only overwritten if labels are passed, you need to pass use_cache=False
to achieve the above.
Proposed Solution
I think it would make more sense to have use_cache=False
, by default, and then set it to True
in prepare_inputs_for_generation
, if applicable.
Alternatively, we could default use_cache=False
in BartConfig
.
Here is a (broken) test you can paste into test_modeling_bart.py line 303
def test_hidden_states_shape(self):
config, input_ids, batch_size = self._get_config_and_data()
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
decoder_input_ids = shift_tokens_right(lm_labels, config.pad_token_id)
lm_model = BartForConditionalGeneration(config)
lm_model.to(torch_device)
expected_shape = (batch_size, input_ids.shape[1], config.d_model)
outputs = lm_model(input_ids=input_ids, labels=lm_labels, output_hidden_states=True, decoder_input_ids=decoder_input_ids,)
self.assertEqual(outputs.decoder_hidden_states[0].shape, expected_shape)
outputs = lm_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids,
output_hidden_states=True,
# use_cache=False, # test passes if you uncomment this
)
self.assertEqual(outputs.decoder_hidden_states[0].shape, expected_shape)