torch.cuda.reset_peak_memory_stats()
start_mem = torch.cuda.memory_allocated()
Run inference
output = model.generate(input_ids, max_new_tokens=256)
After inference
end_mem = torch.cuda.memory_allocated()
peak_mem = torch.cuda.max_memory_allocated()
2 Likes