How to fine-tune a model for my use-case?

I would like to fine-tune a model to do the following task: given an input text, return relevant labels that describe it. One example might be applying labels to stack overflow questions–the key points here are that there are many labels (on the order of tens of thousands) and multiple labels can be used to classify to a given input. I would like to use MPT-7B for this task.

From my research about this problem, it seemed that the best options would be a Sequence-to-Sequence model or a Sequence Classification model, however since I want to use MPT-7B, my understanding is that I am limited to using a Causal LM model.

As a first pass, I’ve been trying to adapt the code in the Hugging Face Causal Language Modeling explainer to function in a manner similar to my task–my thinking being that if I can fine-tune the model in this example to accept the ELI5 question as input and generate an ELI5 answer as output, I can then transition to building out my classification use-case.

The first approach I tried taking resulted in this code:

from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer



eli5 = load_dataset("eli5", split="train_asks[:5000]")
eli5 = eli5.train_test_split(test_size=0.2)

block_size = 1024

tokenizer = AutoTokenizer.from_pretrained("distilgpt2", 
                                          padding='max_length', 
                                          truncation=True, 
                                          max_length=block_size)

def preprocess_function(examples):
    # use the question as the text, and the first answer as the target
    return tokenizer(text=[" ".join([x[0], x[1]]) for x in zip(examples["title"], 
                                                               examples["selftext"])], 
                     text_target=[x["text"][0] for x in examples["answers"]])

tokenized_eli5 = eli5.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=eli5["train"].column_names,
)


def group_texts(examples):
    examples["input_ids"] = [x[0:block_size] for x in examples["input_ids"]]
    examples["labels"] = [x[0:block_size] for x in examples["labels"]]
    return examples

lm_dataset = tokenized_eli5.map(group_texts, batched=True, num_proc=4)

from transformers import DataCollatorForSeq2Seq

tokenizer.pad_token = tokenizer.eos_token
# This was originally a DataCollatorForLanguageModeling but since I'm 
# essentially performing a Seq2Seq task, I thought the 
# DataCollatorForSeq2Seq might be a better option, but maybe this is
# what is causing my issues. 
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

model = AutoModelForCausalLM.from_pretrained("distilgpt2")

training_args = TrainingArguments(
    output_dir="my_awesome_eli5_clm-model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_dataset["train"],
    eval_dataset=lm_dataset["test"],
    data_collator=data_collator,
)

This results in the following trace:

│ in <cell line: 1>:1                                                                              │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1645 in train                    │
│                                                                                                  │
│   1642 │   │   inner_training_loop = find_executable_batch_size(                                 │
│   1643 │   │   │   self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size  │
│   1644 │   │   )                                                                                 │
│ ❱ 1645 │   │   return inner_training_loop(                                                       │
│   1646 │   │   │   args=args,                                                                    │
│   1647 │   │   │   resume_from_checkpoint=resume_from_checkpoint,                                │
│   1648 │   │   │   trial=trial,                                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1938 in _inner_training_loop     │
│                                                                                                  │
│   1935 │   │   │   │   │   self.control = self.callback_handler.on_step_begin(args, self.state,  │
│   1936 │   │   │   │                                                                             │
│   1937 │   │   │   │   with self.accelerator.accumulate(model):                                  │
│ ❱ 1938 │   │   │   │   │   tr_loss_step = self.training_step(model, inputs)                      │
│   1939 │   │   │   │                                                                             │
│   1940 │   │   │   │   if (                                                                      │
│   1941 │   │   │   │   │   args.logging_nan_inf_filter                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2759 in training_step            │
│                                                                                                  │
│   2756 │   │   │   return loss_mb.reduce_mean().detach().to(self.args.device)                    │
│   2757 │   │                                                                                     │
│   2758 │   │   with self.compute_loss_context_manager():                                         │
│ ❱ 2759 │   │   │   loss = self.compute_loss(model, inputs)                                       │
│   2760 │   │                                                                                     │
│   2761 │   │   if self.args.n_gpu > 1:                                                           │
│   2762 │   │   │   loss = loss.mean()  # mean() to average on multi-gpu parallel training        │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2784 in compute_loss             │
│                                                                                                  │
│   2781 │   │   │   labels = inputs.pop("labels")                                                 │
│   2782 │   │   else:                                                                             │
│   2783 │   │   │   labels = None                                                                 │
│ ❱ 2784 │   │   outputs = model(**inputs)                                                         │
│   2785 │   │   # Save past state if it exists                                                    │
│   2786 │   │   # TODO: this needs to be fixed and made cleaner later.                            │
│   2787 │   │   if self.args.past_index >= 0:                                                     │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/gpt2/modeling_gpt2.py:1113 in        │
│ forward                                                                                          │
│                                                                                                  │
│   1110 │   │   │   shift_labels = labels[..., 1:].contiguous()                                   │
│   1111 │   │   │   # Flatten the tokens                                                          │
│   1112 │   │   │   loss_fct = CrossEntropyLoss()                                                 │
│ ❱ 1113 │   │   │   loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.v  │
│   1114 │   │                                                                                     │
│   1115 │   │   if not return_dict:                                                               │
│   1116 │   │   │   output = (lm_logits,) + transformer_outputs[1:]                               │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/loss.py:1174 in forward                 │
│                                                                                                  │
│   1171 │   │   self.label_smoothing = label_smoothing                                            │
│   1172 │                                                                                         │
│   1173 │   def forward(self, input: Tensor, target: Tensor) -> Tensor:                           │
│ ❱ 1174 │   │   return F.cross_entropy(input, target, weight=self.weight,                         │
│   1175 │   │   │   │   │   │   │      ignore_index=self.ignore_index, reduction=self.reduction,  │
│   1176 │   │   │   │   │   │   │      label_smoothing=self.label_smoothing)                      │
│   1177                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:3029 in cross_entropy             │
│                                                                                                  │
│   3026 │   │   )                                                                                 │
│   3027 │   if size_average is not None or reduce is not None:                                    │
│   3028 │   │   reduction = _Reduction.legacy_get_string(size_average, reduce)                    │
│ ❱ 3029 │   return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(re  │
│   3030                                                                                           │
│   3031                                                                                           │
│   3032 def binary_cross_entropy(                                                                 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Expected input batch_size (1896) to match target batch_size (1808).

I suspect that this has something to do with me misusing the AutoModelForCausalLM or else the DataCollatorForSeq2Seq. Some things that come to mind as potential issues:

  • The colab says " What’s cool about language modeling tasks is you don’t need labels (also known as an unsupervised task) because the next word is the label." but I’m trying to set labels when preprocessing the data
  • The DataCollatorForSeq2Seq is probably incompatible with the AutoModelForCausalLM in some way.

Any suggestions for what I’m doing wrong (or if there is a better way to approach my task altogether) would be very helpful.