Hi, fairly new to deep learning and transformers in particular.
I have been trying to fine-tune DialogGPT-small
on google-colab - and have been getting OOM exceptions for any batch-size larger than 2.
I am using this training loop:
from torch.nn.utils.rnn import pad_sequence
from tqdm.notebook import tqdm
import gc
gc.collect()
torch.cuda.empty_cache()
def _batch_train(batch, model, optimizer, scheduler):
inputs = batch.to(args.device)
labels = inputs
out = model(inputs, labels=labels)
loss = out.loss
loss.backward()
optimizer.step()
scheduler.step()
model.zero_grad()
loss_value = loss.item()
del inputs
del labels
del loss
del out
return loss_value
def train(model, dataset, args: Args):
print("Building optimizer")
optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
print("Building Dataloader")
def collate(examples: List[torch.Tensor]):
if tokenizer._pad_token is None:
return pad_sequence(examples, batch_first=True)
return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate)
print("Building scheduler")
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(dataloader) // args.gradient_accumulation_steps * args.epochs
)
print('Starting training:')
loss = None
model.train()
checkpoint = checkpoint_manager.load_checkpoint(model,optimizer)
if checkpoint is None:
i=0
for i in range(i, args.epochs):
print("EPOCH %d" % i)
last_loss = None
for i, batch in enumerate(tqdm(dataloader)):
last_loss = _batch_train(batch, model, optimizer, scheduler)
checkpoint_manager.save_checkpoint(model,optimizer, epoch=i)
print("EPOCH", last_loss)
train(model, dataset, args)
I used the following code to see which part takes a lot of memory:
from torch.nn.utils.rnn import pad_sequence
import gc
gc.collect()
torch.cuda.empty_cache()
def print_usage():
print("%.2f [GB]" % (torch.cuda.memory_allocated()/1e9))
optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
def collate(examples: List[torch.Tensor]):
if tokenizer._pad_token is None:
return pad_sequence(examples, batch_first=True)
return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(dataloader) // args.gradient_accumulation_steps * args.epochs
)
model.to(args.device);
print("Initial")
print_usage()
batch = next(enumerate(dataloader))[1]
inputs = batch.to(args.device)
labels = inputs
print("Data Load:")
print_usage()
out = model(inputs, labels=labels)
print("Forward:")
print_usage()
loss = out.loss
loss.backward()
print("Backward:")
print_usage()
optimizer.step()
scheduler.step()
model.zero_grad()
optimizer.zero_grad()
last_loss = loss.item()
del inputs
del labels
del loss
del out
print("Done:")
print_usage()
and got this output:
Initial
0.51 [GB]
Data Load:
0.51 [GB]
Forward:
1.07 [GB]
Backward:
1.12 [GB]
Done:
2.12 [GB]
I am not surprised it takes about 4x memory - that’s the result of allocations in the forward function, backward function, and optimizer (I think?).
How does it fill the memory (12-15 [GB]) for very small batch sizes?
the max sequence length is about 800 tokens or so.
My aim with this project is learning, so suggestions for other models are more than appreciated.