Memory overhead/usage calculation

torch.cuda.reset_peak_memory_stats()

start_mem = torch.cuda.memory_allocated()

Run inference

output = model.generate(input_ids, max_new_tokens=256)

After inference

end_mem = torch.cuda.memory_allocated()

peak_mem = torch.cuda.max_memory_allocated()

2 Likes