Accelerate/DeepSpeed: Flan-T5 OOM despite device_mapping

I currently want to get FLAN-T5 working for inference on my setup which consists of 6x RTX 3090 (6x. 24GB) and cannot get it to work in my Jupyter Notebook inside a Pytorch Nvidia Container (22.06).

I have already tried configuring DeepSpeed and Accelerate in order to reduce the size of the model and to distribute it over all GPUs. This works fine at the start, but only allocates about 10GB on every GPU. This is followed by GPU 0 being loaded fully and eventually encountering an OOM error.

This is the code I use to initialize accelerate and DeepSpeed.

# Set environment variable to store the file locally to avoid redownloading.
import os
!mkdir -p /home/jovyan/datasets/huggingface
os.environ['TRANSFORMERS_CACHE'] = '/home/jovyan/datasets/huggingface'

from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
import accelerate

# Prepare deepspeed
deepspeed = accelerate.DeepSpeedPlugin(gradient_accumulation_steps=1,
                                       zero3_init_flag=False,
                                       zero_stage=3,
                                       offload_optimizer_device="cpu",
                                       offload_param_device="cpu")
accelerator = accelerate.Accelerator(deepspeed_plugin=deepspeed)
accelerate.state.AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 1
accelerate.state.AcceleratorState()

Afterwards I try to load the models using accelerator.prepare, but here I run into the problem:

model, tokenizer = accelerator.prepare(
    T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="balanced_low_0"), 
    T5Tokenizer.from_pretrained("google/flan-t5-xxl", device_map="balanced_low_0")
)

I already tried all Zero stages, with and out deepspeed, with and without param/optimizer offloading and am out of ideas. I also tried setting num_processes but ran into an issue where I don’t know where to set acc_gradient_steps or world_size and can’t find any documentation either.

I’m sorry if this is a mundane task, but any help would be greatly appreciated.

Hello @Breenori, device_mapping is supported only for inference. Could you try removing it and then use DeepSpeed Stage 3 with CPU offloading?