I am trying to store hidden states for the entire test set of xsum
using PegasusForConditionalGeneration
. I am able to do this for a couple of examples, but am unable to do so with a dataloader.
After the initial setup
def tokenization(example):
return tokenizer(example["document"], truncation=True, padding="longest", return_tensors="pt")
dataset_test = dataset["test"].map(tokenization, batched=True)
dataset_test = dataset_test.remove_columns(['document','summary','id'])
...
dataloader_test = DataLoader(dataset_test, batch_size=1)
outputs = []
with torch.no_grad() :
for batch in dataloader_test:
output = model.generate(**batch, output_hidden_states=True, output_scores=True, return_dict_in_generate=True)
outputs.append(output)
This gives an OOM on GPU on the first model.generate
itself, even after using things like gc.collect()
and torch.cuda.empty_cache()
.