Why does per_device_train_batch_size have a severe impact on memory?

While fine tuning a 13B model w/ 2048 context window, I noticed that the model itself took less than 10GB of memory when loaded with 4 bit quantization using the bitsandbytes library. However even with per_device_train_batch_size=1 the memory jumped to 18GB+ once I started fine-tuning the model using the Trainer class that comes with the transformers library. I would like to understand is this normal and why? with per_device_train_batch_size=1, I am only giving one “data point” to the model (so to say) and yet it ends up taking almost 10GB of memory. This does not seem all right to me. A quick back of the envelope calculation gives me this: with per_device_train_batch_size=1, we are giving the model a vector of 2048 token ids. even if each token id is mapped to an embedding vector (say 768 in length) - the memory taken should be 20487684 (lets say 4 byte or 32 bit float) which is only 6MB and way less than 10GB. can anyone help me understand this please?