Gradual Layer Freezing

I have a short question. How do I perform gradual layer freezing using the huggingface trainer. I read that, one can freeze layers with:

modules = [L1bb.embeddings, *L1bb.encoder.layer[:5]] #Replace 5 by what you want
for module in mdoules:
    for param in module.parameters():
        param.requires_grad = False

but using the huggingface trainer, I do not write my own loops, where I can start freezing some layers lets say starting the second epoch. How can I start freezing some layers only from the second epoch on and then gradually increase the number of layers frozen per epoch?


1 Like

There is nothing out of the box in the library to unfreeze parts of your model during training. You can pass the model with some layers frozen, using the code you wrote, but it will stay this way.

You can try to use a TrainerCallback to unfreeze parts of the model in the middle of the training (after a given number of steps/epochs).

Thank you very much for the reply. Could you help me on how to achieve that using a Callback?

You could do it the following way. It is a bit awkward, but works as far as I know:

class FreezingCallback(TrainerCallback):
    """Callback to gradually unfreeze the model according to a freezing :class:`Schedule` during training. It ensures that the model is always completely unfrozen before saving it to avoid unexpected behaviour."""

    def __init__(self, freezing_schedule: Schedule, trainer: Trainer, model_config: GPT2Config):
        self.model_config = model_config
        self.trainer = trainer
        self.freezing_schedule = freezing_schedule
        self.current_step_idx = 0

    def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.epoch >= self.freezing_schedule.schedule[self.current_step_idx][1]:
            self.current_step_idx += 1
                              self.model_config.n_layer, int(state.epoch))

    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        for name, param in self.trainer.model.named_parameters():
            param.requires_grad = True

    def freeze_model(self, freeze_to: int, highest_layer: int, epoch: int):
        print(f"\nEpoch {epoch}: Freezing model to layer {freeze_to} of {highest_layer} layers.")

        for name, param in self.trainer.model.named_parameters():
            # find out the number of every layer. GPT2-specific!
                layer_number = int('\.h\.\d+\.', name).group().strip(".h"))
            except AttributeError:
                layer_number = math.inf
            # freeze all layers up to layer freeze_to including embedding layers
            if '.wte.' in name or '.wpe.' in name or layer_number <= freeze_to:
                param.requires_grad = False

Then, before calling trainer.train(), initialize the callback and add it to the trainer:

freezing_callback = FreezingCallback(freezing_schedule, trainer, config)

Note that the freeze_model-method is GPT2-specific here since it relies on the naming of the layers.
The Schedule object used to initialize the callback is a list of tuples where the first entry represents the layer to freeze to and the second the epoch until which those layers shall be frozen. I parse the schedule from the commandline. But this is all optional.


@arvidunt Would you be able to elaborate on the Schedule object you used?

Sure. It is not as fancy as it sounds, just a custom, auxiliary object to represent a freezing schedule. An easy way to implement it would be some kind of array of tuples. Suppose this array would be the following:
freezing_schedule.schedule = [(a,b), (c,d), (e,f)].
Then, according to the implementation of on_epoch_begin, in the first step on training, the model would be frozen up to layer a until epoch b (since freezing_schedule.schedule[0][0] equals a and freezing_schedule.schedule[0][1] equals b), in the second step it would be frozen up to layer c until epoch d, and so on.
Does that help?

Sure, that helps. I figured you were using a specific Schedule library and since I just started using HuggingFace, I wanted to make sure I wasn’t missing something. Thanks for the clarification 6 months after the initial post!