Freezing layers when using gradient checkpointing


I am using BERT with pytorch_transformers V1.1.0. In this version, gradients checkpointing was not implemented, so I do it by myself simply by replacing layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) with layer_outputs = checkpoint(layer_module, hidden_states, attention_mask, head_mask[i]) in BertEncoder. This works fine when I try to fine-tune the entire BERT model. However, when I tried to freeze several layers of BERT with the code below:

        if layers_to_freeze is not None:
            modules = [self.transformer_model.embeddings,
            for module in modules:
                for param in module.parameters():
                    param.requires_grad = False

The training procedures behave the same no matter what layers_to_freeze I set (i.e., the trace of loss is exactly the same for different layers_to_freeze), while after I disable gradient checkpointing, it works as expected. I think this suggest that using torch.utils.checkpoint might interfere with param.requires_grad = False for freezing layers.

I notice that in your official implementation for gradient checkpointing, you define create_custom_forward first instead of directly call layer_module. I haven’t tested whether doing this would avoid the issue I mentioned so far. Before I test it, I am curious about out of what concern you are implementing it in this way? Will directly calling layer_module lead to any serious issue?