Using Lora for inference

Context:
In our company, we have the requirement for a sequence classification model. We currently make use of finetuned distilbert models which are specific to each respective client. Hence, assuming we have 50 clients, we have a FastAPI microservice which :

  • Loads 50 models on initialization of microservice
  • Predicts on a span of text passed as a REST Api parameter on api calls to the FastAPI microservice. (i.e it uses the corresponding distilbert model of each client as per the api call parameter)

Proposed Change:
After experiments with Lora on larger language models (eg. deberta), we find that the performance is comparable to baseline (finetuned distilbert) for each client. As per current understanding, we find this appealing because instead of a finetuned model for each client, we can use a single large base model and have Lora adapters for each client.

Our understanding is lora adapter is 10x smaller than finetuned distilbert model, leading to large space saving. Lora also allows for the adapter to be attached onto the base model on inference (i.e per each api call). This should allow for us to have only 1 base model loaded into CPU/Memory on microservice initialization instead of 50 distilberts, leading to savings in CPU and other resources. Time taken for attaching a lora adapter to a base model was found to be sub 0.02 seconds (manageable).

Question:

  • Is our understanding on the usage and benefits of Lora adapters for inference correct? Is this approach scalable?
  • How do we handle multiple api calls (i.e api call corresponding to different clients) in the microservice at the same time? (assuming we want to load minimum number of instances of the base model at the start and attach adapter before inference for each api call)
  • What are suggestions and best practices to follow for this specific usecase? (unable to find much literature on the web)

Thank you in advance for help provided.

1 Like

Hey Sumba Hi,

LoRA adapters can be used that way for deberta model so the understanding is correct but for multiple API requests, you have to load the different adapters for the sequence classification. If that works out for great else alternative would be to have weighted adapters added to the model to reduce operation of adding different adapters based on API requests or traffic.