KeyError:664 with Seq2Seq trainer()

Hello all, I have two columns of data which I have loaded as a df and then dfs for training_set, testing_set, and a val_set. I keep getting the error posted below. I’ve tried resetting indexes, reindexing… Any ideas? Thank you very much!

The code block is:
training_args = Seq2SeqTrainingArguments(
output_dir=“TESTING”,
evaluation_strategy=“epoch”,
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=4,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=2,
fp16=True,
predict_with_generate=True,

)

trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_set,
eval_dataset=val_set,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics
)

trainer.train()

The error is:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /usr/local/lib/python3.10/dist-packages/pandas/core/indexes/base.py:3802 in get_loc │
│ │
│ 3799 │ │ │ │ ) │
│ 3800 │ │ │ casted_key = self._maybe_cast_indexer(key) │
│ 3801 │ │ │ try: │
│ ❱ 3802 │ │ │ │ return self._engine.get_loc(casted_key) │
│ 3803 │ │ │ except KeyError as err: │
│ 3804 │ │ │ │ raise KeyError(key) from err │
│ 3805 │ │ │ except TypeError: │
│ │
│ in pandas._libs.index.IndexEngine.get_loc:138 │
│ │
│ in pandas._libs.index.IndexEngine.get_loc:165 │
│ │
│ in pandas._libs.hashtable.PyObjectHashTable.get_item:5745 │
│ │
│ in pandas._libs.hashtable.PyObjectHashTable.get_item:5753 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 664

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <cell line: 25>:25 │
│ │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1645 in train │
│ │
│ 1642 │ │ inner_training_loop = find_executable_batch_size( │
│ 1643 │ │ │ self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size │
│ 1644 │ │ ) │
│ ❱ 1645 │ │ return inner_training_loop( │
│ 1646 │ │ │ args=args, │
│ 1647 │ │ │ resume_from_checkpoint=resume_from_checkpoint, │
│ 1648 │ │ │ trial=trial, │
│ │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1916 in _inner_training_loop │
│ │
│ 1913 │ │ │ │ rng_to_sync = True │
│ 1914 │ │ │ │
│ 1915 │ │ │ step = -1 │
│ ❱ 1916 │ │ │ for step, inputs in enumerate(epoch_iterator): │
│ 1917 │ │ │ │ total_batched_samples += 1 │
│ 1918 │ │ │ │ if rng_to_sync: │
│ 1919 │ │ │ │ │ self._load_rng_state(resume_from_checkpoint) │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:633 in next
│ │
│ 630 │ │ │ if self._sampler_iter is None: │
│ 631 │ │ │ │ # TODO(Bug in dataloader iterator found by mypy · Issue #76750 · pytorch/pytorch · GitHub) │
│ 632 │ │ │ │ self._reset() # type: ignore[call-arg] │
│ ❱ 633 │ │ │ data = self._next_data() │
│ 634 │ │ │ self._num_yielded += 1 │
│ 635 │ │ │ if self._dataset_kind == _DatasetKind.Iterable and \ │
│ 636 │ │ │ │ │ self._IterableDataset_len_called is not None and \ │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:677 in _next_data │
│ │
│ 674 │ │
│ 675 │ def _next_data(self): │
│ 676 │ │ index = self._next_index() # may raise StopIteration │
│ ❱ 677 │ │ data = self._dataset_fetcher.fetch(index) # may raise StopIteration │
│ 678 │ │ if self._pin_memory: │
│ 679 │ │ │ data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) │
│ 680 │ │ return data │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py:51 in fetch │
│ │
│ 48 │ │ │ if hasattr(self.dataset, “getitems”) and self.dataset.getitems: │
│ 49 │ │ │ │ data = self.dataset.getitems(possibly_batched_index) │
│ 50 │ │ │ else: │
│ ❱ 51 │ │ │ │ data = [self.dataset[idx] for idx in possibly_batched_index] │
│ 52 │ │ else: │
│ 53 │ │ │ data = self.dataset[possibly_batched_index] │
│ 54 │ │ return self.collate_fn(data) │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py:51 in │
│ │
│ 48 │ │ │ if hasattr(self.dataset, “getitems”) and self.dataset.getitems: │
│ 49 │ │ │ │ data = self.dataset.getitems(possibly_batched_index) │
│ 50 │ │ │ else: │
│ ❱ 51 │ │ │ │ data = [self.dataset[idx] for idx in possibly_batched_index] │
│ 52 │ │ else: │
│ 53 │ │ │ data = self.dataset[possibly_batched_index] │
│ 54 │ │ return self.collate_fn(data) │
│ │
│ /usr/local/lib/python3.10/dist-packages/pandas/core/frame.py:3807 in getitem
│ │
│ 3804 │ │ if is_single_key: │
│ 3805 │ │ │ if self.columns.nlevels > 1: │
│ 3806 │ │ │ │ return self._getitem_multilevel(key) │
│ ❱ 3807 │ │ │ indexer = self.columns.get_loc(key) │
│ 3808 │ │ │ if is_integer(indexer): │
│ 3809 │ │ │ │ indexer = [indexer] │
│ 3810 │ │ else: │
│ │
│ /usr/local/lib/python3.10/dist-packages/pandas/core/indexes/base.py:3804 in get_loc │
│ │
│ 3801 │ │ │ try: │
│ 3802 │ │ │ │ return self._engine.get_loc(casted_key) │
│ 3803 │ │ │ except KeyError as err: │
│ ❱ 3804 │ │ │ │ raise KeyError(key) from err │
│ 3805 │ │ │ except TypeError: │
│ 3806 │ │ │ │ # If we have a listlike key, _check_indexing_error will raise │
│ 3807 │ │ │ │ # InvalidIndexError. Otherwise we fall through and re-raise │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 664