Why is the memory quickly filled up in the first few iterations when using Trainer of transformers to train the network, and then drops to a very low level as the training progresses?

I’m training the vision module of qwen2.5-vl-3B model and freeze the other module. The vision module is composed of transformer block.I use one 80GB A100 to train it. The max pixel of image is set to 1,048,576, which means about 5349 vision token as the input of vision module (the patch size is 14*14).
When the trainer started, the memory of gpu is quickly filled up in the first few iterations, actually the memory is 79000MiB/81920Mib. After a few minutes, memory usage droped quickly to 35000MiB and stabilized.
I want to know what causes this.

1 Like