I am training a relatively large multimodal model on a definitely large dataset, which uses the embeddings of two other large models as inputs.
The cleanest solution is obviously to treat one large module. But I am worried about GPU memory even with a sharded model. I only easily have access to 4 h100s. So I am wondering if there is a way to efficiently offload the model being trained temporarily to load the pretrained ones and compute embeddings. Potentially doing this at the scale of multiple batches instead of every batch.
How much of a nightmare would this be and if doable are there any recommendations for how to approach it?
Thanks,
Evan