T5 memory usage

Hi folks,

I’ve been running some profile on flan-t5-base, where it seems to scale very poorly with batch size. In particular, running a forward pass [w/ 250 context window; batch size of 1] allocates 947MB to store the model and ~70MB for the forward pass. This scales roughly linearly: at inference, a batch of 2 uses ~150MB, a batch of 3 uses ~205, etc.

However, at training time, with a batch size of 1, the first batch’s forward pass allocates ~750MB, and the second batch allocates an additional ~900MB, staying constant from then on. With a batch size of 2, the first batch’s forward pass allocates ~1.6G, and the second allocates ~1G. This first allocation scales linearly with batch size; with a batch size of 8, the first forward pass allocates ~7G.

Is this expected? I’d imagine at training time we’d take a memory hit from having to store the gradients, but this should be constant. The activations, in turn, I’d expect to be relatively small: 12 layers * 256 context window * 768 hidden layer size * 4 bits to a float is only ~10MB – assuming I’m understanding the architecture correctly.