What is the difference between forward() and generate()?

Hi!

It seems like some models implement both functions and semantically they behave similarly, but might be implemented differently? What is the difference? In both cases, for an input sequence, the model produces a prediction (inference)?

Thank you,

wilornel

Hi,

  • forward() can be used both for training and inference. Forward refers to a single forward pass through the network. During training, we apply a forward pass to get the model’s predictions, and then do a backward pass to compute the gradients of the parameters with respect to the loss, which we then update. We then do another forward pass, followed by another backward pass etc. This is typically done on batches of data.
  • generate() can only be used at inference time, and uses forward() behind the scenes, in a sequence of time steps (see this post for a simple showcase of that). The first forward is used to predict the first token, next we append the predicted token to the input of the next time step, which again uses forward() to predict the next token, and so on. This is called autoregressive generation. There are decoding strategies to decide which next token to take as prediction such as beam search, top k sampling, and so on (a detailed blog post can be found here).
5 Likes

@nielsr Can you provide any insight as to why one would prefer to use one over the other.

For example, I am realizing that using generate() we are not able to obtain the model loss at inference time. If I can also generate the text sequence using forward(), i’d rather just use that. I feel there is something else at play though.

1 Like

The generate method is more feature complete with various fancier decoding methods besides greedy decoding, such as beam search and top-k sampling.