Getting CUDA out of memory when calling save_pretrained in a script that tries lora training a large language model

I am trying to train a LLama LLM (“eachadea/vicuna-13b-1.1”) using LoRA on a LambdaLabs A100 40 GB.

Everything seems to be working fine including the training, however the script fails on the last line: lora_model.save_pretrained(lora_file_path)

With this exception:

Traceback (most recent call last):   File "train.py", line 151, in <module>
    lora_model.save_pretrained(lora_file_path)   File "/home/ubuntu/.local/lib/python3.8/site-packages/peft/peft_model.py", line 125, in save_pretrained
    output_state_dict = get_peft_model_state_dict(   File "/home/ubuntu/.local/lib/python3.8/site-packages/peft/utils/save_and_load.py", line 32, in get_peft_model_state_dict
    state_dict = model.state_dict()   File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1448, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)   File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1448, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)   File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1448, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)   [Previous line repeated 4 more times]   File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 1445, in state_dict
    self._save_to_state_dict(destination, prefix, keep_vars)   File "/usr/local/lib/python3.8/dist-packages/bitsandbytes-0.38.1-py3.8.egg/bitsandbytes/nn/modules.py", line 268, in _save_to_state_dict
    self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)   File "/usr/local/lib/python3.8/dist-packages/bitsandbytes-0.38.1-py3.8.egg/bitsandbytes/autograd/_functions.py", line 100, in undo_layout
    return outputs.reshape(rows, cols).contiguous() torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate
26.00 MiB (GPU 0; 39.56 GiB total capacity; 36.42 GiB already allocated; 18.56 MiB free; 38.17 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The text file is a 622KB book in text format.

here is the code:

import os, sys
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from datasets import Dataset, load_dataset
import transformers
from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training, set_peft_model_state_dict)

model_name = "eachadea/vicuna-13b-1.1"
load_in_8bit=True
lora_file_path = "my_lora"
text_filename='input.txt'
output_dir='.'
cutoff_len = 512
overlap_len = 128
newline_favor_len = 128

def split_chunks(arr, step):
    for i in range(0, len(arr), step):
        yield arr[i:i + step]

def cut_chunk_for_newline(chunk: str, max_length: int):
    if '\n' not in chunk:
        return chunk
    first_newline = chunk.index('\n')
    if first_newline < max_length:
        chunk = chunk[first_newline + 1:]
    if '\n' not in chunk:
        return chunk
    last_newline = chunk.rindex('\n')
    if len(chunk) - last_newline < max_length:
        chunk = chunk[:last_newline]
    return chunk

def tokenize(prompt):
    result = tokenizer(prompt, truncation=True, max_length=cutoff_len + 1, padding="max_length")
    return {
        "input_ids": result["input_ids"][:-1], # return all elements except the last one.
        "attention_mask": result["attention_mask"][:-1], # return all elements except the last one.
    }

model = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=load_in_8bit, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"

for param in model.parameters():
  param.requires_grad = False  # freeze the model - train adapters later
  if param.ndim == 1:
    # cast the small parameters (e.g. layernorm) to fp32 for stability
    param.data = param.data.to(torch.float32)

model.gradient_checkpointing_enable()  # reduce number of stored activations
model.enable_input_require_grads()

class CastOutputToFloat(nn.Sequential):
  def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)

config = LoraConfig(
    r=16, # 32 oob
    lora_alpha=32, # 64 oob
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

if not hasattr(model, 'lm_head') or hasattr(model.lm_head, 'weight'):
    print("prepare_model_for_int8_training...")
    prepare_model_for_int8_training(model)

lora_model = get_peft_model(model, config)

with open(text_filename, 'r', encoding='utf-8') as file:
    raw_text = file.read()

tokens = tokenizer.encode(raw_text)
del raw_text  # be safe on RAM
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
for i in range(1, len(tokens)):
    tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]

text_chunks = [tokenizer.decode(x) for x in tokens]
del tokens
if newline_favor_len > 0:
    text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]

train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
del text_chunks

trainer = transformers.Trainer(
    model=lora_model, 
    train_dataset=train_data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4, 
        gradient_accumulation_steps=4,
        warmup_steps=100, 
        max_steps=200, 
        learning_rate=2e-4, 
        fp16=True,
        evaluation_strategy="no",
        logging_steps=1, 
        output_dir=output_dir,
        ddp_find_unused_parameters=None,
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
lora_model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

if torch.__version__ >= "2" and sys.platform != "win32":
    lora_model = torch.compile(lora_model)

trainer.train()
lora_model.save_pretrained(lora_file_path)

Here are my versions:

accelerate==0.19.0
datasets==2.12.0
loralib==0.1.1
#numexpr==2.7.1
peft==0.3.0
protobuf==3.20.3
requests==2.28.2
sentencepiece==0.1.99
#tokenizers==0.13.3
#torch==1.13.1
#torchvision==0.14.1
transformers==4.29.2

User angelovAlex found the solution here: save_pretrained issue · Issue #1 · rhulha/lora · GitHub

Hey @rhulha ,
Did you manage to solve this problem? I’m facing the exact same issue.

Also facing this issue, training a 13B LLaMa-2 model with the HuggingFace trainer. Everything goes swimmingly until the very end, when a call to model.save_pretrained() triggers the OOM. Fine-tuning with LoRA in 8bit.