Out of Memory training google/big-bird-roberta-base

I’ve been struggling to train Google’s Big Bird model using the transformers library due to out of memory errors. I have two Tesla V100 GPU’s with 32 GB RAM each. I’m trying to train the google/bigbird-roberta-base model on the Spider dataset using the Huggingface trainer API. I’m using a batch size of 1 and the smallest version of this model, and still get OOM errors. According to the Big Bird paper ([2007.14062] Big Bird: Transformers for Longer Sequences), Big Bird can be trained on chips with 16 GB of memory, so I’m not sure why I’m running into OOM. Has anyone encountered trouble training Big Bird due to memory problems?

Here’s the code that does the training:

rouge = datasets.load_metric("rouge")

training_args = Seq2SeqTrainingArguments(
    predict_with_generate = True,
    evaluation_strategy = "steps",
    per_device_train_batch_size = batch_size,
    per_device_eval_batch_size = batch_size,
    output_dir = "./",
    logging_steps = 2,
    save_steps = 400,
    eval_steps = 4

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),

trainer = Seq2SeqTrainer(
    model = model,
    tokenizer = tokenizer,
    args = training_args,
    compute_metrics = compute_metrics,
    train_dataset = train_data,
    eval_dataset = val_data


Here’s the exact error I get:

RuntimeError: CUDA out of memory. Tried to allocate 36.00 MiB (GPU 0; 31.75 GiB total capacity; 25.14 GiB already allocated; 21.50 MiB free; 26.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Thanks so much for sharing any experience you have with this!