OOM while generating last hidden states for entire dataset

I am trying to store hidden states for the entire test set of xsum using PegasusForConditionalGeneration. I am able to do this for a couple of examples, but am unable to do so with a dataloader.

After the initial setup

def tokenization(example):
    return tokenizer(example["document"], truncation=True, padding="longest", return_tensors="pt")
dataset_test = dataset["test"].map(tokenization, batched=True)

dataset_test = dataset_test.remove_columns(['document','summary','id'])

...

dataloader_test = DataLoader(dataset_test, batch_size=1)
outputs = []
with torch.no_grad() :
    for batch in dataloader_test:
        output = model.generate(**batch, output_hidden_states=True, output_scores=True, return_dict_in_generate=True)
        outputs.append(output)

This gives an OOM on GPU on the first model.generate itself, even after using things like gc.collect() and torch.cuda.empty_cache().

1 Like

What is the maximum length of the texts? Can you try limiting the length to a shorted sequence?