Trainer's dataloader influenced by target model?

Hello all :wave:,

Iā€™m playing with the AutoTrain and encountered a behaviour I cannot explain by myself, if any of you has an idea of why this behaviour, I would be very glad to know :smiley:

It appears that the dataloader used by the Trainer can fully crash depending the model provided to the Trainer.

i.e:

The following code works

from transformers import (
    Wav2Vec2CTCTokenizer,
    SeamlessM4TFeatureExtractor,
    Wav2Vec2BertForCTC,
    Wav2Vec2BertProcessor,
)

processor = Wav2Vec2BertProcessor(
    feature_extractor=SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0"),
    tokenizer=Wav2Vec2CTCTokenizer.from_pretrained(
        "/home/vdm/W2V2-BERT-CTC-EN/outputs/",
        unk_token="UNK_TOEN",
        pad_token="PAD_TOKEN",
        word_delimiter_token="WORD_DELIMITER_TOKEN",
    ),
)

data_collator = DataCollatorCTCWithPadding(
    processor=processor,
    padding=True,
)

model = Wav2Vec2BertForCTC.from_pretrained(
    'facebook/w2v-bert-2.0',
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    add_adapter=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    train_dataset=common_voice_train,
)

for i in trainer.get_train_dataloader():
    print(i)
    exit()

But the following doesnā€™t

from transformers import (
    Wav2Vec2CTCTokenizer,
    SeamlessM4TFeatureExtractor,
    Wav2Vec2BertForCTC,
    Wav2Vec2BertProcessor,
)

processor = Wav2Vec2BertProcessor(
    feature_extractor=SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0"),
    tokenizer=Wav2Vec2CTCTokenizer.from_pretrained(
        "/home/vdm/W2V2-BERT-CTC-EN/outputs/",
        unk_token="UNK_TOEN",
        pad_token="PAD_TOKEN",
        word_delimiter_token="WORD_DELIMITER_TOKEN",
    ),
)

data_collator = DataCollatorCTCWithPadding(
    processor=processor,
    padding=True,
)

from torch import nn

# ---- HERE ----
class dummy(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def forward(self, t):
        return t
# ---------------

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    train_dataset=common_voice_train,
)

for i in trainer.get_train_dataloader():
    print(i)
    exit()

Resulting error:

Traceback (most recent call last):
  File "/home/vdm/SLAM-ASR/src/training.py", line 265, in <module>
    main()
  File "/home/vdm/SLAM-ASR/src/training.py", line 256, in main
    for i in trainer.get_train_dataloader():
  File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/accelerate/data_loader.py", line 452, in __iter__
    current_batch = next(dataloader_iter)
  File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
  File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2814, in __getitems__
    batch = self.__getitem__(keys)
  File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2810, in __getitem__
    return self._getitem(key)
  File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2794, in _getitem
    pa_subtable = query_table(self._data, key, indices=self._indices)
  File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 583, in query_table
    _check_valid_index_key(key, size)
  File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 536, in _check_valid_index_key
    _check_valid_index_key(int(max(key)), size=size)
  File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 526, in _check_valid_index_key
    raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
IndexError: Invalid key: 198 is out of bounds for size 0

Note that only the model changes (and is not even called), we only try to loop over the dataloader one time.

Is the dataloader behaviour influenced by the provided model? If so, what should one add to dummy in order to being able to iterate over the dataloader?

Solved! :tada:

So basically, as written in the code (cf) ā€œcolumns not accepted by the model.forward() method are automatically removedā€.

Since the parameters of dummy.forward() do not correspond to any columns in the dataset provided, all columns are eliminated, leading to an empty dataset and triggering the IndexError.

While this behavior is documented in the doc, I think a proper error message (or at least a warning) should be raised in case no columns matched the model.forward() params :slight_smile:

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.