Train LoRA adapters on Multiple Datasets in Parallel for llama7B

What I love about LoRA is the LoRA weights are “swappable” . I.e. for one frozen / pre-trained model, you can swap in different weights at inference time depending on your task.

Given a limited number of GPUs, is there a recommended way to train a multiple sets of LoRA adapters at once? For example, say I have 5 datasets and so want to train set of 5 lora-weights on these 5 datasets. I want to optimize for speed.

I am using llama-7b / 13b so we can assume that the entire model can fit on the VRAM of an A100 40GB.

I can think of a few approaches, but open to suggestions:

Naive Independent Training: Assign each GPU to a different dataset and train the LoRA weights for that dataset independently.

  • Pros: No gradient reconciliation required, maximum parallelism across datasets.
  • Cons: Doesn’t benefit from data parallelism for each dataset. GPUs that finish first sit idle.

Parallelize on each Dataset: Each GPU gets a copy of the pre-trained model and a copy of LoRA weights. Then, one dataset is split between the GPUs and the gradients are averaged between the GPUs during training. After one dataset is done being trained, on the LoRA adapters are moved off VRAM and saved and new ones are initialized on VRAM for the next dataset. (torch.nn.parallel.DistributedDataParallel).

  • Pros: The speed up of training one dataset in parallel. Pretrained model can stay in VRAM between different datasets.
  • Cons: time spent reconciling gradients, still doing one dataset at a time.

Are there more sophisticated ways to do this?


1 Like