GPU OOM when training

I’m running the language modeling script provided here. I’m training a Roberta-base model and I have an RTX 3090 with 24 Gb, although when training it runs well until 9k steps, then an OOM error is through. The memory usage on training begins at 12Gb, runs a few steps, and keeps growing until OOM error. It seems to be that previous batches aren’t freed from the memory but I am not sure yet.

I implemented my dataset class and passed it to the Trainer, although I am loading all raw data into the RAM, I only tokenized them at the __getitem__ method, so I don’t think this is the actual issue.

Does anyone have some thoughts on this?

My dataset class:

class LMDataset(Dataset):
    def __init__(
        self,
        base_path: str,
        tokenizer: AutoTokenizer,
        set: str = "train",
    ):
        self.tokenizer = tokenizer
        src_file = Path(base_path).joinpath("processed", "{}.csv".format(set))

        df = pd.read_csv(src_file, header=0, names=["text"])
        self.samples = df["text"].to_list()

    def __len__(self):
        return len(self.samples)

    def _tokenize(
        self,
        text: str,
        padding: Optional[Union[str, bool]] = False,
        max_seq_length: Optional[int] = None,
    ):
        return self.tokenizer(
            text,
            padding=padding,
            truncation=True,
            max_length=max_seq_length or self.tokenizer.model_max_length,
            return_special_tokens_mask=True,
        )

    def __getitem__(
        self,
        i,
        padding: Optional[Union[str, bool]] = False,
        max_seq_length: Optional[int] = None,
    ):
        input_ids = self._tokenize(self.samples[i], padding, max_seq_length)[
            "input_ids"
        ]
        return torch.tensor(input_ids, dtype=torch.long)

My guess would be that you have a specific sample in your dataset that is very long. Your collate function (not shown) might then be padding up to that length. That means that, for instance, your first <9k steps are of size 128x64 (seq_len x batch_size), which does not lead to an OOM. But then, around 9k steps you have a large sequence as a sample, which would (for instance) lead to 384 x 64 input, leading to an OOM.

So check the data distribution of your dataset, and check the collate function. You may want to specify a max_length that is smaller than model max length after all.

1 Like

Thank you, Bram! You are totally right, a problematic data sample of extremely large size caused this issue. Also, my batch size didn’t fit with the max_seq_length, when I used padding="max_length" I got an OOM on the first batch, which was expected. I reduced my batch size (sad) and trimmed the samples to tokenizer.model_max_length.

1 Like