Fine-tuning with load_in_8bit and inference without load_in_8bit possible?

Hello

I would like to load the model EleutherAI/gpt-j-6B as 8-bit with:

model = GPTJForCausalLM.from_pretrained(
    "EleutherAI/gpt-j-6B",
    revision="float16",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    use_cache=False,
    gradient_checkpointing=True,
    device_map='auto',
    load_in_8bit=True
)

Then, I would like to fine-tune this model and safe it. Afterwards, I would like to load the fine-tuned model with out load_in_8bit so that I can run it on Windows (bitsandbytes library which is needed for load_in_8bit is not supported on Windows).

Is this possible or do I need to run a model fine-tuned with load_in_8bit always with this flag?

Does somebody has an answer to it or even tried it out?

Hi,

The LLM.8bit() algorithm as explained in the blog post is meant for inference.

However, bitsandbytes also provides functionalities to train models more efficiently, namely an 8-bit optimizer. See here for more info: GitHub - TimDettmers/bitsandbytes: 8-bit CUDA functions for PyTorch

Thanks for the explanation, that makes sense.

Do you know if it is possible to train using 8-bit optimizer from bitsandbytes but save it “normally” (without 8-bit) so that I can load the model without bitsandbytes?

Yes, it’s only the optimizer state which is in int8, not the model state. You just need to replace your regular optimizer by bnb.optim.Adam8bit(....)