Should 24GB of VRAM be able to fine tune a 1B model?

I am currently trying to fine tune a very small model as an experiment. I have a dataset with about 600 rows. I have tried all sorts of things from:

Supervised Fine-tuning Trainer. I started with the simple example. Which seemed to work when my dataset was about 100 rows.

from datasets import load_dataset
from trl import SFTConfig, ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
from transformers import AutoModelForCausalLM,TrainingArguments, Trainer
import torch

torch.cuda.empty_cache()

dataset = load_dataset("json", data_files="Dataset.json", split="train")


AutoTokenizer.from_pretrained("/ContainerData/outetts/scratch/raw/new")
training_args = SFTConfig(
    model_init_kwargs={
        "torch_dtype": "bfloat16",
    },
    max_seq_length=2048,
    output_dir="train",
    learning_rate= 6e-05,
    per_gpu_eval_batch_size=4,
    per_device_train_batch_size=4,
    eval_steps=500,
    eval_accumulation_steps=16
)

model_args = ModelConfig(
    model_name_or_path="train/outetts",
    attn_implementation="flash_attention_2", # or "flash_attention_2"
)
torch_dtype = (
    model_args.torch_dtype
    if model_args.torch_dtype in ["auto", None]
    else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
    revision=model_args.model_revision,
    trust_remote_code=model_args.trust_remote_code,
    attn_implementation=model_args.attn_implementation,
    torch_dtype=torch_dtype,
    use_cache=False,
    device_map=get_kbit_device_map() if quantization_config is not None else None,
    quantization_config=quantization_config,
)

trainer = SFTTrainer(
    model=model_args.model_name_or_path,
    peft_config=get_peft_config(model_args),
    train_dataset=dataset,
    args=training_args,
)

And this evolved from a much simpler attempt:

from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from transformers import AutoModelForCausalLM,TrainingArguments, Trainer
import torch

dataset = load_dataset("json", data_files="Dataset.json", split="train")
print(dataset.column_names)  # This will now work correctly

training_args = SFTConfig(
    model_init_kwargs={
        "torch_dtype": "bfloat16",
    },
    max_seq_length=4096,
    output_dir="train",
    learning_rate= 6e-05,
    per_gpu_eval_batch_size=6,
)

model = "train/outetts"

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    args=training_args,
)

trainer.train()

Maybe I just need to get a runpod with more vram? But I thought a 1B model could be trained on 24GB of VRAM.

Thanks!

1 Like

I am not too sure, you could try setting the per_gpu_train_batch_size param.

1 Like

Omg its working… why is it working? It only works with that value set to 1.

2 Likes

You can definitely do it using peft, you should look into LoRA.

It’s very common these days to use LoRA or even quantized LoRA and you will be able to fine-tune even bigger models, HuggingFace already has quite a lot of materials on it: PEFT

Batch size is the number of examples that will be processed at once. Bigger the batch size, bigger the memory requirements.

But you don’t want to reduce it to a value like 1 as that affects the training quality.

One way to get around this is using gradient_accumulation_steps also along with a batch size of 1, as that will be equivalent of using a larger batch size.

See it here - Trainer

2 Likes

Follow this guide when training:

But yes, that should be sufficient to SFT a 1B model.

For a significant memory reduction use PEFT with QLoRa/LoRa

3 Likes

I hve 16GB ram for video card + 16GB vram and I have succesfully fine-tuned 8b llama model.

But it was on limits, training args needs some tuning.

So 1b with 24gb should be possible.

1 Like

you have the script you used and data format you used share

1 Like

unsloth has taken care of the hard stuff, they mostly work with LLM’s but I’m sure that the source can give you hints on how to help with your specific needs.

https://unsloth.ai/blog/grpo

1 Like

I followee this one here georgesung/llama3_8b_chat_uncensored · Hugging Face

1 Like