Hi,
I am using BERT with pytorch_transformers V1.1.0
. In this version, gradients checkpointing was not implemented, so I do it by myself simply by replacing layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
with layer_outputs = checkpoint(layer_module, hidden_states, attention_mask, head_mask[i])
in BertEncoder
. This works fine when I try to fine-tune the entire BERT model. However, when I tried to freeze several layers of BERT with the code below:
if layers_to_freeze is not None:
modules = [self.transformer_model.embeddings,
*self.transformer_model.encoder.layer[:layers_to_freeze]]
for module in modules:
for param in module.parameters():
param.requires_grad = False
The training procedures behave the same no matter what layers_to_freeze
I set (i.e., the trace of loss is exactly the same for different layers_to_freeze
), while after I disable gradient checkpointing, it works as expected. I think this suggest that using torch.utils.checkpoint
might interfere with param.requires_grad = False
for freezing layers.
I notice that in your official implementation for gradient checkpointing, you define create_custom_forward
first instead of directly call layer_module
. I haven’t tested whether doing this would avoid the issue I mentioned so far. Before I test it, I am curious about out of what concern you are implementing it in this way? Will directly calling layer_module
lead to any serious issue?