Why does PretrainedConfig.use_cache default to True?



I just had a really confusing bart bug where I was passing input_ids and decoder_input_ids, but not labels, to avoid computing the loss ( I wanted to compute a custom loss).
Since use_cache defaults to True if is only overwritten if labels are passed, you need to pass use_cache=False to achieve the above.

Proposed Solution

I think it would make more sense to have use_cache=False, by default, and then set it to True in prepare_inputs_for_generation, if applicable.
Alternatively, we could default use_cache=False in BartConfig.

Here is a (broken) test you can paste into test_modeling_bart.py line 303

    def test_hidden_states_shape(self):
        config, input_ids, batch_size = self._get_config_and_data()
        lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
        decoder_input_ids = shift_tokens_right(lm_labels, config.pad_token_id)
        lm_model = BartForConditionalGeneration(config)
        expected_shape = (batch_size, input_ids.shape[1], config.d_model)
        outputs = lm_model(input_ids=input_ids, labels=lm_labels, output_hidden_states=True, decoder_input_ids=decoder_input_ids,)
        self.assertEqual(outputs.decoder_hidden_states[0].shape, expected_shape)

        outputs = lm_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids,
                           # use_cache=False, # test passes if you uncomment this

        self.assertEqual(outputs.decoder_hidden_states[0].shape, expected_shape)
1 Like