How to avert 'loading checkpoint shards'?

Hello, I have downloaded the model to my local computer in hopes of it would help me avoid the dreadfully slow loading process. Sadly it didn’t work as intend with the demo code. Is hat possible, and if so how can I adapt the code to do it?

from transformers import T5Tokenizer, T5ForConditionalGeneration

import torch

torch.cuda.set_per_process_memory_fraction(1.0)

tokenizer = T5Tokenizer.from_pretrained("LOCAL_PATH")

model = T5ForConditionalGeneration.from_pretrained("LOCAL_PATH", device_map="auto")

input_text = "INPUT"

input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model.generate(input_ids)

print(tokenizer. Decode(outputs[0]))
3 Likes

Split your code into two parts. Load the model once in a Jupyter notebook cell, and run the generation in a separate cell. This way, you load the model only once, speeding up the process.

First cell (run once):

from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
torch.cuda.set_per_process_memory_fraction(1.0)
tokenizer = T5Tokenizer.from_pretrained("LOCAL_PATH")
model = T5ForConditionalGeneration.from_pretrained("LOCAL_PATH", device_map="auto")

Second cell (run as needed):

input_text = "Your input text"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

Is there a way to do it without using notebooks?

You can load the model into a local server, Flask for example.

You can use FastAPI.

This is the code that works for me:

from contextlib import asynccontextmanager
from fastapi import FastAPI

service = Service()
app = FastAPI()


@asynccontextmanager
async def lifespan(app: FastAPI):
   # Note: this will only be called once
   service.load_model()


@app.get("/")
async def root():
    return {"message": "ColPali Search API", "docs": "/docs", "health": "/health"}


@app.post("/search")
async def search(query: str):
    response = service.search(query=query)
    return {"response": response}
1 Like