The output dimension of models for causal LM is
(batch_size, sequence_length, config.vocab_size). I don’t understand why this is the case. I would expect the outputs to be
(batch_size, config.vocab_size), i.e. the logits for next token prediction.
generate method always uses
next_token_logits = outputs.logits[:, -1, :], i.e. only the last tokens logits for next token prediction. However, the logits for all tokens of the sentence are computed, which very quickly blows the memory up for large
For example, for the Bloom family, the
vocab_size is 250880 (which is one of the largest vocab size). This means that even using bloom-560M (which is a small model), inference on a batch size of 64 with a prompt of say 500 tokens (or a small prompt with a large
max_new_tokens = 500) will take AT LEAST
64 * 500 * 250880 * 2 / 1024**3 = 14.95 GiB (the factor 2 is in case we loaded the model in float16, otherwise multiply this number by 2 again). And this is not even taking into account the memory needed for intermediate results in the
So we are basically blowing up the memory very quickly for large sequences or large
max_new_tokens, despite the fact that we only need the logits for the next token predictions and not the tokens we already have!
Is there any way to modify the
forward method of models for causal LM to only output logits of dimension
(batch_size, config.vocab_size)? Or is it a property of the underlying models and that would need retraining for scratch?