Corrupted deepspeed checkpoint

Hello. I configured my training run to use deepspeed zero stage2 with the huggingface trainer with the following settings:

DEEPSPEED_CONFIG = {
    "optimizer": {
        "type": "AdamW",
        "params": {"lr": "auto", "betas": "auto", "eps": "auto", "weight_decay": "auto"},
    },
    "scheduler": {
        "type": "WarmupLR",
        "params": {"warmup_min_lr": "auto", "warmup_max_lr": "auto", "warmup_num_steps": "auto"},
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {"device": "cpu", "pin_memory": True},
        "allgather_partitions": True,
        "allgather_bucket_size": 2e8,
        "overlap_comm": False,
        "reduce_scatter": True,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": True,
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
}

During training I use the trainer to create checkpoints and Sagemaker checkpointing configuration in the Huggingface estimator to create checkpoints and upload to s3.
I then create zero_to_fp32.py (from the deepspeed checkpoint) to create a pytorch_model.bin file.
When I try to load that pytorch_model.bin to resume training, I get the following error:

RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got Long

When I load the pytorch_model.bin model to perform an evaluation, I also notice that loss values on the same Dev dataset as training are near infinite compared to the small loss values observed at the same step in training.

The training behavior is as expected. Losses reported during training at different steps are L2 losses and normal for the problem. The underlying transformer model is a Falcom model with a custom configuration.

Has anyone observed similar problems trying to resume training or run inference from deepspeed checkpoints?

1 Like

It seems like a troublesome problem…
In some cases, it seems that you can avoid it by omitting the torch_dtype specification when loading.