Training from a checkpoint and freezing some of model's parameters

Hi all,

I am training from a checkpoint and want to freeze some of models’ parameters and start training from there. This is the code i have written:


    # Freeze all the parameters
    for param in model.parameters():
        param.requires_grad = False
        
    # Unfreeze hce_encoder
    for param in model.roberta.hce_encoder.parameters():
        param.requires_grad = True

    # Create your optimizer with the updated model parameters
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4)
    

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        optimizers=(optimizer, None),
        compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics
        if training_args.do_eval and not is_torch_tpu_available()
        else None,
    )

But it gives me an error that:
ValueError: loaded state dict has a different number of parameter groups

This is the full error:

 File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/transformers/trainer.py", line 1679, in _inner_training_loop
    self._load_optimizer_and_scheduler(resume_from_checkpoint)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/transformers/trainer.py", line 2471, in _load_optimizer_and_scheduler
    self.optimizer.load_state_dict(
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/accelerate/optimizer.py", line 103, in load_state_dict
    self.optimizer.load_state_dict(state_dict)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/optim/optimizer.py", line 385, in load_state_dict
    raise ValueError("loaded state dict has a different number of "
ValueError: loaded state dict has a different number of parameter groups

Any idea how to resolve it? Any help is appreciated.

With a first looking, I’m quite sure it’s because when you first trained the model, your optimizer was created with all model parameters, i.e. optim = AdamW(model.parameters(), ...), while now you created the optimizer as optim = AdamW(filter(lambda p: p.requires_grad, model.parameters()), ...) i.e. you’re not considering all model parameters (specifically, you’re considering only the hce encoder of roberta module), but you’re still loading the state_dict with all the model’s parameters. So required params and saved state dict don’t match.