This is the part of the source code that loads the model. I want to avoid using 4-bit quantization because I’ve read that it can lead to significant degradation in performance.
For additional context, I use accelerate to enable distributed data parallel, and the dataset I use for fine-tuning has approximately 1k entries. I use the datasets.map()
function along with the model tokenizer to generate data in the form of { "input_ids" : List[int], "attention_mask" : List[int], "labels" : List[int] }
to feed to the transformers Trainer.