Fine-tuning with Different Model Heads

I have a question about the heads that models are initialized. From my understanding, classes like AutoModelForCausalLM or AutoModelForSeq2SeqLM initialize LLMs with a certain type of head. How might one determine which head is appropriate for their fine-tuning approach, and how does the head influence the training process?

For example, this GitHub repository has demo code that initializes a model with AutoModelForCausalLM but then performs instruction tuning with a DataCollatorForSeq2Seq. I understand that the DataCollatorForCausalLM isn’t suitable for labelled data (since it sets the labels to be copies of the input_ids even if you manually create labels for your dataset), but I don’t understand why the model was initialized with the CausalLM head.

If anyone has any insights into the different model heads and how they tie in with the data collator and trainer in the fine-tuning process, I would appreciate the guidance.

The DataCollator just prepares the batches before they go to the models forward pass. You could write your own DataCollator. Mostly they just apply padding convert data to tensors, sometime the tokeniser is invoked here. So it is just dependant on what your data looks like and what it needs to look like to be valid input for the forward pass of the model (and to ensure a consistent batch).

So the reason they chose that Data Collator is simply because it fit the task they needed to use it for.

I think my problem is that I’ve been too concerned with the names of the different classes in the transformers library. Most of my confusion has come from a shallow understanding of each class in the fine-tuning pipeline and why they are named the way they are. Thank you for the clarification :hugs:

You may find benefit from a deeper knowledge of the models and the concept of machine learning with neural networks. Hugging face gives us a good framework but it abstracts us away from deeper understanding.

My advice would be to begin with a simple encoder model and use it for classification. Though, without using huggingface Trainer. Implement your own training loop, your own data collator and data loaders, evalutations, and optimisations and you will find that you understand why the huggingface classes are the way that they are.

Then, hit the theory a bit more by looking at resources like “The Illustrated Transformer” or dive into the concept of “embeddings”

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.