Llama 2 & 8K Training


I am using SFFTrainer & bitsandbytes to fine tune Llama-2-7B on a dataset where the input will consistently be 5K-8K tokens. I am using an A10G and have no problems with setting max_seq_length to 2K or 4K, but whenever I set it to 5K+ I run out of VRAM. Also, worth mentioning that I am using RoPE to hopefully accommodate the larger context.

Do I need to set max_seq_length to 8K to train effectively on this dataset? Also, what is the relationship of max_seq_length & num_of_sequences? Is there anyway to accomodate a larger sequence length like 8K in the the VRAM I have available? I thought I could do it by decreasing num_of_sequences, but that doesn’t seem to be have any effect on the amount of VRAM being reserved.

Thank you!

1 Like