`text-generation` `Pipeline` prohibitively slow to load, even with cached model

I used the timeit module to test the difference between including and excluding the device=0 argument when instantiating a pipeline for gpt2, and found an enormous performance benefit of adding device=0; over 50 repetitions, the best time for using device=0 was 184 seconds, while the development node I was working on killed my process after 3 repetitions.

However, it looks like this performance benefit does not uniformly apply across different models. I am using the following code to load an LLM as a text-generation Pipeline instance:

from transformers import pipeline
pipe = pipeline('text-generation', model='MetaIX/GPT4-X-Alpasta-30b', device=0)

Even with device=0, my process gets killed before completion of a single load (processes on this node are killed when they reach 2 hours of CPU time).

What else can I do to speed up the loading of LLMs into pipeline objects?

EDIT: The MetaIX/GPT4-X-Alpasta-30b model is already downloaded locally, and this still happens

Here’s what I figured out in case it’s helpful to anyone else.

It seems like the model I chose was just substantially larger & required a lot more memory than the gpt2 model. In order to get both model loading and inference to work without OOM Errors, I used the following code to generate text for a given prompt:

# Load model and tokenizer
checkpoint = 'MetaIX/GPT4-X-Alpasta-30b'
model = AutoModelForCausalLM.from_pretrained(checkpoint,
        torch_dtype=torch.float16, device_map='auto', load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# Generate predictions
inputs = tokenizer(prompt, return_tensors='pt')
inputs = inputs.to(0)
output = model.generate(inputs['input_ids'], max_new_tokens=500)
response = tokenizer.decode(output[0].tolist())

The key for me was the load_in_8bit parameter. While I could load the model just fine without it, I would get CUDA OOM errors when it came time to use model.generate() for inference. Inference is still pretty slow, but at least it doesn’t get killed!