Hello all ,
Iām playing with the AutoTrain
and encountered a behaviour I cannot explain by myself, if any of you has an idea of why this behaviour, I would be very glad to know
It appears that the dataloader used by the Trainer
can fully crash depending the model provided to the Trainer
.
i.e:
The following code works
from transformers import (
Wav2Vec2CTCTokenizer,
SeamlessM4TFeatureExtractor,
Wav2Vec2BertForCTC,
Wav2Vec2BertProcessor,
)
processor = Wav2Vec2BertProcessor(
feature_extractor=SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0"),
tokenizer=Wav2Vec2CTCTokenizer.from_pretrained(
"/home/vdm/W2V2-BERT-CTC-EN/outputs/",
unk_token="UNK_TOEN",
pad_token="PAD_TOKEN",
word_delimiter_token="WORD_DELIMITER_TOKEN",
),
)
data_collator = DataCollatorCTCWithPadding(
processor=processor,
padding=True,
)
model = Wav2Vec2BertForCTC.from_pretrained(
'facebook/w2v-bert-2.0',
attention_dropout=0.0,
hidden_dropout=0.0,
feat_proj_dropout=0.0,
mask_time_prob=0.0,
layerdrop=0.0,
ctc_loss_reduction="mean",
add_adapter=True,
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)
trainer = Trainer(
model=model,
data_collator=data_collator,
train_dataset=common_voice_train,
)
for i in trainer.get_train_dataloader():
print(i)
exit()
But the following doesnāt
from transformers import (
Wav2Vec2CTCTokenizer,
SeamlessM4TFeatureExtractor,
Wav2Vec2BertForCTC,
Wav2Vec2BertProcessor,
)
processor = Wav2Vec2BertProcessor(
feature_extractor=SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0"),
tokenizer=Wav2Vec2CTCTokenizer.from_pretrained(
"/home/vdm/W2V2-BERT-CTC-EN/outputs/",
unk_token="UNK_TOEN",
pad_token="PAD_TOKEN",
word_delimiter_token="WORD_DELIMITER_TOKEN",
),
)
data_collator = DataCollatorCTCWithPadding(
processor=processor,
padding=True,
)
from torch import nn
# ---- HERE ----
class dummy(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def forward(self, t):
return t
# ---------------
trainer = Trainer(
model=model,
data_collator=data_collator,
train_dataset=common_voice_train,
)
for i in trainer.get_train_dataloader():
print(i)
exit()
Resulting error:
Traceback (most recent call last):
File "/home/vdm/SLAM-ASR/src/training.py", line 265, in <module>
main()
File "/home/vdm/SLAM-ASR/src/training.py", line 256, in main
for i in trainer.get_train_dataloader():
File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/accelerate/data_loader.py", line 452, in __iter__
current_batch = next(dataloader_iter)
File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = self.dataset.__getitems__(possibly_batched_index)
File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2814, in __getitems__
batch = self.__getitem__(keys)
File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2810, in __getitem__
return self._getitem(key)
File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2794, in _getitem
pa_subtable = query_table(self._data, key, indices=self._indices)
File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 583, in query_table
_check_valid_index_key(key, size)
File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 536, in _check_valid_index_key
_check_valid_index_key(int(max(key)), size=size)
File "/home/vdm/.pyenv/versions/3.10.13/envs/SLAM-ASR/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 526, in _check_valid_index_key
raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
IndexError: Invalid key: 198 is out of bounds for size 0
Note that only the model changes (and is not even called), we only try to loop over the dataloader one time.
Is the dataloader behaviour influenced by the provided model? If so, what should one add to dummy
in order to being able to iterate over the dataloader?