Error when training with `peft` + `lora`

Hello, I am trying to use the tutorial here, Google Colab and Iโ€™m finetuning it on a custom dataset. I am loading my dataset from a pandas dataframe and Iโ€™m not sure what the error means here. Can anyone help me with this? TIA!

โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ in <cell line: 21>:21                                                                            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1664 in train                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1661 โ”‚   โ”‚   inner_training_loop = find_executable_batch_size(                                 โ”‚
โ”‚   1662 โ”‚   โ”‚   โ”‚   self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size  โ”‚
โ”‚   1663 โ”‚   โ”‚   )                                                                                 โ”‚
โ”‚ โฑ 1664 โ”‚   โ”‚   return inner_training_loop(                                                       โ”‚
โ”‚   1665 โ”‚   โ”‚   โ”‚   args=args,                                                                    โ”‚
โ”‚   1666 โ”‚   โ”‚   โ”‚   resume_from_checkpoint=resume_from_checkpoint,                                โ”‚
โ”‚   1667 โ”‚   โ”‚   โ”‚   trial=trial,                                                                  โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1909 in _inner_training_loop     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1906 โ”‚   โ”‚   โ”‚   โ”‚   rng_to_sync = True                                                        โ”‚
โ”‚   1907 โ”‚   โ”‚   โ”‚                                                                                 โ”‚
โ”‚   1908 โ”‚   โ”‚   โ”‚   step = -1                                                                     โ”‚
โ”‚ โฑ 1909 โ”‚   โ”‚   โ”‚   for step, inputs in enumerate(epoch_iterator):                                โ”‚
โ”‚   1910 โ”‚   โ”‚   โ”‚   โ”‚   total_batched_samples += 1                                                โ”‚
โ”‚   1911 โ”‚   โ”‚   โ”‚   โ”‚   if rng_to_sync:                                                           โ”‚
โ”‚   1912 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   self._load_rng_state(resume_from_checkpoint)                          โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:634 in __next__           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    631 โ”‚   โ”‚   โ”‚   if self._sampler_iter is None:                                                โ”‚
โ”‚    632 โ”‚   โ”‚   โ”‚   โ”‚   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   โ”‚
โ”‚    633 โ”‚   โ”‚   โ”‚   โ”‚   self._reset()  # type: ignore[call-arg]                                   โ”‚
โ”‚ โฑ  634 โ”‚   โ”‚   โ”‚   data = self._next_data()                                                      โ”‚
โ”‚    635 โ”‚   โ”‚   โ”‚   self._num_yielded += 1                                                        โ”‚
โ”‚    636 โ”‚   โ”‚   โ”‚   if self._dataset_kind == _DatasetKind.Iterable and \                          โ”‚
โ”‚    637 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   self._IterableDataset_len_called is not None and \                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:678 in _next_data         โ”‚
โ”‚                                                                                                  โ”‚
โ”‚    675 โ”‚                                                                                         โ”‚
โ”‚    676 โ”‚   def _next_data(self):                                                                 โ”‚
โ”‚    677 โ”‚   โ”‚   index = self._next_index()  # may raise StopIteration                             โ”‚
โ”‚ โฑ  678 โ”‚   โ”‚   data = self._dataset_fetcher.fetch(index)  # may raise StopIteration              โ”‚
โ”‚    679 โ”‚   โ”‚   if self._pin_memory:                                                              โ”‚
โ”‚    680 โ”‚   โ”‚   โ”‚   data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)            โ”‚
โ”‚    681 โ”‚   โ”‚   return data                                                                       โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py:49 in fetch             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   46 โ”‚   def fetch(self, possibly_batched_index):                                                โ”‚
โ”‚   47 โ”‚   โ”‚   if self.auto_collation:                                                             โ”‚
โ”‚   48 โ”‚   โ”‚   โ”‚   if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:         โ”‚
โ”‚ โฑ 49 โ”‚   โ”‚   โ”‚   โ”‚   data = self.dataset.__getitems__(possibly_batched_index)                    โ”‚
โ”‚   50 โ”‚   โ”‚   โ”‚   else:                                                                           โ”‚
โ”‚   51 โ”‚   โ”‚   โ”‚   โ”‚   data = [self.dataset[idx] for idx in possibly_batched_index]                โ”‚
โ”‚   52 โ”‚   โ”‚   else:                                                                               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py:2782 in __getitems__           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   2779 โ”‚                                                                                         โ”‚
โ”‚   2780 โ”‚   def __getitems__(self, keys: List) -> List:                                           โ”‚
โ”‚   2781 โ”‚   โ”‚   """Can be used to get a batch using a list of integers indices."""                โ”‚
โ”‚ โฑ 2782 โ”‚   โ”‚   batch = self.__getitem__(keys)                                                    โ”‚
โ”‚   2783 โ”‚   โ”‚   n_examples = len(batch[next(iter(batch))])                                        โ”‚
โ”‚   2784 โ”‚   โ”‚   return [{col: array[i] for col, array in batch.items()} for i in range(n_example  โ”‚
โ”‚   2785                                                                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py:2778 in __getitem__            โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   2775 โ”‚                                                                                         โ”‚
โ”‚   2776 โ”‚   def __getitem__(self, key):  # noqa: F811                                             โ”‚
โ”‚   2777 โ”‚   โ”‚   """Can be used to index columns (by string names) or rows (by integer index or i  โ”‚
โ”‚ โฑ 2778 โ”‚   โ”‚   return self._getitem(key)                                                         โ”‚
โ”‚   2779 โ”‚                                                                                         โ”‚
โ”‚   2780 โ”‚   def __getitems__(self, keys: List) -> List:                                           โ”‚
โ”‚   2781 โ”‚   โ”‚   """Can be used to get a batch using a list of integers indices."""                โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py:2762 in _getitem               โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   2759 โ”‚   โ”‚   format_kwargs = kwargs["format_kwargs"] if "format_kwargs" in kwargs else self._  โ”‚
โ”‚   2760 โ”‚   โ”‚   format_kwargs = format_kwargs if format_kwargs is not None else {}                โ”‚
โ”‚   2761 โ”‚   โ”‚   formatter = get_formatter(format_type, features=self._info.features, **format_kw  โ”‚
โ”‚ โฑ 2762 โ”‚   โ”‚   pa_subtable = query_table(self._data, key, indices=self._indices if self._indice  โ”‚
โ”‚   2763 โ”‚   โ”‚   formatted_output = format_table(                                                  โ”‚
โ”‚   2764 โ”‚   โ”‚   โ”‚   pa_subtable, key, formatter=formatter, format_columns=format_columns, output  โ”‚
โ”‚   2765 โ”‚   โ”‚   )                                                                                 โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py:578 in query_table     โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   575 โ”‚   โ”‚   _check_valid_column_key(key, table.column_names)                                   โ”‚
โ”‚   576 โ”‚   else:                                                                                  โ”‚
โ”‚   577 โ”‚   โ”‚   size = indices.num_rows if indices is not None else table.num_rows                 โ”‚
โ”‚ โฑ 578 โ”‚   โ”‚   _check_valid_index_key(key, size)                                                  โ”‚
โ”‚   579 โ”‚   # Query the main table                                                                 โ”‚
โ”‚   580 โ”‚   if indices is None:                                                                    โ”‚
โ”‚   581 โ”‚   โ”‚   pa_subtable = _query_table(table, key)                                             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py:531 in                 โ”‚
โ”‚ _check_valid_index_key                                                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   528 โ”‚   โ”‚   โ”‚   _check_valid_index_key(min(key), size=size)                                    โ”‚
โ”‚   529 โ”‚   elif isinstance(key, Iterable):                                                        โ”‚
โ”‚   530 โ”‚   โ”‚   if len(key) > 0:                                                                   โ”‚
โ”‚ โฑ 531 โ”‚   โ”‚   โ”‚   _check_valid_index_key(int(max(key)), size=size)                               โ”‚
โ”‚   532 โ”‚   โ”‚   โ”‚   _check_valid_index_key(int(min(key)), size=size)                               โ”‚
โ”‚   533 โ”‚   else:                                                                                  โ”‚
โ”‚   534 โ”‚   โ”‚   _raise_bad_key_type(key)                                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py:521 in                 โ”‚
โ”‚ _check_valid_index_key                                                                           โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   518 def _check_valid_index_key(key: Union[int, slice, range, Iterable], size: int) -> None:    โ”‚
โ”‚   519 โ”‚   if isinstance(key, int):                                                               โ”‚
โ”‚   520 โ”‚   โ”‚   if (key < 0 and key + size < 0) or (key >= size):                                  โ”‚
โ”‚ โฑ 521 โ”‚   โ”‚   โ”‚   raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")       โ”‚
โ”‚   522 โ”‚   โ”‚   return                                                                             โ”‚
โ”‚   523 โ”‚   elif isinstance(key, slice):                                                           โ”‚
โ”‚   524 โ”‚   โ”‚   pass                                                                               โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
IndexError: Invalid key: 19 is out of bounds for size 0

I think there is a bug in the Trainer class (probably in _remove_unused_columns function). This error happens only when I use LoRA fine-tuning. For now, I resolved the error by setting remove_unused_columns=False in the TrainingArguments.