Albert Pre-training with Batch size 8 is throwing OOM

Environment info

  • transformers version: 4.15.0
  • Platform: Ubuntu 18.1
  • Python version: 3.7
  • PyTorch version (GPU?): 1.6+cuda 10.1
  • Tensorflow version (GPU?): None
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Who can help

The tasks I am working on is:
Masked Language Modelling (using AlbertForMaskedLM)
Training Albert model from scratch

python --model_type albert --num_train_epochs 300 --train_file /home/kushwanth/write_chunks/sample.txt --validation_file /home/kushwanth/write_chunks/sample.txt --tokenizer_name albert --do_train=yes --output_dir=/home/kushwanth/model --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --save_steps 3000 --logging_steps 500 --report_to tensorboard --preprocessing_num_workers 10 2>&1 | tee /home/kushwanth/log_1.txt

GPU spects:
Number of GPUs 8 with 16GB RAM each

Trying to allocate batch size of 8 and we are getting Out of memory error and GPU utilisation is 50% on avg

Model Config:
“attention_probs_dropout_prob”: 0,
“bos_token_id”: 2,
“classifier_dropout_prob”: 0.1,
“embedding_size”: 128,
“eos_token_id”: 3,
“hidden_act”: “gelu_new”,
“hidden_dropout_prob”: 0,
“hidden_size”: 768,
“initializer_range”: 0.02,
“inner_group_num”: 1,
“intermediate_size”: 3072,
“layer_norm_eps”: 1e-12,
“max_position_embeddings”: 512,
“model_type”: “albert”,
“num_attention_heads”: 12,
“num_hidden_groups”: 1,
“num_hidden_layers”: 12,
“pad_token_id”: 0,
“position_embedding_type”: “absolute”,
“torch_dtype”: “float32”,
“transformers_version”: “4.15.0”,
“type_vocab_size”: 1,
“vocab_size”: 40000

File "", line 442, in main                                                                                      
    train_result = trainer.train(resume_from_checkpoint=checkpoint)                                                               
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/transformers/", line 1332, in train       
    tr_loss_step = self.training_step(model, inputs)                                                                              
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/transformers/", line 1891, in training_step         loss = self.compute_loss(model, inputs)                                                                                       
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/transformers/", line 1923, in compute_loss      
    outputs = model(**inputs)                                                                                                     
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/", line 722, in _call_impl          result = self.forward(*input, **kwargs)                                                                                       
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/parallel/", line 156, in forward 
    return self.gather(outputs, self.output_device)                                                                               
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/parallel/", line 168, in gather  
    return gather(outputs, output_device, dim=self.dim)
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/parallel/", line 68, in gather
    res = gather_map(outputs)
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/parallel/", line 62, in gather_m
    for k in out))
  File "<string>", line 7, in __init__
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/transformers/", line 2294, in __post_init__
    for element in iterator:
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/parallel/", line 62, in <genexpr>
    for k in out))
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/parallel/", line 55, in gather_m
    return Gather.apply(target_device, dim, *outputs)
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/parallel/", line 68, in forward
    return comm.gather(inputs, ctx.dim, ctx.target_device)
  File "/home/kushwanth/anaconda3/envs/py37/lib/python3.7/site-packages/torch/cuda/", line 166, in gather
    return torch._C._gather(tensors, dim, destination)
RuntimeError: CUDA out of memory. Tried to allocate 4.88 GiB (GPU 0; 15.78 GiB total capacity; 7.63 GiB already allocated; 2.08 Gi
B free; 12.50 GiB reserved in total by PyTorch)