[SOLVED] Trying to fine-tune Llama, getting NaN gradients after a single step

I’m attempting to fine-tune Meta-Llama-3.1-8B-Instruct with some data that I’m using without using any library other than torch.

The model that I’m using:

  1. Uses torch.float16 as the dtype.
  2. Uses <|reserved_special_token_0|> as the padding token (this works well when generating).
  3. Is being inputted well-formed tokenised and target data, where the input and target tensors are the same shape, the input begins in <BOS>, the target ends in <EOS>, and both are right-padded.

Here is my Trainer class.

class Trainer:
    def __init__(
        self,
        model: Model,
    ):
        # model.model: LlamaForCausalLM
        [...]
        self.optimizer = Adam(model.parameters(), lr = 1e-5)

    def train(self, data, data_attn_mask, target) -> float:
        dataset = TensorDataset(data, data_attn_mask, target)
        dataloader = DataLoader(dataset, batch_size = 5, shuffle = True)

        loss = 0.
        for e in range(50):
            logging.info(f'Starting epoch {e}')
            loss = self.trainEpoch(dataloader)

            logging.info(f'Epoch {e} has loss {loss}')

        return loss

    def trainEpoch(self, loader) -> float:
        epoch_loss = 0.
        for e, (data, attn, target) in enumerate(loader):
            logging.info(f'Starting minibatch {e}.')
            data = data.to(self.model.device)
            attn = attn.to(self.model.device)
            target = target.to(self.model.device)

            self.optimizer.zero_grad()

            outputs = self.model.model(input_ids = data, attention_mask = attn, labels = target)
            loss = outputs.loss

            epoch_loss += loss.cpu().item() / data.shape[0]

            if loss.isnan():
                raise ValueError('Loss is nan. Oh no!')

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.model.parameters(), 1)
            self.optimizer.step()

            logging.info(f'Batch {e}: loss = {loss.cpu().item()}')

        return epoch_loss

After the first run of self.optimizer.step(), all the gradients in my model become NaN or infinity!

I feel that I’m missing something obvious. Can anybody figure out what’s going on? Otherwise, what’s the right way to fine-tune Llama?

(I’m aware that I am overfitting; this is just a technical test).

I found the problem!

Everything is solved if I move my dtype from torch.float16 to torch.bfloat16. Apparently Llama’s weights are all in float32 format, which is bit-compatible with brain floats but not with regular floats.