I’m using megatron-11b model to generate sequences with prompts. It’s on a dgx with 8*A100 40GB cards.
Model is loaded onto CPU first then wrapped with accelerator.prepare(), resulting in two 12GB shards of the model on two GPUs.
dataset is loaded from local txt file. Very small, ~200 samples as a tester. Prompts are about 10~20 words long. DataLoader wrapped around dataset, called eval_loader is also passed into accelerator.prepare(). No other components, e.g. optimizer or any other datasets or loaders.
Loop over eval_loader and asking for model.generate() with size-1 batch. One of the two GPUs throws OOM and program terminates: say GPU0 and GPU1 are both around 12GB occupancy, and GPU1 memory surges to fully loaded and OOM.
code is as follow:
from glob import glob
import datasets
import torch
import transformers
from accelerate import Accelerator
from datasets import load_dataset
from megatron_11b import MegatronForCausalLM, MegatronTokenizer
# from megatron_11b.megatron_policy import MegatronPolicy
# from parallelformers import parallelize
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
accelerator = Accelerator()
if accelerator.is_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
data_files = {i: each for i, each in enumerate(glob('./prompt/inputs/*.txt'))}
accelerator.print(data_files)
dataset = load_dataset('text', data_files=data_files, split=2)
accelerator.print(dataset)
eval_loader = DataLoader(dataset)
tokenizer = MegatronTokenizer.from_pretrained("./megatron-11B")
model = MegatronForCausalLM.from_pretrained("./megatron-11B")
model, eval_loader = accelerator.prepare(model, eval_loader)
progress_bar = tqdm(range(len(eval_loader)), disable=not accelerator.is_main_process)
model.eval()
for each in eval_loader:
encoded = tokenizer(each['text'], padding=True, return_tensors="pt")
with torch.no_grad():
result = model.generate(
input_ids=encoded.input_ids[:,:-1],
max_length=64, num_beams=5, no_repeat_ngram_size=2, repetition_penalty=1.2,
)
progress_bar.update(1)
accelerator.print(tokenizer.batch_decode(result))
Does anyone have an idea why this happens? How to fix it?