Load_in_8bit vs. loading 8-bit quantized model

As additional context, GPU memory usage seems to take a leap at some point during training.

image (2)

For the first few steps, GPU memory usage seemed to be stable at around 50GB (using load_in_4bit), but this soon jumps to nearly 80GB.

image (3)

I’m not sure what causes this or at which step exactly this occurs. This is a the WandB memory allocation graph for part of this run, just in case it is helpful.