I am using BERT with
pytorch_transformers V1.1.0. In this version, gradients checkpointing was not implemented, so I do it by myself simply by replacing
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) with
layer_outputs = checkpoint(layer_module, hidden_states, attention_mask, head_mask[i]) in
BertEncoder. This works fine when I try to fine-tune the entire BERT model. However, when I tried to freeze several layers of BERT with the code below:
if layers_to_freeze is not None: modules = [self.transformer_model.embeddings, *self.transformer_model.encoder.layer[:layers_to_freeze]] for module in modules: for param in module.parameters(): param.requires_grad = False
The training procedures behave the same no matter what
layers_to_freeze I set (i.e., the trace of loss is exactly the same for different
layers_to_freeze), while after I disable gradient checkpointing, it works as expected. I think this suggest that using
torch.utils.checkpoint might interfere with
param.requires_grad = False for freezing layers.
I notice that in your official implementation for gradient checkpointing, you define
create_custom_forward first instead of directly call
layer_module. I haven’t tested whether doing this would avoid the issue I mentioned so far. Before I test it, I am curious about out of what concern you are implementing it in this way? Will directly calling
layer_module lead to any serious issue?