Accelerator load_state for LM head with tied weights

I’m having an issue with accelerator.load_state when trying to resume training from a specific checkpoint. Here’s how I save the model checkpoints during training:

# In the training loop:
accelerator.wait_for_everyone()
save_dir = os.path.join(args.save_dir, f"step_{completed_steps}")
accelerator.save_state(save_dir)

When I attempt to load the state using the following code:

tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
model = AutoModelForMaskedLM.from_pretrained(args.model_ckpt)

accelerator = Accelerator(log_with=["wandb", "tensorboard"], project_config=config)

# Add accelerator state to the arguments
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
args = Namespace(**vars(args), **acc_state)
samples_per_step = accelerator.state.num_processes * args.train_batch_size
print(args)

# Check if shuffle buffer is divisible by train batch size
if args.shuffle_buffer < (args.train_batch_size * args.gradient_accumulation_steps * accelerator.state.num_processes * 5):
    args.shuffle_buffer = args.train_batch_size * args.gradient_accumulation_steps * accelerator.state.num_processes * 10

# Load dataset and dataloader
train_dataloader, eval_dataloader = create_dataloaders(args=args, tokenizer=tokenizer)

set_seed(args.seed)

# Logging
logger, run_name = setup_logging(args, accelerator=accelerator)
logger.info(accelerator.state)

# Prepare optimizer and learning rate scheduler
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate * accelerator.state.num_processes)
lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=int(args.max_train_steps * 0.05),  # 5% of the training steps
    num_training_steps=args.max_train_steps * accelerator.state.num_processes,
)

# Register the lr_scheduler to the accelerator
accelerator.register_for_checkpointing(lr_scheduler)

# Prepare model, optimizer, and dataloaders
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

# Load the weights and states from a previous save
if args.resume_from_checkpoint:
    if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
        accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
        accelerator.load_state(args.resume_from_checkpoint)
        path = os.path.basename(args.resume_from_checkpoint)

However, I get the following error:

RuntimeError: Error(s) in loading state_dict for DebertaV2ForMaskedLM:
Missing key(s) in state_dict: "cls.predictions.decoder.weight", "cls.predictions.decoder.bias".

Interestingly, when I upload the same checkpoint to the hub and use this code to load the model, I don’t get any errors:

tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForMaskedLM.from_pretrained(model_ckpt)

I also found that in the DeBERTa model, the tied weights are defined like this:

class DebertaForMaskedLM(DebertaPreTrainedModel):
    _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

It seems that accelerator.load_state isn’t handling the tied weights properly.

How can I solve this issue? I’m using the same script that has worked before for other cases (e.g., when training from scratch or loading the checkpoint as a pre-trained model without resuming).

Any suggestions would be appreciated!

transformers==4.38.2
accelerate==0.28.0
torch==2.2.0