OOM Issues fine-tune DialogGPT-small

Hi, fairly new to deep learning and transformers in particular.

I have been trying to fine-tune DialogGPT-small on google-colab - and have been getting OOM exceptions for any batch-size larger than 2.

I am using this training loop:

from torch.nn.utils.rnn import pad_sequence
from tqdm.notebook import tqdm
import gc
gc.collect()
torch.cuda.empty_cache()

def _batch_train(batch, model, optimizer, scheduler):
    inputs = batch.to(args.device)
    labels = inputs
    out = model(inputs, labels=labels)
    loss = out.loss
    loss.backward()
    optimizer.step()
    scheduler.step()
    model.zero_grad()
    loss_value =  loss.item()
    del inputs
    del labels
    del loss
    del out
    return loss_value

def train(model, dataset, args: Args):
    print("Building optimizer")
    optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
    print("Building Dataloader")
    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
    
    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate)
    print("Building scheduler")
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(dataloader) // args.gradient_accumulation_steps * args.epochs
    )
    
    print('Starting training:')
    loss = None
    model.train()
    
    checkpoint = checkpoint_manager.load_checkpoint(model,optimizer)
    if checkpoint is None:
      i=0
    for i in range(i, args.epochs):
        print("EPOCH %d" % i)
        last_loss = None
        for i, batch in enumerate(tqdm(dataloader)):
            last_loss = _batch_train(batch, model, optimizer, scheduler)
        checkpoint_manager.save_checkpoint(model,optimizer, epoch=i)
        print("EPOCH", last_loss)

train(model, dataset, args)

I used the following code to see which part takes a lot of memory:

from torch.nn.utils.rnn import pad_sequence
import gc
gc.collect()
torch.cuda.empty_cache()

def print_usage():
  print("%.2f [GB]" % (torch.cuda.memory_allocated()/1e9))
optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
def collate(examples: List[torch.Tensor]):
    if tokenizer._pad_token is None:
        return pad_sequence(examples, batch_first=True)
    return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate)
scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(dataloader) // args.gradient_accumulation_steps * args.epochs
    )
model.to(args.device);
print("Initial")
print_usage()
batch = next(enumerate(dataloader))[1]

inputs = batch.to(args.device)
labels = inputs
print("Data Load:")
print_usage()

out = model(inputs, labels=labels)
print("Forward:")
print_usage()

loss = out.loss
loss.backward()
print("Backward:")
print_usage()

optimizer.step()
scheduler.step()
model.zero_grad()
optimizer.zero_grad()
last_loss = loss.item()
del inputs
del labels
del loss
del out
print("Done:")
print_usage()

and got this output:

Initial
0.51 [GB]
Data Load:
0.51 [GB]
Forward:
1.07 [GB]
Backward:
1.12 [GB]
Done:
2.12 [GB]

I am not surprised it takes about 4x memory - that’s the result of allocations in the forward function, backward function, and optimizer (I think?).

How does it fill the memory (12-15 [GB]) for very small batch sizes?
the max sequence length is about 800 tokens or so.

My aim with this project is learning, so suggestions for other models are more than appreciated.