Out of memory training 3B param model on 8 GPU (320GB memory) with FSDP

I’m using transformers==4.29.0, and running the train script with:

torchrun --nproc_per_node=8 train-script.py

I am unable to train a 3B parameter model with p4d.24xlarge due to out of memory errors, this instance has 8 GPUs with 40GB each for a total of 320 GB GPU memory. Is this normal? All 8 GPUs were filling up approx uniformly until they breached the memory limit. What can I change in my configuration? Note that I am loading the model as shown below, without explicitly making any GPU specifications (such as device_map=‘auto’).

The torchrun command seems to launch 8 python processes. All 8 of these processes start out by loading my pandas dataframe of the data, and they all separately tokenize the same data. This makes me concerned that the model is also being loaded 8 different times. The tokenizer stage certainly involves very large memory usage, beyond what I would consider reasonable for a tokenizer. Is torchrun supposed to automagically manipulate the loading of the model so that it only loads once and shards it properly?

model = AutoModelForCausalLM.from_pretrained(ckpt)

dataset = Dataset.from_pandas(df)
data_collator = DataCollatorForLanguageModeling(
... tokenize dataset...

training_args = TrainingArguments(
    gradient_accumulation_steps=8, # alpaca
    evaluation_strategy="no", # alpaca
    save_total_limit=1, # alpaca
    learning_rate=2e-5, # alpaca
    weight_decay=0.0, # alpaca
    warmup_ratio=0.03, # alpaca
    lr_scheduler_type='cosine', # alpaca
    fsdp="full_shard auto_wrap", # alpaca 

trainer = Trainer(

I also tried “full_shard” in this way, and the CUDA memory usage showed that every GPU just load a full model. So I guss this is not the right way to implement model parameters parallel. Unfortunately, I don’t know the right way. What’s more, as I know for now, model parallel is not a good idea, because it slow down our training speed. I recommend you to use int8 + peft to save CUDA memory.

Finally, I‘m a beginner in NLP, please forgive any mistakes in what I said.