What is the difference between forward() and generate()?

Hi,

  • forward() can be used both for training and inference. Forward refers to a single forward pass through the network. During training, we apply a forward pass to get the model’s predictions, and then do a backward pass to compute the gradients of the parameters with respect to the loss, which we then update. We then do another forward pass, followed by another backward pass etc. This is typically done on batches of data.
  • generate() can only be used at inference time, and uses forward() behind the scenes, in a sequence of time steps (see this post for a simple showcase of that). The first forward is used to predict the first token, next we append the predicted token to the input of the next time step, which again uses forward() to predict the next token, and so on. This is called autoregressive generation. There are decoding strategies to decide which next token to take as prediction such as beam search, top k sampling, and so on (a detailed blog post can be found here).
5 Likes