How estimate VRAM needed for prompt according to prompt's size (inference and fine tuning)

Hi, I noticed give a LLM a huge prompt (4000 tokens) can consume something around 6G-VRAM. (8 bits model)
So it’s really difficult to fine tune on huge prompt when you use free ColabT4.
My point is does someone can help me explaining the operations behind the VRAm consumption (in regard to the length of the prompt) when inference and fine tuning using Lora ?
I distingue both because when Fine tuning on a causal task it consumes the inference basis + the needed for fine tuning. (Gradients: 2 bytes for parameters + same for optimizers I read)
And this point is confusing me.
I meant, Alright doing inference on huge prompt the llms needs to keep in its embeddings the 4000 previous token it has seen. But I don’t understand why it is required AT THE BEGINNING of fine tuning, it should start with the first token of my prompt, predicts the next one. Apply the loss on the predictions and compare with the true next token on my corpus, adjust the weights, and so on until the 4000th token (at this point I would understand it consumes more VRAM than T4 has, but the consumption seems high extremly quickly, so I guess at the beginning of the process).

Can someone enlight me?

maybe interesting for @muellerzr !

1 Like