HF trainer shows different capabilities of memory efficiency with DeepSpeed or FairScale

Hi there, I would like to share this little experiment result with you, and see if it’s normal. I am pre-training BERT with MLM and NSP tasks. To test how big the model I can fit in, I used DeepSpeed and FairScale. The only increasing parameter of BERT is the hidden_size, everything else, like the maximum_sequence_size, batch_size is fixed.

On a single node with two GPUs, without taking DeepSpeed nor FairScale, the maximum hidden_size I can set is around 110x16. With 120x16, there is an OOM error. With DeepSpeed_zero_1, the maximum hidden_size is around 80x16. Stage2 and Stage3 do not help in increasing the model size.

With FairScale_simple, the maximum hidden_size is around 120x16, around 70% memory is occupied. I thought FaireScale_zero_dp_2 or level 3 could help in this, but not.

In a nutshell, with the same technique behind DeepSpeed and Fairscale, they show different capabilities. Is this normal? or Is it a bug? or I didn’t set the argument correctly. And, furthermore, I didn’t see the capabilities increasing with the stages. I thought It should do better with DeepSpeed Stage 2 and Stage 3. The same for Fairscale. Is my assumption right?

Thank you.

The script and the data are attached, so you can reproduce the experiments.


import torch
from torch import nn
from torch.utils.data import Dataset
from transformers import BertConfig, BertForPreTraining, \
Trainer, TrainingArguments, HfArgumentParser, set_seed
import numpy as np
import sys

from dataclasses import dataclass

class DataNumpyDataset(Dataset):
    def __init__(self):
        self.dataset_size = 1024

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        if idx >= self.dataset_size:
            raise IndexError
            data = np.load("data.npy", allow_pickle=True)
            pt_data = dict()
            pt_data["input_ids"] = torch.tensor(data[0]["input_ids"], dtype=torch.long)[0,:].cuda()
            pt_data["labels"] = pt_data["input_ids"].detach().clone() #torch.tensor(data[0]["input_ids"], dtype=torch.long)[0,:].cuda()
            pt_data["next_sentence_label"] = torch.tensor(data[1], dtype=torch.long)[0,:].cuda()
            pt_data["token_type_ids"] = torch.tensor(data[0]["token_type_ids"], dtype=torch.long)[0,:].cuda()
            pt_data["attention_mask"] = torch.tensor(data[0]["attention_mask"], dtype=torch.float)[0,:].cuda()
            return pt_data
        except EOFError:
            raise IOError("Reach the end of the data file")

class ModelArguments:
    hidden_size: int = 1024

if __name__ == '__main__':
    # args_hidden_size = 1920
    args_num_hidden_layers = 24
    args_num_attention_heads = 16
    args_vocab_size = 30522

    parser = HfArgumentParser((ModelArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
        model_args, training_args = parser.parse_args_into_dataclasses()

    configuration = BertConfig(hidden_size=model_args.hidden_size, num_hidden_layers=args_num_hidden_layers,
                               num_attention_heads=args_num_attention_heads, intermediate_size=4*model_args.hidden_size,
    model = BertForPreTraining(configuration)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    dataset = DataNumpyDataset()

    trainer = Trainer(


Scripts to run it:

deepspeed --num_gpus=2 \
./train_HF.py \
--do_train \
--learning_rate 1e-4 \
--hidden_size=1760 \
--num_train_epochs 10 \
--per_device_train_batch_size 1 \
--output_dir "./log" \
--save_strategy="no" \
--report_to="none" \
--fp16=false \
--dataloader_pin_memory=False \
--deepspeed 'ds_config_zero1.json'

ds_config_zero1.json, zero2.json, zero3.json, and the data is attached in drive.google.
link to deep_config files and data.

Some other info:
*GPU device is GeForce RTX 3090, with around 24GiB.
*Driver Version: 460.32.03 CUDA Version: 11.2.

  • deepspeed 0.5.4, installed from pip
  • transformers 4.12.2, installed from pip
  • fairscale fairscale 0.4.1, installed from pip
  • pytorch version 1.8.1+cu111
  • python version Python 3.7.7
  • f16 optimization is disabled.
  • gradient checkpointing is disabled
  • CPU offload is disabled.