Fine-tuning Mistral/Mixtral for sequence classification on long context


I would like fine-tune Mistral or if possible Mixtral for classification of long sequences if it is possible up to 32k context. While this models need a lot of memory to train on their own, if I understand correctly needed memory increases quadratically, for which reason I run out of memory as I try to increase the context.

For this reason I tried running it on A100, as well as using quantitization and LoRA, which enabled me, to run the code, but as I increase the context I get error that I ran out of memory.

I started looking at ZeRO implementation with deepspeed and accelerate, model and pipeline parallelism and how to implement it on multiple A100. But due to being quite new to to this I am not really sure how to implement this and if this will resolve my problem.

I would be grateful for any advice if I am going into the right direction or how should I approach this/ is there any good example of implementation or anything else.

Thank you!