I’m using megatron-11b
model to generate sequences with prompts. It’s on a dgx with 8*A100 40GB cards.
Model is loaded onto CPU first then wrapped with accelerator.prepare()
, resulting in two 12GB shards of the model on two GPUs.
dataset is loaded from local txt file. Very small, ~200 samples as a tester. Prompts are about 10~20 words long. DataLoader
wrapped around dataset, called eval_loader
is also passed into accelerator.prepare()
. No other components, e.g. optimizer or any other datasets or loaders.
Loop over eval_loader and asking for model.generate()
with size-1 batch. One of the two GPUs throws OOM and program terminates: say GPU0 and GPU1 are both around 12GB occupancy, and GPU1 memory surges to fully loaded and OOM.
code is as follow:
from glob import glob
import datasets
import torch
import transformers
from accelerate import Accelerator
from datasets import load_dataset
from megatron_11b import MegatronForCausalLM, MegatronTokenizer
# from megatron_11b.megatron_policy import MegatronPolicy
# from parallelformers import parallelize
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
accelerator = Accelerator()
if accelerator.is_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
data_files = {i: each for i, each in enumerate(glob('./prompt/inputs/*.txt'))}
accelerator.print(data_files)
dataset = load_dataset('text', data_files=data_files, split=2)
accelerator.print(dataset)
eval_loader = DataLoader(dataset)
tokenizer = MegatronTokenizer.from_pretrained("./megatron-11B")
model = MegatronForCausalLM.from_pretrained("./megatron-11B")
model, eval_loader = accelerator.prepare(model, eval_loader)
progress_bar = tqdm(range(len(eval_loader)), disable=not accelerator.is_main_process)
model.eval()
for each in eval_loader:
encoded = tokenizer(each['text'], padding=True, return_tensors="pt")
with torch.no_grad():
result = model.generate(
input_ids=encoded.input_ids[:,:-1],
max_length=64, num_beams=5, no_repeat_ngram_size=2, repetition_penalty=1.2,
)
progress_bar.update(1)
accelerator.print(tokenizer.batch_decode(result))
Does anyone have an idea why this happens? How to fix it?